rename module
This commit is contained in:
parent
981763207a
commit
b98599f21e
|
|
@ -76,6 +76,11 @@ class ComplexBatchNorm2D(nn.Module):
|
|||
self.Var_ii.fill_(1)
|
||||
self.num_batches_tracked.zero_()
|
||||
|
||||
def extra_repr(self):
|
||||
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
real, imag = torch.chunk(input, 2, 1)
|
||||
|
|
@ -159,3 +164,17 @@ class ComplexBatchNorm2D(nn.Module):
|
|||
|
||||
outputs = torch.cat([yr, yi], 1)
|
||||
return outputs
|
||||
|
||||
|
||||
class ComplexRelu(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.real_relu = nn.PReLU()
|
||||
self.imag_relu = nn.PReLU()
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
real, imag = torch.chunk(input, 2, 1)
|
||||
real = self.real_relu(real)
|
||||
imag = self.imag_relu(imag)
|
||||
return torch.cat([real, imag], dim=1)
|
||||
Loading…
Reference in New Issue