diff --git a/mayavoz/cli/train.py b/mayavoz/cli/train.py index c00c024..8f12ea7 100644 --- a/mayavoz/cli/train.py +++ b/mayavoz/cli/train.py @@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0") @hydra.main(config_path="train_config", config_name="config") -def main(config: DictConfig): +def train(config: DictConfig): - OmegaConf.save(config, "config_log.yaml") + OmegaConf.save(config, "config.yaml") callbacks = [] logger = MLFlowLogger( @@ -96,7 +96,7 @@ def main(config: DictConfig): trainer.test(model) logger.experiment.log_artifact( - logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" + logger.run_id, f"{trainer.default_root_dir}/config.yaml" ) saved_location = os.path.join( @@ -117,4 +117,4 @@ def main(config: DictConfig): if __name__ == "__main__": - main() + train()