Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-06 09:54:14 +05:30
commit a064151e2e
41 changed files with 1268 additions and 810 deletions

9
.flake8 Normal file
View File

@ -0,0 +1,9 @@
[flake8]
per-file-ignores = __init__.py:F401
ignore = E203, E266, E501, W503
# line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9
exclude = tools/kaldi_decoder

43
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,43 @@
repos:
# # Clean Notebooks
# - repo: https://github.com/kynan/nbstripout
# rev: master
# hooks:
# - id: nbstripout
# Format Code
- repo: https://github.com/ambv/black
rev: 22.8.0
hooks:
- id: black
# Sort imports
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://gitlab.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
args: ['--ignore=E203,E501,F811,E712,W503']
# Formatting, Whitespace, etc
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
- id: check-ast
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: mixed-line-ending
args: ['--fix=no']

View File

@ -1 +1,6 @@
# enhancer
# enhancer
Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides
* Various pretrained models nicely integrated with huggingface that users can select and use without any hastle.
* Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
* A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!

View File

@ -1,67 +0,0 @@
from genericpath import isfile
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID","0")
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
loss=parameters.get("loss"), metric = parameters.get("metric"))
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True,
mode=direction,every_n_epochs=1
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience",10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor",0.1),
verbose=True,
min_lr=parameters.get("min_lr",1e-6),
patience=parameters.get("ReduceLr_patience",3)
)
return {"optimizer":optimizer, "lr_scheduler":scheduler}
model.configure_parameters = MethodType(configure_optimizer,model)
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id,saved_location)
if __name__=="__main__":
main()

View File

@ -1 +1 @@
__version__ = "0.0.1"
__version__ = "0.0.1"

85
enhancer/cli/train.py Normal file
View File

@ -0,0 +1,85 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="val_loss",
verbose=True,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
parameters=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
model.configure_parameters = MethodType(configure_optimizer, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
if __name__ == "__main__":
main()

View File

@ -4,4 +4,4 @@ defaults:
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment
- mlflow : experiment

View File

@ -10,4 +10,3 @@ files:
test_clean : clean_test_wav
train_noisy : clean_test_wav
test_noisy : clean_test_wav

View File

@ -10,6 +10,3 @@ files:
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
duration : 1.0
sampling_rate: 16000
batch_size: 64
num_workers : 0
files:
train_clean : clean_testset_wav
test_clean : clean_testset_wav
train_noisy : noisy_testset_wav
test_noisy : noisy_testset_wav

View File

@ -5,4 +5,3 @@ ReduceLr_patience : 5
ReduceLr_factor : 0.1
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +1,2 @@
experiment_name : shahules/enhancer
run_name : baseline
run_name : baseline

View File

@ -14,5 +14,3 @@ encoder_decoder:
lstm:
bidirectional: False
num_layers: 2

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -1,20 +1,21 @@
import multiprocessing
import math
import multiprocessing
import os
import pytorch_lightning as pl
from torch.utils.data import IterableDataset, DataLoader, Dataset
import torch.nn.functional as F
from typing import Optional
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils.random import create_unique_rng
from enhancer.utils.io import Audio
from enhancer.utils import check_files
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
class TrainDataset(IterableDataset):
def __init__(self,dataset):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
@ -23,88 +24,102 @@ class TrainDataset(IterableDataset):
def __len__(self):
return self.dataset.train__len__()
class ValidDataset(Dataset):
def __init__(self,dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self,idx):
def __getitem__(self, idx):
return self.dataset.val__getitem__(idx)
def __len__(self):
return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name:str,
root_dir:str,
files:Files,
duration:float=1.0,
sampling_rate:int=48000,
matching_function = None,
name: str,
root_dir: str,
files: Files,
duration: float = 1.0,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
num_workers:Optional[int]=None):
num_workers: Optional[int] = None,
):
super().__init__()
self.name = name
self.files,self.root_dir = check_files(root_dir,files)
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count()//2
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
if stage in ("fit",None):
if stage in ("fit", None):
train_clean = os.path.join(self.root_dir,self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean,
train_noisy, self.matching_function)
train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir,self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean,
val_noisy, self.matching_function)
val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict()
for item in val_data:
clean,noisy,total_dur = item.values()
clean, noisy, total_dur = item.values()
if total_dur < self.duration:
continue
num_segments = round(total_dur/self.duration)
num_segments = round(total_dur / self.duration)
for index in range(num_segments):
start_time = index * self.duration
self._validation.append(({"clean":clean,"noisy":noisy},
start_time))
self._validation.append(
({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset):
"""
Dataset object for creating clean-noisy speech enhancement datasets
paramters:
name : str
name of the dataset
name of the dataset
root_dir : str
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset)
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass)
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
desired sampling rate
batch_size : int
batch size of each batch
num_workers : int
@ -114,71 +129,92 @@ class EnhancerDataset(TaskDataset):
use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
self,
name:str,
root_dir:str,
files:Files,
name: str,
root_dir: str,
files: Files,
duration=1.0,
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers:Optional[int]=None):
num_workers: Optional[int] = None,
):
super().__init__(
name=name,
root_dir=root_dir,
files=files,
sampling_rate=sampling_rate,
duration=duration,
matching_function = matching_function,
matching_function=matching_function,
batch_size=batch_size,
num_workers = num_workers,
num_workers=num_workers,
)
self.sampling_rate = sampling_rate
self.files = files
self.duration = max(1.0,duration)
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
def setup(self, stage: Optional[str] = None):
def setup(self, stage:Optional[str]=None):
super().setup(stage=stage)
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
rng = create_unique_rng(self.model.current_epoch)
while True:
file_dict,*_ = rng.choices(self.train_data,k=1,
weights=[file["duration"] for file in self.train_data])
file_duration = file_dict['duration']
start_time = round(rng.uniform(0,file_duration- self.duration),2)
data = self.prepare_segment(file_dict,start_time)
file_dict, *_ = rng.choices(
self.train_data,
k=1,
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time)
yield data
def val__getitem__(self,idx):
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
def prepare_segment(self,file_dict:dict, start_time:float):
clean_segment = self.audio(file_dict["clean"],
offset=start_time,duration=self.duration)
noisy_segment = self.audio(file_dict["noisy"],
offset=start_time,duration=self.duration)
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
return {"clean": clean_segment,"noisy":noisy_segment}
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
noisy_segment = self.audio(
file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self):
return len(self._validation)

View File

@ -1,108 +1,118 @@
import glob
import os
from re import S
import numpy as np
from scipy.io import wavfile
MATCHING_FNS = ("one_to_one","one_to_many")
MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod
def one_to_one(clean_path,noisy_path):
def one_to_one(clean_path, noisy_path):
"""
One clean audio can have only one noisy audio file
"""
matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))]
common_filenames = np.intersect1d(noisy_filenames,clean_filenames)
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr_noisy)):
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, file_name)
)
sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
"duration":clean_file.shape[-1]/sr_clean}
)
{
"clean": os.path.join(clean_path, file_name),
"noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
@staticmethod
def one_to_many(clean_path,noisy_path):
def one_to_many(clean_path, noisy_path):
"""
One clean audio have multiple noisy audio files
"""
matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames:
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav"))
noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}.wav")
)
for noisy_file in noisy_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr_noisy)):
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
"duration":clean_file.shape[-1]/sr_clean}
)
{
"clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
class Fileprocessor:
def __init__(
self,
clean_dir,
noisy_dir,
matching_function = None
):
def __init__(self, clean_dir, noisy_dir, matching_function=None):
self.clean_dir = clean_dir
self.noisy_dir = noisy_dir
self.matching_function = matching_function
@classmethod
def from_name(cls,
name:str,
clean_dir,
noisy_dir,
matching_function=None
):
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None:
if name.lower() == "vctk":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else:
if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self):
if self.matching_function is None:
raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir,self.noisy_dir)
return self.matching_function(self.clean_dir, self.noisy_dir)

View File

@ -1,119 +1,168 @@
from json import load
import wave
from pathlib import Path
from typing import Optional, Union
import numpy as np
from scipy.signal import get_window
from scipy.io import wavfile
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
from pathlib import Path
from librosa import load as load_audio
from scipy.io import wavfile
from scipy.signal import get_window
from enhancer.utils import Audio
class Inference:
"""
contains methods used for inference.
"""
@staticmethod
def read_input(audio, sr, model_sr):
"""
read and verify audio input regardless of the input format.
arguments:
audio : audio input
sr : sampling rate of input audio
model_sr : sampling rate used for model training.
"""
if isinstance(audio,(np.ndarray,torch.Tensor)):
if isinstance(audio, (np.ndarray, torch.Tensor)):
assert sr is not None, "Invalid sampling rate!"
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
audio = audio.reshape(1, -1)
if isinstance(audio,str):
if isinstance(audio, str):
audio = Path(audio)
if not audio.is_file():
raise ValueError(f"Input file {audio} does not exist")
else:
audio,sr = load_audio(audio,sr=sr,)
audio, sr = load_audio(
audio,
sr=sr,
)
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
audio = audio.reshape(1, -1)
else:
assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
assert (
audio.shape[0] == 1
), "Enhance inference only supports single waveform"
waveform = Audio.resample_audio(audio,sr=sr,target_sr=model_sr)
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
waveform = Audio.convert_mono(waveform)
if isinstance(waveform,np.ndarray):
if isinstance(waveform, np.ndarray):
waveform = torch.from_numpy(waveform)
return waveform
@staticmethod
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
def batchify(
waveform: torch.Tensor,
window_size: int,
step_size: Optional[int] = None,
):
"""
break input waveform into samples with duration specified.
break input waveform into samples with duration specified.(Overlap-add)
arguments:
waveform : audio waveform
window_size : window size used for splitting waveform into batches
step_size : step_size used for splitting waveform into batches
"""
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
_,num_samples = waveform.shape
assert (
waveform.ndim == 2
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
_, num_samples = waveform.shape
waveform = waveform.unsqueeze(-1)
step_size = window_size//2 if step_size is None else step_size
step_size = window_size // 2 if step_size is None else step_size
if num_samples >= window_size:
waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1),
stride=(step_size,1), padding=(window_size,0))
waveform_batch = waveform_batch.permute(2,0,1)
waveform_batch = F.unfold(
waveform[None, ...],
kernel_size=(window_size, 1),
stride=(step_size, 1),
padding=(window_size, 0),
)
waveform_batch = waveform_batch.permute(2, 0, 1)
return waveform_batch
@staticmethod
def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None,
window="hanning",):
def aggreagate(
data: torch.Tensor,
window_size: int,
total_frames: int,
step_size: Optional[int] = None,
window="hanning",
):
"""
takes input as tensor outputs aggregated waveform
stitch batched waveform into single waveform. (Overlap-add)
arguments:
data: batched waveform
window_size : window_size used to batch waveform
step_size : step_size used to batch waveform
total_frames : total number of frames present in original waveform
window : type of window used for overlap-add mechanism.
"""
num_chunks,n_channels,num_frames = data.shape
window = get_window(window=window,Nx=data.shape[-1])
num_chunks, n_channels, num_frames = data.shape
window = get_window(window=window, Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device)
data *= window
step_size = window_size//2 if step_size is None else step_size
step_size = window_size // 2 if step_size is None else step_size
data = data.permute(1, 2, 0)
data = F.fold(
data,
(total_frames, 1),
kernel_size=(window_size, 1),
stride=(step_size, 1),
padding=(window_size, 0),
).squeeze(-1)
data = data.permute(1,2,0)
data = F.fold(data,
(total_frames,1),
kernel_size=(window_size,1),
stride=(step_size,1),
padding=(window_size,0)).squeeze(-1)
return data.reshape(1,n_channels,-1)
return data.reshape(1, n_channels, -1)
@staticmethod
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):
def write_output(
waveform: torch.Tensor, filename: Union[str, Path], sr: int
):
"""
write audio output as wav file
arguments:
waveform : audio waveform
filename : name of the wave file. Output will be written as cleaned_filename.wav
sr : sampling rate
"""
if isinstance(filename,str):
if isinstance(filename, str):
filename = Path(filename)
parent, name = filename.parent, "cleaned_"+filename.name
filename = parent/Path(name)
parent, name = filename.parent, "cleaned_" + filename.name
filename = parent / Path(name)
if filename.is_file():
raise FileExistsError(f"file {filename} already exists")
else:
if isinstance(waveform,torch.Tensor):
waveform = waveform.detach().cpu().squeeze().numpy()
wavfile.write(filename,rate=sr,data=waveform)
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
@staticmethod
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
def prepare_output(
waveform: torch.Tensor,
model_sampling_rate: int,
audio: Union[str, np.ndarray, torch.Tensor],
sampling_rate: Optional[int],
):
if isinstance(audio,np.ndarray):
"""
prepare output audio based on input format
arguments:
waveform : predicted audio waveform
model_sampling_rate : sampling rate used to train the model
audio : input audio
sampling_rate : input audio sampling rate
"""
if isinstance(audio, np.ndarray):
waveform = waveform.detach().cpu().numpy()
if sampling_rate!=None:
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
if sampling_rate is not None:
waveform = Audio.resample_audio(
waveform, sr=model_sampling_rate, target_sr=sampling_rate
)
return waveform
return waveform

View File

@ -3,62 +3,82 @@ import torch.nn as nn
class mean_squared_error(nn.Module):
"""
Mean squared error / L1 loss
"""
def __init__(self,reduction="mean"):
def __init__(self, reduction="mean"):
super().__init__()
self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead""")
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
class mean_absolute_error(nn.Module):
def __init__(self,reduction="mean"):
class mean_absolute_error(nn.Module):
"""
Mean absolute error / L2 loss
"""
def __init__(self, reduction="mean"):
super().__init__()
self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead""")
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
class Si_SDR(nn.Module):
def __init__(
self,
reduction:str="mean"
):
class Si_SDR(nn.Module):
"""
SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
"""
def __init__(self, reduction: str = "mean"):
super().__init__()
if reduction in ["sum","mean",None]:
if reduction in ["sum", "mean", None]:
self.reduction = reduction
else:
raise TypeError("Invalid reduction, valid options are sum, mean, None")
raise TypeError(
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = False
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead""")
target_energy = torch.sum(target**2,keepdim=True,dim=-1)
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
target_energy = torch.sum(target**2, keepdim=True, dim=-1)
scaling_factor = (
torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
)
target_projection = target * scaling_factor
noise = prediction - target_projection
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1)
si_sdr = 10*torch.log10(ratio).mean(dim=-1)
ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
noise**2, dim=-1
)
si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
if self.reduction == "sum":
si_sdr = si_sdr.sum()
@ -66,46 +86,55 @@ class Si_SDR(nn.Module):
si_sdr = si_sdr.mean()
else:
pass
return si_sdr
class Avergeloss(nn.Module):
"""
Combine multiple metics of same nature.
for example, ["mea","mae"]
parameters:
losses : loss function names to be combined
"""
def __init__(self,losses):
def __init__(self, losses):
super().__init__()
self.valid_losses = nn.ModuleList()
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
direction = [
getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
]
if len(set(direction)) > 1:
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
raise ValueError(
"all cost functions should be of same nature, maximize or minimize!"
)
self.higher_better = direction[0]
for loss in losses:
loss = self.validate_loss(loss)
self.valid_losses.append(loss())
def validate_loss(self,loss:str):
def validate_loss(self, loss: str):
if loss not in LOSS_MAP.keys():
raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}")
raise ValueError(
f"""Invalid loss function {loss}, available loss functions are
{tuple([loss for loss in LOSS_MAP.keys()])}"""
)
else:
return LOSS_MAP[loss]
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
loss = 0.0
for loss_fun in self.valid_losses:
loss += loss_fun(prediction, target)
return loss
LOSS_MAP = {"mae":mean_absolute_error,
"mse": mean_squared_error,
"SI-SDR":Si_SDR}
LOSS_MAP = {
"mae": mean_absolute_error,
"mse": mean_squared_error,
"SI-SDR": Si_SDR,
}

View File

@ -1,3 +1,3 @@
from enhancer.models.demucs import Demucs
from enhancer.models.model import Model
from enhancer.models.waveunet import WaveUnet
from enhancer.models.model import Model

View File

@ -1,217 +1,264 @@
import logging
from typing import Optional, Union, List
from torch import nn
import torch.nn.functional as F
import math
from typing import List, Optional, Union
import torch.nn.functional as F
from torch import nn
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module):
def __init__(
self,
input_size:int,
hidden_size:int,
num_layers:int,
bidirectional:bool=True
input_size: int,
hidden_size: int,
num_layers: int,
bidirectional: bool = True,
):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
self.lstm = nn.LSTM(
input_size, hidden_size, num_layers, bidirectional=bidirectional
)
dim = 2 if bidirectional else 1
self.linear = nn.Linear(dim*hidden_size,hidden_size)
self.linear = nn.Linear(dim * hidden_size, hidden_size)
def forward(self,x):
def forward(self, x):
output,(h,c) = self.lstm(x)
output, (h, c) = self.lstm(x)
output = self.linear(output)
return output,(h,c)
return output, (h, c)
class DemucsEncoder(nn.Module):
def __init__(
self,
num_channels:int,
hidden_size:int,
kernel_size:int,
stride:int=1,
glu:bool=False,
num_channels: int,
hidden_size: int,
kernel_size: int,
stride: int = 1,
glu: bool = False,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.encoder = nn.Sequential(
nn.Conv1d(num_channels,hidden_size,kernel_size,stride),
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1),
activation
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation,
)
def forward(self,waveform):
def forward(self, waveform):
return self.encoder(waveform)
class DemucsDecoder(nn.Module):
class DemucsDecoder(nn.Module):
def __init__(
self,
num_channels:int,
hidden_size:int,
kernel_size:int,
stride:int=1,
glu:bool=False,
layer:int=0
num_channels: int,
hidden_size: int,
kernel_size: int,
stride: int = 1,
glu: bool = False,
layer: int = 0,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.decoder = nn.Sequential(
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1),
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation,
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride)
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
)
if layer>0:
if layer > 0:
self.decoder.add_module("4", nn.ReLU())
def forward(self,waveform,):
def forward(
self,
waveform,
):
out = self.decoder(waveform)
return out
class Demucs(Model):
"""
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters:
encoder_decoder: dict, optional
keyword arguments passsed to encoder decoder block
lstm : dict, optional
keyword arguments passsed to LSTM block
num_channels: int, defaults to 1
number channels in input audio
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
ED_DEFAULTS = {
"initial_output_channels":48,
"kernel_size":8,
"stride":1,
"depth":5,
"glu":True,
"growth_factor":2,
"initial_output_channels": 48,
"kernel_size": 8,
"stride": 1,
"depth": 5,
"glu": True,
"growth_factor": 2,
}
LSTM_DEFAULTS = {
"bidirectional":True,
"num_layers":2,
"bidirectional": True,
"num_layers": 2,
}
def __init__(
self,
encoder_decoder:Optional[dict]=None,
lstm:Optional[dict]=None,
num_channels:int=1,
resample:int=4,
sampling_rate = 16000,
lr:float=1e-3,
dataset:Optional[EnhancerDataset]=None,
loss:Union[str, List] = "mse",
metric:Union[str, List] = "mse"
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
num_channels: int = 1,
resample: int = 4,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
if sampling_rate != dataset.sampling_rate:
logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric)
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
self.save_hyperparameters("encoder_decoder","lstm","resample")
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"]
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for layer in range(encoder_decoder["depth"]):
encoder_layer = DemucsEncoder(num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"],
)
encoder_layer = DemucsEncoder(
num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"],
)
self.encoder.append(encoder_layer)
decoder_layer = DemucsDecoder(num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=1,
glu=encoder_decoder["glu"],
layer=layer
)
self.decoder.insert(0,decoder_layer)
decoder_layer = DemucsDecoder(
num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=1,
glu=encoder_decoder["glu"],
layer=layer,
)
self.decoder.insert(0, decoder_layer)
num_channels = hidden
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
self.de_lstm = DemucsLSTM(input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"]
)
def forward(self,waveform):
self.de_lstm = DemucsLSTM(
input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"],
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1)!=1:
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels")
if waveform.size(1) != 1:
raise TypeError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
)
length = waveform.shape[-1]
x = F.pad(waveform, (0,self.get_padding_length(length) - length))
if self.hparams.resample>1:
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample > 1:
x = audio.resample_audio(
audio=x,
sr=self.hparams.sampling_rate,
target_sr=int(
self.hparams.sampling_rate * self.hparams.resample
),
)
encoder_outputs = []
for encoder in self.encoder:
x = encoder(x)
encoder_outputs.append(x)
x = x.permute(0,2,1)
x,_ = self.de_lstm(x)
x = x.permute(0, 2, 1)
x, _ = self.de_lstm(x)
x = x.permute(0,2,1)
x = x.permute(0, 2, 1)
for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1)
x += skip_connection[..., :x.shape[-1]]
x += skip_connection[..., : x.shape[-1]]
x = decoder(x)
if self.hparams.resample > 1:
x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
self.hparams.sampling_rate)
x = audio.resample_audio(
x,
int(self.hparams.sampling_rate * self.hparams.resample),
self.hparams.sampling_rate,
)
return x
def get_padding_length(self,input_length):
def get_padding_length(self, input_length):
input_length = math.ceil(input_length * self.hparams.resample)
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation
input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1
input_length = max(1,input_length)
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration
input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"]
input_length = math.ceil(input_length/self.hparams.resample)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # encoder operation
input_length = (
math.ceil(
(input_length - self.hparams.encoder_decoder["kernel_size"])
/ self.hparams.encoder_decoder["stride"]
)
+ 1
)
input_length = max(1, input_length)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # decoder operaration
input_length = (input_length - 1) * self.hparams.encoder_decoder[
"stride"
] + self.hparams.encoder_decoder["kernel_size"]
input_length = math.ceil(input_length / self.hparams.resample)
return int(input_length)

View File

@ -1,92 +1,108 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import logging
import numpy as np
import os
from typing import Optional, Union, List, Text, Dict, Any
from torch.optim import Adam
import torch
from torch.nn.functional import pad
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
import pytorch_lightning as pl
import torch
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch.optim import Adam
from enhancer import __version__
from enhancer.data.dataset import EnhancerDataset
from enhancer.utils.io import Audio
from enhancer.loss import Avergeloss
from enhancer.inference import Inference
from enhancer.loss import Avergeloss
CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
DEFAULT_DEVICE = "cpu"
class Model(pl.LightningModule):
"""
Base class for all models
parameters:
num_channels: int, default to 1
number of channels in input audio
sampling_rate : int, default 16khz
audio sampling rate
lr: float, optional
learning rate for model training
dataset: EnhancerDataset, optional
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings, default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
def __init__(
self,
num_channels:int=1,
sampling_rate:int=16000,
lr:float=1e-3,
dataset:Optional[EnhancerDataset]=None,
duration:Optional[float]=None,
num_channels: int = 1,
sampling_rate: int = 16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse"
metric: Union[str, List] = "mse",
):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
)
if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
self.logger.experiment.log_dict(
dict(self.hparams), "hyperparameters.json"
)
self.loss = loss
self.metric = metric
@property
def loss(self):
return self._loss
@loss.setter
def loss(self,loss):
if isinstance(loss,str):
losses = [loss]
@loss.setter
def loss(self, loss):
if isinstance(loss, str):
losses = [loss]
self._loss = Avergeloss(losses)
@property
def metric(self):
return self._metric
@metric.setter
def metric(self,metric):
def metric(self, metric):
if isinstance(metric, str):
metric = [metric]
if isinstance(metric,str):
metric = [metric]
self._metric = Avergeloss(metric)
@property
def dataset(self):
return self._dataset
@dataset.setter
def dataset(self,dataset):
def dataset(self, dataset):
self._dataset = dataset
def setup(self,stage:Optional[str]=None):
def setup(self, stage: Optional[str] = None):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
def train_dataloader(self):
return self.dataset.train_dataloader()
@ -94,9 +110,9 @@ class Model(pl.LightningModule):
return self.dataset.val_dataloader()
def configure_optimizers(self):
return Adam(self.parameters(), lr = self.hparams.lr)
return Adam(self.parameters(), lr=self.hparams.lr)
def training_step(self,batch, batch_idx:int):
def training_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"]
target = batch["clean"]
@ -105,13 +121,16 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target)
if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="train_loss", value=loss.item(),
step=self.global_step)
self.log("train_loss",loss.item())
return {"loss":loss}
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss",
value=loss.item(),
step=self.global_step,
)
self.log("train_loss", loss.item())
return {"loss": loss}
def validation_step(self,batch,batch_idx:int):
def validation_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"]
target = batch["clean"]
@ -119,48 +138,92 @@ class Model(pl.LightningModule):
metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target)
self.log("val_metric",metric_val.item())
self.log("val_loss",loss_val.item())
self.log("val_metric", metric_val.item())
self.log("val_loss", loss_val.item())
if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_loss",value=loss_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_metric",value=metric_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_loss",
value=loss_val.item(),
step=self.global_step,
)
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
step=self.global_step,
)
return {"loss":loss_val}
return {"loss": loss_val}
def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = {
"version": {
"enhancer":__version__,
"pytorch":torch.__version__
"version": {"enhancer": __version__, "pytorch": torch.__version__},
"architecture": {
"module": self.__class__.__module__,
"class": self.__class__.__name__,
},
"architecture":{
"module":self.__class__.__module__,
"class":self.__class__.__name__
}
}
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod
def from_pretrained(
cls,
checkpoint: Union[Path, Text],
map_location = None,
map_location=None,
hparams_file: Union[Path, Text] = None,
strict: bool = True,
use_auth_token: Union[Text, None] = None,
cached_dir: Union[Path, Text]=CACHE_DIR,
**kwargs
cached_dir: Union[Path, Text] = CACHE_DIR,
**kwargs,
):
"""
Load Pretrained model
parameters:
checkpoint : Path or str
Path to checkpoint, or a remote URL, or a model identifier from
the huggingface.co model hub.
map_location: optional
Same role as in torch.load().
Defaults to `lambda storage, loc: storage`.
hparams_file : Path or str, optional
Path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely wont need this since Lightning will always save the
hyperparameters to the checkpoint. However, if your checkpoint weights
do not have the hyperparameters saved, use this method to pass in a .yaml
file with the hparams you would like to use. These will be converted
into a dict and passed into your Model for use.
strict : bool, optional
Whether to strictly enforce that the keys in checkpoint match
the keys returned by this modules state dict. Defaults to True.
use_auth_token : str, optional
When loading a private huggingface.co model, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Returns
-------
model : Model
Model
See also
--------
torch.load
"""
checkpoint = str(checkpoint)
if hparams_file is not None:
@ -168,104 +231,133 @@ class Model(pl.LightningModule):
if os.path.isfile(checkpoint):
model_path_pl = checkpoint
elif urlparse(checkpoint).scheme in ("http","https"):
elif urlparse(checkpoint).scheme in ("http", "https"):
model_path_pl = checkpoint
else:
if "@" in checkpoint:
model_id = checkpoint.split("@")[0]
revision_id = checkpoint.split("@")[1]
else:
model_id = checkpoint
revision_id = None
url = hf_hub_url(
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
)
model_path_pl = cached_download(
url=url,library_name="enhancer",library_version=__version__,
cache_dir=cached_dir,use_auth_token=use_auth_token
url=url,
library_name="enhancer",
library_version=__version__,
cache_dir=cached_dir,
use_auth_token=use_auth_token,
)
if map_location is None:
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location)
loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)
try:
model = Klass.load_from_checkpoint(
checkpoint_path = model_path_pl,
map_location = map_location,
hparams_file = hparams_file,
strict = strict,
**kwargs
checkpoint_path=model_path_pl,
map_location=map_location,
hparams_file=hparams_file,
strict=strict,
**kwargs,
)
except Exception as e:
print(e)
return model
return model
def infer(self, batch: torch.Tensor, batch_size: int = 32):
"""
perform model inference
parameters:
batch : torch.Tensor
input data
batch_size : int, default 32
batch size for inference
"""
def infer(self,batch:torch.Tensor,batch_size:int=32):
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
assert (
batch.ndim == 3
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = []
self.eval().to(self.device)
with torch.no_grad():
for batch_id in range(0,batch.shape[0],batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
self.device
)
prediction = self(batch_data)
batch_predictions.append(prediction)
return torch.vstack(batch_predictions)
def enhance(
self,
audio:Union[Path,np.ndarray,torch.Tensor],
sampling_rate:Optional[int]=None,
batch_size:int=32,
save_output:bool=False,
duration:Optional[int]=None,
step_size:Optional[int]=None,):
audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
batch_size: int = 32,
save_output: bool = False,
duration: Optional[int] = None,
step_size: Optional[int] = None,
):
"""
Enhance audio using loaded pretained model.
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate: int, optional incase input is path
sampling rate of input
batch_size: int, default 32
input audio is split into multiple chunks. Inference is done on batches
of these chunks according to given batch size.
save_output : bool, default False
weather to save output to file
duration : float, optional
chunk duration in seconds, defaults to duration of loaded pretrained model.
step_size: int, optional
step size between consecutive durations, defaults to 50% of duration
"""
model_sampling_rate = self.hparams["sampling_rate"]
if duration is None:
duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform = Inference.read_input(
audio, sampling_rate, model_sampling_rate
)
waveform.to(self.device)
window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
if save_output and isinstance(audio,(str,Path)):
Inference.write_output(waveform,audio,model_sampling_rate)
batched_waveform = Inference.batchify(
waveform, window_size, step_size=step_size
)
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
waveform = Inference.aggreagate(
batch_prediction,
window_size,
waveform.shape[-1],
step_size,
)
if save_output and isinstance(audio, (str, Path)):
Inference.write_output(waveform, audio, model_sampling_rate)
else:
waveform = Inference.prepare_output(waveform, model_sampling_rate,
audio, sampling_rate)
waveform = Inference.prepare_output(
waveform, model_sampling_rate, audio, sampling_rate
)
return waveform
@property
def valid_monitor(self):
return "max" if self.loss.higher_better else "min"
return "max" if self.loss.higher_better else "min"

View File

@ -1,82 +1,124 @@
import logging
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, List
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
class WavenetDecoder(nn.Module):
def __init__(
self,
in_channels:int,
out_channels:int,
kernel_size:int=5,
padding:int=2,
stride:int=1,
dilation:int=1,
in_channels: int,
out_channels: int,
kernel_size: int = 5,
padding: int = 2,
stride: int = 1,
dilation: int = 1,
):
super(WavenetDecoder,self).__init__()
super(WavenetDecoder, self).__init__()
self.decoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1)
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self,waveform):
def forward(self, waveform):
return self.decoder(waveform)
class WavenetEncoder(nn.Module):
class WavenetEncoder(nn.Module):
def __init__(
self,
in_channels:int,
out_channels:int,
kernel_size:int=15,
padding:int=7,
stride:int=1,
dilation:int=1,
in_channels: int,
out_channels: int,
kernel_size: int = 15,
padding: int = 7,
stride: int = 1,
dilation: int = 1,
):
super(WavenetEncoder,self).__init__()
super(WavenetEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1)
nn.LeakyReLU(negative_slope=0.1),
)
def forward(
self,
waveform
):
def forward(self, waveform):
return self.encoder(waveform)
class WaveUnet(Model):
"""
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
parameters:
num_channels: int, defaults to 1
number of channels in input audio
depth : int, defaults to 12
depth of network
initial_output_channels: int, defaults to 24
number of output channels in intial upsampling layer
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
def __init__(
self,
num_channels:int=1,
depth:int=12,
initial_output_channels:int=24,
sampling_rate:int=16000,
lr:float=1e-3,
dataset:Optional[EnhancerDataset]=None,
duration:Optional[float]=None,
num_channels: int = 1,
depth: int = 12,
initial_output_channels: int = 24,
sampling_rate: int = 16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse"
metric: Union[str, List] = "mse",
):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
if sampling_rate != dataset.sampling_rate:
logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
self.save_hyperparameters("depth")
self.encoders = nn.ModuleList()
@ -84,72 +126,76 @@ class WaveUnet(Model):
out_channels = initial_output_channels
for layer in range(depth):
encoder = WavenetEncoder(num_channels,out_channels)
encoder = WavenetEncoder(num_channels, out_channels)
self.encoders.append(encoder)
num_channels = out_channels
out_channels += initial_output_channels
if layer == depth -1 :
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels)
if layer == depth - 1:
decoder = WavenetDecoder(
depth * initial_output_channels + num_channels, num_channels
)
else:
decoder = WavenetDecoder(num_channels+out_channels,num_channels)
decoder = WavenetDecoder(
num_channels + out_channels, num_channels
)
self.decoders.insert(0,decoder)
self.decoders.insert(0, decoder)
bottleneck_dim = depth * initial_output_channels
self.bottleneck = nn.Sequential(
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
padding=7),
nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
nn.BatchNorm1d(bottleneck_dim),
nn.LeakyReLU(negative_slope=0.1, inplace=True)
nn.LeakyReLU(negative_slope=0.1, inplace=True),
)
self.final = nn.Sequential(
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
nn.Tanh()
nn.Tanh(),
)
def forward(
self,waveform
):
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1)!=1:
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels")
if waveform.size(1) != 1:
raise TypeError(
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
)
encoder_outputs = []
out = waveform
for encoder in self.encoders:
out = encoder(out)
encoder_outputs.insert(0,out)
out = out[:,:,::2]
encoder_outputs.insert(0, out)
out = out[:, :, ::2]
out = self.bottleneck(out)
for layer,decoder in enumerate(self.decoders):
for layer, decoder in enumerate(self.decoders):
out = F.interpolate(out, scale_factor=2, mode="linear")
out = self.fix_last_dim(out,encoder_outputs[layer])
out = torch.cat([out,encoder_outputs[layer]],dim=1)
out = self.fix_last_dim(out, encoder_outputs[layer])
out = torch.cat([out, encoder_outputs[layer]], dim=1)
out = decoder(out)
out = torch.cat([out, waveform],dim=1)
out = torch.cat([out, waveform], dim=1)
out = self.final(out)
return out
def fix_last_dim(self,x,target):
def fix_last_dim(self, x, target):
"""
trying to do centre crop along last dimension
centre crop along last dimension
"""
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension"
assert (
x.shape[-1] >= target.shape[-1]
), "input dimension cannot be larger than target dimension"
if x.shape[-1] == target.shape[-1]:
return x
diff = x.shape[-1] - target.shape[-1]
if diff%2!=0:
x = F.pad(x,(0,1))
if diff % 2 != 0:
x = F.pad(x, (0, 1))
diff += 1
crop = diff//2
return x[:,:,crop:-crop]
crop = diff // 2
return x[:, :, crop:-crop]

View File

@ -1,3 +1,3 @@
from enhancer.utils.utils import check_files
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.config import Files
from enhancer.utils.utils import check_files

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass
@dataclass
class Files:
train_clean : str
train_noisy : str
test_clean : str
test_noisy : str
train_clean: str
train_noisy: str
test_clean: str
test_noisy: str

View File

@ -1,41 +1,67 @@
import os
from pathlib import Path
from typing import Optional, Union
import librosa
from typing import Optional
import numpy as np
import torch
import torchaudio
class Audio:
"""
Audio utils
parameters:
sampling_rate : int, defaults to 16KHz
audio sampling rate
mono: bool, defaults to True
return_tensors: bool, defaults to True
returns torch tensor type if set to True else numpy ndarray
"""
def __init__(
self,
sampling_rate:int=16000,
mono:bool=True,
return_tensor=True
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
) -> None:
self.sampling_rate = sampling_rate
self.mono = mono
self.return_tensor = return_tensor
def __call__(
self,
audio,
sampling_rate:Optional[int]=None,
offset:Optional[float] = None,
duration:Optional[float] = None
audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
offset: Optional[float] = None,
duration: Optional[float] = None,
):
if isinstance(audio,str):
"""
read and process input audio
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate : int, optional
sampling rate of the audio input
offset: float, optional
offset from which the audio must be read, reads from beginning if unused.
duration: float (seconds), optional
read duration, reads full audio starting from offset if not used
"""
if isinstance(audio, str):
if os.path.exists(audio):
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False,
offset=offset,duration=duration)
audio, sampling_rate = librosa.load(
audio,
sr=sampling_rate,
mono=False,
offset=offset,
duration=duration,
)
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
audio = audio.reshape(1, -1)
else:
raise FileNotFoundError(f"File {audio} deos not exist")
elif isinstance(audio,np.ndarray):
elif isinstance(audio, np.ndarray):
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
audio = audio.reshape(1, -1)
else:
raise ValueError("audio should be either filepath or numpy ndarray")
@ -43,40 +69,60 @@ class Audio:
audio = self.convert_mono(audio)
if sampling_rate:
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate)
audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate
)
if self.return_tensor:
return torch.tensor(audio)
else:
return audio
@staticmethod
def convert_mono(
audio
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
"""
convert input audio into mono (1)
parameters:
audio: np.ndarray or torch.Tensor
"""
if len(audio.shape) > 2:
assert (
audio.shape[0] == 1
), "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1], audio.shape[2])
):
if len(audio.shape)>2:
assert audio.shape[0] == 1, "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1],audio.shape[2])
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}"
num_channels,num_samples = audio.shape
if num_channels>1:
return audio.mean(axis=0).reshape(1,num_samples)
assert (
audio.shape[1] >> audio.shape[0]
), f"expected input format (num_channels,num_samples) got {audio.shape}"
num_channels, num_samples = audio.shape
if num_channels > 1:
return audio.mean(axis=0).reshape(1, num_samples)
return audio
@staticmethod
def resample_audio(
audio,
sr:int,
target_sr:int
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
):
if sr!=target_sr:
if isinstance(audio,np.ndarray):
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr)
elif isinstance(audio,torch.Tensor):
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
"""
resample audio to desired sampling rate
parameters:
audio : Path to audio file or numpy array or torch tensor
audio waveform
sr : int
current sampling rate
target_sr : int
target sampling rate
"""
if sr != target_sr:
if isinstance(audio, np.ndarray):
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
elif isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(
audio, orig_freq=sr, new_freq=target_sr
)
else:
raise ValueError("Input should be either numpy array or torch tensor")
raise ValueError(
"Input should be either numpy array or torch tensor"
)
return audio

View File

@ -1,19 +1,19 @@
import os
import random
import torch
def create_unique_rng(epoch:int):
def create_unique_rng(epoch: int):
"""create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random()
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0"))
global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
local_rank = int(os.environ.get('LOCAL_RANK',"0"))
node_rank = int(os.environ.get('NODE_RANK',"0"))
world_size = int(os.environ.get('WORLD_SIZE',"0"))
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get("NODE_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "0"))
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
@ -24,17 +24,13 @@ def create_unique_rng(epoch:int):
worker_id = 0
seed = (
global_seed
+ worker_id
+ local_rank * num_workers
+ node_rank * num_workers * global_rank
+ epoch * num_workers * world_size
)
global_seed
+ worker_id
+ local_rank * num_workers
+ node_rank * num_workers * global_rank
+ epoch * num_workers * world_size
)
rng.seed(seed)
return rng

View File

@ -1,19 +1,26 @@
import os
from typing import Optional
from enhancer.utils.config import Files
def check_files(root_dir:str, files:Files):
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')]
def check_files(root_dir: str, files: Files):
path_variables = [
member_var
for member_var in dir(files)
if not member_var.startswith("__")
]
for variable in path_variables:
path = getattr(files,variable)
if not os.path.isdir(os.path.join(root_dir,path)):
path = getattr(files, variable)
if not os.path.isdir(os.path.join(root_dir, path)):
raise ValueError(f"Invalid {path}, is not a directory")
return files,root_dir
def merge_dict(default_dict:dict, custom:Optional[dict]=None):
return files, root_dir
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
params = dict(default_dict)
if custom:
params.update(custom)

View File

@ -5,4 +5,4 @@ dependencies:
- python=3.8
- pip:
- -r requirements.txt
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html

View File

@ -33,7 +33,7 @@ mkdir temp
pwd
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
echo "Start Training..."
python cli/train.py

15
pyproject.toml Normal file
View File

@ -0,0 +1,15 @@
[tool.black]
line-length = 80
target-version = ['py38']
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.mypy_cache
| \.tox
| \.venv
)/
)
'''

View File

@ -1,15 +1,16 @@
joblib==1.1.0
numpy==1.19.5
librosa==0.9.1
numpy==1.19.5
hydra-core==1.2.0
scikit-learn==0.24.2
scipy==1.5.4
torch==1.10.2
tqdm==4.64.0
mlflow==1.23.1
protobuf==3.19.3
boto3==1.23.9
torchaudio==0.10.2
huggingface-hub==0.4.0
pytorch-lightning==1.5.10
black>=22.8.0
boto3>=1.24.86
flake8>=5.0.4
huggingface-hu>=0.10.0
hydra-core>=1.2.0
joblib>=1.2.0
librosa>=0.9.2
mlflow>=1.29.0
numpy>=1.23.3
protobuf>=3.19.6
pytorch-lightning>=1.7.7
scikit-learn>=1.1.2
scipy>=1.9.1
torch>=1.12.1
torchaudio>=0.12.1
tqdm>=4.64.1

View File

@ -10,4 +10,4 @@ conda env create -f environment.yml || conda env update -f environment.yml
source activate enhancer
echo "copying files"
# cp /scratch/$USER/TIMIT/.* /deep-transcriber
# cp /scratch/$USER/TIMIT/.* /deep-transcriber

View File

@ -1,31 +1,32 @@
from asyncio import base_tasks
import torch
import pytest
import torch
from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun):
batch_size = 4
shape = (1,1000)
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
shape = (1, 1000)
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
with pytest.raises(TypeError):
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_output_type(loss):
batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
batch_size, 1, 1000
)
loss_value = loss(prediction, target)
assert isinstance(loss_value.item(),float)
assert isinstance(loss_value.item(), float)

View File

@ -1,46 +1,43 @@
import pytest
import torch
from enhancer import data
from enhancer.utils.config import Files
from enhancer.models import Demucs
from enhancer.data.dataset import EnhancerDataset
from enhancer.models import Demucs
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
def test_forward(batch_size,samples):
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = Demucs()
model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False)
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(TypeError):
_ = model(data)
@pytest.mark.parametrize("dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
def test_demucs_init(dataset,channels,loss):
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
model = Demucs(num_channels=channels,dataset=dataset,loss=loss)
_ = Demucs(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -1,46 +1,43 @@
import pytest
import torch
from enhancer import data
from enhancer.utils.config import Files
from enhancer.models import WaveUnet
from enhancer.data.dataset import EnhancerDataset
from enhancer.models import WaveUnet
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
def test_forward(batch_size,samples):
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = WaveUnet()
model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False)
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(TypeError):
_ = model(data)
@pytest.mark.parametrize("dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
def test_demucs_init(dataset,channels,loss):
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
_ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -4,22 +4,26 @@ import torch
from enhancer.inference import Inference
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)])
@pytest.mark.parametrize(
"audio",
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
)
def test_read_input(audio):
read_audio = Inference.read_input(audio,48000,16000)
assert isinstance(read_audio,torch.Tensor)
read_audio = Inference.read_input(audio, 48000, 16000)
assert isinstance(read_audio, torch.Tensor)
assert read_audio.shape[0] == 1
def test_batchify():
rand = torch.rand(1,1000)
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100)
rand = torch.rand(1, 1000)
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
assert batched_rand.shape[0] == 12
def test_aggregate():
rand = torch.rand(12,1,100)
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100)
rand = torch.rand(12, 1, 100)
agg_rand = Inference.aggreagate(
data=rand, window_size=100, total_frames=1000, step_size=100
)
assert agg_rand.shape[-1] == 1000

View File

@ -1,46 +1,50 @@
from logging import root
import numpy as np
import pytest
import torch
import numpy as np
from enhancer.utils.io import Audio
from enhancer.utils.config import Files
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils.io import Audio
def test_io_channel():
input_audio = np.random.rand(2,32000)
audio = Audio(mono=True,return_tensor=False)
input_audio = np.random.rand(2, 32000)
audio = Audio(mono=True, return_tensor=False)
output_audio = audio(input_audio)
assert output_audio.shape[0] == 1
def test_io_resampling():
input_audio = np.random.rand(1,32000)
resampled_audio = Audio.resample_audio(input_audio,16000,8000)
input_audio = np.random.rand(1, 32000)
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
input_audio = torch.rand(1,32000)
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000)
input_audio = torch.rand(1, 32000)
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
def test_fileprocessor_vctk():
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",48000)
fp = Fileprocessor.from_name(
"vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
48000,
)
matching_dict = fp.prepare_matching_dict()
assert len(matching_dict)==2
assert len(matching_dict) == 2
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000)
assert hasattr(fp.matching_function, '__call__')
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
assert hasattr(fp.matching_function, "__call__")
def test_fileprocessor_invaliname():
with pytest.raises(ValueError):
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict()
_ = Fileprocessor.from_name(
"undefined", "clean_dir", "noisy_dir", 16000
).prepare_matching_dict()