diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 8c18ed5..524a6cf 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,8 +1,8 @@ import torch from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d -from enhancer.models.complexnn.norm import ComplexBatchNorm2D from enhancer.models.complexnn.rnn import ComplexLSTM +from enhancer.models.complexnn.utils import ComplexBatchNorm2D def test_complexconv2d(): @@ -12,7 +12,7 @@ def test_complexconv2d(): ) with torch.no_grad(): out = conv(sample_input) - assert out.shape == torch.Size([1, 32, 128, 14]) + assert out.shape == torch.Size([1, 32, 128, 13]) def test_complexconvtranspose2d():