add direction si-snr
This commit is contained in:
		
							parent
							
								
									234e1a89de
								
							
						
					
					
						commit
						82308750dc
					
				|  | @ -66,7 +66,7 @@ class Si_SDR: | ||||||
|             raise TypeError( |             raise TypeError( | ||||||
|                 "Invalid reduction, valid options are sum, mean, None" |                 "Invalid reduction, valid options are sum, mean, None" | ||||||
|             ) |             ) | ||||||
|         self.higher_better = False |         self.higher_better = True | ||||||
|         self.name = "si-sdr" |         self.name = "si-sdr" | ||||||
| 
 | 
 | ||||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): |     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
|  | @ -183,11 +183,34 @@ class LossWrapper(nn.Module): | ||||||
|         return loss |         return loss | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class Si_snr(nn.Module): | ||||||
|  |     """ | ||||||
|  |     SI-SNR | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) | ||||||
|  |         self.higher_better = True | ||||||
|  |         self.name = "si_snr" | ||||||
|  | 
 | ||||||
|  |     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
|  | 
 | ||||||
|  |         if prediction.size() != target.size() or target.ndim < 3: | ||||||
|  |             raise TypeError( | ||||||
|  |                 f"""Inputs must be of the same shape (batch_size,channels,samples) | ||||||
|  |                     got {prediction.size()} and {target.size()} instead""" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         return self.loss_fun(prediction, target) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| LOSS_MAP = { | LOSS_MAP = { | ||||||
|     "mae": mean_absolute_error, |     "mae": mean_absolute_error, | ||||||
|     "mse": mean_squared_error, |     "mse": mean_squared_error, | ||||||
|     "si-sdr": Si_SDR, |     "si-sdr": Si_SDR, | ||||||
|     "pesq": Pesq, |     "pesq": Pesq, | ||||||
|     "stoi": Stoi, |     "stoi": Stoi, | ||||||
|     "si-snr": ScaleInvariantSignalNoiseRatio, |     "si-snr": Si_snr, | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786