valid monitor fix
This commit is contained in:
parent
ba271f8a2a
commit
fffdf02b93
11
cli/train.py
11
cli/train.py
|
|
@ -4,7 +4,7 @@ from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
@ -20,14 +20,15 @@ def main(config: DictConfig):
|
||||||
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
||||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||||
|
|
||||||
|
direction = model.valid_monitor
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="",filename="model",monitor="valid_loss",verbose=False,
|
dirpath="",filename="model",monitor="val_loss",verbose=False,
|
||||||
mode="min",every_n_epochs=1
|
mode=direction,every_n_epochs=1
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor="valid_loss",
|
monitor="val_loss",
|
||||||
mode="min",
|
mode=direction,
|
||||||
min_delta=0.0,
|
min_delta=0.0,
|
||||||
patience=100,
|
patience=100,
|
||||||
strict=True,
|
strict=True,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue