67 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
| from typing import List, Optional
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| 
 | |
| class ComplexLSTM(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_size: int,
 | |
|         hidden_size: int,
 | |
|         num_layers: int = 1,
 | |
|         projection_size: Optional[int] = None,
 | |
|         bidirectional: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.input_size = input_size // 2
 | |
|         self.hidden_size = hidden_size // 2
 | |
|         self.num_layers = num_layers
 | |
| 
 | |
|         self.real_lstm = nn.LSTM(
 | |
|             self.input_size,
 | |
|             self.hidden_size,
 | |
|             self.num_layers,
 | |
|             bidirectional=bidirectional,
 | |
|             batch_first=False,
 | |
|         )
 | |
|         self.imag_lstm = nn.LSTM(
 | |
|             self.input_size,
 | |
|             self.hidden_size,
 | |
|             self.num_layers,
 | |
|             bidirectional=bidirectional,
 | |
|             batch_first=False,
 | |
|         )
 | |
| 
 | |
|         bidirectional = 2 if bidirectional else 1
 | |
|         if projection_size is not None:
 | |
|             self.projection_size = projection_size // 2
 | |
|             self.real_linear = nn.Linear(
 | |
|                 self.hidden_size * bidirectional, self.projection_size
 | |
|             )
 | |
|             self.imag_linear = nn.Linear(
 | |
|                 self.hidden_size * bidirectional, self.projection_size
 | |
|             )
 | |
| 
 | |
|     def forward(self, input):
 | |
| 
 | |
|         if isinstance(input, List):
 | |
|             real, imag = input
 | |
|         else:
 | |
|             real, imag = torch.chunk(input, 2, 1)
 | |
| 
 | |
|         real_real = self.real_lstm(real)[0]
 | |
|         real_imag = self.imag_lstm(real)[0]
 | |
| 
 | |
|         imag_imag = self.imag_lstm(imag)[0]
 | |
|         imag_real = self.real_lstm(imag)[0]
 | |
| 
 | |
|         real = real_real - imag_imag
 | |
|         imag = imag_real + real_imag
 | |
| 
 | |
|         if self.projection_size is not None:
 | |
|             real = self.real_linear(real)
 | |
|             imag = self.imag_linear(imag)
 | |
| 
 | |
|         return [real, imag]
 |