diff --git a/tests/transforms_test.py b/tests/transforms_test.py new file mode 100644 index 0000000..3053b09 --- /dev/null +++ b/tests/transforms_test.py @@ -0,0 +1,14 @@ +import torch + +from enhancer.utils.transforms import ConviSTFT, ConvSTFT + + +def test_stft_istft(): + sample_input = torch.rand(1, 1, 16000) + stft = ConvSTFT(window_len=400, hop_size=100, nfft=512) + istft = ConviSTFT(window_len=400, hop_size=100, nfft=512) + + with torch.no_grad(): + spectrogram = stft(sample_input) + waveform = istft(spectrogram) + assert sample_input.shape == waveform.shape