log config to mlflow
This commit is contained in:
parent
e4b2965b45
commit
43a22d2432
|
|
@ -15,8 +15,7 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
|||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
yaml_conf = OmegaConf.to_yaml(config)
|
||||
OmegaConf.save(yaml_conf, "config_log.yaml")
|
||||
OmegaConf.save(config, "config_log.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
|
|
@ -24,7 +23,6 @@ def main(config: DictConfig):
|
|||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
logger.experiment.log_artifact(logger.run_id, "config_log.yaml")
|
||||
|
||||
parameters = config.hyperparameters
|
||||
|
||||
|
|
@ -78,6 +76,10 @@ def main(config: DictConfig):
|
|||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue