diff --git a/enhancer/loss.py b/enhancer/loss.py index db1d222..9ef90d2 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -1,5 +1,9 @@ +import logging + import torch import torch.nn as nn +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility class mean_squared_error(nn.Module): @@ -12,6 +16,7 @@ class mean_squared_error(nn.Module): self.loss_fun = nn.MSELoss(reduction=reduction) self.higher_better = False + self.name = "mse" def forward(self, prediction: torch.Tensor, target: torch.Tensor): @@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module): self.loss_fun = nn.L1Loss(reduction=reduction) self.higher_better = False + self.name = "mae" def forward(self, prediction: torch.Tensor, target: torch.Tensor): @@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module): return self.loss_fun(prediction, target) -class Si_SDR(nn.Module): +class Si_SDR: """ SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) """ def __init__(self, reduction: str = "mean"): - super().__init__() if reduction in ["sum", "mean", None]: self.reduction = reduction else: @@ -60,8 +65,9 @@ class Si_SDR(nn.Module): "Invalid reduction, valid options are sum, mean, None" ) self.higher_better = False + self.name = "Si-SDR" - def forward(self, prediction: torch.Tensor, target: torch.Tensor): + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( @@ -90,7 +96,40 @@ class Si_SDR(nn.Module): return si_sdr -class Avergeloss(nn.Module): +class Stoi: + """ + STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. + Note that input will be moved to cpu to perform the metric calculation. + parameters: + sr: int + sampling rate + """ + + def __init__(self, sr: int): + self.sr = sr + self.stoi = ShortTimeObjectiveIntelligibility(fs=sr) + self.name = "stoi" + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + + return self.stoi(prediction, target) + + +class Pesq: + def __init__(self, sr: int, mode="nb"): + + self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) + self.name = "pesq" + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + try: + return self.pesq(prediction, target) + except Exception as e: + logging.warning(f"{e} error occured while calculating PESQ") + return 0.0 + + +class LossWrapper(nn.Module): """ Combine multiple metics of same nature. for example, ["mea","mae"] @@ -137,4 +176,6 @@ LOSS_MAP = { "mae": mean_absolute_error, "mse": mean_squared_error, "SI-SDR": Si_SDR, + "pesq": Pesq, + "stoi": Stoi, }