From 5f1ed8c725931d0a59df7f0eb163af5db93665cf Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 11:17:37 +0530 Subject: [PATCH] iterable dataset --- enhancer/data/dataset.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index c48aded..80055b6 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -179,13 +179,11 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - names = [] output = {"noisy": [], "clean": []} for item in batch: output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) - names.append(item["name"]) - print(names) + output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) return output @@ -355,14 +353,7 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - train_data = self.train_data - else: - train_data = worker_info.dataset.data - len = sum([len(item) for item in train_data]) // (self.batch_size) - print("workers", len) - return len + return sum([len(item) for item in self.train_data]) // (self.batch_size) def val__len__(self): return len(self._validation)