From b98599f21e0e940246f041395b6cd2fe6f40e451 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:36:27 +0530 Subject: [PATCH] rename module --- .../models/complexnn/{norm.py => utils.py} | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) rename enhancer/models/complexnn/{norm.py => utils.py} (91%) diff --git a/enhancer/models/complexnn/norm.py b/enhancer/models/complexnn/utils.py similarity index 91% rename from enhancer/models/complexnn/norm.py rename to enhancer/models/complexnn/utils.py index 5dd0104..d5de558 100644 --- a/enhancer/models/complexnn/norm.py +++ b/enhancer/models/complexnn/utils.py @@ -76,6 +76,11 @@ class ComplexBatchNorm2D(nn.Module): self.Var_ii.fill_(1) self.num_batches_tracked.zero_() + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( + **self.__dict__ + ) + def forward(self, input): real, imag = torch.chunk(input, 2, 1) @@ -159,3 +164,17 @@ class ComplexBatchNorm2D(nn.Module): outputs = torch.cat([yr, yi], 1) return outputs + + +class ComplexRelu(nn.Module): + def __init__(self): + super().__init__() + self.real_relu = nn.PReLU() + self.imag_relu = nn.PReLU() + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + real = self.real_relu(real) + imag = self.imag_relu(imag) + return torch.cat([real, imag], dim=1)