diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index ab2f1ce..1e0ec04 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -7,6 +7,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor from enhancer.utils import check_files @@ -14,9 +15,6 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng -# from torch_audiomentations import Compose - - LARGE_NUM = 2147483647 @@ -66,7 +64,7 @@ class TaskDataset(pl.LightningDataModule): matching_function=None, batch_size=32, num_workers: Optional[int] = None, - # augmentations: Optional[Compose] = None, + augmentations: Optional[Compose] = None, ): super().__init__() @@ -86,7 +84,7 @@ class TaskDataset(pl.LightningDataModule): else: raise ValueError("valid_minutes must be greater than 0") - # self.augmentations = augmentations + self.augmentations = augmentations def setup(self, stage: Optional[str] = None): """ @@ -178,7 +176,25 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - raise NotImplementedError("Not implemented") + + output = {"clean": [], "noisy": []} + for item in batch: + output["clean"].append(item["clean"]) + output["noisy"].append(item["noisy"]) + + output["clean"] = torch.stack(output["clean"], dim=0) + output["noisy"] = torch.stack(output["noisy"], dim=0) + + if self.augmentations is not None: + output["clean"] = self.augmentations( + output["clean"], sample_rate=self.sampling_rate + ) + self.augmentations.freeze_parameters() + output["noisy"] = self.augmentations( + output["noisy"], sample_rate=self.sampling_rate + ) + + return output @property def generator(self): @@ -251,7 +267,7 @@ class EnhancerDataset(TaskDataset): matching_function=None, batch_size=32, num_workers: Optional[int] = None, - # augmentations: Optional[Compose] = None, + augmentations: Optional[Compose] = None, ): super().__init__( @@ -264,7 +280,7 @@ class EnhancerDataset(TaskDataset): matching_function=matching_function, batch_size=batch_size, num_workers=num_workers, - # augmentations=augmentations, + augmentations=augmentations, ) self.sampling_rate = sampling_rate