From f8c8884ce9f6413b27967c1392ba6381a9d671df Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 10 Sep 2022 11:42:14 +0530 Subject: [PATCH] average loss --- enhancer/utils/loss.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/enhancer/utils/loss.py b/enhancer/utils/loss.py index b9c8bee..c410914 100644 --- a/enhancer/utils/loss.py +++ b/enhancer/utils/loss.py @@ -1,3 +1,4 @@ +from turtle import forward import torch import torch.nn as nn @@ -22,7 +23,33 @@ class mean_absolute_error(nn.Module): def forward(self, prediction:torch.Tensor, target: torch.Tensor): return self.loss_fun(prediction, target) + +class Avergeloss(nn.Module): + + def __init__(self,losses): + + self.valid_losses = nn.ModuleList() + for loss in losses: + loss = self.validate_loss(loss) + self.valid_losses.append(loss()) + + + def validate_loss(self,loss:str): + if loss not in LOSS_MAP.keys(): + raise ValueError() + else: + return LOSS_MAP[loss] + + def forward(self,prediction:torch.Tensor, target:torch.Tensor): + loss = 0.0 + for loss_fun in self.valid_losses: + loss += loss_fun(prediction, target) + return loss + + + + LOSS_MAP = {"mea":mean_absolute_error, "mse": mean_squared_error}