revert to torchmetric pesq
This commit is contained in:
parent
1edc10e9f5
commit
c51dea6885
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pesq import pesq
|
||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||
|
||||
|
||||
|
|
@ -122,20 +122,16 @@ class Pesq:
|
|||
self.sr = sr
|
||||
self.name = "pesq"
|
||||
self.mode = mode
|
||||
self.pesq = PerceptualEvaluationSpeechQuality(
|
||||
fs=self.sr, mode=self.mode
|
||||
)
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
pesq_values = []
|
||||
for pred, target_ in zip(prediction, target):
|
||||
try:
|
||||
pesq_values.append(
|
||||
pesq(
|
||||
self.sr,
|
||||
target_.squeeze().detach().cpu().numpy(),
|
||||
pred.squeeze().detach().cpu().numpy(),
|
||||
self.mode,
|
||||
)
|
||||
)
|
||||
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||
except Exception as e:
|
||||
logging.warning(f"{e} error occured while calculating PESQ")
|
||||
return torch.tensor(np.mean(pesq_values))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ joblib>=1.2.0
|
|||
librosa>=0.9.2
|
||||
mlflow>=1.29.0
|
||||
numpy>=1.23.3
|
||||
git+https://github.com/ludlows/python-pesq#egg=pesq
|
||||
pesq==0.0.4
|
||||
protobuf>=3.19.6
|
||||
pystoi==0.3.3
|
||||
pytest-lazy-fixture>=0.6.3
|
||||
|
|
|
|||
Loading…
Reference in New Issue