fix tests

This commit is contained in:
shahules786 2022-11-07 11:15:30 +05:30
parent 1a4102cc53
commit 77699ce7f9
1 changed files with 2 additions and 2 deletions

View File

@ -1,8 +1,8 @@
import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
from enhancer.models.complexnn.rnn import ComplexLSTM
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d():
@ -12,7 +12,7 @@ def test_complexconv2d():
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 14])
assert out.shape == torch.Size([1, 32, 128, 13])
def test_complexconvtranspose2d():