commit
1d366d6096
|
|
@ -0,0 +1,9 @@
|
||||||
|
[flake8]
|
||||||
|
per-file-ignores = __init__.py:F401
|
||||||
|
ignore = E203, E266, E501, W503
|
||||||
|
# line length is intentionally set to 80 here because black uses Bugbear
|
||||||
|
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
||||||
|
max-line-length = 80
|
||||||
|
max-complexity = 18
|
||||||
|
select = B,C,E,F,W,T4,B9
|
||||||
|
exclude = tools/kaldi_decoder
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||||
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
|
name: Enhancer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ dev ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ dev ]
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.8]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v1.1.1
|
||||||
|
env :
|
||||||
|
ACTIONS_ALLOW_UNSECURE_COMMANDS : true
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Cache pip
|
||||||
|
uses: actions/cache@v1
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip # This path is specific to Ubuntu
|
||||||
|
# Look to see if there is a cache hit for the corresponding requirements file
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-
|
||||||
|
${{ runner.os }}-
|
||||||
|
# You can test your matrix by printing the current Python version
|
||||||
|
- name: Display Python version
|
||||||
|
run: python -c "import sys; print(sys.version)"
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
sudo apt-get install libsndfile1
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install black pytest-cov
|
||||||
|
- name: Install enhancer
|
||||||
|
run: |
|
||||||
|
pip install -e .[dev,testing]
|
||||||
|
- name: Run black
|
||||||
|
run:
|
||||||
|
black --check . --exclude enhancer/version.py
|
||||||
|
- name: Test with pytest
|
||||||
|
run:
|
||||||
|
pytest tests --cov=enhancer/
|
||||||
|
|
@ -1,3 +1,10 @@
|
||||||
|
#local
|
||||||
|
*.ckpt
|
||||||
|
*_local.yaml
|
||||||
|
cli/train_config/dataset/Vctk_local.yaml
|
||||||
|
.DS_Store
|
||||||
|
outputs/
|
||||||
|
datasets/
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
|
||||||
|
repos:
|
||||||
|
# # Clean Notebooks
|
||||||
|
# - repo: https://github.com/kynan/nbstripout
|
||||||
|
# rev: master
|
||||||
|
# hooks:
|
||||||
|
# - id: nbstripout
|
||||||
|
# Format Code
|
||||||
|
- repo: https://github.com/ambv/black
|
||||||
|
rev: 22.8.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
|
||||||
|
# Sort imports
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 5.10.1
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args: ["--profile", "black"]
|
||||||
|
|
||||||
|
- repo: https://gitlab.com/pycqa/flake8
|
||||||
|
rev: 5.0.4
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args: ['--ignore=E203,E501,F811,E712,W503']
|
||||||
|
|
||||||
|
# Formatting, Whitespace, etc
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v3.2.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb=1000']
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-json
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-xml
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
- id: mixed-line-ending
|
||||||
|
args: ['--fix=no']
|
||||||
44
README.md
44
README.md
|
|
@ -1 +1,43 @@
|
||||||
# enhancer
|
<p align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
|
||||||
|
</p>
|
||||||
|
|
||||||
|
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
|
||||||
|
|
||||||
|
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
||||||
|
## Key features :key:
|
||||||
|
|
||||||
|
* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
|
||||||
|
* :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
|
||||||
|
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
|
||||||
|
* :zap: Supports multi-gpu training integrated with Pytorch Lightning.
|
||||||
|
|
||||||
|
## Quick Start :fire:
|
||||||
|
``` python
|
||||||
|
from mayavoz import Mayamodel
|
||||||
|
|
||||||
|
model = Mayamodel.from_pretrained("mayavoz/waveunet")
|
||||||
|
model("noisy_audio.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
|
||||||
|
|
||||||
|
- With Pypi
|
||||||
|
```
|
||||||
|
pip install mayavoz
|
||||||
|
```
|
||||||
|
|
||||||
|
- With conda
|
||||||
|
|
||||||
|
```
|
||||||
|
conda env create -f environment.yml
|
||||||
|
conda activate mayavoz
|
||||||
|
```
|
||||||
|
|
||||||
|
- From source code
|
||||||
|
```
|
||||||
|
git clone url
|
||||||
|
cd mayavoz
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
import os
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from pytorch_lightning.callbacks import (
|
||||||
|
EarlyStopping,
|
||||||
|
LearningRateMonitor,
|
||||||
|
ModelCheckpoint,
|
||||||
|
)
|
||||||
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
# from torch_audiomentations import Compose, Shift
|
||||||
|
|
||||||
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="train_config", config_name="config")
|
||||||
|
def main(config: DictConfig):
|
||||||
|
|
||||||
|
OmegaConf.save(config, "config_log.yaml")
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
logger = MLFlowLogger(
|
||||||
|
experiment_name=config.mlflow.experiment_name,
|
||||||
|
run_name=config.mlflow.run_name,
|
||||||
|
tags={"JOB_ID": JOB_ID},
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters = config.hyperparameters
|
||||||
|
# apply_augmentations = Compose(
|
||||||
|
# [
|
||||||
|
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
|
||||||
|
dataset = instantiate(config.dataset, augmentations=None)
|
||||||
|
model = instantiate(
|
||||||
|
config.model,
|
||||||
|
dataset=dataset,
|
||||||
|
lr=parameters.get("lr"),
|
||||||
|
loss=parameters.get("loss"),
|
||||||
|
metric=parameters.get("metric"),
|
||||||
|
)
|
||||||
|
|
||||||
|
direction = model.valid_monitor
|
||||||
|
checkpoint = ModelCheckpoint(
|
||||||
|
dirpath="./model",
|
||||||
|
filename=f"model_{JOB_ID}",
|
||||||
|
monitor="valid_loss",
|
||||||
|
verbose=False,
|
||||||
|
mode=direction,
|
||||||
|
every_n_epochs=1,
|
||||||
|
)
|
||||||
|
callbacks.append(checkpoint)
|
||||||
|
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||||
|
|
||||||
|
if parameters.get("Early_stop", False):
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor="val_loss",
|
||||||
|
mode=direction,
|
||||||
|
min_delta=0.0,
|
||||||
|
patience=parameters.get("EarlyStopping_patience", 10),
|
||||||
|
strict=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = instantiate(
|
||||||
|
config.optimizer,
|
||||||
|
lr=parameters.get("lr"),
|
||||||
|
params=self.parameters(),
|
||||||
|
)
|
||||||
|
scheduler = ReduceLROnPlateau(
|
||||||
|
optimizer=optimizer,
|
||||||
|
mode=direction,
|
||||||
|
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||||
|
verbose=True,
|
||||||
|
min_lr=parameters.get("min_lr", 1e-6),
|
||||||
|
patience=parameters.get("ReduceLr_patience", 3),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": scheduler,
|
||||||
|
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||||
|
}
|
||||||
|
|
||||||
|
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||||
|
|
||||||
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
|
trainer.fit(model)
|
||||||
|
trainer.test(model)
|
||||||
|
|
||||||
|
logger.experiment.log_artifact(
|
||||||
|
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_location = os.path.join(
|
||||||
|
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||||
|
)
|
||||||
|
if os.path.isfile(saved_location):
|
||||||
|
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||||
|
logger.experiment.log_param(
|
||||||
|
logger.run_id,
|
||||||
|
"num_train_steps_per_epoch",
|
||||||
|
dataset.train__len__() / dataset.batch_size,
|
||||||
|
)
|
||||||
|
logger.experiment.log_param(
|
||||||
|
logger.run_id,
|
||||||
|
"num_valid_steps_per_epoch",
|
||||||
|
dataset.val__len__() / dataset.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
defaults:
|
||||||
|
- model : Demucs
|
||||||
|
- dataset : Vctk
|
||||||
|
- optimizer : Adam
|
||||||
|
- hyperparameters : default
|
||||||
|
- trainer : default
|
||||||
|
- mlflow : experiment
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
|
name : dns-2020
|
||||||
|
duration : 2.0
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 32
|
||||||
|
valid_size: 0.05
|
||||||
|
files:
|
||||||
|
train_clean : CleanSpeech_training
|
||||||
|
test_clean : CleanSpeech_training
|
||||||
|
train_noisy : NoisySpeech_training
|
||||||
|
test_noisy : NoisySpeech_training
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
name : vctk
|
||||||
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
|
duration : 4.5
|
||||||
|
stride : 2
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 32
|
||||||
|
valid_minutes : 15
|
||||||
|
files:
|
||||||
|
train_clean : clean_trainset_28spk_wav
|
||||||
|
test_clean : clean_testset_wav
|
||||||
|
train_noisy : noisy_trainset_28spk_wav
|
||||||
|
test_noisy : noisy_testset_wav
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
name : vctk
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
|
||||||
|
duration : 1.0
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 64
|
||||||
|
num_workers : 0
|
||||||
|
|
||||||
|
files:
|
||||||
|
train_clean : clean_testset_wav
|
||||||
|
test_clean : clean_testset_wav
|
||||||
|
train_noisy : noisy_testset_wav
|
||||||
|
test_noisy : noisy_testset_wav
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
loss : mae
|
||||||
|
metric : [stoi,pesq,si-sdr]
|
||||||
|
lr : 0.0003
|
||||||
|
ReduceLr_patience : 5
|
||||||
|
ReduceLr_factor : 0.2
|
||||||
|
min_lr : 0.000001
|
||||||
|
EarlyStopping_factor : 10
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
experiment_name : shahules/enhancer
|
||||||
|
run_name : Demucs + Vtck with stride + augmentations
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
_target_: enhancer.models.dccrn.DCCRN
|
||||||
|
num_channels: 1
|
||||||
|
sampling_rate : 16000
|
||||||
|
complex_lstm : True
|
||||||
|
complex_norm : True
|
||||||
|
complex_relu : True
|
||||||
|
masking_mode : True
|
||||||
|
|
||||||
|
encoder_decoder:
|
||||||
|
initial_output_channels : 32
|
||||||
|
depth : 6
|
||||||
|
kernel_size : 5
|
||||||
|
growth_factor : 2
|
||||||
|
stride : 2
|
||||||
|
padding : 2
|
||||||
|
output_padding : 1
|
||||||
|
|
||||||
|
lstm:
|
||||||
|
num_layers : 2
|
||||||
|
hidden_size : 256
|
||||||
|
|
||||||
|
stft:
|
||||||
|
window_len : 400
|
||||||
|
hop_size : 100
|
||||||
|
nfft : 512
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
_target_: enhancer.models.demucs.Demucs
|
||||||
|
num_channels: 1
|
||||||
|
resample: 4
|
||||||
|
sampling_rate : 16000
|
||||||
|
|
||||||
|
encoder_decoder:
|
||||||
|
depth: 4
|
||||||
|
initial_output_channels: 64
|
||||||
|
kernel_size: 8
|
||||||
|
stride: 4
|
||||||
|
growth_factor: 2
|
||||||
|
glu: True
|
||||||
|
|
||||||
|
lstm:
|
||||||
|
bidirectional: False
|
||||||
|
num_layers: 2
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
_target_: enhancer.models.waveunet.WaveUnet
|
||||||
|
num_channels : 1
|
||||||
|
depth : 9
|
||||||
|
initial_output_channels: 24
|
||||||
|
sampling_rate : 16000
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
_target_: torch.optim.Adam
|
||||||
|
lr: 1e-3
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
eps: 1e-08
|
||||||
|
weight_decay: 0
|
||||||
|
amsgrad: False
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
_target_: pytorch_lightning.Trainer
|
||||||
|
accelerator: gpu
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
amp_backend: native
|
||||||
|
auto_lr_find: True
|
||||||
|
auto_scale_batch_size: False
|
||||||
|
auto_select_gpus: True
|
||||||
|
benchmark: False
|
||||||
|
check_val_every_n_epoch: 1
|
||||||
|
detect_anomaly: False
|
||||||
|
deterministic: False
|
||||||
|
devices: 2
|
||||||
|
enable_checkpointing: True
|
||||||
|
enable_model_summary: True
|
||||||
|
enable_progress_bar: True
|
||||||
|
fast_dev_run: False
|
||||||
|
gpus: null
|
||||||
|
gradient_clip_val: 0
|
||||||
|
gradient_clip_algorithm: norm
|
||||||
|
ipus: null
|
||||||
|
limit_predict_batches: 1.0
|
||||||
|
limit_test_batches: 1.0
|
||||||
|
limit_train_batches: 1.0
|
||||||
|
limit_val_batches: 1.0
|
||||||
|
log_every_n_steps: 50
|
||||||
|
max_epochs: 200
|
||||||
|
max_steps: -1
|
||||||
|
max_time: null
|
||||||
|
min_epochs: 1
|
||||||
|
min_steps: null
|
||||||
|
move_metrics_to_cpu: False
|
||||||
|
multiple_trainloader_mode: max_size_cycle
|
||||||
|
num_nodes: 1
|
||||||
|
num_processes: 1
|
||||||
|
num_sanity_val_steps: 2
|
||||||
|
overfit_batches: 0.0
|
||||||
|
precision: 32
|
||||||
|
profiler: null
|
||||||
|
reload_dataloaders_every_n_epochs: 0
|
||||||
|
replace_sampler_ddp: True
|
||||||
|
strategy: ddp
|
||||||
|
sync_batchnorm: False
|
||||||
|
tpu_cores: null
|
||||||
|
track_grad_norm: -1
|
||||||
|
val_check_interval: 1.0
|
||||||
|
weights_save_path: null
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
_target_: pytorch_lightning.Trainer
|
||||||
|
fast_dev_run: True
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
@ -0,0 +1,376 @@
|
||||||
|
import math
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
|
from torch_audiomentations import Compose
|
||||||
|
|
||||||
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
from enhancer.utils import check_files
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
|
LARGE_NUM = 2147483647
|
||||||
|
|
||||||
|
|
||||||
|
class TrainDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.train__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.train__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class ValidDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.val__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.val__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.test__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.test__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDataset(pl.LightningDataModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
root_dir: str,
|
||||||
|
files: Files,
|
||||||
|
min_valid_minutes: float = 0.20,
|
||||||
|
duration: float = 1.0,
|
||||||
|
stride=None,
|
||||||
|
sampling_rate: int = 48000,
|
||||||
|
matching_function=None,
|
||||||
|
batch_size=32,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
augmentations: Optional[Compose] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.files, self.root_dir = check_files(root_dir, files)
|
||||||
|
self.duration = duration
|
||||||
|
self.stride = stride or duration
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.matching_function = matching_function
|
||||||
|
self._validation = []
|
||||||
|
if num_workers is None:
|
||||||
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
|
self.num_workers = num_workers
|
||||||
|
if min_valid_minutes > 0.0:
|
||||||
|
self.min_valid_minutes = min_valid_minutes
|
||||||
|
else:
|
||||||
|
raise ValueError("min_valid_minutes must be greater than 0")
|
||||||
|
|
||||||
|
self.augmentations = augmentations
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
prepare train/validation/test data splits
|
||||||
|
"""
|
||||||
|
|
||||||
|
if stage in ("fit", None):
|
||||||
|
|
||||||
|
train_clean = os.path.join(self.root_dir, self.files.train_clean)
|
||||||
|
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
|
)
|
||||||
|
train_data = fp.prepare_matching_dict()
|
||||||
|
train_data, self.val_data = self.train_valid_split(
|
||||||
|
train_data,
|
||||||
|
min_valid_minutes=self.min_valid_minutes,
|
||||||
|
random_state=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train_data = self.prepare_traindata(train_data)
|
||||||
|
self._validation = self.prepare_mapstype(self.val_data)
|
||||||
|
|
||||||
|
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||||
|
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
self.name, test_clean, test_noisy, self.matching_function
|
||||||
|
)
|
||||||
|
test_data = fp.prepare_matching_dict()
|
||||||
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
|
def train_valid_split(
|
||||||
|
self, data, min_valid_minutes: float = 20, random_state: int = 42
|
||||||
|
):
|
||||||
|
|
||||||
|
min_valid_minutes *= 60
|
||||||
|
valid_sec_now = 0.0
|
||||||
|
valid_indices = []
|
||||||
|
all_speakers = np.unique(
|
||||||
|
[Path(file["clean"]).name.split("_")[0] for file in data]
|
||||||
|
)
|
||||||
|
possible_indices = list(range(0, len(all_speakers)))
|
||||||
|
rng = create_unique_rng(len(all_speakers))
|
||||||
|
|
||||||
|
while valid_sec_now <= min_valid_minutes:
|
||||||
|
speaker_index = rng.choice(possible_indices)
|
||||||
|
possible_indices.remove(speaker_index)
|
||||||
|
speaker_name = all_speakers[speaker_index]
|
||||||
|
print(f"Selected f{speaker_name} for valid")
|
||||||
|
file_indices = [
|
||||||
|
i
|
||||||
|
for i, file in enumerate(data)
|
||||||
|
if speaker_name == Path(file["clean"]).name.split("_")[0]
|
||||||
|
]
|
||||||
|
for i in file_indices:
|
||||||
|
valid_indices.append(i)
|
||||||
|
valid_sec_now += data[i]["duration"]
|
||||||
|
|
||||||
|
train_data = [
|
||||||
|
item for i, item in enumerate(data) if i not in valid_indices
|
||||||
|
]
|
||||||
|
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
||||||
|
return train_data, valid_data
|
||||||
|
|
||||||
|
def prepare_traindata(self, data):
|
||||||
|
train_data = []
|
||||||
|
for item in data:
|
||||||
|
clean, noisy, total_dur = item.values()
|
||||||
|
num_segments = self.get_num_segments(
|
||||||
|
total_dur, self.duration, self.stride
|
||||||
|
)
|
||||||
|
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
|
||||||
|
train_data.append(samples_metadata)
|
||||||
|
return train_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_num_segments(file_duration, duration, stride):
|
||||||
|
|
||||||
|
if file_duration < duration:
|
||||||
|
num_segments = 1
|
||||||
|
else:
|
||||||
|
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||||
|
|
||||||
|
return num_segments
|
||||||
|
|
||||||
|
def prepare_mapstype(self, data):
|
||||||
|
|
||||||
|
metadata = []
|
||||||
|
for item in data:
|
||||||
|
clean, noisy, total_dur = item.values()
|
||||||
|
if total_dur < self.duration:
|
||||||
|
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||||
|
else:
|
||||||
|
num_segments = self.get_num_segments(
|
||||||
|
total_dur, self.duration, self.duration
|
||||||
|
)
|
||||||
|
for index in range(num_segments):
|
||||||
|
start_time = index * self.duration
|
||||||
|
metadata.append(
|
||||||
|
({"clean": clean, "noisy": noisy}, start_time)
|
||||||
|
)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def train_collatefn(self, batch):
|
||||||
|
|
||||||
|
output = {"clean": [], "noisy": []}
|
||||||
|
for item in batch:
|
||||||
|
output["clean"].append(item["clean"])
|
||||||
|
output["noisy"].append(item["noisy"])
|
||||||
|
|
||||||
|
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||||
|
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||||
|
|
||||||
|
if self.augmentations is not None:
|
||||||
|
noise = output["noisy"] - output["clean"]
|
||||||
|
output["clean"] = self.augmentations(
|
||||||
|
output["clean"], sample_rate=self.sampling_rate
|
||||||
|
)
|
||||||
|
self.augmentations.freeze_parameters()
|
||||||
|
output["noisy"] = (
|
||||||
|
self.augmentations(noise, sample_rate=self.sampling_rate)
|
||||||
|
+ output["clean"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@property
|
||||||
|
def generator(self):
|
||||||
|
generator = torch.Generator()
|
||||||
|
if hasattr(self, "model"):
|
||||||
|
seed = self.model.current_epoch + LARGE_NUM
|
||||||
|
else:
|
||||||
|
seed = LARGE_NUM
|
||||||
|
return generator.manual_seed(seed)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
dataset = TrainDataset(self)
|
||||||
|
sampler = RandomSampler(dataset, generator=self.generator)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=self.train_collatefn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
ValidDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
TestDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancerDataset(TaskDataset):
|
||||||
|
"""
|
||||||
|
Dataset object for creating clean-noisy speech enhancement datasets
|
||||||
|
paramters:
|
||||||
|
name : str
|
||||||
|
name of the dataset
|
||||||
|
root_dir : str
|
||||||
|
root directory of the dataset containing clean/noisy folders
|
||||||
|
files : Files
|
||||||
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
|
folder names (refer enhancer.utils.Files dataclass)
|
||||||
|
min_valid_minutes: float
|
||||||
|
minimum validation split size time in minutes
|
||||||
|
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
||||||
|
duration : float
|
||||||
|
expected audio duration of single audio sample for training
|
||||||
|
sampling_rate : int
|
||||||
|
desired sampling rate
|
||||||
|
batch_size : int
|
||||||
|
batch size of each batch
|
||||||
|
num_workers : int
|
||||||
|
num workers to be used while training
|
||||||
|
matching_function : str
|
||||||
|
maching functions - (one_to_one,one_to_many). Default set to None.
|
||||||
|
use one_to_one mapping for datasets with one noisy file for each clean file
|
||||||
|
use one_to_many mapping for multiple noisy files for each clean file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
root_dir: str,
|
||||||
|
files: Files,
|
||||||
|
min_valid_minutes=5.0,
|
||||||
|
duration=1.0,
|
||||||
|
stride=None,
|
||||||
|
sampling_rate=48000,
|
||||||
|
matching_function=None,
|
||||||
|
batch_size=32,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
augmentations: Optional[Compose] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
root_dir=root_dir,
|
||||||
|
files=files,
|
||||||
|
min_valid_minutes=min_valid_minutes,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
duration=duration,
|
||||||
|
matching_function=matching_function,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
augmentations=augmentations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.files = files
|
||||||
|
self.duration = max(1.0, duration)
|
||||||
|
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||||
|
self.stride = stride or duration
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
super().setup(stage=stage)
|
||||||
|
|
||||||
|
def train__getitem__(self, idx):
|
||||||
|
|
||||||
|
for filedict, num_samples in self.train_data:
|
||||||
|
if idx >= num_samples:
|
||||||
|
idx -= num_samples
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
start = 0
|
||||||
|
if self.duration is not None:
|
||||||
|
start = idx * self.stride
|
||||||
|
return self.prepare_segment(filedict, start)
|
||||||
|
|
||||||
|
def val__getitem__(self, idx):
|
||||||
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
||||||
|
def test__getitem__(self, idx):
|
||||||
|
return self.prepare_segment(*self._test[idx])
|
||||||
|
|
||||||
|
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||||
|
clean_segment = self.audio(
|
||||||
|
file_dict["clean"], offset=start_time, duration=self.duration
|
||||||
|
)
|
||||||
|
noisy_segment = self.audio(
|
||||||
|
file_dict["noisy"], offset=start_time, duration=self.duration
|
||||||
|
)
|
||||||
|
clean_segment = F.pad(
|
||||||
|
clean_segment,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
int(
|
||||||
|
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
noisy_segment = F.pad(
|
||||||
|
noisy_segment,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
int(
|
||||||
|
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"clean": clean_segment,
|
||||||
|
"noisy": noisy_segment,
|
||||||
|
}
|
||||||
|
|
||||||
|
def train__len__(self):
|
||||||
|
_, num_examples = list(zip(*self.train_data))
|
||||||
|
return sum(num_examples)
|
||||||
|
|
||||||
|
def val__len__(self):
|
||||||
|
return len(self._validation)
|
||||||
|
|
||||||
|
def test__len__(self):
|
||||||
|
return len(self._test)
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
MATCHING_FNS = ("one_to_one", "one_to_many")
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessorFunctions:
|
||||||
|
"""
|
||||||
|
Preprocessing methods for different types of speech enhacement datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def one_to_one(clean_path, noisy_path):
|
||||||
|
"""
|
||||||
|
One clean audio can have only one noisy audio file
|
||||||
|
"""
|
||||||
|
|
||||||
|
matching_wavfiles = list()
|
||||||
|
clean_filenames = [
|
||||||
|
file.split("/")[-1]
|
||||||
|
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||||
|
]
|
||||||
|
noisy_filenames = [
|
||||||
|
file.split("/")[-1]
|
||||||
|
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
|
||||||
|
]
|
||||||
|
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
|
||||||
|
|
||||||
|
for file_name in common_filenames:
|
||||||
|
|
||||||
|
sr_clean, clean_file = wavfile.read(
|
||||||
|
os.path.join(clean_path, file_name)
|
||||||
|
)
|
||||||
|
sr_noisy, noisy_file = wavfile.read(
|
||||||
|
os.path.join(noisy_path, file_name)
|
||||||
|
)
|
||||||
|
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||||
|
sr_clean == sr_noisy
|
||||||
|
):
|
||||||
|
matching_wavfiles.append(
|
||||||
|
{
|
||||||
|
"clean": os.path.join(clean_path, file_name),
|
||||||
|
"noisy": os.path.join(noisy_path, file_name),
|
||||||
|
"duration": clean_file.shape[-1] / sr_clean,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return matching_wavfiles
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def one_to_many(clean_path, noisy_path):
|
||||||
|
"""
|
||||||
|
One clean audio have multiple noisy audio files
|
||||||
|
"""
|
||||||
|
|
||||||
|
matching_wavfiles = list()
|
||||||
|
clean_filenames = [
|
||||||
|
file.split("/")[-1]
|
||||||
|
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||||
|
]
|
||||||
|
for clean_file in clean_filenames:
|
||||||
|
noisy_filenames = glob.glob(
|
||||||
|
os.path.join(noisy_path, f"*_{clean_file}")
|
||||||
|
)
|
||||||
|
for noisy_file in noisy_filenames:
|
||||||
|
|
||||||
|
sr_clean, clean_wav = wavfile.read(
|
||||||
|
os.path.join(clean_path, clean_file)
|
||||||
|
)
|
||||||
|
sr_noisy, noisy_wav = wavfile.read(noisy_file)
|
||||||
|
if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and (
|
||||||
|
sr_clean == sr_noisy
|
||||||
|
):
|
||||||
|
matching_wavfiles.append(
|
||||||
|
{
|
||||||
|
"clean": os.path.join(clean_path, clean_file),
|
||||||
|
"noisy": noisy_file,
|
||||||
|
"duration": clean_wav.shape[-1] / sr_clean,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return matching_wavfiles
|
||||||
|
|
||||||
|
|
||||||
|
class Fileprocessor:
|
||||||
|
def __init__(self, clean_dir, noisy_dir, matching_function=None):
|
||||||
|
self.clean_dir = clean_dir
|
||||||
|
self.noisy_dir = noisy_dir
|
||||||
|
self.matching_function = matching_function
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||||
|
|
||||||
|
if matching_function is None:
|
||||||
|
if name.lower() == "vctk":
|
||||||
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||||
|
elif name.lower() == "dns-2020":
|
||||||
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if matching_function not in MATCHING_FNS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return cls(
|
||||||
|
clean_dir,
|
||||||
|
noisy_dir,
|
||||||
|
getattr(ProcessorFunctions, matching_function),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_matching_dict(self):
|
||||||
|
|
||||||
|
if self.matching_function is None:
|
||||||
|
raise ValueError("Not a valid matching function")
|
||||||
|
|
||||||
|
return self.matching_function(self.clean_dir, self.noisy_dir)
|
||||||
|
|
@ -0,0 +1,170 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from librosa import load as load_audio
|
||||||
|
from scipy.io import wavfile
|
||||||
|
from scipy.signal import get_window
|
||||||
|
|
||||||
|
from enhancer.utils import Audio
|
||||||
|
|
||||||
|
|
||||||
|
class Inference:
|
||||||
|
"""
|
||||||
|
contains methods used for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_input(audio, sr, model_sr):
|
||||||
|
"""
|
||||||
|
read and verify audio input regardless of the input format.
|
||||||
|
arguments:
|
||||||
|
audio : audio input
|
||||||
|
sr : sampling rate of input audio
|
||||||
|
model_sr : sampling rate used for model training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
||||||
|
assert sr is not None, "Invalid sampling rate!"
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1, -1)
|
||||||
|
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = Path(audio)
|
||||||
|
if not audio.is_file():
|
||||||
|
raise ValueError(f"Input file {audio} does not exist")
|
||||||
|
else:
|
||||||
|
audio, sr = load_audio(
|
||||||
|
audio,
|
||||||
|
sr=sr,
|
||||||
|
)
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1, -1)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
audio.shape[0] == 1
|
||||||
|
), "Enhance inference only supports single waveform"
|
||||||
|
|
||||||
|
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
|
||||||
|
waveform = Audio.convert_mono(waveform)
|
||||||
|
if isinstance(waveform, np.ndarray):
|
||||||
|
waveform = torch.from_numpy(waveform)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batchify(
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
break input waveform into samples with duration specified.(Overlap-add)
|
||||||
|
arguments:
|
||||||
|
waveform : audio waveform
|
||||||
|
window_size : window size used for splitting waveform into batches
|
||||||
|
step_size : step_size used for splitting waveform into batches
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
waveform.ndim == 2
|
||||||
|
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
||||||
|
_, num_samples = waveform.shape
|
||||||
|
waveform = waveform.unsqueeze(-1)
|
||||||
|
step_size = window_size // 2 if step_size is None else step_size
|
||||||
|
|
||||||
|
if num_samples >= window_size:
|
||||||
|
waveform_batch = F.unfold(
|
||||||
|
waveform[None, ...],
|
||||||
|
kernel_size=(window_size, 1),
|
||||||
|
stride=(step_size, 1),
|
||||||
|
padding=(window_size, 0),
|
||||||
|
)
|
||||||
|
waveform_batch = waveform_batch.permute(2, 0, 1)
|
||||||
|
|
||||||
|
return waveform_batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def aggreagate(
|
||||||
|
data: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
total_frames: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
window="hamming",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
stitch batched waveform into single waveform. (Overlap-add)
|
||||||
|
arguments:
|
||||||
|
data: batched waveform
|
||||||
|
window_size : window_size used to batch waveform
|
||||||
|
step_size : step_size used to batch waveform
|
||||||
|
total_frames : total number of frames present in original waveform
|
||||||
|
window : type of window used for overlap-add mechanism.
|
||||||
|
"""
|
||||||
|
num_chunks, n_channels, num_frames = data.shape
|
||||||
|
window = get_window(window=window, Nx=data.shape[-1])
|
||||||
|
window = torch.from_numpy(window).to(data.device)
|
||||||
|
data *= window
|
||||||
|
step_size = window_size // 2 if step_size is None else step_size
|
||||||
|
|
||||||
|
data = data.permute(1, 2, 0)
|
||||||
|
data = F.fold(
|
||||||
|
data,
|
||||||
|
(total_frames, 1),
|
||||||
|
kernel_size=(window_size, 1),
|
||||||
|
stride=(step_size, 1),
|
||||||
|
padding=(window_size, 0),
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
return data.reshape(1, n_channels, -1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_output(
|
||||||
|
waveform: torch.Tensor, filename: Union[str, Path], sr: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
write audio output as wav file
|
||||||
|
arguments:
|
||||||
|
waveform : audio waveform
|
||||||
|
filename : name of the wave file. Output will be written as cleaned_filename.wav
|
||||||
|
sr : sampling rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(filename, str):
|
||||||
|
filename = Path(filename)
|
||||||
|
|
||||||
|
parent, name = filename.parent, "cleaned_" + filename.name
|
||||||
|
filename = parent / Path(name)
|
||||||
|
if filename.is_file():
|
||||||
|
raise FileExistsError(f"file {filename} already exists")
|
||||||
|
else:
|
||||||
|
wavfile.write(
|
||||||
|
filename, rate=sr, data=waveform.detach().cpu().numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_output(
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
model_sampling_rate: int,
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: Optional[int],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
prepare output audio based on input format
|
||||||
|
arguments:
|
||||||
|
waveform : predicted audio waveform
|
||||||
|
model_sampling_rate : sampling rate used to train the model
|
||||||
|
audio : input audio
|
||||||
|
sampling_rate : input audio sampling rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(audio, np.ndarray):
|
||||||
|
waveform = waveform.detach().cpu().numpy()
|
||||||
|
|
||||||
|
if sampling_rate is not None:
|
||||||
|
waveform = Audio.resample_audio(
|
||||||
|
waveform, sr=model_sampling_rate, target_sr=sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchmetrics import ScaleInvariantSignalNoiseRatio
|
||||||
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
||||||
|
class mean_squared_error(nn.Module):
|
||||||
|
"""
|
||||||
|
Mean squared error / L1 loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reduction="mean"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||||
|
self.higher_better = False
|
||||||
|
self.name = "mse"
|
||||||
|
|
||||||
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
|
raise TypeError(
|
||||||
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
class mean_absolute_error(nn.Module):
|
||||||
|
"""
|
||||||
|
Mean absolute error / L2 loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reduction="mean"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||||
|
self.higher_better = False
|
||||||
|
self.name = "mae"
|
||||||
|
|
||||||
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
|
raise TypeError(
|
||||||
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
class Si_SDR:
|
||||||
|
"""
|
||||||
|
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reduction: str = "mean"):
|
||||||
|
if reduction in ["sum", "mean", None]:
|
||||||
|
self.reduction = reduction
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Invalid reduction, valid options are sum, mean, None"
|
||||||
|
)
|
||||||
|
self.higher_better = True
|
||||||
|
self.name = "si-sdr"
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
|
raise TypeError(
|
||||||
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
|
)
|
||||||
|
|
||||||
|
target_energy = torch.sum(target**2, keepdim=True, dim=-1)
|
||||||
|
scaling_factor = (
|
||||||
|
torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
|
||||||
|
)
|
||||||
|
target_projection = target * scaling_factor
|
||||||
|
noise = prediction - target_projection
|
||||||
|
ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
|
||||||
|
noise**2, dim=-1
|
||||||
|
)
|
||||||
|
si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
|
||||||
|
|
||||||
|
if self.reduction == "sum":
|
||||||
|
si_sdr = si_sdr.sum()
|
||||||
|
elif self.reduction == "mean":
|
||||||
|
si_sdr = si_sdr.mean()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return si_sdr
|
||||||
|
|
||||||
|
|
||||||
|
class Stoi:
|
||||||
|
"""
|
||||||
|
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
|
||||||
|
Note that input will be moved to cpu to perform the metric calculation.
|
||||||
|
parameters:
|
||||||
|
sr: int
|
||||||
|
sampling rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sr: int):
|
||||||
|
self.sr = sr
|
||||||
|
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
|
||||||
|
self.name = "stoi"
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
return self.stoi(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
class Pesq:
|
||||||
|
def __init__(self, sr: int, mode="wb"):
|
||||||
|
|
||||||
|
self.sr = sr
|
||||||
|
self.name = "pesq"
|
||||||
|
self.mode = mode
|
||||||
|
self.pesq = PerceptualEvaluationSpeechQuality(
|
||||||
|
fs=self.sr, mode=self.mode
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
pesq_values = []
|
||||||
|
for pred, target_ in zip(prediction, target):
|
||||||
|
try:
|
||||||
|
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
|
return torch.tensor(np.mean(pesq_values))
|
||||||
|
|
||||||
|
|
||||||
|
class LossWrapper(nn.Module):
|
||||||
|
"""
|
||||||
|
Combine multiple metics of same nature.
|
||||||
|
for example, ["mea","mae"]
|
||||||
|
parameters:
|
||||||
|
losses : loss function names to be combined
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, losses):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.valid_losses = nn.ModuleList()
|
||||||
|
|
||||||
|
direction = [
|
||||||
|
getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
|
||||||
|
]
|
||||||
|
if len(set(direction)) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"all cost functions should be of same nature, maximize or minimize!"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.higher_better = direction[0]
|
||||||
|
self.name = ""
|
||||||
|
for loss in losses:
|
||||||
|
loss = self.validate_loss(loss)
|
||||||
|
self.valid_losses.append(loss())
|
||||||
|
self.name += f"{loss().name}_"
|
||||||
|
|
||||||
|
def validate_loss(self, loss: str):
|
||||||
|
if loss not in LOSS_MAP.keys():
|
||||||
|
raise ValueError(
|
||||||
|
f"""Invalid loss function {loss}, available loss functions are
|
||||||
|
{tuple([loss for loss in LOSS_MAP.keys()])}"""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return LOSS_MAP[loss]
|
||||||
|
|
||||||
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
loss = 0.0
|
||||||
|
for loss_fun in self.valid_losses:
|
||||||
|
loss += loss_fun(prediction, target)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class Si_snr(nn.Module):
|
||||||
|
"""
|
||||||
|
SI-SNR
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
||||||
|
self.higher_better = True
|
||||||
|
self.name = "si_snr"
|
||||||
|
|
||||||
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
|
raise TypeError(
|
||||||
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
LOSS_MAP = {
|
||||||
|
"mae": mean_absolute_error,
|
||||||
|
"mse": mean_squared_error,
|
||||||
|
"si-sdr": Si_SDR,
|
||||||
|
"pesq": Pesq,
|
||||||
|
"stoi": Stoi,
|
||||||
|
"si-snr": Si_snr,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from enhancer.models.demucs import Demucs
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
from enhancer.models.waveunet import WaveUnet
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
|
||||||
|
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
||||||
|
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
|
||||||
|
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
||||||
|
from enhancer.models.complexnn.utils import ComplexRelu # noqa
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(nnet):
|
||||||
|
nn.init.xavier_normal_(nnet.weight.data)
|
||||||
|
nn.init.constant_(nnet.bias, 0.0)
|
||||||
|
return nnet
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexConv2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Tuple[int, int] = (1, 1),
|
||||||
|
stride: Tuple[int, int] = (1, 1),
|
||||||
|
padding: Tuple[int, int] = (0, 0),
|
||||||
|
groups: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Complex Conv2d (non-causal)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels // 2
|
||||||
|
self.out_channels = out_channels // 2
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.groups = groups
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
self.real_conv = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=(self.padding[0], 0),
|
||||||
|
groups=self.groups,
|
||||||
|
dilation=self.dilation,
|
||||||
|
)
|
||||||
|
self.imag_conv = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=(self.padding[0], 0),
|
||||||
|
groups=self.groups,
|
||||||
|
dilation=self.dilation,
|
||||||
|
)
|
||||||
|
self.imag_conv = init_weights(self.imag_conv)
|
||||||
|
self.real_conv = init_weights(self.real_conv)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
"""
|
||||||
|
complex axis should be always 1 dim
|
||||||
|
"""
|
||||||
|
input = F.pad(input, [self.padding[1], 0, 0, 0])
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
|
real_real = self.real_conv(real)
|
||||||
|
real_imag = self.imag_conv(real)
|
||||||
|
|
||||||
|
imag_imag = self.imag_conv(imag)
|
||||||
|
imag_real = self.real_conv(imag)
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = real_imag - imag_real
|
||||||
|
|
||||||
|
out = torch.cat([real, imag], 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexConvTranspose2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Tuple[int, int] = (1, 1),
|
||||||
|
stride: Tuple[int, int] = (1, 1),
|
||||||
|
padding: Tuple[int, int] = (0, 0),
|
||||||
|
output_padding: Tuple[int, int] = (0, 0),
|
||||||
|
groups: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels // 2
|
||||||
|
self.out_channels = out_channels // 2
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.groups = groups
|
||||||
|
self.output_padding = output_padding
|
||||||
|
|
||||||
|
self.real_conv = nn.ConvTranspose2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
output_padding=self.output_padding,
|
||||||
|
groups=self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.imag_conv = nn.ConvTranspose2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
output_padding=self.output_padding,
|
||||||
|
groups=self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.real_conv = init_weights(self.real_conv)
|
||||||
|
self.imag_conv = init_weights(self.imag_conv)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
real_real = self.real_conv(real)
|
||||||
|
real_imag = self.imag_conv(real)
|
||||||
|
|
||||||
|
imag_imag = self.imag_conv(imag)
|
||||||
|
imag_real = self.real_conv(imag)
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = real_imag - imag_real
|
||||||
|
|
||||||
|
out = torch.cat([real, imag], 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexLSTM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
projection_size: Optional[int] = None,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size // 2
|
||||||
|
self.hidden_size = hidden_size // 2
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.real_lstm = nn.LSTM(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
self.imag_lstm = nn.LSTM(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
bidirectional = 2 if bidirectional else 1
|
||||||
|
if projection_size is not None:
|
||||||
|
self.projection_size = projection_size // 2
|
||||||
|
self.real_linear = nn.Linear(
|
||||||
|
self.hidden_size * bidirectional, self.projection_size
|
||||||
|
)
|
||||||
|
self.imag_linear = nn.Linear(
|
||||||
|
self.hidden_size * bidirectional, self.projection_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.projection_size = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if isinstance(input, List):
|
||||||
|
real, imag = input
|
||||||
|
else:
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
|
real_real = self.real_lstm(real)[0]
|
||||||
|
real_imag = self.imag_lstm(real)[0]
|
||||||
|
|
||||||
|
imag_imag = self.imag_lstm(imag)[0]
|
||||||
|
imag_real = self.real_lstm(imag)[0]
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = imag_real + real_imag
|
||||||
|
|
||||||
|
if self.projection_size is not None:
|
||||||
|
real = self.real_linear(real)
|
||||||
|
imag = self.imag_linear(imag)
|
||||||
|
|
||||||
|
return [real, imag]
|
||||||
|
|
@ -0,0 +1,199 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexBatchNorm2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
affine: bool = True,
|
||||||
|
track_running_stats: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Complex batch normalization 2D
|
||||||
|
https://arxiv.org/abs/1705.09792
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features // 2
|
||||||
|
self.affine = affine
|
||||||
|
self.momentum = momentum
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
else:
|
||||||
|
self.register_parameter("Wrr", None)
|
||||||
|
self.register_parameter("Wri", None)
|
||||||
|
self.register_parameter("Wii", None)
|
||||||
|
self.register_parameter("Br", None)
|
||||||
|
self.register_parameter("Bi", None)
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
values = torch.zeros(self.num_features)
|
||||||
|
self.register_buffer("Mean_real", values)
|
||||||
|
self.register_buffer("Mean_imag", values)
|
||||||
|
self.register_buffer("Var_rr", values)
|
||||||
|
self.register_buffer("Var_ri", values)
|
||||||
|
self.register_buffer("Var_ii", values)
|
||||||
|
self.register_buffer(
|
||||||
|
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("Mean_real", None)
|
||||||
|
self.register_parameter("Mean_imag", None)
|
||||||
|
self.register_parameter("Var_rr", None)
|
||||||
|
self.register_parameter("Var_ri", None)
|
||||||
|
self.register_parameter("Var_ii", None)
|
||||||
|
self.register_parameter("num_batches_tracked", None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.affine:
|
||||||
|
self.Wrr.data.fill_(1)
|
||||||
|
self.Wii.data.fill_(1)
|
||||||
|
self.Wri.data.uniform_(-0.9, 0.9)
|
||||||
|
self.Br.data.fill_(0)
|
||||||
|
self.Bi.data.fill_(0)
|
||||||
|
self.reset_running_stats()
|
||||||
|
|
||||||
|
def reset_running_stats(self):
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.Mean_real.zero_()
|
||||||
|
self.Mean_imag.zero_()
|
||||||
|
self.Var_rr.fill_(1)
|
||||||
|
self.Var_ri.zero_()
|
||||||
|
self.Var_ii.fill_(1)
|
||||||
|
self.num_batches_tracked.zero_()
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
|
||||||
|
**self.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
exp_avg_factor = 0.0
|
||||||
|
|
||||||
|
training = self.training and self.track_running_stats
|
||||||
|
if training:
|
||||||
|
self.num_batches_tracked += 1
|
||||||
|
if self.momentum is None:
|
||||||
|
exp_avg_factor = 1 / self.num_batches_tracked
|
||||||
|
else:
|
||||||
|
exp_avg_factor = self.momentum
|
||||||
|
|
||||||
|
redux = [i for i in reversed(range(real.dim())) if i != 1]
|
||||||
|
vdim = [1] * real.dim()
|
||||||
|
vdim[1] = real.size(1)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
batch_mean_real, batch_mean_imag = real, imag
|
||||||
|
for dim in redux:
|
||||||
|
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
|
||||||
|
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
|
||||||
|
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_mean_real = self.Mean_real.view(vdim)
|
||||||
|
batch_mean_imag = self.Mean_imag.view(vdim)
|
||||||
|
|
||||||
|
real = real - batch_mean_real
|
||||||
|
imag = imag - batch_mean_imag
|
||||||
|
|
||||||
|
if training:
|
||||||
|
batch_var_rr = real * real
|
||||||
|
batch_var_ri = real * imag
|
||||||
|
batch_var_ii = imag * imag
|
||||||
|
for dim in redux:
|
||||||
|
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
|
||||||
|
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
|
||||||
|
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
|
||||||
|
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
|
||||||
|
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
|
||||||
|
else:
|
||||||
|
batch_var_rr = self.Var_rr.view(vdim)
|
||||||
|
batch_var_ii = self.Var_ii.view(vdim)
|
||||||
|
batch_var_ri = self.Var_ri.view(vdim)
|
||||||
|
|
||||||
|
batch_var_rr += self.eps
|
||||||
|
batch_var_ii += self.eps
|
||||||
|
|
||||||
|
# Covariance matrics
|
||||||
|
# | batch_var_rr batch_var_ri |
|
||||||
|
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
|
||||||
|
# Inverse square root of cov matrix by combining below two formulas
|
||||||
|
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
|
||||||
|
# https://mathworld.wolfram.com/MatrixInverse.html
|
||||||
|
|
||||||
|
tau = batch_var_rr + batch_var_ii
|
||||||
|
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
|
||||||
|
t = (tau + 2 * s).sqrt()
|
||||||
|
|
||||||
|
rst = (s * t).reciprocal()
|
||||||
|
Urr = (batch_var_ii + s) * rst
|
||||||
|
Uri = -batch_var_ri * rst
|
||||||
|
Uii = (batch_var_rr + s) * rst
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
Wrr, Wri, Wii = (
|
||||||
|
self.Wrr.view(vdim),
|
||||||
|
self.Wri.view(vdim),
|
||||||
|
self.Wii.view(vdim),
|
||||||
|
)
|
||||||
|
Zrr = (Wrr * Urr) + (Wri * Uri)
|
||||||
|
Zri = (Wrr * Uri) + (Wri * Uii)
|
||||||
|
Zir = (Wii * Uri) + (Wri * Urr)
|
||||||
|
Zii = (Wri * Uri) + (Wii * Uii)
|
||||||
|
else:
|
||||||
|
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
|
||||||
|
|
||||||
|
yr = (Zrr * real) + (Zri * imag)
|
||||||
|
yi = (Zir * real) + (Zii * imag)
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
yr = yr + self.Br.view(vdim)
|
||||||
|
yi = yi + self.Bi.view(vdim)
|
||||||
|
|
||||||
|
outputs = torch.cat([yr, yi], 1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexRelu(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.real_relu = nn.PReLU()
|
||||||
|
self.imag_relu = nn.PReLU()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
real = self.real_relu(real)
|
||||||
|
imag = self.imag_relu(imag)
|
||||||
|
return torch.cat([real, imag], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def complex_cat(inputs, axis=1):
|
||||||
|
|
||||||
|
real, imag = [], []
|
||||||
|
for data in inputs:
|
||||||
|
real_data, imag_data = torch.chunk(data, 2, axis)
|
||||||
|
real.append(real_data)
|
||||||
|
imag.append(imag_data)
|
||||||
|
real = torch.cat(real, axis)
|
||||||
|
imag = torch.cat(imag, axis)
|
||||||
|
return torch.cat([real, imag], axis)
|
||||||
|
|
@ -0,0 +1,338 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from enhancer.data import EnhancerDataset
|
||||||
|
from enhancer.models import Model
|
||||||
|
from enhancer.models.complexnn import (
|
||||||
|
ComplexBatchNorm2D,
|
||||||
|
ComplexConv2d,
|
||||||
|
ComplexConvTranspose2d,
|
||||||
|
ComplexLSTM,
|
||||||
|
ComplexRelu,
|
||||||
|
)
|
||||||
|
from enhancer.models.complexnn.utils import complex_cat
|
||||||
|
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||||
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DCCRN_ENCODER(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channel: int,
|
||||||
|
kernel_size: Tuple[int, int],
|
||||||
|
complex_norm: bool = True,
|
||||||
|
complex_relu: bool = True,
|
||||||
|
stride: Tuple[int, int] = (2, 1),
|
||||||
|
padding: Tuple[int, int] = (2, 1),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
||||||
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
ComplexConv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channel,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
),
|
||||||
|
batchnorm(out_channel),
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class DCCRN_DECODER(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Tuple[int, int],
|
||||||
|
layer: int = 0,
|
||||||
|
complex_norm: bool = True,
|
||||||
|
complex_relu: bool = True,
|
||||||
|
stride: Tuple[int, int] = (2, 1),
|
||||||
|
padding: Tuple[int, int] = (2, 0),
|
||||||
|
output_padding: Tuple[int, int] = (1, 0),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
||||||
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
||||||
|
|
||||||
|
if layer != 0:
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
ComplexConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
),
|
||||||
|
batchnorm(out_channels),
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
ComplexConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
return self.decoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class DCCRN(Model):
|
||||||
|
|
||||||
|
STFT_DEFAULTS = {
|
||||||
|
"window_len": 400,
|
||||||
|
"hop_size": 100,
|
||||||
|
"nfft": 512,
|
||||||
|
"window": "hamming",
|
||||||
|
}
|
||||||
|
|
||||||
|
ED_DEFAULTS = {
|
||||||
|
"initial_output_channels": 32,
|
||||||
|
"depth": 6,
|
||||||
|
"kernel_size": 5,
|
||||||
|
"growth_factor": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"padding": 2,
|
||||||
|
"output_padding": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
LSTM_DEFAULTS = {
|
||||||
|
"num_layers": 2,
|
||||||
|
"hidden_size": 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stft: Optional[dict] = None,
|
||||||
|
encoder_decoder: Optional[dict] = None,
|
||||||
|
lstm: Optional[dict] = None,
|
||||||
|
complex_lstm: bool = True,
|
||||||
|
complex_norm: bool = True,
|
||||||
|
complex_relu: bool = True,
|
||||||
|
masking_mode: str = "E",
|
||||||
|
num_channels: int = 1,
|
||||||
|
sampling_rate=16000,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
loss: Union[str, List, Any] = "mse",
|
||||||
|
metric: Union[str, List] = "mse",
|
||||||
|
):
|
||||||
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
logging.warning(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
|
super().__init__(
|
||||||
|
num_channels=num_channels,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
||||||
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
|
stft = merge_dict(self.STFT_DEFAULTS, stft)
|
||||||
|
self.save_hyperparameters(
|
||||||
|
"encoder_decoder",
|
||||||
|
"lstm",
|
||||||
|
"stft",
|
||||||
|
"complex_lstm",
|
||||||
|
"complex_norm",
|
||||||
|
"masking_mode",
|
||||||
|
)
|
||||||
|
self.complex_lstm = complex_lstm
|
||||||
|
self.complex_norm = complex_norm
|
||||||
|
self.masking_mode = masking_mode
|
||||||
|
|
||||||
|
self.stft = ConvSTFT(
|
||||||
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
||||||
|
)
|
||||||
|
self.istft = ConviSTFT(
|
||||||
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = nn.ModuleList()
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
num_channels *= 2
|
||||||
|
hidden_size = encoder_decoder["initial_output_channels"]
|
||||||
|
growth_factor = 2
|
||||||
|
|
||||||
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
|
encoder_ = DCCRN_ENCODER(
|
||||||
|
num_channels,
|
||||||
|
hidden_size,
|
||||||
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
||||||
|
stride=(encoder_decoder["stride"], 1),
|
||||||
|
padding=(encoder_decoder["padding"], 1),
|
||||||
|
complex_norm=complex_norm,
|
||||||
|
complex_relu=complex_relu,
|
||||||
|
)
|
||||||
|
self.encoder.append(encoder_)
|
||||||
|
|
||||||
|
decoder_ = DCCRN_DECODER(
|
||||||
|
hidden_size + hidden_size,
|
||||||
|
num_channels,
|
||||||
|
layer=layer,
|
||||||
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
||||||
|
stride=(encoder_decoder["stride"], 1),
|
||||||
|
padding=(encoder_decoder["padding"], 0),
|
||||||
|
output_padding=(encoder_decoder["output_padding"], 0),
|
||||||
|
complex_norm=complex_norm,
|
||||||
|
complex_relu=complex_relu,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder.insert(0, decoder_)
|
||||||
|
|
||||||
|
if layer < encoder_decoder["depth"] - 3:
|
||||||
|
num_channels = hidden_size
|
||||||
|
hidden_size *= growth_factor
|
||||||
|
else:
|
||||||
|
num_channels = hidden_size
|
||||||
|
|
||||||
|
kernel_size = hidden_size / 2
|
||||||
|
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
|
||||||
|
|
||||||
|
if self.complex_lstm:
|
||||||
|
lstms = []
|
||||||
|
for layer in range(lstm["num_layers"]):
|
||||||
|
|
||||||
|
if layer == 0:
|
||||||
|
input_size = int(hidden_size * kernel_size)
|
||||||
|
else:
|
||||||
|
input_size = lstm["hidden_size"]
|
||||||
|
|
||||||
|
if layer == lstm["num_layers"] - 1:
|
||||||
|
projection_size = int(hidden_size * kernel_size)
|
||||||
|
else:
|
||||||
|
projection_size = None
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"input_size": input_size,
|
||||||
|
"hidden_size": lstm["hidden_size"],
|
||||||
|
"num_layers": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
lstms.append(
|
||||||
|
ComplexLSTM(projection_size=projection_size, **kwargs)
|
||||||
|
)
|
||||||
|
self.lstm = nn.Sequential(*lstms)
|
||||||
|
else:
|
||||||
|
self.lstm = nn.Sequential(
|
||||||
|
nn.LSTM(
|
||||||
|
input_size=hidden_size * kernel_size,
|
||||||
|
hidden_sizs=lstm["hidden_size"],
|
||||||
|
num_layers=lstm["num_layers"],
|
||||||
|
dropout=0.0,
|
||||||
|
batch_first=False,
|
||||||
|
)[0],
|
||||||
|
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
|
if waveform.size(1) != self.hparams.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
|
waveform_stft = self.stft(waveform)
|
||||||
|
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
||||||
|
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
||||||
|
|
||||||
|
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
|
||||||
|
phase_spec = torch.atan2(imag, real)
|
||||||
|
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
out = complex_spec
|
||||||
|
for _, encoder in enumerate(self.encoder):
|
||||||
|
out = encoder(out)
|
||||||
|
encoder_outputs.append(out)
|
||||||
|
|
||||||
|
B, C, D, T = out.size()
|
||||||
|
out = out.permute(3, 0, 1, 2)
|
||||||
|
if self.complex_lstm:
|
||||||
|
|
||||||
|
lstm_real = out[:, :, : C // 2]
|
||||||
|
lstm_imag = out[:, :, C // 2 :]
|
||||||
|
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
|
||||||
|
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
|
||||||
|
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
|
||||||
|
lstm_real = lstm_real.reshape(T, B, C // 2, D)
|
||||||
|
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
|
||||||
|
out = torch.cat([lstm_real, lstm_imag], 2)
|
||||||
|
else:
|
||||||
|
out = out.reshape(T, B, C * D)
|
||||||
|
out = self.lstm(out)
|
||||||
|
out = out.reshape(T, B, D, C)
|
||||||
|
|
||||||
|
out = out.permute(1, 2, 3, 0)
|
||||||
|
for layer, decoder in enumerate(self.decoder):
|
||||||
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
|
out = complex_cat([skip_connection, out])
|
||||||
|
out = decoder(out)
|
||||||
|
out = out[..., 1:]
|
||||||
|
mask_real, mask_imag = out[:, 0], out[:, 1]
|
||||||
|
mask_real = F.pad(mask_real, [0, 0, 1, 0])
|
||||||
|
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
|
||||||
|
if self.masking_mode == "E":
|
||||||
|
|
||||||
|
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
|
||||||
|
real_phase = mask_real / (mask_mag + 1e-8)
|
||||||
|
imag_phase = mask_imag / (mask_mag + 1e-8)
|
||||||
|
mask_phase = torch.atan2(imag_phase, real_phase)
|
||||||
|
mask_mag = torch.tanh(mask_mag)
|
||||||
|
est_mag = mask_mag * mag_spec
|
||||||
|
est_phase = mask_phase * phase_spec
|
||||||
|
# cos(theta) + isin(theta)
|
||||||
|
real = est_mag + torch.cos(est_phase)
|
||||||
|
imag = est_mag + torch.sin(est_phase)
|
||||||
|
|
||||||
|
if self.masking_mode == "C":
|
||||||
|
|
||||||
|
real = real * mask_real - imag * mask_imag
|
||||||
|
imag = real * mask_imag + imag * mask_real
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
real = real * mask_real
|
||||||
|
imag = imag * mask_imag
|
||||||
|
|
||||||
|
spec = torch.cat([real, imag], 1)
|
||||||
|
wav = self.istft(spec)
|
||||||
|
wav = wav.clamp_(-1, 1)
|
||||||
|
return wav
|
||||||
|
|
@ -0,0 +1,274 @@
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
from enhancer.utils.io import Audio as audio
|
||||||
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DemucsLSTM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
bidirectional: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size, hidden_size, num_layers, bidirectional=bidirectional
|
||||||
|
)
|
||||||
|
dim = 2 if bidirectional else 1
|
||||||
|
self.linear = nn.Linear(dim * hidden_size, hidden_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
output, (h, c) = self.lstm(x)
|
||||||
|
output = self.linear(output)
|
||||||
|
|
||||||
|
return output, (h, c)
|
||||||
|
|
||||||
|
|
||||||
|
class DemucsEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int,
|
||||||
|
hidden_size: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
glu: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
|
multi_factor = 2 if glu else 1
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class DemucsDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int,
|
||||||
|
hidden_size: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
glu: bool = False,
|
||||||
|
layer: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
|
multi_factor = 2 if glu else 1
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
|
||||||
|
activation,
|
||||||
|
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
||||||
|
)
|
||||||
|
if layer > 0:
|
||||||
|
self.decoder.add_module("4", nn.ReLU())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
waveform,
|
||||||
|
):
|
||||||
|
|
||||||
|
out = self.decoder(waveform)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Demucs(Model):
|
||||||
|
"""
|
||||||
|
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||||
|
parameters:
|
||||||
|
encoder_decoder: dict, optional
|
||||||
|
keyword arguments passsed to encoder decoder block
|
||||||
|
lstm : dict, optional
|
||||||
|
keyword arguments passsed to LSTM block
|
||||||
|
num_channels: int, defaults to 1
|
||||||
|
number channels in input audio
|
||||||
|
sampling_rate: int, defaults to 16KHz
|
||||||
|
sampling rate of input audio
|
||||||
|
lr : float, defaults to 1e-3
|
||||||
|
learning rate used for training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
EnhancerDataset object containing train/validation data for training
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds
|
||||||
|
loss : string or List of strings
|
||||||
|
loss function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
metric : string or List of strings
|
||||||
|
metric function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
ED_DEFAULTS = {
|
||||||
|
"initial_output_channels": 48,
|
||||||
|
"kernel_size": 8,
|
||||||
|
"stride": 4,
|
||||||
|
"depth": 5,
|
||||||
|
"glu": True,
|
||||||
|
"growth_factor": 2,
|
||||||
|
}
|
||||||
|
LSTM_DEFAULTS = {
|
||||||
|
"bidirectional": True,
|
||||||
|
"num_layers": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_decoder: Optional[dict] = None,
|
||||||
|
lstm: Optional[dict] = None,
|
||||||
|
num_channels: int = 1,
|
||||||
|
resample: int = 4,
|
||||||
|
sampling_rate=16000,
|
||||||
|
normalize=True,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
|
loss: Union[str, List] = "mse",
|
||||||
|
metric: Union[str, List] = "mse",
|
||||||
|
floor=1e-3,
|
||||||
|
):
|
||||||
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
logging.warning(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
|
super().__init__(
|
||||||
|
num_channels=num_channels,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
||||||
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
|
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||||
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
|
self.normalize = normalize
|
||||||
|
self.floor = floor
|
||||||
|
self.encoder = nn.ModuleList()
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
|
encoder_layer = DemucsEncoder(
|
||||||
|
num_channels=num_channels,
|
||||||
|
hidden_size=hidden,
|
||||||
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
|
stride=encoder_decoder["stride"],
|
||||||
|
glu=encoder_decoder["glu"],
|
||||||
|
)
|
||||||
|
self.encoder.append(encoder_layer)
|
||||||
|
|
||||||
|
decoder_layer = DemucsDecoder(
|
||||||
|
num_channels=num_channels,
|
||||||
|
hidden_size=hidden,
|
||||||
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
|
stride=encoder_decoder["stride"],
|
||||||
|
glu=encoder_decoder["glu"],
|
||||||
|
layer=layer,
|
||||||
|
)
|
||||||
|
self.decoder.insert(0, decoder_layer)
|
||||||
|
|
||||||
|
num_channels = hidden
|
||||||
|
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
||||||
|
|
||||||
|
self.de_lstm = DemucsLSTM(
|
||||||
|
input_size=num_channels,
|
||||||
|
hidden_size=num_channels,
|
||||||
|
num_layers=lstm["num_layers"],
|
||||||
|
bidirectional=lstm["bidirectional"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
|
if waveform.size(1) != self.hparams.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
if self.normalize:
|
||||||
|
waveform = waveform.mean(dim=1, keepdim=True)
|
||||||
|
std = waveform.std(dim=-1, keepdim=True)
|
||||||
|
waveform = waveform / (self.floor + std)
|
||||||
|
else:
|
||||||
|
std = 1
|
||||||
|
length = waveform.shape[-1]
|
||||||
|
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||||
|
if self.hparams.resample > 1:
|
||||||
|
x = audio.resample_audio(
|
||||||
|
audio=x,
|
||||||
|
sr=self.hparams.sampling_rate,
|
||||||
|
target_sr=int(
|
||||||
|
self.hparams.sampling_rate * self.hparams.resample
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
for encoder in self.encoder:
|
||||||
|
x = encoder(x)
|
||||||
|
encoder_outputs.append(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x, _ = self.de_lstm(x)
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
for decoder in self.decoder:
|
||||||
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
|
x = x + skip_connection[..., : x.shape[-1]]
|
||||||
|
x = decoder(x)
|
||||||
|
|
||||||
|
if self.hparams.resample > 1:
|
||||||
|
x = audio.resample_audio(
|
||||||
|
x,
|
||||||
|
int(self.hparams.sampling_rate * self.hparams.resample),
|
||||||
|
self.hparams.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = x[..., :length]
|
||||||
|
return std * out
|
||||||
|
|
||||||
|
def get_padding_length(self, input_length):
|
||||||
|
|
||||||
|
input_length = math.ceil(input_length * self.hparams.resample)
|
||||||
|
|
||||||
|
for layer in range(
|
||||||
|
self.hparams.encoder_decoder["depth"]
|
||||||
|
): # encoder operation
|
||||||
|
input_length = (
|
||||||
|
math.ceil(
|
||||||
|
(input_length - self.hparams.encoder_decoder["kernel_size"])
|
||||||
|
/ self.hparams.encoder_decoder["stride"]
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
input_length = max(1, input_length)
|
||||||
|
for layer in range(
|
||||||
|
self.hparams.encoder_decoder["depth"]
|
||||||
|
): # decoder operaration
|
||||||
|
input_length = (input_length - 1) * self.hparams.encoder_decoder[
|
||||||
|
"stride"
|
||||||
|
] + self.hparams.encoder_decoder["kernel_size"]
|
||||||
|
input_length = math.ceil(input_length / self.hparams.resample)
|
||||||
|
|
||||||
|
return int(input_length)
|
||||||
|
|
@ -0,0 +1,431 @@
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from importlib import import_module
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Optional, Text, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.inference import Inference
|
||||||
|
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||||
|
from enhancer.version import __version__
|
||||||
|
|
||||||
|
CACHE_DIR = ""
|
||||||
|
HF_TORCH_WEIGHTS = ""
|
||||||
|
DEFAULT_DEVICE = "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
class Model(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
Base class for all models
|
||||||
|
parameters:
|
||||||
|
num_channels: int, default to 1
|
||||||
|
number of channels in input audio
|
||||||
|
sampling_rate : int, default 16khz
|
||||||
|
audio sampling rate
|
||||||
|
lr: float, optional
|
||||||
|
learning rate for model training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
Enhancer dataset used for training/validation
|
||||||
|
duration: float, optional
|
||||||
|
duration used for training/inference
|
||||||
|
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
||||||
|
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int = 1,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
loss: Union[str, List] = "mse",
|
||||||
|
metric: Union[str, List, Any] = "mse",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
num_channels == 1
|
||||||
|
), "Enhancer only support for mono channel models"
|
||||||
|
self.dataset = dataset
|
||||||
|
self.save_hyperparameters(
|
||||||
|
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||||
|
)
|
||||||
|
if self.logger:
|
||||||
|
self.logger.experiment.log_dict(
|
||||||
|
dict(self.hparams), "hyperparameters.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loss = loss
|
||||||
|
self.metric = metric
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss(self):
|
||||||
|
return self._loss
|
||||||
|
|
||||||
|
@loss.setter
|
||||||
|
def loss(self, loss):
|
||||||
|
|
||||||
|
if isinstance(loss, str):
|
||||||
|
loss = [loss]
|
||||||
|
|
||||||
|
self._loss = LossWrapper(loss)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric(self):
|
||||||
|
return self._metric
|
||||||
|
|
||||||
|
@metric.setter
|
||||||
|
def metric(self, metric):
|
||||||
|
self._metric = []
|
||||||
|
if isinstance(metric, (str, nn.Module)):
|
||||||
|
metric = [metric]
|
||||||
|
|
||||||
|
for func in metric:
|
||||||
|
if isinstance(func, str):
|
||||||
|
if func in LOSS_MAP.keys():
|
||||||
|
if func in ("pesq", "stoi"):
|
||||||
|
self._metric.append(
|
||||||
|
LOSS_MAP[func](self.hparams.sampling_rate)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._metric.append(LOSS_MAP[func]())
|
||||||
|
else:
|
||||||
|
ValueError(f"Invalid metrics {func}")
|
||||||
|
|
||||||
|
elif isinstance(func, nn.Module):
|
||||||
|
self._metric.append(func)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid metrics")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self):
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@dataset.setter
|
||||||
|
def dataset(self, dataset):
|
||||||
|
self._dataset = dataset
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
if stage == "fit":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.dataset.setup(stage)
|
||||||
|
self.dataset.model = self
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Total train duration",
|
||||||
|
self.dataset.train_dataloader().dataset.__len__()
|
||||||
|
* self.dataset.duration
|
||||||
|
/ 60,
|
||||||
|
"minutes",
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Total validation duration",
|
||||||
|
self.dataset.val_dataloader().dataset.__len__()
|
||||||
|
* self.dataset.duration
|
||||||
|
/ 60,
|
||||||
|
"minutes",
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Total test duration",
|
||||||
|
self.dataset.test_dataloader().dataset.__len__()
|
||||||
|
* self.dataset.duration
|
||||||
|
/ 60,
|
||||||
|
"minutes",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return self.dataset.train_dataloader()
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return self.dataset.val_dataloader()
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return self.dataset.test_dataloader()
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
|
mixed_waveform = batch["noisy"]
|
||||||
|
target = batch["clean"]
|
||||||
|
prediction = self(mixed_waveform)
|
||||||
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
|
self.log(
|
||||||
|
"train_loss",
|
||||||
|
loss.item(),
|
||||||
|
on_epoch=True,
|
||||||
|
on_step=True,
|
||||||
|
logger=True,
|
||||||
|
prog_bar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
|
metric_dict = {}
|
||||||
|
mixed_waveform = batch["noisy"]
|
||||||
|
target = batch["clean"]
|
||||||
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
|
metric_dict["valid_loss"] = self.loss(target, prediction).item()
|
||||||
|
for metric in self.metric:
|
||||||
|
value = metric(target, prediction)
|
||||||
|
metric_dict[f"valid_{metric.name}"] = value.item()
|
||||||
|
|
||||||
|
self.log_dict(
|
||||||
|
metric_dict,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metric_dict
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
|
||||||
|
metric_dict = {}
|
||||||
|
mixed_waveform = batch["noisy"]
|
||||||
|
target = batch["clean"]
|
||||||
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
|
for metric in self.metric:
|
||||||
|
value = metric(target, prediction)
|
||||||
|
metric_dict[f"test_{metric.name}"] = value
|
||||||
|
|
||||||
|
self.log_dict(
|
||||||
|
metric_dict,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metric_dict
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
|
||||||
|
test_mean_metrics = defaultdict(int)
|
||||||
|
for output in outputs:
|
||||||
|
for metric, value in output.items():
|
||||||
|
test_mean_metrics[metric] += value.item()
|
||||||
|
for metric in test_mean_metrics.keys():
|
||||||
|
test_mean_metrics[metric] /= len(outputs)
|
||||||
|
|
||||||
|
print("----------TEST REPORT----------\n")
|
||||||
|
for metric in test_mean_metrics.keys():
|
||||||
|
print(f"|{metric.upper()} | {test_mean_metrics[metric]} |")
|
||||||
|
print("--------------------------------")
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
|
checkpoint["enhancer"] = {
|
||||||
|
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||||
|
"architecture": {
|
||||||
|
"module": self.__class__.__module__,
|
||||||
|
"class": self.__class__.__name__,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
checkpoint: Union[Path, Text],
|
||||||
|
map_location=None,
|
||||||
|
hparams_file: Union[Path, Text] = None,
|
||||||
|
strict: bool = True,
|
||||||
|
use_auth_token: Union[Text, None] = None,
|
||||||
|
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load Pretrained model
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
checkpoint : Path or str
|
||||||
|
Path to checkpoint, or a remote URL, or a model identifier from
|
||||||
|
the huggingface.co model hub.
|
||||||
|
map_location: optional
|
||||||
|
Same role as in torch.load().
|
||||||
|
Defaults to `lambda storage, loc: storage`.
|
||||||
|
hparams_file : Path or str, optional
|
||||||
|
Path to a .yaml file with hierarchical structure as in this example:
|
||||||
|
drop_prob: 0.2
|
||||||
|
dataloader:
|
||||||
|
batch_size: 32
|
||||||
|
You most likely won’t need this since Lightning will always save the
|
||||||
|
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||||
|
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||||
|
file with the hparams you would like to use. These will be converted
|
||||||
|
into a dict and passed into your Model for use.
|
||||||
|
strict : bool, optional
|
||||||
|
Whether to strictly enforce that the keys in checkpoint match
|
||||||
|
the keys returned by this module’s state dict. Defaults to True.
|
||||||
|
use_auth_token : str, optional
|
||||||
|
When loading a private huggingface.co model, set `use_auth_token`
|
||||||
|
to True or to a string containing your hugginface.co authentication
|
||||||
|
token that can be obtained by running `huggingface-cli login`
|
||||||
|
cache_dir: Path or str, optional
|
||||||
|
Path to model cache directory
|
||||||
|
kwargs: optional
|
||||||
|
Any extra keyword args needed to init the model.
|
||||||
|
Can also be used to override saved hyperparameter values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Model
|
||||||
|
Model
|
||||||
|
|
||||||
|
See also
|
||||||
|
--------
|
||||||
|
torch.load
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpoint = str(checkpoint)
|
||||||
|
if hparams_file is not None:
|
||||||
|
hparams_file = str(hparams_file)
|
||||||
|
|
||||||
|
if os.path.isfile(checkpoint):
|
||||||
|
model_path_pl = checkpoint
|
||||||
|
elif urlparse(checkpoint).scheme in ("http", "https"):
|
||||||
|
model_path_pl = checkpoint
|
||||||
|
else:
|
||||||
|
|
||||||
|
if "@" in checkpoint:
|
||||||
|
model_id = checkpoint.split("@")[0]
|
||||||
|
revision_id = checkpoint.split("@")[1]
|
||||||
|
else:
|
||||||
|
model_id = checkpoint
|
||||||
|
revision_id = None
|
||||||
|
|
||||||
|
url = hf_hub_url(
|
||||||
|
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||||
|
)
|
||||||
|
model_path_pl = cached_download(
|
||||||
|
url=url,
|
||||||
|
library_name="enhancer",
|
||||||
|
library_version=__version__,
|
||||||
|
cache_dir=cached_dir,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if map_location is None:
|
||||||
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
|
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||||
|
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||||
|
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||||
|
module = import_module(module_name)
|
||||||
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = Klass.load_from_checkpoint(
|
||||||
|
checkpoint_path=model_path_pl,
|
||||||
|
map_location=map_location,
|
||||||
|
hparams_file=hparams_file,
|
||||||
|
strict=strict,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||||
|
"""
|
||||||
|
perform model inference
|
||||||
|
parameters:
|
||||||
|
batch : torch.Tensor
|
||||||
|
input data
|
||||||
|
batch_size : int, default 32
|
||||||
|
batch size for inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
batch.ndim == 3
|
||||||
|
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||||
|
batch_predictions = []
|
||||||
|
self.eval().to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_id in range(0, batch.shape[0], batch_size):
|
||||||
|
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
prediction = self(batch_data)
|
||||||
|
batch_predictions.append(prediction)
|
||||||
|
|
||||||
|
return torch.vstack(batch_predictions)
|
||||||
|
|
||||||
|
def enhance(
|
||||||
|
self,
|
||||||
|
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
|
batch_size: int = 32,
|
||||||
|
save_output: bool = False,
|
||||||
|
duration: Optional[int] = None,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enhance audio using loaded pretained model.
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate: int, optional incase input is path
|
||||||
|
sampling rate of input
|
||||||
|
batch_size: int, default 32
|
||||||
|
input audio is split into multiple chunks. Inference is done on batches
|
||||||
|
of these chunks according to given batch size.
|
||||||
|
save_output : bool, default False
|
||||||
|
weather to save output to file
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||||
|
step_size: int, optional
|
||||||
|
step size between consecutive durations, defaults to 50% of duration
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_sampling_rate = self.hparams["sampling_rate"]
|
||||||
|
if duration is None:
|
||||||
|
duration = self.hparams["duration"]
|
||||||
|
waveform = Inference.read_input(
|
||||||
|
audio, sampling_rate, model_sampling_rate
|
||||||
|
)
|
||||||
|
waveform.to(self.device)
|
||||||
|
window_size = round(duration * model_sampling_rate)
|
||||||
|
batched_waveform = Inference.batchify(
|
||||||
|
waveform, window_size, step_size=step_size
|
||||||
|
)
|
||||||
|
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||||
|
waveform = Inference.aggreagate(
|
||||||
|
batch_prediction,
|
||||||
|
window_size,
|
||||||
|
waveform.shape[-1],
|
||||||
|
step_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if save_output and isinstance(audio, (str, Path)):
|
||||||
|
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||||
|
|
||||||
|
else:
|
||||||
|
waveform = Inference.prepare_output(
|
||||||
|
waveform, model_sampling_rate, audio, sampling_rate
|
||||||
|
)
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid_monitor(self):
|
||||||
|
|
||||||
|
return "max" if self.loss.higher_better else "min"
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
class WavenetDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
padding: int = 2,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super(WavenetDecoder, self).__init__()
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
),
|
||||||
|
nn.BatchNorm1d(out_channels),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
return self.decoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class WavenetEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 15,
|
||||||
|
padding: int = 7,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super(WavenetEncoder, self).__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
),
|
||||||
|
nn.BatchNorm1d(out_channels),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class WaveUnet(Model):
|
||||||
|
"""
|
||||||
|
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||||
|
parameters:
|
||||||
|
num_channels: int, defaults to 1
|
||||||
|
number of channels in input audio
|
||||||
|
depth : int, defaults to 12
|
||||||
|
depth of network
|
||||||
|
initial_output_channels: int, defaults to 24
|
||||||
|
number of output channels in intial upsampling layer
|
||||||
|
sampling_rate: int, defaults to 16KHz
|
||||||
|
sampling rate of input audio
|
||||||
|
lr : float, defaults to 1e-3
|
||||||
|
learning rate used for training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
EnhancerDataset object containing train/validation data for training
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds
|
||||||
|
loss : string or List of strings
|
||||||
|
loss function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
metric : string or List of strings
|
||||||
|
metric function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int = 1,
|
||||||
|
depth: int = 12,
|
||||||
|
initial_output_channels: int = 24,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
loss: Union[str, List] = "mse",
|
||||||
|
metric: Union[str, List] = "mse",
|
||||||
|
):
|
||||||
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
logging.warning(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
|
super().__init__(
|
||||||
|
num_channels=num_channels,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
|
)
|
||||||
|
self.save_hyperparameters("depth")
|
||||||
|
self.encoders = nn.ModuleList()
|
||||||
|
self.decoders = nn.ModuleList()
|
||||||
|
out_channels = initial_output_channels
|
||||||
|
for layer in range(depth):
|
||||||
|
|
||||||
|
encoder = WavenetEncoder(num_channels, out_channels)
|
||||||
|
self.encoders.append(encoder)
|
||||||
|
|
||||||
|
num_channels = out_channels
|
||||||
|
out_channels += initial_output_channels
|
||||||
|
if layer == depth - 1:
|
||||||
|
decoder = WavenetDecoder(
|
||||||
|
depth * initial_output_channels + num_channels, num_channels
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoder = WavenetDecoder(
|
||||||
|
num_channels + out_channels, num_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoders.insert(0, decoder)
|
||||||
|
|
||||||
|
bottleneck_dim = depth * initial_output_channels
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
|
||||||
|
nn.BatchNorm1d(bottleneck_dim),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
)
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
|
if waveform.size(1) != 1:
|
||||||
|
raise TypeError(
|
||||||
|
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
out = waveform
|
||||||
|
for encoder in self.encoders:
|
||||||
|
out = encoder(out)
|
||||||
|
encoder_outputs.insert(0, out)
|
||||||
|
out = out[:, :, ::2]
|
||||||
|
|
||||||
|
out = self.bottleneck(out)
|
||||||
|
|
||||||
|
for layer, decoder in enumerate(self.decoders):
|
||||||
|
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||||
|
out = self.fix_last_dim(out, encoder_outputs[layer])
|
||||||
|
out = torch.cat([out, encoder_outputs[layer]], dim=1)
|
||||||
|
out = decoder(out)
|
||||||
|
|
||||||
|
out = torch.cat([out, waveform], dim=1)
|
||||||
|
out = self.final(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fix_last_dim(self, x, target):
|
||||||
|
"""
|
||||||
|
centre crop along last dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
x.shape[-1] >= target.shape[-1]
|
||||||
|
), "input dimension cannot be larger than target dimension"
|
||||||
|
if x.shape[-1] == target.shape[-1]:
|
||||||
|
return x
|
||||||
|
|
||||||
|
diff = x.shape[-1] - target.shape[-1]
|
||||||
|
if diff % 2 != 0:
|
||||||
|
x = F.pad(x, (0, 1))
|
||||||
|
diff += 1
|
||||||
|
|
||||||
|
crop = diff // 2
|
||||||
|
return x[:, :, crop:-crop]
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
from enhancer.utils.utils import check_files
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Files:
|
||||||
|
train_clean: str
|
||||||
|
train_noisy: str
|
||||||
|
test_clean: str
|
||||||
|
test_noisy: str
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
class Audio:
|
||||||
|
"""
|
||||||
|
Audio utils
|
||||||
|
parameters:
|
||||||
|
sampling_rate : int, defaults to 16KHz
|
||||||
|
audio sampling rate
|
||||||
|
mono: bool, defaults to True
|
||||||
|
return_tensors: bool, defaults to True
|
||||||
|
returns torch tensor type if set to True else numpy ndarray
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.mono = mono
|
||||||
|
self.return_tensor = return_tensor
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
read and process input audio
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate : int, optional
|
||||||
|
sampling rate of the audio input
|
||||||
|
offset: float, optional
|
||||||
|
offset from which the audio must be read, reads from beginning if unused.
|
||||||
|
duration: float (seconds), optional
|
||||||
|
read duration, reads full audio starting from offset if not used
|
||||||
|
"""
|
||||||
|
if isinstance(audio, str):
|
||||||
|
if os.path.exists(audio):
|
||||||
|
audio, sampling_rate = librosa.load(
|
||||||
|
audio,
|
||||||
|
sr=sampling_rate,
|
||||||
|
mono=False,
|
||||||
|
offset=offset,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1, -1)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"File {audio} deos not exist")
|
||||||
|
elif isinstance(audio, np.ndarray):
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1, -1)
|
||||||
|
else:
|
||||||
|
raise ValueError("audio should be either filepath or numpy ndarray")
|
||||||
|
|
||||||
|
if self.mono:
|
||||||
|
audio = self.convert_mono(audio)
|
||||||
|
|
||||||
|
if sampling_rate:
|
||||||
|
audio = self.__class__.resample_audio(
|
||||||
|
audio, sampling_rate, self.sampling_rate
|
||||||
|
)
|
||||||
|
if self.return_tensor:
|
||||||
|
return torch.tensor(audio)
|
||||||
|
else:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
convert input audio into mono (1)
|
||||||
|
parameters:
|
||||||
|
audio: np.ndarray or torch.Tensor
|
||||||
|
"""
|
||||||
|
if len(audio.shape) > 2:
|
||||||
|
assert (
|
||||||
|
audio.shape[0] == 1
|
||||||
|
), "convert mono only accepts single waveform"
|
||||||
|
audio = audio.reshape(audio.shape[1], audio.shape[2])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
audio.shape[1] >> audio.shape[0]
|
||||||
|
), f"expected input format (num_channels,num_samples) got {audio.shape}"
|
||||||
|
num_channels, num_samples = audio.shape
|
||||||
|
if num_channels > 1:
|
||||||
|
return audio.mean(axis=0).reshape(1, num_samples)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resample_audio(
|
||||||
|
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
resample audio to desired sampling rate
|
||||||
|
parameters:
|
||||||
|
audio : Path to audio file or numpy array or torch tensor
|
||||||
|
audio waveform
|
||||||
|
sr : int
|
||||||
|
current sampling rate
|
||||||
|
target_sr : int
|
||||||
|
target sampling rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
if sr != target_sr:
|
||||||
|
if isinstance(audio, np.ndarray):
|
||||||
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
||||||
|
elif isinstance(audio, torch.Tensor):
|
||||||
|
audio = torchaudio.functional.resample(
|
||||||
|
audio, orig_freq=sr, new_freq=target_sr
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Input should be either numpy array or torch tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_unique_rng(epoch: int):
|
||||||
|
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||||
|
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
|
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
|
||||||
|
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
node_rank = int(os.environ.get("NODE_RANK", "0"))
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", "0"))
|
||||||
|
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
num_workers = worker_info.num_workers
|
||||||
|
worker_id = worker_info.id
|
||||||
|
else:
|
||||||
|
num_workers = 1
|
||||||
|
worker_id = 0
|
||||||
|
|
||||||
|
seed = (
|
||||||
|
global_seed
|
||||||
|
+ worker_id
|
||||||
|
+ local_rank * num_workers
|
||||||
|
+ node_rank * num_workers * global_rank
|
||||||
|
+ epoch * num_workers * world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
rng.seed(seed)
|
||||||
|
|
||||||
|
return rng
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFFT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_len: int,
|
||||||
|
nfft: Optional[int] = None,
|
||||||
|
window: str = "hamming",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_len = window_len
|
||||||
|
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
|
||||||
|
self.window = torch.from_numpy(
|
||||||
|
get_window(window, window_len, fftbins=True).astype("float32")
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_kernel(self, inverse=False):
|
||||||
|
|
||||||
|
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
|
||||||
|
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
|
||||||
|
kernel = np.concatenate([real, imag], 1).T
|
||||||
|
if inverse:
|
||||||
|
kernel = np.linalg.pinv(kernel).T
|
||||||
|
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
|
||||||
|
kernel *= self.window
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
class ConvSTFT(ConvFFT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_len: int,
|
||||||
|
hop_size: Optional[int] = None,
|
||||||
|
nfft: Optional[int] = None,
|
||||||
|
window: str = "hamming",
|
||||||
|
):
|
||||||
|
super().__init__(window_len=window_len, nfft=nfft, window=window)
|
||||||
|
self.hop_size = hop_size if hop_size else window_len // 2
|
||||||
|
self.register_buffer("weight", self.init_kernel())
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if input.dim() < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected signal with shape 2 or 3 got {input.dim()}"
|
||||||
|
)
|
||||||
|
elif input.dim() == 2:
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
input = F.pad(
|
||||||
|
input,
|
||||||
|
(self.window_len - self.hop_size, self.window_len - self.hop_size),
|
||||||
|
)
|
||||||
|
output = F.conv1d(input, self.weight, stride=self.hop_size)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ConviSTFT(ConvFFT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_len: int,
|
||||||
|
hop_size: Optional[int] = None,
|
||||||
|
nfft: Optional[int] = None,
|
||||||
|
window: str = "hamming",
|
||||||
|
):
|
||||||
|
super().__init__(window_len=window_len, nfft=nfft, window=window)
|
||||||
|
self.hop_size = hop_size if hop_size else window_len // 2
|
||||||
|
self.register_buffer("weight", self.init_kernel(True))
|
||||||
|
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
|
||||||
|
|
||||||
|
def forward(self, input, phase=None):
|
||||||
|
|
||||||
|
if phase is not None:
|
||||||
|
real = input * torch.cos(phase)
|
||||||
|
imag = input * torch.sin(phase)
|
||||||
|
input = torch.cat([real, imag], 1)
|
||||||
|
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
|
||||||
|
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
|
||||||
|
coeff = coeff.to(input.device)
|
||||||
|
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
|
||||||
|
out = out / (coeff + 1e-8)
|
||||||
|
pad = self.window_len - self.hop_size
|
||||||
|
out = out[..., pad:-pad]
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
|
def check_files(root_dir: str, files: Files):
|
||||||
|
|
||||||
|
path_variables = [
|
||||||
|
member_var
|
||||||
|
for member_var in dir(files)
|
||||||
|
if not member_var.startswith("__")
|
||||||
|
]
|
||||||
|
for variable in path_variables:
|
||||||
|
path = getattr(files, variable)
|
||||||
|
if not os.path.isdir(os.path.join(root_dir, path)):
|
||||||
|
raise ValueError(f"Invalid {path}, is not a directory")
|
||||||
|
|
||||||
|
return files, root_dir
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
|
||||||
|
|
||||||
|
params = dict(default_dict)
|
||||||
|
if custom:
|
||||||
|
params.update(custom)
|
||||||
|
return params
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
__version__ = "0.0.1"
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
name: enhancer
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
- pip=21.0.1
|
||||||
|
- python=3.8
|
||||||
|
- pip:
|
||||||
|
- -r requirements.txt
|
||||||
|
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
[tool.black]
|
||||||
|
line-length = 80
|
||||||
|
target-version = ['py38']
|
||||||
|
exclude = '''
|
||||||
|
|
||||||
|
(
|
||||||
|
/(
|
||||||
|
\.eggs # exclude a few common directories in the
|
||||||
|
| \.git # root of the project
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.tox
|
||||||
|
| \.venv
|
||||||
|
)/
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
boto3>=1.24.86
|
||||||
|
huggingface-hub>=0.10.0
|
||||||
|
hydra-core>=1.2.0
|
||||||
|
joblib>=1.2.0
|
||||||
|
librosa>=0.9.2
|
||||||
|
mlflow>=1.29.0
|
||||||
|
numpy>=1.23.3
|
||||||
|
pesq==0.0.4
|
||||||
|
protobuf>=3.19.6
|
||||||
|
pystoi==0.3.3
|
||||||
|
pytest-lazy-fixture>=0.6.3
|
||||||
|
pytorch-lightning>=1.7.7
|
||||||
|
scikit-learn>=1.1.2
|
||||||
|
scipy>=1.9.1
|
||||||
|
soundfile>=0.11.0
|
||||||
|
torch>=1.12.1
|
||||||
|
torch-audiomentations==0.11.0
|
||||||
|
torchaudio>=0.12.1
|
||||||
|
tqdm>=4.64.1
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
# This file is used to configure your project.
|
||||||
|
# Read more about the various options under:
|
||||||
|
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
||||||
|
|
||||||
|
[metadata]
|
||||||
|
name = enhancer
|
||||||
|
description = Deep learning for speech enhacement
|
||||||
|
author = Shahul Ess
|
||||||
|
author-email = shahules786@gmail.com
|
||||||
|
license = mit
|
||||||
|
long-description = file: README.md
|
||||||
|
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
|
||||||
|
# Change if running only on Windows, Mac or Linux (comma-separated)
|
||||||
|
platforms = Linux, Mac
|
||||||
|
# Add here all kinds of additional classifiers as defined under
|
||||||
|
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||||
|
classifiers =
|
||||||
|
Development Status :: 4 - Beta
|
||||||
|
Programming Language :: Python
|
||||||
|
|
||||||
|
[options]
|
||||||
|
zip_safe = False
|
||||||
|
packages = find:
|
||||||
|
include_package_data = True
|
||||||
|
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
|
||||||
|
setup_requires = setuptools
|
||||||
|
# Add here dependencies of your project (semicolon/line-separated), e.g.
|
||||||
|
# install_requires = numpy; scipy
|
||||||
|
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
|
||||||
|
python_requires = >=3.8
|
||||||
|
|
||||||
|
[options.packages.find]
|
||||||
|
where = .
|
||||||
|
exclude =
|
||||||
|
tests
|
||||||
|
|
||||||
|
[options.extras_require]
|
||||||
|
# Add here additional requirements for extra features, to install with:
|
||||||
|
# `pip install fastaudio[PDF]` like:
|
||||||
|
# PDF = ReportLab; RXP
|
||||||
|
# Add here test requirements (semicolon/line-separated)
|
||||||
|
testing =
|
||||||
|
pytest>=7.1.3
|
||||||
|
pytest-cov>=4.0.0
|
||||||
|
dev =
|
||||||
|
pre-commit>=2.20.0
|
||||||
|
black>=22.8.0
|
||||||
|
flake8>=5.0.4
|
||||||
|
cli =
|
||||||
|
hydra-core >=1.1,<=1.2
|
||||||
|
|
||||||
|
|
||||||
|
[options.entry_points]
|
||||||
|
|
||||||
|
console_scripts =
|
||||||
|
enhancer-train=enhancer.cli.train:train
|
||||||
|
|
||||||
|
[test]
|
||||||
|
# py.test options when running `python setup.py test`
|
||||||
|
# addopts = --verbose
|
||||||
|
extras = True
|
||||||
|
|
||||||
|
[tool:pytest]
|
||||||
|
# Options for py.test:
|
||||||
|
# Specify command line options as you would do when invoking py.test directly.
|
||||||
|
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
||||||
|
# in order to write a coverage file that can be read by Jenkins.
|
||||||
|
addopts =
|
||||||
|
--cov enhancer --cov-report term-missing
|
||||||
|
--verbose
|
||||||
|
norecursedirs =
|
||||||
|
dist
|
||||||
|
build
|
||||||
|
.tox
|
||||||
|
testpaths = tests
|
||||||
|
|
||||||
|
[aliases]
|
||||||
|
dists = bdist_wheel
|
||||||
|
|
||||||
|
[bdist_wheel]
|
||||||
|
# Use this option if your package is pure-python
|
||||||
|
universal = 1
|
||||||
|
|
||||||
|
[build_sphinx]
|
||||||
|
source_dir = doc
|
||||||
|
build_dir = build/sphinx
|
||||||
|
|
||||||
|
[devpi:upload]
|
||||||
|
# Options for the devpi: PyPI server and packaging tool
|
||||||
|
# VCS export must be deactivated since we are using setuptools-scm
|
||||||
|
no-vcs = 1
|
||||||
|
formats = bdist_wheel
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
# Some sane defaults for the code style checker flake8
|
||||||
|
exclude =
|
||||||
|
.tox
|
||||||
|
build
|
||||||
|
dist
|
||||||
|
.eggs
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pkg_resources import VersionConflict, require
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
with open("README.md") as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
with open("requirements.txt") as f:
|
||||||
|
requirements = f.read().splitlines()
|
||||||
|
|
||||||
|
try:
|
||||||
|
require("setuptools>=38.3")
|
||||||
|
except VersionConflict:
|
||||||
|
print("Error: version of setuptools is too old (<38.3)!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.resolve()
|
||||||
|
# Creating the version file
|
||||||
|
|
||||||
|
with open("version.txt") as f:
|
||||||
|
version = f.read()
|
||||||
|
|
||||||
|
version = version.strip()
|
||||||
|
sha = "Unknown"
|
||||||
|
|
||||||
|
if os.getenv("BUILD_VERSION"):
|
||||||
|
version = os.getenv("BUILD_VERSION")
|
||||||
|
elif sha != "Unknown":
|
||||||
|
version += "+" + sha[:7]
|
||||||
|
print("-- Building version " + version)
|
||||||
|
|
||||||
|
version_path = ROOT_DIR / "enhancer" / "version.py"
|
||||||
|
|
||||||
|
with open(version_path, "w") as f:
|
||||||
|
f.write("__version__ = '{}'\n".format(version))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup(
|
||||||
|
name="enhancer",
|
||||||
|
namespace_packages=["enhancer"],
|
||||||
|
version=version,
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=requirements,
|
||||||
|
description="Deep learning toolkit for speech enhancement",
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
author="Shahul Es",
|
||||||
|
author_email="shahules786@gmail.com",
|
||||||
|
url="",
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Topic :: Scientific/Engineering",
|
||||||
|
],
|
||||||
|
)
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,32 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.loss import mean_absolute_error, mean_squared_error
|
||||||
|
|
||||||
|
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||||
|
|
||||||
|
|
||||||
|
def check_loss_shapes_compatibility(loss_fun):
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
shape = (1, 1000)
|
||||||
|
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("loss", loss_functions)
|
||||||
|
def test_loss_input_shapes(loss):
|
||||||
|
check_loss_shapes_compatibility(loss)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("loss", loss_functions)
|
||||||
|
def test_loss_output_type(loss):
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
|
||||||
|
batch_size, 1, 1000
|
||||||
|
)
|
||||||
|
loss_value = loss(prediction, target)
|
||||||
|
assert isinstance(loss_value.item(), float)
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
|
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||||
|
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexconv2d():
|
||||||
|
sample_input = torch.rand(1, 2, 256, 13)
|
||||||
|
conv = ComplexConv2d(
|
||||||
|
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = conv(sample_input)
|
||||||
|
assert out.shape == torch.Size([1, 32, 128, 13])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexconvtranspose2d():
|
||||||
|
sample_input = torch.rand(1, 512, 4, 13)
|
||||||
|
conv = ComplexConvTranspose2d(
|
||||||
|
256 * 2,
|
||||||
|
128 * 2,
|
||||||
|
kernel_size=(5, 2),
|
||||||
|
stride=(2, 1),
|
||||||
|
padding=(2, 0),
|
||||||
|
output_padding=(1, 0),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = conv(sample_input)
|
||||||
|
|
||||||
|
assert out.shape == torch.Size([1, 256, 8, 14])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexlstm():
|
||||||
|
sample_input = torch.rand(13, 2, 128)
|
||||||
|
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = lstm(sample_input)
|
||||||
|
|
||||||
|
assert out[0].shape == torch.Size([13, 1, 512])
|
||||||
|
assert out[1].shape == torch.Size([13, 1, 512])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexbatchnorm2d():
|
||||||
|
sample_input = torch.rand(1, 64, 64, 14)
|
||||||
|
batchnorm = ComplexBatchNorm2D(num_features=64)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = batchnorm(sample_input)
|
||||||
|
|
||||||
|
assert out.size() == sample_input.size()
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models import Demucs
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vctk_dataset():
|
||||||
|
root_dir = "tests/data/vctk"
|
||||||
|
files = Files(
|
||||||
|
train_clean="clean_testset_wav",
|
||||||
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
|
def test_forward(batch_size, samples):
|
||||||
|
model = Demucs()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dataset,channels,loss",
|
||||||
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = Demucs(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models.dccrn import DCCRN
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vctk_dataset():
|
||||||
|
root_dir = "tests/data/vctk"
|
||||||
|
files = Files(
|
||||||
|
train_clean="clean_testset_wav",
|
||||||
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
|
def test_forward(batch_size, samples):
|
||||||
|
model = DCCRN()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dataset,channels,loss",
|
||||||
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models import WaveUnet
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vctk_dataset():
|
||||||
|
root_dir = "tests/data/vctk"
|
||||||
|
files = Files(
|
||||||
|
train_clean="clean_testset_wav",
|
||||||
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
|
def test_forward(batch_size, samples):
|
||||||
|
model = WaveUnet()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dataset,channels,loss",
|
||||||
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.inference import Inference
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"audio",
|
||||||
|
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
|
||||||
|
)
|
||||||
|
def test_read_input(audio):
|
||||||
|
|
||||||
|
read_audio = Inference.read_input(audio, 48000, 16000)
|
||||||
|
assert isinstance(read_audio, torch.Tensor)
|
||||||
|
assert read_audio.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchify():
|
||||||
|
rand = torch.rand(1, 1000)
|
||||||
|
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
|
||||||
|
assert batched_rand.shape[0] == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate():
|
||||||
|
rand = torch.rand(12, 1, 100)
|
||||||
|
agg_rand = Inference.aggreagate(
|
||||||
|
data=rand, window_size=100, total_frames=1000, step_size=100
|
||||||
|
)
|
||||||
|
assert agg_rand.shape[-1] == 1000
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft_istft():
|
||||||
|
sample_input = torch.rand(1, 1, 16000)
|
||||||
|
stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
|
||||||
|
istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
spectrogram = stft(sample_input)
|
||||||
|
waveform = istft(spectrogram)
|
||||||
|
assert sample_input.shape == waveform.shape
|
||||||
|
assert (
|
||||||
|
torch.isclose(waveform, sample_input).sum().item()
|
||||||
|
> sample_input.shape[-1] // 2
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
|
||||||
|
|
||||||
|
def test_io_channel():
|
||||||
|
|
||||||
|
input_audio = np.random.rand(2, 32000)
|
||||||
|
audio = Audio(mono=True, return_tensor=False)
|
||||||
|
output_audio = audio(input_audio)
|
||||||
|
assert output_audio.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_io_resampling():
|
||||||
|
|
||||||
|
input_audio = np.random.rand(1, 32000)
|
||||||
|
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
|
||||||
|
|
||||||
|
input_audio = torch.rand(1, 32000)
|
||||||
|
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
|
||||||
|
|
||||||
|
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
||||||
|
|
||||||
|
|
||||||
|
def test_fileprocessor_vctk():
|
||||||
|
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
"vctk",
|
||||||
|
"tests/data/vctk/clean_testset_wav",
|
||||||
|
"tests/data/vctk/noisy_testset_wav",
|
||||||
|
)
|
||||||
|
matching_dict = fp.prepare_matching_dict()
|
||||||
|
assert len(matching_dict) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||||
|
def test_fileprocessor_names(dataset_name):
|
||||||
|
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir")
|
||||||
|
assert hasattr(fp.matching_function, "__call__")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fileprocessor_invaliname():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = Fileprocessor.from_name(
|
||||||
|
"undefined", "clean_dir", "noisy_dir", 16000
|
||||||
|
).prepare_matching_dict()
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
0.0.1
|
||||||
Loading…
Reference in New Issue