mayavoz/enhancer/models/complexnn/rnn.py

67 lines
1.8 KiB
Python

from typing import List, Optional
import torch
from torch import nn
class ComplexLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
projection_size: Optional[int] = None,
bidirectional: bool = False,
):
super().__init__()
self.input_size = input_size // 2
self.hidden_size = hidden_size // 2
self.num_layers = num_layers
self.real_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
self.imag_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
bidirectional = 2 if bidirectional else 1
if projection_size is not None:
self.projection_size = projection_size // 2
self.real_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
self.imag_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
def forward(self, input):
if isinstance(input, List):
real, imag = input
else:
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_lstm(real)[0]
real_imag = self.imag_lstm(real)[0]
imag_imag = self.imag_lstm(imag)[0]
imag_real = self.real_lstm(imag)[0]
real = real_real - imag_imag
imag = imag_real + real_imag
if self.projection_size is not None:
real = self.real_linear(real)
imag = self.imag_linear(imag)
return [real, imag]