fix logging path
This commit is contained in:
		
							parent
							
								
									ecd47905dd
								
							
						
					
					
						commit
						a880125322
					
				|  | @ -1,3 +1,4 @@ | |||
| from genericpath import isfile | ||||
| import os | ||||
| from types import MethodType | ||||
| import hydra | ||||
|  | @ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau | |||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID") | ||||
| JOB_ID = os.environ.get("SLURM_JOBID","0") | ||||
| 
 | ||||
| @hydra.main(config_path="train_config",config_name="config") | ||||
| def main(config: DictConfig): | ||||
|  | @ -55,8 +56,10 @@ def main(config: DictConfig): | |||
| 
 | ||||
|     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     if os.path.exists("./model/"): | ||||
|         logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt") | ||||
| 
 | ||||
|     saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id,saved_location) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786