add pesq/stoi

This commit is contained in:
shahules786 2022-10-10 12:46:36 +05:30
parent 3e654d10a7
commit 5945ddccaa
1 changed files with 45 additions and 4 deletions

View File

@ -1,5 +1,9 @@
import logging
import torch
import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
class mean_squared_error(nn.Module):
@ -12,6 +16,7 @@ class mean_squared_error(nn.Module):
self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False
self.name = "mse"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module):
self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False
self.name = "mae"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module):
return self.loss_fun(prediction, target)
class Si_SDR(nn.Module):
class Si_SDR:
"""
SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
"""
def __init__(self, reduction: str = "mean"):
super().__init__()
if reduction in ["sum", "mean", None]:
self.reduction = reduction
else:
@ -60,8 +65,9 @@ class Si_SDR(nn.Module):
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = False
self.name = "Si-SDR"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
@ -90,7 +96,40 @@ class Si_SDR(nn.Module):
return si_sdr
class Avergeloss(nn.Module):
class Stoi:
"""
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to cpu to perform the metric calculation.
parameters:
sr: int
sampling rate
"""
def __init__(self, sr: int):
self.sr = sr
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
self.name = "stoi"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
return self.stoi(prediction, target)
class Pesq:
def __init__(self, sr: int, mode="nb"):
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
self.name = "pesq"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
try:
return self.pesq(prediction, target)
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return 0.0
class LossWrapper(nn.Module):
"""
Combine multiple metics of same nature.
for example, ["mea","mae"]
@ -137,4 +176,6 @@ LOSS_MAP = {
"mae": mean_absolute_error,
"mse": mean_squared_error,
"SI-SDR": Si_SDR,
"pesq": Pesq,
"stoi": Stoi,
}