diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index c128404..2ea4018 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -4,7 +4,8 @@ root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 stride : 0.5 sampling_rate: 16000 -batch_size: 16 +batch_size: 4 +valid_minutes : 1 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 218cd7c..3722763 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -145,8 +145,7 @@ class TaskDataset(pl.LightningDataModule): ({"clean": clean, "noisy": noisy}, start) ) train_data.append(samples_metadata) - print(train_data[:10]) - return train_data + return train_data[:25] @staticmethod def get_num_segments(file_duration, duration, stride): @@ -175,12 +174,14 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - + names = [] output = {"noisy": [], "clean": []} for item in batch: output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) + names.append(item["name"]) + print(names) output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) return output @@ -318,7 +319,6 @@ class EnhancerDataset(TaskDataset): def prepare_segment(self, file_dict: dict, start_time: float): - print(file_dict["clean"].split("/")[-1], "->", start_time) clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration ) @@ -343,7 +343,11 @@ class EnhancerDataset(TaskDataset): ), ), ) - return {"clean": clean_segment, "noisy": noisy_segment} + return { + "clean": clean_segment, + "noisy": noisy_segment, + "name": file_dict["clean"].split("/")[-1] + "->" + start_time, + } def train__len__(self):