refactor Audio

This commit is contained in:
shahules786 2022-10-05 15:19:17 +05:30
parent b92310c93d
commit aca4521ef2
1 changed files with 83 additions and 38 deletions

View File

@ -1,17 +1,25 @@
import os import os
import librosa import librosa
from typing import Optional from pathlib import Path
from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
class Audio: class Audio:
"""
Audio utils
parameters:
sampling_rate : int, defaults to 16KHz
audio sampling rate
mono: bool, defaults to True
return_tensors: bool, defaults to True
returns torch tensor type if set to True else numpy ndarray
"""
def __init__( def __init__(
self, self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
sampling_rate:int=16000,
mono:bool=True,
return_tensor=True
) -> None: ) -> None:
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@ -20,22 +28,39 @@ class Audio:
def __call__( def __call__(
self, self,
audio, audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate:Optional[int]=None, sampling_rate: Optional[int] = None,
offset:Optional[float] = None, offset: Optional[float] = None,
duration:Optional[float] = None duration: Optional[float] = None,
): ):
if isinstance(audio,str): """
read and process input audio
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate : int, optional
sampling rate of the audio input
offset: float, optional
offset from which the audio must be read, reads from beginning if unused.
duration: float (seconds), optional
read duration, reads full audio starting from offset if not used
"""
if isinstance(audio, str):
if os.path.exists(audio): if os.path.exists(audio):
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, audio, sampling_rate = librosa.load(
offset=offset,duration=duration) audio,
sr=sampling_rate,
mono=False,
offset=offset,
duration=duration,
)
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
raise FileNotFoundError(f"File {audio} deos not exist") raise FileNotFoundError(f"File {audio} deos not exist")
elif isinstance(audio,np.ndarray): elif isinstance(audio, np.ndarray):
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
raise ValueError("audio should be either filepath or numpy ndarray") raise ValueError("audio should be either filepath or numpy ndarray")
@ -43,40 +68,60 @@ class Audio:
audio = self.convert_mono(audio) audio = self.convert_mono(audio)
if sampling_rate: if sampling_rate:
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate
)
if self.return_tensor: if self.return_tensor:
return torch.tensor(audio) return torch.tensor(audio)
else: else:
return audio return audio
@staticmethod @staticmethod
def convert_mono( def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
audio """
convert input audio into mono (1)
parameters:
audio: np.ndarray or torch.Tensor
"""
if len(audio.shape) > 2:
assert (
audio.shape[0] == 1
), "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1], audio.shape[2])
): assert (
if len(audio.shape)>2: audio.shape[1] >> audio.shape[0]
assert audio.shape[0] == 1, "convert mono only accepts single waveform" ), f"expected input format (num_channels,num_samples) got {audio.shape}"
audio = audio.reshape(audio.shape[1],audio.shape[2]) num_channels, num_samples = audio.shape
if num_channels > 1:
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" return audio.mean(axis=0).reshape(1, num_samples)
num_channels,num_samples = audio.shape
if num_channels>1:
return audio.mean(axis=0).reshape(1,num_samples)
return audio return audio
@staticmethod @staticmethod
def resample_audio( def resample_audio(
audio, audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
sr:int,
target_sr:int
): ):
if sr!=target_sr: """
if isinstance(audio,np.ndarray): resample audio to desired sampling rate
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) parameters:
elif isinstance(audio,torch.Tensor): audio : Path to audio file or numpy array or torch tensor
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) audio waveform
sr : int
current sampling rate
target_sr : int
target sampling rate
"""
if sr != target_sr:
if isinstance(audio, np.ndarray):
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
elif isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(
audio, orig_freq=sr, new_freq=target_sr
)
else: else:
raise ValueError("Input should be either numpy array or torch tensor") raise ValueError(
"Input should be either numpy array or torch tensor"
)
return audio return audio