diff --git a/cli/train.py b/cli/train.py index cd19ffc..dee3d2e 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,3 +1,4 @@ +from genericpath import isfile import os from types import MethodType import hydra @@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID") +JOB_ID = os.environ.get("SLURM_JOBID","0") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): @@ -55,8 +56,10 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - if os.path.exists("./model/"): - logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt") + + saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id,saved_location)