fix tests

This commit is contained in:
shahules786 2022-11-07 11:34:21 +05:30
parent 6573bc4c5e
commit 6626ad75e7
2 changed files with 44 additions and 1 deletions

View File

@ -30,7 +30,7 @@ def test_forward(batch_size, samples):
data = torch.rand(batch_size, 2, samples, requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(ValueError):
_ = model(data) _ = model(data)

View File

@ -0,0 +1,43 @@
import pytest
import torch
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.dccrn import DCCRN
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = DCCRN()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(ValueError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)