log val loss

This commit is contained in:
shahules786 2022-09-30 10:10:35 +05:30
parent bd0bfbeea7
commit c717e7c38c
1 changed files with 11 additions and 4 deletions

View File

@ -86,7 +86,7 @@ class Model(pl.LightningModule):
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="train_loss", value=loss.item(),
step=self.global_step)
self.log("train_loss",loss.item())
return {"loss":loss}
def validation_step(self,batch,batch_idx:int):
@ -95,13 +95,20 @@ class Model(pl.LightningModule):
target = batch["clean"]
prediction = self(mixed_waveform)
loss = self.metric(prediction, target)
metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target)
self.log("val_metric",metric_val.item())
self.log("val_loss",loss_val.item())
if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_loss",value=loss.item(),
key="val_loss",value=loss_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_metric",value=metric_val.item(),
step=self.global_step)
return {"loss":loss}
return {"loss":loss_val}
def on_save_checkpoint(self, checkpoint):