Merge pull request #10 from shahules786/dev-reformat

Dev reformat
This commit is contained in:
Shahul ES 2022-10-05 15:23:03 +05:30 committed by GitHub
commit 9989705a60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 114 additions and 70 deletions

View File

@ -1,3 +1,3 @@
from enhancer.utils.utils import check_files
from enhancer.utils.io import Audio
from enhancer.utils.config import Files
from enhancer.utils.config import Files

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass
@dataclass
class Files:
train_clean : str
train_noisy : str
test_clean : str
test_noisy : str
train_clean: str
train_noisy: str
test_clean: str
test_noisy: str

View File

@ -1,41 +1,66 @@
import os
import librosa
from typing import Optional
from pathlib import Path
from typing import Optional, Union
import numpy as np
import torch
import torchaudio
class Audio:
"""
Audio utils
parameters:
sampling_rate : int, defaults to 16KHz
audio sampling rate
mono: bool, defaults to True
return_tensors: bool, defaults to True
returns torch tensor type if set to True else numpy ndarray
"""
def __init__(
self,
sampling_rate:int=16000,
mono:bool=True,
return_tensor=True
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
) -> None:
self.sampling_rate = sampling_rate
self.mono = mono
self.return_tensor = return_tensor
def __call__(
self,
audio,
sampling_rate:Optional[int]=None,
offset:Optional[float] = None,
duration:Optional[float] = None
audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
offset: Optional[float] = None,
duration: Optional[float] = None,
):
if isinstance(audio,str):
"""
read and process input audio
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate : int, optional
sampling rate of the audio input
offset: float, optional
offset from which the audio must be read, reads from beginning if unused.
duration: float (seconds), optional
read duration, reads full audio starting from offset if not used
"""
if isinstance(audio, str):
if os.path.exists(audio):
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False,
offset=offset,duration=duration)
audio, sampling_rate = librosa.load(
audio,
sr=sampling_rate,
mono=False,
offset=offset,
duration=duration,
)
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
audio = audio.reshape(1, -1)
else:
raise FileNotFoundError(f"File {audio} deos not exist")
elif isinstance(audio,np.ndarray):
elif isinstance(audio, np.ndarray):
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
audio = audio.reshape(1, -1)
else:
raise ValueError("audio should be either filepath or numpy ndarray")
@ -43,40 +68,60 @@ class Audio:
audio = self.convert_mono(audio)
if sampling_rate:
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate)
audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate
)
if self.return_tensor:
return torch.tensor(audio)
else:
return audio
@staticmethod
def convert_mono(
audio
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
"""
convert input audio into mono (1)
parameters:
audio: np.ndarray or torch.Tensor
"""
if len(audio.shape) > 2:
assert (
audio.shape[0] == 1
), "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1], audio.shape[2])
):
if len(audio.shape)>2:
assert audio.shape[0] == 1, "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1],audio.shape[2])
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}"
num_channels,num_samples = audio.shape
if num_channels>1:
return audio.mean(axis=0).reshape(1,num_samples)
assert (
audio.shape[1] >> audio.shape[0]
), f"expected input format (num_channels,num_samples) got {audio.shape}"
num_channels, num_samples = audio.shape
if num_channels > 1:
return audio.mean(axis=0).reshape(1, num_samples)
return audio
@staticmethod
def resample_audio(
audio,
sr:int,
target_sr:int
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
):
if sr!=target_sr:
if isinstance(audio,np.ndarray):
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr)
elif isinstance(audio,torch.Tensor):
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
"""
resample audio to desired sampling rate
parameters:
audio : Path to audio file or numpy array or torch tensor
audio waveform
sr : int
current sampling rate
target_sr : int
target sampling rate
"""
if sr != target_sr:
if isinstance(audio, np.ndarray):
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
elif isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(
audio, orig_freq=sr, new_freq=target_sr
)
else:
raise ValueError("Input should be either numpy array or torch tensor")
raise ValueError(
"Input should be either numpy array or torch tensor"
)
return audio

View File

@ -3,17 +3,16 @@ import random
import torch
def create_unique_rng(epoch:int):
def create_unique_rng(epoch: int):
"""create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random()
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0"))
global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
local_rank = int(os.environ.get('LOCAL_RANK',"0"))
node_rank = int(os.environ.get('NODE_RANK',"0"))
world_size = int(os.environ.get('WORLD_SIZE',"0"))
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get("NODE_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "0"))
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
@ -24,17 +23,13 @@ def create_unique_rng(epoch:int):
worker_id = 0
seed = (
global_seed
+ worker_id
+ local_rank * num_workers
+ node_rank * num_workers * global_rank
+ epoch * num_workers * world_size
)
global_seed
+ worker_id
+ local_rank * num_workers
+ node_rank * num_workers * global_rank
+ epoch * num_workers * world_size
)
rng.seed(seed)
return rng

View File

@ -1,19 +1,24 @@
import os
from typing import Optional
from enhancer.utils.config import Files
def check_files(root_dir:str, files:Files):
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')]
def check_files(root_dir: str, files: Files):
path_variables = [
member_var
for member_var in dir(files)
if not member_var.startswith("__")
]
for variable in path_variables:
path = getattr(files,variable)
if not os.path.isdir(os.path.join(root_dir,path)):
path = getattr(files, variable)
if not os.path.isdir(os.path.join(root_dir, path)):
raise ValueError(f"Invalid {path}, is not a directory")
return files,root_dir
def merge_dict(default_dict:dict, custom:Optional[dict]=None):
return files, root_dir
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
params = dict(default_dict)
if custom:
params.update(custom)