fix optimizer scheduler

This commit is contained in:
shahules786 2022-10-26 10:26:27 +05:30
parent 58de41598e
commit 04782ba6e9
1 changed files with 16 additions and 7 deletions

View File

@ -4,10 +4,14 @@ from types import MethodType
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_audiomentations import BandPassFilter, Compose, Shift from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1" os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0") JOB_ID = os.environ.get("SLURM_JOBID", "0")
@ -29,7 +33,6 @@ def main(config: DictConfig):
apply_augmentations = Compose( apply_augmentations = Compose(
[ [
Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5), Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5),
BandPassFilter(p=0.5),
] ]
) )
@ -52,6 +55,8 @@ def main(config: DictConfig):
every_n_epochs=1, every_n_epochs=1,
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False): if parameters.get("Early_stop", False):
early_stopping = EarlyStopping( early_stopping = EarlyStopping(
monitor="val_loss", monitor="val_loss",
@ -63,11 +68,11 @@ def main(config: DictConfig):
) )
callbacks.append(early_stopping) callbacks.append(early_stopping)
def configure_optimizer(self): def configure_optimizers(self):
optimizer = instantiate( optimizer = instantiate(
config.optimizer, config.optimizer,
lr=parameters.get("lr"), lr=parameters.get("lr"),
parameters=self.parameters(), params=self.parameters(),
) )
scheduler = ReduceLROnPlateau( scheduler = ReduceLROnPlateau(
optimizer=optimizer, optimizer=optimizer,
@ -77,9 +82,13 @@ def main(config: DictConfig):
min_lr=parameters.get("min_lr", 1e-6), min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3), patience=parameters.get("ReduceLr_patience", 3),
) )
return {"optimizer": optimizer, "lr_scheduler": scheduler} return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_parameters = MethodType(configure_optimizer, model) model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model) trainer.fit(model)