rename loss
This commit is contained in:
parent
00ef644179
commit
cb0040b508
|
|
@ -135,7 +135,7 @@ class Model(pl.LightningModule):
|
|||
loss = self.loss(prediction, target)
|
||||
|
||||
self.log(
|
||||
f"train_{self.loss.name}",
|
||||
"train_loss",
|
||||
loss.item(),
|
||||
on_epoch=True,
|
||||
on_step=True,
|
||||
|
|
@ -152,6 +152,7 @@ class Model(pl.LightningModule):
|
|||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
metric_dict["valid_loss"] = self.loss(target, prediction).item()
|
||||
for metric in self.metric:
|
||||
value = metric(target, prediction)
|
||||
metric_dict[f"valid_{metric.name}"] = value.item()
|
||||
|
|
|
|||
Loading…
Reference in New Issue