test complexnn
This commit is contained in:
parent
26cccc6772
commit
7abd266ab2
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.models.complexnn.conv import ComplexConv2d
|
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
|
|
||||||
|
|
||||||
def test_complexconv2d():
|
def test_complexconv2d():
|
||||||
|
|
@ -11,3 +11,19 @@ def test_complexconv2d():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = conv(sample_input)
|
out = conv(sample_input)
|
||||||
assert out.shape == torch.Size([1, 32, 128, 14])
|
assert out.shape == torch.Size([1, 32, 128, 14])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexconvtranspose2d():
|
||||||
|
sample_input = torch.rand(1, 512, 4, 13)
|
||||||
|
conv = ComplexConvTranspose2d(
|
||||||
|
256 * 2,
|
||||||
|
128 * 2,
|
||||||
|
kernel_size=(5, 2),
|
||||||
|
stride=(2, 1),
|
||||||
|
padding=(2, 0),
|
||||||
|
output_padding=(1, 0),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = conv(sample_input)
|
||||||
|
|
||||||
|
assert out.shape == torch.Size([1, 256, 8, 14])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue