prepare output type/sr
This commit is contained in:
parent
07c525ca15
commit
5e5fd9d9b0
|
|
@ -18,6 +18,8 @@ class Inference:
|
||||||
|
|
||||||
if isinstance(audio,(np.ndarray,torch.Tensor)):
|
if isinstance(audio,(np.ndarray,torch.Tensor)):
|
||||||
assert sr is not None, "Invalid sampling rate!"
|
assert sr is not None, "Invalid sampling rate!"
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1,-1)
|
||||||
|
|
||||||
if isinstance(audio,str):
|
if isinstance(audio,str):
|
||||||
audio = Path(audio)
|
audio = Path(audio)
|
||||||
|
|
@ -65,6 +67,8 @@ class Inference:
|
||||||
window = get_window(window=window,Nx=data.shape[-1])
|
window = get_window(window=window,Nx=data.shape[-1])
|
||||||
window = torch.from_numpy(window).to(data.device)
|
window = torch.from_numpy(window).to(data.device)
|
||||||
data *= window
|
data *= window
|
||||||
|
step_size = window_size//2 if step_size is None else step_size
|
||||||
|
|
||||||
|
|
||||||
data = data.permute(1,2,0)
|
data = data.permute(1,2,0)
|
||||||
data = F.fold(data,
|
data = F.fold(data,
|
||||||
|
|
@ -85,6 +89,17 @@ class Inference:
|
||||||
else:
|
else:
|
||||||
wavfile.write(filename,rate=sr,data=waveform.detach().cpu())
|
wavfile.write(filename,rate=sr,data=waveform.detach().cpu())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
|
||||||
|
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
|
||||||
|
):
|
||||||
|
if isinstance(audio,np.ndarray):
|
||||||
|
waveform = waveform.detach().cpu().numpy()
|
||||||
|
|
||||||
|
if sampling_rate!=None:
|
||||||
|
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue