This commit is contained in:
shahules786 2022-10-22 10:30:27 +05:30
parent c4a27686da
commit 7fa54fc414
1 changed files with 5 additions and 7 deletions

View File

@ -355,14 +355,12 @@ class EnhancerDataset(TaskDataset):
}
def train__len__(self):
if self.num_workers > 1:
num_workers = 2
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
train_data = self.train_data
else:
num_workers = 1
print("num_workers", num_workers)
return sum([len(item) for item in self.train_data]) // (
self.batch_size * self.num_workers
)
train_data = worker_info.dataset.data
return sum([len(item) for item in train_data]) // (self.batch_size)
def val__len__(self):
return len(self._validation)