add lr scheduler

This commit is contained in:
shahules786 2022-10-02 18:12:21 +05:30
parent bf07cae301
commit b1e15c7552
1 changed files with 16 additions and 0 deletions

View File

@ -1,7 +1,9 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
@ -37,6 +39,20 @@ def main(config: DictConfig):
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=0.1,
verbose=True,
min_lr=parameters.get("min_lr",1e-6),
patience=3
)
return {"optimizer":optimizer, "lr_scheduler":scheduler}
model.configure_parameters = MethodType(configure_optimizer,model)
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")