diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 38300fd..de48e64 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,7 +4,7 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -45,15 +45,15 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) - early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=0.0, - patience=parameters.get("EarlyStopping_patience", 10), - strict=True, - verbose=False, - ) - callbacks.append(early_stopping) + # early_stopping = EarlyStopping( + # monitor="val_loss", + # mode=direction, + # min_delta=0.0, + # patience=parameters.get("EarlyStopping_patience", 10), + # strict=True, + # verbose=False, + # ) + # callbacks.append(early_stopping) def configure_optimizer(self): optimizer = instantiate(