tests complexnn

This commit is contained in:
shahules786 2022-11-01 10:35:49 +05:30
parent 0b50a573e8
commit b1144e7b81
1 changed files with 11 additions and 0 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.rnn import ComplexLSTM
def test_complexconv2d(): def test_complexconv2d():
@ -27,3 +28,13 @@ def test_complexconvtranspose2d():
out = conv(sample_input) out = conv(sample_input)
assert out.shape == torch.Size([1, 256, 8, 14]) assert out.shape == torch.Size([1, 256, 8, 14])
def test_complexlstm():
sample_input = torch.rand(13, 2, 128)
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
with torch.no_grad():
out = lstm(sample_input)
assert out[0].shape == torch.Size([13, 1, 512])
assert out[1].shape == torch.Size([13, 1, 512])