Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-08 11:04:32 +05:30
commit 14156743f9
1 changed files with 5 additions and 3 deletions

View File

@ -15,8 +15,7 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config") @hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig): def main(config: DictConfig):
yaml_conf = OmegaConf.to_yaml(config) OmegaConf.save(config, "config_log.yaml")
OmegaConf.save(yaml_conf, "config_log.yaml")
callbacks = [] callbacks = []
logger = MLFlowLogger( logger = MLFlowLogger(
@ -24,7 +23,6 @@ def main(config: DictConfig):
run_name=config.mlflow.run_name, run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID}, tags={"JOB_ID": JOB_ID},
) )
logger.experiment.log_artifact(logger.run_id, "config_log.yaml")
parameters = config.hyperparameters parameters = config.hyperparameters
@ -78,6 +76,10 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model) trainer.fit(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
)
saved_location = os.path.join( saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
) )