Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-06 09:54:14 +05:30
commit a064151e2e
41 changed files with 1268 additions and 810 deletions

9
.flake8 Normal file
View File

@ -0,0 +1,9 @@
[flake8]
per-file-ignores = __init__.py:F401
ignore = E203, E266, E501, W503
# line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9
exclude = tools/kaldi_decoder

43
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,43 @@
repos:
# # Clean Notebooks
# - repo: https://github.com/kynan/nbstripout
# rev: master
# hooks:
# - id: nbstripout
# Format Code
- repo: https://github.com/ambv/black
rev: 22.8.0
hooks:
- id: black
# Sort imports
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://gitlab.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
args: ['--ignore=E203,E501,F811,E712,W503']
# Formatting, Whitespace, etc
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
- id: check-ast
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: mixed-line-ending
args: ['--fix=no']

View File

@ -1 +1,6 @@
# enhancer # enhancer
Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides
* Various pretrained models nicely integrated with huggingface that users can select and use without any hastle.
* Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
* A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!

View File

@ -1,67 +0,0 @@
from genericpath import isfile
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID","0")
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
loss=parameters.get("loss"), metric = parameters.get("metric"))
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True,
mode=direction,every_n_epochs=1
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience",10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor",0.1),
verbose=True,
min_lr=parameters.get("min_lr",1e-6),
patience=parameters.get("ReduceLr_patience",3)
)
return {"optimizer":optimizer, "lr_scheduler":scheduler}
model.configure_parameters = MethodType(configure_optimizer,model)
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id,saved_location)
if __name__=="__main__":
main()

85
enhancer/cli/train.py Normal file
View File

@ -0,0 +1,85 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="val_loss",
verbose=True,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
parameters=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
model.configure_parameters = MethodType(configure_optimizer, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
if __name__ == "__main__":
main()

View File

@ -10,4 +10,3 @@ files:
test_clean : clean_test_wav test_clean : clean_test_wav
train_noisy : clean_test_wav train_noisy : clean_test_wav
test_noisy : clean_test_wav test_noisy : clean_test_wav

View File

@ -10,6 +10,3 @@ files:
test_clean : clean_testset_wav test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav test_noisy : noisy_testset_wav

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
duration : 1.0
sampling_rate: 16000
batch_size: 64
num_workers : 0
files:
train_clean : clean_testset_wav
test_clean : clean_testset_wav
train_noisy : noisy_testset_wav
test_noisy : noisy_testset_wav

View File

@ -5,4 +5,3 @@ ReduceLr_patience : 5
ReduceLr_factor : 0.1 ReduceLr_factor : 0.1
min_lr : 0.000001 min_lr : 0.000001
EarlyStopping_factor : 10 EarlyStopping_factor : 10

View File

@ -14,5 +14,3 @@ encoder_decoder:
lstm: lstm:
bidirectional: False bidirectional: False
num_layers: 2 num_layers: 2

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -1,20 +1,21 @@
import multiprocessing
import math import math
import multiprocessing
import os import os
import pytorch_lightning as pl
from torch.utils.data import IterableDataset, DataLoader, Dataset
import torch.nn.functional as F
from typing import Optional from typing import Optional
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils.random import create_unique_rng
from enhancer.utils.io import Audio
from enhancer.utils import check_files from enhancer.utils import check_files
from enhancer.utils.config import Files from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
class TrainDataset(IterableDataset): class TrainDataset(IterableDataset):
def __init__(self, dataset):
def __init__(self,dataset):
self.dataset = dataset self.dataset = dataset
def __iter__(self): def __iter__(self):
@ -23,72 +24,86 @@ class TrainDataset(IterableDataset):
def __len__(self): def __len__(self):
return self.dataset.train__len__() return self.dataset.train__len__()
class ValidDataset(Dataset):
def __init__(self,dataset): class ValidDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
def __getitem__(self,idx): def __getitem__(self, idx):
return self.dataset.val__getitem__(idx) return self.dataset.val__getitem__(idx)
def __len__(self): def __len__(self):
return self.dataset.val__len__() return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
class TaskDataset(pl.LightningDataModule):
def __init__( def __init__(
self, self,
name:str, name: str,
root_dir:str, root_dir: str,
files:Files, files: Files,
duration:float=1.0, duration: float = 1.0,
sampling_rate:int=48000, sampling_rate: int = 48000,
matching_function = None, matching_function=None,
batch_size=32, batch_size=32,
num_workers:Optional[int]=None): num_workers: Optional[int] = None,
):
super().__init__() super().__init__()
self.name = name self.name = name
self.files,self.root_dir = check_files(root_dir,files) self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration self.duration = duration
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.batch_size = batch_size self.batch_size = batch_size
self.matching_function = matching_function self.matching_function = matching_function
self._validation = [] self._validation = []
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count()//2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
if stage in ("fit",None): if stage in ("fit", None):
train_clean = os.path.join(self.root_dir,self.files.train_clean) train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy) train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean, fp = Fileprocessor.from_name(
train_noisy, self.matching_function) self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict() self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir,self.files.test_clean) val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy) val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean, fp = Fileprocessor.from_name(
val_noisy, self.matching_function) self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict() val_data = fp.prepare_matching_dict()
for item in val_data: for item in val_data:
clean,noisy,total_dur = item.values() clean, noisy, total_dur = item.values()
if total_dur < self.duration: if total_dur < self.duration:
continue continue
num_segments = round(total_dur/self.duration) num_segments = round(total_dur / self.duration)
for index in range(num_segments): for index in range(num_segments):
start_time = index * self.duration start_time = index * self.duration
self._validation.append(({"clean":clean,"noisy":noisy}, self._validation.append(
start_time)) ({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self): def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self): def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset): class EnhancerDataset(TaskDataset):
""" """
@ -100,7 +115,7 @@ class EnhancerDataset(TaskDataset):
root directory of the dataset containing clean/noisy folders root directory of the dataset containing clean/noisy folders
files : Files files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset) folder names (refer enhancer.utils.Files dataclass)
duration : float duration : float
expected audio duration of single audio sample for training expected audio duration of single audio sample for training
sampling_rate : int sampling_rate : int
@ -119,14 +134,15 @@ class EnhancerDataset(TaskDataset):
def __init__( def __init__(
self, self,
name:str, name: str,
root_dir:str, root_dir: str,
files:Files, files: Files,
duration=1.0, duration=1.0,
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers:Optional[int]=None): num_workers: Optional[int] = None,
):
super().__init__( super().__init__(
name=name, name=name,
@ -134,18 +150,17 @@ class EnhancerDataset(TaskDataset):
files=files, files=files,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function = matching_function, matching_function=matching_function,
batch_size=batch_size, batch_size=batch_size,
num_workers = num_workers, num_workers=num_workers,
) )
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.files = files self.files = files
self.duration = max(1.0,duration) self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
def setup(self, stage:Optional[str]=None): def setup(self, stage: Optional[str] = None):
super().setup(stage=stage) super().setup(stage=stage)
@ -155,30 +170,51 @@ class EnhancerDataset(TaskDataset):
while True: while True:
file_dict,*_ = rng.choices(self.train_data,k=1, file_dict, *_ = rng.choices(
weights=[file["duration"] for file in self.train_data]) self.train_data,
file_duration = file_dict['duration'] k=1,
start_time = round(rng.uniform(0,file_duration- self.duration),2) weights=[file["duration"] for file in self.train_data],
data = self.prepare_segment(file_dict,start_time) )
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time)
yield data yield data
def val__getitem__(self,idx): def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx]) return self.prepare_segment(*self._validation[idx])
def prepare_segment(self,file_dict:dict, start_time:float): def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(file_dict["clean"], clean_segment = self.audio(
offset=start_time,duration=self.duration) file_dict["clean"], offset=start_time, duration=self.duration
noisy_segment = self.audio(file_dict["noisy"], )
offset=start_time,duration=self.duration) noisy_segment = self.audio(
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1]))) file_dict["noisy"], offset=start_time, duration=self.duration
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1]))) )
return {"clean": clean_segment,"noisy":noisy_segment} clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self): def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)

View File

@ -1,108 +1,118 @@
import glob import glob
import os import os
from re import S
import numpy as np import numpy as np
from scipy.io import wavfile from scipy.io import wavfile
MATCHING_FNS = ("one_to_one","one_to_many") MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions: class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod @staticmethod
def one_to_one(clean_path,noisy_path): def one_to_one(clean_path, noisy_path):
""" """
One clean audio can have only one noisy audio file One clean audio can have only one noisy audio file
""" """
matching_wavfiles = list() matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))] file.split("/")[-1]
common_filenames = np.intersect1d(noisy_filenames,clean_filenames) for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames: for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_clean, clean_file = wavfile.read(
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) os.path.join(clean_path, file_name)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and )
(sr_clean==sr_noisy)): sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append( matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), {
"duration":clean_file.shape[-1]/sr_clean} "clean": os.path.join(clean_path, file_name),
"noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
) )
return matching_wavfiles return matching_wavfiles
@staticmethod @staticmethod
def one_to_many(clean_path,noisy_path): def one_to_many(clean_path, noisy_path):
""" """
One clean audio have multiple noisy audio files One clean audio have multiple noisy audio files
""" """
matching_wavfiles = dict() matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames: for clean_file in clean_filenames:
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav")) noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}.wav")
)
for noisy_file in noisy_filenames: for noisy_file in noisy_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_file = wavfile.read(noisy_file) sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
(sr_clean==sr_noisy)): sr_clean == sr_noisy
):
matching_wavfiles.update( matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, {
"duration":clean_file.shape[-1]/sr_clean} "clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,
"duration": clean_file.shape[-1] / sr_clean,
}
) )
return matching_wavfiles return matching_wavfiles
class Fileprocessor: class Fileprocessor:
def __init__(self, clean_dir, noisy_dir, matching_function=None):
def __init__(
self,
clean_dir,
noisy_dir,
matching_function = None
):
self.clean_dir = clean_dir self.clean_dir = clean_dir
self.noisy_dir = noisy_dir self.noisy_dir = noisy_dir
self.matching_function = matching_function self.matching_function = matching_function
@classmethod @classmethod
def from_name(cls, def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
name:str,
clean_dir,
noisy_dir,
matching_function=None
):
if matching_function is None: if matching_function is None:
if name.lower() == "vctk": if name.lower() == "vctk":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020": elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else: else:
if matching_function not in MATCHING_FNS: if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else: else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self): def prepare_matching_dict(self):
if self.matching_function is None: if self.matching_function is None:
raise ValueError("Not a valid matching function") raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir,self.noisy_dir) return self.matching_function(self.clean_dir, self.noisy_dir)

View File

@ -1,119 +1,168 @@
from json import load from pathlib import Path
import wave from typing import Optional, Union
import numpy as np import numpy as np
from scipy.signal import get_window
from scipy.io import wavfile
from typing import List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pathlib import Path
from librosa import load as load_audio from librosa import load as load_audio
from scipy.io import wavfile
from scipy.signal import get_window
from enhancer.utils import Audio from enhancer.utils import Audio
class Inference: class Inference:
"""
contains methods used for inference.
"""
@staticmethod @staticmethod
def read_input(audio, sr, model_sr): def read_input(audio, sr, model_sr):
"""
read and verify audio input regardless of the input format.
arguments:
audio : audio input
sr : sampling rate of input audio
model_sr : sampling rate used for model training.
"""
if isinstance(audio,(np.ndarray,torch.Tensor)): if isinstance(audio, (np.ndarray, torch.Tensor)):
assert sr is not None, "Invalid sampling rate!" assert sr is not None, "Invalid sampling rate!"
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
if isinstance(audio,str): if isinstance(audio, str):
audio = Path(audio) audio = Path(audio)
if not audio.is_file(): if not audio.is_file():
raise ValueError(f"Input file {audio} does not exist") raise ValueError(f"Input file {audio} does not exist")
else: else:
audio,sr = load_audio(audio,sr=sr,) audio, sr = load_audio(
audio,
sr=sr,
)
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
assert audio.shape[0] == 1, "Enhance inference only supports single waveform" assert (
audio.shape[0] == 1
), "Enhance inference only supports single waveform"
waveform = Audio.resample_audio(audio,sr=sr,target_sr=model_sr) waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
waveform = Audio.convert_mono(waveform) waveform = Audio.convert_mono(waveform)
if isinstance(waveform,np.ndarray): if isinstance(waveform, np.ndarray):
waveform = torch.from_numpy(waveform) waveform = torch.from_numpy(waveform)
return waveform return waveform
@staticmethod @staticmethod
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None): def batchify(
waveform: torch.Tensor,
window_size: int,
step_size: Optional[int] = None,
):
""" """
break input waveform into samples with duration specified. break input waveform into samples with duration specified.(Overlap-add)
arguments:
waveform : audio waveform
window_size : window size used for splitting waveform into batches
step_size : step_size used for splitting waveform into batches
""" """
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" assert (
_,num_samples = waveform.shape waveform.ndim == 2
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
_, num_samples = waveform.shape
waveform = waveform.unsqueeze(-1) waveform = waveform.unsqueeze(-1)
step_size = window_size//2 if step_size is None else step_size step_size = window_size // 2 if step_size is None else step_size
if num_samples >= window_size: if num_samples >= window_size:
waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1), waveform_batch = F.unfold(
stride=(step_size,1), padding=(window_size,0)) waveform[None, ...],
waveform_batch = waveform_batch.permute(2,0,1) kernel_size=(window_size, 1),
stride=(step_size, 1),
padding=(window_size, 0),
)
waveform_batch = waveform_batch.permute(2, 0, 1)
return waveform_batch return waveform_batch
@staticmethod @staticmethod
def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None, def aggreagate(
window="hanning",): data: torch.Tensor,
window_size: int,
total_frames: int,
step_size: Optional[int] = None,
window="hanning",
):
""" """
takes input as tensor outputs aggregated waveform stitch batched waveform into single waveform. (Overlap-add)
arguments:
data: batched waveform
window_size : window_size used to batch waveform
step_size : step_size used to batch waveform
total_frames : total number of frames present in original waveform
window : type of window used for overlap-add mechanism.
""" """
num_chunks,n_channels,num_frames = data.shape num_chunks, n_channels, num_frames = data.shape
window = get_window(window=window,Nx=data.shape[-1]) window = get_window(window=window, Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device) window = torch.from_numpy(window).to(data.device)
data *= window data *= window
step_size = window_size//2 if step_size is None else step_size step_size = window_size // 2 if step_size is None else step_size
data = data.permute(1, 2, 0)
data = F.fold(
data,
(total_frames, 1),
kernel_size=(window_size, 1),
stride=(step_size, 1),
padding=(window_size, 0),
).squeeze(-1)
data = data.permute(1,2,0) return data.reshape(1, n_channels, -1)
data = F.fold(data,
(total_frames,1),
kernel_size=(window_size,1),
stride=(step_size,1),
padding=(window_size,0)).squeeze(-1)
return data.reshape(1,n_channels,-1)
@staticmethod @staticmethod
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int): def write_output(
waveform: torch.Tensor, filename: Union[str, Path], sr: int
):
"""
write audio output as wav file
arguments:
waveform : audio waveform
filename : name of the wave file. Output will be written as cleaned_filename.wav
sr : sampling rate
"""
if isinstance(filename,str): if isinstance(filename, str):
filename = Path(filename) filename = Path(filename)
parent, name = filename.parent, "cleaned_"+filename.name parent, name = filename.parent, "cleaned_" + filename.name
filename = parent/Path(name) filename = parent / Path(name)
if filename.is_file(): if filename.is_file():
raise FileExistsError(f"file {filename} already exists") raise FileExistsError(f"file {filename} already exists")
else: else:
if isinstance(waveform,torch.Tensor): wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
waveform = waveform.detach().cpu().squeeze().numpy()
wavfile.write(filename,rate=sr,data=waveform)
@staticmethod @staticmethod
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int, def prepare_output(
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int] waveform: torch.Tensor,
model_sampling_rate: int,
audio: Union[str, np.ndarray, torch.Tensor],
sampling_rate: Optional[int],
): ):
if isinstance(audio,np.ndarray): """
prepare output audio based on input format
arguments:
waveform : predicted audio waveform
model_sampling_rate : sampling rate used to train the model
audio : input audio
sampling_rate : input audio sampling rate
"""
if isinstance(audio, np.ndarray):
waveform = waveform.detach().cpu().numpy() waveform = waveform.detach().cpu().numpy()
if sampling_rate!=None: if sampling_rate is not None:
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate) waveform = Audio.resample_audio(
waveform, sr=model_sampling_rate, target_sr=sampling_rate
)
return waveform return waveform

View File

@ -3,62 +3,82 @@ import torch.nn as nn
class mean_squared_error(nn.Module): class mean_squared_error(nn.Module):
"""
Mean squared error / L1 loss
"""
def __init__(self,reduction="mean"): def __init__(self, reduction="mean"):
super().__init__() super().__init__()
self.loss_fun = nn.MSELoss(reduction=reduction) self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False self.higher_better = False
def forward(self,prediction:torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3: if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) raise TypeError(
got {prediction.size()} and {target.size()} instead""") f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target) return self.loss_fun(prediction, target)
class mean_absolute_error(nn.Module):
def __init__(self,reduction="mean"): class mean_absolute_error(nn.Module):
"""
Mean absolute error / L2 loss
"""
def __init__(self, reduction="mean"):
super().__init__() super().__init__()
self.loss_fun = nn.L1Loss(reduction=reduction) self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False self.higher_better = False
def forward(self, prediction:torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3: if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) raise TypeError(
got {prediction.size()} and {target.size()} instead""") f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target) return self.loss_fun(prediction, target)
class Si_SDR(nn.Module):
def __init__( class Si_SDR(nn.Module):
self, """
reduction:str="mean" SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
): """
def __init__(self, reduction: str = "mean"):
super().__init__() super().__init__()
if reduction in ["sum","mean",None]: if reduction in ["sum", "mean", None]:
self.reduction = reduction self.reduction = reduction
else: else:
raise TypeError("Invalid reduction, valid options are sum, mean, None") raise TypeError(
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = False self.higher_better = False
def forward(self,prediction:torch.Tensor, target:torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3: if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) raise TypeError(
got {prediction.size()} and {target.size()} instead""") f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
target_energy = torch.sum(target**2,keepdim=True,dim=-1) target_energy = torch.sum(target**2, keepdim=True, dim=-1)
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy scaling_factor = (
torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
)
target_projection = target * scaling_factor target_projection = target * scaling_factor
noise = prediction - target_projection noise = prediction - target_projection
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1) ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
si_sdr = 10*torch.log10(ratio).mean(dim=-1) noise**2, dim=-1
)
si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
if self.reduction == "sum": if self.reduction == "sum":
si_sdr = si_sdr.sum() si_sdr = si_sdr.sum()
@ -70,31 +90,42 @@ class Si_SDR(nn.Module):
return si_sdr return si_sdr
class Avergeloss(nn.Module): class Avergeloss(nn.Module):
"""
Combine multiple metics of same nature.
for example, ["mea","mae"]
parameters:
losses : loss function names to be combined
"""
def __init__(self,losses): def __init__(self, losses):
super().__init__() super().__init__()
self.valid_losses = nn.ModuleList() self.valid_losses = nn.ModuleList()
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] direction = [
getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
]
if len(set(direction)) > 1: if len(set(direction)) > 1:
raise ValueError("all cost functions should be of same nature, maximize or minimize!") raise ValueError(
"all cost functions should be of same nature, maximize or minimize!"
)
self.higher_better = direction[0] self.higher_better = direction[0]
for loss in losses: for loss in losses:
loss = self.validate_loss(loss) loss = self.validate_loss(loss)
self.valid_losses.append(loss()) self.valid_losses.append(loss())
def validate_loss(self, loss: str):
def validate_loss(self,loss:str):
if loss not in LOSS_MAP.keys(): if loss not in LOSS_MAP.keys():
raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}") raise ValueError(
f"""Invalid loss function {loss}, available loss functions are
{tuple([loss for loss in LOSS_MAP.keys()])}"""
)
else: else:
return LOSS_MAP[loss] return LOSS_MAP[loss]
def forward(self,prediction:torch.Tensor, target:torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
loss = 0.0 loss = 0.0
for loss_fun in self.valid_losses: for loss_fun in self.valid_losses:
loss += loss_fun(prediction, target) loss += loss_fun(prediction, target)
@ -102,10 +133,8 @@ class Avergeloss(nn.Module):
return loss return loss
LOSS_MAP = {
"mae": mean_absolute_error,
LOSS_MAP = {"mae":mean_absolute_error,
"mse": mean_squared_error, "mse": mean_squared_error,
"SI-SDR":Si_SDR} "SI-SDR": Si_SDR,
}

View File

@ -1,3 +1,3 @@
from enhancer.models.demucs import Demucs from enhancer.models.demucs import Demucs
from enhancer.models.waveunet import WaveUnet
from enhancer.models.model import Model from enhancer.models.model import Model
from enhancer.models.waveunet import WaveUnet

View File

@ -1,136 +1,173 @@
import logging import logging
from typing import Optional, Union, List
from torch import nn
import torch.nn.functional as F
import math import math
from typing import List, Optional, Union
import torch.nn.functional as F
from torch import nn
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
from enhancer.utils.io import Audio as audio from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module): class DemucsLSTM(nn.Module):
def __init__( def __init__(
self, self,
input_size:int, input_size: int,
hidden_size:int, hidden_size: int,
num_layers:int, num_layers: int,
bidirectional:bool=True bidirectional: bool = True,
): ):
super().__init__() super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) self.lstm = nn.LSTM(
input_size, hidden_size, num_layers, bidirectional=bidirectional
)
dim = 2 if bidirectional else 1 dim = 2 if bidirectional else 1
self.linear = nn.Linear(dim*hidden_size,hidden_size) self.linear = nn.Linear(dim * hidden_size, hidden_size)
def forward(self,x): def forward(self, x):
output,(h,c) = self.lstm(x) output, (h, c) = self.lstm(x)
output = self.linear(output) output = self.linear(output)
return output,(h,c) return output, (h, c)
class DemucsEncoder(nn.Module): class DemucsEncoder(nn.Module):
def __init__( def __init__(
self, self,
num_channels:int, num_channels: int,
hidden_size:int, hidden_size: int,
kernel_size:int, kernel_size: int,
stride:int=1, stride: int = 1,
glu:bool=False, glu: bool = False,
): ):
super().__init__() super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU() activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(num_channels,hidden_size,kernel_size,stride), nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation activation,
) )
def forward(self,waveform): def forward(self, waveform):
return self.encoder(waveform) return self.encoder(waveform)
class DemucsDecoder(nn.Module):
class DemucsDecoder(nn.Module):
def __init__( def __init__(
self, self,
num_channels:int, num_channels: int,
hidden_size:int, hidden_size: int,
kernel_size:int, kernel_size: int,
stride:int=1, stride: int = 1,
glu:bool=False, glu: bool = False,
layer:int=0 layer: int = 0,
): ):
super().__init__() super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU() activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation, activation,
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
) )
if layer>0: if layer > 0:
self.decoder.add_module("4", nn.ReLU()) self.decoder.add_module("4", nn.ReLU())
def forward(self,waveform,): def forward(
self,
waveform,
):
out = self.decoder(waveform) out = self.decoder(waveform)
return out return out
class Demucs(Model): class Demucs(Model):
"""
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters:
encoder_decoder: dict, optional
keyword arguments passsed to encoder decoder block
lstm : dict, optional
keyword arguments passsed to LSTM block
num_channels: int, defaults to 1
number channels in input audio
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
ED_DEFAULTS = { ED_DEFAULTS = {
"initial_output_channels":48, "initial_output_channels": 48,
"kernel_size":8, "kernel_size": 8,
"stride":1, "stride": 1,
"depth":5, "depth": 5,
"glu":True, "glu": True,
"growth_factor":2, "growth_factor": 2,
} }
LSTM_DEFAULTS = { LSTM_DEFAULTS = {
"bidirectional":True, "bidirectional": True,
"num_layers":2, "num_layers": 2,
} }
def __init__( def __init__(
self, self,
encoder_decoder:Optional[dict]=None, encoder_decoder: Optional[dict] = None,
lstm:Optional[dict]=None, lstm: Optional[dict] = None,
num_channels:int=1, num_channels: int = 1,
resample:int=4, resample: int = 4,
sampling_rate = 16000, sampling_rate=16000,
lr:float=1e-3, lr: float = 1e-3,
dataset:Optional[EnhancerDataset]=None, dataset: Optional[EnhancerDataset] = None,
loss:Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str, List] = "mse" metric: Union[str, List] = "mse",
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None: if dataset is not None:
if sampling_rate!=dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels, super().__init__(
sampling_rate=sampling_rate,lr=lr, num_channels=num_channels,
dataset=dataset,duration=duration,loss=loss, metric=metric) sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS,lstm) lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder","lstm","resample") self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"] hidden = encoder_decoder["initial_output_channels"]
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
for layer in range(encoder_decoder["depth"]): for layer in range(encoder_decoder["depth"]):
encoder_layer = DemucsEncoder(num_channels=num_channels, encoder_layer = DemucsEncoder(
num_channels=num_channels,
hidden_size=hidden, hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"], kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"], stride=encoder_decoder["stride"],
@ -138,80 +175,90 @@ class Demucs(Model):
) )
self.encoder.append(encoder_layer) self.encoder.append(encoder_layer)
decoder_layer = DemucsDecoder(num_channels=num_channels, decoder_layer = DemucsDecoder(
num_channels=num_channels,
hidden_size=hidden, hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"], kernel_size=encoder_decoder["kernel_size"],
stride=1, stride=1,
glu=encoder_decoder["glu"], glu=encoder_decoder["glu"],
layer=layer layer=layer,
) )
self.decoder.insert(0,decoder_layer) self.decoder.insert(0, decoder_layer)
num_channels = hidden num_channels = hidden
hidden = self.ED_DEFAULTS["growth_factor"] * hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden
self.de_lstm = DemucsLSTM(input_size=num_channels, self.de_lstm = DemucsLSTM(
input_size=num_channels,
hidden_size=num_channels, hidden_size=num_channels,
num_layers=lstm["num_layers"], num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"] bidirectional=lstm["bidirectional"],
) )
def forward(self,waveform): def forward(self, waveform):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1)!=1: if waveform.size(1) != 1:
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") raise TypeError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
)
length = waveform.shape[-1] length = waveform.shape[-1]
x = F.pad(waveform, (0,self.get_padding_length(length) - length)) x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample>1: if self.hparams.resample > 1:
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, x = audio.resample_audio(
target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) audio=x,
sr=self.hparams.sampling_rate,
target_sr=int(
self.hparams.sampling_rate * self.hparams.resample
),
)
encoder_outputs = [] encoder_outputs = []
for encoder in self.encoder: for encoder in self.encoder:
x = encoder(x) x = encoder(x)
encoder_outputs.append(x) encoder_outputs.append(x)
x = x.permute(0,2,1) x = x.permute(0, 2, 1)
x,_ = self.de_lstm(x) x, _ = self.de_lstm(x)
x = x.permute(0,2,1) x = x.permute(0, 2, 1)
for decoder in self.decoder: for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1) skip_connection = encoder_outputs.pop(-1)
x += skip_connection[..., :x.shape[-1]] x += skip_connection[..., : x.shape[-1]]
x = decoder(x) x = decoder(x)
if self.hparams.resample > 1: if self.hparams.resample > 1:
x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), x = audio.resample_audio(
self.hparams.sampling_rate) x,
int(self.hparams.sampling_rate * self.hparams.resample),
self.hparams.sampling_rate,
)
return x return x
def get_padding_length(self,input_length): def get_padding_length(self, input_length):
input_length = math.ceil(input_length * self.hparams.resample) input_length = math.ceil(input_length * self.hparams.resample)
for layer in range(
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation self.hparams.encoder_decoder["depth"]
input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 ): # encoder operation
input_length = max(1,input_length) input_length = (
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration math.ceil(
input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] (input_length - self.hparams.encoder_decoder["kernel_size"])
input_length = math.ceil(input_length/self.hparams.resample) / self.hparams.encoder_decoder["stride"]
)
+ 1
)
input_length = max(1, input_length)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # decoder operaration
input_length = (input_length - 1) * self.hparams.encoder_decoder[
"stride"
] + self.hparams.encoder_decoder["kernel_size"]
input_length = math.ceil(input_length / self.hparams.resample)
return int(input_length) return int(input_length)

View File

@ -1,50 +1,67 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import logging
import numpy as np
import os import os
from typing import Optional, Union, List, Text, Dict, Any from importlib import import_module
from torch.optim import Adam
import torch
from torch.nn.functional import pad
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
import pytorch_lightning as pl
import torch
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch.optim import Adam
from enhancer import __version__ from enhancer import __version__
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
from enhancer.utils.io import Audio
from enhancer.loss import Avergeloss
from enhancer.inference import Inference from enhancer.inference import Inference
from enhancer.loss import Avergeloss
CACHE_DIR = "" CACHE_DIR = ""
HF_TORCH_WEIGHTS = "" HF_TORCH_WEIGHTS = ""
DEFAULT_DEVICE = "cpu" DEFAULT_DEVICE = "cpu"
class Model(pl.LightningModule): class Model(pl.LightningModule):
"""
Base class for all models
parameters:
num_channels: int, default to 1
number of channels in input audio
sampling_rate : int, default 16khz
audio sampling rate
lr: float, optional
learning rate for model training
dataset: EnhancerDataset, optional
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings, default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
def __init__( def __init__(
self, self,
num_channels:int=1, num_channels: int = 1,
sampling_rate:int=16000, sampling_rate: int = 16000,
lr:float=1e-3, lr: float = 1e-3,
dataset:Optional[EnhancerDataset]=None, dataset: Optional[EnhancerDataset] = None,
duration:Optional[float]=None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse" metric: Union[str, List] = "mse",
): ):
super().__init__() super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models" assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
)
if self.logger: if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") self.logger.experiment.log_dict(
dict(self.hparams), "hyperparameters.json"
)
self.loss = loss self.loss = loss
self.metric = metric self.metric = metric
@ -54,9 +71,9 @@ class Model(pl.LightningModule):
return self._loss return self._loss
@loss.setter @loss.setter
def loss(self,loss): def loss(self, loss):
if isinstance(loss,str): if isinstance(loss, str):
losses = [loss] losses = [loss]
self._loss = Avergeloss(losses) self._loss = Avergeloss(losses)
@ -66,23 +83,22 @@ class Model(pl.LightningModule):
return self._metric return self._metric
@metric.setter @metric.setter
def metric(self,metric): def metric(self, metric):
if isinstance(metric,str): if isinstance(metric, str):
metric = [metric] metric = [metric]
self._metric = Avergeloss(metric) self._metric = Avergeloss(metric)
@property @property
def dataset(self): def dataset(self):
return self._dataset return self._dataset
@dataset.setter @dataset.setter
def dataset(self,dataset): def dataset(self, dataset):
self._dataset = dataset self._dataset = dataset
def setup(self,stage:Optional[str]=None): def setup(self, stage: Optional[str] = None):
if stage == "fit": if stage == "fit":
self.dataset.setup(stage) self.dataset.setup(stage)
self.dataset.model = self self.dataset.model = self
@ -94,9 +110,9 @@ class Model(pl.LightningModule):
return self.dataset.val_dataloader() return self.dataset.val_dataloader()
def configure_optimizers(self): def configure_optimizers(self):
return Adam(self.parameters(), lr = self.hparams.lr) return Adam(self.parameters(), lr=self.hparams.lr)
def training_step(self,batch, batch_idx:int): def training_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"] mixed_waveform = batch["noisy"]
target = batch["clean"] target = batch["clean"]
@ -105,13 +121,16 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target) loss = self.loss(prediction, target)
if self.logger: if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id, self.logger.experiment.log_metric(
key="train_loss", value=loss.item(), run_id=self.logger.run_id,
step=self.global_step) key="train_loss",
self.log("train_loss",loss.item()) value=loss.item(),
return {"loss":loss} step=self.global_step,
)
self.log("train_loss", loss.item())
return {"loss": loss}
def validation_step(self,batch,batch_idx:int): def validation_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"] mixed_waveform = batch["noisy"]
target = batch["clean"] target = batch["clean"]
@ -119,48 +138,92 @@ class Model(pl.LightningModule):
metric_val = self.metric(prediction, target) metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target) loss_val = self.loss(prediction, target)
self.log("val_metric",metric_val.item()) self.log("val_metric", metric_val.item())
self.log("val_loss",loss_val.item()) self.log("val_loss", loss_val.item())
if self.logger: if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id, self.logger.experiment.log_metric(
key="val_loss",value=loss_val.item(), run_id=self.logger.run_id,
step=self.global_step) key="val_loss",
self.logger.experiment.log_metric(run_id=self.logger.run_id, value=loss_val.item(),
key="val_metric",value=metric_val.item(), step=self.global_step,
step=self.global_step) )
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
step=self.global_step,
)
return {"loss":loss_val} return {"loss": loss_val}
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = { checkpoint["enhancer"] = {
"version": { "version": {"enhancer": __version__, "pytorch": torch.__version__},
"enhancer":__version__, "architecture": {
"pytorch":torch.__version__ "module": self.__class__.__module__,
"class": self.__class__.__name__,
}, },
"architecture":{
"module":self.__class__.__module__,
"class":self.__class__.__name__
}
} }
def on_load_checkpoint(self, checkpoint: Dict[str, Any]): def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass pass
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
checkpoint: Union[Path, Text], checkpoint: Union[Path, Text],
map_location = None, map_location=None,
hparams_file: Union[Path, Text] = None, hparams_file: Union[Path, Text] = None,
strict: bool = True, strict: bool = True,
use_auth_token: Union[Text, None] = None, use_auth_token: Union[Text, None] = None,
cached_dir: Union[Path, Text]=CACHE_DIR, cached_dir: Union[Path, Text] = CACHE_DIR,
**kwargs **kwargs,
): ):
"""
Load Pretrained model
parameters:
checkpoint : Path or str
Path to checkpoint, or a remote URL, or a model identifier from
the huggingface.co model hub.
map_location: optional
Same role as in torch.load().
Defaults to `lambda storage, loc: storage`.
hparams_file : Path or str, optional
Path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely wont need this since Lightning will always save the
hyperparameters to the checkpoint. However, if your checkpoint weights
do not have the hyperparameters saved, use this method to pass in a .yaml
file with the hparams you would like to use. These will be converted
into a dict and passed into your Model for use.
strict : bool, optional
Whether to strictly enforce that the keys in checkpoint match
the keys returned by this modules state dict. Defaults to True.
use_auth_token : str, optional
When loading a private huggingface.co model, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Returns
-------
model : Model
Model
See also
--------
torch.load
"""
checkpoint = str(checkpoint) checkpoint = str(checkpoint)
if hparams_file is not None: if hparams_file is not None:
@ -168,7 +231,7 @@ class Model(pl.LightningModule):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
model_path_pl = checkpoint model_path_pl = checkpoint
elif urlparse(checkpoint).scheme in ("http","https"): elif urlparse(checkpoint).scheme in ("http", "https"):
model_path_pl = checkpoint model_path_pl = checkpoint
else: else:
@ -180,17 +243,20 @@ class Model(pl.LightningModule):
revision_id = None revision_id = None
url = hf_hub_url( url = hf_hub_url(
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
) )
model_path_pl = cached_download( model_path_pl = cached_download(
url=url,library_name="enhancer",library_version=__version__, url=url,
cache_dir=cached_dir,use_auth_token=use_auth_token library_name="enhancer",
library_version=__version__,
cache_dir=cached_dir,
use_auth_token=use_auth_token,
) )
if map_location is None: if map_location is None:
map_location = torch.device(DEFAULT_DEVICE) map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location) loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name) module = import_module(module_name)
@ -198,27 +264,38 @@ class Model(pl.LightningModule):
try: try:
model = Klass.load_from_checkpoint( model = Klass.load_from_checkpoint(
checkpoint_path = model_path_pl, checkpoint_path=model_path_pl,
map_location = map_location, map_location=map_location,
hparams_file = hparams_file, hparams_file=hparams_file,
strict = strict, strict=strict,
**kwargs **kwargs,
) )
except Exception as e: except Exception as e:
print(e) print(e)
return model return model
def infer(self,batch:torch.Tensor,batch_size:int=32): def infer(self, batch: torch.Tensor, batch_size: int = 32):
"""
perform model inference
parameters:
batch : torch.Tensor
input data
batch_size : int, default 32
batch size for inference
"""
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" assert (
batch.ndim == 3
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = [] batch_predictions = []
self.eval().to(self.device) self.eval().to(self.device)
with torch.no_grad(): with torch.no_grad():
for batch_id in range(0,batch.shape[0],batch_size): for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
self.device
)
prediction = self(batch_data) prediction = self(batch_data)
batch_predictions.append(prediction) batch_predictions.append(prediction)
@ -226,46 +303,61 @@ class Model(pl.LightningModule):
def enhance( def enhance(
self, self,
audio:Union[Path,np.ndarray,torch.Tensor], audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate:Optional[int]=None, sampling_rate: Optional[int] = None,
batch_size:int=32, batch_size: int = 32,
save_output:bool=False, save_output: bool = False,
duration:Optional[int]=None, duration: Optional[int] = None,
step_size:Optional[int]=None,): step_size: Optional[int] = None,
):
"""
Enhance audio using loaded pretained model.
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate: int, optional incase input is path
sampling rate of input
batch_size: int, default 32
input audio is split into multiple chunks. Inference is done on batches
of these chunks according to given batch size.
save_output : bool, default False
weather to save output to file
duration : float, optional
chunk duration in seconds, defaults to duration of loaded pretrained model.
step_size: int, optional
step size between consecutive durations, defaults to 50% of duration
"""
model_sampling_rate = self.hparams["sampling_rate"] model_sampling_rate = self.hparams["sampling_rate"]
if duration is None: if duration is None:
duration = self.hparams["duration"] duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform = Inference.read_input(
audio, sampling_rate, model_sampling_rate
)
waveform.to(self.device) waveform.to(self.device)
window_size = round(duration * model_sampling_rate) window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) batched_waveform = Inference.batchify(
batch_prediction = self.infer(batched_waveform,batch_size=batch_size) waveform, window_size, step_size=step_size
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) )
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
waveform = Inference.aggreagate(
batch_prediction,
window_size,
waveform.shape[-1],
step_size,
)
if save_output and isinstance(audio,(str,Path)): if save_output and isinstance(audio, (str, Path)):
Inference.write_output(waveform,audio,model_sampling_rate) Inference.write_output(waveform, audio, model_sampling_rate)
else: else:
waveform = Inference.prepare_output(waveform, model_sampling_rate, waveform = Inference.prepare_output(
audio, sampling_rate) waveform, model_sampling_rate, audio, sampling_rate
)
return waveform return waveform
@property @property
def valid_monitor(self): def valid_monitor(self):
return "max" if self.loss.higher_better else "min" return "max" if self.loss.higher_better else "min"

View File

@ -1,82 +1,124 @@
import logging import logging
from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Union, List
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
class WavenetDecoder(nn.Module): class WavenetDecoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels:int, in_channels: int,
out_channels:int, out_channels: int,
kernel_size:int=5, kernel_size: int = 5,
padding:int=2, padding: int = 2,
stride:int=1, stride: int = 1,
dilation:int=1, dilation: int = 1,
): ):
super(WavenetDecoder,self).__init__() super(WavenetDecoder, self).__init__()
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1) nn.LeakyReLU(negative_slope=0.1),
) )
def forward(self,waveform): def forward(self, waveform):
return self.decoder(waveform) return self.decoder(waveform)
class WavenetEncoder(nn.Module):
class WavenetEncoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels:int, in_channels: int,
out_channels:int, out_channels: int,
kernel_size:int=15, kernel_size: int = 15,
padding:int=7, padding: int = 7,
stride:int=1, stride: int = 1,
dilation:int=1, dilation: int = 1,
): ):
super(WavenetEncoder,self).__init__() super(WavenetEncoder, self).__init__()
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1) nn.LeakyReLU(negative_slope=0.1),
) )
def forward(self, waveform):
def forward(
self,
waveform
):
return self.encoder(waveform) return self.encoder(waveform)
class WaveUnet(Model): class WaveUnet(Model):
"""
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
parameters:
num_channels: int, defaults to 1
number of channels in input audio
depth : int, defaults to 12
depth of network
initial_output_channels: int, defaults to 24
number of output channels in intial upsampling layer
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
def __init__( def __init__(
self, self,
num_channels:int=1, num_channels: int = 1,
depth:int=12, depth: int = 12,
initial_output_channels:int=24, initial_output_channels: int = 24,
sampling_rate:int=16000, sampling_rate: int = 16000,
lr:float=1e-3, lr: float = 1e-3,
dataset:Optional[EnhancerDataset]=None, dataset: Optional[EnhancerDataset] = None,
duration:Optional[float]=None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse" metric: Union[str, List] = "mse",
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None: if dataset is not None:
if sampling_rate!=dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels, super().__init__(
sampling_rate=sampling_rate,lr=lr, num_channels=num_channels,
dataset=dataset,duration=duration,loss=loss, metric=metric sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
) )
self.save_hyperparameters("depth") self.save_hyperparameters("depth")
self.encoders = nn.ModuleList() self.encoders = nn.ModuleList()
@ -84,72 +126,76 @@ class WaveUnet(Model):
out_channels = initial_output_channels out_channels = initial_output_channels
for layer in range(depth): for layer in range(depth):
encoder = WavenetEncoder(num_channels,out_channels) encoder = WavenetEncoder(num_channels, out_channels)
self.encoders.append(encoder) self.encoders.append(encoder)
num_channels = out_channels num_channels = out_channels
out_channels += initial_output_channels out_channels += initial_output_channels
if layer == depth -1 : if layer == depth - 1:
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) decoder = WavenetDecoder(
depth * initial_output_channels + num_channels, num_channels
)
else: else:
decoder = WavenetDecoder(num_channels+out_channels,num_channels) decoder = WavenetDecoder(
num_channels + out_channels, num_channels
)
self.decoders.insert(0,decoder) self.decoders.insert(0, decoder)
bottleneck_dim = depth * initial_output_channels bottleneck_dim = depth * initial_output_channels
self.bottleneck = nn.Sequential( self.bottleneck = nn.Sequential(
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
padding=7),
nn.BatchNorm1d(bottleneck_dim), nn.BatchNorm1d(bottleneck_dim),
nn.LeakyReLU(negative_slope=0.1, inplace=True) nn.LeakyReLU(negative_slope=0.1, inplace=True),
) )
self.final = nn.Sequential( self.final = nn.Sequential(
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
nn.Tanh() nn.Tanh(),
) )
def forward(self, waveform):
def forward(
self,waveform
):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1)!=1: if waveform.size(1) != 1:
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") raise TypeError(
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
)
encoder_outputs = [] encoder_outputs = []
out = waveform out = waveform
for encoder in self.encoders: for encoder in self.encoders:
out = encoder(out) out = encoder(out)
encoder_outputs.insert(0,out) encoder_outputs.insert(0, out)
out = out[:,:,::2] out = out[:, :, ::2]
out = self.bottleneck(out) out = self.bottleneck(out)
for layer,decoder in enumerate(self.decoders): for layer, decoder in enumerate(self.decoders):
out = F.interpolate(out, scale_factor=2, mode="linear") out = F.interpolate(out, scale_factor=2, mode="linear")
out = self.fix_last_dim(out,encoder_outputs[layer]) out = self.fix_last_dim(out, encoder_outputs[layer])
out = torch.cat([out,encoder_outputs[layer]],dim=1) out = torch.cat([out, encoder_outputs[layer]], dim=1)
out = decoder(out) out = decoder(out)
out = torch.cat([out, waveform],dim=1) out = torch.cat([out, waveform], dim=1)
out = self.final(out) out = self.final(out)
return out return out
def fix_last_dim(self,x,target): def fix_last_dim(self, x, target):
""" """
trying to do centre crop along last dimension centre crop along last dimension
""" """
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" assert (
x.shape[-1] >= target.shape[-1]
), "input dimension cannot be larger than target dimension"
if x.shape[-1] == target.shape[-1]: if x.shape[-1] == target.shape[-1]:
return x return x
diff = x.shape[-1] - target.shape[-1] diff = x.shape[-1] - target.shape[-1]
if diff%2!=0: if diff % 2 != 0:
x = F.pad(x,(0,1)) x = F.pad(x, (0, 1))
diff += 1 diff += 1
crop = diff//2 crop = diff // 2
return x[:,:,crop:-crop] return x[:, :, crop:-crop]

View File

@ -1,3 +1,3 @@
from enhancer.utils.utils import check_files
from enhancer.utils.io import Audio
from enhancer.utils.config import Files from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.utils import check_files

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class Files: class Files:
train_clean : str train_clean: str
train_noisy : str train_noisy: str
test_clean : str test_clean: str
test_noisy : str test_noisy: str

View File

@ -1,17 +1,26 @@
import os import os
from pathlib import Path
from typing import Optional, Union
import librosa import librosa
from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
class Audio: class Audio:
"""
Audio utils
parameters:
sampling_rate : int, defaults to 16KHz
audio sampling rate
mono: bool, defaults to True
return_tensors: bool, defaults to True
returns torch tensor type if set to True else numpy ndarray
"""
def __init__( def __init__(
self, self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
sampling_rate:int=16000,
mono:bool=True,
return_tensor=True
) -> None: ) -> None:
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@ -20,22 +29,39 @@ class Audio:
def __call__( def __call__(
self, self,
audio, audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate:Optional[int]=None, sampling_rate: Optional[int] = None,
offset:Optional[float] = None, offset: Optional[float] = None,
duration:Optional[float] = None duration: Optional[float] = None,
): ):
if isinstance(audio,str): """
read and process input audio
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate : int, optional
sampling rate of the audio input
offset: float, optional
offset from which the audio must be read, reads from beginning if unused.
duration: float (seconds), optional
read duration, reads full audio starting from offset if not used
"""
if isinstance(audio, str):
if os.path.exists(audio): if os.path.exists(audio):
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, audio, sampling_rate = librosa.load(
offset=offset,duration=duration) audio,
sr=sampling_rate,
mono=False,
offset=offset,
duration=duration,
)
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
raise FileNotFoundError(f"File {audio} deos not exist") raise FileNotFoundError(f"File {audio} deos not exist")
elif isinstance(audio,np.ndarray): elif isinstance(audio, np.ndarray):
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.reshape(1,-1) audio = audio.reshape(1, -1)
else: else:
raise ValueError("audio should be either filepath or numpy ndarray") raise ValueError("audio should be either filepath or numpy ndarray")
@ -43,40 +69,60 @@ class Audio:
audio = self.convert_mono(audio) audio = self.convert_mono(audio)
if sampling_rate: if sampling_rate:
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate
)
if self.return_tensor: if self.return_tensor:
return torch.tensor(audio) return torch.tensor(audio)
else: else:
return audio return audio
@staticmethod @staticmethod
def convert_mono( def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
audio """
convert input audio into mono (1)
parameters:
audio: np.ndarray or torch.Tensor
"""
if len(audio.shape) > 2:
assert (
audio.shape[0] == 1
), "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1], audio.shape[2])
): assert (
if len(audio.shape)>2: audio.shape[1] >> audio.shape[0]
assert audio.shape[0] == 1, "convert mono only accepts single waveform" ), f"expected input format (num_channels,num_samples) got {audio.shape}"
audio = audio.reshape(audio.shape[1],audio.shape[2]) num_channels, num_samples = audio.shape
if num_channels > 1:
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" return audio.mean(axis=0).reshape(1, num_samples)
num_channels,num_samples = audio.shape
if num_channels>1:
return audio.mean(axis=0).reshape(1,num_samples)
return audio return audio
@staticmethod @staticmethod
def resample_audio( def resample_audio(
audio, audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
sr:int,
target_sr:int
): ):
if sr!=target_sr: """
if isinstance(audio,np.ndarray): resample audio to desired sampling rate
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) parameters:
elif isinstance(audio,torch.Tensor): audio : Path to audio file or numpy array or torch tensor
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) audio waveform
sr : int
current sampling rate
target_sr : int
target sampling rate
"""
if sr != target_sr:
if isinstance(audio, np.ndarray):
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
elif isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(
audio, orig_freq=sr, new_freq=target_sr
)
else: else:
raise ValueError("Input should be either numpy array or torch tensor") raise ValueError(
"Input should be either numpy array or torch tensor"
)
return audio return audio

View File

@ -1,19 +1,19 @@
import os import os
import random import random
import torch import torch
def create_unique_rng(epoch: int):
def create_unique_rng(epoch:int):
"""create unique random number generator for each (worker_id,epoch) combination""" """create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random() rng = random.Random()
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
global_rank = int(os.environ.get('GLOBAL_RANK',"0")) global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
local_rank = int(os.environ.get('LOCAL_RANK',"0")) local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get('NODE_RANK',"0")) node_rank = int(os.environ.get("NODE_RANK", "0"))
world_size = int(os.environ.get('WORLD_SIZE',"0")) world_size = int(os.environ.get("WORLD_SIZE", "0"))
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: if worker_info is not None:
@ -34,7 +34,3 @@ def create_unique_rng(epoch:int):
rng.seed(seed) rng.seed(seed)
return rng return rng

View File

@ -1,19 +1,26 @@
import os import os
from typing import Optional from typing import Optional
from enhancer.utils.config import Files from enhancer.utils.config import Files
def check_files(root_dir:str, files:Files):
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] def check_files(root_dir: str, files: Files):
path_variables = [
member_var
for member_var in dir(files)
if not member_var.startswith("__")
]
for variable in path_variables: for variable in path_variables:
path = getattr(files,variable) path = getattr(files, variable)
if not os.path.isdir(os.path.join(root_dir,path)): if not os.path.isdir(os.path.join(root_dir, path)):
raise ValueError(f"Invalid {path}, is not a directory") raise ValueError(f"Invalid {path}, is not a directory")
return files,root_dir return files, root_dir
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
def merge_dict(default_dict:dict, custom:Optional[dict]=None):
params = dict(default_dict) params = dict(default_dict)
if custom: if custom:
params.update(custom) params.update(custom)

15
pyproject.toml Normal file
View File

@ -0,0 +1,15 @@
[tool.black]
line-length = 80
target-version = ['py38']
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.mypy_cache
| \.tox
| \.venv
)/
)
'''

View File

@ -1,15 +1,16 @@
joblib==1.1.0 black>=22.8.0
numpy==1.19.5 boto3>=1.24.86
librosa==0.9.1 flake8>=5.0.4
numpy==1.19.5 huggingface-hu>=0.10.0
hydra-core==1.2.0 hydra-core>=1.2.0
scikit-learn==0.24.2 joblib>=1.2.0
scipy==1.5.4 librosa>=0.9.2
torch==1.10.2 mlflow>=1.29.0
tqdm==4.64.0 numpy>=1.23.3
mlflow==1.23.1 protobuf>=3.19.6
protobuf==3.19.3 pytorch-lightning>=1.7.7
boto3==1.23.9 scikit-learn>=1.1.2
torchaudio==0.10.2 scipy>=1.9.1
huggingface-hub==0.4.0 torch>=1.12.1
pytorch-lightning==1.5.10 torchaudio>=0.12.1
tqdm>=4.64.1

View File

@ -1,31 +1,32 @@
from asyncio import base_tasks
import torch
import pytest import pytest
import torch
from enhancer.loss import mean_absolute_error, mean_squared_error from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()] loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun): def check_loss_shapes_compatibility(loss_fun):
batch_size = 4 batch_size = 4
shape = (1,1000) shape = (1, 1000)
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
with pytest.raises(TypeError): with pytest.raises(TypeError):
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
@pytest.mark.parametrize("loss",loss_functions) @pytest.mark.parametrize("loss", loss_functions)
def test_loss_input_shapes(loss): def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss) check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_output_type(loss): def test_loss_output_type(loss):
batch_size = 4 batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
batch_size, 1, 1000
)
loss_value = loss(prediction, target) loss_value = loss(prediction, target)
assert isinstance(loss_value.item(),float) assert isinstance(loss_value.item(), float)

View File

@ -1,46 +1,43 @@
import pytest import pytest
import torch import torch
from enhancer import data
from enhancer.utils.config import Files
from enhancer.models import Demucs
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
from enhancer.models import Demucs
from enhancer.utils.config import Files
@pytest.fixture @pytest.fixture
def vctk_dataset(): def vctk_dataset():
root_dir = "tests/data/vctk" root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", files = Files(
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") train_clean="clean_testset_wav",
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) def test_forward(batch_size, samples):
def test_forward(batch_size,samples):
model = Demucs() model = Demucs()
model.eval() model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False) data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
_ = model(data) _ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)
@pytest.mark.parametrize("dataset,channels,loss", @pytest.mark.parametrize(
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) "dataset,channels,loss",
def test_demucs_init(dataset,channels,loss): [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad(): with torch.no_grad():
model = Demucs(num_channels=channels,dataset=dataset,loss=loss) _ = Demucs(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -1,46 +1,43 @@
import pytest import pytest
import torch import torch
from enhancer import data
from enhancer.utils.config import Files
from enhancer.models import WaveUnet
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
from enhancer.models import WaveUnet
from enhancer.utils.config import Files
@pytest.fixture @pytest.fixture
def vctk_dataset(): def vctk_dataset():
root_dir = "tests/data/vctk" root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", files = Files(
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") train_clean="clean_testset_wav",
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) def test_forward(batch_size, samples):
def test_forward(batch_size,samples):
model = WaveUnet() model = WaveUnet()
model.eval() model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False) data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
_ = model(data) _ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)
@pytest.mark.parametrize("dataset,channels,loss", @pytest.mark.parametrize(
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) "dataset,channels,loss",
def test_demucs_init(dataset,channels,loss): [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad(): with torch.no_grad():
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) _ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -4,22 +4,26 @@ import torch
from enhancer.inference import Inference from enhancer.inference import Inference
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) @pytest.mark.parametrize(
"audio",
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
)
def test_read_input(audio): def test_read_input(audio):
read_audio = Inference.read_input(audio,48000,16000) read_audio = Inference.read_input(audio, 48000, 16000)
assert isinstance(read_audio,torch.Tensor) assert isinstance(read_audio, torch.Tensor)
assert read_audio.shape[0] == 1 assert read_audio.shape[0] == 1
def test_batchify(): def test_batchify():
rand = torch.rand(1,1000) rand = torch.rand(1, 1000)
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
assert batched_rand.shape[0] == 12 assert batched_rand.shape[0] == 12
def test_aggregate(): def test_aggregate():
rand = torch.rand(12,1,100) rand = torch.rand(12, 1, 100)
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) agg_rand = Inference.aggreagate(
data=rand, window_size=100, total_frames=1000, step_size=100
)
assert agg_rand.shape[-1] == 1000 assert agg_rand.shape[-1] == 1000

View File

@ -1,46 +1,50 @@
from logging import root import numpy as np
import pytest import pytest
import torch import torch
import numpy as np
from enhancer.utils.io import Audio
from enhancer.utils.config import Files
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils.io import Audio
def test_io_channel(): def test_io_channel():
input_audio = np.random.rand(2,32000) input_audio = np.random.rand(2, 32000)
audio = Audio(mono=True,return_tensor=False) audio = Audio(mono=True, return_tensor=False)
output_audio = audio(input_audio) output_audio = audio(input_audio)
assert output_audio.shape[0] == 1 assert output_audio.shape[0] == 1
def test_io_resampling(): def test_io_resampling():
input_audio = np.random.rand(1,32000) input_audio = np.random.rand(1, 32000)
resampled_audio = Audio.resample_audio(input_audio,16000,8000) resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
input_audio = torch.rand(1,32000) input_audio = torch.rand(1, 32000)
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
def test_fileprocessor_vctk(): def test_fileprocessor_vctk():
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", fp = Fileprocessor.from_name(
"tests/data/vctk/noisy_testset_wav",48000) "vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
48000,
)
matching_dict = fp.prepare_matching_dict() matching_dict = fp.prepare_matching_dict()
assert len(matching_dict)==2 assert len(matching_dict) == 2
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name): def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
assert hasattr(fp.matching_function, '__call__') assert hasattr(fp.matching_function, "__call__")
def test_fileprocessor_invaliname(): def test_fileprocessor_invaliname():
with pytest.raises(ValueError): with pytest.raises(ValueError):
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() _ = Fileprocessor.from_name(
"undefined", "clean_dir", "noisy_dir", 16000
).prepare_matching_dict()