From 7fa54fc414f438a1cae888e93a31edb1f604c01e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 10:30:27 +0530 Subject: [PATCH] debug --- enhancer/data/dataset.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 1ae48ed..05dd287 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -355,14 +355,12 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - if self.num_workers > 1: - num_workers = 2 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + train_data = self.train_data else: - num_workers = 1 - print("num_workers", num_workers) - return sum([len(item) for item in self.train_data]) // ( - self.batch_size * self.num_workers - ) + train_data = worker_info.dataset.data + return sum([len(item) for item in train_data]) // (self.batch_size) def val__len__(self): return len(self._validation)