diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 5562cfd..c00c024 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -11,7 +11,8 @@ from pytorch_lightning.callbacks import ( ) from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch_audiomentations import Compose, Shift + +# from torch_audiomentations import Compose, Shift os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -30,13 +31,13 @@ def main(config: DictConfig): ) parameters = config.hyperparameters - apply_augmentations = Compose( - [ - Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), - ] - ) + # apply_augmentations = Compose( + # [ + # Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), + # ] + # ) - dataset = instantiate(config.dataset, augmentations=apply_augmentations) + dataset = instantiate(config.dataset, augmentations=None) model = instantiate( config.model, dataset=dataset,