iterable dataset
This commit is contained in:
parent
05e40f84b6
commit
5f1ed8c725
|
|
@ -179,13 +179,11 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def train_collatefn(self, batch):
|
def train_collatefn(self, batch):
|
||||||
names = []
|
|
||||||
output = {"noisy": [], "clean": []}
|
output = {"noisy": [], "clean": []}
|
||||||
for item in batch:
|
for item in batch:
|
||||||
output["noisy"].append(item["noisy"])
|
output["noisy"].append(item["noisy"])
|
||||||
output["clean"].append(item["clean"])
|
output["clean"].append(item["clean"])
|
||||||
names.append(item["name"])
|
|
||||||
print(names)
|
|
||||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||||
return output
|
return output
|
||||||
|
|
@ -355,14 +353,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
return sum([len(item) for item in self.train_data]) // (self.batch_size)
|
||||||
if worker_info is None:
|
|
||||||
train_data = self.train_data
|
|
||||||
else:
|
|
||||||
train_data = worker_info.dataset.data
|
|
||||||
len = sum([len(item) for item in train_data]) // (self.batch_size)
|
|
||||||
print("workers", len)
|
|
||||||
return len
|
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue