add doc/refactor black
This commit is contained in:
parent
d31a6d2ebd
commit
7f00707733
|
|
@ -1,9 +1,7 @@
|
||||||
from json import load
|
|
||||||
import wave
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.signal import get_window
|
from scipy.signal import get_window
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -11,10 +9,21 @@ from librosa import load as load_audio
|
||||||
|
|
||||||
from enhancer.utils import Audio
|
from enhancer.utils import Audio
|
||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
|
"""
|
||||||
|
contains methods used for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_input(audio, sr, model_sr):
|
def read_input(audio, sr, model_sr):
|
||||||
|
"""
|
||||||
|
read and verify audio input regardless of the input format.
|
||||||
|
arguments:
|
||||||
|
audio : audio input
|
||||||
|
sr : sampling rate of input audio
|
||||||
|
model_sr : sampling rate used for model training.
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
||||||
assert sr is not None, "Invalid sampling rate!"
|
assert sr is not None, "Invalid sampling rate!"
|
||||||
|
|
@ -24,11 +33,16 @@ class Inference:
|
||||||
if not audio.is_file():
|
if not audio.is_file():
|
||||||
raise ValueError(f"Input file {audio} does not exist")
|
raise ValueError(f"Input file {audio} does not exist")
|
||||||
else:
|
else:
|
||||||
audio,sr = load_audio(audio,sr=sr,)
|
audio, sr = load_audio(
|
||||||
|
audio,
|
||||||
|
sr=sr,
|
||||||
|
)
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.reshape(1, -1)
|
audio = audio.reshape(1, -1)
|
||||||
else:
|
else:
|
||||||
assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
|
assert (
|
||||||
|
audio.shape[0] == 1
|
||||||
|
), "Enhance inference only supports single waveform"
|
||||||
|
|
||||||
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
|
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
|
||||||
waveform = Audio.convert_mono(waveform)
|
waveform = Audio.convert_mono(waveform)
|
||||||
|
|
@ -38,28 +52,52 @@ class Inference:
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
|
def batchify(
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
break input waveform into samples with duration specified.
|
break input waveform into samples with duration specified.(Overlap-add)
|
||||||
|
arguments:
|
||||||
|
waveform : audio waveform
|
||||||
|
window_size : window size used for splitting waveform into batches
|
||||||
|
step_size : step_size used for splitting waveform into batches
|
||||||
"""
|
"""
|
||||||
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
assert (
|
||||||
|
waveform.ndim == 2
|
||||||
|
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
||||||
_, num_samples = waveform.shape
|
_, num_samples = waveform.shape
|
||||||
waveform = waveform.unsqueeze(-1)
|
waveform = waveform.unsqueeze(-1)
|
||||||
step_size = window_size // 2 if step_size is None else step_size
|
step_size = window_size // 2 if step_size is None else step_size
|
||||||
|
|
||||||
if num_samples >= window_size:
|
if num_samples >= window_size:
|
||||||
waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1),
|
waveform_batch = F.unfold(
|
||||||
stride=(step_size,1), padding=(window_size,0))
|
waveform[None, ...],
|
||||||
|
kernel_size=(window_size, 1),
|
||||||
|
stride=(step_size, 1),
|
||||||
|
padding=(window_size, 0),
|
||||||
|
)
|
||||||
waveform_batch = waveform_batch.permute(2, 0, 1)
|
waveform_batch = waveform_batch.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
return waveform_batch
|
return waveform_batch
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None,
|
def aggreagate(
|
||||||
window="hanning",):
|
data: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
total_frames: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
window="hanning",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
takes input as tensor outputs aggregated waveform
|
stitch batched waveform into single waveform. (Overlap-add)
|
||||||
|
arguments:
|
||||||
|
data: batched waveform
|
||||||
|
window_size : window_size used to batch waveform
|
||||||
|
step_size : step_size used to batch waveform
|
||||||
|
total_frames : total number of frames present in original waveform
|
||||||
|
window : type of window used for overlap-add mechanism.
|
||||||
"""
|
"""
|
||||||
num_chunks, n_channels, num_frames = data.shape
|
num_chunks, n_channels, num_frames = data.shape
|
||||||
window = get_window(window=window, Nx=data.shape[-1])
|
window = get_window(window=window, Nx=data.shape[-1])
|
||||||
|
|
@ -67,16 +105,27 @@ class Inference:
|
||||||
data *= window
|
data *= window
|
||||||
|
|
||||||
data = data.permute(1, 2, 0)
|
data = data.permute(1, 2, 0)
|
||||||
data = F.fold(data,
|
data = F.fold(
|
||||||
|
data,
|
||||||
(total_frames, 1),
|
(total_frames, 1),
|
||||||
kernel_size=(window_size, 1),
|
kernel_size=(window_size, 1),
|
||||||
stride=(step_size, 1),
|
stride=(step_size, 1),
|
||||||
padding=(window_size,0)).squeeze(-1)
|
padding=(window_size, 0),
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
return data.reshape(1, n_channels, -1)
|
return data.reshape(1, n_channels, -1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):
|
def write_output(
|
||||||
|
waveform: torch.Tensor, filename: Union[str, Path], sr: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
write audio output as wav file
|
||||||
|
arguments:
|
||||||
|
waveform : audio waveform
|
||||||
|
filename : name of the wave file. Output will be written as cleaned_filename.wav
|
||||||
|
sr : sampling rate
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(filename, str):
|
if isinstance(filename, str):
|
||||||
filename = Path(filename)
|
filename = Path(filename)
|
||||||
|
|
@ -85,15 +134,28 @@ class Inference:
|
||||||
else:
|
else:
|
||||||
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
|
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_output(
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
model_sampling_rate: int,
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: Optional[int],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
prepare output audio based on input format
|
||||||
|
arguments:
|
||||||
|
waveform : predicted audio waveform
|
||||||
|
model_sampling_rate : sampling rate used to train the model
|
||||||
|
audio : input audio
|
||||||
|
sampling_rate : input audio sampling rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(audio, np.ndarray):
|
||||||
|
waveform = waveform.detach().cpu().numpy()
|
||||||
|
|
||||||
|
if sampling_rate is not None:
|
||||||
|
waveform = Audio.resample_audio(
|
||||||
|
waveform, sr=model_sampling_rate, target_sr=sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue