tests waveunet
This commit is contained in:
parent
21cee225c2
commit
26a1f862f6
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
import torch
|
||||
from enhancer import data
|
||||
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.models import WaveUnet
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vctk_dataset():
|
||||
root_dir = "tests/data/vctk"
|
||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
||||
def test_forward(batch_size,samples):
|
||||
model = WaveUnet()
|
||||
model.eval()
|
||||
|
||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
||||
with torch.no_grad():
|
||||
_ = model(data)
|
||||
|
||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
||||
with torch.no_grad():
|
||||
with pytest.raises(TypeError):
|
||||
_ = model(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
||||
def test_demucs_init(dataset,channels,loss):
|
||||
with torch.no_grad():
|
||||
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue