minor bug fixes
This commit is contained in:
parent
e9ea0d1695
commit
34755f33aa
|
|
@ -26,6 +26,8 @@ class Inference:
|
||||||
raise ValueError(f"Input file {audio} does not exist")
|
raise ValueError(f"Input file {audio} does not exist")
|
||||||
else:
|
else:
|
||||||
audio,sr = load_audio(audio,sr=sr,)
|
audio,sr = load_audio(audio,sr=sr,)
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1,-1)
|
||||||
else:
|
else:
|
||||||
assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
|
assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
|
||||||
|
|
||||||
|
|
@ -40,11 +42,10 @@ class Inference:
|
||||||
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
|
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
|
||||||
"""
|
"""
|
||||||
break input waveform into samples with duration specified.
|
break input waveform into samples with duration specified.
|
||||||
Wrap into tensors of specified batch size
|
|
||||||
"""
|
"""
|
||||||
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
||||||
_,num_samples = waveform.shape
|
_,num_samples = waveform.shape
|
||||||
waveform = waveform.unsqueeze(0)
|
waveform = waveform.unsqueeze(-1)
|
||||||
step_size = window_size//2 if step_size is None else step_size
|
step_size = window_size//2 if step_size is None else step_size
|
||||||
|
|
||||||
if num_samples >= window_size:
|
if num_samples >= window_size:
|
||||||
|
|
@ -55,24 +56,25 @@ class Inference:
|
||||||
|
|
||||||
return waveform_batch
|
return waveform_batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def aggreagate(self,data:torch.Tensor,window_size:int, step_size:Optional[int]=None):
|
def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None,
|
||||||
|
window="hanning",):
|
||||||
"""
|
"""
|
||||||
takes input as tensor outputs aggregated waveform
|
takes input as tensor outputs aggregated waveform
|
||||||
"""
|
"""
|
||||||
batch_size,n_channels,num_frames = data.shape
|
num_chunks,n_channels,num_frames = data.shape
|
||||||
window = get_window(window=window,Nx=data.shape[-1])
|
window = get_window(window=window,Nx=data.shape[-1])
|
||||||
window = torch.from_numpy(window).to(data.device)
|
window = torch.from_numpy(window).to(data.device)
|
||||||
data *= window
|
data *= window
|
||||||
|
|
||||||
data = data.permute(1,2,0)
|
data = data.permute(1,2,0)
|
||||||
data = F.fold(data,
|
data = F.fold(data,
|
||||||
(num_frames,1),
|
(total_frames,1),
|
||||||
kernel_size=(window_size,1),
|
kernel_size=(window_size,1),
|
||||||
stride=(step_size,1),
|
stride=(step_size,1),
|
||||||
padding=(window_size,0))
|
padding=(window_size,0)).squeeze(-1)
|
||||||
|
|
||||||
return data
|
return data.reshape(1,n_channels,-1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):
|
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue