refactor IO functions

This commit is contained in:
shahules786 2022-09-19 22:34:24 +05:30
parent 43b1dd190a
commit 2a293a1d40
1 changed files with 13 additions and 14 deletions

View File

@ -50,13 +50,18 @@ class Audio:
else: else:
return audio return audio
@staticmethod
def convert_mono( def convert_mono(
self,
audio audio
): ):
if len(audio.shape)>2:
assert audio.shape[0] == 1, "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1],audio.shape[2])
assert audio.shape[0] > audio.shape[1], "expected input format (num_channels,num_samples)"
num_channels,num_samples = audio.shape num_channels,num_samples = audio.shape
if num_channels>1 and self.mono: if num_channels>1:
return audio.mean(axis=0).reshape(1,num_samples) return audio.mean(axis=0).reshape(1,num_samples)
return audio return audio
@ -68,17 +73,11 @@ class Audio:
target_sr:int target_sr:int
): ):
if sr!=target_sr: if sr!=target_sr:
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) if isinstance(audio,np.ndarray):
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr)
elif isinstance(audio,torch.Tensor):
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
else:
raise ValueError("Input should be either numpy array or torch tensor")
return audio return audio
@staticmethod
def pt_resample_audio(
audio,
sr:int,
target_sr:int
):
if sr!=target_sr:
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
return audio