fix padding & init
This commit is contained in:
parent
70d17f6586
commit
c21f05e307
|
|
@ -59,9 +59,10 @@ class ComplexConv2d(nn.Module):
|
||||||
"""
|
"""
|
||||||
complex axis should be always 1 dim
|
complex axis should be always 1 dim
|
||||||
"""
|
"""
|
||||||
input = F.pad(input, [self.padding[1], self.padding[1], 0, 0])
|
input = F.pad(input, [self.padding[1], 0, 0, 0])
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
real_real = self.real_conv(real)
|
real_real = self.real_conv(real)
|
||||||
real_imag = self.imag_conv(real)
|
real_imag = self.imag_conv(real)
|
||||||
|
|
||||||
|
|
@ -72,7 +73,6 @@ class ComplexConv2d(nn.Module):
|
||||||
imag = real_imag - imag_real
|
imag = real_imag - imag_real
|
||||||
|
|
||||||
out = torch.cat([real, imag], 1)
|
out = torch.cat([real, imag], 1)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,13 +116,12 @@ class ComplexConvTranspose2d(nn.Module):
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
init_weights(self.real_conv)
|
self.real_conv = init_weights(self.real_conv)
|
||||||
init_weights(self.imag_conv)
|
self.imag_conv = init_weights(self.imag_conv)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
real_real = self.real_conv(real)
|
real_real = self.real_conv(real)
|
||||||
real_imag = self.imag_conv(real)
|
real_imag = self.imag_conv(real)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue