test complexnn
This commit is contained in:
parent
26cccc6772
commit
7abd266ab2
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from enhancer.models.complexnn.conv import ComplexConv2d
|
||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||
|
||||
|
||||
def test_complexconv2d():
|
||||
|
|
@ -11,3 +11,19 @@ def test_complexconv2d():
|
|||
with torch.no_grad():
|
||||
out = conv(sample_input)
|
||||
assert out.shape == torch.Size([1, 32, 128, 14])
|
||||
|
||||
|
||||
def test_complexconvtranspose2d():
|
||||
sample_input = torch.rand(1, 512, 4, 13)
|
||||
conv = ComplexConvTranspose2d(
|
||||
256 * 2,
|
||||
128 * 2,
|
||||
kernel_size=(5, 2),
|
||||
stride=(2, 1),
|
||||
padding=(2, 0),
|
||||
output_padding=(1, 0),
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = conv(sample_input)
|
||||
|
||||
assert out.shape == torch.Size([1, 256, 8, 14])
|
||||
|
|
|
|||
Loading…
Reference in New Issue