revert to torchmetric pesq

This commit is contained in:
shahules786 2022-10-26 21:46:19 +05:30
parent 1edc10e9f5
commit c51dea6885
2 changed files with 6 additions and 10 deletions

View File

@ -3,7 +3,7 @@ import logging
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from pesq import pesq from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
@ -122,20 +122,16 @@ class Pesq:
self.sr = sr self.sr = sr
self.name = "pesq" self.name = "pesq"
self.mode = mode self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(
fs=self.sr, mode=self.mode
)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
pesq_values = [] pesq_values = []
for pred, target_ in zip(prediction, target): for pred, target_ in zip(prediction, target):
try: try:
pesq_values.append( pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
pesq(
self.sr,
target_.squeeze().detach().cpu().numpy(),
pred.squeeze().detach().cpu().numpy(),
self.mode,
)
)
except Exception as e: except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ") logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values)) return torch.tensor(np.mean(pesq_values))

View File

@ -5,7 +5,7 @@ joblib>=1.2.0
librosa>=0.9.2 librosa>=0.9.2
mlflow>=1.29.0 mlflow>=1.29.0
numpy>=1.23.3 numpy>=1.23.3
git+https://github.com/ludlows/python-pesq#egg=pesq pesq==0.0.4
protobuf>=3.19.6 protobuf>=3.19.6
pystoi==0.3.3 pystoi==0.3.3
pytest-lazy-fixture>=0.6.3 pytest-lazy-fixture>=0.6.3