merge & test
This commit is contained in:
		
						commit
						b0483d2fa8
					
				
							
								
								
									
										23
									
								
								cli/train.py
								
								
								
								
							
							
						
						
									
										23
									
								
								cli/train.py
								
								
								
								
							|  | @ -1,7 +1,9 @@ | |||
| import os | ||||
| from types import MethodType | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
|  | @ -23,7 +25,7 @@ def main(config: DictConfig): | |||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, | ||||
|         dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, | ||||
|         mode=direction,every_n_epochs=1 | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|  | @ -31,15 +33,30 @@ def main(config: DictConfig): | |||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=10, | ||||
|             patience=parameters.get("EarlyStopping_patience",10), | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|     callbacks.append(early_stopping) | ||||
|      | ||||
|     def configure_optimizer(self): | ||||
|         optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor",0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr",1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience",3) | ||||
|         ) | ||||
|         return {"optimizer":optimizer, "lr_scheduler":scheduler} | ||||
| 
 | ||||
|     model.configure_parameters = MethodType(configure_optimizer,model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") | ||||
|     if os.path.exists("./model/"): | ||||
|         logger.experiment.log_artifact(logger.run_id,f"./model/.*") | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| defaults: | ||||
|   - model : Demucs | ||||
|   - model : WaveUnet | ||||
|   - dataset : Vctk | ||||
|   - optimizer : Adam | ||||
|   - hyperparameters : default | ||||
|  |  | |||
|  | @ -1,3 +1,9 @@ | |||
| loss : mse | ||||
| metric : mae | ||||
| lr : 0.0001 | ||||
| num_epochs : 100 | ||||
| ReduceLr_patience : 5 | ||||
| ReduceLr_factor : 0.1 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,8 +22,8 @@ limit_predict_batches: 1.0 | |||
| limit_test_batches: 1.0 | ||||
| limit_train_batches: 1.0 | ||||
| limit_val_batches: 1.0 | ||||
| log_every_n_steps: 50 | ||||
| max_epochs: 500 | ||||
| log_every_n_steps: 10 | ||||
| max_epochs: 30 | ||||
| max_steps: null | ||||
| max_time: null | ||||
| min_epochs: 1 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786