diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 6e6b4e1..cbbfad8 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -120,7 +120,11 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) - if self.logger: + if ( + (self.logger) + and (self.global_step > 50) + and (self.global_step % 50 == 0) + ): self.logger.experiment.log_metric( run_id=self.logger.run_id, key="train_loss", @@ -141,7 +145,11 @@ class Model(pl.LightningModule): self.log("val_metric", metric_val.item()) self.log("val_loss", loss_val.item()) - if self.logger: + if ( + (self.logger) + and (self.global_step > 50) + and (self.global_step % 50 == 0) + ): self.logger.experiment.log_metric( run_id=self.logger.run_id, key="val_loss",