add lr scheduler
This commit is contained in:
parent
bf07cae301
commit
b1e15c7552
16
cli/train.py
16
cli/train.py
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
|
|
@ -37,6 +39,20 @@ def main(config: DictConfig):
|
|||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizer(self):
|
||||
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=0.1,
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr",1e-6),
|
||||
patience=3
|
||||
)
|
||||
return {"optimizer":optimizer, "lr_scheduler":scheduler}
|
||||
|
||||
model.configure_parameters = MethodType(configure_optimizer,model)
|
||||
|
||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue