From 756465c2bfdf6eee06a55526765b10f95eba2e9f Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:26:34 +0530 Subject: [PATCH] format loss.py --- enhancer/loss.py | 110 +++++++++++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 42 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index ef33161..1e156a0 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,62 +3,82 @@ import torch.nn as nn class mean_squared_error(nn.Module): + """ + Mean squared error / L1 loss + """ - def __init__(self,reduction="mean"): + def __init__(self, reduction="mean"): super().__init__() self.loss_fun = nn.MSELoss(reduction=reduction) self.higher_better = False - def forward(self,prediction:torch.Tensor, target: torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) return self.loss_fun(prediction, target) -class mean_absolute_error(nn.Module): - def __init__(self,reduction="mean"): +class mean_absolute_error(nn.Module): + """ + Mean absolute error / L2 loss + """ + + def __init__(self, reduction="mean"): super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) self.higher_better = False - def forward(self, prediction:torch.Tensor, target: torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) return self.loss_fun(prediction, target) -class Si_SDR(nn.Module): - def __init__( - self, - reduction:str="mean" - ): +class Si_SDR(nn.Module): + """ + SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) + """ + + def __init__(self, reduction: str = "mean"): super().__init__() - if reduction in ["sum","mean",None]: + if reduction in ["sum", "mean", None]: self.reduction = reduction else: - raise TypeError("Invalid reduction, valid options are sum, mean, None") + raise TypeError( + "Invalid reduction, valid options are sum, mean, None" + ) self.higher_better = False - def forward(self,prediction:torch.Tensor, target:torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") - - target_energy = torch.sum(target**2,keepdim=True,dim=-1) - scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + target_energy = torch.sum(target**2, keepdim=True, dim=-1) + scaling_factor = ( + torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy + ) target_projection = target * scaling_factor noise = prediction - target_projection - ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1) - si_sdr = 10*torch.log10(ratio).mean(dim=-1) + ratio = torch.sum(target_projection**2, dim=-1) / torch.sum( + noise**2, dim=-1 + ) + si_sdr = 10 * torch.log10(ratio).mean(dim=-1) if self.reduction == "sum": si_sdr = si_sdr.sum() @@ -66,46 +86,52 @@ class Si_SDR(nn.Module): si_sdr = si_sdr.mean() else: pass - + return si_sdr - class Avergeloss(nn.Module): + """ + Combine multiple metics of same nature. + for example, ["mea","mae"] + """ - def __init__(self,losses): + def __init__(self, losses): super().__init__() self.valid_losses = nn.ModuleList() - - direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] + + direction = [ + getattr(LOSS_MAP[loss](), "higher_better") for loss in losses + ] if len(set(direction)) > 1: - raise ValueError("all cost functions should be of same nature, maximize or minimize!") + raise ValueError( + "all cost functions should be of same nature, maximize or minimize!" + ) self.higher_better = direction[0] for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) - - def validate_loss(self,loss:str): + def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): - raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}") + raise ValueError( + f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}" + ) else: return LOSS_MAP[loss] - def forward(self,prediction:torch.Tensor, target:torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): loss = 0.0 for loss_fun in self.valid_losses: loss += loss_fun(prediction, target) - + return loss - - - -LOSS_MAP = {"mae":mean_absolute_error, - "mse": mean_squared_error, - "SI-SDR":Si_SDR} - +LOSS_MAP = { + "mae": mean_absolute_error, + "mse": mean_squared_error, + "SI-SDR": Si_SDR, +}