add direction si-snr
This commit is contained in:
		
							parent
							
								
									234e1a89de
								
							
						
					
					
						commit
						82308750dc
					
				|  | @ -66,7 +66,7 @@ class Si_SDR: | |||
|             raise TypeError( | ||||
|                 "Invalid reduction, valid options are sum, mean, None" | ||||
|             ) | ||||
|         self.higher_better = False | ||||
|         self.higher_better = True | ||||
|         self.name = "si-sdr" | ||||
| 
 | ||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
|  | @ -183,11 +183,34 @@ class LossWrapper(nn.Module): | |||
|         return loss | ||||
| 
 | ||||
| 
 | ||||
| class Si_snr(nn.Module): | ||||
|     """ | ||||
|     SI-SNR | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) | ||||
|         self.higher_better = True | ||||
|         self.name = "si_snr" | ||||
| 
 | ||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|         if prediction.size() != target.size() or target.ndim < 3: | ||||
|             raise TypeError( | ||||
|                 f"""Inputs must be of the same shape (batch_size,channels,samples) | ||||
|                     got {prediction.size()} and {target.size()} instead""" | ||||
|             ) | ||||
| 
 | ||||
|         return self.loss_fun(prediction, target) | ||||
| 
 | ||||
| 
 | ||||
| LOSS_MAP = { | ||||
|     "mae": mean_absolute_error, | ||||
|     "mse": mean_squared_error, | ||||
|     "si-sdr": Si_SDR, | ||||
|     "pesq": Pesq, | ||||
|     "stoi": Stoi, | ||||
|     "si-snr": ScaleInvariantSignalNoiseRatio, | ||||
|     "si-snr": Si_snr, | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786