diff --git a/tests/transforms_test.py b/tests/transforms_test.py index 3053b09..89425ad 100644 --- a/tests/transforms_test.py +++ b/tests/transforms_test.py @@ -12,3 +12,7 @@ def test_stft_istft(): spectrogram = stft(sample_input) waveform = istft(spectrogram) assert sample_input.shape == waveform.shape + assert ( + torch.isclose(waveform, sample_input).sum().item() + > sample_input.shape[-1] // 2 + )