minor bug fixes

This commit is contained in:
shahules786 2022-09-26 17:09:29 +05:30
parent e9ea0d1695
commit 34755f33aa
1 changed files with 10 additions and 8 deletions

View File

@ -26,6 +26,8 @@ class Inference:
raise ValueError(f"Input file {audio} does not exist") raise ValueError(f"Input file {audio} does not exist")
else: else:
audio,sr = load_audio(audio,sr=sr,) audio,sr = load_audio(audio,sr=sr,)
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
else: else:
assert audio.shape[0] == 1, "Enhance inference only supports single waveform" assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
@ -40,11 +42,10 @@ class Inference:
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None): def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
""" """
break input waveform into samples with duration specified. break input waveform into samples with duration specified.
Wrap into tensors of specified batch size
""" """
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
_,num_samples = waveform.shape _,num_samples = waveform.shape
waveform = waveform.unsqueeze(0) waveform = waveform.unsqueeze(-1)
step_size = window_size//2 if step_size is None else step_size step_size = window_size//2 if step_size is None else step_size
if num_samples >= window_size: if num_samples >= window_size:
@ -55,24 +56,25 @@ class Inference:
return waveform_batch return waveform_batch
@staticmethod
def aggreagate(self,data:torch.Tensor,window_size:int, step_size:Optional[int]=None): def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None,
window="hanning",):
""" """
takes input as tensor outputs aggregated waveform takes input as tensor outputs aggregated waveform
""" """
batch_size,n_channels,num_frames = data.shape num_chunks,n_channels,num_frames = data.shape
window = get_window(window=window,Nx=data.shape[-1]) window = get_window(window=window,Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device) window = torch.from_numpy(window).to(data.device)
data *= window data *= window
data = data.permute(1,2,0) data = data.permute(1,2,0)
data = F.fold(data, data = F.fold(data,
(num_frames,1), (total_frames,1),
kernel_size=(window_size,1), kernel_size=(window_size,1),
stride=(step_size,1), stride=(step_size,1),
padding=(window_size,0)) padding=(window_size,0)).squeeze(-1)
return data return data.reshape(1,n_channels,-1)
@staticmethod @staticmethod
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int): def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):