diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 53ffba2..74c2baa 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,6 +1,7 @@ import torch from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from enhancer.models.complexnn.rnn import ComplexLSTM def test_complexconv2d(): @@ -27,3 +28,13 @@ def test_complexconvtranspose2d(): out = conv(sample_input) assert out.shape == torch.Size([1, 256, 8, 14]) + + +def test_complexlstm(): + sample_input = torch.rand(13, 2, 128) + lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2) + with torch.no_grad(): + out = lstm(sample_input) + + assert out[0].shape == torch.Size([13, 1, 512]) + assert out[1].shape == torch.Size([13, 1, 512])