num steps
This commit is contained in:
parent
1b9b6c9a9f
commit
1c81d629c4
|
|
@ -32,7 +32,7 @@ def main(config: DictConfig):
|
||||||
loss=parameters.get("loss"),
|
loss=parameters.get("loss"),
|
||||||
metric=parameters.get("metric"),
|
metric=parameters.get("metric"),
|
||||||
)
|
)
|
||||||
|
|
||||||
direction = model.valid_monitor
|
direction = model.valid_monitor
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="./model",
|
dirpath="./model",
|
||||||
|
|
@ -79,6 +79,11 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
if os.path.isfile(saved_location):
|
if os.path.isfile(saved_location):
|
||||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||||
|
logger.experiment.log_param(logger.run_id, "num_train_steps_per_epoch",
|
||||||
|
dataset.train__len__() / dataset.batch_size)
|
||||||
|
logger.experiment.log_param(logger.run_id, "num_valid_steps_per_epoch",
|
||||||
|
dataset.val__len__() / dataset.batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue