add SI-SDR
This commit is contained in:
parent
25568bb84a
commit
838b7d2357
|
|
@ -37,16 +37,38 @@ class mean_absolute_error(nn.Module):
|
||||||
class Si_SDR(nn.Module):
|
class Si_SDR(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self
|
self,
|
||||||
|
reduction:str="mean"
|
||||||
):
|
):
|
||||||
pass
|
super().__init__()
|
||||||
|
if reduction in ["sum","mean",None]:
|
||||||
|
self.reduction = reduction
|
||||||
|
else:
|
||||||
|
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
||||||
|
|
||||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
got {prediction.size()} and {target.size()} instead""")
|
got {prediction.size()} and {target.size()} instead""")
|
||||||
prediction,target = prediction.unsqueeze(1),target.unsqueeze(1)
|
|
||||||
|
target_energy = torch.sum(target**2,keepdim=True,dim=-1)
|
||||||
|
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy
|
||||||
|
target_projection = target * scaling_factor
|
||||||
|
noise = prediction - target_projection
|
||||||
|
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1)
|
||||||
|
si_sdr = 10*torch.log10(ratio).mean(dim=-1)
|
||||||
|
|
||||||
|
if self.reduction == "sum":
|
||||||
|
si_sdr = si_sdr.sum()
|
||||||
|
elif self.reduction == "mean":
|
||||||
|
si_sdr = si_sdr.mean()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return si_sdr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Avergeloss(nn.Module):
|
class Avergeloss(nn.Module):
|
||||||
|
|
||||||
|
|
@ -75,6 +97,8 @@ class Avergeloss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LOSS_MAP = {"mae":mean_absolute_error, "mse": mean_squared_error}
|
LOSS_MAP = {"mae":mean_absolute_error,
|
||||||
|
"mse": mean_squared_error,
|
||||||
|
"SI-SDR":Si_SDR}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue