batchnorm
This commit is contained in:
parent
b1144e7b81
commit
e932dc6c75
|
|
@ -0,0 +1,72 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ComplexBatchNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: bool = True,
|
||||
affine: bool = True,
|
||||
track_running_stats: bool = True,
|
||||
):
|
||||
self.num_features = num_features // 2
|
||||
self.affine = affine
|
||||
self.momentum = momentum
|
||||
self.track_running_stats = track_running_stats
|
||||
|
||||
if self.affine:
|
||||
values = torch.Tensor(self.num_features)
|
||||
self.Wrr = nn.parameter.Parameter(values)
|
||||
self.Wri = nn.parameter.Parameter(values)
|
||||
self.Wii = nn.parameter.Parameter(values)
|
||||
self.Br = nn.parameter.Parameter(values)
|
||||
self.Bi = nn.parameter.Parameter(values)
|
||||
else:
|
||||
self.register_parameter("Wrr", None)
|
||||
self.register_parameter("Wri", None)
|
||||
self.register_parameter("Wii", None)
|
||||
self.register_parameter("Br", None)
|
||||
self.register_parameter("Bi", None)
|
||||
|
||||
if self.track_running_stats:
|
||||
values = torch.Tensor(self.num_features)
|
||||
self.register_buffer("Mean_real", values)
|
||||
self.register_buffer("Mean_imag", values)
|
||||
self.register_buffer("Var_rr", values)
|
||||
self.register_buffer("Var_ri", values)
|
||||
self.register_buffer("Var_ii", values)
|
||||
self.register_buffer(
|
||||
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("Mean_real", None)
|
||||
self.register_parameter("Mean_imag", None)
|
||||
self.register_parameter("Var_rr", None)
|
||||
self.register_parameter("Var_ri", None)
|
||||
self.register_parameter("Var_ii", None)
|
||||
self.register_parameter("num_batches_tracked", None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.affine:
|
||||
self.Wrr.data.fill_(1)
|
||||
self.Wii.data.fill(1)
|
||||
self.Wri.data.uniform_(-0.9, 0.9)
|
||||
self.Br.data.fill_(0)
|
||||
self.Bi.data.fill_(0)
|
||||
self.reset_running_stats()
|
||||
|
||||
def reset_running_stats(self):
|
||||
if self.track_running_stats:
|
||||
self.Mean_real.zero_()
|
||||
self.Mean_imag.zero_()
|
||||
self.Var_rr.fill_(1)
|
||||
self.Var_ri.zero_()
|
||||
self.Var_ii.fill_(1)
|
||||
self.num_batches_tracked.zero_()
|
||||
|
||||
def forward(self, input):
|
||||
pass
|
||||
Loading…
Reference in New Issue