fix optimizer scheduler

This commit is contained in:
shahules786 2022-10-26 10:26:27 +05:30
parent 58de41598e
commit 04782ba6e9
1 changed files with 16 additions and 7 deletions

View File

@ -4,10 +4,14 @@ from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_audiomentations import BandPassFilter, Compose, Shift
from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@ -29,7 +33,6 @@ def main(config: DictConfig):
apply_augmentations = Compose(
[
Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5),
BandPassFilter(p=0.5),
]
)
@ -52,6 +55,8 @@ def main(config: DictConfig):
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
@ -63,11 +68,11 @@ def main(config: DictConfig):
)
callbacks.append(early_stopping)
def configure_optimizer(self):
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
parameters=self.parameters(),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
@ -77,9 +82,13 @@ def main(config: DictConfig):
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_parameters = MethodType(configure_optimizer, model)
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)