test inference
This commit is contained in:
		
							parent
							
								
									868fde7e69
								
							
						
					
					
						commit
						9890249824
					
				|  | @ -0,0 +1,26 @@ | |||
| import pytest | ||||
| import torch | ||||
| import numpy as np | ||||
| 
 | ||||
| from enhancer.inference import Inference | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) | ||||
| def test_read_input(audio): | ||||
| 
 | ||||
|     read_audio = Inference.read_input(audio,48000,16000) | ||||
|     assert isinstance(read_audio,torch.Tensor) | ||||
|     assert read_audio.shape[0] == 1 | ||||
| 
 | ||||
| def test_batchify(): | ||||
|     rand = torch.rand(1,1000) | ||||
|     batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) | ||||
|     assert batched_rand.shape[0] == 12 | ||||
| 
 | ||||
| def test_aggregate(): | ||||
|     rand = torch.rand(12,1,100) | ||||
|     agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) | ||||
|     assert agg_rand.shape[-1] == 1000 | ||||
| 
 | ||||
| 
 | ||||
|      | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786