tests complexnn
This commit is contained in:
parent
0b50a573e8
commit
b1144e7b81
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
|
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||||
|
|
||||||
|
|
||||||
def test_complexconv2d():
|
def test_complexconv2d():
|
||||||
|
|
@ -27,3 +28,13 @@ def test_complexconvtranspose2d():
|
||||||
out = conv(sample_input)
|
out = conv(sample_input)
|
||||||
|
|
||||||
assert out.shape == torch.Size([1, 256, 8, 14])
|
assert out.shape == torch.Size([1, 256, 8, 14])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexlstm():
|
||||||
|
sample_input = torch.rand(13, 2, 128)
|
||||||
|
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = lstm(sample_input)
|
||||||
|
|
||||||
|
assert out[0].shape == torch.Size([13, 1, 512])
|
||||||
|
assert out[1].shape == torch.Size([13, 1, 512])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue