diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d4686b5..1ae48ed 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -77,6 +77,7 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers + print("num_workers-main", self.num_workers) if valid_minutes > 0.0: self.valid_minutes = valid_minutes else: @@ -184,7 +185,6 @@ class TaskDataset(pl.LightningDataModule): output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) names.append(item["name"]) - print(names) output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) @@ -359,8 +359,9 @@ class EnhancerDataset(TaskDataset): num_workers = 2 else: num_workers = 1 + print("num_workers", num_workers) return sum([len(item) for item in self.train_data]) // ( - self.batch_size * num_workers + self.batch_size * self.num_workers ) def val__len__(self):