diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f4e7e4a..98abe8a 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,6 +1,4 @@ - -from dataclasses import dataclass -import glob +import multiprocessing import math import os import pytorch_lightning as pl @@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule): duration:float=1.0, sampling_rate:int=48000, matching_function = None, - batch_size=32): + batch_size=32, + num_workers:Optional[int]=None): super().__init__() self.name = name @@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule): self.batch_size = batch_size self.matching_function = matching_function self._validation = [] + if num_workers is None: + num_workers = multiprocessing.cpu_count()//2 + self.num_workers = num_workers def setup(self, stage: Optional[str] = None): @@ -85,10 +87,10 @@ class TaskDataset(pl.LightningDataModule): self._validation.append(({"clean":clean,"noisy":noisy}, start_time)) def train_dataloader(self): - return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2) + return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) def val_dataloader(self): - return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2) + return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) class EnhancerDataset(TaskDataset): """Dataset object for creating clean-noisy speech enhancement datasets""" @@ -101,7 +103,8 @@ class EnhancerDataset(TaskDataset): duration=1.0, sampling_rate=48000, matching_function=None, - batch_size=32): + batch_size=32, + num_workers:Optional[int]=None): super().__init__( name=name, @@ -110,7 +113,8 @@ class EnhancerDataset(TaskDataset): sampling_rate=sampling_rate, duration=duration, matching_function = matching_function, - batch_size=batch_size + batch_size=batch_size, + num_workers = num_workers, )