rename module

This commit is contained in:
shahules786 2022-11-05 16:36:27 +05:30
parent 981763207a
commit b98599f21e
1 changed files with 19 additions and 0 deletions

View File

@ -76,6 +76,11 @@ class ComplexBatchNorm2D(nn.Module):
self.Var_ii.fill_(1)
self.num_batches_tracked.zero_()
def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
**self.__dict__
)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
@ -159,3 +164,17 @@ class ComplexBatchNorm2D(nn.Module):
outputs = torch.cat([yr, yi], 1)
return outputs
class ComplexRelu(nn.Module):
def __init__(self):
super().__init__()
self.real_relu = nn.PReLU()
self.imag_relu = nn.PReLU()
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real = self.real_relu(real)
imag = self.imag_relu(imag)
return torch.cat([real, imag], dim=1)