tests complexnn
This commit is contained in:
parent
0b50a573e8
commit
b1144e7b81
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
|
||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||
|
||||
|
||||
def test_complexconv2d():
|
||||
|
|
@ -27,3 +28,13 @@ def test_complexconvtranspose2d():
|
|||
out = conv(sample_input)
|
||||
|
||||
assert out.shape == torch.Size([1, 256, 8, 14])
|
||||
|
||||
|
||||
def test_complexlstm():
|
||||
sample_input = torch.rand(13, 2, 128)
|
||||
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
|
||||
with torch.no_grad():
|
||||
out = lstm(sample_input)
|
||||
|
||||
assert out[0].shape == torch.Size([13, 1, 512])
|
||||
assert out[1].shape == torch.Size([13, 1, 512])
|
||||
|
|
|
|||
Loading…
Reference in New Issue