From 6626ad75e71f8639549e372ccad0f856f1ceb373 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 11:34:21 +0530 Subject: [PATCH] fix tests --- tests/models/demucs_test.py | 2 +- tests/models/test_dccrn.py | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_dccrn.py diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index f5a0ec4..29e030e 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -30,7 +30,7 @@ def test_forward(batch_size, samples): data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): _ = model(data) diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py new file mode 100644 index 0000000..96a853b --- /dev/null +++ b/tests/models/test_dccrn.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.dccrn import DCCRN +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = DCCRN() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(ValueError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)