tests
This commit is contained in:
parent
8ac01b846d
commit
8a07cb8712
|
|
@ -6,6 +6,7 @@ from enhancer.loss import mean_absolute_error, mean_squared_error
|
|||
|
||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||
|
||||
|
||||
def check_loss_shapes_compatibility(loss_fun):
|
||||
|
||||
batch_size = 4
|
||||
|
|
@ -20,12 +21,13 @@ def check_loss_shapes_compatibility(loss_fun):
|
|||
def test_loss_input_shapes(loss):
|
||||
check_loss_shapes_compatibility(loss)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("loss", loss_functions)
|
||||
def test_loss_output_type(loss):
|
||||
|
||||
batch_size = 4
|
||||
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
||||
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
|
||||
batch_size, 1, 1000
|
||||
)
|
||||
loss_value = loss(prediction, target)
|
||||
assert isinstance(loss_value.item(), float)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,13 +10,16 @@ from enhancer.data.dataset import EnhancerDataset
|
|||
@pytest.fixture
|
||||
def vctk_dataset():
|
||||
root_dir = "tests/data/vctk"
|
||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
||||
files = Files(
|
||||
train_clean="clean_testset_wav",
|
||||
train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||
def test_forward(batch_size, samples):
|
||||
model = Demucs()
|
||||
|
|
@ -32,15 +35,10 @@ def test_forward(batch_size,samples):
|
|||
_ = model(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
||||
@pytest.mark.parametrize(
|
||||
"dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||
)
|
||||
def test_demucs_init(dataset, channels, loss):
|
||||
with torch.no_grad():
|
||||
model = Demucs(num_channels=channels, dataset=dataset, loss=loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,13 +10,16 @@ from enhancer.data.dataset import EnhancerDataset
|
|||
@pytest.fixture
|
||||
def vctk_dataset():
|
||||
root_dir = "tests/data/vctk"
|
||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
||||
files = Files(
|
||||
train_clean="clean_testset_wav",
|
||||
train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||
def test_forward(batch_size, samples):
|
||||
model = WaveUnet()
|
||||
|
|
@ -32,15 +35,10 @@ def test_forward(batch_size,samples):
|
|||
_ = model(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
||||
@pytest.mark.parametrize(
|
||||
"dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||
)
|
||||
def test_demucs_init(dataset, channels, loss):
|
||||
with torch.no_grad():
|
||||
model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,22 +4,26 @@ import torch
|
|||
from enhancer.inference import Inference
|
||||
|
||||
|
||||
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)])
|
||||
@pytest.mark.parametrize(
|
||||
"audio",
|
||||
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
|
||||
)
|
||||
def test_read_input(audio):
|
||||
|
||||
read_audio = Inference.read_input(audio, 48000, 16000)
|
||||
assert isinstance(read_audio, torch.Tensor)
|
||||
assert read_audio.shape[0] == 1
|
||||
|
||||
|
||||
def test_batchify():
|
||||
rand = torch.rand(1, 1000)
|
||||
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
|
||||
assert batched_rand.shape[0] == 12
|
||||
|
||||
|
||||
def test_aggregate():
|
||||
rand = torch.rand(12, 1, 100)
|
||||
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100)
|
||||
agg_rand = Inference.aggreagate(
|
||||
data=rand, window_size=100, total_frames=1000, step_size=100
|
||||
)
|
||||
assert agg_rand.shape[-1] == 1000
|
||||
|
||||
|
||||
|
||||
|
|
@ -7,6 +7,7 @@ from enhancer.utils.io import Audio
|
|||
from enhancer.utils.config import Files
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
|
||||
|
||||
def test_io_channel():
|
||||
|
||||
input_audio = np.random.rand(2, 32000)
|
||||
|
|
@ -14,6 +15,7 @@ def test_io_channel():
|
|||
output_audio = audio(input_audio)
|
||||
assert output_audio.shape[0] == 1
|
||||
|
||||
|
||||
def test_io_resampling():
|
||||
|
||||
input_audio = np.random.rand(1, 32000)
|
||||
|
|
@ -24,23 +26,27 @@ def test_io_resampling():
|
|||
|
||||
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
||||
|
||||
|
||||
def test_fileprocessor_vctk():
|
||||
|
||||
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav",
|
||||
"tests/data/vctk/noisy_testset_wav",48000)
|
||||
fp = Fileprocessor.from_name(
|
||||
"vctk",
|
||||
"tests/data/vctk/clean_testset_wav",
|
||||
"tests/data/vctk/noisy_testset_wav",
|
||||
48000,
|
||||
)
|
||||
matching_dict = fp.prepare_matching_dict()
|
||||
assert len(matching_dict) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||
def test_fileprocessor_names(dataset_name):
|
||||
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
|
||||
assert hasattr(fp.matching_function, '__call__')
|
||||
assert hasattr(fp.matching_function, "__call__")
|
||||
|
||||
|
||||
def test_fileprocessor_invaliname():
|
||||
with pytest.raises(ValueError):
|
||||
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
fp = Fileprocessor.from_name(
|
||||
"undefined", "clean_dir", "noisy_dir", 16000
|
||||
).prepare_matching_dict()
|
||||
|
|
|
|||
Loading…
Reference in New Issue