From 26cccc67721803c2b3fd0b0f988d76fce3f6278a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 31 Oct 2022 11:43:32 +0530 Subject: [PATCH] complex tranposed conv --- enhancer/models/complexnn/conv.py | 64 ++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py index eddcce3..55acc07 100644 --- a/enhancer/models/complexnn/conv.py +++ b/enhancer/models/complexnn/conv.py @@ -7,7 +7,7 @@ from torch import nn def init_weights(nnet): nn.init.xavier_normal_(nnet.weight.data) - nn.init.constant(nnet.bias, 0.0) + nn.init.constant_(nnet.bias, 0.0) return nnet @@ -57,7 +57,6 @@ class ComplexConv2d(nn.Module): def forward(self, input): """ - forward complex axis should be always 1 dim """ input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) @@ -75,3 +74,64 @@ class ComplexConv2d(nn.Module): out = torch.cat([real, imag], 1) return out + + +class ComplexConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + output_padding: Tuple[int, int] = (0, 0), + groups: int = 1, + ): + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.output_padding = output_padding + + self.real_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + self.imag_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + init_weights(self.real_conv) + init_weights(self.imag_conv) + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + + return out