add direction si-snr

This commit is contained in:
shahules786 2022-11-07 12:26:58 +05:30
parent 5b635a82a9
commit 69f6bb4926
1 changed files with 25 additions and 2 deletions

View File

@ -66,7 +66,7 @@ class Si_SDR:
raise TypeError(
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = False
self.higher_better = True
self.name = "si-sdr"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
@ -183,11 +183,34 @@ class LossWrapper(nn.Module):
return loss
class Si_snr(nn.Module):
"""
SI-SNR
"""
def __init__(self, **kwargs):
super().__init__()
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
self.higher_better = True
self.name = "si_snr"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
LOSS_MAP = {
"mae": mean_absolute_error,
"mse": mean_squared_error,
"si-sdr": Si_SDR,
"pesq": Pesq,
"stoi": Stoi,
"si-snr": ScaleInvariantSignalNoiseRatio,
"si-snr": Si_snr,
}