complex batchnorm 2d test

This commit is contained in:
shahules786 2022-11-03 16:06:14 +05:30
parent da1b986d31
commit d3e052c5f3
1 changed files with 10 additions and 0 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
from enhancer.models.complexnn.rnn import ComplexLSTM from enhancer.models.complexnn.rnn import ComplexLSTM
@ -38,3 +39,12 @@ def test_complexlstm():
assert out[0].shape == torch.Size([13, 1, 512]) assert out[0].shape == torch.Size([13, 1, 512])
assert out[1].shape == torch.Size([13, 1, 512]) assert out[1].shape == torch.Size([13, 1, 512])
def test_complexbatchnorm2d():
sample_input = torch.rand(1, 64, 64, 14)
batchnorm = ComplexBatchNorm2D(num_features=64)
with torch.no_grad():
out = batchnorm(sample_input)
assert out.size() == sample_input.size()