diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f71d612..34ecb8f 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -186,12 +186,14 @@ class TaskDataset(pl.LightningDataModule): output["noisy"] = torch.stack(output["noisy"], dim=0) if self.augmentations is not None: + noise = output["noisy"] - output["clean"] output["clean"] = self.augmentations( output["clean"], sample_rate=self.sampling_rate ) self.augmentations.freeze_parameters() - output["noisy"] = self.augmentations( - output["noisy"], sample_rate=self.sampling_rate + output["noisy"] = ( + self.augmentations(noise, sample_rate=self.sampling_rate) + + output["clean"] ) return output