add SI-SDR
This commit is contained in:
parent
25568bb84a
commit
838b7d2357
|
|
@ -37,16 +37,38 @@ class mean_absolute_error(nn.Module):
|
|||
class Si_SDR(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self
|
||||
self,
|
||||
reduction:str="mean"
|
||||
):
|
||||
pass
|
||||
super().__init__()
|
||||
if reduction in ["sum","mean",None]:
|
||||
self.reduction = reduction
|
||||
else:
|
||||
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
||||
|
||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
||||
|
||||
if prediction.size() != target.size() or target.ndim < 3:
|
||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead""")
|
||||
prediction,target = prediction.unsqueeze(1),target.unsqueeze(1)
|
||||
|
||||
target_energy = torch.sum(target**2,keepdim=True,dim=-1)
|
||||
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy
|
||||
target_projection = target * scaling_factor
|
||||
noise = prediction - target_projection
|
||||
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1)
|
||||
si_sdr = 10*torch.log10(ratio).mean(dim=-1)
|
||||
|
||||
if self.reduction == "sum":
|
||||
si_sdr = si_sdr.sum()
|
||||
elif self.reduction == "mean":
|
||||
si_sdr = si_sdr.mean()
|
||||
else:
|
||||
pass
|
||||
|
||||
return si_sdr
|
||||
|
||||
|
||||
|
||||
class Avergeloss(nn.Module):
|
||||
|
||||
|
|
@ -75,6 +97,8 @@ class Avergeloss(nn.Module):
|
|||
|
||||
|
||||
|
||||
LOSS_MAP = {"mae":mean_absolute_error, "mse": mean_squared_error}
|
||||
LOSS_MAP = {"mae":mean_absolute_error,
|
||||
"mse": mean_squared_error,
|
||||
"SI-SDR":Si_SDR}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue