51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
import torch
|
|
|
|
from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
|
from mayavoz.models.complexnn.rnn import ComplexLSTM
|
|
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D
|
|
|
|
|
|
def test_complexconv2d():
|
|
sample_input = torch.rand(1, 2, 256, 13)
|
|
conv = ComplexConv2d(
|
|
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
|
|
)
|
|
with torch.no_grad():
|
|
out = conv(sample_input)
|
|
assert out.shape == torch.Size([1, 32, 128, 13])
|
|
|
|
|
|
def test_complexconvtranspose2d():
|
|
sample_input = torch.rand(1, 512, 4, 13)
|
|
conv = ComplexConvTranspose2d(
|
|
256 * 2,
|
|
128 * 2,
|
|
kernel_size=(5, 2),
|
|
stride=(2, 1),
|
|
padding=(2, 0),
|
|
output_padding=(1, 0),
|
|
)
|
|
with torch.no_grad():
|
|
out = conv(sample_input)
|
|
|
|
assert out.shape == torch.Size([1, 256, 8, 14])
|
|
|
|
|
|
def test_complexlstm():
|
|
sample_input = torch.rand(13, 2, 128)
|
|
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
|
|
with torch.no_grad():
|
|
out = lstm(sample_input)
|
|
|
|
assert out[0].shape == torch.Size([13, 1, 512])
|
|
assert out[1].shape == torch.Size([13, 1, 512])
|
|
|
|
|
|
def test_complexbatchnorm2d():
|
|
sample_input = torch.rand(1, 64, 64, 14)
|
|
batchnorm = ComplexBatchNorm2D(num_features=64)
|
|
with torch.no_grad():
|
|
out = batchnorm(sample_input)
|
|
|
|
assert out.size() == sample_input.size()
|