fix pesq
This commit is contained in:
parent
2e58091543
commit
88112e6ae3
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||
|
|
@ -116,17 +117,24 @@ class Stoi:
|
|||
|
||||
|
||||
class Pesq:
|
||||
def __init__(self, sr: int, mode="nb"):
|
||||
def __init__(self, sr: int, mode="wb"):
|
||||
|
||||
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||
self.sr = sr
|
||||
self.name = "pesq"
|
||||
self.mode = mode
|
||||
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
pesq_values = []
|
||||
for pred, target_ in zip(prediction, target):
|
||||
try:
|
||||
return self.pesq(prediction, target)
|
||||
pesq_values.append(
|
||||
self.pesq(pred.squeeze(), target_.squeeze()).item()
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"{e} error occured while calculating PESQ")
|
||||
return torch.tensor(0.0)
|
||||
return torch.tensor(np.mean(pesq_values))
|
||||
|
||||
|
||||
class LossWrapper(nn.Module):
|
||||
|
|
@ -177,7 +185,7 @@ class LossWrapper(nn.Module):
|
|||
LOSS_MAP = {
|
||||
"mae": mean_absolute_error,
|
||||
"mse": mean_squared_error,
|
||||
"SI-SDR": Si_SDR,
|
||||
"si-sdr": Si_SDR,
|
||||
"pesq": Pesq,
|
||||
"stoi": Stoi,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue