fix tests
This commit is contained in:
		
							parent
							
								
									1a4102cc53
								
							
						
					
					
						commit
						77699ce7f9
					
				|  | @ -1,8 +1,8 @@ | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d | from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d | ||||||
| from enhancer.models.complexnn.norm import ComplexBatchNorm2D |  | ||||||
| from enhancer.models.complexnn.rnn import ComplexLSTM | from enhancer.models.complexnn.rnn import ComplexLSTM | ||||||
|  | from enhancer.models.complexnn.utils import ComplexBatchNorm2D | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_complexconv2d(): | def test_complexconv2d(): | ||||||
|  | @ -12,7 +12,7 @@ def test_complexconv2d(): | ||||||
|     ) |     ) | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|         out = conv(sample_input) |         out = conv(sample_input) | ||||||
|     assert out.shape == torch.Size([1, 32, 128, 14]) |     assert out.shape == torch.Size([1, 32, 128, 13]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_complexconvtranspose2d(): | def test_complexconvtranspose2d(): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786