diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index bf481e8..bbbae90 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -2,6 +2,7 @@ from typing import Optional import numpy as np import torch +import torch.nn.functional as F from scipy.signal import get_window from torch import nn @@ -42,7 +43,22 @@ class ConvSTFT(ConvFFT): self.register_buffer("weight", self.init_kernel) def forward(self, input): - pass + + if input.dim() < 2: + raise ValueError( + f"Expected signal with shape 2 or 3 got {input.dim()}" + ) + elif input.dim() == 2: + input = input.unsqueeze(1) + else: + pass + input = F.pad( + input, + (self.window_len - self.hop_size, self.window_len - self.hop_size), + ) + output = F.conv1d(input, self.weight, stride=self.hop_size) + + return output class ConviSTFT(ConvFFT): @@ -58,6 +74,11 @@ class ConviSTFT(ConvFFT): ) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) + self.register_buffer("enframe", np.eye(window_len).unsqueeze(1)) - def forward(self, input): - pass + def forward(self, input, phase=None): + + if phase is not None: + real = input * torch.cos(phase) + imag = input * torch.sin(phase) + input = torch.cat([real, imag], 1)