14 lines
353 B
Python
14 lines
353 B
Python
import torch
|
|
|
|
from enhancer.models.complexnn.conv import ComplexConv2d
|
|
|
|
|
|
def test_complexconv2d():
|
|
sample_input = torch.rand(1, 2, 256, 13)
|
|
conv = ComplexConv2d(
|
|
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
|
|
)
|
|
with torch.no_grad():
|
|
out = conv(sample_input)
|
|
assert out.shape == torch.Size([1, 32, 128, 14])
|