ensure loss direction

This commit is contained in:
shahules786 2022-09-30 15:19:06 +05:30
parent fffdf02b93
commit 1f4947103f
1 changed files with 9 additions and 2 deletions

View File

@ -1,5 +1,3 @@
from modulefinder import Module
from turtle import forward
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,6 +8,7 @@ class mean_squared_error(nn.Module):
super().__init__() super().__init__()
self.loss_fun = nn.MSELoss(reduction=reduction) self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False
def forward(self,prediction:torch.Tensor, target: torch.Tensor): def forward(self,prediction:torch.Tensor, target: torch.Tensor):
@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module):
super().__init__() super().__init__()
self.loss_fun = nn.L1Loss(reduction=reduction) self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False
def forward(self, prediction:torch.Tensor, target: torch.Tensor): def forward(self, prediction:torch.Tensor, target: torch.Tensor):
@ -45,6 +45,7 @@ class Si_SDR(nn.Module):
self.reduction = reduction self.reduction = reduction
else: else:
raise TypeError("Invalid reduction, valid options are sum, mean, None") raise TypeError("Invalid reduction, valid options are sum, mean, None")
self.higher_better = False
def forward(self,prediction:torch.Tensor, target:torch.Tensor): def forward(self,prediction:torch.Tensor, target:torch.Tensor):
@ -76,6 +77,12 @@ class Avergeloss(nn.Module):
super().__init__() super().__init__()
self.valid_losses = nn.ModuleList() self.valid_losses = nn.ModuleList()
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
if len(set(direction)) > 1:
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
self.higher_better = direction[0]
for loss in losses: for loss in losses:
loss = self.validate_loss(loss) loss = self.validate_loss(loss)
self.valid_losses.append(loss()) self.valid_losses.append(loss())