diff --git a/enhancer/inference.py b/enhancer/inference.py index 6e9cff7..2c63d54 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -18,6 +18,8 @@ class Inference: if isinstance(audio,(np.ndarray,torch.Tensor)): assert sr is not None, "Invalid sampling rate!" + if len(audio.shape) == 1: + audio = audio.reshape(1,-1) if isinstance(audio,str): audio = Path(audio) @@ -65,6 +67,8 @@ class Inference: window = get_window(window=window,Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window + step_size = window_size//2 if step_size is None else step_size + data = data.permute(1,2,0) data = F.fold(data, @@ -84,7 +88,18 @@ class Inference: raise FileExistsError(f"file {filename} already exists") else: wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) - + + @staticmethod + def prepare_output(waveform:torch.Tensor, model_sampling_rate:int, + audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int] + ): + if isinstance(audio,np.ndarray): + waveform = waveform.detach().cpu().numpy() + + if sampling_rate!=None: + waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate) + + return waveform