137 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			137 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
| from typing import Tuple
 | |
| 
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| from torch import nn
 | |
| 
 | |
| 
 | |
| def init_weights(nnet):
 | |
|     nn.init.xavier_normal_(nnet.weight.data)
 | |
|     nn.init.constant_(nnet.bias, 0.0)
 | |
|     return nnet
 | |
| 
 | |
| 
 | |
| class ComplexConv2d(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels: int,
 | |
|         out_channels: int,
 | |
|         kernel_size: Tuple[int, int] = (1, 1),
 | |
|         stride: Tuple[int, int] = (1, 1),
 | |
|         padding: Tuple[int, int] = (0, 0),
 | |
|         groups: int = 1,
 | |
|         dilation: int = 1,
 | |
|     ):
 | |
|         """
 | |
|         Complex Conv2d (non-causal)
 | |
|         """
 | |
|         super().__init__()
 | |
|         self.in_channels = in_channels // 2
 | |
|         self.out_channels = out_channels // 2
 | |
|         self.kernel_size = kernel_size
 | |
|         self.stride = stride
 | |
|         self.padding = padding
 | |
|         self.groups = groups
 | |
|         self.dilation = dilation
 | |
| 
 | |
|         self.real_conv = nn.Conv2d(
 | |
|             self.in_channels,
 | |
|             self.out_channels,
 | |
|             kernel_size=self.kernel_size,
 | |
|             stride=self.stride,
 | |
|             padding=(self.padding[0], 0),
 | |
|             groups=self.groups,
 | |
|             dilation=self.dilation,
 | |
|         )
 | |
|         self.imag_conv = nn.Conv2d(
 | |
|             self.in_channels,
 | |
|             self.out_channels,
 | |
|             kernel_size=self.kernel_size,
 | |
|             stride=self.stride,
 | |
|             padding=(self.padding[0], 0),
 | |
|             groups=self.groups,
 | |
|             dilation=self.dilation,
 | |
|         )
 | |
|         self.imag_conv = init_weights(self.imag_conv)
 | |
|         self.real_conv = init_weights(self.real_conv)
 | |
| 
 | |
|     def forward(self, input):
 | |
|         """
 | |
|         complex axis should be always 1 dim
 | |
|         """
 | |
|         input = F.pad(input, [self.padding[1], 0, 0, 0])
 | |
| 
 | |
|         real, imag = torch.chunk(input, 2, 1)
 | |
| 
 | |
|         real_real = self.real_conv(real)
 | |
|         real_imag = self.imag_conv(real)
 | |
| 
 | |
|         imag_imag = self.imag_conv(imag)
 | |
|         imag_real = self.real_conv(imag)
 | |
| 
 | |
|         real = real_real - imag_imag
 | |
|         imag = real_imag - imag_real
 | |
| 
 | |
|         out = torch.cat([real, imag], 1)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class ComplexConvTranspose2d(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels: int,
 | |
|         out_channels: int,
 | |
|         kernel_size: Tuple[int, int] = (1, 1),
 | |
|         stride: Tuple[int, int] = (1, 1),
 | |
|         padding: Tuple[int, int] = (0, 0),
 | |
|         output_padding: Tuple[int, int] = (0, 0),
 | |
|         groups: int = 1,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.in_channels = in_channels // 2
 | |
|         self.out_channels = out_channels // 2
 | |
|         self.kernel_size = kernel_size
 | |
|         self.stride = stride
 | |
|         self.padding = padding
 | |
|         self.groups = groups
 | |
|         self.output_padding = output_padding
 | |
| 
 | |
|         self.real_conv = nn.ConvTranspose2d(
 | |
|             self.in_channels,
 | |
|             self.out_channels,
 | |
|             kernel_size=self.kernel_size,
 | |
|             stride=self.stride,
 | |
|             padding=self.padding,
 | |
|             output_padding=self.output_padding,
 | |
|             groups=self.groups,
 | |
|         )
 | |
| 
 | |
|         self.imag_conv = nn.ConvTranspose2d(
 | |
|             self.in_channels,
 | |
|             self.out_channels,
 | |
|             kernel_size=self.kernel_size,
 | |
|             stride=self.stride,
 | |
|             padding=self.padding,
 | |
|             output_padding=self.output_padding,
 | |
|             groups=self.groups,
 | |
|         )
 | |
| 
 | |
|         self.real_conv = init_weights(self.real_conv)
 | |
|         self.imag_conv = init_weights(self.imag_conv)
 | |
| 
 | |
|     def forward(self, input):
 | |
| 
 | |
|         real, imag = torch.chunk(input, 2, 1)
 | |
|         real_real = self.real_conv(real)
 | |
|         real_imag = self.imag_conv(real)
 | |
| 
 | |
|         imag_imag = self.imag_conv(imag)
 | |
|         imag_real = self.real_conv(imag)
 | |
| 
 | |
|         real = real_real - imag_imag
 | |
|         imag = real_imag - imag_real
 | |
| 
 | |
|         out = torch.cat([real, imag], 1)
 | |
| 
 | |
|         return out
 |