This commit is contained in:
shahules786 2022-10-29 13:20:04 +05:30
parent cf1e5c07a9
commit 6f6e7f7ad8
3 changed files with 91 additions and 0 deletions

View File

@ -0,0 +1 @@
# from enhancer.models.complexnn.conv import ComplexConv2d

View File

@ -0,0 +1,77 @@
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
def init_weights(nnet):
nn.init.xavier_normal_(nnet.weight.data)
nn.init.constant(nnet.bias, 0.0)
return nnet
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
groups: int = 1,
dilation: int = 1,
):
"""
Complex Conv2d (non-causal)
"""
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.real_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = init_weights(self.imag_conv)
self.real_conv = init_weights(self.real_conv)
def forward(self, input):
"""
forward
complex axis should be always 1 dim
"""
input = F.pad(input, [self.padding[1], self.padding[1], 0, 0])
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out

View File

@ -0,0 +1,13 @@
import torch
from enhancer.models.complexnn.conv import ComplexConv2d
def test_complexconv2d():
sample_input = torch.rand(1, 2, 256, 13)
conv = ComplexConv2d(
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 14])