add callbacks

This commit is contained in:
shahules786 2022-09-27 22:26:47 +05:30
parent 634f146ca7
commit 9ab996e5c7
1 changed files with 18 additions and 10 deletions

View File

@ -1,21 +1,14 @@
import hydra
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import DictConfig,OmegaConf
from pytorch_lightning import Trainer
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
import torch
from enhancer.models.demucs import Demucs
from enhancer.data.dataset import EnhancerDataset
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name)
@ -26,7 +19,22 @@ def main(config: DictConfig):
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
loss=parameters.get("loss"), metric = parameters.get("metric"))
trainer = instantiate(config.trainer,logger=logger)
checkpoint = ModelCheckpoint(
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False,
mode="min",every_n_epochs=1
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor=parameters.get("loss"),
mode="min",
min_delta=0.0,
patience=100,
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)