transforms test

This commit is contained in:
shahules786 2022-11-07 10:25:27 +05:30
parent d7f3847917
commit c1d5e56ec0
1 changed files with 7 additions and 6 deletions

View File

@ -18,15 +18,16 @@ class ConvFFT(nn.Module):
self.window_len = window_len self.window_len = window_len
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
self.window = torch.from_numpy( self.window = torch.from_numpy(
get_window(window, window_len).astype("float32") get_window(window, window_len, fftbins=True).astype("float32")
) )
@property def init_kernel(self, inverse=False):
def init_kernel(self):
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
real, imag = np.real(fourier_basis), np.imag(fourier_basis) real, imag = np.real(fourier_basis), np.imag(fourier_basis)
kernel = np.concatenate([real, imag], 1).T kernel = np.concatenate([real, imag], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
kernel *= self.window kernel *= self.window
return kernel return kernel
@ -42,7 +43,7 @@ class ConvSTFT(ConvFFT):
): ):
super().__init__(window_len=window_len, nfft=nfft, window=window) super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2 self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel) self.register_buffer("weight", self.init_kernel())
def forward(self, input): def forward(self, input):
@ -73,7 +74,7 @@ class ConviSTFT(ConvFFT):
): ):
super().__init__(window_len=window_len, nfft=nfft, window=window) super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2 self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel) self.register_buffer("weight", self.init_kernel(True))
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
def forward(self, input, phase=None): def forward(self, input, phase=None):
@ -85,7 +86,7 @@ class ConviSTFT(ConvFFT):
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
out = out / coeff out = out / (coeff + 1e-8)
pad = self.window_len - self.hop_size pad = self.window_len - self.hop_size
out = out[..., pad:-pad] out = out[..., pad:-pad]
return out return out