diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py index 55acc07..d9a4d0f 100644 --- a/enhancer/models/complexnn/conv.py +++ b/enhancer/models/complexnn/conv.py @@ -59,9 +59,10 @@ class ComplexConv2d(nn.Module): """ complex axis should be always 1 dim """ - input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) + input = F.pad(input, [self.padding[1], 0, 0, 0]) real, imag = torch.chunk(input, 2, 1) + real_real = self.real_conv(real) real_imag = self.imag_conv(real) @@ -72,7 +73,6 @@ class ComplexConv2d(nn.Module): imag = real_imag - imag_real out = torch.cat([real, imag], 1) - return out @@ -116,13 +116,12 @@ class ComplexConvTranspose2d(nn.Module): groups=self.groups, ) - init_weights(self.real_conv) - init_weights(self.imag_conv) + self.real_conv = init_weights(self.real_conv) + self.imag_conv = init_weights(self.imag_conv) def forward(self, input): real, imag = torch.chunk(input, 2, 1) - real_real = self.real_conv(real) real_imag = self.imag_conv(real)