fourier transforms using cnn

This commit is contained in:
shahules786 2022-10-27 11:32:50 +05:30
parent 23da02d47d
commit 085a85d9ae
1 changed files with 24 additions and 3 deletions

View File

@ -2,6 +2,7 @@ from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import get_window
from torch import nn
@ -42,7 +43,22 @@ class ConvSTFT(ConvFFT):
self.register_buffer("weight", self.init_kernel)
def forward(self, input):
pass
if input.dim() < 2:
raise ValueError(
f"Expected signal with shape 2 or 3 got {input.dim()}"
)
elif input.dim() == 2:
input = input.unsqueeze(1)
else:
pass
input = F.pad(
input,
(self.window_len - self.hop_size, self.window_len - self.hop_size),
)
output = F.conv1d(input, self.weight, stride=self.hop_size)
return output
class ConviSTFT(ConvFFT):
@ -58,6 +74,11 @@ class ConviSTFT(ConvFFT):
)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel)
self.register_buffer("enframe", np.eye(window_len).unsqueeze(1))
def forward(self, input):
pass
def forward(self, input, phase=None):
if phase is not None:
real = input * torch.cos(phase)
imag = input * torch.sin(phase)
input = torch.cat([real, imag], 1)