log val loss
This commit is contained in:
parent
bd0bfbeea7
commit
c717e7c38c
|
|
@ -86,7 +86,7 @@ class Model(pl.LightningModule):
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||||
key="train_loss", value=loss.item(),
|
key="train_loss", value=loss.item(),
|
||||||
step=self.global_step)
|
step=self.global_step)
|
||||||
|
self.log("train_loss",loss.item())
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
def validation_step(self,batch,batch_idx:int):
|
def validation_step(self,batch,batch_idx:int):
|
||||||
|
|
@ -95,13 +95,20 @@ class Model(pl.LightningModule):
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
loss = self.metric(prediction, target)
|
metric_val = self.metric(prediction, target)
|
||||||
|
loss_val = self.loss(prediction, target)
|
||||||
|
self.log("val_metric",metric_val.item())
|
||||||
|
self.log("val_loss",loss_val.item())
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||||
key="val_loss",value=loss.item(),
|
key="val_loss",value=loss_val.item(),
|
||||||
|
step=self.global_step)
|
||||||
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||||
|
key="val_metric",value=metric_val.item(),
|
||||||
step=self.global_step)
|
step=self.global_step)
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss_val}
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue