rename loss
This commit is contained in:
parent
00ef644179
commit
cb0040b508
|
|
@ -135,7 +135,7 @@ class Model(pl.LightningModule):
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
self.log(
|
self.log(
|
||||||
f"train_{self.loss.name}",
|
"train_loss",
|
||||||
loss.item(),
|
loss.item(),
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
on_step=True,
|
on_step=True,
|
||||||
|
|
@ -152,6 +152,7 @@ class Model(pl.LightningModule):
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
|
metric_dict["valid_loss"] = self.loss(target, prediction).item()
|
||||||
for metric in self.metric:
|
for metric in self.metric:
|
||||||
value = metric(target, prediction)
|
value = metric(target, prediction)
|
||||||
metric_dict[f"valid_{metric.name}"] = value.item()
|
metric_dict[f"valid_{metric.name}"] = value.item()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue