commit
						9989705a60
					
				|  | @ -1,3 +1,3 @@ | |||
| from enhancer.utils.utils import check_files | ||||
| from enhancer.utils.io import Audio | ||||
| from enhancer.utils.config import Files | ||||
| from enhancer.utils.config import Files | ||||
|  |  | |||
|  | @ -1,10 +1,9 @@ | |||
| from dataclasses import dataclass | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class Files: | ||||
|     train_clean : str | ||||
|     train_noisy : str | ||||
|     test_clean : str | ||||
|     test_noisy : str | ||||
| 
 | ||||
| 
 | ||||
|     train_clean: str | ||||
|     train_noisy: str | ||||
|     test_clean: str | ||||
|     test_noisy: str | ||||
|  |  | |||
|  | @ -1,41 +1,66 @@ | |||
| import os | ||||
| import librosa | ||||
| from typing import Optional | ||||
| from pathlib import Path | ||||
| from typing import Optional, Union | ||||
| import numpy as np | ||||
| import torch | ||||
| import torchaudio | ||||
| 
 | ||||
| 
 | ||||
| class Audio: | ||||
|     """ | ||||
|     Audio utils | ||||
|     parameters: | ||||
|         sampling_rate : int, defaults to 16KHz | ||||
|             audio sampling rate | ||||
|         mono: bool, defaults to True | ||||
|         return_tensors: bool, defaults to True | ||||
|             returns torch tensor type if set to True else numpy ndarray | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         sampling_rate:int=16000, | ||||
|         mono:bool=True, | ||||
|         return_tensor=True | ||||
|         self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True | ||||
|     ) -> None: | ||||
|          | ||||
| 
 | ||||
|         self.sampling_rate = sampling_rate | ||||
|         self.mono = mono | ||||
|         self.return_tensor = return_tensor | ||||
| 
 | ||||
|     def __call__( | ||||
|         self, | ||||
|         audio, | ||||
|         sampling_rate:Optional[int]=None, | ||||
|         offset:Optional[float] = None, | ||||
|         duration:Optional[float] = None | ||||
|         audio: Union[Path, np.ndarray, torch.Tensor], | ||||
|         sampling_rate: Optional[int] = None, | ||||
|         offset: Optional[float] = None, | ||||
|         duration: Optional[float] = None, | ||||
|     ): | ||||
|         if isinstance(audio,str): | ||||
|         """ | ||||
|         read and process input audio | ||||
|         parameters: | ||||
|             audio: Path to audio file or numpy array or torch tensor | ||||
|                 single input audio | ||||
|             sampling_rate : int, optional | ||||
|                 sampling rate of the audio input | ||||
|             offset: float, optional | ||||
|                 offset from which the audio must be read, reads from beginning if unused. | ||||
|             duration: float (seconds), optional | ||||
|                 read duration, reads full audio starting from offset if not used | ||||
|         """ | ||||
|         if isinstance(audio, str): | ||||
|             if os.path.exists(audio): | ||||
|                 audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, | ||||
|                 offset=offset,duration=duration) | ||||
|                 audio, sampling_rate = librosa.load( | ||||
|                     audio, | ||||
|                     sr=sampling_rate, | ||||
|                     mono=False, | ||||
|                     offset=offset, | ||||
|                     duration=duration, | ||||
|                 ) | ||||
|                 if len(audio.shape) == 1: | ||||
|                     audio = audio.reshape(1,-1) | ||||
|                     audio = audio.reshape(1, -1) | ||||
|             else: | ||||
|                 raise FileNotFoundError(f"File {audio} deos not exist") | ||||
|         elif isinstance(audio,np.ndarray): | ||||
|         elif isinstance(audio, np.ndarray): | ||||
|             if len(audio.shape) == 1: | ||||
|                 audio = audio.reshape(1,-1) | ||||
|                 audio = audio.reshape(1, -1) | ||||
|         else: | ||||
|             raise ValueError("audio should be either filepath or numpy ndarray") | ||||
| 
 | ||||
|  | @ -43,40 +68,60 @@ class Audio: | |||
|             audio = self.convert_mono(audio) | ||||
| 
 | ||||
|         if sampling_rate: | ||||
|             audio =  self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) | ||||
|             audio = self.__class__.resample_audio( | ||||
|                 audio, self.sampling_rate, sampling_rate | ||||
|             ) | ||||
|         if self.return_tensor: | ||||
|             return torch.tensor(audio) | ||||
|         else: | ||||
|             return audio | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def convert_mono( | ||||
|         audio | ||||
|     def convert_mono(audio: Union[np.ndarray, torch.Tensor]): | ||||
|         """ | ||||
|         convert input audio into mono (1) | ||||
|         parameters: | ||||
|             audio: np.ndarray or torch.Tensor | ||||
|         """ | ||||
|         if len(audio.shape) > 2: | ||||
|             assert ( | ||||
|                 audio.shape[0] == 1 | ||||
|             ), "convert mono only accepts single waveform" | ||||
|             audio = audio.reshape(audio.shape[1], audio.shape[2]) | ||||
| 
 | ||||
|     ): | ||||
|         if len(audio.shape)>2: | ||||
|             assert audio.shape[0] == 1, "convert mono only accepts single waveform" | ||||
|             audio = audio.reshape(audio.shape[1],audio.shape[2]) | ||||
|           | ||||
|         assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" | ||||
|         num_channels,num_samples = audio.shape | ||||
|         if num_channels>1: | ||||
|             return audio.mean(axis=0).reshape(1,num_samples) | ||||
|         assert ( | ||||
|             audio.shape[1] >> audio.shape[0] | ||||
|         ), f"expected input format (num_channels,num_samples) got {audio.shape}" | ||||
|         num_channels, num_samples = audio.shape | ||||
|         if num_channels > 1: | ||||
|             return audio.mean(axis=0).reshape(1, num_samples) | ||||
|         return audio | ||||
| 
 | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def resample_audio( | ||||
|         audio, | ||||
|         sr:int, | ||||
|         target_sr:int | ||||
|         audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int | ||||
|     ): | ||||
|         if sr!=target_sr: | ||||
|             if isinstance(audio,np.ndarray): | ||||
|                 audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) | ||||
|             elif isinstance(audio,torch.Tensor): | ||||
|                 audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) | ||||
|         """ | ||||
|         resample audio to desired sampling rate | ||||
|         parameters: | ||||
|             audio : Path to audio file or numpy array or torch tensor | ||||
|                 audio waveform | ||||
|             sr : int | ||||
|                 current sampling rate | ||||
|             target_sr : int | ||||
|                 target sampling rate | ||||
| 
 | ||||
|         """ | ||||
|         if sr != target_sr: | ||||
|             if isinstance(audio, np.ndarray): | ||||
|                 audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) | ||||
|             elif isinstance(audio, torch.Tensor): | ||||
|                 audio = torchaudio.functional.resample( | ||||
|                     audio, orig_freq=sr, new_freq=target_sr | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("Input should be either numpy array or torch tensor") | ||||
|                 raise ValueError( | ||||
|                     "Input should be either numpy array or torch tensor" | ||||
|                 ) | ||||
| 
 | ||||
|         return audio | ||||
|  |  | |||
|  | @ -3,17 +3,16 @@ import random | |||
| import torch | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def create_unique_rng(epoch:int): | ||||
| def create_unique_rng(epoch: int): | ||||
|     """create unique random number generator for each (worker_id,epoch) combination""" | ||||
| 
 | ||||
|     rng = random.Random() | ||||
| 
 | ||||
|     global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) | ||||
|     global_rank = int(os.environ.get('GLOBAL_RANK',"0")) | ||||
|     local_rank = int(os.environ.get('LOCAL_RANK',"0")) | ||||
|     node_rank = int(os.environ.get('NODE_RANK',"0")) | ||||
|     world_size = int(os.environ.get('WORLD_SIZE',"0")) | ||||
|     global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) | ||||
|     global_rank = int(os.environ.get("GLOBAL_RANK", "0")) | ||||
|     local_rank = int(os.environ.get("LOCAL_RANK", "0")) | ||||
|     node_rank = int(os.environ.get("NODE_RANK", "0")) | ||||
|     world_size = int(os.environ.get("WORLD_SIZE", "0")) | ||||
| 
 | ||||
|     worker_info = torch.utils.data.get_worker_info() | ||||
|     if worker_info is not None: | ||||
|  | @ -24,17 +23,13 @@ def create_unique_rng(epoch:int): | |||
|         worker_id = 0 | ||||
| 
 | ||||
|     seed = ( | ||||
|             global_seed | ||||
|             + worker_id | ||||
|             + local_rank * num_workers | ||||
|             + node_rank * num_workers * global_rank | ||||
|             + epoch * num_workers * world_size | ||||
|         ) | ||||
|         global_seed | ||||
|         + worker_id | ||||
|         + local_rank * num_workers | ||||
|         + node_rank * num_workers * global_rank | ||||
|         + epoch * num_workers * world_size | ||||
|     ) | ||||
| 
 | ||||
|     rng.seed(seed) | ||||
| 
 | ||||
|     return rng | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,19 +1,24 @@ | |||
| 
 | ||||
| import os | ||||
| from typing import Optional | ||||
| from enhancer.utils.config import Files | ||||
| 
 | ||||
| def check_files(root_dir:str, files:Files): | ||||
| 
 | ||||
|     path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] | ||||
| def check_files(root_dir: str, files: Files): | ||||
| 
 | ||||
|     path_variables = [ | ||||
|         member_var | ||||
|         for member_var in dir(files) | ||||
|         if not member_var.startswith("__") | ||||
|     ] | ||||
|     for variable in path_variables: | ||||
|         path = getattr(files,variable) | ||||
|         if not os.path.isdir(os.path.join(root_dir,path)): | ||||
|         path = getattr(files, variable) | ||||
|         if not os.path.isdir(os.path.join(root_dir, path)): | ||||
|             raise ValueError(f"Invalid {path}, is not a directory") | ||||
|              | ||||
|     return files,root_dir | ||||
| 
 | ||||
| def merge_dict(default_dict:dict, custom:Optional[dict]=None): | ||||
|     return files, root_dir | ||||
| 
 | ||||
| 
 | ||||
| def merge_dict(default_dict: dict, custom: Optional[dict] = None): | ||||
|     params = dict(default_dict) | ||||
|     if custom: | ||||
|         params.update(custom) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES