From 2587d5b8675edf5239ebcf14265e80958e6764e5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 10 Oct 2022 12:47:24 +0530 Subject: [PATCH] add test step --- enhancer/models/model.py | 48 ++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 7ff15e4..07564cd 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -13,7 +13,7 @@ from torch.optim import Adam from enhancer.data.dataset import EnhancerDataset from enhancer.inference import Inference -from enhancer.loss import Avergeloss +from enhancer.loss import LOSS_MAP, LossWrapper from enhancer.version import __version__ CACHE_DIR = "" @@ -76,7 +76,7 @@ class Model(pl.LightningModule): if isinstance(loss, str): loss = [loss] - self._loss = Avergeloss(loss) + self._loss = LossWrapper(loss) @property def metric(self): @@ -84,11 +84,21 @@ class Model(pl.LightningModule): @metric.setter def metric(self, metric): - + self._metric = [] if isinstance(metric, str): metric = [metric] - self._metric = Avergeloss(metric) + for func in metric: + if func in LOSS_MAP.keys(): + if func in ("pesq", "stoi"): + self._metric.append( + LOSS_MAP[func](self.hparams.sampling_rate) + ) + else: + self._metric.append(LOSS_MAP[func]()) + + else: + raise ValueError(f"Invalid metrics {func}") @property def dataset(self): @@ -109,6 +119,9 @@ class Model(pl.LightningModule): def val_dataloader(self): return self.dataset.val_dataloader() + def test_dataloader(self): + return self.dataset.test_dataloader() + def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.lr) @@ -140,9 +153,7 @@ class Model(pl.LightningModule): target = batch["clean"] prediction = self(mixed_waveform) - metric_val = self.metric(prediction, target) loss_val = self.loss(prediction, target) - self.log("val_metric", metric_val.item()) self.log("val_loss", loss_val.item()) if ( @@ -156,15 +167,28 @@ class Model(pl.LightningModule): value=loss_val.item(), step=self.global_step, ) - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="val_metric", - value=metric_val.item(), - step=self.global_step, - ) return {"loss": loss_val} + def test_step(self, batch, batch_idx): + + metric_dict = {} + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + + for metric in self.metric: + value = metric(target, prediction) + metric_dict[metric.name] = value + + self.logger.experiment.log_metrics( + run_id=self.logger.run_id, + metrics=metric_dict, + step=self.global_step, + ) + + return metric_dict + def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = {