diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 20c8196..8e607ed 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -86,7 +86,7 @@ class Model(pl.LightningModule): self.logger.experiment.log_metric(run_id=self.logger.run_id, key="train_loss", value=loss.item(), step=self.global_step) - + self.log("train_loss",loss.item()) return {"loss":loss} def validation_step(self,batch,batch_idx:int): @@ -95,13 +95,20 @@ class Model(pl.LightningModule): target = batch["clean"] prediction = self(mixed_waveform) - loss = self.metric(prediction, target) + metric_val = self.metric(prediction, target) + loss_val = self.loss(prediction, target) + self.log("val_metric",metric_val.item()) + self.log("val_loss",loss_val.item()) + if self.logger: self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_loss",value=loss.item(), + key="val_loss",value=loss_val.item(), + step=self.global_step) + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="val_metric",value=metric_val.item(), step=self.global_step) - return {"loss":loss} + return {"loss":loss_val} def on_save_checkpoint(self, checkpoint):