fix logging path
This commit is contained in:
parent
ecd47905dd
commit
a880125322
|
|
@ -1,3 +1,4 @@
|
|||
from genericpath import isfile
|
||||
import os
|
||||
from types import MethodType
|
||||
import hydra
|
||||
|
|
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID")
|
||||
JOB_ID = os.environ.get("SLURM_JOBID","0")
|
||||
|
||||
@hydra.main(config_path="train_config",config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
|
@ -55,8 +56,10 @@ def main(config: DictConfig):
|
|||
|
||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
if os.path.exists("./model/"):
|
||||
logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt")
|
||||
|
||||
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id,saved_location)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue