iterable dataset

This commit is contained in:
shahules786 2022-10-22 11:17:37 +05:30
parent 05e40f84b6
commit 5f1ed8c725
1 changed files with 2 additions and 11 deletions

View File

@ -179,13 +179,11 @@ class TaskDataset(pl.LightningDataModule):
return metadata return metadata
def train_collatefn(self, batch): def train_collatefn(self, batch):
names = []
output = {"noisy": [], "clean": []} output = {"noisy": [], "clean": []}
for item in batch: for item in batch:
output["noisy"].append(item["noisy"]) output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"]) output["clean"].append(item["clean"])
names.append(item["name"])
print(names)
output["clean"] = torch.stack(output["clean"], dim=0) output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0)
return output return output
@ -355,14 +353,7 @@ class EnhancerDataset(TaskDataset):
} }
def train__len__(self): def train__len__(self):
worker_info = torch.utils.data.get_worker_info() return sum([len(item) for item in self.train_data]) // (self.batch_size)
if worker_info is None:
train_data = self.train_data
else:
train_data = worker_info.dataset.data
len = sum([len(item) for item in train_data]) // (self.batch_size)
print("workers", len)
return len
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)