From d1bafb3dc637aa51493656f677a15434af17b742 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 25 Oct 2022 12:43:54 +0530 Subject: [PATCH] add augmentations --- enhancer/cli/train.py | 9 ++++++++- enhancer/data/dataset.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 08f4d3e..6d5f182 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch_audiomentations import BandPassFilter, Compose, Shift os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -25,8 +26,14 @@ def main(config: DictConfig): ) parameters = config.hyperparameters + apply_augmentations = Compose( + [ + Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5), + BandPassFilter(p=0.5), + ] + ) - dataset = instantiate(config.dataset) + dataset = instantiate(config.dataset, augmentations=apply_augmentations) model = instantiate( config.model, dataset=dataset, diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 1e0ec04..f71d612 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -211,6 +211,7 @@ class TaskDataset(pl.LightningDataModule): batch_size=self.batch_size, num_workers=self.num_workers, generator=self.generator, + collate_fn=self.train_collatefn, ) def val_dataloader(self):