prepare output type/sr
This commit is contained in:
parent
07c525ca15
commit
5e5fd9d9b0
|
|
@ -18,6 +18,8 @@ class Inference:
|
|||
|
||||
if isinstance(audio,(np.ndarray,torch.Tensor)):
|
||||
assert sr is not None, "Invalid sampling rate!"
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.reshape(1,-1)
|
||||
|
||||
if isinstance(audio,str):
|
||||
audio = Path(audio)
|
||||
|
|
@ -65,6 +67,8 @@ class Inference:
|
|||
window = get_window(window=window,Nx=data.shape[-1])
|
||||
window = torch.from_numpy(window).to(data.device)
|
||||
data *= window
|
||||
step_size = window_size//2 if step_size is None else step_size
|
||||
|
||||
|
||||
data = data.permute(1,2,0)
|
||||
data = F.fold(data,
|
||||
|
|
@ -84,7 +88,18 @@ class Inference:
|
|||
raise FileExistsError(f"file {filename} already exists")
|
||||
else:
|
||||
wavfile.write(filename,rate=sr,data=waveform.detach().cpu())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
|
||||
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
|
||||
):
|
||||
if isinstance(audio,np.ndarray):
|
||||
waveform = waveform.detach().cpu().numpy()
|
||||
|
||||
if sampling_rate!=None:
|
||||
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue