From aca4521ef27725df97ac055297ef8301d7c04d11 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:19:17 +0530 Subject: [PATCH 1/2] refactor Audio --- enhancer/utils/io.py | 121 +++++++++++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index afc19e8..3703285 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,41 +1,66 @@ import os import librosa -from typing import Optional +from pathlib import Path +from typing import Optional, Union import numpy as np import torch import torchaudio + class Audio: + """ + Audio utils + parameters: + sampling_rate : int, defaults to 16KHz + audio sampling rate + mono: bool, defaults to True + return_tensors: bool, defaults to True + returns torch tensor type if set to True else numpy ndarray + """ def __init__( - self, - sampling_rate:int=16000, - mono:bool=True, - return_tensor=True + self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True ) -> None: - + self.sampling_rate = sampling_rate self.mono = mono self.return_tensor = return_tensor def __call__( self, - audio, - sampling_rate:Optional[int]=None, - offset:Optional[float] = None, - duration:Optional[float] = None + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + offset: Optional[float] = None, + duration: Optional[float] = None, ): - if isinstance(audio,str): + """ + read and process input audio + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate : int, optional + sampling rate of the audio input + offset: float, optional + offset from which the audio must be read, reads from beginning if unused. + duration: float (seconds), optional + read duration, reads full audio starting from offset if not used + """ + if isinstance(audio, str): if os.path.exists(audio): - audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, - offset=offset,duration=duration) + audio, sampling_rate = librosa.load( + audio, + sr=sampling_rate, + mono=False, + offset=offset, + duration=duration, + ) if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: raise FileNotFoundError(f"File {audio} deos not exist") - elif isinstance(audio,np.ndarray): + elif isinstance(audio, np.ndarray): if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: raise ValueError("audio should be either filepath or numpy ndarray") @@ -43,40 +68,60 @@ class Audio: audio = self.convert_mono(audio) if sampling_rate: - audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) + audio = self.__class__.resample_audio( + audio, self.sampling_rate, sampling_rate + ) if self.return_tensor: return torch.tensor(audio) else: return audio @staticmethod - def convert_mono( - audio + def convert_mono(audio: Union[np.ndarray, torch.Tensor]): + """ + convert input audio into mono (1) + parameters: + audio: np.ndarray or torch.Tensor + """ + if len(audio.shape) > 2: + assert ( + audio.shape[0] == 1 + ), "convert mono only accepts single waveform" + audio = audio.reshape(audio.shape[1], audio.shape[2]) - ): - if len(audio.shape)>2: - assert audio.shape[0] == 1, "convert mono only accepts single waveform" - audio = audio.reshape(audio.shape[1],audio.shape[2]) - - assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" - num_channels,num_samples = audio.shape - if num_channels>1: - return audio.mean(axis=0).reshape(1,num_samples) + assert ( + audio.shape[1] >> audio.shape[0] + ), f"expected input format (num_channels,num_samples) got {audio.shape}" + num_channels, num_samples = audio.shape + if num_channels > 1: + return audio.mean(axis=0).reshape(1, num_samples) return audio - @staticmethod def resample_audio( - audio, - sr:int, - target_sr:int + audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int ): - if sr!=target_sr: - if isinstance(audio,np.ndarray): - audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) - elif isinstance(audio,torch.Tensor): - audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) + """ + resample audio to desired sampling rate + parameters: + audio : Path to audio file or numpy array or torch tensor + audio waveform + sr : int + current sampling rate + target_sr : int + target sampling rate + + """ + if sr != target_sr: + if isinstance(audio, np.ndarray): + audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) + elif isinstance(audio, torch.Tensor): + audio = torchaudio.functional.resample( + audio, orig_freq=sr, new_freq=target_sr + ) else: - raise ValueError("Input should be either numpy array or torch tensor") + raise ValueError( + "Input should be either numpy array or torch tensor" + ) return audio From 6c4bced3607f23db76d911f9a3471de22db53fa1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:20:43 +0530 Subject: [PATCH 2/2] reformat utils --- enhancer/utils/__init__.py | 2 +- enhancer/utils/config.py | 11 +++++------ enhancer/utils/random.py | 29 ++++++++++++----------------- enhancer/utils/utils.py | 21 +++++++++++++-------- 4 files changed, 31 insertions(+), 32 deletions(-) diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index 3da7ede..c9f5438 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,3 +1,3 @@ from enhancer.utils.utils import check_files from enhancer.utils.io import Audio -from enhancer.utils.config import Files \ No newline at end of file +from enhancer.utils.config import Files diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index 1bbc51d..252e6c9 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -1,10 +1,9 @@ from dataclasses import dataclass + @dataclass class Files: - train_clean : str - train_noisy : str - test_clean : str - test_noisy : str - - + train_clean: str + train_noisy: str + test_clean: str + test_noisy: str diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 3b1acac..51e09c0 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -3,17 +3,16 @@ import random import torch - -def create_unique_rng(epoch:int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() - global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) - global_rank = int(os.environ.get('GLOBAL_RANK',"0")) - local_rank = int(os.environ.get('LOCAL_RANK',"0")) - node_rank = int(os.environ.get('NODE_RANK',"0")) - world_size = int(os.environ.get('WORLD_SIZE',"0")) + global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) + global_rank = int(os.environ.get("GLOBAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + node_rank = int(os.environ.get("NODE_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "0")) worker_info = torch.utils.data.get_worker_info() if worker_info is not None: @@ -24,17 +23,13 @@ def create_unique_rng(epoch:int): worker_id = 0 seed = ( - global_seed - + worker_id - + local_rank * num_workers - + node_rank * num_workers * global_rank - + epoch * num_workers * world_size - ) + global_seed + + worker_id + + local_rank * num_workers + + node_rank * num_workers * global_rank + + epoch * num_workers * world_size + ) rng.seed(seed) return rng - - - - diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index be74dc2..73673ed 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -1,19 +1,24 @@ - import os from typing import Optional from enhancer.utils.config import Files -def check_files(root_dir:str, files:Files): - path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] +def check_files(root_dir: str, files: Files): + + path_variables = [ + member_var + for member_var in dir(files) + if not member_var.startswith("__") + ] for variable in path_variables: - path = getattr(files,variable) - if not os.path.isdir(os.path.join(root_dir,path)): + path = getattr(files, variable) + if not os.path.isdir(os.path.join(root_dir, path)): raise ValueError(f"Invalid {path}, is not a directory") - - return files,root_dir -def merge_dict(default_dict:dict, custom:Optional[dict]=None): + return files, root_dir + + +def merge_dict(default_dict: dict, custom: Optional[dict] = None): params = dict(default_dict) if custom: params.update(custom)