diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index fb4d04c..a8b3896 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -353,8 +353,11 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - - return sum([len(item) for item in self.train_data]) // self.batch_size + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info else 1 + return sum([len(item) for item in self.train_data]) // ( + self.batch_size * num_workers + ) def val__len__(self): return len(self._validation)