add complex-cat
This commit is contained in:
parent
60fc4607d0
commit
d7f3847917
|
|
@ -7,7 +7,7 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
momentum: bool = True,
|
momentum: float = 0.1,
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
track_running_stats: bool = True,
|
track_running_stats: bool = True,
|
||||||
):
|
):
|
||||||
|
|
@ -25,12 +25,11 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
if self.affine:
|
if self.affine:
|
||||||
values = torch.Tensor(self.num_features)
|
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
self.Wrr = nn.parameter.Parameter(values)
|
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
self.Wri = nn.parameter.Parameter(values)
|
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
self.Wii = nn.parameter.Parameter(values)
|
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
self.Br = nn.parameter.Parameter(values)
|
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
self.Bi = nn.parameter.Parameter(values)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("Wrr", None)
|
self.register_parameter("Wrr", None)
|
||||||
self.register_parameter("Wri", None)
|
self.register_parameter("Wri", None)
|
||||||
|
|
@ -39,7 +38,7 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
self.register_parameter("Bi", None)
|
self.register_parameter("Bi", None)
|
||||||
|
|
||||||
if self.track_running_stats:
|
if self.track_running_stats:
|
||||||
values = torch.Tensor(self.num_features)
|
values = torch.zeros(self.num_features)
|
||||||
self.register_buffer("Mean_real", values)
|
self.register_buffer("Mean_real", values)
|
||||||
self.register_buffer("Mean_imag", values)
|
self.register_buffer("Mean_imag", values)
|
||||||
self.register_buffer("Var_rr", values)
|
self.register_buffer("Var_rr", values)
|
||||||
|
|
@ -111,8 +110,8 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
batch_mean_real = self.Mean_real.view(vdim)
|
batch_mean_real = self.Mean_real.view(vdim)
|
||||||
batch_mean_imag = self.Mean_imag.view(vdim)
|
batch_mean_imag = self.Mean_imag.view(vdim)
|
||||||
|
|
||||||
real -= batch_mean_real
|
real = real - batch_mean_real
|
||||||
imag -= batch_mean_imag
|
imag = imag - batch_mean_imag
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
batch_var_rr = real * real
|
batch_var_rr = real * real
|
||||||
|
|
@ -141,7 +140,7 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
|
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
|
||||||
t = (tau + 2 * s).sqrt()
|
t = (tau + 2 * s).sqrt()
|
||||||
|
|
||||||
rst = 1 / (s * t)
|
rst = (s * t).reciprocal()
|
||||||
Urr = (batch_var_ii + s) * rst
|
Urr = (batch_var_ii + s) * rst
|
||||||
Uri = -batch_var_ri * rst
|
Uri = -batch_var_ri * rst
|
||||||
Uii = (batch_var_rr + s) * rst
|
Uii = (batch_var_rr + s) * rst
|
||||||
|
|
@ -162,6 +161,10 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
yr = (Zrr * real) + (Zri * imag)
|
yr = (Zrr * real) + (Zri * imag)
|
||||||
yi = (Zir * real) + (Zii * imag)
|
yi = (Zir * real) + (Zii * imag)
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
yr = yr + self.Br.view(vdim)
|
||||||
|
yi = yi + self.Bi.view(vdim)
|
||||||
|
|
||||||
outputs = torch.cat([yr, yi], 1)
|
outputs = torch.cat([yr, yi], 1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
@ -178,3 +181,15 @@ class ComplexRelu(nn.Module):
|
||||||
real = self.real_relu(real)
|
real = self.real_relu(real)
|
||||||
imag = self.imag_relu(imag)
|
imag = self.imag_relu(imag)
|
||||||
return torch.cat([real, imag], dim=1)
|
return torch.cat([real, imag], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def complex_cat(inputs, axis=1):
|
||||||
|
|
||||||
|
real, imag = [], []
|
||||||
|
for data in inputs:
|
||||||
|
real_data, imag_data = torch.chunk(data, 2, axis)
|
||||||
|
real.append(real_data)
|
||||||
|
imag.append(imag_data)
|
||||||
|
real = torch.cat(real, axis)
|
||||||
|
imag = torch.cat(imag, axis)
|
||||||
|
return torch.cat([real, imag], axis)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue