add num_workers as arg

This commit is contained in:
shahules786 2022-09-28 22:05:10 +05:30
parent 60b5d00bab
commit 658e4d08a5
1 changed files with 12 additions and 8 deletions

View File

@ -1,6 +1,4 @@
from dataclasses import dataclass
import glob
import multiprocessing
import math
import os
import pytorch_lightning as pl
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
duration:float=1.0,
sampling_rate:int=48000,
matching_function = None,
batch_size=32):
batch_size=32,
num_workers:Optional[int]=None):
super().__init__()
self.name = name
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count()//2
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
@ -85,10 +87,10 @@ class TaskDataset(pl.LightningDataModule):
self._validation.append(({"clean":clean,"noisy":noisy},
start_time))
def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2)
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2)
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
class EnhancerDataset(TaskDataset):
"""Dataset object for creating clean-noisy speech enhancement datasets"""
@ -101,7 +103,8 @@ class EnhancerDataset(TaskDataset):
duration=1.0,
sampling_rate=48000,
matching_function=None,
batch_size=32):
batch_size=32,
num_workers:Optional[int]=None):
super().__init__(
name=name,
@ -110,7 +113,8 @@ class EnhancerDataset(TaskDataset):
sampling_rate=sampling_rate,
duration=duration,
matching_function = matching_function,
batch_size=batch_size
batch_size=batch_size,
num_workers = num_workers,
)