add num_workers as arg
This commit is contained in:
parent
60b5d00bab
commit
658e4d08a5
|
|
@ -1,6 +1,4 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
import glob
|
||||
import multiprocessing
|
||||
import math
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
|
|
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
|
|||
duration:float=1.0,
|
||||
sampling_rate:int=48000,
|
||||
matching_function = None,
|
||||
batch_size=32):
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
|
|
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self.batch_size = batch_size
|
||||
self.matching_function = matching_function
|
||||
self._validation = []
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count()//2
|
||||
self.num_workers = num_workers
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
|
|
@ -85,10 +87,10 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self._validation.append(({"clean":clean,"noisy":noisy},
|
||||
start_time))
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2)
|
||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2)
|
||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
||||
|
|
@ -101,7 +103,8 @@ class EnhancerDataset(TaskDataset):
|
|||
duration=1.0,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32):
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
|
||||
super().__init__(
|
||||
name=name,
|
||||
|
|
@ -110,7 +113,8 @@ class EnhancerDataset(TaskDataset):
|
|||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function = matching_function,
|
||||
batch_size=batch_size
|
||||
batch_size=batch_size,
|
||||
num_workers = num_workers,
|
||||
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue