diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 74c2baa..8c18ed5 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,6 +1,7 @@ import torch from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from enhancer.models.complexnn.norm import ComplexBatchNorm2D from enhancer.models.complexnn.rnn import ComplexLSTM @@ -38,3 +39,12 @@ def test_complexlstm(): assert out[0].shape == torch.Size([13, 1, 512]) assert out[1].shape == torch.Size([13, 1, 512]) + + +def test_complexbatchnorm2d(): + sample_input = torch.rand(1, 64, 64, 14) + batchnorm = ComplexBatchNorm2D(num_features=64) + with torch.no_grad(): + out = batchnorm(sample_input) + + assert out.size() == sample_input.size()