fix tests
This commit is contained in:
parent
1a4102cc53
commit
77699ce7f9
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
|
|
||||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||||
|
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
|
||||||
|
|
||||||
|
|
||||||
def test_complexconv2d():
|
def test_complexconv2d():
|
||||||
|
|
@ -12,7 +12,7 @@ def test_complexconv2d():
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = conv(sample_input)
|
out = conv(sample_input)
|
||||||
assert out.shape == torch.Size([1, 32, 128, 14])
|
assert out.shape == torch.Size([1, 32, 128, 13])
|
||||||
|
|
||||||
|
|
||||||
def test_complexconvtranspose2d():
|
def test_complexconvtranspose2d():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue