diff --git a/cli/train.py b/cli/train.py index 7d49acc..8a4beaf 100644 --- a/cli/train.py +++ b/cli/train.py @@ -25,7 +25,7 @@ def main(config: DictConfig): direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, + dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) @@ -55,7 +55,8 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") + if os.path.exists("./model/"): + logger.experiment.log_artifact(logger.run_id,f"./model/.*")