fix tests

This commit is contained in:
shahules786 2022-11-07 11:15:30 +05:30
parent 1a4102cc53
commit 77699ce7f9
1 changed files with 2 additions and 2 deletions

View File

@ -1,8 +1,8 @@
import torch import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
from enhancer.models.complexnn.rnn import ComplexLSTM from enhancer.models.complexnn.rnn import ComplexLSTM
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d(): def test_complexconv2d():
@ -12,7 +12,7 @@ def test_complexconv2d():
) )
with torch.no_grad(): with torch.no_grad():
out = conv(sample_input) out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 14]) assert out.shape == torch.Size([1, 32, 128, 13])
def test_complexconvtranspose2d(): def test_complexconvtranspose2d():