diff --git a/enhancer/loss.py b/enhancer/loss.py index 3bc6fa2..ef33161 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -1,5 +1,3 @@ -from modulefinder import Module -from turtle import forward import torch import torch.nn as nn @@ -10,6 +8,7 @@ class mean_squared_error(nn.Module): super().__init__() self.loss_fun = nn.MSELoss(reduction=reduction) + self.higher_better = False def forward(self,prediction:torch.Tensor, target: torch.Tensor): @@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module): super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) + self.higher_better = False def forward(self, prediction:torch.Tensor, target: torch.Tensor): @@ -45,6 +45,7 @@ class Si_SDR(nn.Module): self.reduction = reduction else: raise TypeError("Invalid reduction, valid options are sum, mean, None") + self.higher_better = False def forward(self,prediction:torch.Tensor, target:torch.Tensor): @@ -76,6 +77,12 @@ class Avergeloss(nn.Module): super().__init__() self.valid_losses = nn.ModuleList() + + direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] + if len(set(direction)) > 1: + raise ValueError("all cost functions should be of same nature, maximize or minimize!") + + self.higher_better = direction[0] for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss())