add early stopping
This commit is contained in:
parent
891446f7db
commit
204de08a9a
|
|
@ -4,7 +4,7 @@ from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
|
@ -45,15 +45,16 @@ def main(config: DictConfig):
|
||||||
every_n_epochs=1,
|
every_n_epochs=1,
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
# early_stopping = EarlyStopping(
|
if parameters.get("Early_stop", False):
|
||||||
# monitor="val_loss",
|
early_stopping = EarlyStopping(
|
||||||
# mode=direction,
|
monitor="val_loss",
|
||||||
# min_delta=0.0,
|
mode=direction,
|
||||||
# patience=parameters.get("EarlyStopping_patience", 10),
|
min_delta=0.0,
|
||||||
# strict=True,
|
patience=parameters.get("EarlyStopping_patience", 10),
|
||||||
# verbose=False,
|
strict=True,
|
||||||
# )
|
verbose=False,
|
||||||
# callbacks.append(early_stopping)
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
def configure_optimizer(self):
|
def configure_optimizer(self):
|
||||||
optimizer = instantiate(
|
optimizer = instantiate(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue