fix tests
This commit is contained in:
parent
1a4102cc53
commit
77699ce7f9
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
|
||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
|
||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
|
||||
|
||||
|
||||
def test_complexconv2d():
|
||||
|
|
@ -12,7 +12,7 @@ def test_complexconv2d():
|
|||
)
|
||||
with torch.no_grad():
|
||||
out = conv(sample_input)
|
||||
assert out.shape == torch.Size([1, 32, 128, 14])
|
||||
assert out.shape == torch.Size([1, 32, 128, 13])
|
||||
|
||||
|
||||
def test_complexconvtranspose2d():
|
||||
|
|
|
|||
Loading…
Reference in New Issue