fix logging
This commit is contained in:
parent
e389acefa0
commit
144b9d6128
|
|
@ -132,45 +132,39 @@ class Model(pl.LightningModule):
|
|||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
loss = self.loss(prediction, target)
|
||||
|
||||
if (
|
||||
(self.logger)
|
||||
and (self.global_step > 50)
|
||||
and (self.global_step % 50 == 0)
|
||||
):
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="train_loss",
|
||||
value=loss.item(),
|
||||
step=self.global_step,
|
||||
self.log(
|
||||
"train_loss",
|
||||
loss.item(),
|
||||
on_epoch=True,
|
||||
on_step=True,
|
||||
logger=True,
|
||||
prog_bar=True,
|
||||
)
|
||||
self.log("train_loss", loss.item())
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
|
||||
metric_dict = {}
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
loss_val = self.loss(prediction, target)
|
||||
self.log("val_loss", loss_val.item())
|
||||
for metric in self.metric:
|
||||
value = metric(target, prediction)
|
||||
metric_dict[f"valid_{metric.name}"] = value.item()
|
||||
|
||||
if (
|
||||
(self.logger)
|
||||
and (self.global_step > 50)
|
||||
and (self.global_step % 50 == 0)
|
||||
):
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_loss",
|
||||
value=loss_val.item(),
|
||||
step=self.global_step,
|
||||
self.log_dict(
|
||||
metric_dict,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
|
||||
return {"loss": loss_val}
|
||||
return metric_dict
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
|
||||
|
|
@ -183,44 +177,16 @@ class Model(pl.LightningModule):
|
|||
value = metric(target, prediction)
|
||||
metric_dict[metric.name] = value
|
||||
|
||||
for k, v in metric_dict.items():
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key=k,
|
||||
value=v,
|
||||
step=self.global_step,
|
||||
self.log_dict(
|
||||
metric_dict,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
|
||||
return metric_dict
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
train_mean_loss = 0.0
|
||||
for output in outputs:
|
||||
train_mean_loss += output["loss"]
|
||||
train_mean_loss /= len(outputs)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="train_loss_epoch",
|
||||
value=train_mean_loss,
|
||||
step=self.current_epoch,
|
||||
)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
valid_mean_loss = 0.0
|
||||
for output in outputs:
|
||||
valid_mean_loss += output["loss"]
|
||||
valid_mean_loss /= len(outputs)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="valid_loss_epoch",
|
||||
value=valid_mean_loss,
|
||||
step=self.current_epoch,
|
||||
)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
|
||||
test_mean_metrics = defaultdict(int)
|
||||
|
|
|
|||
Loading…
Reference in New Issue