From 204de08a9ac065f32997aab12c34738e0debced2 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 14 Oct 2022 10:52:56 +0530 Subject: [PATCH] add early stopping --- enhancer/cli/train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 49f7b3b..08f4d3e 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,7 +4,7 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -45,15 +45,16 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) - # early_stopping = EarlyStopping( - # monitor="val_loss", - # mode=direction, - # min_delta=0.0, - # patience=parameters.get("EarlyStopping_patience", 10), - # strict=True, - # verbose=False, - # ) - # callbacks.append(early_stopping) + if parameters.get("Early_stop", False): + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) def configure_optimizer(self): optimizer = instantiate(