This commit is contained in:
shahules786 2022-10-13 10:38:58 +05:30
parent 2e58091543
commit 88112e6ae3
1 changed files with 16 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import logging
import numpy as np
import torch
import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
@ -116,17 +117,24 @@ class Stoi:
class Pesq:
def __init__(self, sr: int, mode="nb"):
def __init__(self, sr: int, mode="wb"):
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
self.sr = sr
self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
pesq_values = []
for pred, target_ in zip(prediction, target):
try:
return self.pesq(prediction, target)
pesq_values.append(
self.pesq(pred.squeeze(), target_.squeeze()).item()
)
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(0.0)
return torch.tensor(np.mean(pesq_values))
class LossWrapper(nn.Module):
@ -177,7 +185,7 @@ class LossWrapper(nn.Module):
LOSS_MAP = {
"mae": mean_absolute_error,
"mse": mean_squared_error,
"SI-SDR": Si_SDR,
"si-sdr": Si_SDR,
"pesq": Pesq,
"stoi": Stoi,
}