negate
This commit is contained in:
parent
4e58df5e37
commit
d5b17f3745
|
|
@ -192,7 +192,7 @@ class Si_snr(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
||||
self.higher_better = True
|
||||
self.higher_better = False
|
||||
self.name = "si_snr"
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
|
@ -203,7 +203,7 @@ class Si_snr(nn.Module):
|
|||
got {prediction.size()} and {target.size()} instead"""
|
||||
)
|
||||
|
||||
return self.loss_fun(prediction, target)
|
||||
return -1 * self.loss_fun(prediction, target)
|
||||
|
||||
|
||||
LOSS_MAP = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue