diff --git a/cli/train.py b/cli/train.py index 391df13..8a4beaf 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,7 +1,9 @@ import os +from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig +from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" @@ -23,7 +25,7 @@ def main(config: DictConfig): direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False, + dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) @@ -31,15 +33,30 @@ def main(config: DictConfig): monitor="val_loss", mode=direction, min_delta=0.0, - patience=10, + patience=parameters.get("EarlyStopping_patience",10), strict=True, verbose=False, ) callbacks.append(early_stopping) + + def configure_optimizer(self): + optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + mode=direction, + factor=parameters.get("ReduceLr_factor",0.1), + verbose=True, + min_lr=parameters.get("min_lr",1e-6), + patience=parameters.get("ReduceLr_patience",3) + ) + return {"optimizer":optimizer, "lr_scheduler":scheduler} + + model.configure_parameters = MethodType(configure_optimizer,model) trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}") + if os.path.exists("./model/"): + logger.experiment.log_artifact(logger.run_id,f"./model/.*") diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index a0ac704..49a7f80 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,3 +1,9 @@ loss : mse metric : mae lr : 0.0001 +num_epochs : 100 +ReduceLr_patience : 5 +ReduceLr_factor : 0.1 +min_lr : 0.000001 +EarlyStopping_factor : 10 + diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index 6c693d8..2aa5083 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -22,8 +22,8 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 50 -max_epochs: 500 +log_every_n_steps: 10 +max_epochs: 30 max_steps: null max_time: null min_epochs: 1