import pytest import torch from mayavoz.data.dataset import MayaDataset from mayavoz.models.dccrn import DCCRN from mayavoz.utils.config import Files @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" files = Files( train_clean="clean_testset_wav", train_noisy="noisy_testset_wav", test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) def test_forward(batch_size, samples): model = DCCRN() model.eval() data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(ValueError): _ = model(data) @pytest.mark.parametrize( "dataset,channels,loss", [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], ) def test_demucs_init(dataset, channels, loss): with torch.no_grad(): _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)