Merge pull request #2 from shahules786/dev-hawk
include model artifacts logging
This commit is contained in:
		
						commit
						bf07cae301
					
				|  | @ -5,13 +5,14 @@ from omegaconf import DictConfig | ||||||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||||
| from pytorch_lightning.loggers import MLFlowLogger | from pytorch_lightning.loggers import MLFlowLogger | ||||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | os.environ["HYDRA_FULL_ERROR"] = "1" | ||||||
|  | JOB_ID = os.environ.get("SLURM_JOBID") | ||||||
| 
 | 
 | ||||||
| @hydra.main(config_path="train_config",config_name="config") | @hydra.main(config_path="train_config",config_name="config") | ||||||
| def main(config: DictConfig): | def main(config: DictConfig): | ||||||
| 
 | 
 | ||||||
|     callbacks = [] |     callbacks = [] | ||||||
|     logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, |     logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, | ||||||
|                             run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) |                             run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     parameters = config.hyperparameters |     parameters = config.hyperparameters | ||||||
|  | @ -22,7 +23,7 @@ def main(config: DictConfig): | ||||||
| 
 | 
 | ||||||
|     direction = model.valid_monitor |     direction = model.valid_monitor | ||||||
|     checkpoint = ModelCheckpoint( |     checkpoint = ModelCheckpoint( | ||||||
|         dirpath="",filename="model",monitor="val_loss",verbose=False, |         dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, | ||||||
|         mode=direction,every_n_epochs=1 |         mode=direction,every_n_epochs=1 | ||||||
|     ) |     ) | ||||||
|     callbacks.append(checkpoint) |     callbacks.append(checkpoint) | ||||||
|  | @ -30,7 +31,7 @@ def main(config: DictConfig): | ||||||
|             monitor="val_loss", |             monitor="val_loss", | ||||||
|             mode=direction, |             mode=direction, | ||||||
|             min_delta=0.0, |             min_delta=0.0, | ||||||
|             patience=100, |             patience=10, | ||||||
|             strict=True, |             strict=True, | ||||||
|             verbose=False, |             verbose=False, | ||||||
|         ) |         ) | ||||||
|  | @ -38,6 +39,7 @@ def main(config: DictConfig): | ||||||
| 
 | 
 | ||||||
|     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) |     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) | ||||||
|     trainer.fit(model) |     trainer.fit(model) | ||||||
|  |     logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| defaults: | defaults: | ||||||
|   - model : WaveUnet |   - model : Demucs | ||||||
|   - dataset : Vctk |   - dataset : Vctk | ||||||
|   - optimizer : Adam |   - optimizer : Adam | ||||||
|   - hyperparameters : default |   - hyperparameters : default | ||||||
|  |  | ||||||
|  | @ -1,11 +1,11 @@ | ||||||
| _target_: enhancer.models.demucs.Demucs | _target_: enhancer.models.demucs.Demucs | ||||||
| num_channels: 1 | num_channels: 1 | ||||||
| resample: 4 | resample: 2 | ||||||
| sampling_rate : 16000 | sampling_rate : 16000 | ||||||
| 
 | 
 | ||||||
| encoder_decoder: | encoder_decoder: | ||||||
|   depth: 5 |   depth: 5 | ||||||
|   initial_output_channels: 48 |   initial_output_channels: 32 | ||||||
|   kernel_size: 8 |   kernel_size: 8 | ||||||
|   stride: 1 |   stride: 1 | ||||||
|   growth_factor: 2 |   growth_factor: 2 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES