fix optimizer scheduler
This commit is contained in:
parent
58de41598e
commit
04782ba6e9
|
|
@ -4,10 +4,14 @@ from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import (
|
||||||
|
EarlyStopping,
|
||||||
|
LearningRateMonitor,
|
||||||
|
ModelCheckpoint,
|
||||||
|
)
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torch_audiomentations import BandPassFilter, Compose, Shift
|
from torch_audiomentations import Compose, Shift
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
@ -29,7 +33,6 @@ def main(config: DictConfig):
|
||||||
apply_augmentations = Compose(
|
apply_augmentations = Compose(
|
||||||
[
|
[
|
||||||
Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5),
|
Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||||
BandPassFilter(p=0.5),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -52,6 +55,8 @@ def main(config: DictConfig):
|
||||||
every_n_epochs=1,
|
every_n_epochs=1,
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
|
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
if parameters.get("Early_stop", False):
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
|
|
@ -63,11 +68,11 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
callbacks.append(early_stopping)
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
def configure_optimizer(self):
|
def configure_optimizers(self):
|
||||||
optimizer = instantiate(
|
optimizer = instantiate(
|
||||||
config.optimizer,
|
config.optimizer,
|
||||||
lr=parameters.get("lr"),
|
lr=parameters.get("lr"),
|
||||||
parameters=self.parameters(),
|
params=self.parameters(),
|
||||||
)
|
)
|
||||||
scheduler = ReduceLROnPlateau(
|
scheduler = ReduceLROnPlateau(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
|
@ -77,9 +82,13 @@ def main(config: DictConfig):
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
min_lr=parameters.get("min_lr", 1e-6),
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
patience=parameters.get("ReduceLr_patience", 3),
|
||||||
)
|
)
|
||||||
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": scheduler,
|
||||||
|
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||||
|
}
|
||||||
|
|
||||||
model.configure_parameters = MethodType(configure_optimizer, model)
|
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue