diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index f5a0ec4..29e030e 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -30,7 +30,7 @@ def test_forward(batch_size, samples): data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): _ = model(data) diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py new file mode 100644 index 0000000..96a853b --- /dev/null +++ b/tests/models/test_dccrn.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.dccrn import DCCRN +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = DCCRN() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(ValueError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)