complex batchnorm 2d

This commit is contained in:
shahules786 2022-11-03 16:05:55 +05:30
parent e932dc6c75
commit da1b986d31
1 changed files with 92 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
class ComplexBatchNorm(nn.Module): class ComplexBatchNorm2D(nn.Module):
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -11,10 +11,18 @@ class ComplexBatchNorm(nn.Module):
affine: bool = True, affine: bool = True,
track_running_stats: bool = True, track_running_stats: bool = True,
): ):
"""
Complex batch normalization 2D
https://arxiv.org/abs/1705.09792
"""
super().__init__()
self.num_features = num_features // 2 self.num_features = num_features // 2
self.affine = affine self.affine = affine
self.momentum = momentum self.momentum = momentum
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
self.eps = eps
if self.affine: if self.affine:
values = torch.Tensor(self.num_features) values = torch.Tensor(self.num_features)
@ -53,7 +61,7 @@ class ComplexBatchNorm(nn.Module):
def reset_parameters(self): def reset_parameters(self):
if self.affine: if self.affine:
self.Wrr.data.fill_(1) self.Wrr.data.fill_(1)
self.Wii.data.fill(1) self.Wii.data.fill_(1)
self.Wri.data.uniform_(-0.9, 0.9) self.Wri.data.uniform_(-0.9, 0.9)
self.Br.data.fill_(0) self.Br.data.fill_(0)
self.Bi.data.fill_(0) self.Bi.data.fill_(0)
@ -69,4 +77,85 @@ class ComplexBatchNorm(nn.Module):
self.num_batches_tracked.zero_() self.num_batches_tracked.zero_()
def forward(self, input): def forward(self, input):
pass
real, imag = torch.chunk(input, 2, 1)
exp_avg_factor = 0.0
training = self.training and self.track_running_stats
if training:
self.num_batches_tracked += 1
if self.momentum is None:
exp_avg_factor = 1 / self.num_batches_tracked
else:
exp_avg_factor = self.momentum
redux = [i for i in reversed(range(real.dim())) if i != 1]
vdim = [1] * real.dim()
vdim[1] = real.size(1)
if training:
batch_mean_real, batch_mean_imag = real, imag
for dim in redux:
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
if self.track_running_stats:
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
else:
batch_mean_real = self.Mean_real.view(vdim)
batch_mean_imag = self.Mean_imag.view(vdim)
real -= batch_mean_real
imag -= batch_mean_imag
if training:
batch_var_rr = real * real
batch_var_ri = real * imag
batch_var_ii = imag * imag
for dim in redux:
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
if self.track_running_stats:
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
batch_var_rr += self.eps
batch_var_ii += self.eps
# Covariance matrics
# | batch_var_rr batch_var_ri |
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
# Inverse square root of cov matrix by combining below two formulas
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
# https://mathworld.wolfram.com/MatrixInverse.html
tau = batch_var_rr + batch_var_ii
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
t = (tau + 2 * s).sqrt()
rst = 1 / (s * t)
Urr = (batch_var_ii + s) * rst
Uri = -batch_var_ri * rst
Uii = (batch_var_rr + s) * rst
if self.affine:
Wrr, Wri, Wii = (
self.Wrr.view(vdim),
self.Wri.view(vdim),
self.Wii.view(vdim),
)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wii * Uri) + (Wri * Urr)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * real) + (Zri * imag)
yi = (Zir * real) + (Zii * imag)
outputs = torch.cat([yr, yi], 1)
return outputs