test transforms

This commit is contained in:
shahules786 2022-10-29 11:35:35 +05:30
parent c18a85b5c8
commit cf1e5c07a9
1 changed files with 14 additions and 0 deletions

14
tests/transforms_test.py Normal file
View File

@ -0,0 +1,14 @@
import torch
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
def test_stft_istft():
sample_input = torch.rand(1, 1, 16000)
stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
with torch.no_grad():
spectrogram = stft(sample_input)
waveform = istft(spectrogram)
assert sample_input.shape == waveform.shape