From da1b986d311579f60445e599d29cdf42306815d8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 3 Nov 2022 16:05:55 +0530 Subject: [PATCH] complex batchnorm 2d --- enhancer/models/complexnn/norm.py | 95 ++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/enhancer/models/complexnn/norm.py b/enhancer/models/complexnn/norm.py index eec2130..5dd0104 100644 --- a/enhancer/models/complexnn/norm.py +++ b/enhancer/models/complexnn/norm.py @@ -2,7 +2,7 @@ import torch from torch import nn -class ComplexBatchNorm(nn.Module): +class ComplexBatchNorm2D(nn.Module): def __init__( self, num_features: int, @@ -11,10 +11,18 @@ class ComplexBatchNorm(nn.Module): affine: bool = True, track_running_stats: bool = True, ): + """ + Complex batch normalization 2D + https://arxiv.org/abs/1705.09792 + + + """ + super().__init__() self.num_features = num_features // 2 self.affine = affine self.momentum = momentum self.track_running_stats = track_running_stats + self.eps = eps if self.affine: values = torch.Tensor(self.num_features) @@ -53,7 +61,7 @@ class ComplexBatchNorm(nn.Module): def reset_parameters(self): if self.affine: self.Wrr.data.fill_(1) - self.Wii.data.fill(1) + self.Wii.data.fill_(1) self.Wri.data.uniform_(-0.9, 0.9) self.Br.data.fill_(0) self.Bi.data.fill_(0) @@ -69,4 +77,85 @@ class ComplexBatchNorm(nn.Module): self.num_batches_tracked.zero_() def forward(self, input): - pass + + real, imag = torch.chunk(input, 2, 1) + exp_avg_factor = 0.0 + + training = self.training and self.track_running_stats + if training: + self.num_batches_tracked += 1 + if self.momentum is None: + exp_avg_factor = 1 / self.num_batches_tracked + else: + exp_avg_factor = self.momentum + + redux = [i for i in reversed(range(real.dim())) if i != 1] + vdim = [1] * real.dim() + vdim[1] = real.size(1) + + if training: + batch_mean_real, batch_mean_imag = real, imag + for dim in redux: + batch_mean_real = batch_mean_real.mean(dim, keepdim=True) + batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True) + if self.track_running_stats: + self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor) + self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor) + + else: + batch_mean_real = self.Mean_real.view(vdim) + batch_mean_imag = self.Mean_imag.view(vdim) + + real -= batch_mean_real + imag -= batch_mean_imag + + if training: + batch_var_rr = real * real + batch_var_ri = real * imag + batch_var_ii = imag * imag + for dim in redux: + batch_var_rr = batch_var_rr.mean(dim, keepdim=True) + batch_var_ri = batch_var_ri.mean(dim, keepdim=True) + batch_var_ii = batch_var_ii.mean(dim, keepdim=True) + if self.track_running_stats: + self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) + self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) + self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) + + batch_var_rr += self.eps + batch_var_ii += self.eps + + # Covariance matrics + # | batch_var_rr batch_var_ri | + # | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri + # Inverse square root of cov matrix by combining below two formulas + # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix + # https://mathworld.wolfram.com/MatrixInverse.html + + tau = batch_var_rr + batch_var_ii + s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri + t = (tau + 2 * s).sqrt() + + rst = 1 / (s * t) + Urr = (batch_var_ii + s) * rst + Uri = -batch_var_ri * rst + Uii = (batch_var_rr + s) * rst + + if self.affine: + Wrr, Wri, Wii = ( + self.Wrr.view(vdim), + self.Wri.view(vdim), + self.Wii.view(vdim), + ) + Zrr = (Wrr * Urr) + (Wri * Uri) + Zri = (Wrr * Uri) + (Wri * Uii) + Zir = (Wii * Uri) + (Wri * Urr) + Zii = (Wri * Uri) + (Wii * Uii) + else: + Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii + + yr = (Zrr * real) + (Zri * imag) + yi = (Zir * real) + (Zii * imag) + + outputs = torch.cat([yr, yi], 1) + return outputs