add early stopping

This commit is contained in:
shahules786 2022-10-14 10:52:56 +05:30
parent 891446f7db
commit 204de08a9a
1 changed files with 11 additions and 10 deletions

View File

@ -4,7 +4,7 @@ from types import MethodType
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
@ -45,15 +45,16 @@ def main(config: DictConfig):
every_n_epochs=1, every_n_epochs=1,
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
# early_stopping = EarlyStopping( if parameters.get("Early_stop", False):
# monitor="val_loss", early_stopping = EarlyStopping(
# mode=direction, monitor="val_loss",
# min_delta=0.0, mode=direction,
# patience=parameters.get("EarlyStopping_patience", 10), min_delta=0.0,
# strict=True, patience=parameters.get("EarlyStopping_patience", 10),
# verbose=False, strict=True,
# ) verbose=False,
# callbacks.append(early_stopping) )
callbacks.append(early_stopping)
def configure_optimizer(self): def configure_optimizer(self):
optimizer = instantiate( optimizer = instantiate(