From 34755f33aa8fa79cd6c3b6242232a58b12f3d2e7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 26 Sep 2022 17:09:29 +0530 Subject: [PATCH] minor bug fixes --- enhancer/inference.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 414f6ae..404fe95 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -26,6 +26,8 @@ class Inference: raise ValueError(f"Input file {audio} does not exist") else: audio,sr = load_audio(audio,sr=sr,) + if len(audio.shape) == 1: + audio = audio.reshape(1,-1) else: assert audio.shape[0] == 1, "Enhance inference only supports single waveform" @@ -40,11 +42,10 @@ class Inference: def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None): """ break input waveform into samples with duration specified. - Wrap into tensors of specified batch size """ assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" _,num_samples = waveform.shape - waveform = waveform.unsqueeze(0) + waveform = waveform.unsqueeze(-1) step_size = window_size//2 if step_size is None else step_size if num_samples >= window_size: @@ -55,24 +56,25 @@ class Inference: return waveform_batch - - def aggreagate(self,data:torch.Tensor,window_size:int, step_size:Optional[int]=None): + @staticmethod + def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None, + window="hanning",): """ takes input as tensor outputs aggregated waveform """ - batch_size,n_channels,num_frames = data.shape + num_chunks,n_channels,num_frames = data.shape window = get_window(window=window,Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window data = data.permute(1,2,0) data = F.fold(data, - (num_frames,1), + (total_frames,1), kernel_size=(window_size,1), stride=(step_size,1), - padding=(window_size,0)) + padding=(window_size,0)).squeeze(-1) - return data + return data.reshape(1,n_channels,-1) @staticmethod def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):