fourier transforms using cnn
This commit is contained in:
parent
23da02d47d
commit
085a85d9ae
|
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import get_window
|
||||
from torch import nn
|
||||
|
||||
|
|
@ -42,7 +43,22 @@ class ConvSTFT(ConvFFT):
|
|||
self.register_buffer("weight", self.init_kernel)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if input.dim() < 2:
|
||||
raise ValueError(
|
||||
f"Expected signal with shape 2 or 3 got {input.dim()}"
|
||||
)
|
||||
elif input.dim() == 2:
|
||||
input = input.unsqueeze(1)
|
||||
else:
|
||||
pass
|
||||
input = F.pad(
|
||||
input,
|
||||
(self.window_len - self.hop_size, self.window_len - self.hop_size),
|
||||
)
|
||||
output = F.conv1d(input, self.weight, stride=self.hop_size)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ConviSTFT(ConvFFT):
|
||||
|
|
@ -58,6 +74,11 @@ class ConviSTFT(ConvFFT):
|
|||
)
|
||||
self.hop_size = hop_size if hop_size else window_len // 2
|
||||
self.register_buffer("weight", self.init_kernel)
|
||||
self.register_buffer("enframe", np.eye(window_len).unsqueeze(1))
|
||||
|
||||
def forward(self, input):
|
||||
pass
|
||||
def forward(self, input, phase=None):
|
||||
|
||||
if phase is not None:
|
||||
real = input * torch.cos(phase)
|
||||
imag = input * torch.sin(phase)
|
||||
input = torch.cat([real, imag], 1)
|
||||
|
|
|
|||
Loading…
Reference in New Issue