add complex-cat

This commit is contained in:
shahules786 2022-11-07 10:24:47 +05:30
parent 60fc4607d0
commit d7f3847917
1 changed files with 26 additions and 11 deletions

View File

@ -7,7 +7,7 @@ class ComplexBatchNorm2D(nn.Module):
self, self,
num_features: int, num_features: int,
eps: float = 1e-5, eps: float = 1e-5,
momentum: bool = True, momentum: float = 0.1,
affine: bool = True, affine: bool = True,
track_running_stats: bool = True, track_running_stats: bool = True,
): ):
@ -25,12 +25,11 @@ class ComplexBatchNorm2D(nn.Module):
self.eps = eps self.eps = eps
if self.affine: if self.affine:
values = torch.Tensor(self.num_features) self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wrr = nn.parameter.Parameter(values) self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wri = nn.parameter.Parameter(values) self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wii = nn.parameter.Parameter(values) self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Br = nn.parameter.Parameter(values) self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Bi = nn.parameter.Parameter(values)
else: else:
self.register_parameter("Wrr", None) self.register_parameter("Wrr", None)
self.register_parameter("Wri", None) self.register_parameter("Wri", None)
@ -39,7 +38,7 @@ class ComplexBatchNorm2D(nn.Module):
self.register_parameter("Bi", None) self.register_parameter("Bi", None)
if self.track_running_stats: if self.track_running_stats:
values = torch.Tensor(self.num_features) values = torch.zeros(self.num_features)
self.register_buffer("Mean_real", values) self.register_buffer("Mean_real", values)
self.register_buffer("Mean_imag", values) self.register_buffer("Mean_imag", values)
self.register_buffer("Var_rr", values) self.register_buffer("Var_rr", values)
@ -111,8 +110,8 @@ class ComplexBatchNorm2D(nn.Module):
batch_mean_real = self.Mean_real.view(vdim) batch_mean_real = self.Mean_real.view(vdim)
batch_mean_imag = self.Mean_imag.view(vdim) batch_mean_imag = self.Mean_imag.view(vdim)
real -= batch_mean_real real = real - batch_mean_real
imag -= batch_mean_imag imag = imag - batch_mean_imag
if training: if training:
batch_var_rr = real * real batch_var_rr = real * real
@ -141,7 +140,7 @@ class ComplexBatchNorm2D(nn.Module):
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
t = (tau + 2 * s).sqrt() t = (tau + 2 * s).sqrt()
rst = 1 / (s * t) rst = (s * t).reciprocal()
Urr = (batch_var_ii + s) * rst Urr = (batch_var_ii + s) * rst
Uri = -batch_var_ri * rst Uri = -batch_var_ri * rst
Uii = (batch_var_rr + s) * rst Uii = (batch_var_rr + s) * rst
@ -162,6 +161,10 @@ class ComplexBatchNorm2D(nn.Module):
yr = (Zrr * real) + (Zri * imag) yr = (Zrr * real) + (Zri * imag)
yi = (Zir * real) + (Zii * imag) yi = (Zir * real) + (Zii * imag)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], 1) outputs = torch.cat([yr, yi], 1)
return outputs return outputs
@ -178,3 +181,15 @@ class ComplexRelu(nn.Module):
real = self.real_relu(real) real = self.real_relu(real)
imag = self.imag_relu(imag) imag = self.imag_relu(imag)
return torch.cat([real, imag], dim=1) return torch.cat([real, imag], dim=1)
def complex_cat(inputs, axis=1):
real, imag = [], []
for data in inputs:
real_data, imag_data = torch.chunk(data, 2, axis)
real.append(real_data)
imag.append(imag_data)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
return torch.cat([real, imag], axis)