commit
						9989705a60
					
				|  | @ -1,3 +1,3 @@ | ||||||
| from enhancer.utils.utils import check_files | from enhancer.utils.utils import check_files | ||||||
| from enhancer.utils.io import Audio | from enhancer.utils.io import Audio | ||||||
| from enhancer.utils.config import Files | from enhancer.utils.config import Files | ||||||
|  |  | ||||||
|  | @ -1,10 +1,9 @@ | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class Files: | class Files: | ||||||
|     train_clean : str |     train_clean: str | ||||||
|     train_noisy : str |     train_noisy: str | ||||||
|     test_clean : str |     test_clean: str | ||||||
|     test_noisy : str |     test_noisy: str | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -1,41 +1,66 @@ | ||||||
| import os | import os | ||||||
| import librosa | import librosa | ||||||
| from typing import Optional | from pathlib import Path | ||||||
|  | from typing import Optional, Union | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torchaudio | import torchaudio | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class Audio: | class Audio: | ||||||
|  |     """ | ||||||
|  |     Audio utils | ||||||
|  |     parameters: | ||||||
|  |         sampling_rate : int, defaults to 16KHz | ||||||
|  |             audio sampling rate | ||||||
|  |         mono: bool, defaults to True | ||||||
|  |         return_tensors: bool, defaults to True | ||||||
|  |             returns torch tensor type if set to True else numpy ndarray | ||||||
|  |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True | ||||||
|         sampling_rate:int=16000, |  | ||||||
|         mono:bool=True, |  | ||||||
|         return_tensor=True |  | ||||||
|     ) -> None: |     ) -> None: | ||||||
|          | 
 | ||||||
|         self.sampling_rate = sampling_rate |         self.sampling_rate = sampling_rate | ||||||
|         self.mono = mono |         self.mono = mono | ||||||
|         self.return_tensor = return_tensor |         self.return_tensor = return_tensor | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
|         audio, |         audio: Union[Path, np.ndarray, torch.Tensor], | ||||||
|         sampling_rate:Optional[int]=None, |         sampling_rate: Optional[int] = None, | ||||||
|         offset:Optional[float] = None, |         offset: Optional[float] = None, | ||||||
|         duration:Optional[float] = None |         duration: Optional[float] = None, | ||||||
|     ): |     ): | ||||||
|         if isinstance(audio,str): |         """ | ||||||
|  |         read and process input audio | ||||||
|  |         parameters: | ||||||
|  |             audio: Path to audio file or numpy array or torch tensor | ||||||
|  |                 single input audio | ||||||
|  |             sampling_rate : int, optional | ||||||
|  |                 sampling rate of the audio input | ||||||
|  |             offset: float, optional | ||||||
|  |                 offset from which the audio must be read, reads from beginning if unused. | ||||||
|  |             duration: float (seconds), optional | ||||||
|  |                 read duration, reads full audio starting from offset if not used | ||||||
|  |         """ | ||||||
|  |         if isinstance(audio, str): | ||||||
|             if os.path.exists(audio): |             if os.path.exists(audio): | ||||||
|                 audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, |                 audio, sampling_rate = librosa.load( | ||||||
|                 offset=offset,duration=duration) |                     audio, | ||||||
|  |                     sr=sampling_rate, | ||||||
|  |                     mono=False, | ||||||
|  |                     offset=offset, | ||||||
|  |                     duration=duration, | ||||||
|  |                 ) | ||||||
|                 if len(audio.shape) == 1: |                 if len(audio.shape) == 1: | ||||||
|                     audio = audio.reshape(1,-1) |                     audio = audio.reshape(1, -1) | ||||||
|             else: |             else: | ||||||
|                 raise FileNotFoundError(f"File {audio} deos not exist") |                 raise FileNotFoundError(f"File {audio} deos not exist") | ||||||
|         elif isinstance(audio,np.ndarray): |         elif isinstance(audio, np.ndarray): | ||||||
|             if len(audio.shape) == 1: |             if len(audio.shape) == 1: | ||||||
|                 audio = audio.reshape(1,-1) |                 audio = audio.reshape(1, -1) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("audio should be either filepath or numpy ndarray") |             raise ValueError("audio should be either filepath or numpy ndarray") | ||||||
| 
 | 
 | ||||||
|  | @ -43,40 +68,60 @@ class Audio: | ||||||
|             audio = self.convert_mono(audio) |             audio = self.convert_mono(audio) | ||||||
| 
 | 
 | ||||||
|         if sampling_rate: |         if sampling_rate: | ||||||
|             audio =  self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) |             audio = self.__class__.resample_audio( | ||||||
|  |                 audio, self.sampling_rate, sampling_rate | ||||||
|  |             ) | ||||||
|         if self.return_tensor: |         if self.return_tensor: | ||||||
|             return torch.tensor(audio) |             return torch.tensor(audio) | ||||||
|         else: |         else: | ||||||
|             return audio |             return audio | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def convert_mono( |     def convert_mono(audio: Union[np.ndarray, torch.Tensor]): | ||||||
|         audio |         """ | ||||||
|  |         convert input audio into mono (1) | ||||||
|  |         parameters: | ||||||
|  |             audio: np.ndarray or torch.Tensor | ||||||
|  |         """ | ||||||
|  |         if len(audio.shape) > 2: | ||||||
|  |             assert ( | ||||||
|  |                 audio.shape[0] == 1 | ||||||
|  |             ), "convert mono only accepts single waveform" | ||||||
|  |             audio = audio.reshape(audio.shape[1], audio.shape[2]) | ||||||
| 
 | 
 | ||||||
|     ): |         assert ( | ||||||
|         if len(audio.shape)>2: |             audio.shape[1] >> audio.shape[0] | ||||||
|             assert audio.shape[0] == 1, "convert mono only accepts single waveform" |         ), f"expected input format (num_channels,num_samples) got {audio.shape}" | ||||||
|             audio = audio.reshape(audio.shape[1],audio.shape[2]) |         num_channels, num_samples = audio.shape | ||||||
|           |         if num_channels > 1: | ||||||
|         assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" |             return audio.mean(axis=0).reshape(1, num_samples) | ||||||
|         num_channels,num_samples = audio.shape |  | ||||||
|         if num_channels>1: |  | ||||||
|             return audio.mean(axis=0).reshape(1,num_samples) |  | ||||||
|         return audio |         return audio | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def resample_audio( |     def resample_audio( | ||||||
|         audio, |         audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int | ||||||
|         sr:int, |  | ||||||
|         target_sr:int |  | ||||||
|     ): |     ): | ||||||
|         if sr!=target_sr: |         """ | ||||||
|             if isinstance(audio,np.ndarray): |         resample audio to desired sampling rate | ||||||
|                 audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) |         parameters: | ||||||
|             elif isinstance(audio,torch.Tensor): |             audio : Path to audio file or numpy array or torch tensor | ||||||
|                 audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) |                 audio waveform | ||||||
|  |             sr : int | ||||||
|  |                 current sampling rate | ||||||
|  |             target_sr : int | ||||||
|  |                 target sampling rate | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         if sr != target_sr: | ||||||
|  |             if isinstance(audio, np.ndarray): | ||||||
|  |                 audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) | ||||||
|  |             elif isinstance(audio, torch.Tensor): | ||||||
|  |                 audio = torchaudio.functional.resample( | ||||||
|  |                     audio, orig_freq=sr, new_freq=target_sr | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 raise ValueError("Input should be either numpy array or torch tensor") |                 raise ValueError( | ||||||
|  |                     "Input should be either numpy array or torch tensor" | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         return audio |         return audio | ||||||
|  |  | ||||||
|  | @ -3,17 +3,16 @@ import random | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | def create_unique_rng(epoch: int): | ||||||
| def create_unique_rng(epoch:int): |  | ||||||
|     """create unique random number generator for each (worker_id,epoch) combination""" |     """create unique random number generator for each (worker_id,epoch) combination""" | ||||||
| 
 | 
 | ||||||
|     rng = random.Random() |     rng = random.Random() | ||||||
| 
 | 
 | ||||||
|     global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) |     global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) | ||||||
|     global_rank = int(os.environ.get('GLOBAL_RANK',"0")) |     global_rank = int(os.environ.get("GLOBAL_RANK", "0")) | ||||||
|     local_rank = int(os.environ.get('LOCAL_RANK',"0")) |     local_rank = int(os.environ.get("LOCAL_RANK", "0")) | ||||||
|     node_rank = int(os.environ.get('NODE_RANK',"0")) |     node_rank = int(os.environ.get("NODE_RANK", "0")) | ||||||
|     world_size = int(os.environ.get('WORLD_SIZE',"0")) |     world_size = int(os.environ.get("WORLD_SIZE", "0")) | ||||||
| 
 | 
 | ||||||
|     worker_info = torch.utils.data.get_worker_info() |     worker_info = torch.utils.data.get_worker_info() | ||||||
|     if worker_info is not None: |     if worker_info is not None: | ||||||
|  | @ -24,17 +23,13 @@ def create_unique_rng(epoch:int): | ||||||
|         worker_id = 0 |         worker_id = 0 | ||||||
| 
 | 
 | ||||||
|     seed = ( |     seed = ( | ||||||
|             global_seed |         global_seed | ||||||
|             + worker_id |         + worker_id | ||||||
|             + local_rank * num_workers |         + local_rank * num_workers | ||||||
|             + node_rank * num_workers * global_rank |         + node_rank * num_workers * global_rank | ||||||
|             + epoch * num_workers * world_size |         + epoch * num_workers * world_size | ||||||
|         ) |     ) | ||||||
| 
 | 
 | ||||||
|     rng.seed(seed) |     rng.seed(seed) | ||||||
| 
 | 
 | ||||||
|     return rng |     return rng | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -1,19 +1,24 @@ | ||||||
| 
 |  | ||||||
| import os | import os | ||||||
| from typing import Optional | from typing import Optional | ||||||
| from enhancer.utils.config import Files | from enhancer.utils.config import Files | ||||||
| 
 | 
 | ||||||
| def check_files(root_dir:str, files:Files): |  | ||||||
| 
 | 
 | ||||||
|     path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] | def check_files(root_dir: str, files: Files): | ||||||
|  | 
 | ||||||
|  |     path_variables = [ | ||||||
|  |         member_var | ||||||
|  |         for member_var in dir(files) | ||||||
|  |         if not member_var.startswith("__") | ||||||
|  |     ] | ||||||
|     for variable in path_variables: |     for variable in path_variables: | ||||||
|         path = getattr(files,variable) |         path = getattr(files, variable) | ||||||
|         if not os.path.isdir(os.path.join(root_dir,path)): |         if not os.path.isdir(os.path.join(root_dir, path)): | ||||||
|             raise ValueError(f"Invalid {path}, is not a directory") |             raise ValueError(f"Invalid {path}, is not a directory") | ||||||
|              |  | ||||||
|     return files,root_dir |  | ||||||
| 
 | 
 | ||||||
| def merge_dict(default_dict:dict, custom:Optional[dict]=None): |     return files, root_dir | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def merge_dict(default_dict: dict, custom: Optional[dict] = None): | ||||||
|     params = dict(default_dict) |     params = dict(default_dict) | ||||||
|     if custom: |     if custom: | ||||||
|         params.update(custom) |         params.update(custom) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES