add direction si-snr
This commit is contained in:
parent
5b635a82a9
commit
69f6bb4926
|
|
@ -66,7 +66,7 @@ class Si_SDR:
|
|||
raise TypeError(
|
||||
"Invalid reduction, valid options are sum, mean, None"
|
||||
)
|
||||
self.higher_better = False
|
||||
self.higher_better = True
|
||||
self.name = "si-sdr"
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
|
@ -183,11 +183,34 @@ class LossWrapper(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
class Si_snr(nn.Module):
|
||||
"""
|
||||
SI-SNR
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
||||
self.higher_better = True
|
||||
self.name = "si_snr"
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
if prediction.size() != target.size() or target.ndim < 3:
|
||||
raise TypeError(
|
||||
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead"""
|
||||
)
|
||||
|
||||
return self.loss_fun(prediction, target)
|
||||
|
||||
|
||||
LOSS_MAP = {
|
||||
"mae": mean_absolute_error,
|
||||
"mse": mean_squared_error,
|
||||
"si-sdr": Si_SDR,
|
||||
"pesq": Pesq,
|
||||
"stoi": Stoi,
|
||||
"si-snr": ScaleInvariantSignalNoiseRatio,
|
||||
"si-snr": Si_snr,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue