diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 08f4d3e..398fa2b 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -76,7 +76,7 @@ def main(config: DictConfig): trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) - trainer.test(model) + trainer.test(ckpt_path="best") logger.experiment.log_artifact( logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"