44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import hydra
|
|
from hydra.utils import instantiate
|
|
from omegaconf import DictConfig
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
|
from pytorch_lightning.loggers import MLFlowLogger
|
|
|
|
from enhancer.data.dataset import EnhancerDataset
|
|
|
|
@hydra.main(config_path="train_config",config_name="config")
|
|
def main(config: DictConfig):
|
|
|
|
callbacks = []
|
|
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
|
run_name=config.mlflow.run_name)
|
|
|
|
|
|
parameters = config.hyperparameters
|
|
|
|
dataset = instantiate(config.dataset)
|
|
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
|
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
|
|
|
checkpoint = ModelCheckpoint(
|
|
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False,
|
|
mode="min",every_n_epochs=1
|
|
)
|
|
callbacks.append(checkpoint)
|
|
early_stopping = EarlyStopping(
|
|
monitor=parameters.get("loss"),
|
|
mode="min",
|
|
min_delta=0.0,
|
|
patience=100,
|
|
strict=True,
|
|
verbose=False,
|
|
)
|
|
callbacks.append(early_stopping)
|
|
|
|
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
if __name__=="__main__":
|
|
main() |