add SI-SDR

This commit is contained in:
shahules786 2022-09-29 11:00:41 +05:30
parent 25568bb84a
commit 838b7d2357
1 changed files with 28 additions and 4 deletions

View File

@ -37,16 +37,38 @@ class mean_absolute_error(nn.Module):
class Si_SDR(nn.Module):
def __init__(
self
self,
reduction:str="mean"
):
pass
super().__init__()
if reduction in ["sum","mean",None]:
self.reduction = reduction
else:
raise TypeError("Invalid reduction, valid options are sum, mean, None")
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead""")
prediction,target = prediction.unsqueeze(1),target.unsqueeze(1)
target_energy = torch.sum(target**2,keepdim=True,dim=-1)
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy
target_projection = target * scaling_factor
noise = prediction - target_projection
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1)
si_sdr = 10*torch.log10(ratio).mean(dim=-1)
if self.reduction == "sum":
si_sdr = si_sdr.sum()
elif self.reduction == "mean":
si_sdr = si_sdr.mean()
else:
pass
return si_sdr
class Avergeloss(nn.Module):
@ -75,6 +97,8 @@ class Avergeloss(nn.Module):
LOSS_MAP = {"mae":mean_absolute_error, "mse": mean_squared_error}
LOSS_MAP = {"mae":mean_absolute_error,
"mse": mean_squared_error,
"SI-SDR":Si_SDR}