diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py new file mode 100644 index 0000000..bf481e8 --- /dev/null +++ b/enhancer/utils/transforms.py @@ -0,0 +1,63 @@ +from typing import Optional + +import numpy as np +import torch +from scipy.signal import get_window +from torch import nn + + +class ConvFFT(nn.Module): + def __init__( + self, + window_len: int, + nfft: Optional[int] = None, + window: str = "hamming", + ): + self.window_len = window_len + self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) + self.window = get_window(window, window_len) + + @property + def init_kernel(self): + + fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] + real, imag = np.real(fourier_basis), np.imag(fourier_basis) + kernel = np.concatenate([real, imag], 1).T + kernel *= self.window + return torch.from_numpy(kernel) + + +class ConvSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super(self, ConvSTFT).__init__( + window_len=window_len, nfft=nfft, window=window + ) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel) + + def forward(self, input): + pass + + +class ConviSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super(self, ConvSTFT).__init__( + window_len=window_len, nfft=nfft, window=window + ) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel) + + def forward(self, input): + pass