add pesq/stoi
This commit is contained in:
parent
3e654d10a7
commit
5945ddccaa
|
|
@ -1,5 +1,9 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||
|
||||
|
||||
class mean_squared_error(nn.Module):
|
||||
|
|
@ -12,6 +16,7 @@ class mean_squared_error(nn.Module):
|
|||
|
||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||
self.higher_better = False
|
||||
self.name = "mse"
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
|
|
@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module):
|
|||
|
||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||
self.higher_better = False
|
||||
self.name = "mae"
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
|
|
@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module):
|
|||
return self.loss_fun(prediction, target)
|
||||
|
||||
|
||||
class Si_SDR(nn.Module):
|
||||
class Si_SDR:
|
||||
"""
|
||||
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
||||
"""
|
||||
|
||||
def __init__(self, reduction: str = "mean"):
|
||||
super().__init__()
|
||||
if reduction in ["sum", "mean", None]:
|
||||
self.reduction = reduction
|
||||
else:
|
||||
|
|
@ -60,8 +65,9 @@ class Si_SDR(nn.Module):
|
|||
"Invalid reduction, valid options are sum, mean, None"
|
||||
)
|
||||
self.higher_better = False
|
||||
self.name = "Si-SDR"
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
if prediction.size() != target.size() or target.ndim < 3:
|
||||
raise TypeError(
|
||||
|
|
@ -90,7 +96,40 @@ class Si_SDR(nn.Module):
|
|||
return si_sdr
|
||||
|
||||
|
||||
class Avergeloss(nn.Module):
|
||||
class Stoi:
|
||||
"""
|
||||
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
|
||||
Note that input will be moved to cpu to perform the metric calculation.
|
||||
parameters:
|
||||
sr: int
|
||||
sampling rate
|
||||
"""
|
||||
|
||||
def __init__(self, sr: int):
|
||||
self.sr = sr
|
||||
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
|
||||
self.name = "stoi"
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
return self.stoi(prediction, target)
|
||||
|
||||
|
||||
class Pesq:
|
||||
def __init__(self, sr: int, mode="nb"):
|
||||
|
||||
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||
self.name = "pesq"
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
try:
|
||||
return self.pesq(prediction, target)
|
||||
except Exception as e:
|
||||
logging.warning(f"{e} error occured while calculating PESQ")
|
||||
return 0.0
|
||||
|
||||
|
||||
class LossWrapper(nn.Module):
|
||||
"""
|
||||
Combine multiple metics of same nature.
|
||||
for example, ["mea","mae"]
|
||||
|
|
@ -137,4 +176,6 @@ LOSS_MAP = {
|
|||
"mae": mean_absolute_error,
|
||||
"mse": mean_squared_error,
|
||||
"SI-SDR": Si_SDR,
|
||||
"pesq": Pesq,
|
||||
"stoi": Stoi,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue