complex batchnorm 2d test
This commit is contained in:
parent
da1b986d31
commit
d3e052c5f3
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
|
from enhancer.models.complexnn.norm import ComplexBatchNorm2D
|
||||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,3 +39,12 @@ def test_complexlstm():
|
||||||
|
|
||||||
assert out[0].shape == torch.Size([13, 1, 512])
|
assert out[0].shape == torch.Size([13, 1, 512])
|
||||||
assert out[1].shape == torch.Size([13, 1, 512])
|
assert out[1].shape == torch.Size([13, 1, 512])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexbatchnorm2d():
|
||||||
|
sample_input = torch.rand(1, 64, 64, 14)
|
||||||
|
batchnorm = ComplexBatchNorm2D(num_features=64)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = batchnorm(sample_input)
|
||||||
|
|
||||||
|
assert out.size() == sample_input.size()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue