rename module
This commit is contained in:
parent
981763207a
commit
b98599f21e
|
|
@ -76,6 +76,11 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
self.Var_ii.fill_(1)
|
self.Var_ii.fill_(1)
|
||||||
self.num_batches_tracked.zero_()
|
self.num_batches_tracked.zero_()
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
|
||||||
|
**self.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
@ -159,3 +164,17 @@ class ComplexBatchNorm2D(nn.Module):
|
||||||
|
|
||||||
outputs = torch.cat([yr, yi], 1)
|
outputs = torch.cat([yr, yi], 1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexRelu(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.real_relu = nn.PReLU()
|
||||||
|
self.imag_relu = nn.PReLU()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
real = self.real_relu(real)
|
||||||
|
imag = self.imag_relu(imag)
|
||||||
|
return torch.cat([real, imag], dim=1)
|
||||||
Loading…
Reference in New Issue