fix tests
This commit is contained in:
		
							parent
							
								
									1a4102cc53
								
							
						
					
					
						commit
						77699ce7f9
					
				|  | @ -1,8 +1,8 @@ | |||
| import torch | ||||
| 
 | ||||
| from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d | ||||
| from enhancer.models.complexnn.norm import ComplexBatchNorm2D | ||||
| from enhancer.models.complexnn.rnn import ComplexLSTM | ||||
| from enhancer.models.complexnn.utils import ComplexBatchNorm2D | ||||
| 
 | ||||
| 
 | ||||
| def test_complexconv2d(): | ||||
|  | @ -12,7 +12,7 @@ def test_complexconv2d(): | |||
|     ) | ||||
|     with torch.no_grad(): | ||||
|         out = conv(sample_input) | ||||
|     assert out.shape == torch.Size([1, 32, 128, 14]) | ||||
|     assert out.shape == torch.Size([1, 32, 128, 13]) | ||||
| 
 | ||||
| 
 | ||||
| def test_complexconvtranspose2d(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786