This commit is contained in:
shahules786 2022-10-05 15:54:54 +05:30
parent 8ac01b846d
commit 8a07cb8712
5 changed files with 83 additions and 75 deletions

View File

@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()] loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun): def check_loss_shapes_compatibility(loss_fun):
batch_size = 4 batch_size = 4
shape = (1,1000) shape = (1, 1000)
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
with pytest.raises(TypeError): with pytest.raises(TypeError):
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
@pytest.mark.parametrize("loss",loss_functions) @pytest.mark.parametrize("loss", loss_functions)
def test_loss_input_shapes(loss): def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss) check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_output_type(loss): def test_loss_output_type(loss):
batch_size = 4 batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
batch_size, 1, 1000
)
loss_value = loss(prediction, target) loss_value = loss(prediction, target)
assert isinstance(loss_value.item(),float) assert isinstance(loss_value.item(), float)

View File

@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
@pytest.fixture @pytest.fixture
def vctk_dataset(): def vctk_dataset():
root_dir = "tests/data/vctk" root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", files = Files(
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") train_clean="clean_testset_wav",
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) @pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size,samples): def test_forward(batch_size, samples):
model = Demucs() model = Demucs()
model.eval() model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False) data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
_ = model(data) _ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)
@pytest.mark.parametrize("dataset,channels,loss", @pytest.mark.parametrize(
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) "dataset,channels,loss",
def test_demucs_init(dataset,channels,loss): [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad(): with torch.no_grad():
model = Demucs(num_channels=channels,dataset=dataset,loss=loss) model = Demucs(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
@pytest.fixture @pytest.fixture
def vctk_dataset(): def vctk_dataset():
root_dir = "tests/data/vctk" root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", files = Files(
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") train_clean="clean_testset_wav",
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) @pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size,samples): def test_forward(batch_size, samples):
model = WaveUnet() model = WaveUnet()
model.eval() model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False) data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
_ = model(data) _ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)
@pytest.mark.parametrize("dataset,channels,loss", @pytest.mark.parametrize(
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) "dataset,channels,loss",
def test_demucs_init(dataset,channels,loss): [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad(): with torch.no_grad():
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -4,22 +4,26 @@ import torch
from enhancer.inference import Inference from enhancer.inference import Inference
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) @pytest.mark.parametrize(
"audio",
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
)
def test_read_input(audio): def test_read_input(audio):
read_audio = Inference.read_input(audio,48000,16000) read_audio = Inference.read_input(audio, 48000, 16000)
assert isinstance(read_audio,torch.Tensor) assert isinstance(read_audio, torch.Tensor)
assert read_audio.shape[0] == 1 assert read_audio.shape[0] == 1
def test_batchify(): def test_batchify():
rand = torch.rand(1,1000) rand = torch.rand(1, 1000)
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
assert batched_rand.shape[0] == 12 assert batched_rand.shape[0] == 12
def test_aggregate(): def test_aggregate():
rand = torch.rand(12,1,100) rand = torch.rand(12, 1, 100)
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) agg_rand = Inference.aggreagate(
data=rand, window_size=100, total_frames=1000, step_size=100
)
assert agg_rand.shape[-1] == 1000 assert agg_rand.shape[-1] == 1000

View File

@ -7,40 +7,46 @@ from enhancer.utils.io import Audio
from enhancer.utils.config import Files from enhancer.utils.config import Files
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
def test_io_channel(): def test_io_channel():
input_audio = np.random.rand(2,32000) input_audio = np.random.rand(2, 32000)
audio = Audio(mono=True,return_tensor=False) audio = Audio(mono=True, return_tensor=False)
output_audio = audio(input_audio) output_audio = audio(input_audio)
assert output_audio.shape[0] == 1 assert output_audio.shape[0] == 1
def test_io_resampling(): def test_io_resampling():
input_audio = np.random.rand(1,32000) input_audio = np.random.rand(1, 32000)
resampled_audio = Audio.resample_audio(input_audio,16000,8000) resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
input_audio = torch.rand(1,32000) input_audio = torch.rand(1, 32000)
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
def test_fileprocessor_vctk(): def test_fileprocessor_vctk():
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", fp = Fileprocessor.from_name(
"tests/data/vctk/noisy_testset_wav",48000) "vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
48000,
)
matching_dict = fp.prepare_matching_dict() matching_dict = fp.prepare_matching_dict()
assert len(matching_dict)==2 assert len(matching_dict) == 2
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name): def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
assert hasattr(fp.matching_function, '__call__') assert hasattr(fp.matching_function, "__call__")
def test_fileprocessor_invaliname(): def test_fileprocessor_invaliname():
with pytest.raises(ValueError): with pytest.raises(ValueError):
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() fp = Fileprocessor.from_name(
"undefined", "clean_dir", "noisy_dir", 16000
).prepare_matching_dict()