diff --git a/cli/train.py b/cli/train.py index a5c83f0..391df13 100644 --- a/cli/train.py +++ b/cli/train.py @@ -5,13 +5,14 @@ from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) + run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) parameters = config.hyperparameters @@ -22,7 +23,7 @@ def main(config: DictConfig): direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor="val_loss",verbose=False, + dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) @@ -30,7 +31,7 @@ def main(config: DictConfig): monitor="val_loss", mode=direction, min_delta=0.0, - patience=100, + patience=10, strict=True, verbose=False, ) @@ -38,6 +39,7 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) + logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 61551bd..6b5d98e 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : WaveUnet + - model : Demucs - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/model/Demucs.yaml b/cli/train_config/model/Demucs.yaml index 27603dc..1006e71 100644 --- a/cli/train_config/model/Demucs.yaml +++ b/cli/train_config/model/Demucs.yaml @@ -1,11 +1,11 @@ _target_: enhancer.models.demucs.Demucs num_channels: 1 -resample: 4 +resample: 2 sampling_rate : 16000 encoder_decoder: depth: 5 - initial_output_channels: 48 + initial_output_channels: 32 kernel_size: 8 stride: 1 growth_factor: 2