From c51dea68859f20458801b2e4c8fd0a73f690175a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 21:46:19 +0530 Subject: [PATCH] revert to torchmetric pesq --- enhancer/loss.py | 14 +++++--------- requirements.txt | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index fc8afae..32b30cf 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch import torch.nn as nn -from pesq import pesq +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility @@ -122,20 +122,16 @@ class Pesq: self.sr = sr self.name = "pesq" self.mode = mode + self.pesq = PerceptualEvaluationSpeechQuality( + fs=self.sr, mode=self.mode + ) def __call__(self, prediction: torch.Tensor, target: torch.Tensor): pesq_values = [] for pred, target_ in zip(prediction, target): try: - pesq_values.append( - pesq( - self.sr, - target_.squeeze().detach().cpu().numpy(), - pred.squeeze().detach().cpu().numpy(), - self.mode, - ) - ) + pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) except Exception as e: logging.warning(f"{e} error occured while calculating PESQ") return torch.tensor(np.mean(pesq_values)) diff --git a/requirements.txt b/requirements.txt index cf8992d..fb54920 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ joblib>=1.2.0 librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 -git+https://github.com/ludlows/python-pesq#egg=pesq +pesq==0.0.4 protobuf>=3.19.6 pystoi==0.3.3 pytest-lazy-fixture>=0.6.3