fix padding & init

This commit is contained in:
shahules786 2022-11-07 10:23:46 +05:30
parent 70d17f6586
commit c21f05e307
1 changed files with 4 additions and 5 deletions

View File

@ -59,9 +59,10 @@ class ComplexConv2d(nn.Module):
"""
complex axis should be always 1 dim
"""
input = F.pad(input, [self.padding[1], self.padding[1], 0, 0])
input = F.pad(input, [self.padding[1], 0, 0, 0])
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
@ -72,7 +73,6 @@ class ComplexConv2d(nn.Module):
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out
@ -116,13 +116,12 @@ class ComplexConvTranspose2d(nn.Module):
groups=self.groups,
)
init_weights(self.real_conv)
init_weights(self.imag_conv)
self.real_conv = init_weights(self.real_conv)
self.imag_conv = init_weights(self.imag_conv)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)