From 9ab996e5c73a89598aa3e695d8d386e65bafdb44 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 27 Sep 2022 22:26:47 +0530 Subject: [PATCH] add callbacks --- cli/train.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/cli/train.py b/cli/train.py index e6644e8..bc5e287 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,21 +1,14 @@ import hydra -from hydra.core.config_store import ConfigStore from hydra.utils import instantiate - -from omegaconf import DictConfig,OmegaConf -from pytorch_lightning import Trainer +from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger -import torch - - -from enhancer.models.demucs import Demucs -from enhancer.data.dataset import EnhancerDataset @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): + callbacks = [] logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, run_name=config.mlflow.run_name) @@ -26,7 +19,22 @@ def main(config: DictConfig): model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), loss=parameters.get("loss"), metric = parameters.get("metric")) - trainer = instantiate(config.trainer,logger=logger) + checkpoint = ModelCheckpoint( + dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False, + mode="min",every_n_epochs=1 + ) + callbacks.append(checkpoint) + early_stopping = EarlyStopping( + monitor=parameters.get("loss"), + mode="min", + min_delta=0.0, + patience=100, + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) + + trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model)