diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 6d5f182..131db4f 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,10 +4,14 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch_audiomentations import BandPassFilter, Compose, Shift +from torch_audiomentations import Compose, Shift os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -29,7 +33,6 @@ def main(config: DictConfig): apply_augmentations = Compose( [ Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5), - BandPassFilter(p=0.5), ] ) @@ -52,6 +55,8 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) + callbacks.append(LearningRateMonitor(logging_interval="epoch")) + if parameters.get("Early_stop", False): early_stopping = EarlyStopping( monitor="val_loss", @@ -63,11 +68,11 @@ def main(config: DictConfig): ) callbacks.append(early_stopping) - def configure_optimizer(self): + def configure_optimizers(self): optimizer = instantiate( config.optimizer, lr=parameters.get("lr"), - parameters=self.parameters(), + params=self.parameters(), ) scheduler = ReduceLROnPlateau( optimizer=optimizer, @@ -77,9 +82,13 @@ def main(config: DictConfig): min_lr=parameters.get("min_lr", 1e-6), patience=parameters.get("ReduceLr_patience", 3), ) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', + } - model.configure_parameters = MethodType(configure_optimizer, model) + model.configure_optimizers = MethodType(configure_optimizers, model) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model)