From 0d3bfd341210f2f43d9162e1be68a5e389f1de2c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 11:13:17 +0530 Subject: [PATCH] debug --- enhancer/cli/train_config/dataset/Vctk.yaml | 3 ++- enhancer/data/dataset.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index c128404..2ea4018 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -4,7 +4,8 @@ root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 stride : 0.5 sampling_rate: 16000 -batch_size: 16 +batch_size: 4 +valid_minutes : 1 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 218cd7c..3722763 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -145,8 +145,7 @@ class TaskDataset(pl.LightningDataModule): ({"clean": clean, "noisy": noisy}, start) ) train_data.append(samples_metadata) - print(train_data[:10]) - return train_data + return train_data[:25] @staticmethod def get_num_segments(file_duration, duration, stride): @@ -175,12 +174,14 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - + names = [] output = {"noisy": [], "clean": []} for item in batch: output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) + names.append(item["name"]) + print(names) output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) return output @@ -318,7 +319,6 @@ class EnhancerDataset(TaskDataset): def prepare_segment(self, file_dict: dict, start_time: float): - print(file_dict["clean"].split("/")[-1], "->", start_time) clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration ) @@ -343,7 +343,11 @@ class EnhancerDataset(TaskDataset): ), ), ) - return {"clean": clean_segment, "noisy": noisy_segment} + return { + "clean": clean_segment, + "noisy": noisy_segment, + "name": file_dict["clean"].split("/")[-1] + "->" + start_time, + } def train__len__(self):