add pesq/stoi
This commit is contained in:
parent
3e654d10a7
commit
5945ddccaa
|
|
@ -1,5 +1,9 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
||||||
class mean_squared_error(nn.Module):
|
class mean_squared_error(nn.Module):
|
||||||
|
|
@ -12,6 +16,7 @@ class mean_squared_error(nn.Module):
|
||||||
|
|
||||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
self.name = "mse"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module):
|
||||||
|
|
||||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
self.name = "mae"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module):
|
||||||
return self.loss_fun(prediction, target)
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
class Si_SDR(nn.Module):
|
class Si_SDR:
|
||||||
"""
|
"""
|
||||||
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reduction: str = "mean"):
|
def __init__(self, reduction: str = "mean"):
|
||||||
super().__init__()
|
|
||||||
if reduction in ["sum", "mean", None]:
|
if reduction in ["sum", "mean", None]:
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
|
|
@ -60,8 +65,9 @@ class Si_SDR(nn.Module):
|
||||||
"Invalid reduction, valid options are sum, mean, None"
|
"Invalid reduction, valid options are sum, mean, None"
|
||||||
)
|
)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
self.name = "Si-SDR"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
@ -90,7 +96,40 @@ class Si_SDR(nn.Module):
|
||||||
return si_sdr
|
return si_sdr
|
||||||
|
|
||||||
|
|
||||||
class Avergeloss(nn.Module):
|
class Stoi:
|
||||||
|
"""
|
||||||
|
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
|
||||||
|
Note that input will be moved to cpu to perform the metric calculation.
|
||||||
|
parameters:
|
||||||
|
sr: int
|
||||||
|
sampling rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sr: int):
|
||||||
|
self.sr = sr
|
||||||
|
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
|
||||||
|
self.name = "stoi"
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
return self.stoi(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
class Pesq:
|
||||||
|
def __init__(self, sr: int, mode="nb"):
|
||||||
|
|
||||||
|
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||||
|
self.name = "pesq"
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
try:
|
||||||
|
return self.pesq(prediction, target)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class LossWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Combine multiple metics of same nature.
|
Combine multiple metics of same nature.
|
||||||
for example, ["mea","mae"]
|
for example, ["mea","mae"]
|
||||||
|
|
@ -137,4 +176,6 @@ LOSS_MAP = {
|
||||||
"mae": mean_absolute_error,
|
"mae": mean_absolute_error,
|
||||||
"mse": mean_squared_error,
|
"mse": mean_squared_error,
|
||||||
"SI-SDR": Si_SDR,
|
"SI-SDR": Si_SDR,
|
||||||
|
"pesq": Pesq,
|
||||||
|
"stoi": Stoi,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue