Merge pull request #2 from shahules786/dev-hawk
include model artifacts logging
This commit is contained in:
		
						commit
						bf07cae301
					
				|  | @ -5,13 +5,14 @@ from omegaconf import DictConfig | |||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID") | ||||
| 
 | ||||
| @hydra.main(config_path="train_config",config_name="config") | ||||
| def main(config: DictConfig): | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, | ||||
|                             run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) | ||||
|                             run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) | ||||
| 
 | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
|  | @ -22,7 +23,7 @@ def main(config: DictConfig): | |||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="",filename="model",monitor="val_loss",verbose=False, | ||||
|         dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, | ||||
|         mode=direction,every_n_epochs=1 | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|  | @ -30,7 +31,7 @@ def main(config: DictConfig): | |||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=100, | ||||
|             patience=10, | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|  | @ -38,6 +39,7 @@ def main(config: DictConfig): | |||
| 
 | ||||
|     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| defaults: | ||||
|   - model : WaveUnet | ||||
|   - model : Demucs | ||||
|   - dataset : Vctk | ||||
|   - optimizer : Adam | ||||
|   - hyperparameters : default | ||||
|  |  | |||
|  | @ -1,11 +1,11 @@ | |||
| _target_: enhancer.models.demucs.Demucs | ||||
| num_channels: 1 | ||||
| resample: 4 | ||||
| resample: 2 | ||||
| sampling_rate : 16000 | ||||
| 
 | ||||
| encoder_decoder: | ||||
|   depth: 5 | ||||
|   initial_output_channels: 48 | ||||
|   initial_output_channels: 32 | ||||
|   kernel_size: 8 | ||||
|   stride: 1 | ||||
|   growth_factor: 2 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES