diff --git a/enhancer/loss.py b/enhancer/loss.py index cdd15a5..5092656 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -1,5 +1,6 @@ import logging +import numpy as np import torch import torch.nn as nn from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality @@ -116,17 +117,24 @@ class Stoi: class Pesq: - def __init__(self, sr: int, mode="nb"): + def __init__(self, sr: int, mode="wb"): - self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) + self.sr = sr self.name = "pesq" + self.mode = mode + self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) def __call__(self, prediction: torch.Tensor, target: torch.Tensor): - try: - return self.pesq(prediction, target) - except Exception as e: - logging.warning(f"{e} error occured while calculating PESQ") - return torch.tensor(0.0) + + pesq_values = [] + for pred, target_ in zip(prediction, target): + try: + pesq_values.append( + self.pesq(pred.squeeze(), target_.squeeze()).item() + ) + except Exception as e: + logging.warning(f"{e} error occured while calculating PESQ") + return torch.tensor(np.mean(pesq_values)) class LossWrapper(nn.Module): @@ -177,7 +185,7 @@ class LossWrapper(nn.Module): LOSS_MAP = { "mae": mean_absolute_error, "mse": mean_squared_error, - "SI-SDR": Si_SDR, + "si-sdr": Si_SDR, "pesq": Pesq, "stoi": Stoi, }