diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 81fd443..a32c41f 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -15,8 +15,7 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0") @hydra.main(config_path="train_config", config_name="config") def main(config: DictConfig): - yaml_conf = OmegaConf.to_yaml(config) - OmegaConf.save(yaml_conf, "config_log.yaml") + OmegaConf.save(config, "config_log.yaml") callbacks = [] logger = MLFlowLogger( @@ -24,7 +23,6 @@ def main(config: DictConfig): run_name=config.mlflow.run_name, tags={"JOB_ID": JOB_ID}, ) - logger.experiment.log_artifact(logger.run_id, "config_log.yaml") parameters = config.hyperparameters @@ -78,6 +76,10 @@ def main(config: DictConfig): trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) + logger.experiment.log_artifact( + logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" + ) + saved_location = os.path.join( trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" )