test complexnn

This commit is contained in:
shahules786 2022-10-31 11:43:50 +05:30
parent 26cccc6772
commit 7abd266ab2
1 changed files with 17 additions and 1 deletions

View File

@ -1,6 +1,6 @@
import torch
from enhancer.models.complexnn.conv import ComplexConv2d
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
def test_complexconv2d():
@ -11,3 +11,19 @@ def test_complexconv2d():
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 14])
def test_complexconvtranspose2d():
sample_input = torch.rand(1, 512, 4, 13)
conv = ComplexConvTranspose2d(
256 * 2,
128 * 2,
kernel_size=(5, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 256, 8, 14])