This commit is contained in:
shahules786 2022-10-22 09:57:27 +05:30
parent 8457e1cbe2
commit c4a27686da
1 changed files with 3 additions and 2 deletions

View File

@ -77,6 +77,7 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
print("num_workers-main", self.num_workers)
if valid_minutes > 0.0: if valid_minutes > 0.0:
self.valid_minutes = valid_minutes self.valid_minutes = valid_minutes
else: else:
@ -184,7 +185,6 @@ class TaskDataset(pl.LightningDataModule):
output["noisy"].append(item["noisy"]) output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"]) output["clean"].append(item["clean"])
names.append(item["name"]) names.append(item["name"])
print(names) print(names)
output["clean"] = torch.stack(output["clean"], dim=0) output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0)
@ -359,8 +359,9 @@ class EnhancerDataset(TaskDataset):
num_workers = 2 num_workers = 2
else: else:
num_workers = 1 num_workers = 1
print("num_workers", num_workers)
return sum([len(item) for item in self.train_data]) // ( return sum([len(item) for item in self.train_data]) // (
self.batch_size * num_workers self.batch_size * self.num_workers
) )
def val__len__(self): def val__len__(self):