From 82308750dc67758368e58281c2d525042f65a7eb Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 12:28:25 +0530 Subject: [PATCH] add direction si-snr --- enhancer/loss.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index ec753d4..75527bb 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -66,7 +66,7 @@ class Si_SDR: raise TypeError( "Invalid reduction, valid options are sum, mean, None" ) - self.higher_better = False + self.higher_better = True self.name = "si-sdr" def __call__(self, prediction: torch.Tensor, target: torch.Tensor): @@ -183,11 +183,34 @@ class LossWrapper(nn.Module): return loss +class Si_snr(nn.Module): + """ + SI-SNR + """ + + def __init__(self, **kwargs): + super().__init__() + + self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) + self.higher_better = True + self.name = "si_snr" + + def forward(self, prediction: torch.Tensor, target: torch.Tensor): + + if prediction.size() != target.size() or target.ndim < 3: + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + return self.loss_fun(prediction, target) + + LOSS_MAP = { "mae": mean_absolute_error, "mse": mean_squared_error, "si-sdr": Si_SDR, "pesq": Pesq, "stoi": Stoi, - "si-snr": ScaleInvariantSignalNoiseRatio, + "si-snr": Si_snr, }