mayavoz/tests/models/complexnn_test.py

14 lines
353 B
Python

import torch
from enhancer.models.complexnn.conv import ComplexConv2d
def test_complexconv2d():
sample_input = torch.rand(1, 2, 256, 13)
conv = ComplexConv2d(
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 14])