complex batchnorm 2d test
This commit is contained in:
parent
da1b986d31
commit
d3e052c5f3
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
|
||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
|
||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||
|
||||
|
||||
|
|
@ -38,3 +39,12 @@ def test_complexlstm():
|
|||
|
||||
assert out[0].shape == torch.Size([13, 1, 512])
|
||||
assert out[1].shape == torch.Size([13, 1, 512])
|
||||
|
||||
|
||||
def test_complexbatchnorm2d():
|
||||
sample_input = torch.rand(1, 64, 64, 14)
|
||||
batchnorm = ComplexBatchNorm2D(num_features=64)
|
||||
with torch.no_grad():
|
||||
out = batchnorm(sample_input)
|
||||
|
||||
assert out.size() == sample_input.size()
|
||||
|
|
|
|||
Loading…
Reference in New Issue