prepare output type/sr

This commit is contained in:
shahules786 2022-10-03 20:00:35 +05:30
parent 07c525ca15
commit 5e5fd9d9b0
1 changed files with 16 additions and 1 deletions

View File

@ -18,6 +18,8 @@ class Inference:
if isinstance(audio,(np.ndarray,torch.Tensor)):
assert sr is not None, "Invalid sampling rate!"
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
if isinstance(audio,str):
audio = Path(audio)
@ -65,6 +67,8 @@ class Inference:
window = get_window(window=window,Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device)
data *= window
step_size = window_size//2 if step_size is None else step_size
data = data.permute(1,2,0)
data = F.fold(data,
@ -84,7 +88,18 @@ class Inference:
raise FileExistsError(f"file {filename} already exists")
else:
wavfile.write(filename,rate=sr,data=waveform.detach().cpu())
@staticmethod
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
):
if isinstance(audio,np.ndarray):
waveform = waveform.detach().cpu().numpy()
if sampling_rate!=None:
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
return waveform