Merge pull request #10 from shahules786/dev-reformat

Dev reformat
This commit is contained in:
Shahul ES 2022-10-05 15:23:03 +05:30 committed by GitHub
commit 9989705a60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 114 additions and 70 deletions

View File

@ -1,3 +1,3 @@
from enhancer.utils.utils import check_files from enhancer.utils.utils import check_files
from enhancer.utils.io import Audio from enhancer.utils.io import Audio
from enhancer.utils.config import Files from enhancer.utils.config import Files

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class Files: class Files:
train_clean : str train_clean: str
train_noisy : str train_noisy: str
test_clean : str test_clean: str
test_noisy : str test_noisy: str

View File

@ -1,41 +1,66 @@
import os import os
import librosa import librosa
from typing import Optional from pathlib import Path
from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
class Audio: class Audio:
"""
Audio utils
parameters:
sampling_rate : int, defaults to 16KHz
audio sampling rate
mono: bool, defaults to True
return_tensors: bool, defaults to True
returns torch tensor type if set to True else numpy ndarray
"""
def __init__( def __init__(
self, self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
sampling_rate:int=16000,
mono:bool=True,
return_tensor=True
) -> None: ) -> None:
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.mono = mono self.mono = mono
self.return_tensor = return_tensor self.return_tensor = return_tensor
def __call__( def __call__(
self, self,
audio, audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate:Optional[int]=None, sampling_rate: Optional[int] = None,
offset:Optional[float] = None, offset: Optional[float] = None,
duration:Optional[float] = None duration: Optional[float] = None,
): ):
if isinstance(audio,str): """
read and process input audio
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate : int, optional
sampling rate of the audio input
offset: float, optional
offset from which the audio must be read, reads from beginning if unused.
duration: float (seconds), optional
read duration, reads full audio starting from offset if not used
"""
if isinstance(audio, str):
if os.path.exists(audio): if os.path.exists(audio):
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, audio, sampling_rate = librosa.load(
offset=offset,duration=duration) audio,
sr=sampling_rate,
mono=False,
offset=offset,
duration=duration,
)
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
raise FileNotFoundError(f"File {audio} deos not exist") raise FileNotFoundError(f"File {audio} deos not exist")
elif isinstance(audio,np.ndarray): elif isinstance(audio, np.ndarray):
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
raise ValueError("audio should be either filepath or numpy ndarray") raise ValueError("audio should be either filepath or numpy ndarray")
@ -43,40 +68,60 @@ class Audio:
audio = self.convert_mono(audio) audio = self.convert_mono(audio)
if sampling_rate: if sampling_rate:
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate
)
if self.return_tensor: if self.return_tensor:
return torch.tensor(audio) return torch.tensor(audio)
else: else:
return audio return audio
@staticmethod @staticmethod
def convert_mono( def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
audio """
convert input audio into mono (1)
parameters:
audio: np.ndarray or torch.Tensor
"""
if len(audio.shape) > 2:
assert (
audio.shape[0] == 1
), "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1], audio.shape[2])
): assert (
if len(audio.shape)>2: audio.shape[1] >> audio.shape[0]
assert audio.shape[0] == 1, "convert mono only accepts single waveform" ), f"expected input format (num_channels,num_samples) got {audio.shape}"
audio = audio.reshape(audio.shape[1],audio.shape[2]) num_channels, num_samples = audio.shape
if num_channels > 1:
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" return audio.mean(axis=0).reshape(1, num_samples)
num_channels,num_samples = audio.shape
if num_channels>1:
return audio.mean(axis=0).reshape(1,num_samples)
return audio return audio
@staticmethod @staticmethod
def resample_audio( def resample_audio(
audio, audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
sr:int,
target_sr:int
): ):
if sr!=target_sr: """
if isinstance(audio,np.ndarray): resample audio to desired sampling rate
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) parameters:
elif isinstance(audio,torch.Tensor): audio : Path to audio file or numpy array or torch tensor
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) audio waveform
sr : int
current sampling rate
target_sr : int
target sampling rate
"""
if sr != target_sr:
if isinstance(audio, np.ndarray):
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
elif isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(
audio, orig_freq=sr, new_freq=target_sr
)
else: else:
raise ValueError("Input should be either numpy array or torch tensor") raise ValueError(
"Input should be either numpy array or torch tensor"
)
return audio return audio

View File

@ -3,17 +3,16 @@ import random
import torch import torch
def create_unique_rng(epoch: int):
def create_unique_rng(epoch:int):
"""create unique random number generator for each (worker_id,epoch) combination""" """create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random() rng = random.Random()
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
global_rank = int(os.environ.get('GLOBAL_RANK',"0")) global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
local_rank = int(os.environ.get('LOCAL_RANK',"0")) local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get('NODE_RANK',"0")) node_rank = int(os.environ.get("NODE_RANK", "0"))
world_size = int(os.environ.get('WORLD_SIZE',"0")) world_size = int(os.environ.get("WORLD_SIZE", "0"))
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: if worker_info is not None:
@ -24,17 +23,13 @@ def create_unique_rng(epoch:int):
worker_id = 0 worker_id = 0
seed = ( seed = (
global_seed global_seed
+ worker_id + worker_id
+ local_rank * num_workers + local_rank * num_workers
+ node_rank * num_workers * global_rank + node_rank * num_workers * global_rank
+ epoch * num_workers * world_size + epoch * num_workers * world_size
) )
rng.seed(seed) rng.seed(seed)
return rng return rng

View File

@ -1,19 +1,24 @@
import os import os
from typing import Optional from typing import Optional
from enhancer.utils.config import Files from enhancer.utils.config import Files
def check_files(root_dir:str, files:Files):
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] def check_files(root_dir: str, files: Files):
path_variables = [
member_var
for member_var in dir(files)
if not member_var.startswith("__")
]
for variable in path_variables: for variable in path_variables:
path = getattr(files,variable) path = getattr(files, variable)
if not os.path.isdir(os.path.join(root_dir,path)): if not os.path.isdir(os.path.join(root_dir, path)):
raise ValueError(f"Invalid {path}, is not a directory") raise ValueError(f"Invalid {path}, is not a directory")
return files,root_dir
def merge_dict(default_dict:dict, custom:Optional[dict]=None): return files, root_dir
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
params = dict(default_dict) params = dict(default_dict)
if custom: if custom:
params.update(custom) params.update(custom)