fix logging path
This commit is contained in:
parent
ecd47905dd
commit
a880125322
|
|
@ -1,3 +1,4 @@
|
||||||
|
from genericpath import isfile
|
||||||
import os
|
import os
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
|
|
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID")
|
JOB_ID = os.environ.get("SLURM_JOBID","0")
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
@ -55,8 +56,10 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
if os.path.exists("./model/"):
|
|
||||||
logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt")
|
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
|
||||||
|
if os.path.isfile(saved_location):
|
||||||
|
logger.experiment.log_artifact(logger.run_id,saved_location)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue