From cfff2bed115c6d054b4827392e3c5ebe987c1cc9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 25 Sep 2022 11:58:43 +0530 Subject: [PATCH] rmv loss --- enhancer/utils/loss.py | 65 ------------------------------------------ 1 file changed, 65 deletions(-) delete mode 100644 enhancer/utils/loss.py diff --git a/enhancer/utils/loss.py b/enhancer/utils/loss.py deleted file mode 100644 index 3c1cc06..0000000 --- a/enhancer/utils/loss.py +++ /dev/null @@ -1,65 +0,0 @@ -from turtle import forward -import torch -import torch.nn as nn - - -class mean_squared_error(nn.Module): - - def __init__(self,reduction="mean"): - super().__init__() - - self.loss_fun = nn.MSELoss(reduction=reduction) - - def forward(self,prediction:torch.Tensor, target: torch.Tensor): - - if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") - - return self.loss_fun(prediction, target) - -class mean_absolute_error(nn.Module): - - def __init__(self,reduction="mean"): - super().__init__() - - self.loss_fun = nn.L1Loss(reduction=reduction) - - def forward(self, prediction:torch.Tensor, target: torch.Tensor): - - if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") - - return self.loss_fun(prediction, target) - -class Avergeloss(nn.Module): - - def __init__(self,losses): - super().__init__() - - self.valid_losses = nn.ModuleList() - for loss in losses: - loss = self.validate_loss(loss) - self.valid_losses.append(loss()) - - - def validate_loss(self,loss:str): - if loss not in LOSS_MAP.keys(): - raise ValueError(f"Invalid loss function {loss}, available loss functions are {LOSS_MAP.keys()}") - else: - return LOSS_MAP[loss] - - def forward(self,prediction:torch.Tensor, target:torch.Tensor): - loss = 0.0 - for loss_fun in self.valid_losses: - loss += loss_fun(prediction, target) - - return loss - - - - -LOSS_MAP = {"mea":mean_absolute_error, "mse": mean_squared_error} - -