This commit is contained in:
shahules786 2022-10-21 11:13:17 +05:30
parent 178a4523ef
commit 0d3bfd3412
2 changed files with 11 additions and 6 deletions

View File

@ -4,7 +4,8 @@ root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 4.5
stride : 0.5 stride : 0.5
sampling_rate: 16000 sampling_rate: 16000
batch_size: 16 batch_size: 4
valid_minutes : 1
files: files:
train_clean : clean_trainset_28spk_wav train_clean : clean_trainset_28spk_wav

View File

@ -145,8 +145,7 @@ class TaskDataset(pl.LightningDataModule):
({"clean": clean, "noisy": noisy}, start) ({"clean": clean, "noisy": noisy}, start)
) )
train_data.append(samples_metadata) train_data.append(samples_metadata)
print(train_data[:10]) return train_data[:25]
return train_data
@staticmethod @staticmethod
def get_num_segments(file_duration, duration, stride): def get_num_segments(file_duration, duration, stride):
@ -175,12 +174,14 @@ class TaskDataset(pl.LightningDataModule):
return metadata return metadata
def train_collatefn(self, batch): def train_collatefn(self, batch):
names = []
output = {"noisy": [], "clean": []} output = {"noisy": [], "clean": []}
for item in batch: for item in batch:
output["noisy"].append(item["noisy"]) output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"]) output["clean"].append(item["clean"])
names.append(item["name"])
print(names)
output["clean"] = torch.stack(output["clean"], dim=0) output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0)
return output return output
@ -318,7 +319,6 @@ class EnhancerDataset(TaskDataset):
def prepare_segment(self, file_dict: dict, start_time: float): def prepare_segment(self, file_dict: dict, start_time: float):
print(file_dict["clean"].split("/")[-1], "->", start_time)
clean_segment = self.audio( clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration file_dict["clean"], offset=start_time, duration=self.duration
) )
@ -343,7 +343,11 @@ class EnhancerDataset(TaskDataset):
), ),
), ),
) )
return {"clean": clean_segment, "noisy": noisy_segment} return {
"clean": clean_segment,
"noisy": noisy_segment,
"name": file_dict["clean"].split("/")[-1] + "->" + start_time,
}
def train__len__(self): def train__len__(self):