add callbacks
This commit is contained in:
parent
634f146ca7
commit
9ab996e5c7
28
cli/train.py
28
cli/train.py
|
|
@ -1,21 +1,14 @@
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.core.config_store import ConfigStore
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import DictConfig
|
||||||
from omegaconf import DictConfig,OmegaConf
|
|
||||||
from pytorch_lightning import Trainer
|
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
from enhancer.models.demucs import Demucs
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
||||||
run_name=config.mlflow.run_name)
|
run_name=config.mlflow.run_name)
|
||||||
|
|
||||||
|
|
@ -26,7 +19,22 @@ def main(config: DictConfig):
|
||||||
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
||||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||||
|
|
||||||
trainer = instantiate(config.trainer,logger=logger)
|
checkpoint = ModelCheckpoint(
|
||||||
|
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False,
|
||||||
|
mode="min",every_n_epochs=1
|
||||||
|
)
|
||||||
|
callbacks.append(checkpoint)
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor=parameters.get("loss"),
|
||||||
|
mode="min",
|
||||||
|
min_delta=0.0,
|
||||||
|
patience=100,
|
||||||
|
strict=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue