From 8a07cb8712a437faeff8bd7475d33e7c445fbb28 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:54:54 +0530 Subject: [PATCH] tests --- tests/loss_function_test.py | 20 +++++++++-------- tests/models/demucs_test.py | 36 ++++++++++++++---------------- tests/models/test_waveunet.py | 36 ++++++++++++++---------------- tests/test_inference.py | 24 +++++++++++--------- tests/utils_test.py | 42 ++++++++++++++++++++--------------- 5 files changed, 83 insertions(+), 75 deletions(-) diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index fbc982c..a4fdc62 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error loss_functions = [mean_absolute_error(), mean_squared_error()] + def check_loss_shapes_compatibility(loss_fun): batch_size = 4 - shape = (1,1000) - loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) + shape = (1, 1000) + loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape)) with pytest.raises(TypeError): - loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) + loss_fun(torch.rand(4, *shape), torch.rand(6, *shape)) -@pytest.mark.parametrize("loss",loss_functions) +@pytest.mark.parametrize("loss", loss_functions) def test_loss_input_shapes(loss): check_loss_shapes_compatibility(loss) -@pytest.mark.parametrize("loss",loss_functions) + +@pytest.mark.parametrize("loss", loss_functions) def test_loss_output_type(loss): batch_size = 4 - prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) + prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( + batch_size, 1, 1000 + ) loss_value = loss(prediction, target) - assert isinstance(loss_value.item(),float) - - + assert isinstance(loss_value.item(), float) diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index a59fa04..6660888 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = Demucs() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = Demucs(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + model = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 43fd14d..c83966b 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = WaveUnet() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/test_inference.py b/tests/test_inference.py index 5eb7442..a6e2423 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4,22 +4,26 @@ import torch from enhancer.inference import Inference -@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) +@pytest.mark.parametrize( + "audio", + ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], +) def test_read_input(audio): - read_audio = Inference.read_input(audio,48000,16000) - assert isinstance(read_audio,torch.Tensor) + read_audio = Inference.read_input(audio, 48000, 16000) + assert isinstance(read_audio, torch.Tensor) assert read_audio.shape[0] == 1 + def test_batchify(): - rand = torch.rand(1,1000) - batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) + rand = torch.rand(1, 1000) + batched_rand = Inference.batchify(rand, window_size=100, step_size=100) assert batched_rand.shape[0] == 12 + def test_aggregate(): - rand = torch.rand(12,1,100) - agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) + rand = torch.rand(12, 1, 100) + agg_rand = Inference.aggreagate( + data=rand, window_size=100, total_frames=1000, step_size=100 + ) assert agg_rand.shape[-1] == 1000 - - - \ No newline at end of file diff --git a/tests/utils_test.py b/tests/utils_test.py index 413bfac..93a9094 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -7,40 +7,46 @@ from enhancer.utils.io import Audio from enhancer.utils.config import Files from enhancer.data.fileprocessor import Fileprocessor + def test_io_channel(): - input_audio = np.random.rand(2,32000) - audio = Audio(mono=True,return_tensor=False) + input_audio = np.random.rand(2, 32000) + audio = Audio(mono=True, return_tensor=False) output_audio = audio(input_audio) assert output_audio.shape[0] == 1 + def test_io_resampling(): - input_audio = np.random.rand(1,32000) - resampled_audio = Audio.resample_audio(input_audio,16000,8000) + input_audio = np.random.rand(1, 32000) + resampled_audio = Audio.resample_audio(input_audio, 16000, 8000) - input_audio = torch.rand(1,32000) - resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) + input_audio = torch.rand(1, 32000) + resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000) assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 + def test_fileprocessor_vctk(): - fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", - "tests/data/vctk/noisy_testset_wav",48000) + fp = Fileprocessor.from_name( + "vctk", + "tests/data/vctk/clean_testset_wav", + "tests/data/vctk/noisy_testset_wav", + 48000, + ) matching_dict = fp.prepare_matching_dict() - assert len(matching_dict)==2 + assert len(matching_dict) == 2 -@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"]) + +@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) def test_fileprocessor_names(dataset_name): - fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) - assert hasattr(fp.matching_function, '__call__') + fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) + assert hasattr(fp.matching_function, "__call__") + def test_fileprocessor_invaliname(): with pytest.raises(ValueError): - fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() - - - - - + fp = Fileprocessor.from_name( + "undefined", "clean_dir", "noisy_dir", 16000 + ).prepare_matching_dict()