fix padding & init
This commit is contained in:
parent
70d17f6586
commit
c21f05e307
|
|
@ -59,9 +59,10 @@ class ComplexConv2d(nn.Module):
|
|||
"""
|
||||
complex axis should be always 1 dim
|
||||
"""
|
||||
input = F.pad(input, [self.padding[1], self.padding[1], 0, 0])
|
||||
input = F.pad(input, [self.padding[1], 0, 0, 0])
|
||||
|
||||
real, imag = torch.chunk(input, 2, 1)
|
||||
|
||||
real_real = self.real_conv(real)
|
||||
real_imag = self.imag_conv(real)
|
||||
|
||||
|
|
@ -72,7 +73,6 @@ class ComplexConv2d(nn.Module):
|
|||
imag = real_imag - imag_real
|
||||
|
||||
out = torch.cat([real, imag], 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -116,13 +116,12 @@ class ComplexConvTranspose2d(nn.Module):
|
|||
groups=self.groups,
|
||||
)
|
||||
|
||||
init_weights(self.real_conv)
|
||||
init_weights(self.imag_conv)
|
||||
self.real_conv = init_weights(self.real_conv)
|
||||
self.imag_conv = init_weights(self.imag_conv)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
real, imag = torch.chunk(input, 2, 1)
|
||||
|
||||
real_real = self.real_conv(real)
|
||||
real_imag = self.imag_conv(real)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue