diff --git a/cli/train.py b/cli/train.py index 9aa497d..5a056a2 100644 --- a/cli/train.py +++ b/cli/train.py @@ -3,11 +3,14 @@ from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger +from pytorch_lightning.callbacks import TQDMProgressBar + @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] + callbacks.append(TQDMProgressBar(refresh_rate=10)) logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, run_name=config.mlflow.run_name)