commit
9989705a60
|
|
@ -1,10 +1,9 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Files:
|
class Files:
|
||||||
train_clean: str
|
train_clean: str
|
||||||
train_noisy: str
|
train_noisy: str
|
||||||
test_clean: str
|
test_clean: str
|
||||||
test_noisy: str
|
test_noisy: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,25 @@
|
||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
class Audio:
|
class Audio:
|
||||||
|
"""
|
||||||
|
Audio utils
|
||||||
|
parameters:
|
||||||
|
sampling_rate : int, defaults to 16KHz
|
||||||
|
audio sampling rate
|
||||||
|
mono: bool, defaults to True
|
||||||
|
return_tensors: bool, defaults to True
|
||||||
|
returns torch tensor type if set to True else numpy ndarray
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
|
||||||
sampling_rate:int=16000,
|
|
||||||
mono:bool=True,
|
|
||||||
return_tensor=True
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
@ -20,15 +28,32 @@ class Audio:
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
audio,
|
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
offset: Optional[float] = None,
|
offset: Optional[float] = None,
|
||||||
duration:Optional[float] = None
|
duration: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
read and process input audio
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate : int, optional
|
||||||
|
sampling rate of the audio input
|
||||||
|
offset: float, optional
|
||||||
|
offset from which the audio must be read, reads from beginning if unused.
|
||||||
|
duration: float (seconds), optional
|
||||||
|
read duration, reads full audio starting from offset if not used
|
||||||
|
"""
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
if os.path.exists(audio):
|
if os.path.exists(audio):
|
||||||
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False,
|
audio, sampling_rate = librosa.load(
|
||||||
offset=offset,duration=duration)
|
audio,
|
||||||
|
sr=sampling_rate,
|
||||||
|
mono=False,
|
||||||
|
offset=offset,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.reshape(1, -1)
|
audio = audio.reshape(1, -1)
|
||||||
else:
|
else:
|
||||||
|
|
@ -43,40 +68,60 @@ class Audio:
|
||||||
audio = self.convert_mono(audio)
|
audio = self.convert_mono(audio)
|
||||||
|
|
||||||
if sampling_rate:
|
if sampling_rate:
|
||||||
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate)
|
audio = self.__class__.resample_audio(
|
||||||
|
audio, self.sampling_rate, sampling_rate
|
||||||
|
)
|
||||||
if self.return_tensor:
|
if self.return_tensor:
|
||||||
return torch.tensor(audio)
|
return torch.tensor(audio)
|
||||||
else:
|
else:
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_mono(
|
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
|
||||||
audio
|
"""
|
||||||
|
convert input audio into mono (1)
|
||||||
):
|
parameters:
|
||||||
|
audio: np.ndarray or torch.Tensor
|
||||||
|
"""
|
||||||
if len(audio.shape) > 2:
|
if len(audio.shape) > 2:
|
||||||
assert audio.shape[0] == 1, "convert mono only accepts single waveform"
|
assert (
|
||||||
|
audio.shape[0] == 1
|
||||||
|
), "convert mono only accepts single waveform"
|
||||||
audio = audio.reshape(audio.shape[1], audio.shape[2])
|
audio = audio.reshape(audio.shape[1], audio.shape[2])
|
||||||
|
|
||||||
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}"
|
assert (
|
||||||
|
audio.shape[1] >> audio.shape[0]
|
||||||
|
), f"expected input format (num_channels,num_samples) got {audio.shape}"
|
||||||
num_channels, num_samples = audio.shape
|
num_channels, num_samples = audio.shape
|
||||||
if num_channels > 1:
|
if num_channels > 1:
|
||||||
return audio.mean(axis=0).reshape(1, num_samples)
|
return audio.mean(axis=0).reshape(1, num_samples)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resample_audio(
|
def resample_audio(
|
||||||
audio,
|
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
|
||||||
sr:int,
|
|
||||||
target_sr:int
|
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
resample audio to desired sampling rate
|
||||||
|
parameters:
|
||||||
|
audio : Path to audio file or numpy array or torch tensor
|
||||||
|
audio waveform
|
||||||
|
sr : int
|
||||||
|
current sampling rate
|
||||||
|
target_sr : int
|
||||||
|
target sampling rate
|
||||||
|
|
||||||
|
"""
|
||||||
if sr != target_sr:
|
if sr != target_sr:
|
||||||
if isinstance(audio, np.ndarray):
|
if isinstance(audio, np.ndarray):
|
||||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
||||||
elif isinstance(audio, torch.Tensor):
|
elif isinstance(audio, torch.Tensor):
|
||||||
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
|
audio = torchaudio.functional.resample(
|
||||||
|
audio, orig_freq=sr, new_freq=target_sr
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input should be either numpy array or torch tensor")
|
raise ValueError(
|
||||||
|
"Input should be either numpy array or torch tensor"
|
||||||
|
)
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,16 @@ import random
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_unique_rng(epoch: int):
|
def create_unique_rng(epoch: int):
|
||||||
"""create unique random number generator for each (worker_id,epoch) combination"""
|
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||||
|
|
||||||
rng = random.Random()
|
rng = random.Random()
|
||||||
|
|
||||||
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
|
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
|
||||||
global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
|
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
|
||||||
local_rank = int(os.environ.get('LOCAL_RANK',"0"))
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
node_rank = int(os.environ.get('NODE_RANK',"0"))
|
node_rank = int(os.environ.get("NODE_RANK", "0"))
|
||||||
world_size = int(os.environ.get('WORLD_SIZE',"0"))
|
world_size = int(os.environ.get("WORLD_SIZE", "0"))
|
||||||
|
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
if worker_info is not None:
|
if worker_info is not None:
|
||||||
|
|
@ -34,7 +33,3 @@ def create_unique_rng(epoch:int):
|
||||||
rng.seed(seed)
|
rng.seed(seed)
|
||||||
|
|
||||||
return rng
|
return rng
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enhancer.utils.config import Files
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
def check_files(root_dir: str, files: Files):
|
def check_files(root_dir: str, files: Files):
|
||||||
|
|
||||||
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')]
|
path_variables = [
|
||||||
|
member_var
|
||||||
|
for member_var in dir(files)
|
||||||
|
if not member_var.startswith("__")
|
||||||
|
]
|
||||||
for variable in path_variables:
|
for variable in path_variables:
|
||||||
path = getattr(files, variable)
|
path = getattr(files, variable)
|
||||||
if not os.path.isdir(os.path.join(root_dir, path)):
|
if not os.path.isdir(os.path.join(root_dir, path)):
|
||||||
|
|
@ -13,6 +17,7 @@ def check_files(root_dir:str, files:Files):
|
||||||
|
|
||||||
return files, root_dir
|
return files, root_dir
|
||||||
|
|
||||||
|
|
||||||
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
|
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
|
||||||
params = dict(default_dict)
|
params = dict(default_dict)
|
||||||
if custom:
|
if custom:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue