diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 2fb2779..ba3b4d2 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -50,13 +50,18 @@ class Audio: else: return audio + @staticmethod def convert_mono( - self, audio ): + if len(audio.shape)>2: + assert audio.shape[0] == 1, "convert mono only accepts single waveform" + audio = audio.reshape(audio.shape[1],audio.shape[2]) + + assert audio.shape[0] > audio.shape[1], "expected input format (num_channels,num_samples)" num_channels,num_samples = audio.shape - if num_channels>1 and self.mono: + if num_channels>1: return audio.mean(axis=0).reshape(1,num_samples) return audio @@ -68,17 +73,11 @@ class Audio: target_sr:int ): if sr!=target_sr: - audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) + if isinstance(audio,np.ndarray): + audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) + elif isinstance(audio,torch.Tensor): + audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) + else: + raise ValueError("Input should be either numpy array or torch tensor") return audio - - @staticmethod - def pt_resample_audio( - audio, - sr:int, - target_sr:int - ): - if sr!=target_sr: - audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) - - return audio \ No newline at end of file