diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index bbbae90..acfc323 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -14,9 +14,12 @@ class ConvFFT(nn.Module): nfft: Optional[int] = None, window: str = "hamming", ): + super().__init__() self.window_len = window_len self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) - self.window = get_window(window, window_len) + self.window = torch.from_numpy( + get_window(window, window_len).astype("float32") + ) @property def init_kernel(self): @@ -24,8 +27,9 @@ class ConvFFT(nn.Module): fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] real, imag = np.real(fourier_basis), np.imag(fourier_basis) kernel = np.concatenate([real, imag], 1).T + kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) kernel *= self.window - return torch.from_numpy(kernel) + return kernel class ConvSTFT(ConvFFT): @@ -36,9 +40,7 @@ class ConvSTFT(ConvFFT): nfft: Optional[int] = None, window: str = "hamming", ): - super(self, ConvSTFT).__init__( - window_len=window_len, nfft=nfft, window=window - ) + super().__init__(window_len=window_len, nfft=nfft, window=window) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) @@ -69,12 +71,10 @@ class ConviSTFT(ConvFFT): nfft: Optional[int] = None, window: str = "hamming", ): - super(self, ConvSTFT).__init__( - window_len=window_len, nfft=nfft, window=window - ) + super().__init__(window_len=window_len, nfft=nfft, window=window) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) - self.register_buffer("enframe", np.eye(window_len).unsqueeze(1)) + self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) def forward(self, input, phase=None): @@ -82,3 +82,10 @@ class ConviSTFT(ConvFFT): real = input * torch.cos(phase) imag = input * torch.sin(phase) input = torch.cat([real, imag], 1) + out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) + coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 + coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) + out = out / coeff + pad = self.window_len - self.hop_size + out = out[..., pad:-pad] + return out