diff --git a/cli/train.py b/cli/train.py index a5c83f0..391df13 100644 --- a/cli/train.py +++ b/cli/train.py @@ -5,13 +5,14 @@ from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) + run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) parameters = config.hyperparameters @@ -22,7 +23,7 @@ def main(config: DictConfig): direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor="val_loss",verbose=False, + dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) @@ -30,7 +31,7 @@ def main(config: DictConfig): monitor="val_loss", mode=direction, min_delta=0.0, - patience=100, + patience=10, strict=True, verbose=False, ) @@ -38,6 +39,7 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) + logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")