47 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			47 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| import pytest
 | |
| import torch
 | |
| from enhancer import data
 | |
| 
 | |
| from enhancer.utils.config import Files
 | |
| from enhancer.models import WaveUnet
 | |
| from enhancer.data.dataset import EnhancerDataset
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def vctk_dataset():
 | |
|     root_dir = "tests/data/vctk"
 | |
|     files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
 | |
|          test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
 | |
|     dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
 | |
|     return dataset
 | |
|     
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("batch_size,samples",[(1,1000)])
 | |
| def test_forward(batch_size,samples):
 | |
|     model = WaveUnet()
 | |
|     model.eval()
 | |
| 
 | |
|     data = torch.rand(batch_size,1,samples,requires_grad=False)
 | |
|     with torch.no_grad():
 | |
|         _ = model(data)
 | |
| 
 | |
|     data = torch.rand(batch_size,2,samples,requires_grad=False)
 | |
|     with torch.no_grad():
 | |
|         with pytest.raises(TypeError):
 | |
|             _ = model(data)
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("dataset,channels,loss",
 | |
|                         [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
 | |
| def test_demucs_init(dataset,channels,loss):
 | |
|     with torch.no_grad():
 | |
|         model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
|     
 | |
| 
 |