from typing import Optional import numpy as np import torch import torch.nn.functional as F from scipy.signal import get_window from torch import nn class ConvFFT(nn.Module): def __init__( self, window_len: int, nfft: Optional[int] = None, window: str = "hamming", ): self.window_len = window_len self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) self.window = get_window(window, window_len) @property def init_kernel(self): fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] real, imag = np.real(fourier_basis), np.imag(fourier_basis) kernel = np.concatenate([real, imag], 1).T kernel *= self.window return torch.from_numpy(kernel) class ConvSTFT(ConvFFT): def __init__( self, window_len: int, hop_size: Optional[int] = None, nfft: Optional[int] = None, window: str = "hamming", ): super(self, ConvSTFT).__init__( window_len=window_len, nfft=nfft, window=window ) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) def forward(self, input): if input.dim() < 2: raise ValueError( f"Expected signal with shape 2 or 3 got {input.dim()}" ) elif input.dim() == 2: input = input.unsqueeze(1) else: pass input = F.pad( input, (self.window_len - self.hop_size, self.window_len - self.hop_size), ) output = F.conv1d(input, self.weight, stride=self.hop_size) return output class ConviSTFT(ConvFFT): def __init__( self, window_len: int, hop_size: Optional[int] = None, nfft: Optional[int] = None, window: str = "hamming", ): super(self, ConvSTFT).__init__( window_len=window_len, nfft=nfft, window=window ) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) self.register_buffer("enframe", np.eye(window_len).unsqueeze(1)) def forward(self, input, phase=None): if phase is not None: real = input * torch.cos(phase) imag = input * torch.sin(phase) input = torch.cat([real, imag], 1)