diff --git a/enhancer/loss.py b/enhancer/loss.py index d673ec3..3bc6fa2 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -37,16 +37,38 @@ class mean_absolute_error(nn.Module): class Si_SDR(nn.Module): def __init__( - self + self, + reduction:str="mean" ): - pass + super().__init__() + if reduction in ["sum","mean",None]: + self.reduction = reduction + else: + raise TypeError("Invalid reduction, valid options are sum, mean, None") def forward(self,prediction:torch.Tensor, target:torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) got {prediction.size()} and {target.size()} instead""") - prediction,target = prediction.unsqueeze(1),target.unsqueeze(1) + + target_energy = torch.sum(target**2,keepdim=True,dim=-1) + scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy + target_projection = target * scaling_factor + noise = prediction - target_projection + ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1) + si_sdr = 10*torch.log10(ratio).mean(dim=-1) + + if self.reduction == "sum": + si_sdr = si_sdr.sum() + elif self.reduction == "mean": + si_sdr = si_sdr.mean() + else: + pass + + return si_sdr + + class Avergeloss(nn.Module): @@ -75,6 +97,8 @@ class Avergeloss(nn.Module): -LOSS_MAP = {"mae":mean_absolute_error, "mse": mean_squared_error} +LOSS_MAP = {"mae":mean_absolute_error, + "mse": mean_squared_error, + "SI-SDR":Si_SDR}