This commit is contained in:
shahules786 2022-10-13 10:38:58 +05:30
parent 2e58091543
commit 88112e6ae3
1 changed files with 16 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import logging import logging
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
@ -116,17 +117,24 @@ class Stoi:
class Pesq: class Pesq:
def __init__(self, sr: int, mode="nb"): def __init__(self, sr: int, mode="wb"):
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) self.sr = sr
self.name = "pesq" self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
try:
return self.pesq(prediction, target) pesq_values = []
except Exception as e: for pred, target_ in zip(prediction, target):
logging.warning(f"{e} error occured while calculating PESQ") try:
return torch.tensor(0.0) pesq_values.append(
self.pesq(pred.squeeze(), target_.squeeze()).item()
)
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values))
class LossWrapper(nn.Module): class LossWrapper(nn.Module):
@ -177,7 +185,7 @@ class LossWrapper(nn.Module):
LOSS_MAP = { LOSS_MAP = {
"mae": mean_absolute_error, "mae": mean_absolute_error,
"mse": mean_squared_error, "mse": mean_squared_error,
"SI-SDR": Si_SDR, "si-sdr": Si_SDR,
"pesq": Pesq, "pesq": Pesq,
"stoi": Stoi, "stoi": Stoi,
} }