From 2a293a1d40af7b4482d38fbc3e43e408a064ba9e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 19 Sep 2022 22:34:24 +0530 Subject: [PATCH] refactor IO functions --- enhancer/utils/io.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 2fb2779..ba3b4d2 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -50,13 +50,18 @@ class Audio: else: return audio + @staticmethod def convert_mono( - self, audio ): + if len(audio.shape)>2: + assert audio.shape[0] == 1, "convert mono only accepts single waveform" + audio = audio.reshape(audio.shape[1],audio.shape[2]) + + assert audio.shape[0] > audio.shape[1], "expected input format (num_channels,num_samples)" num_channels,num_samples = audio.shape - if num_channels>1 and self.mono: + if num_channels>1: return audio.mean(axis=0).reshape(1,num_samples) return audio @@ -68,17 +73,11 @@ class Audio: target_sr:int ): if sr!=target_sr: - audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) + if isinstance(audio,np.ndarray): + audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) + elif isinstance(audio,torch.Tensor): + audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) + else: + raise ValueError("Input should be either numpy array or torch tensor") return audio - - @staticmethod - def pt_resample_audio( - audio, - sr:int, - target_sr:int - ): - if sr!=target_sr: - audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) - - return audio \ No newline at end of file