diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index cb3c7c1..5ad61b1 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -32,7 +32,7 @@ def main(config: DictConfig): loss=parameters.get("loss"), metric=parameters.get("metric"), ) - + direction = model.valid_monitor checkpoint = ModelCheckpoint( dirpath="./model", @@ -79,6 +79,11 @@ def main(config: DictConfig): ) if os.path.isfile(saved_location): logger.experiment.log_artifact(logger.run_id, saved_location) + logger.experiment.log_param(logger.run_id, "num_train_steps_per_epoch", + dataset.train__len__() / dataset.batch_size) + logger.experiment.log_param(logger.run_id, "num_valid_steps_per_epoch", + dataset.val__len__() / dataset.batch_size) + if __name__ == "__main__":