From e932dc6c75b73e3a3f25221ac78cbdf9d0dd8862 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 3 Nov 2022 11:37:58 +0530 Subject: [PATCH] batchnorm --- enhancer/models/complexnn/norm.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 enhancer/models/complexnn/norm.py diff --git a/enhancer/models/complexnn/norm.py b/enhancer/models/complexnn/norm.py new file mode 100644 index 0000000..eec2130 --- /dev/null +++ b/enhancer/models/complexnn/norm.py @@ -0,0 +1,72 @@ +import torch +from torch import nn + + +class ComplexBatchNorm(nn.Module): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: bool = True, + affine: bool = True, + track_running_stats: bool = True, + ): + self.num_features = num_features // 2 + self.affine = affine + self.momentum = momentum + self.track_running_stats = track_running_stats + + if self.affine: + values = torch.Tensor(self.num_features) + self.Wrr = nn.parameter.Parameter(values) + self.Wri = nn.parameter.Parameter(values) + self.Wii = nn.parameter.Parameter(values) + self.Br = nn.parameter.Parameter(values) + self.Bi = nn.parameter.Parameter(values) + else: + self.register_parameter("Wrr", None) + self.register_parameter("Wri", None) + self.register_parameter("Wii", None) + self.register_parameter("Br", None) + self.register_parameter("Bi", None) + + if self.track_running_stats: + values = torch.Tensor(self.num_features) + self.register_buffer("Mean_real", values) + self.register_buffer("Mean_imag", values) + self.register_buffer("Var_rr", values) + self.register_buffer("Var_ri", values) + self.register_buffer("Var_ii", values) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("Mean_real", None) + self.register_parameter("Mean_imag", None) + self.register_parameter("Var_rr", None) + self.register_parameter("Var_ri", None) + self.register_parameter("Var_ii", None) + self.register_parameter("num_batches_tracked", None) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.Wrr.data.fill_(1) + self.Wii.data.fill(1) + self.Wri.data.uniform_(-0.9, 0.9) + self.Br.data.fill_(0) + self.Bi.data.fill_(0) + self.reset_running_stats() + + def reset_running_stats(self): + if self.track_running_stats: + self.Mean_real.zero_() + self.Mean_imag.zero_() + self.Var_rr.fill_(1) + self.Var_ri.zero_() + self.Var_ii.fill_(1) + self.num_batches_tracked.zero_() + + def forward(self, input): + pass