add complex-cat
This commit is contained in:
		
							parent
							
								
									60fc4607d0
								
							
						
					
					
						commit
						d7f3847917
					
				|  | @ -7,7 +7,7 @@ class ComplexBatchNorm2D(nn.Module): | |||
|         self, | ||||
|         num_features: int, | ||||
|         eps: float = 1e-5, | ||||
|         momentum: bool = True, | ||||
|         momentum: float = 0.1, | ||||
|         affine: bool = True, | ||||
|         track_running_stats: bool = True, | ||||
|     ): | ||||
|  | @ -25,12 +25,11 @@ class ComplexBatchNorm2D(nn.Module): | |||
|         self.eps = eps | ||||
| 
 | ||||
|         if self.affine: | ||||
|             values = torch.Tensor(self.num_features) | ||||
|             self.Wrr = nn.parameter.Parameter(values) | ||||
|             self.Wri = nn.parameter.Parameter(values) | ||||
|             self.Wii = nn.parameter.Parameter(values) | ||||
|             self.Br = nn.parameter.Parameter(values) | ||||
|             self.Bi = nn.parameter.Parameter(values) | ||||
|             self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|         else: | ||||
|             self.register_parameter("Wrr", None) | ||||
|             self.register_parameter("Wri", None) | ||||
|  | @ -39,7 +38,7 @@ class ComplexBatchNorm2D(nn.Module): | |||
|             self.register_parameter("Bi", None) | ||||
| 
 | ||||
|         if self.track_running_stats: | ||||
|             values = torch.Tensor(self.num_features) | ||||
|             values = torch.zeros(self.num_features) | ||||
|             self.register_buffer("Mean_real", values) | ||||
|             self.register_buffer("Mean_imag", values) | ||||
|             self.register_buffer("Var_rr", values) | ||||
|  | @ -111,8 +110,8 @@ class ComplexBatchNorm2D(nn.Module): | |||
|             batch_mean_real = self.Mean_real.view(vdim) | ||||
|             batch_mean_imag = self.Mean_imag.view(vdim) | ||||
| 
 | ||||
|         real -= batch_mean_real | ||||
|         imag -= batch_mean_imag | ||||
|         real = real - batch_mean_real | ||||
|         imag = imag - batch_mean_imag | ||||
| 
 | ||||
|         if training: | ||||
|             batch_var_rr = real * real | ||||
|  | @ -141,7 +140,7 @@ class ComplexBatchNorm2D(nn.Module): | |||
|         s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri | ||||
|         t = (tau + 2 * s).sqrt() | ||||
| 
 | ||||
|         rst = 1 / (s * t) | ||||
|         rst = (s * t).reciprocal() | ||||
|         Urr = (batch_var_ii + s) * rst | ||||
|         Uri = -batch_var_ri * rst | ||||
|         Uii = (batch_var_rr + s) * rst | ||||
|  | @ -162,6 +161,10 @@ class ComplexBatchNorm2D(nn.Module): | |||
|         yr = (Zrr * real) + (Zri * imag) | ||||
|         yi = (Zir * real) + (Zii * imag) | ||||
| 
 | ||||
|         if self.affine: | ||||
|             yr = yr + self.Br.view(vdim) | ||||
|             yi = yi + self.Bi.view(vdim) | ||||
| 
 | ||||
|         outputs = torch.cat([yr, yi], 1) | ||||
|         return outputs | ||||
| 
 | ||||
|  | @ -178,3 +181,15 @@ class ComplexRelu(nn.Module): | |||
|         real = self.real_relu(real) | ||||
|         imag = self.imag_relu(imag) | ||||
|         return torch.cat([real, imag], dim=1) | ||||
| 
 | ||||
| 
 | ||||
| def complex_cat(inputs, axis=1): | ||||
| 
 | ||||
|     real, imag = [], [] | ||||
|     for data in inputs: | ||||
|         real_data, imag_data = torch.chunk(data, 2, axis) | ||||
|         real.append(real_data) | ||||
|         imag.append(imag_data) | ||||
|     real = torch.cat(real, axis) | ||||
|     imag = torch.cat(imag, axis) | ||||
|     return torch.cat([real, imag], axis) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786