negate si-snr

This commit is contained in:
shahules786 2022-11-14 10:48:31 +05:30
parent d90db16bce
commit 4a2865ff03
1 changed files with 2 additions and 2 deletions

View File

@ -192,7 +192,7 @@ class Si_snr(nn.Module):
super().__init__() super().__init__()
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
self.higher_better = True self.higher_better = False
self.name = "si_snr" self.name = "si_snr"
def forward(self, prediction: torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -203,7 +203,7 @@ class Si_snr(nn.Module):
got {prediction.size()} and {target.size()} instead""" got {prediction.size()} and {target.size()} instead"""
) )
return self.loss_fun(prediction, target) return -1 * self.loss_fun(prediction, target)
LOSS_MAP = { LOSS_MAP = {