From 9170666592dbba76035e2cc3237fcc5a7c60ed50 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 10:55:39 +0530 Subject: [PATCH 1/2] reduce params --- cli/train_config/config.yaml | 2 +- cli/train_config/model/Demucs.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 61551bd..6b5d98e 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : WaveUnet + - model : Demucs - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/model/Demucs.yaml b/cli/train_config/model/Demucs.yaml index 27603dc..1006e71 100644 --- a/cli/train_config/model/Demucs.yaml +++ b/cli/train_config/model/Demucs.yaml @@ -1,11 +1,11 @@ _target_: enhancer.models.demucs.Demucs num_channels: 1 -resample: 4 +resample: 2 sampling_rate : 16000 encoder_decoder: depth: 5 - initial_output_channels: 48 + initial_output_channels: 32 kernel_size: 8 stride: 1 growth_factor: 2 From d25be3a59dfdfd782faadaa85b57b2ee2170c895 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 09:02:46 +0530 Subject: [PATCH 2/2] log model artifacts --- cli/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cli/train.py b/cli/train.py index a5c83f0..391df13 100644 --- a/cli/train.py +++ b/cli/train.py @@ -5,13 +5,14 @@ from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) + run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) parameters = config.hyperparameters @@ -22,7 +23,7 @@ def main(config: DictConfig): direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor="val_loss",verbose=False, + dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) @@ -30,7 +31,7 @@ def main(config: DictConfig): monitor="val_loss", mode=direction, min_delta=0.0, - patience=100, + patience=10, strict=True, verbose=False, ) @@ -38,6 +39,7 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) + logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")