fix augmentation

This commit is contained in:
shahules786 2022-10-25 15:10:13 +05:30
parent b070613b64
commit 4acad6ede8
1 changed files with 4 additions and 2 deletions

View File

@ -186,12 +186,14 @@ class TaskDataset(pl.LightningDataModule):
output["noisy"] = torch.stack(output["noisy"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0)
if self.augmentations is not None: if self.augmentations is not None:
noise = output["noisy"] - output["clean"]
output["clean"] = self.augmentations( output["clean"] = self.augmentations(
output["clean"], sample_rate=self.sampling_rate output["clean"], sample_rate=self.sampling_rate
) )
self.augmentations.freeze_parameters() self.augmentations.freeze_parameters()
output["noisy"] = self.augmentations( output["noisy"] = (
output["noisy"], sample_rate=self.sampling_rate self.augmentations(noise, sample_rate=self.sampling_rate)
+ output["clean"]
) )
return output return output