19 lines
		
	
	
		
			532 B
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			19 lines
		
	
	
		
			532 B
		
	
	
	
		
			Python
		
	
	
	
| import torch
 | |
| 
 | |
| from enhancer.utils.transforms import ConviSTFT, ConvSTFT
 | |
| 
 | |
| 
 | |
| def test_stft_istft():
 | |
|     sample_input = torch.rand(1, 1, 16000)
 | |
|     stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
 | |
|     istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
 | |
| 
 | |
|     with torch.no_grad():
 | |
|         spectrogram = stft(sample_input)
 | |
|         waveform = istft(spectrogram)
 | |
|     assert sample_input.shape == waveform.shape
 | |
|     assert (
 | |
|         torch.isclose(waveform, sample_input).sum().item()
 | |
|         > sample_input.shape[-1] // 2
 | |
|     )
 |