diff --git a/tests/test_inference.py b/tests/test_inference.py index f727938..8bdcf4e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -27,3 +27,12 @@ def test_aggregate(): data=rand, window_size=100, total_frames=1000, step_size=100 ) assert agg_rand.shape[-1] == 1000 + + +def test_pretrained(): + from mayavoz.models import Mayamodel + + model = Mayamodel.from_pretrained( + "shahules786/mayavoz-waveunet-valentini-28spk" + ) + _ = model.enhance("tests/data/vctk/clean_testset_wav/p257_166.wav")