This commit is contained in:
shahules786 2022-10-05 15:54:54 +05:30
parent 8ac01b846d
commit 8a07cb8712
5 changed files with 83 additions and 75 deletions

View File

@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun):
batch_size = 4
shape = (1,1000)
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
shape = (1, 1000)
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
with pytest.raises(TypeError):
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_output_type(loss):
batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
batch_size, 1, 1000
)
loss_value = loss(prediction, target)
assert isinstance(loss_value.item(),float)
assert isinstance(loss_value.item(), float)

View File

@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
def test_forward(batch_size,samples):
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = Demucs()
model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False)
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(TypeError):
_ = model(data)
@pytest.mark.parametrize("dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
def test_demucs_init(dataset,channels,loss):
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
model = Demucs(num_channels=channels,dataset=dataset,loss=loss)
model = Demucs(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
def test_forward(batch_size,samples):
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = WaveUnet()
model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False)
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(TypeError):
_ = model(data)
@pytest.mark.parametrize("dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
def test_demucs_init(dataset,channels,loss):
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -4,22 +4,26 @@ import torch
from enhancer.inference import Inference
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)])
@pytest.mark.parametrize(
"audio",
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
)
def test_read_input(audio):
read_audio = Inference.read_input(audio,48000,16000)
assert isinstance(read_audio,torch.Tensor)
read_audio = Inference.read_input(audio, 48000, 16000)
assert isinstance(read_audio, torch.Tensor)
assert read_audio.shape[0] == 1
def test_batchify():
rand = torch.rand(1,1000)
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100)
rand = torch.rand(1, 1000)
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
assert batched_rand.shape[0] == 12
def test_aggregate():
rand = torch.rand(12,1,100)
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100)
rand = torch.rand(12, 1, 100)
agg_rand = Inference.aggreagate(
data=rand, window_size=100, total_frames=1000, step_size=100
)
assert agg_rand.shape[-1] == 1000

View File

@ -7,40 +7,46 @@ from enhancer.utils.io import Audio
from enhancer.utils.config import Files
from enhancer.data.fileprocessor import Fileprocessor
def test_io_channel():
input_audio = np.random.rand(2,32000)
audio = Audio(mono=True,return_tensor=False)
input_audio = np.random.rand(2, 32000)
audio = Audio(mono=True, return_tensor=False)
output_audio = audio(input_audio)
assert output_audio.shape[0] == 1
def test_io_resampling():
input_audio = np.random.rand(1,32000)
resampled_audio = Audio.resample_audio(input_audio,16000,8000)
input_audio = np.random.rand(1, 32000)
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
input_audio = torch.rand(1,32000)
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000)
input_audio = torch.rand(1, 32000)
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
def test_fileprocessor_vctk():
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",48000)
fp = Fileprocessor.from_name(
"vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
48000,
)
matching_dict = fp.prepare_matching_dict()
assert len(matching_dict)==2
assert len(matching_dict) == 2
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000)
assert hasattr(fp.matching_function, '__call__')
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
assert hasattr(fp.matching_function, "__call__")
def test_fileprocessor_invaliname():
with pytest.raises(ValueError):
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict()
fp = Fileprocessor.from_name(
"undefined", "clean_dir", "noisy_dir", 16000
).prepare_matching_dict()