rename to train
This commit is contained in:
parent
7afe928ee1
commit
9ee809a047
|
|
@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
@hydra.main(config_path="train_config", config_name="config")
|
||||||
def main(config: DictConfig):
|
def train(config: DictConfig):
|
||||||
|
|
||||||
OmegaConf.save(config, "config_log.yaml")
|
OmegaConf.save(config, "config.yaml")
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
logger = MLFlowLogger(
|
logger = MLFlowLogger(
|
||||||
|
|
@ -96,7 +96,7 @@ def main(config: DictConfig):
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
logger.experiment.log_artifact(
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
saved_location = os.path.join(
|
saved_location = os.path.join(
|
||||||
|
|
@ -117,4 +117,4 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
train()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue