fix padding & init

This commit is contained in:
shahules786 2022-11-07 10:23:46 +05:30
parent 70d17f6586
commit c21f05e307
1 changed files with 4 additions and 5 deletions

View File

@ -59,9 +59,10 @@ class ComplexConv2d(nn.Module):
""" """
complex axis should be always 1 dim complex axis should be always 1 dim
""" """
input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) input = F.pad(input, [self.padding[1], 0, 0, 0])
real, imag = torch.chunk(input, 2, 1) real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real) real_real = self.real_conv(real)
real_imag = self.imag_conv(real) real_imag = self.imag_conv(real)
@ -72,7 +73,6 @@ class ComplexConv2d(nn.Module):
imag = real_imag - imag_real imag = real_imag - imag_real
out = torch.cat([real, imag], 1) out = torch.cat([real, imag], 1)
return out return out
@ -116,13 +116,12 @@ class ComplexConvTranspose2d(nn.Module):
groups=self.groups, groups=self.groups,
) )
init_weights(self.real_conv) self.real_conv = init_weights(self.real_conv)
init_weights(self.imag_conv) self.imag_conv = init_weights(self.imag_conv)
def forward(self, input): def forward(self, input):
real, imag = torch.chunk(input, 2, 1) real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real) real_real = self.real_conv(real)
real_imag = self.imag_conv(real) real_imag = self.imag_conv(real)