complex lstm
This commit is contained in:
		
							parent
							
								
									7abd266ab2
								
							
						
					
					
						commit
						0b50a573e8
					
				|  | @ -0,0 +1,66 @@ | |||
| from typing import List, Optional | ||||
| 
 | ||||
| import torch | ||||
| from torch import nn | ||||
| 
 | ||||
| 
 | ||||
| class ComplexLSTM(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_size: int, | ||||
|         hidden_size: int, | ||||
|         num_layers: int = 1, | ||||
|         projection_size: Optional[int] = None, | ||||
|         bidirectional: bool = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.input_size = input_size // 2 | ||||
|         self.hidden_size = hidden_size // 2 | ||||
|         self.num_layers = num_layers | ||||
| 
 | ||||
|         self.real_lstm = nn.LSTM( | ||||
|             self.input_size, | ||||
|             self.hidden_size, | ||||
|             self.num_layers, | ||||
|             bidirectional=bidirectional, | ||||
|             batch_first=False, | ||||
|         ) | ||||
|         self.imag_lstm = nn.LSTM( | ||||
|             self.input_size, | ||||
|             self.hidden_size, | ||||
|             self.num_layers, | ||||
|             bidirectional=bidirectional, | ||||
|             batch_first=False, | ||||
|         ) | ||||
| 
 | ||||
|         bidirectional = 2 if bidirectional else 1 | ||||
|         if projection_size is not None: | ||||
|             self.projection_size = projection_size // 2 | ||||
|             self.real_linear = nn.Linear( | ||||
|                 self.hidden_size * bidirectional, self.projection_size | ||||
|             ) | ||||
|             self.imag_linear = nn.Linear( | ||||
|                 self.hidden_size * bidirectional, self.projection_size | ||||
|             ) | ||||
| 
 | ||||
|     def forward(self, input): | ||||
| 
 | ||||
|         if isinstance(input, List): | ||||
|             real, imag = input | ||||
|         else: | ||||
|             real, imag = torch.chunk(input, 2, 1) | ||||
| 
 | ||||
|         real_real = self.real_lstm(real)[0] | ||||
|         real_imag = self.imag_lstm(real)[0] | ||||
| 
 | ||||
|         imag_imag = self.imag_lstm(imag)[0] | ||||
|         imag_real = self.real_lstm(imag)[0] | ||||
| 
 | ||||
|         real = real_real - imag_imag | ||||
|         imag = imag_real + real_imag | ||||
| 
 | ||||
|         if self.projection_size is not None: | ||||
|             real = self.real_linear(real) | ||||
|             imag = self.imag_linear(imag) | ||||
| 
 | ||||
|         return [real, imag] | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786