From b1e15c7552ae1551af8d7e26e3bfe596c9a15e7b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 18:12:21 +0530 Subject: [PATCH 1/4] add lr scheduler --- cli/train.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cli/train.py b/cli/train.py index 391df13..d951f33 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,7 +1,9 @@ import os +from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig +from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" @@ -36,6 +38,20 @@ def main(config: DictConfig): verbose=False, ) callbacks.append(early_stopping) + + def configure_optimizer(self): + optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + mode=direction, + factor=0.1, + verbose=True, + min_lr=parameters.get("min_lr",1e-6), + patience=3 + ) + return {"optimizer":optimizer, "lr_scheduler":scheduler} + + model.configure_parameters = MethodType(configure_optimizer,model) trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) From c11bea6aa0a5b5df1aaecb0d48a4bbd402f2fa6a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 18:18:32 +0530 Subject: [PATCH 2/4] add to hyperparameters --- cli/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cli/train.py b/cli/train.py index d951f33..7d49acc 100644 --- a/cli/train.py +++ b/cli/train.py @@ -33,7 +33,7 @@ def main(config: DictConfig): monitor="val_loss", mode=direction, min_delta=0.0, - patience=10, + patience=parameters.get("EarlyStopping_patience",10), strict=True, verbose=False, ) @@ -44,10 +44,10 @@ def main(config: DictConfig): scheduler = ReduceLROnPlateau( optimizer=optimizer, mode=direction, - factor=0.1, + factor=parameters.get("ReduceLr_factor",0.1), verbose=True, min_lr=parameters.get("min_lr",1e-6), - patience=3 + patience=parameters.get("ReduceLr_patience",3) ) return {"optimizer":optimizer, "lr_scheduler":scheduler} From 6faffdcb1731b9b0f799a68afaf3f87baba44794 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 18:18:56 +0530 Subject: [PATCH 3/4] include more params --- cli/train_config/hyperparameters/default.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 04b099b..49a7f80 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -2,3 +2,8 @@ loss : mse metric : mae lr : 0.0001 num_epochs : 100 +ReduceLr_patience : 5 +ReduceLr_factor : 0.1 +min_lr : 0.000001 +EarlyStopping_factor : 10 + From d31a6d2ebd25e1f7b31aa4aeaf47a849311d926b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 09:57:51 +0530 Subject: [PATCH 4/4] check file before logging --- cli/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cli/train.py b/cli/train.py index 7d49acc..8a4beaf 100644 --- a/cli/train.py +++ b/cli/train.py @@ -25,7 +25,7 @@ def main(config: DictConfig): direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, + dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) @@ -55,7 +55,8 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") + if os.path.exists("./model/"): + logger.experiment.log_artifact(logger.run_id,f"./model/.*")