fix logging path

This commit is contained in:
shahules786 2022-10-03 20:01:19 +05:30
parent ecd47905dd
commit a880125322
1 changed files with 6 additions and 3 deletions

View File

@ -1,3 +1,4 @@
from genericpath import isfile
import os
from types import MethodType
import hydra
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID")
JOB_ID = os.environ.get("SLURM_JOBID","0")
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
@ -55,8 +56,10 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
if os.path.exists("./model/"):
logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt")
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id,saved_location)