fourier transforms using cnn
This commit is contained in:
parent
23da02d47d
commit
085a85d9ae
|
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from scipy.signal import get_window
|
from scipy.signal import get_window
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
@ -42,7 +43,22 @@ class ConvSTFT(ConvFFT):
|
||||||
self.register_buffer("weight", self.init_kernel)
|
self.register_buffer("weight", self.init_kernel)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
||||||
|
if input.dim() < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected signal with shape 2 or 3 got {input.dim()}"
|
||||||
|
)
|
||||||
|
elif input.dim() == 2:
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
else:
|
||||||
pass
|
pass
|
||||||
|
input = F.pad(
|
||||||
|
input,
|
||||||
|
(self.window_len - self.hop_size, self.window_len - self.hop_size),
|
||||||
|
)
|
||||||
|
output = F.conv1d(input, self.weight, stride=self.hop_size)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class ConviSTFT(ConvFFT):
|
class ConviSTFT(ConvFFT):
|
||||||
|
|
@ -58,6 +74,11 @@ class ConviSTFT(ConvFFT):
|
||||||
)
|
)
|
||||||
self.hop_size = hop_size if hop_size else window_len // 2
|
self.hop_size = hop_size if hop_size else window_len // 2
|
||||||
self.register_buffer("weight", self.init_kernel)
|
self.register_buffer("weight", self.init_kernel)
|
||||||
|
self.register_buffer("enframe", np.eye(window_len).unsqueeze(1))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input, phase=None):
|
||||||
pass
|
|
||||||
|
if phase is not None:
|
||||||
|
real = input * torch.cos(phase)
|
||||||
|
imag = input * torch.sin(phase)
|
||||||
|
input = torch.cat([real, imag], 1)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue