revert to torchmetric pesq
This commit is contained in:
parent
1edc10e9f5
commit
c51dea6885
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pesq import pesq
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,20 +122,16 @@ class Pesq:
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
self.name = "pesq"
|
self.name = "pesq"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
self.pesq = PerceptualEvaluationSpeechQuality(
|
||||||
|
fs=self.sr, mode=self.mode
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
pesq_values = []
|
pesq_values = []
|
||||||
for pred, target_ in zip(prediction, target):
|
for pred, target_ in zip(prediction, target):
|
||||||
try:
|
try:
|
||||||
pesq_values.append(
|
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||||
pesq(
|
|
||||||
self.sr,
|
|
||||||
target_.squeeze().detach().cpu().numpy(),
|
|
||||||
pred.squeeze().detach().cpu().numpy(),
|
|
||||||
self.mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"{e} error occured while calculating PESQ")
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
return torch.tensor(np.mean(pesq_values))
|
return torch.tensor(np.mean(pesq_values))
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ joblib>=1.2.0
|
||||||
librosa>=0.9.2
|
librosa>=0.9.2
|
||||||
mlflow>=1.29.0
|
mlflow>=1.29.0
|
||||||
numpy>=1.23.3
|
numpy>=1.23.3
|
||||||
git+https://github.com/ludlows/python-pesq#egg=pesq
|
pesq==0.0.4
|
||||||
protobuf>=3.19.6
|
protobuf>=3.19.6
|
||||||
pystoi==0.3.3
|
pystoi==0.3.3
|
||||||
pytest-lazy-fixture>=0.6.3
|
pytest-lazy-fixture>=0.6.3
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue