log model artifacts
This commit is contained in:
parent
10ec1a76c8
commit
d25be3a59d
|
|
@ -5,13 +5,14 @@ from omegaconf import DictConfig
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
JOB_ID = os.environ.get("SLURM_JOBID")
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
||||||
run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")})
|
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
|
||||||
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
|
|
@ -22,7 +23,7 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
direction = model.valid_monitor
|
direction = model.valid_monitor
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="",filename="model",monitor="val_loss",verbose=False,
|
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False,
|
||||||
mode=direction,every_n_epochs=1
|
mode=direction,every_n_epochs=1
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
|
|
@ -30,7 +31,7 @@ def main(config: DictConfig):
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
mode=direction,
|
mode=direction,
|
||||||
min_delta=0.0,
|
min_delta=0.0,
|
||||||
patience=100,
|
patience=10,
|
||||||
strict=True,
|
strict=True,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
@ -38,6 +39,7 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue