44 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			44 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
import pytest
 | 
						|
import torch
 | 
						|
 | 
						|
from enhancer.data.dataset import EnhancerDataset
 | 
						|
from enhancer.models import Demucs
 | 
						|
from enhancer.utils.config import Files
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def vctk_dataset():
 | 
						|
    root_dir = "tests/data/vctk"
 | 
						|
    files = Files(
 | 
						|
        train_clean="clean_testset_wav",
 | 
						|
        train_noisy="noisy_testset_wav",
 | 
						|
        test_clean="clean_testset_wav",
 | 
						|
        test_noisy="noisy_testset_wav",
 | 
						|
    )
 | 
						|
    dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
 | 
						|
    return dataset
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
 | 
						|
def test_forward(batch_size, samples):
 | 
						|
    model = Demucs()
 | 
						|
    model.eval()
 | 
						|
 | 
						|
    data = torch.rand(batch_size, 1, samples, requires_grad=False)
 | 
						|
    with torch.no_grad():
 | 
						|
        _ = model(data)
 | 
						|
 | 
						|
    data = torch.rand(batch_size, 2, samples, requires_grad=False)
 | 
						|
    with torch.no_grad():
 | 
						|
        with pytest.raises(TypeError):
 | 
						|
            _ = model(data)
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize(
 | 
						|
    "dataset,channels,loss",
 | 
						|
    [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
 | 
						|
)
 | 
						|
def test_demucs_init(dataset, channels, loss):
 | 
						|
    with torch.no_grad():
 | 
						|
        _ = Demucs(num_channels=channels, dataset=dataset, loss=loss)
 |