diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 3f74a74..8eb19a8 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -132,45 +132,39 @@ class Model(pl.LightningModule): mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) - loss = self.loss(prediction, target) - if ( - (self.logger) - and (self.global_step > 50) - and (self.global_step % 50 == 0) - ): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="train_loss", - value=loss.item(), - step=self.global_step, - ) - self.log("train_loss", loss.item()) + self.log( + "train_loss", + loss.item(), + on_epoch=True, + on_step=True, + logger=True, + prog_bar=True, + ) + return {"loss": loss} def validation_step(self, batch, batch_idx: int): + metric_dict = {} mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) - loss_val = self.loss(prediction, target) - self.log("val_loss", loss_val.item()) + for metric in self.metric: + value = metric(target, prediction) + metric_dict[f"valid_{metric.name}"] = value.item() - if ( - (self.logger) - and (self.global_step > 50) - and (self.global_step % 50 == 0) - ): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="val_loss", - value=loss_val.item(), - step=self.global_step, - ) + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) - return {"loss": loss_val} + return metric_dict def test_step(self, batch, batch_idx): @@ -183,44 +177,16 @@ class Model(pl.LightningModule): value = metric(target, prediction) metric_dict[metric.name] = value - for k, v in metric_dict.items(): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key=k, - value=v, - step=self.global_step, - ) + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) return metric_dict - def training_epoch_end(self, outputs): - train_mean_loss = 0.0 - for output in outputs: - train_mean_loss += output["loss"] - train_mean_loss /= len(outputs) - - if self.logger: - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="train_loss_epoch", - value=train_mean_loss, - step=self.current_epoch, - ) - - def validation_epoch_end(self, outputs): - valid_mean_loss = 0.0 - for output in outputs: - valid_mean_loss += output["loss"] - valid_mean_loss /= len(outputs) - - if self.logger: - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="valid_loss_epoch", - value=valid_mean_loss, - step=self.current_epoch, - ) - def test_epoch_end(self, outputs): test_mean_metrics = defaultdict(int)