add num_workers as arg
This commit is contained in:
parent
60b5d00bab
commit
658e4d08a5
|
|
@ -1,6 +1,4 @@
|
||||||
|
import multiprocessing
|
||||||
from dataclasses import dataclass
|
|
||||||
import glob
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
duration:float=1.0,
|
duration:float=1.0,
|
||||||
sampling_rate:int=48000,
|
sampling_rate:int=48000,
|
||||||
matching_function = None,
|
matching_function = None,
|
||||||
batch_size=32):
|
batch_size=32,
|
||||||
|
num_workers:Optional[int]=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
self._validation = []
|
self._validation = []
|
||||||
|
if num_workers is None:
|
||||||
|
num_workers = multiprocessing.cpu_count()//2
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
|
@ -85,10 +87,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self._validation.append(({"clean":clean,"noisy":noisy},
|
self._validation.append(({"clean":clean,"noisy":noisy},
|
||||||
start_time))
|
start_time))
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2)
|
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2)
|
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class EnhancerDataset(TaskDataset):
|
||||||
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
||||||
|
|
@ -101,7 +103,8 @@ class EnhancerDataset(TaskDataset):
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32):
|
batch_size=32,
|
||||||
|
num_workers:Optional[int]=None):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
@ -110,7 +113,8 @@ class EnhancerDataset(TaskDataset):
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function = matching_function,
|
matching_function = matching_function,
|
||||||
batch_size=batch_size
|
batch_size=batch_size,
|
||||||
|
num_workers = num_workers,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue