From a880125322c9201c5c3059248f130c17001b7971 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 20:01:19 +0530 Subject: [PATCH] fix logging path --- cli/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cli/train.py b/cli/train.py index cd19ffc..dee3d2e 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,3 +1,4 @@ +from genericpath import isfile import os from types import MethodType import hydra @@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID") +JOB_ID = os.environ.get("SLURM_JOBID","0") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): @@ -55,8 +56,10 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - if os.path.exists("./model/"): - logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt") + + saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id,saved_location)