200 lines
6.8 KiB
Python
200 lines
6.8 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class ComplexBatchNorm2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-5,
|
|
momentum: float = 0.1,
|
|
affine: bool = True,
|
|
track_running_stats: bool = True,
|
|
):
|
|
"""
|
|
Complex batch normalization 2D
|
|
https://arxiv.org/abs/1705.09792
|
|
|
|
|
|
"""
|
|
super().__init__()
|
|
self.num_features = num_features // 2
|
|
self.affine = affine
|
|
self.momentum = momentum
|
|
self.track_running_stats = track_running_stats
|
|
self.eps = eps
|
|
|
|
if self.affine:
|
|
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
else:
|
|
self.register_parameter("Wrr", None)
|
|
self.register_parameter("Wri", None)
|
|
self.register_parameter("Wii", None)
|
|
self.register_parameter("Br", None)
|
|
self.register_parameter("Bi", None)
|
|
|
|
if self.track_running_stats:
|
|
values = torch.zeros(self.num_features)
|
|
self.register_buffer("Mean_real", values)
|
|
self.register_buffer("Mean_imag", values)
|
|
self.register_buffer("Var_rr", values)
|
|
self.register_buffer("Var_ri", values)
|
|
self.register_buffer("Var_ii", values)
|
|
self.register_buffer(
|
|
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
|
|
)
|
|
else:
|
|
self.register_parameter("Mean_real", None)
|
|
self.register_parameter("Mean_imag", None)
|
|
self.register_parameter("Var_rr", None)
|
|
self.register_parameter("Var_ri", None)
|
|
self.register_parameter("Var_ii", None)
|
|
self.register_parameter("num_batches_tracked", None)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
if self.affine:
|
|
self.Wrr.data.fill_(1)
|
|
self.Wii.data.fill_(1)
|
|
self.Wri.data.uniform_(-0.9, 0.9)
|
|
self.Br.data.fill_(0)
|
|
self.Bi.data.fill_(0)
|
|
self.reset_running_stats()
|
|
|
|
def reset_running_stats(self):
|
|
if self.track_running_stats:
|
|
self.Mean_real.zero_()
|
|
self.Mean_imag.zero_()
|
|
self.Var_rr.fill_(1)
|
|
self.Var_ri.zero_()
|
|
self.Var_ii.fill_(1)
|
|
self.num_batches_tracked.zero_()
|
|
|
|
def extra_repr(self):
|
|
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
|
|
**self.__dict__
|
|
)
|
|
|
|
def forward(self, input):
|
|
|
|
real, imag = torch.chunk(input, 2, 1)
|
|
exp_avg_factor = 0.0
|
|
|
|
training = self.training and self.track_running_stats
|
|
if training:
|
|
self.num_batches_tracked += 1
|
|
if self.momentum is None:
|
|
exp_avg_factor = 1 / self.num_batches_tracked
|
|
else:
|
|
exp_avg_factor = self.momentum
|
|
|
|
redux = [i for i in reversed(range(real.dim())) if i != 1]
|
|
vdim = [1] * real.dim()
|
|
vdim[1] = real.size(1)
|
|
|
|
if training:
|
|
batch_mean_real, batch_mean_imag = real, imag
|
|
for dim in redux:
|
|
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
|
|
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
|
|
if self.track_running_stats:
|
|
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
|
|
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
|
|
|
|
else:
|
|
batch_mean_real = self.Mean_real.view(vdim)
|
|
batch_mean_imag = self.Mean_imag.view(vdim)
|
|
|
|
real = real - batch_mean_real
|
|
imag = imag - batch_mean_imag
|
|
|
|
if training:
|
|
batch_var_rr = real * real
|
|
batch_var_ri = real * imag
|
|
batch_var_ii = imag * imag
|
|
for dim in redux:
|
|
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
|
|
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
|
|
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
|
|
if self.track_running_stats:
|
|
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
|
|
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
|
|
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
|
|
else:
|
|
batch_var_rr = self.Var_rr.view(vdim)
|
|
batch_var_ii = self.Var_ii.view(vdim)
|
|
batch_var_ri = self.Var_ri.view(vdim)
|
|
|
|
batch_var_rr += self.eps
|
|
batch_var_ii += self.eps
|
|
|
|
# Covariance matrics
|
|
# | batch_var_rr batch_var_ri |
|
|
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
|
|
# Inverse square root of cov matrix by combining below two formulas
|
|
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
|
|
# https://mathworld.wolfram.com/MatrixInverse.html
|
|
|
|
tau = batch_var_rr + batch_var_ii
|
|
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
|
|
t = (tau + 2 * s).sqrt()
|
|
|
|
rst = (s * t).reciprocal()
|
|
Urr = (batch_var_ii + s) * rst
|
|
Uri = -batch_var_ri * rst
|
|
Uii = (batch_var_rr + s) * rst
|
|
|
|
if self.affine:
|
|
Wrr, Wri, Wii = (
|
|
self.Wrr.view(vdim),
|
|
self.Wri.view(vdim),
|
|
self.Wii.view(vdim),
|
|
)
|
|
Zrr = (Wrr * Urr) + (Wri * Uri)
|
|
Zri = (Wrr * Uri) + (Wri * Uii)
|
|
Zir = (Wii * Uri) + (Wri * Urr)
|
|
Zii = (Wri * Uri) + (Wii * Uii)
|
|
else:
|
|
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
|
|
|
|
yr = (Zrr * real) + (Zri * imag)
|
|
yi = (Zir * real) + (Zii * imag)
|
|
|
|
if self.affine:
|
|
yr = yr + self.Br.view(vdim)
|
|
yi = yi + self.Bi.view(vdim)
|
|
|
|
outputs = torch.cat([yr, yi], 1)
|
|
return outputs
|
|
|
|
|
|
class ComplexRelu(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.real_relu = nn.PReLU()
|
|
self.imag_relu = nn.PReLU()
|
|
|
|
def forward(self, input):
|
|
|
|
real, imag = torch.chunk(input, 2, 1)
|
|
real = self.real_relu(real)
|
|
imag = self.imag_relu(imag)
|
|
return torch.cat([real, imag], dim=1)
|
|
|
|
|
|
def complex_cat(inputs, axis=1):
|
|
|
|
real, imag = [], []
|
|
for data in inputs:
|
|
real_data, imag_data = torch.chunk(data, 2, axis)
|
|
real.append(real_data)
|
|
imag.append(imag_data)
|
|
real = torch.cat(real, axis)
|
|
imag = torch.cat(imag, axis)
|
|
return torch.cat([real, imag], axis)
|