complex lstm
This commit is contained in:
parent
7abd266ab2
commit
0b50a573e8
|
|
@ -0,0 +1,66 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexLSTM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
projection_size: Optional[int] = None,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size // 2
|
||||||
|
self.hidden_size = hidden_size // 2
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.real_lstm = nn.LSTM(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
self.imag_lstm = nn.LSTM(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
bidirectional = 2 if bidirectional else 1
|
||||||
|
if projection_size is not None:
|
||||||
|
self.projection_size = projection_size // 2
|
||||||
|
self.real_linear = nn.Linear(
|
||||||
|
self.hidden_size * bidirectional, self.projection_size
|
||||||
|
)
|
||||||
|
self.imag_linear = nn.Linear(
|
||||||
|
self.hidden_size * bidirectional, self.projection_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if isinstance(input, List):
|
||||||
|
real, imag = input
|
||||||
|
else:
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
|
real_real = self.real_lstm(real)[0]
|
||||||
|
real_imag = self.imag_lstm(real)[0]
|
||||||
|
|
||||||
|
imag_imag = self.imag_lstm(imag)[0]
|
||||||
|
imag_real = self.real_lstm(imag)[0]
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = imag_real + real_imag
|
||||||
|
|
||||||
|
if self.projection_size is not None:
|
||||||
|
real = self.real_linear(real)
|
||||||
|
imag = self.imag_linear(imag)
|
||||||
|
|
||||||
|
return [real, imag]
|
||||||
Loading…
Reference in New Issue