This commit is contained in:
shahules786 2022-11-14 10:51:26 +05:30
parent 4e58df5e37
commit d5b17f3745
1 changed files with 2 additions and 2 deletions

View File

@ -192,7 +192,7 @@ class Si_snr(nn.Module):
super().__init__() super().__init__()
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
self.higher_better = True self.higher_better = False
self.name = "si_snr" self.name = "si_snr"
def forward(self, prediction: torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -203,7 +203,7 @@ class Si_snr(nn.Module):
got {prediction.size()} and {target.size()} instead""" got {prediction.size()} and {target.size()} instead"""
) )
return self.loss_fun(prediction, target) return -1 * self.loss_fun(prediction, target)
LOSS_MAP = { LOSS_MAP = {