negate
This commit is contained in:
parent
4e58df5e37
commit
d5b17f3745
|
|
@ -192,7 +192,7 @@ class Si_snr(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
||||||
self.higher_better = True
|
self.higher_better = False
|
||||||
self.name = "si_snr"
|
self.name = "si_snr"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
@ -203,7 +203,7 @@ class Si_snr(nn.Module):
|
||||||
got {prediction.size()} and {target.size()} instead"""
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.loss_fun(prediction, target)
|
return -1 * self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
LOSS_MAP = {
|
LOSS_MAP = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue