tests
This commit is contained in:
parent
8ac01b846d
commit
8a07cb8712
|
|
@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error
|
||||||
|
|
||||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||||
|
|
||||||
|
|
||||||
def check_loss_shapes_compatibility(loss_fun):
|
def check_loss_shapes_compatibility(loss_fun):
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
shape = (1,1000)
|
shape = (1, 1000)
|
||||||
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
|
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
|
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("loss",loss_functions)
|
@pytest.mark.parametrize("loss", loss_functions)
|
||||||
def test_loss_input_shapes(loss):
|
def test_loss_input_shapes(loss):
|
||||||
check_loss_shapes_compatibility(loss)
|
check_loss_shapes_compatibility(loss)
|
||||||
|
|
||||||
@pytest.mark.parametrize("loss",loss_functions)
|
|
||||||
|
@pytest.mark.parametrize("loss", loss_functions)
|
||||||
def test_loss_output_type(loss):
|
def test_loss_output_type(loss):
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
|
||||||
|
batch_size, 1, 1000
|
||||||
|
)
|
||||||
loss_value = loss(prediction, target)
|
loss_value = loss(prediction, target)
|
||||||
assert isinstance(loss_value.item(),float)
|
assert isinstance(loss_value.item(), float)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vctk_dataset():
|
def vctk_dataset():
|
||||||
root_dir = "tests/data/vctk"
|
root_dir = "tests/data/vctk"
|
||||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
files = Files(
|
||||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
train_clean="clean_testset_wav",
|
||||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
def test_forward(batch_size,samples):
|
def test_forward(batch_size, samples):
|
||||||
model = Demucs()
|
model = Demucs()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset,channels,loss",
|
@pytest.mark.parametrize(
|
||||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
"dataset,channels,loss",
|
||||||
def test_demucs_init(dataset,channels,loss):
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = Demucs(num_channels=channels,dataset=dataset,loss=loss)
|
model = Demucs(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vctk_dataset():
|
def vctk_dataset():
|
||||||
root_dir = "tests/data/vctk"
|
root_dir = "tests/data/vctk"
|
||||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
files = Files(
|
||||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
train_clean="clean_testset_wav",
|
||||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
def test_forward(batch_size,samples):
|
def test_forward(batch_size, samples):
|
||||||
model = WaveUnet()
|
model = WaveUnet()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset,channels,loss",
|
@pytest.mark.parametrize(
|
||||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
"dataset,channels,loss",
|
||||||
def test_demucs_init(dataset,channels,loss):
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
|
model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,26 @@ import torch
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)])
|
@pytest.mark.parametrize(
|
||||||
|
"audio",
|
||||||
|
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
|
||||||
|
)
|
||||||
def test_read_input(audio):
|
def test_read_input(audio):
|
||||||
|
|
||||||
read_audio = Inference.read_input(audio,48000,16000)
|
read_audio = Inference.read_input(audio, 48000, 16000)
|
||||||
assert isinstance(read_audio,torch.Tensor)
|
assert isinstance(read_audio, torch.Tensor)
|
||||||
assert read_audio.shape[0] == 1
|
assert read_audio.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_batchify():
|
def test_batchify():
|
||||||
rand = torch.rand(1,1000)
|
rand = torch.rand(1, 1000)
|
||||||
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100)
|
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
|
||||||
assert batched_rand.shape[0] == 12
|
assert batched_rand.shape[0] == 12
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate():
|
def test_aggregate():
|
||||||
rand = torch.rand(12,1,100)
|
rand = torch.rand(12, 1, 100)
|
||||||
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100)
|
agg_rand = Inference.aggreagate(
|
||||||
|
data=rand, window_size=100, total_frames=1000, step_size=100
|
||||||
|
)
|
||||||
assert agg_rand.shape[-1] == 1000
|
assert agg_rand.shape[-1] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -7,40 +7,46 @@ from enhancer.utils.io import Audio
|
||||||
from enhancer.utils.config import Files
|
from enhancer.utils.config import Files
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
|
||||||
|
|
||||||
def test_io_channel():
|
def test_io_channel():
|
||||||
|
|
||||||
input_audio = np.random.rand(2,32000)
|
input_audio = np.random.rand(2, 32000)
|
||||||
audio = Audio(mono=True,return_tensor=False)
|
audio = Audio(mono=True, return_tensor=False)
|
||||||
output_audio = audio(input_audio)
|
output_audio = audio(input_audio)
|
||||||
assert output_audio.shape[0] == 1
|
assert output_audio.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_io_resampling():
|
def test_io_resampling():
|
||||||
|
|
||||||
input_audio = np.random.rand(1,32000)
|
input_audio = np.random.rand(1, 32000)
|
||||||
resampled_audio = Audio.resample_audio(input_audio,16000,8000)
|
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
|
||||||
|
|
||||||
input_audio = torch.rand(1,32000)
|
input_audio = torch.rand(1, 32000)
|
||||||
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000)
|
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
|
||||||
|
|
||||||
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
||||||
|
|
||||||
|
|
||||||
def test_fileprocessor_vctk():
|
def test_fileprocessor_vctk():
|
||||||
|
|
||||||
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav",
|
fp = Fileprocessor.from_name(
|
||||||
"tests/data/vctk/noisy_testset_wav",48000)
|
"vctk",
|
||||||
|
"tests/data/vctk/clean_testset_wav",
|
||||||
|
"tests/data/vctk/noisy_testset_wav",
|
||||||
|
48000,
|
||||||
|
)
|
||||||
matching_dict = fp.prepare_matching_dict()
|
matching_dict = fp.prepare_matching_dict()
|
||||||
assert len(matching_dict)==2
|
assert len(matching_dict) == 2
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
|
|
||||||
|
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||||
def test_fileprocessor_names(dataset_name):
|
def test_fileprocessor_names(dataset_name):
|
||||||
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000)
|
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
|
||||||
assert hasattr(fp.matching_function, '__call__')
|
assert hasattr(fp.matching_function, "__call__")
|
||||||
|
|
||||||
|
|
||||||
def test_fileprocessor_invaliname():
|
def test_fileprocessor_invaliname():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict()
|
fp = Fileprocessor.from_name(
|
||||||
|
"undefined", "clean_dir", "noisy_dir", 16000
|
||||||
|
).prepare_matching_dict()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue