From 085a85d9ae84e6e5ae2d7d751899e26c2d08d3f3 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 27 Oct 2022 11:32:50 +0530 Subject: [PATCH] fourier transforms using cnn --- enhancer/utils/transforms.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index bf481e8..bbbae90 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -2,6 +2,7 @@ from typing import Optional import numpy as np import torch +import torch.nn.functional as F from scipy.signal import get_window from torch import nn @@ -42,7 +43,22 @@ class ConvSTFT(ConvFFT): self.register_buffer("weight", self.init_kernel) def forward(self, input): - pass + + if input.dim() < 2: + raise ValueError( + f"Expected signal with shape 2 or 3 got {input.dim()}" + ) + elif input.dim() == 2: + input = input.unsqueeze(1) + else: + pass + input = F.pad( + input, + (self.window_len - self.hop_size, self.window_len - self.hop_size), + ) + output = F.conv1d(input, self.weight, stride=self.hop_size) + + return output class ConviSTFT(ConvFFT): @@ -58,6 +74,11 @@ class ConviSTFT(ConvFFT): ) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) + self.register_buffer("enframe", np.eye(window_len).unsqueeze(1)) - def forward(self, input): - pass + def forward(self, input, phase=None): + + if phase is not None: + real = input * torch.cos(phase) + imag = input * torch.sin(phase) + input = torch.cat([real, imag], 1)