104 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			104 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| from types import MethodType
 | |
| 
 | |
| import hydra
 | |
| from hydra.utils import instantiate
 | |
| from omegaconf import DictConfig, OmegaConf
 | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
 | |
| from pytorch_lightning.loggers import MLFlowLogger
 | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau
 | |
| 
 | |
| os.environ["HYDRA_FULL_ERROR"] = "1"
 | |
| JOB_ID = os.environ.get("SLURM_JOBID", "0")
 | |
| 
 | |
| 
 | |
| @hydra.main(config_path="train_config", config_name="config")
 | |
| def main(config: DictConfig):
 | |
| 
 | |
|     OmegaConf.save(config, "config_log.yaml")
 | |
| 
 | |
|     callbacks = []
 | |
|     logger = MLFlowLogger(
 | |
|         experiment_name=config.mlflow.experiment_name,
 | |
|         run_name=config.mlflow.run_name,
 | |
|         tags={"JOB_ID": JOB_ID},
 | |
|     )
 | |
| 
 | |
|     parameters = config.hyperparameters
 | |
| 
 | |
|     dataset = instantiate(config.dataset)
 | |
|     model = instantiate(
 | |
|         config.model,
 | |
|         dataset=dataset,
 | |
|         lr=parameters.get("lr"),
 | |
|         loss=parameters.get("loss"),
 | |
|         metric=parameters.get("metric"),
 | |
|     )
 | |
| 
 | |
|     direction = model.valid_monitor
 | |
|     checkpoint = ModelCheckpoint(
 | |
|         dirpath="./model",
 | |
|         filename=f"model_{JOB_ID}",
 | |
|         monitor="valid_loss",
 | |
|         verbose=False,
 | |
|         mode=direction,
 | |
|         every_n_epochs=1,
 | |
|     )
 | |
|     callbacks.append(checkpoint)
 | |
|     if parameters.get("Early_stop", False):
 | |
|         early_stopping = EarlyStopping(
 | |
|             monitor=f"valid_{parameters.get('EarlyStopping_metric','loss')}",
 | |
|             mode=direction,
 | |
|             min_delta=parameters.get("EarlyStopping_delta", 0.00),
 | |
|             patience=parameters.get("EarlyStopping_patience", 10),
 | |
|             strict=True,
 | |
|             verbose=False,
 | |
|         )
 | |
|         callbacks.append(early_stopping)
 | |
| 
 | |
|     def configure_optimizer(self):
 | |
|         optimizer = instantiate(
 | |
|             config.optimizer,
 | |
|             lr=parameters.get("lr"),
 | |
|             parameters=self.parameters(),
 | |
|         )
 | |
|         scheduler = ReduceLROnPlateau(
 | |
|             optimizer=optimizer,
 | |
|             mode=direction,
 | |
|             factor=parameters.get("ReduceLr_factor", 0.1),
 | |
|             verbose=True,
 | |
|             min_lr=parameters.get("min_lr", 1e-6),
 | |
|             patience=parameters.get("ReduceLr_patience", 3),
 | |
|         )
 | |
|         return {"optimizer": optimizer, "lr_scheduler": scheduler}
 | |
| 
 | |
|     model.configure_parameters = MethodType(configure_optimizer, model)
 | |
| 
 | |
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
 | |
|     trainer.fit(model)
 | |
|     trainer.test(ckpt_path="best")
 | |
| 
 | |
|     logger.experiment.log_artifact(
 | |
|         logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
 | |
|     )
 | |
| 
 | |
|     saved_location = os.path.join(
 | |
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
 | |
|     )
 | |
|     if os.path.isfile(saved_location):
 | |
|         logger.experiment.log_artifact(logger.run_id, saved_location)
 | |
|         logger.experiment.log_param(
 | |
|             logger.run_id,
 | |
|             "num_train_steps_per_epoch",
 | |
|             dataset.train__len__() / dataset.batch_size,
 | |
|         )
 | |
|         logger.experiment.log_param(
 | |
|             logger.run_id,
 | |
|             "num_valid_steps_per_epoch",
 | |
|             dataset.val__len__() / dataset.batch_size,
 | |
|         )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |