Merge pull request #21 from shahules786/dev

Merge changes to main
This commit is contained in:
Shahul ES 2022-11-10 10:43:12 +05:30 committed by GitHub
commit 1d366d6096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 3777 additions and 1 deletions

9
.flake8 Normal file
View File

@ -0,0 +1,9 @@
[flake8]
per-file-ignores = __init__.py:F401
ignore = E203, E266, E501, W503
# line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9
exclude = tools/kaldi_decoder

51
.github/workflows/ci.yaml vendored Normal file
View File

@ -0,0 +1,51 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Enhancer
on:
push:
branches: [ dev ]
pull_request:
branches: [ dev ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1.1.1
env :
ACTIONS_ALLOW_UNSECURE_COMMANDS : true
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v1
with:
path: ~/.cache/pip # This path is specific to Ubuntu
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
# You can test your matrix by printing the current Python version
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo apt-get install libsndfile1
pip install -r requirements.txt
pip install black pytest-cov
- name: Install enhancer
run: |
pip install -e .[dev,testing]
- name: Run black
run:
black --check . --exclude enhancer/version.py
- name: Test with pytest
run:
pytest tests --cov=enhancer/

7
.gitignore vendored
View File

@ -1,3 +1,10 @@
#local
*.ckpt
*_local.yaml
cli/train_config/dataset/Vctk_local.yaml
.DS_Store
outputs/
datasets/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

43
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,43 @@
repos:
# # Clean Notebooks
# - repo: https://github.com/kynan/nbstripout
# rev: master
# hooks:
# - id: nbstripout
# Format Code
- repo: https://github.com/ambv/black
rev: 22.8.0
hooks:
- id: black
# Sort imports
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://gitlab.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
args: ['--ignore=E203,E501,F811,E712,W503']
# Formatting, Whitespace, etc
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
- id: check-ast
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: mixed-line-ending
args: ['--fix=no']

View File

@ -1 +1,43 @@
# enhancer <p align="center">
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
</p>
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
## Key features :key:
* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
* :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
* :zap: Supports multi-gpu training integrated with Pytorch Lightning.
## Quick Start :fire:
``` python
from mayavoz import Mayamodel
model = Mayamodel.from_pretrained("mayavoz/waveunet")
model("noisy_audio.wav")
```
## Installation
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
- With Pypi
```
pip install mayavoz
```
- With conda
```
conda env create -f environment.yml
conda activate mayavoz
```
- From source code
```
git clone url
cd mayavoz
pip install -e .
```

1
enhancer/__init__.py Normal file
View File

@ -0,0 +1 @@
__import__("pkg_resources").declare_namespace(__name__)

120
enhancer/cli/train.py Normal file
View File

@ -0,0 +1,120 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
OmegaConf.save(config, "config_log.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,7 @@
defaults:
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -0,0 +1,12 @@
_target_: enhancer.data.dataset.EnhancerDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD
name : dns-2020
duration : 2.0
sampling_rate: 16000
batch_size: 32
valid_size: 0.05
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5
stride : 2
sampling_rate: 16000
batch_size: 32
valid_minutes : 15
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
duration : 1.0
sampling_rate: 16000
batch_size: 64
num_workers : 0
files:
train_clean : clean_testset_wav
test_clean : clean_testset_wav
train_noisy : noisy_testset_wav
test_noisy : noisy_testset_wav

View File

@ -0,0 +1,7 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
ReduceLr_patience : 5
ReduceLr_factor : 0.2
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -0,0 +1,2 @@
experiment_name : shahules/enhancer
run_name : Demucs + Vtck with stride + augmentations

View File

@ -0,0 +1,25 @@
_target_: enhancer.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -0,0 +1,16 @@
_target_: enhancer.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2
glu: True
lstm:
bidirectional: False
num_layers: 2

View File

@ -0,0 +1,5 @@
_target_: enhancer.models.waveunet.WaveUnet
num_channels : 1
depth : 9
initial_output_channels: 24
sampling_rate : 16000

View File

@ -0,0 +1,6 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -0,0 +1,46 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -0,0 +1,2 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

376
enhancer/data/dataset.py Normal file
View File

@ -0,0 +1,376 @@
import math
import multiprocessing
import os
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch_audiomentations import Compose
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
LARGE_NUM = 2147483647
class TrainDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.train__getitem__(idx)
def __len__(self):
return self.dataset.train__len__()
class ValidDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.val__getitem__(idx)
def __len__(self):
return self.dataset.val__len__()
class TestDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.test__getitem__(idx)
def __len__(self):
return self.dataset.test__len__()
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name: str,
root_dir: str,
files: Files,
min_valid_minutes: float = 0.20,
duration: float = 1.0,
stride=None,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
):
super().__init__()
self.name = name
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.stride = stride or duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes
else:
raise ValueError("min_valid_minutes must be greater than 0")
self.augmentations = augmentations
def setup(self, stage: Optional[str] = None):
"""
prepare train/validation/test data splits
"""
if stage in ("fit", None):
train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
train_data = fp.prepare_matching_dict()
train_data, self.val_data = self.train_valid_split(
train_data,
min_valid_minutes=self.min_valid_minutes,
random_state=42,
)
self.train_data = self.prepare_traindata(train_data)
self._validation = self.prepare_mapstype(self.val_data)
test_clean = os.path.join(self.root_dir, self.files.test_clean)
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, test_clean, test_noisy, self.matching_function
)
test_data = fp.prepare_matching_dict()
self._test = self.prepare_mapstype(test_data)
def train_valid_split(
self, data, min_valid_minutes: float = 20, random_state: int = 42
):
min_valid_minutes *= 60
valid_sec_now = 0.0
valid_indices = []
all_speakers = np.unique(
[Path(file["clean"]).name.split("_")[0] for file in data]
)
possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= min_valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
print(f"Selected f{speaker_name} for valid")
file_indices = [
i
for i, file in enumerate(data)
if speaker_name == Path(file["clean"]).name.split("_")[0]
]
for i in file_indices:
valid_indices.append(i)
valid_sec_now += data[i]["duration"]
train_data = [
item for i, item in enumerate(data) if i not in valid_indices
]
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
return train_data, valid_data
def prepare_traindata(self, data):
train_data = []
for item in data:
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
train_data.append(samples_metadata)
return train_data
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def prepare_mapstype(self, data):
metadata = []
for item in data:
clean, noisy, total_dur = item.values()
if total_dur < self.duration:
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
else:
num_segments = self.get_num_segments(
total_dur, self.duration, self.duration
)
for index in range(num_segments):
start_time = index * self.duration
metadata.append(
({"clean": clean, "noisy": noisy}, start_time)
)
return metadata
def train_collatefn(self, batch):
output = {"clean": [], "noisy": []}
for item in batch:
output["clean"].append(item["clean"])
output["noisy"].append(item["noisy"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
if self.augmentations is not None:
noise = output["noisy"] - output["clean"]
output["clean"] = self.augmentations(
output["clean"], sample_rate=self.sampling_rate
)
self.augmentations.freeze_parameters()
output["noisy"] = (
self.augmentations(noise, sample_rate=self.sampling_rate)
+ output["clean"]
)
return output
@property
def generator(self):
generator = torch.Generator()
if hasattr(self, "model"):
seed = self.model.current_epoch + LARGE_NUM
else:
seed = LARGE_NUM
return generator.manual_seed(seed)
def train_dataloader(self):
dataset = TrainDataset(self)
sampler = RandomSampler(dataset, generator=self.generator)
return DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=sampler,
collate_fn=self.train_collatefn,
)
def val_dataloader(self):
return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def test_dataloader(self):
return DataLoader(
TestDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset):
"""
Dataset object for creating clean-noisy speech enhancement datasets
paramters:
name : str
name of the dataset
root_dir : str
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass)
min_valid_minutes: float
minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
batch_size : int
batch size of each batch
num_workers : int
num workers to be used while training
matching_function : str
maching functions - (one_to_one,one_to_many). Default set to None.
use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
self,
name: str,
root_dir: str,
files: Files,
min_valid_minutes=5.0,
duration=1.0,
stride=None,
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
):
super().__init__(
name=name,
root_dir=root_dir,
files=files,
min_valid_minutes=min_valid_minutes,
sampling_rate=sampling_rate,
duration=duration,
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
augmentations=augmentations,
)
self.sampling_rate = sampling_rate
self.files = files
self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
self.stride = stride or duration
def setup(self, stage: Optional[str] = None):
super().setup(stage=stage)
def train__getitem__(self, idx):
for filedict, num_samples in self.train_data:
if idx >= num_samples:
idx -= num_samples
continue
else:
start = 0
if self.duration is not None:
start = idx * self.stride
return self.prepare_segment(filedict, start)
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
def test__getitem__(self, idx):
return self.prepare_segment(*self._test[idx])
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
noisy_segment = self.audio(
file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {
"clean": clean_segment,
"noisy": noisy_segment,
}
def train__len__(self):
_, num_examples = list(zip(*self.train_data))
return sum(num_examples)
def val__len__(self):
return len(self._validation)
def test__len__(self):
return len(self._test)

View File

@ -0,0 +1,121 @@
import glob
import os
import numpy as np
from scipy.io import wavfile
MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod
def one_to_one(clean_path, noisy_path):
"""
One clean audio can have only one noisy audio file
"""
matching_wavfiles = list()
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, file_name)
)
sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append(
{
"clean": os.path.join(clean_path, file_name),
"noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
@staticmethod
def one_to_many(clean_path, noisy_path):
"""
One clean audio have multiple noisy audio files
"""
matching_wavfiles = list()
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames:
noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}")
)
for noisy_file in noisy_filenames:
sr_clean, clean_wav = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_wav = wavfile.read(noisy_file)
if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append(
{
"clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,
"duration": clean_wav.shape[-1] / sr_clean,
}
)
return matching_wavfiles
class Fileprocessor:
def __init__(self, clean_dir, noisy_dir, matching_function=None):
self.clean_dir = clean_dir
self.noisy_dir = noisy_dir
self.matching_function = matching_function
@classmethod
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None:
if name.lower() == "vctk":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else:
raise ValueError(
f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
)
else:
if matching_function not in MATCHING_FNS:
raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else:
return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self):
if self.matching_function is None:
raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir, self.noisy_dir)

170
enhancer/inference.py Normal file
View File

@ -0,0 +1,170 @@
from pathlib import Path
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from librosa import load as load_audio
from scipy.io import wavfile
from scipy.signal import get_window
from enhancer.utils import Audio
class Inference:
"""
contains methods used for inference.
"""
@staticmethod
def read_input(audio, sr, model_sr):
"""
read and verify audio input regardless of the input format.
arguments:
audio : audio input
sr : sampling rate of input audio
model_sr : sampling rate used for model training.
"""
if isinstance(audio, (np.ndarray, torch.Tensor)):
assert sr is not None, "Invalid sampling rate!"
if len(audio.shape) == 1:
audio = audio.reshape(1, -1)
if isinstance(audio, str):
audio = Path(audio)
if not audio.is_file():
raise ValueError(f"Input file {audio} does not exist")
else:
audio, sr = load_audio(
audio,
sr=sr,
)
if len(audio.shape) == 1:
audio = audio.reshape(1, -1)
else:
assert (
audio.shape[0] == 1
), "Enhance inference only supports single waveform"
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
waveform = Audio.convert_mono(waveform)
if isinstance(waveform, np.ndarray):
waveform = torch.from_numpy(waveform)
return waveform
@staticmethod
def batchify(
waveform: torch.Tensor,
window_size: int,
step_size: Optional[int] = None,
):
"""
break input waveform into samples with duration specified.(Overlap-add)
arguments:
waveform : audio waveform
window_size : window size used for splitting waveform into batches
step_size : step_size used for splitting waveform into batches
"""
assert (
waveform.ndim == 2
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
_, num_samples = waveform.shape
waveform = waveform.unsqueeze(-1)
step_size = window_size // 2 if step_size is None else step_size
if num_samples >= window_size:
waveform_batch = F.unfold(
waveform[None, ...],
kernel_size=(window_size, 1),
stride=(step_size, 1),
padding=(window_size, 0),
)
waveform_batch = waveform_batch.permute(2, 0, 1)
return waveform_batch
@staticmethod
def aggreagate(
data: torch.Tensor,
window_size: int,
total_frames: int,
step_size: Optional[int] = None,
window="hamming",
):
"""
stitch batched waveform into single waveform. (Overlap-add)
arguments:
data: batched waveform
window_size : window_size used to batch waveform
step_size : step_size used to batch waveform
total_frames : total number of frames present in original waveform
window : type of window used for overlap-add mechanism.
"""
num_chunks, n_channels, num_frames = data.shape
window = get_window(window=window, Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device)
data *= window
step_size = window_size // 2 if step_size is None else step_size
data = data.permute(1, 2, 0)
data = F.fold(
data,
(total_frames, 1),
kernel_size=(window_size, 1),
stride=(step_size, 1),
padding=(window_size, 0),
).squeeze(-1)
return data.reshape(1, n_channels, -1)
@staticmethod
def write_output(
waveform: torch.Tensor, filename: Union[str, Path], sr: int
):
"""
write audio output as wav file
arguments:
waveform : audio waveform
filename : name of the wave file. Output will be written as cleaned_filename.wav
sr : sampling rate
"""
if isinstance(filename, str):
filename = Path(filename)
parent, name = filename.parent, "cleaned_" + filename.name
filename = parent / Path(name)
if filename.is_file():
raise FileExistsError(f"file {filename} already exists")
else:
wavfile.write(
filename, rate=sr, data=waveform.detach().cpu().numpy()
)
@staticmethod
def prepare_output(
waveform: torch.Tensor,
model_sampling_rate: int,
audio: Union[str, np.ndarray, torch.Tensor],
sampling_rate: Optional[int],
):
"""
prepare output audio based on input format
arguments:
waveform : predicted audio waveform
model_sampling_rate : sampling rate used to train the model
audio : input audio
sampling_rate : input audio sampling rate
"""
if isinstance(audio, np.ndarray):
waveform = waveform.detach().cpu().numpy()
if sampling_rate is not None:
waveform = Audio.resample_audio(
waveform, sr=model_sampling_rate, target_sr=sampling_rate
)
return waveform

216
enhancer/loss.py Normal file
View File

@ -0,0 +1,216 @@
import logging
import numpy as np
import torch
import torch.nn as nn
from torchmetrics import ScaleInvariantSignalNoiseRatio
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
class mean_squared_error(nn.Module):
"""
Mean squared error / L1 loss
"""
def __init__(self, reduction="mean"):
super().__init__()
self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False
self.name = "mse"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
class mean_absolute_error(nn.Module):
"""
Mean absolute error / L2 loss
"""
def __init__(self, reduction="mean"):
super().__init__()
self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False
self.name = "mae"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
class Si_SDR:
"""
SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
"""
def __init__(self, reduction: str = "mean"):
if reduction in ["sum", "mean", None]:
self.reduction = reduction
else:
raise TypeError(
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = True
self.name = "si-sdr"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
target_energy = torch.sum(target**2, keepdim=True, dim=-1)
scaling_factor = (
torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
)
target_projection = target * scaling_factor
noise = prediction - target_projection
ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
noise**2, dim=-1
)
si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
if self.reduction == "sum":
si_sdr = si_sdr.sum()
elif self.reduction == "mean":
si_sdr = si_sdr.mean()
else:
pass
return si_sdr
class Stoi:
"""
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to cpu to perform the metric calculation.
parameters:
sr: int
sampling rate
"""
def __init__(self, sr: int):
self.sr = sr
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
self.name = "stoi"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
return self.stoi(prediction, target)
class Pesq:
def __init__(self, sr: int, mode="wb"):
self.sr = sr
self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(
fs=self.sr, mode=self.mode
)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
pesq_values = []
for pred, target_ in zip(prediction, target):
try:
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values))
class LossWrapper(nn.Module):
"""
Combine multiple metics of same nature.
for example, ["mea","mae"]
parameters:
losses : loss function names to be combined
"""
def __init__(self, losses):
super().__init__()
self.valid_losses = nn.ModuleList()
direction = [
getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
]
if len(set(direction)) > 1:
raise ValueError(
"all cost functions should be of same nature, maximize or minimize!"
)
self.higher_better = direction[0]
self.name = ""
for loss in losses:
loss = self.validate_loss(loss)
self.valid_losses.append(loss())
self.name += f"{loss().name}_"
def validate_loss(self, loss: str):
if loss not in LOSS_MAP.keys():
raise ValueError(
f"""Invalid loss function {loss}, available loss functions are
{tuple([loss for loss in LOSS_MAP.keys()])}"""
)
else:
return LOSS_MAP[loss]
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
loss = 0.0
for loss_fun in self.valid_losses:
loss += loss_fun(prediction, target)
return loss
class Si_snr(nn.Module):
"""
SI-SNR
"""
def __init__(self, **kwargs):
super().__init__()
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
self.higher_better = True
self.name = "si_snr"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
LOSS_MAP = {
"mae": mean_absolute_error,
"mse": mean_squared_error,
"si-sdr": Si_SDR,
"pesq": Pesq,
"stoi": Stoi,
"si-snr": Si_snr,
}

View File

@ -0,0 +1,3 @@
from enhancer.models.demucs import Demucs
from enhancer.models.model import Model
from enhancer.models.waveunet import WaveUnet

View File

@ -0,0 +1,5 @@
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
from enhancer.models.complexnn.utils import ComplexRelu # noqa

View File

@ -0,0 +1,136 @@
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
def init_weights(nnet):
nn.init.xavier_normal_(nnet.weight.data)
nn.init.constant_(nnet.bias, 0.0)
return nnet
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
groups: int = 1,
dilation: int = 1,
):
"""
Complex Conv2d (non-causal)
"""
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.real_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = init_weights(self.imag_conv)
self.real_conv = init_weights(self.real_conv)
def forward(self, input):
"""
complex axis should be always 1 dim
"""
input = F.pad(input, [self.padding[1], 0, 0, 0])
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
output_padding: Tuple[int, int] = (0, 0),
groups: int = 1,
):
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.output_padding = output_padding
self.real_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.imag_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.real_conv = init_weights(self.real_conv)
self.imag_conv = init_weights(self.imag_conv)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out

View File

@ -0,0 +1,68 @@
from typing import List, Optional
import torch
from torch import nn
class ComplexLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
projection_size: Optional[int] = None,
bidirectional: bool = False,
):
super().__init__()
self.input_size = input_size // 2
self.hidden_size = hidden_size // 2
self.num_layers = num_layers
self.real_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
self.imag_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
bidirectional = 2 if bidirectional else 1
if projection_size is not None:
self.projection_size = projection_size // 2
self.real_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
self.imag_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
else:
self.projection_size = None
def forward(self, input):
if isinstance(input, List):
real, imag = input
else:
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_lstm(real)[0]
real_imag = self.imag_lstm(real)[0]
imag_imag = self.imag_lstm(imag)[0]
imag_real = self.real_lstm(imag)[0]
real = real_real - imag_imag
imag = imag_real + real_imag
if self.projection_size is not None:
real = self.real_linear(real)
imag = self.imag_linear(imag)
return [real, imag]

View File

@ -0,0 +1,199 @@
import torch
from torch import nn
class ComplexBatchNorm2D(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
"""
Complex batch normalization 2D
https://arxiv.org/abs/1705.09792
"""
super().__init__()
self.num_features = num_features // 2
self.affine = affine
self.momentum = momentum
self.track_running_stats = track_running_stats
self.eps = eps
if self.affine:
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter("Wrr", None)
self.register_parameter("Wri", None)
self.register_parameter("Wii", None)
self.register_parameter("Br", None)
self.register_parameter("Bi", None)
if self.track_running_stats:
values = torch.zeros(self.num_features)
self.register_buffer("Mean_real", values)
self.register_buffer("Mean_imag", values)
self.register_buffer("Var_rr", values)
self.register_buffer("Var_ri", values)
self.register_buffer("Var_ii", values)
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("Mean_real", None)
self.register_parameter("Mean_imag", None)
self.register_parameter("Var_rr", None)
self.register_parameter("Var_ri", None)
self.register_parameter("Var_ii", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
def reset_parameters(self):
if self.affine:
self.Wrr.data.fill_(1)
self.Wii.data.fill_(1)
self.Wri.data.uniform_(-0.9, 0.9)
self.Br.data.fill_(0)
self.Bi.data.fill_(0)
self.reset_running_stats()
def reset_running_stats(self):
if self.track_running_stats:
self.Mean_real.zero_()
self.Mean_imag.zero_()
self.Var_rr.fill_(1)
self.Var_ri.zero_()
self.Var_ii.fill_(1)
self.num_batches_tracked.zero_()
def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
**self.__dict__
)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
exp_avg_factor = 0.0
training = self.training and self.track_running_stats
if training:
self.num_batches_tracked += 1
if self.momentum is None:
exp_avg_factor = 1 / self.num_batches_tracked
else:
exp_avg_factor = self.momentum
redux = [i for i in reversed(range(real.dim())) if i != 1]
vdim = [1] * real.dim()
vdim[1] = real.size(1)
if training:
batch_mean_real, batch_mean_imag = real, imag
for dim in redux:
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
if self.track_running_stats:
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
else:
batch_mean_real = self.Mean_real.view(vdim)
batch_mean_imag = self.Mean_imag.view(vdim)
real = real - batch_mean_real
imag = imag - batch_mean_imag
if training:
batch_var_rr = real * real
batch_var_ri = real * imag
batch_var_ii = imag * imag
for dim in redux:
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
if self.track_running_stats:
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
else:
batch_var_rr = self.Var_rr.view(vdim)
batch_var_ii = self.Var_ii.view(vdim)
batch_var_ri = self.Var_ri.view(vdim)
batch_var_rr += self.eps
batch_var_ii += self.eps
# Covariance matrics
# | batch_var_rr batch_var_ri |
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
# Inverse square root of cov matrix by combining below two formulas
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
# https://mathworld.wolfram.com/MatrixInverse.html
tau = batch_var_rr + batch_var_ii
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
t = (tau + 2 * s).sqrt()
rst = (s * t).reciprocal()
Urr = (batch_var_ii + s) * rst
Uri = -batch_var_ri * rst
Uii = (batch_var_rr + s) * rst
if self.affine:
Wrr, Wri, Wii = (
self.Wrr.view(vdim),
self.Wri.view(vdim),
self.Wii.view(vdim),
)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wii * Uri) + (Wri * Urr)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * real) + (Zri * imag)
yi = (Zir * real) + (Zii * imag)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], 1)
return outputs
class ComplexRelu(nn.Module):
def __init__(self):
super().__init__()
self.real_relu = nn.PReLU()
self.imag_relu = nn.PReLU()
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real = self.real_relu(real)
imag = self.imag_relu(imag)
return torch.cat([real, imag], dim=1)
def complex_cat(inputs, axis=1):
real, imag = [], []
for data in inputs:
real_data, imag_data = torch.chunk(data, 2, axis)
real.append(real_data)
imag.append(imag_data)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
return torch.cat([real, imag], axis)

338
enhancer/models/dccrn.py Normal file
View File

@ -0,0 +1,338 @@
import logging
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from enhancer.data import EnhancerDataset
from enhancer.models import Model
from enhancer.models.complexnn import (
ComplexBatchNorm2D,
ComplexConv2d,
ComplexConvTranspose2d,
ComplexLSTM,
ComplexRelu,
)
from enhancer.models.complexnn.utils import complex_cat
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
from enhancer.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channel: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 1),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.encoder = nn.Sequential(
ComplexConv2d(
in_channels,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
batchnorm(out_channel),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DCCRN_DECODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
layer: int = 0,
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 0),
output_padding: Tuple[int, int] = (1, 0),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
if layer != 0:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
),
batchnorm(out_channels),
activation,
)
else:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
)
def forward(self, waveform):
return self.decoder(waveform)
class DCCRN(Model):
STFT_DEFAULTS = {
"window_len": 400,
"hop_size": 100,
"nfft": 512,
"window": "hamming",
}
ED_DEFAULTS = {
"initial_output_channels": 32,
"depth": 6,
"kernel_size": 5,
"growth_factor": 2,
"stride": 2,
"padding": 2,
"output_padding": 1,
}
LSTM_DEFAULTS = {
"num_layers": 2,
"hidden_size": 256,
}
def __init__(
self,
stft: Optional[dict] = None,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
complex_lstm: bool = True,
complex_norm: bool = True,
complex_relu: bool = True,
masking_mode: str = "E",
num_channels: int = 1,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
stft = merge_dict(self.STFT_DEFAULTS, stft)
self.save_hyperparameters(
"encoder_decoder",
"lstm",
"stft",
"complex_lstm",
"complex_norm",
"masking_mode",
)
self.complex_lstm = complex_lstm
self.complex_norm = complex_norm
self.masking_mode = masking_mode
self.stft = ConvSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.istft = ConviSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
num_channels *= 2
hidden_size = encoder_decoder["initial_output_channels"]
growth_factor = 2
for layer in range(encoder_decoder["depth"]):
encoder_ = DCCRN_ENCODER(
num_channels,
hidden_size,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 1),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.encoder.append(encoder_)
decoder_ = DCCRN_DECODER(
hidden_size + hidden_size,
num_channels,
layer=layer,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 0),
output_padding=(encoder_decoder["output_padding"], 0),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.decoder.insert(0, decoder_)
if layer < encoder_decoder["depth"] - 3:
num_channels = hidden_size
hidden_size *= growth_factor
else:
num_channels = hidden_size
kernel_size = hidden_size / 2
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
if self.complex_lstm:
lstms = []
for layer in range(lstm["num_layers"]):
if layer == 0:
input_size = int(hidden_size * kernel_size)
else:
input_size = lstm["hidden_size"]
if layer == lstm["num_layers"] - 1:
projection_size = int(hidden_size * kernel_size)
else:
projection_size = None
kwargs = {
"input_size": input_size,
"hidden_size": lstm["hidden_size"],
"num_layers": 1,
}
lstms.append(
ComplexLSTM(projection_size=projection_size, **kwargs)
)
self.lstm = nn.Sequential(*lstms)
else:
self.lstm = nn.Sequential(
nn.LSTM(
input_size=hidden_size * kernel_size,
hidden_sizs=lstm["hidden_size"],
num_layers=lstm["num_layers"],
dropout=0.0,
batch_first=False,
)[0],
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
waveform_stft = self.stft(waveform)
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
phase_spec = torch.atan2(imag, real)
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
encoder_outputs = []
out = complex_spec
for _, encoder in enumerate(self.encoder):
out = encoder(out)
encoder_outputs.append(out)
B, C, D, T = out.size()
out = out.permute(3, 0, 1, 2)
if self.complex_lstm:
lstm_real = out[:, :, : C // 2]
lstm_imag = out[:, :, C // 2 :]
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
lstm_real = lstm_real.reshape(T, B, C // 2, D)
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
out = torch.cat([lstm_real, lstm_imag], 2)
else:
out = out.reshape(T, B, C * D)
out = self.lstm(out)
out = out.reshape(T, B, D, C)
out = out.permute(1, 2, 3, 0)
for layer, decoder in enumerate(self.decoder):
skip_connection = encoder_outputs.pop(-1)
out = complex_cat([skip_connection, out])
out = decoder(out)
out = out[..., 1:]
mask_real, mask_imag = out[:, 0], out[:, 1]
mask_real = F.pad(mask_real, [0, 0, 1, 0])
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
if self.masking_mode == "E":
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
real_phase = mask_real / (mask_mag + 1e-8)
imag_phase = mask_imag / (mask_mag + 1e-8)
mask_phase = torch.atan2(imag_phase, real_phase)
mask_mag = torch.tanh(mask_mag)
est_mag = mask_mag * mag_spec
est_phase = mask_phase * phase_spec
# cos(theta) + isin(theta)
real = est_mag + torch.cos(est_phase)
imag = est_mag + torch.sin(est_phase)
if self.masking_mode == "C":
real = real * mask_real - imag * mask_imag
imag = real * mask_imag + imag * mask_real
else:
real = real * mask_real
imag = imag * mask_imag
spec = torch.cat([real, imag], 1)
wav = self.istft(spec)
wav = wav.clamp_(-1, 1)
return wav

274
enhancer/models/demucs.py Normal file
View File

@ -0,0 +1,274 @@
import logging
import math
from typing import List, Optional, Union
import torch.nn.functional as F
from torch import nn
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bidirectional: bool = True,
):
super().__init__()
self.lstm = nn.LSTM(
input_size, hidden_size, num_layers, bidirectional=bidirectional
)
dim = 2 if bidirectional else 1
self.linear = nn.Linear(dim * hidden_size, hidden_size)
def forward(self, x):
output, (h, c) = self.lstm(x)
output = self.linear(output)
return output, (h, c)
class DemucsEncoder(nn.Module):
def __init__(
self,
num_channels: int,
hidden_size: int,
kernel_size: int,
stride: int = 1,
glu: bool = False,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.encoder = nn.Sequential(
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DemucsDecoder(nn.Module):
def __init__(
self,
num_channels: int,
hidden_size: int,
kernel_size: int,
stride: int = 1,
glu: bool = False,
layer: int = 0,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.decoder = nn.Sequential(
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
activation,
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
)
if layer > 0:
self.decoder.add_module("4", nn.ReLU())
def forward(
self,
waveform,
):
out = self.decoder(waveform)
return out
class Demucs(Model):
"""
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters:
encoder_decoder: dict, optional
keyword arguments passsed to encoder decoder block
lstm : dict, optional
keyword arguments passsed to LSTM block
num_channels: int, defaults to 1
number channels in input audio
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
ED_DEFAULTS = {
"initial_output_channels": 48,
"kernel_size": 8,
"stride": 4,
"depth": 5,
"glu": True,
"growth_factor": 2,
}
LSTM_DEFAULTS = {
"bidirectional": True,
"num_layers": 2,
}
def __init__(
self,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
num_channels: int = 1,
resample: int = 4,
sampling_rate=16000,
normalize=True,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
floor=1e-3,
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"]
self.normalize = normalize
self.floor = floor
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for layer in range(encoder_decoder["depth"]):
encoder_layer = DemucsEncoder(
num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"],
)
self.encoder.append(encoder_layer)
decoder_layer = DemucsDecoder(
num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"],
layer=layer,
)
self.decoder.insert(0, decoder_layer)
num_channels = hidden
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
self.de_lstm = DemucsLSTM(
input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"],
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True)
std = waveform.std(dim=-1, keepdim=True)
waveform = waveform / (self.floor + std)
else:
std = 1
length = waveform.shape[-1]
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample > 1:
x = audio.resample_audio(
audio=x,
sr=self.hparams.sampling_rate,
target_sr=int(
self.hparams.sampling_rate * self.hparams.resample
),
)
encoder_outputs = []
for encoder in self.encoder:
x = encoder(x)
encoder_outputs.append(x)
x = x.permute(0, 2, 1)
x, _ = self.de_lstm(x)
x = x.permute(0, 2, 1)
for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1)
x = x + skip_connection[..., : x.shape[-1]]
x = decoder(x)
if self.hparams.resample > 1:
x = audio.resample_audio(
x,
int(self.hparams.sampling_rate * self.hparams.resample),
self.hparams.sampling_rate,
)
out = x[..., :length]
return std * out
def get_padding_length(self, input_length):
input_length = math.ceil(input_length * self.hparams.resample)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # encoder operation
input_length = (
math.ceil(
(input_length - self.hparams.encoder_decoder["kernel_size"])
/ self.hparams.encoder_decoder["stride"]
)
+ 1
)
input_length = max(1, input_length)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # decoder operaration
input_length = (input_length - 1) * self.hparams.encoder_decoder[
"stride"
] + self.hparams.encoder_decoder["kernel_size"]
input_length = math.ceil(input_length / self.hparams.resample)
return int(input_length)

431
enhancer/models/model.py Normal file
View File

@ -0,0 +1,431 @@
import os
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
import pytorch_lightning as pl
import torch
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset
from enhancer.inference import Inference
from enhancer.loss import LOSS_MAP, LossWrapper
from enhancer.version import __version__
CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
DEFAULT_DEVICE = "cpu"
class Model(pl.LightningModule):
"""
Base class for all models
parameters:
num_channels: int, default to 1
number of channels in input audio
sampling_rate : int, default 16khz
audio sampling rate
lr: float, optional
learning rate for model training
dataset: EnhancerDataset, optional
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings or custom loss (nn.Module), default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
def __init__(
self,
num_channels: int = 1,
sampling_rate: int = 16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse",
):
super().__init__()
assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset
self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
)
if self.logger:
self.logger.experiment.log_dict(
dict(self.hparams), "hyperparameters.json"
)
self.loss = loss
self.metric = metric
@property
def loss(self):
return self._loss
@loss.setter
def loss(self, loss):
if isinstance(loss, str):
loss = [loss]
self._loss = LossWrapper(loss)
@property
def metric(self):
return self._metric
@metric.setter
def metric(self, metric):
self._metric = []
if isinstance(metric, (str, nn.Module)):
metric = [metric]
for func in metric:
if isinstance(func, str):
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
LOSS_MAP[func](self.hparams.sampling_rate)
)
else:
self._metric.append(LOSS_MAP[func]())
else:
ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError("Invalid metrics")
@property
def dataset(self):
return self._dataset
@dataset.setter
def dataset(self, dataset):
self._dataset = dataset
def setup(self, stage: Optional[str] = None):
if stage == "fit":
torch.cuda.empty_cache()
self.dataset.setup(stage)
self.dataset.model = self
print(
"Total train duration",
self.dataset.train_dataloader().dataset.__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total validation duration",
self.dataset.val_dataloader().dataset.__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total test duration",
self.dataset.test_dataloader().dataset.__len__()
* self.dataset.duration
/ 60,
"minutes",
)
def train_dataloader(self):
return self.dataset.train_dataloader()
def val_dataloader(self):
return self.dataset.val_dataloader()
def test_dataloader(self):
return self.dataset.test_dataloader()
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr)
def training_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
loss = self.loss(prediction, target)
self.log(
"train_loss",
loss.item(),
on_epoch=True,
on_step=True,
logger=True,
prog_bar=True,
)
return {"loss": loss}
def validation_step(self, batch, batch_idx: int):
metric_dict = {}
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
metric_dict["valid_loss"] = self.loss(target, prediction).item()
for metric in self.metric:
value = metric(target, prediction)
metric_dict[f"valid_{metric.name}"] = value.item()
self.log_dict(
metric_dict,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return metric_dict
def test_step(self, batch, batch_idx):
metric_dict = {}
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
for metric in self.metric:
value = metric(target, prediction)
metric_dict[f"test_{metric.name}"] = value
self.log_dict(
metric_dict,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return metric_dict
def test_epoch_end(self, outputs):
test_mean_metrics = defaultdict(int)
for output in outputs:
for metric, value in output.items():
test_mean_metrics[metric] += value.item()
for metric in test_mean_metrics.keys():
test_mean_metrics[metric] /= len(outputs)
print("----------TEST REPORT----------\n")
for metric in test_mean_metrics.keys():
print(f"|{metric.upper()} | {test_mean_metrics[metric]} |")
print("--------------------------------")
def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = {
"version": {"enhancer": __version__, "pytorch": torch.__version__},
"architecture": {
"module": self.__class__.__module__,
"class": self.__class__.__name__,
},
}
@classmethod
def from_pretrained(
cls,
checkpoint: Union[Path, Text],
map_location=None,
hparams_file: Union[Path, Text] = None,
strict: bool = True,
use_auth_token: Union[Text, None] = None,
cached_dir: Union[Path, Text] = CACHE_DIR,
**kwargs,
):
"""
Load Pretrained model
parameters:
checkpoint : Path or str
Path to checkpoint, or a remote URL, or a model identifier from
the huggingface.co model hub.
map_location: optional
Same role as in torch.load().
Defaults to `lambda storage, loc: storage`.
hparams_file : Path or str, optional
Path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely wont need this since Lightning will always save the
hyperparameters to the checkpoint. However, if your checkpoint weights
do not have the hyperparameters saved, use this method to pass in a .yaml
file with the hparams you would like to use. These will be converted
into a dict and passed into your Model for use.
strict : bool, optional
Whether to strictly enforce that the keys in checkpoint match
the keys returned by this modules state dict. Defaults to True.
use_auth_token : str, optional
When loading a private huggingface.co model, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Returns
-------
model : Model
Model
See also
--------
torch.load
"""
checkpoint = str(checkpoint)
if hparams_file is not None:
hparams_file = str(hparams_file)
if os.path.isfile(checkpoint):
model_path_pl = checkpoint
elif urlparse(checkpoint).scheme in ("http", "https"):
model_path_pl = checkpoint
else:
if "@" in checkpoint:
model_id = checkpoint.split("@")[0]
revision_id = checkpoint.split("@")[1]
else:
model_id = checkpoint
revision_id = None
url = hf_hub_url(
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
)
model_path_pl = cached_download(
url=url,
library_name="enhancer",
library_version=__version__,
cache_dir=cached_dir,
use_auth_token=use_auth_token,
)
if map_location is None:
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)
try:
model = Klass.load_from_checkpoint(
checkpoint_path=model_path_pl,
map_location=map_location,
hparams_file=hparams_file,
strict=strict,
**kwargs,
)
except Exception as e:
print(e)
return model
def infer(self, batch: torch.Tensor, batch_size: int = 32):
"""
perform model inference
parameters:
batch : torch.Tensor
input data
batch_size : int, default 32
batch size for inference
"""
assert (
batch.ndim == 3
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = []
self.eval().to(self.device)
with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
self.device
)
prediction = self(batch_data)
batch_predictions.append(prediction)
return torch.vstack(batch_predictions)
def enhance(
self,
audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
batch_size: int = 32,
save_output: bool = False,
duration: Optional[int] = None,
step_size: Optional[int] = None,
):
"""
Enhance audio using loaded pretained model.
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate: int, optional incase input is path
sampling rate of input
batch_size: int, default 32
input audio is split into multiple chunks. Inference is done on batches
of these chunks according to given batch size.
save_output : bool, default False
weather to save output to file
duration : float, optional
chunk duration in seconds, defaults to duration of loaded pretrained model.
step_size: int, optional
step size between consecutive durations, defaults to 50% of duration
"""
model_sampling_rate = self.hparams["sampling_rate"]
if duration is None:
duration = self.hparams["duration"]
waveform = Inference.read_input(
audio, sampling_rate, model_sampling_rate
)
waveform.to(self.device)
window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(
waveform, window_size, step_size=step_size
)
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
waveform = Inference.aggreagate(
batch_prediction,
window_size,
waveform.shape[-1],
step_size,
)
if save_output and isinstance(audio, (str, Path)):
Inference.write_output(waveform, audio, model_sampling_rate)
else:
waveform = Inference.prepare_output(
waveform, model_sampling_rate, audio, sampling_rate
)
return waveform
@property
def valid_monitor(self):
return "max" if self.loss.higher_better else "min"

201
enhancer/models/waveunet.py Normal file
View File

@ -0,0 +1,201 @@
import logging
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
class WavenetDecoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 5,
padding: int = 2,
stride: int = 1,
dilation: int = 1,
):
super(WavenetDecoder, self).__init__()
self.decoder = nn.Sequential(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self, waveform):
return self.decoder(waveform)
class WavenetEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 15,
padding: int = 7,
stride: int = 1,
dilation: int = 1,
):
super(WavenetEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self, waveform):
return self.encoder(waveform)
class WaveUnet(Model):
"""
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
parameters:
num_channels: int, defaults to 1
number of channels in input audio
depth : int, defaults to 12
depth of network
initial_output_channels: int, defaults to 24
number of output channels in intial upsampling layer
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
def __init__(
self,
num_channels: int = 1,
depth: int = 12,
initial_output_channels: int = 24,
sampling_rate: int = 16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
self.save_hyperparameters("depth")
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
out_channels = initial_output_channels
for layer in range(depth):
encoder = WavenetEncoder(num_channels, out_channels)
self.encoders.append(encoder)
num_channels = out_channels
out_channels += initial_output_channels
if layer == depth - 1:
decoder = WavenetDecoder(
depth * initial_output_channels + num_channels, num_channels
)
else:
decoder = WavenetDecoder(
num_channels + out_channels, num_channels
)
self.decoders.insert(0, decoder)
bottleneck_dim = depth * initial_output_channels
self.bottleneck = nn.Sequential(
nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
nn.BatchNorm1d(bottleneck_dim),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
)
self.final = nn.Sequential(
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
nn.Tanh(),
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != 1:
raise TypeError(
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
)
encoder_outputs = []
out = waveform
for encoder in self.encoders:
out = encoder(out)
encoder_outputs.insert(0, out)
out = out[:, :, ::2]
out = self.bottleneck(out)
for layer, decoder in enumerate(self.decoders):
out = F.interpolate(out, scale_factor=2, mode="linear")
out = self.fix_last_dim(out, encoder_outputs[layer])
out = torch.cat([out, encoder_outputs[layer]], dim=1)
out = decoder(out)
out = torch.cat([out, waveform], dim=1)
out = self.final(out)
return out
def fix_last_dim(self, x, target):
"""
centre crop along last dimension
"""
assert (
x.shape[-1] >= target.shape[-1]
), "input dimension cannot be larger than target dimension"
if x.shape[-1] == target.shape[-1]:
return x
diff = x.shape[-1] - target.shape[-1]
if diff % 2 != 0:
x = F.pad(x, (0, 1))
diff += 1
crop = diff // 2
return x[:, :, crop:-crop]

View File

@ -0,0 +1,3 @@
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.utils import check_files

9
enhancer/utils/config.py Normal file
View File

@ -0,0 +1,9 @@
from dataclasses import dataclass
@dataclass
class Files:
train_clean: str
train_noisy: str
test_clean: str
test_noisy: str

128
enhancer/utils/io.py Normal file
View File

@ -0,0 +1,128 @@
import os
from pathlib import Path
from typing import Optional, Union
import librosa
import numpy as np
import torch
import torchaudio
class Audio:
"""
Audio utils
parameters:
sampling_rate : int, defaults to 16KHz
audio sampling rate
mono: bool, defaults to True
return_tensors: bool, defaults to True
returns torch tensor type if set to True else numpy ndarray
"""
def __init__(
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
) -> None:
self.sampling_rate = sampling_rate
self.mono = mono
self.return_tensor = return_tensor
def __call__(
self,
audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
offset: Optional[float] = None,
duration: Optional[float] = None,
):
"""
read and process input audio
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate : int, optional
sampling rate of the audio input
offset: float, optional
offset from which the audio must be read, reads from beginning if unused.
duration: float (seconds), optional
read duration, reads full audio starting from offset if not used
"""
if isinstance(audio, str):
if os.path.exists(audio):
audio, sampling_rate = librosa.load(
audio,
sr=sampling_rate,
mono=False,
offset=offset,
duration=duration,
)
if len(audio.shape) == 1:
audio = audio.reshape(1, -1)
else:
raise FileNotFoundError(f"File {audio} deos not exist")
elif isinstance(audio, np.ndarray):
if len(audio.shape) == 1:
audio = audio.reshape(1, -1)
else:
raise ValueError("audio should be either filepath or numpy ndarray")
if self.mono:
audio = self.convert_mono(audio)
if sampling_rate:
audio = self.__class__.resample_audio(
audio, sampling_rate, self.sampling_rate
)
if self.return_tensor:
return torch.tensor(audio)
else:
return audio
@staticmethod
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
"""
convert input audio into mono (1)
parameters:
audio: np.ndarray or torch.Tensor
"""
if len(audio.shape) > 2:
assert (
audio.shape[0] == 1
), "convert mono only accepts single waveform"
audio = audio.reshape(audio.shape[1], audio.shape[2])
assert (
audio.shape[1] >> audio.shape[0]
), f"expected input format (num_channels,num_samples) got {audio.shape}"
num_channels, num_samples = audio.shape
if num_channels > 1:
return audio.mean(axis=0).reshape(1, num_samples)
return audio
@staticmethod
def resample_audio(
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
):
"""
resample audio to desired sampling rate
parameters:
audio : Path to audio file or numpy array or torch tensor
audio waveform
sr : int
current sampling rate
target_sr : int
target sampling rate
"""
if sr != target_sr:
if isinstance(audio, np.ndarray):
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
elif isinstance(audio, torch.Tensor):
audio = torchaudio.functional.resample(
audio, orig_freq=sr, new_freq=target_sr
)
else:
raise ValueError(
"Input should be either numpy array or torch tensor"
)
return audio

36
enhancer/utils/random.py Normal file
View File

@ -0,0 +1,36 @@
import os
import random
import torch
def create_unique_rng(epoch: int):
"""create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random()
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get("NODE_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "0"))
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers = worker_info.num_workers
worker_id = worker_info.id
else:
num_workers = 1
worker_id = 0
seed = (
global_seed
+ worker_id
+ local_rank * num_workers
+ node_rank * num_workers * global_rank
+ epoch * num_workers * world_size
)
rng.seed(seed)
return rng

View File

@ -0,0 +1,93 @@
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import get_window
from torch import nn
class ConvFFT(nn.Module):
def __init__(
self,
window_len: int,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__()
self.window_len = window_len
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
self.window = torch.from_numpy(
get_window(window, window_len, fftbins=True).astype("float32")
)
def init_kernel(self, inverse=False):
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
kernel = np.concatenate([real, imag], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
kernel *= self.window
return kernel
class ConvSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel())
def forward(self, input):
if input.dim() < 2:
raise ValueError(
f"Expected signal with shape 2 or 3 got {input.dim()}"
)
elif input.dim() == 2:
input = input.unsqueeze(1)
else:
pass
input = F.pad(
input,
(self.window_len - self.hop_size, self.window_len - self.hop_size),
)
output = F.conv1d(input, self.weight, stride=self.hop_size)
return output
class ConviSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel(True))
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
def forward(self, input, phase=None):
if phase is not None:
real = input * torch.cos(phase)
imag = input * torch.sin(phase)
input = torch.cat([real, imag], 1)
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
coeff = coeff.to(input.device)
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
out = out / (coeff + 1e-8)
pad = self.window_len - self.hop_size
out = out[..., pad:-pad]
return out

27
enhancer/utils/utils.py Normal file
View File

@ -0,0 +1,27 @@
import os
from typing import Optional
from enhancer.utils.config import Files
def check_files(root_dir: str, files: Files):
path_variables = [
member_var
for member_var in dir(files)
if not member_var.startswith("__")
]
for variable in path_variables:
path = getattr(files, variable)
if not os.path.isdir(os.path.join(root_dir, path)):
raise ValueError(f"Invalid {path}, is not a directory")
return files, root_dir
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
params = dict(default_dict)
if custom:
params.update(custom)
return params

1
enhancer/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.0.1"

8
environment.yml Normal file
View File

@ -0,0 +1,8 @@
name: enhancer
dependencies:
- pip=21.0.1
- python=3.8
- pip:
- -r requirements.txt
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html

15
pyproject.toml Normal file
View File

@ -0,0 +1,15 @@
[tool.black]
line-length = 80
target-version = ['py38']
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.mypy_cache
| \.tox
| \.venv
)/
)
'''

19
requirements.txt Normal file
View File

@ -0,0 +1,19 @@
boto3>=1.24.86
huggingface-hub>=0.10.0
hydra-core>=1.2.0
joblib>=1.2.0
librosa>=0.9.2
mlflow>=1.29.0
numpy>=1.23.3
pesq==0.0.4
protobuf>=3.19.6
pystoi==0.3.3
pytest-lazy-fixture>=0.6.3
pytorch-lightning>=1.7.7
scikit-learn>=1.1.2
scipy>=1.9.1
soundfile>=0.11.0
torch>=1.12.1
torch-audiomentations==0.11.0
torchaudio>=0.12.1
tqdm>=4.64.1

100
setup.cfg Normal file
View File

@ -0,0 +1,100 @@
# This file is used to configure your project.
# Read more about the various options under:
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
name = enhancer
description = Deep learning for speech enhacement
author = Shahul Ess
author-email = shahules786@gmail.com
license = mit
long-description = file: README.md
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
# Change if running only on Windows, Mac or Linux (comma-separated)
platforms = Linux, Mac
# Add here all kinds of additional classifiers as defined under
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers =
Development Status :: 4 - Beta
Programming Language :: Python
[options]
zip_safe = False
packages = find:
include_package_data = True
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
setup_requires = setuptools
# Add here dependencies of your project (semicolon/line-separated), e.g.
# install_requires = numpy; scipy
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8
[options.packages.find]
where = .
exclude =
tests
[options.extras_require]
# Add here additional requirements for extra features, to install with:
# `pip install fastaudio[PDF]` like:
# PDF = ReportLab; RXP
# Add here test requirements (semicolon/line-separated)
testing =
pytest>=7.1.3
pytest-cov>=4.0.0
dev =
pre-commit>=2.20.0
black>=22.8.0
flake8>=5.0.4
cli =
hydra-core >=1.1,<=1.2
[options.entry_points]
console_scripts =
enhancer-train=enhancer.cli.train:train
[test]
# py.test options when running `python setup.py test`
# addopts = --verbose
extras = True
[tool:pytest]
# Options for py.test:
# Specify command line options as you would do when invoking py.test directly.
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
# in order to write a coverage file that can be read by Jenkins.
addopts =
--cov enhancer --cov-report term-missing
--verbose
norecursedirs =
dist
build
.tox
testpaths = tests
[aliases]
dists = bdist_wheel
[bdist_wheel]
# Use this option if your package is pure-python
universal = 1
[build_sphinx]
source_dir = doc
build_dir = build/sphinx
[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
# VCS export must be deactivated since we are using setuptools-scm
no-vcs = 1
formats = bdist_wheel
[flake8]
# Some sane defaults for the code style checker flake8
exclude =
.tox
build
dist
.eggs

63
setup.py Normal file
View File

@ -0,0 +1,63 @@
import os
import sys
from pathlib import Path
from pkg_resources import VersionConflict, require
from setuptools import find_packages, setup
with open("README.md") as f:
long_description = f.read()
with open("requirements.txt") as f:
requirements = f.read().splitlines()
try:
require("setuptools>=38.3")
except VersionConflict:
print("Error: version of setuptools is too old (<38.3)!")
sys.exit(1)
ROOT_DIR = Path(__file__).parent.resolve()
# Creating the version file
with open("version.txt") as f:
version = f.read()
version = version.strip()
sha = "Unknown"
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
print("-- Building version " + version)
version_path = ROOT_DIR / "enhancer" / "version.py"
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))
if __name__ == "__main__":
setup(
name="enhancer",
namespace_packages=["enhancer"],
version=version,
packages=find_packages(),
install_requires=requirements,
description="Deep learning toolkit for speech enhancement",
long_description=long_description,
long_description_content_type="text/markdown",
author="Shahul Es",
author_email="shahules786@gmail.com",
url="",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
],
)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,32 @@
import pytest
import torch
from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun):
batch_size = 4
shape = (1, 1000)
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
with pytest.raises(TypeError):
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_output_type(loss):
batch_size = 4
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
batch_size, 1, 1000
)
loss_value = loss(prediction, target)
assert isinstance(loss_value.item(), float)

View File

@ -0,0 +1,50 @@
import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.rnn import ComplexLSTM
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d():
sample_input = torch.rand(1, 2, 256, 13)
conv = ComplexConv2d(
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 13])
def test_complexconvtranspose2d():
sample_input = torch.rand(1, 512, 4, 13)
conv = ComplexConvTranspose2d(
256 * 2,
128 * 2,
kernel_size=(5, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 256, 8, 14])
def test_complexlstm():
sample_input = torch.rand(13, 2, 128)
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
with torch.no_grad():
out = lstm(sample_input)
assert out[0].shape == torch.Size([13, 1, 512])
assert out[1].shape == torch.Size([13, 1, 512])
def test_complexbatchnorm2d():
sample_input = torch.rand(1, 64, 64, 14)
batchnorm = ComplexBatchNorm2D(num_features=64)
with torch.no_grad():
out = batchnorm(sample_input)
assert out.size() == sample_input.size()

View File

@ -0,0 +1,43 @@
import pytest
import torch
from enhancer.data.dataset import EnhancerDataset
from enhancer.models import Demucs
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = Demucs()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(ValueError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = Demucs(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -0,0 +1,43 @@
import pytest
import torch
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.dccrn import DCCRN
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = DCCRN()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(ValueError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -0,0 +1,43 @@
import pytest
import torch
from enhancer.data.dataset import EnhancerDataset
from enhancer.models import WaveUnet
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = WaveUnet()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(TypeError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)

29
tests/test_inference.py Normal file
View File

@ -0,0 +1,29 @@
import pytest
import torch
from enhancer.inference import Inference
@pytest.mark.parametrize(
"audio",
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
)
def test_read_input(audio):
read_audio = Inference.read_input(audio, 48000, 16000)
assert isinstance(read_audio, torch.Tensor)
assert read_audio.shape[0] == 1
def test_batchify():
rand = torch.rand(1, 1000)
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
assert batched_rand.shape[0] == 12
def test_aggregate():
rand = torch.rand(12, 1, 100)
agg_rand = Inference.aggreagate(
data=rand, window_size=100, total_frames=1000, step_size=100
)
assert agg_rand.shape[-1] == 1000

18
tests/transforms_test.py Normal file
View File

@ -0,0 +1,18 @@
import torch
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
def test_stft_istft():
sample_input = torch.rand(1, 1, 16000)
stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
with torch.no_grad():
spectrogram = stft(sample_input)
waveform = istft(spectrogram)
assert sample_input.shape == waveform.shape
assert (
torch.isclose(waveform, sample_input).sum().item()
> sample_input.shape[-1] // 2
)

49
tests/utils_test.py Normal file
View File

@ -0,0 +1,49 @@
import numpy as np
import pytest
import torch
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils.io import Audio
def test_io_channel():
input_audio = np.random.rand(2, 32000)
audio = Audio(mono=True, return_tensor=False)
output_audio = audio(input_audio)
assert output_audio.shape[0] == 1
def test_io_resampling():
input_audio = np.random.rand(1, 32000)
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
input_audio = torch.rand(1, 32000)
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
def test_fileprocessor_vctk():
fp = Fileprocessor.from_name(
"vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
)
matching_dict = fp.prepare_matching_dict()
assert len(matching_dict) == 2
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir")
assert hasattr(fp.matching_function, "__call__")
def test_fileprocessor_invaliname():
with pytest.raises(ValueError):
_ = Fileprocessor.from_name(
"undefined", "clean_dir", "noisy_dir", 16000
).prepare_matching_dict()

1
version.txt Normal file
View File

@ -0,0 +1 @@
0.0.1