From 6f6e7f7ad85ab92d9989404957624969a1c3e8f7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 29 Oct 2022 13:20:04 +0530 Subject: [PATCH] init --- enhancer/models/complexnn/__init__.py | 1 + enhancer/models/complexnn/conv.py | 77 +++++++++++++++++++++++++++ tests/models/complexnn_test.py | 13 +++++ 3 files changed, 91 insertions(+) create mode 100644 enhancer/models/complexnn/__init__.py create mode 100644 tests/models/complexnn_test.py diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py new file mode 100644 index 0000000..68f376e --- /dev/null +++ b/enhancer/models/complexnn/__init__.py @@ -0,0 +1 @@ +# from enhancer.models.complexnn.conv import ComplexConv2d diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py index e69de29..eddcce3 100644 --- a/enhancer/models/complexnn/conv.py +++ b/enhancer/models/complexnn/conv.py @@ -0,0 +1,77 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def init_weights(nnet): + nn.init.xavier_normal_(nnet.weight.data) + nn.init.constant(nnet.bias, 0.0) + return nnet + + +class ComplexConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + groups: int = 1, + dilation: int = 1, + ): + """ + Complex Conv2d (non-causal) + """ + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.dilation = dilation + + self.real_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = init_weights(self.imag_conv) + self.real_conv = init_weights(self.real_conv) + + def forward(self, input): + """ + forward + complex axis should be always 1 dim + """ + input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) + + real, imag = torch.chunk(input, 2, 1) + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + + return out diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py new file mode 100644 index 0000000..9ca5811 --- /dev/null +++ b/tests/models/complexnn_test.py @@ -0,0 +1,13 @@ +import torch + +from enhancer.models.complexnn.conv import ComplexConv2d + + +def test_complexconv2d(): + sample_input = torch.rand(1, 2, 256, 13) + conv = ComplexConv2d( + 2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1) + ) + with torch.no_grad(): + out = conv(sample_input) + assert out.shape == torch.Size([1, 32, 128, 14])