diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..abbbc73
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,9 @@
+[flake8]
+per-file-ignores = __init__.py:F401
+ignore = E203, E266, E501, W503
+# line length is intentionally set to 80 here because black uses Bugbear
+# See https://github.com/psf/black/blob/master/README.md#line-length for more details
+max-line-length = 80
+max-complexity = 18
+select = B,C,E,F,W,T4,B9
+exclude = tools/kaldi_decoder
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
new file mode 100644
index 0000000..4c64745
--- /dev/null
+++ b/.github/workflows/ci.yaml
@@ -0,0 +1,51 @@
+# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
+# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
+
+name: Enhancer
+
+on:
+ push:
+ branches: [ dev ]
+ pull_request:
+ branches: [ dev ]
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [3.8]
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python
+ uses: actions/setup-python@v1.1.1
+ env :
+ ACTIONS_ALLOW_UNSECURE_COMMANDS : true
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Cache pip
+ uses: actions/cache@v1
+ with:
+ path: ~/.cache/pip # This path is specific to Ubuntu
+ # Look to see if there is a cache hit for the corresponding requirements file
+ key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
+ restore-keys: |
+ ${{ runner.os }}-pip-
+ ${{ runner.os }}-
+ # You can test your matrix by printing the current Python version
+ - name: Display Python version
+ run: python -c "import sys; print(sys.version)"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ sudo apt-get install libsndfile1
+ pip install -r requirements.txt
+ pip install black pytest-cov
+ - name: Install enhancer
+ run: |
+ pip install -e .[dev,testing]
+ - name: Run black
+ run:
+ black --check . --exclude enhancer/version.py
+ - name: Test with pytest
+ run:
+ pytest tests --cov=enhancer/
diff --git a/.gitignore b/.gitignore
index b6e4761..cd1b1e9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,10 @@
+#local
+*.ckpt
+*_local.yaml
+cli/train_config/dataset/Vctk_local.yaml
+.DS_Store
+outputs/
+datasets/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..807429c
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,43 @@
+
+repos:
+ # # Clean Notebooks
+ # - repo: https://github.com/kynan/nbstripout
+ # rev: master
+ # hooks:
+ # - id: nbstripout
+ # Format Code
+ - repo: https://github.com/ambv/black
+ rev: 22.8.0
+ hooks:
+ - id: black
+
+ # Sort imports
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+ args: ["--profile", "black"]
+
+ - repo: https://gitlab.com/pycqa/flake8
+ rev: 5.0.4
+ hooks:
+ - id: flake8
+ args: ['--ignore=E203,E501,F811,E712,W503']
+
+ # Formatting, Whitespace, etc
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v3.2.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-added-large-files
+ args: ['--maxkb=1000']
+ - id: check-ast
+ - id: check-json
+ - id: check-merge-conflict
+ - id: check-xml
+ - id: check-yaml
+ - id: debug-statements
+ - id: end-of-file-fixer
+ - id: requirements-txt-fixer
+ - id: mixed-line-ending
+ args: ['--fix=no']
diff --git a/README.md b/README.md
index e462afa..13b8e14 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,43 @@
-# enhancer
\ No newline at end of file
+
+
+
+
+mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
+
+| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
+## Key features :key:
+
+* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
+* :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
+* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
+* :zap: Supports multi-gpu training integrated with Pytorch Lightning.
+
+## Quick Start :fire:
+``` python
+from mayavoz import Mayamodel
+
+model = Mayamodel.from_pretrained("mayavoz/waveunet")
+model("noisy_audio.wav")
+```
+
+## Installation
+Only Python 3.8+ is officially supported (though it might work with Python 3.7)
+
+- With Pypi
+```
+pip install mayavoz
+```
+
+- With conda
+
+```
+conda env create -f environment.yml
+conda activate mayavoz
+```
+
+- From source code
+```
+git clone url
+cd mayavoz
+pip install -e .
+```
diff --git a/enhancer/__init__.py b/enhancer/__init__.py
new file mode 100644
index 0000000..5284146
--- /dev/null
+++ b/enhancer/__init__.py
@@ -0,0 +1 @@
+__import__("pkg_resources").declare_namespace(__name__)
diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py
new file mode 100644
index 0000000..c00c024
--- /dev/null
+++ b/enhancer/cli/train.py
@@ -0,0 +1,120 @@
+import os
+from types import MethodType
+
+import hydra
+from hydra.utils import instantiate
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning.callbacks import (
+ EarlyStopping,
+ LearningRateMonitor,
+ ModelCheckpoint,
+)
+from pytorch_lightning.loggers import MLFlowLogger
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+
+# from torch_audiomentations import Compose, Shift
+
+os.environ["HYDRA_FULL_ERROR"] = "1"
+JOB_ID = os.environ.get("SLURM_JOBID", "0")
+
+
+@hydra.main(config_path="train_config", config_name="config")
+def main(config: DictConfig):
+
+ OmegaConf.save(config, "config_log.yaml")
+
+ callbacks = []
+ logger = MLFlowLogger(
+ experiment_name=config.mlflow.experiment_name,
+ run_name=config.mlflow.run_name,
+ tags={"JOB_ID": JOB_ID},
+ )
+
+ parameters = config.hyperparameters
+ # apply_augmentations = Compose(
+ # [
+ # Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
+ # ]
+ # )
+
+ dataset = instantiate(config.dataset, augmentations=None)
+ model = instantiate(
+ config.model,
+ dataset=dataset,
+ lr=parameters.get("lr"),
+ loss=parameters.get("loss"),
+ metric=parameters.get("metric"),
+ )
+
+ direction = model.valid_monitor
+ checkpoint = ModelCheckpoint(
+ dirpath="./model",
+ filename=f"model_{JOB_ID}",
+ monitor="valid_loss",
+ verbose=False,
+ mode=direction,
+ every_n_epochs=1,
+ )
+ callbacks.append(checkpoint)
+ callbacks.append(LearningRateMonitor(logging_interval="epoch"))
+
+ if parameters.get("Early_stop", False):
+ early_stopping = EarlyStopping(
+ monitor="val_loss",
+ mode=direction,
+ min_delta=0.0,
+ patience=parameters.get("EarlyStopping_patience", 10),
+ strict=True,
+ verbose=False,
+ )
+ callbacks.append(early_stopping)
+
+ def configure_optimizers(self):
+ optimizer = instantiate(
+ config.optimizer,
+ lr=parameters.get("lr"),
+ params=self.parameters(),
+ )
+ scheduler = ReduceLROnPlateau(
+ optimizer=optimizer,
+ mode=direction,
+ factor=parameters.get("ReduceLr_factor", 0.1),
+ verbose=True,
+ min_lr=parameters.get("min_lr", 1e-6),
+ patience=parameters.get("ReduceLr_patience", 3),
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": scheduler,
+ "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
+ }
+
+ model.configure_optimizers = MethodType(configure_optimizers, model)
+
+ trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
+ trainer.fit(model)
+ trainer.test(model)
+
+ logger.experiment.log_artifact(
+ logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
+ )
+
+ saved_location = os.path.join(
+ trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
+ )
+ if os.path.isfile(saved_location):
+ logger.experiment.log_artifact(logger.run_id, saved_location)
+ logger.experiment.log_param(
+ logger.run_id,
+ "num_train_steps_per_epoch",
+ dataset.train__len__() / dataset.batch_size,
+ )
+ logger.experiment.log_param(
+ logger.run_id,
+ "num_valid_steps_per_epoch",
+ dataset.val__len__() / dataset.batch_size,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml
new file mode 100644
index 0000000..8d0ab14
--- /dev/null
+++ b/enhancer/cli/train_config/config.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - model : Demucs
+ - dataset : Vctk
+ - optimizer : Adam
+ - hyperparameters : default
+ - trainer : default
+ - mlflow : experiment
diff --git a/enhancer/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml
new file mode 100644
index 0000000..09a14fb
--- /dev/null
+++ b/enhancer/cli/train_config/dataset/DNS-2020.yaml
@@ -0,0 +1,12 @@
+_target_: enhancer.data.dataset.EnhancerDataset
+root_dir : /Users/shahules/Myprojects/MS-SNSD
+name : dns-2020
+duration : 2.0
+sampling_rate: 16000
+batch_size: 32
+valid_size: 0.05
+files:
+ train_clean : CleanSpeech_training
+ test_clean : CleanSpeech_training
+ train_noisy : NoisySpeech_training
+ test_noisy : NoisySpeech_training
diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml
new file mode 100644
index 0000000..c33d29a
--- /dev/null
+++ b/enhancer/cli/train_config/dataset/Vctk.yaml
@@ -0,0 +1,13 @@
+_target_: enhancer.data.dataset.EnhancerDataset
+name : vctk
+root_dir : /scratch/c.sistc3/DS_10283_2791
+duration : 4.5
+stride : 2
+sampling_rate: 16000
+batch_size: 32
+valid_minutes : 15
+files:
+ train_clean : clean_trainset_28spk_wav
+ test_clean : clean_testset_wav
+ train_noisy : noisy_trainset_28spk_wav
+ test_noisy : noisy_testset_wav
diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml
new file mode 100644
index 0000000..ba44597
--- /dev/null
+++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml
@@ -0,0 +1,13 @@
+_target_: enhancer.data.dataset.EnhancerDataset
+name : vctk
+root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
+duration : 1.0
+sampling_rate: 16000
+batch_size: 64
+num_workers : 0
+
+files:
+ train_clean : clean_testset_wav
+ test_clean : clean_testset_wav
+ train_noisy : noisy_testset_wav
+ test_noisy : noisy_testset_wav
diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml
new file mode 100644
index 0000000..1782ea9
--- /dev/null
+++ b/enhancer/cli/train_config/hyperparameters/default.yaml
@@ -0,0 +1,7 @@
+loss : mae
+metric : [stoi,pesq,si-sdr]
+lr : 0.0003
+ReduceLr_patience : 5
+ReduceLr_factor : 0.2
+min_lr : 0.000001
+EarlyStopping_factor : 10
diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml
new file mode 100644
index 0000000..d597333
--- /dev/null
+++ b/enhancer/cli/train_config/mlflow/experiment.yaml
@@ -0,0 +1,2 @@
+experiment_name : shahules/enhancer
+run_name : Demucs + Vtck with stride + augmentations
diff --git a/enhancer/cli/train_config/model/DCCRN.yaml b/enhancer/cli/train_config/model/DCCRN.yaml
new file mode 100644
index 0000000..3190391
--- /dev/null
+++ b/enhancer/cli/train_config/model/DCCRN.yaml
@@ -0,0 +1,25 @@
+_target_: enhancer.models.dccrn.DCCRN
+num_channels: 1
+sampling_rate : 16000
+complex_lstm : True
+complex_norm : True
+complex_relu : True
+masking_mode : True
+
+encoder_decoder:
+ initial_output_channels : 32
+ depth : 6
+ kernel_size : 5
+ growth_factor : 2
+ stride : 2
+ padding : 2
+ output_padding : 1
+
+lstm:
+ num_layers : 2
+ hidden_size : 256
+
+stft:
+ window_len : 400
+ hop_size : 100
+ nfft : 512
diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml
new file mode 100644
index 0000000..513e603
--- /dev/null
+++ b/enhancer/cli/train_config/model/Demucs.yaml
@@ -0,0 +1,16 @@
+_target_: enhancer.models.demucs.Demucs
+num_channels: 1
+resample: 4
+sampling_rate : 16000
+
+encoder_decoder:
+ depth: 4
+ initial_output_channels: 64
+ kernel_size: 8
+ stride: 4
+ growth_factor: 2
+ glu: True
+
+lstm:
+ bidirectional: False
+ num_layers: 2
diff --git a/enhancer/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml
new file mode 100644
index 0000000..29d48c7
--- /dev/null
+++ b/enhancer/cli/train_config/model/WaveUnet.yaml
@@ -0,0 +1,5 @@
+_target_: enhancer.models.waveunet.WaveUnet
+num_channels : 1
+depth : 9
+initial_output_channels: 24
+sampling_rate : 16000
diff --git a/enhancer/cli/train_config/optimizer/Adam.yaml b/enhancer/cli/train_config/optimizer/Adam.yaml
new file mode 100644
index 0000000..7952b81
--- /dev/null
+++ b/enhancer/cli/train_config/optimizer/Adam.yaml
@@ -0,0 +1,6 @@
+_target_: torch.optim.Adam
+lr: 1e-3
+betas: [0.9, 0.999]
+eps: 1e-08
+weight_decay: 0
+amsgrad: False
diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml
new file mode 100644
index 0000000..958c418
--- /dev/null
+++ b/enhancer/cli/train_config/trainer/default.yaml
@@ -0,0 +1,46 @@
+_target_: pytorch_lightning.Trainer
+accelerator: gpu
+accumulate_grad_batches: 1
+amp_backend: native
+auto_lr_find: True
+auto_scale_batch_size: False
+auto_select_gpus: True
+benchmark: False
+check_val_every_n_epoch: 1
+detect_anomaly: False
+deterministic: False
+devices: 2
+enable_checkpointing: True
+enable_model_summary: True
+enable_progress_bar: True
+fast_dev_run: False
+gpus: null
+gradient_clip_val: 0
+gradient_clip_algorithm: norm
+ipus: null
+limit_predict_batches: 1.0
+limit_test_batches: 1.0
+limit_train_batches: 1.0
+limit_val_batches: 1.0
+log_every_n_steps: 50
+max_epochs: 200
+max_steps: -1
+max_time: null
+min_epochs: 1
+min_steps: null
+move_metrics_to_cpu: False
+multiple_trainloader_mode: max_size_cycle
+num_nodes: 1
+num_processes: 1
+num_sanity_val_steps: 2
+overfit_batches: 0.0
+precision: 32
+profiler: null
+reload_dataloaders_every_n_epochs: 0
+replace_sampler_ddp: True
+strategy: ddp
+sync_batchnorm: False
+tpu_cores: null
+track_grad_norm: -1
+val_check_interval: 1.0
+weights_save_path: null
diff --git a/enhancer/cli/train_config/trainer/fastrun_dev.yaml b/enhancer/cli/train_config/trainer/fastrun_dev.yaml
new file mode 100644
index 0000000..682149e
--- /dev/null
+++ b/enhancer/cli/train_config/trainer/fastrun_dev.yaml
@@ -0,0 +1,2 @@
+_target_: pytorch_lightning.Trainer
+fast_dev_run: True
diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py
new file mode 100644
index 0000000..7efd946
--- /dev/null
+++ b/enhancer/data/__init__.py
@@ -0,0 +1 @@
+from enhancer.data.dataset import EnhancerDataset
diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py
new file mode 100644
index 0000000..284dfdb
--- /dev/null
+++ b/enhancer/data/dataset.py
@@ -0,0 +1,376 @@
+import math
+import multiprocessing
+import os
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset, RandomSampler
+from torch_audiomentations import Compose
+
+from enhancer.data.fileprocessor import Fileprocessor
+from enhancer.utils import check_files
+from enhancer.utils.config import Files
+from enhancer.utils.io import Audio
+from enhancer.utils.random import create_unique_rng
+
+LARGE_NUM = 2147483647
+
+
+class TrainDataset(Dataset):
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, idx):
+ return self.dataset.train__getitem__(idx)
+
+ def __len__(self):
+ return self.dataset.train__len__()
+
+
+class ValidDataset(Dataset):
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, idx):
+ return self.dataset.val__getitem__(idx)
+
+ def __len__(self):
+ return self.dataset.val__len__()
+
+
+class TestDataset(Dataset):
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, idx):
+ return self.dataset.test__getitem__(idx)
+
+ def __len__(self):
+ return self.dataset.test__len__()
+
+
+class TaskDataset(pl.LightningDataModule):
+ def __init__(
+ self,
+ name: str,
+ root_dir: str,
+ files: Files,
+ min_valid_minutes: float = 0.20,
+ duration: float = 1.0,
+ stride=None,
+ sampling_rate: int = 48000,
+ matching_function=None,
+ batch_size=32,
+ num_workers: Optional[int] = None,
+ augmentations: Optional[Compose] = None,
+ ):
+ super().__init__()
+
+ self.name = name
+ self.files, self.root_dir = check_files(root_dir, files)
+ self.duration = duration
+ self.stride = stride or duration
+ self.sampling_rate = sampling_rate
+ self.batch_size = batch_size
+ self.matching_function = matching_function
+ self._validation = []
+ if num_workers is None:
+ num_workers = multiprocessing.cpu_count() // 2
+ self.num_workers = num_workers
+ if min_valid_minutes > 0.0:
+ self.min_valid_minutes = min_valid_minutes
+ else:
+ raise ValueError("min_valid_minutes must be greater than 0")
+
+ self.augmentations = augmentations
+
+ def setup(self, stage: Optional[str] = None):
+ """
+ prepare train/validation/test data splits
+ """
+
+ if stage in ("fit", None):
+
+ train_clean = os.path.join(self.root_dir, self.files.train_clean)
+ train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
+ fp = Fileprocessor.from_name(
+ self.name, train_clean, train_noisy, self.matching_function
+ )
+ train_data = fp.prepare_matching_dict()
+ train_data, self.val_data = self.train_valid_split(
+ train_data,
+ min_valid_minutes=self.min_valid_minutes,
+ random_state=42,
+ )
+
+ self.train_data = self.prepare_traindata(train_data)
+ self._validation = self.prepare_mapstype(self.val_data)
+
+ test_clean = os.path.join(self.root_dir, self.files.test_clean)
+ test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
+ fp = Fileprocessor.from_name(
+ self.name, test_clean, test_noisy, self.matching_function
+ )
+ test_data = fp.prepare_matching_dict()
+ self._test = self.prepare_mapstype(test_data)
+
+ def train_valid_split(
+ self, data, min_valid_minutes: float = 20, random_state: int = 42
+ ):
+
+ min_valid_minutes *= 60
+ valid_sec_now = 0.0
+ valid_indices = []
+ all_speakers = np.unique(
+ [Path(file["clean"]).name.split("_")[0] for file in data]
+ )
+ possible_indices = list(range(0, len(all_speakers)))
+ rng = create_unique_rng(len(all_speakers))
+
+ while valid_sec_now <= min_valid_minutes:
+ speaker_index = rng.choice(possible_indices)
+ possible_indices.remove(speaker_index)
+ speaker_name = all_speakers[speaker_index]
+ print(f"Selected f{speaker_name} for valid")
+ file_indices = [
+ i
+ for i, file in enumerate(data)
+ if speaker_name == Path(file["clean"]).name.split("_")[0]
+ ]
+ for i in file_indices:
+ valid_indices.append(i)
+ valid_sec_now += data[i]["duration"]
+
+ train_data = [
+ item for i, item in enumerate(data) if i not in valid_indices
+ ]
+ valid_data = [item for i, item in enumerate(data) if i in valid_indices]
+ return train_data, valid_data
+
+ def prepare_traindata(self, data):
+ train_data = []
+ for item in data:
+ clean, noisy, total_dur = item.values()
+ num_segments = self.get_num_segments(
+ total_dur, self.duration, self.stride
+ )
+ samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
+ train_data.append(samples_metadata)
+ return train_data
+
+ @staticmethod
+ def get_num_segments(file_duration, duration, stride):
+
+ if file_duration < duration:
+ num_segments = 1
+ else:
+ num_segments = math.ceil((file_duration - duration) / stride) + 1
+
+ return num_segments
+
+ def prepare_mapstype(self, data):
+
+ metadata = []
+ for item in data:
+ clean, noisy, total_dur = item.values()
+ if total_dur < self.duration:
+ metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
+ else:
+ num_segments = self.get_num_segments(
+ total_dur, self.duration, self.duration
+ )
+ for index in range(num_segments):
+ start_time = index * self.duration
+ metadata.append(
+ ({"clean": clean, "noisy": noisy}, start_time)
+ )
+ return metadata
+
+ def train_collatefn(self, batch):
+
+ output = {"clean": [], "noisy": []}
+ for item in batch:
+ output["clean"].append(item["clean"])
+ output["noisy"].append(item["noisy"])
+
+ output["clean"] = torch.stack(output["clean"], dim=0)
+ output["noisy"] = torch.stack(output["noisy"], dim=0)
+
+ if self.augmentations is not None:
+ noise = output["noisy"] - output["clean"]
+ output["clean"] = self.augmentations(
+ output["clean"], sample_rate=self.sampling_rate
+ )
+ self.augmentations.freeze_parameters()
+ output["noisy"] = (
+ self.augmentations(noise, sample_rate=self.sampling_rate)
+ + output["clean"]
+ )
+
+ return output
+
+ @property
+ def generator(self):
+ generator = torch.Generator()
+ if hasattr(self, "model"):
+ seed = self.model.current_epoch + LARGE_NUM
+ else:
+ seed = LARGE_NUM
+ return generator.manual_seed(seed)
+
+ def train_dataloader(self):
+ dataset = TrainDataset(self)
+ sampler = RandomSampler(dataset, generator=self.generator)
+ return DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ sampler=sampler,
+ collate_fn=self.train_collatefn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ ValidDataset(self),
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ TestDataset(self),
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ )
+
+
+class EnhancerDataset(TaskDataset):
+ """
+ Dataset object for creating clean-noisy speech enhancement datasets
+ paramters:
+ name : str
+ name of the dataset
+ root_dir : str
+ root directory of the dataset containing clean/noisy folders
+ files : Files
+ dataclass containing train_clean, train_noisy, test_clean, test_noisy
+ folder names (refer enhancer.utils.Files dataclass)
+ min_valid_minutes: float
+ minimum validation split size time in minutes
+ algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
+ duration : float
+ expected audio duration of single audio sample for training
+ sampling_rate : int
+ desired sampling rate
+ batch_size : int
+ batch size of each batch
+ num_workers : int
+ num workers to be used while training
+ matching_function : str
+ maching functions - (one_to_one,one_to_many). Default set to None.
+ use one_to_one mapping for datasets with one noisy file for each clean file
+ use one_to_many mapping for multiple noisy files for each clean file
+
+
+
+ """
+
+ def __init__(
+ self,
+ name: str,
+ root_dir: str,
+ files: Files,
+ min_valid_minutes=5.0,
+ duration=1.0,
+ stride=None,
+ sampling_rate=48000,
+ matching_function=None,
+ batch_size=32,
+ num_workers: Optional[int] = None,
+ augmentations: Optional[Compose] = None,
+ ):
+
+ super().__init__(
+ name=name,
+ root_dir=root_dir,
+ files=files,
+ min_valid_minutes=min_valid_minutes,
+ sampling_rate=sampling_rate,
+ duration=duration,
+ matching_function=matching_function,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ augmentations=augmentations,
+ )
+
+ self.sampling_rate = sampling_rate
+ self.files = files
+ self.duration = max(1.0, duration)
+ self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
+ self.stride = stride or duration
+
+ def setup(self, stage: Optional[str] = None):
+
+ super().setup(stage=stage)
+
+ def train__getitem__(self, idx):
+
+ for filedict, num_samples in self.train_data:
+ if idx >= num_samples:
+ idx -= num_samples
+ continue
+ else:
+ start = 0
+ if self.duration is not None:
+ start = idx * self.stride
+ return self.prepare_segment(filedict, start)
+
+ def val__getitem__(self, idx):
+ return self.prepare_segment(*self._validation[idx])
+
+ def test__getitem__(self, idx):
+ return self.prepare_segment(*self._test[idx])
+
+ def prepare_segment(self, file_dict: dict, start_time: float):
+ clean_segment = self.audio(
+ file_dict["clean"], offset=start_time, duration=self.duration
+ )
+ noisy_segment = self.audio(
+ file_dict["noisy"], offset=start_time, duration=self.duration
+ )
+ clean_segment = F.pad(
+ clean_segment,
+ (
+ 0,
+ int(
+ self.duration * self.sampling_rate - clean_segment.shape[-1]
+ ),
+ ),
+ )
+ noisy_segment = F.pad(
+ noisy_segment,
+ (
+ 0,
+ int(
+ self.duration * self.sampling_rate - noisy_segment.shape[-1]
+ ),
+ ),
+ )
+ return {
+ "clean": clean_segment,
+ "noisy": noisy_segment,
+ }
+
+ def train__len__(self):
+ _, num_examples = list(zip(*self.train_data))
+ return sum(num_examples)
+
+ def val__len__(self):
+ return len(self._validation)
+
+ def test__len__(self):
+ return len(self._test)
diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py
new file mode 100644
index 0000000..5b099d4
--- /dev/null
+++ b/enhancer/data/fileprocessor.py
@@ -0,0 +1,121 @@
+import glob
+import os
+
+import numpy as np
+from scipy.io import wavfile
+
+MATCHING_FNS = ("one_to_one", "one_to_many")
+
+
+class ProcessorFunctions:
+ """
+ Preprocessing methods for different types of speech enhacement datasets.
+ """
+
+ @staticmethod
+ def one_to_one(clean_path, noisy_path):
+ """
+ One clean audio can have only one noisy audio file
+ """
+
+ matching_wavfiles = list()
+ clean_filenames = [
+ file.split("/")[-1]
+ for file in glob.glob(os.path.join(clean_path, "*.wav"))
+ ]
+ noisy_filenames = [
+ file.split("/")[-1]
+ for file in glob.glob(os.path.join(noisy_path, "*.wav"))
+ ]
+ common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
+
+ for file_name in common_filenames:
+
+ sr_clean, clean_file = wavfile.read(
+ os.path.join(clean_path, file_name)
+ )
+ sr_noisy, noisy_file = wavfile.read(
+ os.path.join(noisy_path, file_name)
+ )
+ if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
+ sr_clean == sr_noisy
+ ):
+ matching_wavfiles.append(
+ {
+ "clean": os.path.join(clean_path, file_name),
+ "noisy": os.path.join(noisy_path, file_name),
+ "duration": clean_file.shape[-1] / sr_clean,
+ }
+ )
+ return matching_wavfiles
+
+ @staticmethod
+ def one_to_many(clean_path, noisy_path):
+ """
+ One clean audio have multiple noisy audio files
+ """
+
+ matching_wavfiles = list()
+ clean_filenames = [
+ file.split("/")[-1]
+ for file in glob.glob(os.path.join(clean_path, "*.wav"))
+ ]
+ for clean_file in clean_filenames:
+ noisy_filenames = glob.glob(
+ os.path.join(noisy_path, f"*_{clean_file}")
+ )
+ for noisy_file in noisy_filenames:
+
+ sr_clean, clean_wav = wavfile.read(
+ os.path.join(clean_path, clean_file)
+ )
+ sr_noisy, noisy_wav = wavfile.read(noisy_file)
+ if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and (
+ sr_clean == sr_noisy
+ ):
+ matching_wavfiles.append(
+ {
+ "clean": os.path.join(clean_path, clean_file),
+ "noisy": noisy_file,
+ "duration": clean_wav.shape[-1] / sr_clean,
+ }
+ )
+ return matching_wavfiles
+
+
+class Fileprocessor:
+ def __init__(self, clean_dir, noisy_dir, matching_function=None):
+ self.clean_dir = clean_dir
+ self.noisy_dir = noisy_dir
+ self.matching_function = matching_function
+
+ @classmethod
+ def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
+
+ if matching_function is None:
+ if name.lower() == "vctk":
+ return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
+ elif name.lower() == "dns-2020":
+ return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
+ else:
+ raise ValueError(
+ f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
+ )
+ else:
+ if matching_function not in MATCHING_FNS:
+ raise ValueError(
+ f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
+ )
+ else:
+ return cls(
+ clean_dir,
+ noisy_dir,
+ getattr(ProcessorFunctions, matching_function),
+ )
+
+ def prepare_matching_dict(self):
+
+ if self.matching_function is None:
+ raise ValueError("Not a valid matching function")
+
+ return self.matching_function(self.clean_dir, self.noisy_dir)
diff --git a/enhancer/inference.py b/enhancer/inference.py
new file mode 100644
index 0000000..d9282fd
--- /dev/null
+++ b/enhancer/inference.py
@@ -0,0 +1,170 @@
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from librosa import load as load_audio
+from scipy.io import wavfile
+from scipy.signal import get_window
+
+from enhancer.utils import Audio
+
+
+class Inference:
+ """
+ contains methods used for inference.
+ """
+
+ @staticmethod
+ def read_input(audio, sr, model_sr):
+ """
+ read and verify audio input regardless of the input format.
+ arguments:
+ audio : audio input
+ sr : sampling rate of input audio
+ model_sr : sampling rate used for model training.
+ """
+
+ if isinstance(audio, (np.ndarray, torch.Tensor)):
+ assert sr is not None, "Invalid sampling rate!"
+ if len(audio.shape) == 1:
+ audio = audio.reshape(1, -1)
+
+ if isinstance(audio, str):
+ audio = Path(audio)
+ if not audio.is_file():
+ raise ValueError(f"Input file {audio} does not exist")
+ else:
+ audio, sr = load_audio(
+ audio,
+ sr=sr,
+ )
+ if len(audio.shape) == 1:
+ audio = audio.reshape(1, -1)
+ else:
+ assert (
+ audio.shape[0] == 1
+ ), "Enhance inference only supports single waveform"
+
+ waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
+ waveform = Audio.convert_mono(waveform)
+ if isinstance(waveform, np.ndarray):
+ waveform = torch.from_numpy(waveform)
+
+ return waveform
+
+ @staticmethod
+ def batchify(
+ waveform: torch.Tensor,
+ window_size: int,
+ step_size: Optional[int] = None,
+ ):
+ """
+ break input waveform into samples with duration specified.(Overlap-add)
+ arguments:
+ waveform : audio waveform
+ window_size : window size used for splitting waveform into batches
+ step_size : step_size used for splitting waveform into batches
+ """
+ assert (
+ waveform.ndim == 2
+ ), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
+ _, num_samples = waveform.shape
+ waveform = waveform.unsqueeze(-1)
+ step_size = window_size // 2 if step_size is None else step_size
+
+ if num_samples >= window_size:
+ waveform_batch = F.unfold(
+ waveform[None, ...],
+ kernel_size=(window_size, 1),
+ stride=(step_size, 1),
+ padding=(window_size, 0),
+ )
+ waveform_batch = waveform_batch.permute(2, 0, 1)
+
+ return waveform_batch
+
+ @staticmethod
+ def aggreagate(
+ data: torch.Tensor,
+ window_size: int,
+ total_frames: int,
+ step_size: Optional[int] = None,
+ window="hamming",
+ ):
+ """
+ stitch batched waveform into single waveform. (Overlap-add)
+ arguments:
+ data: batched waveform
+ window_size : window_size used to batch waveform
+ step_size : step_size used to batch waveform
+ total_frames : total number of frames present in original waveform
+ window : type of window used for overlap-add mechanism.
+ """
+ num_chunks, n_channels, num_frames = data.shape
+ window = get_window(window=window, Nx=data.shape[-1])
+ window = torch.from_numpy(window).to(data.device)
+ data *= window
+ step_size = window_size // 2 if step_size is None else step_size
+
+ data = data.permute(1, 2, 0)
+ data = F.fold(
+ data,
+ (total_frames, 1),
+ kernel_size=(window_size, 1),
+ stride=(step_size, 1),
+ padding=(window_size, 0),
+ ).squeeze(-1)
+
+ return data.reshape(1, n_channels, -1)
+
+ @staticmethod
+ def write_output(
+ waveform: torch.Tensor, filename: Union[str, Path], sr: int
+ ):
+ """
+ write audio output as wav file
+ arguments:
+ waveform : audio waveform
+ filename : name of the wave file. Output will be written as cleaned_filename.wav
+ sr : sampling rate
+ """
+
+ if isinstance(filename, str):
+ filename = Path(filename)
+
+ parent, name = filename.parent, "cleaned_" + filename.name
+ filename = parent / Path(name)
+ if filename.is_file():
+ raise FileExistsError(f"file {filename} already exists")
+ else:
+ wavfile.write(
+ filename, rate=sr, data=waveform.detach().cpu().numpy()
+ )
+
+ @staticmethod
+ def prepare_output(
+ waveform: torch.Tensor,
+ model_sampling_rate: int,
+ audio: Union[str, np.ndarray, torch.Tensor],
+ sampling_rate: Optional[int],
+ ):
+ """
+ prepare output audio based on input format
+ arguments:
+ waveform : predicted audio waveform
+ model_sampling_rate : sampling rate used to train the model
+ audio : input audio
+ sampling_rate : input audio sampling rate
+
+ """
+ if isinstance(audio, np.ndarray):
+ waveform = waveform.detach().cpu().numpy()
+
+ if sampling_rate is not None:
+ waveform = Audio.resample_audio(
+ waveform, sr=model_sampling_rate, target_sr=sampling_rate
+ )
+
+ return waveform
diff --git a/enhancer/loss.py b/enhancer/loss.py
new file mode 100644
index 0000000..75527bb
--- /dev/null
+++ b/enhancer/loss.py
@@ -0,0 +1,216 @@
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torchmetrics import ScaleInvariantSignalNoiseRatio
+from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
+from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
+
+
+class mean_squared_error(nn.Module):
+ """
+ Mean squared error / L1 loss
+ """
+
+ def __init__(self, reduction="mean"):
+ super().__init__()
+
+ self.loss_fun = nn.MSELoss(reduction=reduction)
+ self.higher_better = False
+ self.name = "mse"
+
+ def forward(self, prediction: torch.Tensor, target: torch.Tensor):
+
+ if prediction.size() != target.size() or target.ndim < 3:
+ raise TypeError(
+ f"""Inputs must be of the same shape (batch_size,channels,samples)
+ got {prediction.size()} and {target.size()} instead"""
+ )
+
+ return self.loss_fun(prediction, target)
+
+
+class mean_absolute_error(nn.Module):
+ """
+ Mean absolute error / L2 loss
+ """
+
+ def __init__(self, reduction="mean"):
+ super().__init__()
+
+ self.loss_fun = nn.L1Loss(reduction=reduction)
+ self.higher_better = False
+ self.name = "mae"
+
+ def forward(self, prediction: torch.Tensor, target: torch.Tensor):
+
+ if prediction.size() != target.size() or target.ndim < 3:
+ raise TypeError(
+ f"""Inputs must be of the same shape (batch_size,channels,samples)
+ got {prediction.size()} and {target.size()} instead"""
+ )
+
+ return self.loss_fun(prediction, target)
+
+
+class Si_SDR:
+ """
+ SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
+ """
+
+ def __init__(self, reduction: str = "mean"):
+ if reduction in ["sum", "mean", None]:
+ self.reduction = reduction
+ else:
+ raise TypeError(
+ "Invalid reduction, valid options are sum, mean, None"
+ )
+ self.higher_better = True
+ self.name = "si-sdr"
+
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
+
+ if prediction.size() != target.size() or target.ndim < 3:
+ raise TypeError(
+ f"""Inputs must be of the same shape (batch_size,channels,samples)
+ got {prediction.size()} and {target.size()} instead"""
+ )
+
+ target_energy = torch.sum(target**2, keepdim=True, dim=-1)
+ scaling_factor = (
+ torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
+ )
+ target_projection = target * scaling_factor
+ noise = prediction - target_projection
+ ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
+ noise**2, dim=-1
+ )
+ si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
+
+ if self.reduction == "sum":
+ si_sdr = si_sdr.sum()
+ elif self.reduction == "mean":
+ si_sdr = si_sdr.mean()
+ else:
+ pass
+
+ return si_sdr
+
+
+class Stoi:
+ """
+ STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
+ Note that input will be moved to cpu to perform the metric calculation.
+ parameters:
+ sr: int
+ sampling rate
+ """
+
+ def __init__(self, sr: int):
+ self.sr = sr
+ self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
+ self.name = "stoi"
+
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
+
+ return self.stoi(prediction, target)
+
+
+class Pesq:
+ def __init__(self, sr: int, mode="wb"):
+
+ self.sr = sr
+ self.name = "pesq"
+ self.mode = mode
+ self.pesq = PerceptualEvaluationSpeechQuality(
+ fs=self.sr, mode=self.mode
+ )
+
+ def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
+
+ pesq_values = []
+ for pred, target_ in zip(prediction, target):
+ try:
+ pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
+ except Exception as e:
+ logging.warning(f"{e} error occured while calculating PESQ")
+ return torch.tensor(np.mean(pesq_values))
+
+
+class LossWrapper(nn.Module):
+ """
+ Combine multiple metics of same nature.
+ for example, ["mea","mae"]
+ parameters:
+ losses : loss function names to be combined
+ """
+
+ def __init__(self, losses):
+ super().__init__()
+
+ self.valid_losses = nn.ModuleList()
+
+ direction = [
+ getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
+ ]
+ if len(set(direction)) > 1:
+ raise ValueError(
+ "all cost functions should be of same nature, maximize or minimize!"
+ )
+
+ self.higher_better = direction[0]
+ self.name = ""
+ for loss in losses:
+ loss = self.validate_loss(loss)
+ self.valid_losses.append(loss())
+ self.name += f"{loss().name}_"
+
+ def validate_loss(self, loss: str):
+ if loss not in LOSS_MAP.keys():
+ raise ValueError(
+ f"""Invalid loss function {loss}, available loss functions are
+ {tuple([loss for loss in LOSS_MAP.keys()])}"""
+ )
+ else:
+ return LOSS_MAP[loss]
+
+ def forward(self, prediction: torch.Tensor, target: torch.Tensor):
+ loss = 0.0
+ for loss_fun in self.valid_losses:
+ loss += loss_fun(prediction, target)
+
+ return loss
+
+
+class Si_snr(nn.Module):
+ """
+ SI-SNR
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
+ self.higher_better = True
+ self.name = "si_snr"
+
+ def forward(self, prediction: torch.Tensor, target: torch.Tensor):
+
+ if prediction.size() != target.size() or target.ndim < 3:
+ raise TypeError(
+ f"""Inputs must be of the same shape (batch_size,channels,samples)
+ got {prediction.size()} and {target.size()} instead"""
+ )
+
+ return self.loss_fun(prediction, target)
+
+
+LOSS_MAP = {
+ "mae": mean_absolute_error,
+ "mse": mean_squared_error,
+ "si-sdr": Si_SDR,
+ "pesq": Pesq,
+ "stoi": Stoi,
+ "si-snr": Si_snr,
+}
diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py
new file mode 100644
index 0000000..2d97568
--- /dev/null
+++ b/enhancer/models/__init__.py
@@ -0,0 +1,3 @@
+from enhancer.models.demucs import Demucs
+from enhancer.models.model import Model
+from enhancer.models.waveunet import WaveUnet
diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py
new file mode 100644
index 0000000..918a261
--- /dev/null
+++ b/enhancer/models/complexnn/__init__.py
@@ -0,0 +1,5 @@
+from enhancer.models.complexnn.conv import ComplexConv2d # noqa
+from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
+from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
+from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
+from enhancer.models.complexnn.utils import ComplexRelu # noqa
diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py
new file mode 100644
index 0000000..d9a4d0f
--- /dev/null
+++ b/enhancer/models/complexnn/conv.py
@@ -0,0 +1,136 @@
+from typing import Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def init_weights(nnet):
+ nn.init.xavier_normal_(nnet.weight.data)
+ nn.init.constant_(nnet.bias, 0.0)
+ return nnet
+
+
+class ComplexConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[int, int] = (1, 1),
+ stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0),
+ groups: int = 1,
+ dilation: int = 1,
+ ):
+ """
+ Complex Conv2d (non-causal)
+ """
+ super().__init__()
+ self.in_channels = in_channels // 2
+ self.out_channels = out_channels // 2
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.groups = groups
+ self.dilation = dilation
+
+ self.real_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=(self.padding[0], 0),
+ groups=self.groups,
+ dilation=self.dilation,
+ )
+ self.imag_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=(self.padding[0], 0),
+ groups=self.groups,
+ dilation=self.dilation,
+ )
+ self.imag_conv = init_weights(self.imag_conv)
+ self.real_conv = init_weights(self.real_conv)
+
+ def forward(self, input):
+ """
+ complex axis should be always 1 dim
+ """
+ input = F.pad(input, [self.padding[1], 0, 0, 0])
+
+ real, imag = torch.chunk(input, 2, 1)
+
+ real_real = self.real_conv(real)
+ real_imag = self.imag_conv(real)
+
+ imag_imag = self.imag_conv(imag)
+ imag_real = self.real_conv(imag)
+
+ real = real_real - imag_imag
+ imag = real_imag - imag_real
+
+ out = torch.cat([real, imag], 1)
+ return out
+
+
+class ComplexConvTranspose2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[int, int] = (1, 1),
+ stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0),
+ output_padding: Tuple[int, int] = (0, 0),
+ groups: int = 1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels // 2
+ self.out_channels = out_channels // 2
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.groups = groups
+ self.output_padding = output_padding
+
+ self.real_conv = nn.ConvTranspose2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ output_padding=self.output_padding,
+ groups=self.groups,
+ )
+
+ self.imag_conv = nn.ConvTranspose2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ output_padding=self.output_padding,
+ groups=self.groups,
+ )
+
+ self.real_conv = init_weights(self.real_conv)
+ self.imag_conv = init_weights(self.imag_conv)
+
+ def forward(self, input):
+
+ real, imag = torch.chunk(input, 2, 1)
+ real_real = self.real_conv(real)
+ real_imag = self.imag_conv(real)
+
+ imag_imag = self.imag_conv(imag)
+ imag_real = self.real_conv(imag)
+
+ real = real_real - imag_imag
+ imag = real_imag - imag_real
+
+ out = torch.cat([real, imag], 1)
+
+ return out
diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py
new file mode 100644
index 0000000..847030b
--- /dev/null
+++ b/enhancer/models/complexnn/rnn.py
@@ -0,0 +1,68 @@
+from typing import List, Optional
+
+import torch
+from torch import nn
+
+
+class ComplexLSTM(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int = 1,
+ projection_size: Optional[int] = None,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.input_size = input_size // 2
+ self.hidden_size = hidden_size // 2
+ self.num_layers = num_layers
+
+ self.real_lstm = nn.LSTM(
+ self.input_size,
+ self.hidden_size,
+ self.num_layers,
+ bidirectional=bidirectional,
+ batch_first=False,
+ )
+ self.imag_lstm = nn.LSTM(
+ self.input_size,
+ self.hidden_size,
+ self.num_layers,
+ bidirectional=bidirectional,
+ batch_first=False,
+ )
+
+ bidirectional = 2 if bidirectional else 1
+ if projection_size is not None:
+ self.projection_size = projection_size // 2
+ self.real_linear = nn.Linear(
+ self.hidden_size * bidirectional, self.projection_size
+ )
+ self.imag_linear = nn.Linear(
+ self.hidden_size * bidirectional, self.projection_size
+ )
+ else:
+ self.projection_size = None
+
+ def forward(self, input):
+
+ if isinstance(input, List):
+ real, imag = input
+ else:
+ real, imag = torch.chunk(input, 2, 1)
+
+ real_real = self.real_lstm(real)[0]
+ real_imag = self.imag_lstm(real)[0]
+
+ imag_imag = self.imag_lstm(imag)[0]
+ imag_real = self.real_lstm(imag)[0]
+
+ real = real_real - imag_imag
+ imag = imag_real + real_imag
+
+ if self.projection_size is not None:
+ real = self.real_linear(real)
+ imag = self.imag_linear(imag)
+
+ return [real, imag]
diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py
new file mode 100644
index 0000000..0c28f9b
--- /dev/null
+++ b/enhancer/models/complexnn/utils.py
@@ -0,0 +1,199 @@
+import torch
+from torch import nn
+
+
+class ComplexBatchNorm2D(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ eps: float = 1e-5,
+ momentum: float = 0.1,
+ affine: bool = True,
+ track_running_stats: bool = True,
+ ):
+ """
+ Complex batch normalization 2D
+ https://arxiv.org/abs/1705.09792
+
+
+ """
+ super().__init__()
+ self.num_features = num_features // 2
+ self.affine = affine
+ self.momentum = momentum
+ self.track_running_stats = track_running_stats
+ self.eps = eps
+
+ if self.affine:
+ self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
+ self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
+ self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
+ self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
+ self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
+ else:
+ self.register_parameter("Wrr", None)
+ self.register_parameter("Wri", None)
+ self.register_parameter("Wii", None)
+ self.register_parameter("Br", None)
+ self.register_parameter("Bi", None)
+
+ if self.track_running_stats:
+ values = torch.zeros(self.num_features)
+ self.register_buffer("Mean_real", values)
+ self.register_buffer("Mean_imag", values)
+ self.register_buffer("Var_rr", values)
+ self.register_buffer("Var_ri", values)
+ self.register_buffer("Var_ii", values)
+ self.register_buffer(
+ "num_batches_tracked", torch.tensor(0, dtype=torch.long)
+ )
+ else:
+ self.register_parameter("Mean_real", None)
+ self.register_parameter("Mean_imag", None)
+ self.register_parameter("Var_rr", None)
+ self.register_parameter("Var_ri", None)
+ self.register_parameter("Var_ii", None)
+ self.register_parameter("num_batches_tracked", None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.affine:
+ self.Wrr.data.fill_(1)
+ self.Wii.data.fill_(1)
+ self.Wri.data.uniform_(-0.9, 0.9)
+ self.Br.data.fill_(0)
+ self.Bi.data.fill_(0)
+ self.reset_running_stats()
+
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.Mean_real.zero_()
+ self.Mean_imag.zero_()
+ self.Var_rr.fill_(1)
+ self.Var_ri.zero_()
+ self.Var_ii.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def extra_repr(self):
+ return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
+ **self.__dict__
+ )
+
+ def forward(self, input):
+
+ real, imag = torch.chunk(input, 2, 1)
+ exp_avg_factor = 0.0
+
+ training = self.training and self.track_running_stats
+ if training:
+ self.num_batches_tracked += 1
+ if self.momentum is None:
+ exp_avg_factor = 1 / self.num_batches_tracked
+ else:
+ exp_avg_factor = self.momentum
+
+ redux = [i for i in reversed(range(real.dim())) if i != 1]
+ vdim = [1] * real.dim()
+ vdim[1] = real.size(1)
+
+ if training:
+ batch_mean_real, batch_mean_imag = real, imag
+ for dim in redux:
+ batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
+ batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
+ if self.track_running_stats:
+ self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
+ self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
+
+ else:
+ batch_mean_real = self.Mean_real.view(vdim)
+ batch_mean_imag = self.Mean_imag.view(vdim)
+
+ real = real - batch_mean_real
+ imag = imag - batch_mean_imag
+
+ if training:
+ batch_var_rr = real * real
+ batch_var_ri = real * imag
+ batch_var_ii = imag * imag
+ for dim in redux:
+ batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
+ batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
+ batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
+ if self.track_running_stats:
+ self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
+ self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
+ self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
+ else:
+ batch_var_rr = self.Var_rr.view(vdim)
+ batch_var_ii = self.Var_ii.view(vdim)
+ batch_var_ri = self.Var_ri.view(vdim)
+
+ batch_var_rr += self.eps
+ batch_var_ii += self.eps
+
+ # Covariance matrics
+ # | batch_var_rr batch_var_ri |
+ # | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
+ # Inverse square root of cov matrix by combining below two formulas
+ # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
+ # https://mathworld.wolfram.com/MatrixInverse.html
+
+ tau = batch_var_rr + batch_var_ii
+ s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
+ t = (tau + 2 * s).sqrt()
+
+ rst = (s * t).reciprocal()
+ Urr = (batch_var_ii + s) * rst
+ Uri = -batch_var_ri * rst
+ Uii = (batch_var_rr + s) * rst
+
+ if self.affine:
+ Wrr, Wri, Wii = (
+ self.Wrr.view(vdim),
+ self.Wri.view(vdim),
+ self.Wii.view(vdim),
+ )
+ Zrr = (Wrr * Urr) + (Wri * Uri)
+ Zri = (Wrr * Uri) + (Wri * Uii)
+ Zir = (Wii * Uri) + (Wri * Urr)
+ Zii = (Wri * Uri) + (Wii * Uii)
+ else:
+ Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
+
+ yr = (Zrr * real) + (Zri * imag)
+ yi = (Zir * real) + (Zii * imag)
+
+ if self.affine:
+ yr = yr + self.Br.view(vdim)
+ yi = yi + self.Bi.view(vdim)
+
+ outputs = torch.cat([yr, yi], 1)
+ return outputs
+
+
+class ComplexRelu(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.real_relu = nn.PReLU()
+ self.imag_relu = nn.PReLU()
+
+ def forward(self, input):
+
+ real, imag = torch.chunk(input, 2, 1)
+ real = self.real_relu(real)
+ imag = self.imag_relu(imag)
+ return torch.cat([real, imag], dim=1)
+
+
+def complex_cat(inputs, axis=1):
+
+ real, imag = [], []
+ for data in inputs:
+ real_data, imag_data = torch.chunk(data, 2, axis)
+ real.append(real_data)
+ imag.append(imag_data)
+ real = torch.cat(real, axis)
+ imag = torch.cat(imag, axis)
+ return torch.cat([real, imag], axis)
diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py
new file mode 100644
index 0000000..7b1e5b1
--- /dev/null
+++ b/enhancer/models/dccrn.py
@@ -0,0 +1,338 @@
+import logging
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from enhancer.data import EnhancerDataset
+from enhancer.models import Model
+from enhancer.models.complexnn import (
+ ComplexBatchNorm2D,
+ ComplexConv2d,
+ ComplexConvTranspose2d,
+ ComplexLSTM,
+ ComplexRelu,
+)
+from enhancer.models.complexnn.utils import complex_cat
+from enhancer.utils.transforms import ConviSTFT, ConvSTFT
+from enhancer.utils.utils import merge_dict
+
+
+class DCCRN_ENCODER(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channel: int,
+ kernel_size: Tuple[int, int],
+ complex_norm: bool = True,
+ complex_relu: bool = True,
+ stride: Tuple[int, int] = (2, 1),
+ padding: Tuple[int, int] = (2, 1),
+ ):
+ super().__init__()
+ batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
+ activation = ComplexRelu() if complex_relu else nn.PReLU()
+
+ self.encoder = nn.Sequential(
+ ComplexConv2d(
+ in_channels,
+ out_channel,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ ),
+ batchnorm(out_channel),
+ activation,
+ )
+
+ def forward(self, waveform):
+
+ return self.encoder(waveform)
+
+
+class DCCRN_DECODER(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[int, int],
+ layer: int = 0,
+ complex_norm: bool = True,
+ complex_relu: bool = True,
+ stride: Tuple[int, int] = (2, 1),
+ padding: Tuple[int, int] = (2, 0),
+ output_padding: Tuple[int, int] = (1, 0),
+ ):
+ super().__init__()
+ batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
+ activation = ComplexRelu() if complex_relu else nn.PReLU()
+
+ if layer != 0:
+ self.decoder = nn.Sequential(
+ ComplexConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ ),
+ batchnorm(out_channels),
+ activation,
+ )
+ else:
+ self.decoder = nn.Sequential(
+ ComplexConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ )
+ )
+
+ def forward(self, waveform):
+
+ return self.decoder(waveform)
+
+
+class DCCRN(Model):
+
+ STFT_DEFAULTS = {
+ "window_len": 400,
+ "hop_size": 100,
+ "nfft": 512,
+ "window": "hamming",
+ }
+
+ ED_DEFAULTS = {
+ "initial_output_channels": 32,
+ "depth": 6,
+ "kernel_size": 5,
+ "growth_factor": 2,
+ "stride": 2,
+ "padding": 2,
+ "output_padding": 1,
+ }
+
+ LSTM_DEFAULTS = {
+ "num_layers": 2,
+ "hidden_size": 256,
+ }
+
+ def __init__(
+ self,
+ stft: Optional[dict] = None,
+ encoder_decoder: Optional[dict] = None,
+ lstm: Optional[dict] = None,
+ complex_lstm: bool = True,
+ complex_norm: bool = True,
+ complex_relu: bool = True,
+ masking_mode: str = "E",
+ num_channels: int = 1,
+ sampling_rate=16000,
+ lr: float = 1e-3,
+ dataset: Optional[EnhancerDataset] = None,
+ duration: Optional[float] = None,
+ loss: Union[str, List, Any] = "mse",
+ metric: Union[str, List] = "mse",
+ ):
+ duration = (
+ dataset.duration if isinstance(dataset, EnhancerDataset) else None
+ )
+ if dataset is not None:
+ if sampling_rate != dataset.sampling_rate:
+ logging.warning(
+ f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
+ )
+ sampling_rate = dataset.sampling_rate
+ super().__init__(
+ num_channels=num_channels,
+ sampling_rate=sampling_rate,
+ lr=lr,
+ dataset=dataset,
+ duration=duration,
+ loss=loss,
+ metric=metric,
+ )
+
+ encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
+ lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
+ stft = merge_dict(self.STFT_DEFAULTS, stft)
+ self.save_hyperparameters(
+ "encoder_decoder",
+ "lstm",
+ "stft",
+ "complex_lstm",
+ "complex_norm",
+ "masking_mode",
+ )
+ self.complex_lstm = complex_lstm
+ self.complex_norm = complex_norm
+ self.masking_mode = masking_mode
+
+ self.stft = ConvSTFT(
+ stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
+ )
+ self.istft = ConviSTFT(
+ stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
+ )
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ num_channels *= 2
+ hidden_size = encoder_decoder["initial_output_channels"]
+ growth_factor = 2
+
+ for layer in range(encoder_decoder["depth"]):
+
+ encoder_ = DCCRN_ENCODER(
+ num_channels,
+ hidden_size,
+ kernel_size=(encoder_decoder["kernel_size"], 2),
+ stride=(encoder_decoder["stride"], 1),
+ padding=(encoder_decoder["padding"], 1),
+ complex_norm=complex_norm,
+ complex_relu=complex_relu,
+ )
+ self.encoder.append(encoder_)
+
+ decoder_ = DCCRN_DECODER(
+ hidden_size + hidden_size,
+ num_channels,
+ layer=layer,
+ kernel_size=(encoder_decoder["kernel_size"], 2),
+ stride=(encoder_decoder["stride"], 1),
+ padding=(encoder_decoder["padding"], 0),
+ output_padding=(encoder_decoder["output_padding"], 0),
+ complex_norm=complex_norm,
+ complex_relu=complex_relu,
+ )
+
+ self.decoder.insert(0, decoder_)
+
+ if layer < encoder_decoder["depth"] - 3:
+ num_channels = hidden_size
+ hidden_size *= growth_factor
+ else:
+ num_channels = hidden_size
+
+ kernel_size = hidden_size / 2
+ hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
+
+ if self.complex_lstm:
+ lstms = []
+ for layer in range(lstm["num_layers"]):
+
+ if layer == 0:
+ input_size = int(hidden_size * kernel_size)
+ else:
+ input_size = lstm["hidden_size"]
+
+ if layer == lstm["num_layers"] - 1:
+ projection_size = int(hidden_size * kernel_size)
+ else:
+ projection_size = None
+
+ kwargs = {
+ "input_size": input_size,
+ "hidden_size": lstm["hidden_size"],
+ "num_layers": 1,
+ }
+
+ lstms.append(
+ ComplexLSTM(projection_size=projection_size, **kwargs)
+ )
+ self.lstm = nn.Sequential(*lstms)
+ else:
+ self.lstm = nn.Sequential(
+ nn.LSTM(
+ input_size=hidden_size * kernel_size,
+ hidden_sizs=lstm["hidden_size"],
+ num_layers=lstm["num_layers"],
+ dropout=0.0,
+ batch_first=False,
+ )[0],
+ nn.Linear(lstm["hidden"], hidden_size * kernel_size),
+ )
+
+ def forward(self, waveform):
+
+ if waveform.dim() == 2:
+ waveform = waveform.unsqueeze(1)
+
+ if waveform.size(1) != self.hparams.num_channels:
+ raise ValueError(
+ f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
+ )
+
+ waveform_stft = self.stft(waveform)
+ real = waveform_stft[:, : self.stft.nfft // 2 + 1]
+ imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
+
+ mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
+ phase_spec = torch.atan2(imag, real)
+ complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
+
+ encoder_outputs = []
+ out = complex_spec
+ for _, encoder in enumerate(self.encoder):
+ out = encoder(out)
+ encoder_outputs.append(out)
+
+ B, C, D, T = out.size()
+ out = out.permute(3, 0, 1, 2)
+ if self.complex_lstm:
+
+ lstm_real = out[:, :, : C // 2]
+ lstm_imag = out[:, :, C // 2 :]
+ lstm_real = lstm_real.reshape(T, B, C // 2 * D)
+ lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
+ lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
+ lstm_real = lstm_real.reshape(T, B, C // 2, D)
+ lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
+ out = torch.cat([lstm_real, lstm_imag], 2)
+ else:
+ out = out.reshape(T, B, C * D)
+ out = self.lstm(out)
+ out = out.reshape(T, B, D, C)
+
+ out = out.permute(1, 2, 3, 0)
+ for layer, decoder in enumerate(self.decoder):
+ skip_connection = encoder_outputs.pop(-1)
+ out = complex_cat([skip_connection, out])
+ out = decoder(out)
+ out = out[..., 1:]
+ mask_real, mask_imag = out[:, 0], out[:, 1]
+ mask_real = F.pad(mask_real, [0, 0, 1, 0])
+ mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
+ if self.masking_mode == "E":
+
+ mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
+ real_phase = mask_real / (mask_mag + 1e-8)
+ imag_phase = mask_imag / (mask_mag + 1e-8)
+ mask_phase = torch.atan2(imag_phase, real_phase)
+ mask_mag = torch.tanh(mask_mag)
+ est_mag = mask_mag * mag_spec
+ est_phase = mask_phase * phase_spec
+ # cos(theta) + isin(theta)
+ real = est_mag + torch.cos(est_phase)
+ imag = est_mag + torch.sin(est_phase)
+
+ if self.masking_mode == "C":
+
+ real = real * mask_real - imag * mask_imag
+ imag = real * mask_imag + imag * mask_real
+
+ else:
+
+ real = real * mask_real
+ imag = imag * mask_imag
+
+ spec = torch.cat([real, imag], 1)
+ wav = self.istft(spec)
+ wav = wav.clamp_(-1, 1)
+ return wav
diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py
new file mode 100644
index 0000000..fafb84e
--- /dev/null
+++ b/enhancer/models/demucs.py
@@ -0,0 +1,274 @@
+import logging
+import math
+from typing import List, Optional, Union
+
+import torch.nn.functional as F
+from torch import nn
+
+from enhancer.data.dataset import EnhancerDataset
+from enhancer.models.model import Model
+from enhancer.utils.io import Audio as audio
+from enhancer.utils.utils import merge_dict
+
+
+class DemucsLSTM(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int,
+ bidirectional: bool = True,
+ ):
+ super().__init__()
+ self.lstm = nn.LSTM(
+ input_size, hidden_size, num_layers, bidirectional=bidirectional
+ )
+ dim = 2 if bidirectional else 1
+ self.linear = nn.Linear(dim * hidden_size, hidden_size)
+
+ def forward(self, x):
+
+ output, (h, c) = self.lstm(x)
+ output = self.linear(output)
+
+ return output, (h, c)
+
+
+class DemucsEncoder(nn.Module):
+ def __init__(
+ self,
+ num_channels: int,
+ hidden_size: int,
+ kernel_size: int,
+ stride: int = 1,
+ glu: bool = False,
+ ):
+ super().__init__()
+ activation = nn.GLU(1) if glu else nn.ReLU()
+ multi_factor = 2 if glu else 1
+ self.encoder = nn.Sequential(
+ nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
+ nn.ReLU(),
+ nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
+ activation,
+ )
+
+ def forward(self, waveform):
+
+ return self.encoder(waveform)
+
+
+class DemucsDecoder(nn.Module):
+ def __init__(
+ self,
+ num_channels: int,
+ hidden_size: int,
+ kernel_size: int,
+ stride: int = 1,
+ glu: bool = False,
+ layer: int = 0,
+ ):
+ super().__init__()
+ activation = nn.GLU(1) if glu else nn.ReLU()
+ multi_factor = 2 if glu else 1
+ self.decoder = nn.Sequential(
+ nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
+ activation,
+ nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
+ )
+ if layer > 0:
+ self.decoder.add_module("4", nn.ReLU())
+
+ def forward(
+ self,
+ waveform,
+ ):
+
+ out = self.decoder(waveform)
+ return out
+
+
+class Demucs(Model):
+ """
+ Demucs model from https://arxiv.org/pdf/1911.13254.pdf
+ parameters:
+ encoder_decoder: dict, optional
+ keyword arguments passsed to encoder decoder block
+ lstm : dict, optional
+ keyword arguments passsed to LSTM block
+ num_channels: int, defaults to 1
+ number channels in input audio
+ sampling_rate: int, defaults to 16KHz
+ sampling rate of input audio
+ lr : float, defaults to 1e-3
+ learning rate used for training
+ dataset: EnhancerDataset, optional
+ EnhancerDataset object containing train/validation data for training
+ duration : float, optional
+ chunk duration in seconds
+ loss : string or List of strings
+ loss function to be used, available ("mse","mae","SI-SDR")
+ metric : string or List of strings
+ metric function to be used, available ("mse","mae","SI-SDR")
+
+ """
+
+ ED_DEFAULTS = {
+ "initial_output_channels": 48,
+ "kernel_size": 8,
+ "stride": 4,
+ "depth": 5,
+ "glu": True,
+ "growth_factor": 2,
+ }
+ LSTM_DEFAULTS = {
+ "bidirectional": True,
+ "num_layers": 2,
+ }
+
+ def __init__(
+ self,
+ encoder_decoder: Optional[dict] = None,
+ lstm: Optional[dict] = None,
+ num_channels: int = 1,
+ resample: int = 4,
+ sampling_rate=16000,
+ normalize=True,
+ lr: float = 1e-3,
+ dataset: Optional[EnhancerDataset] = None,
+ loss: Union[str, List] = "mse",
+ metric: Union[str, List] = "mse",
+ floor=1e-3,
+ ):
+ duration = (
+ dataset.duration if isinstance(dataset, EnhancerDataset) else None
+ )
+ if dataset is not None:
+ if sampling_rate != dataset.sampling_rate:
+ logging.warning(
+ f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
+ )
+ sampling_rate = dataset.sampling_rate
+ super().__init__(
+ num_channels=num_channels,
+ sampling_rate=sampling_rate,
+ lr=lr,
+ dataset=dataset,
+ duration=duration,
+ loss=loss,
+ metric=metric,
+ )
+
+ encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
+ lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
+ self.save_hyperparameters("encoder_decoder", "lstm", "resample")
+ hidden = encoder_decoder["initial_output_channels"]
+ self.normalize = normalize
+ self.floor = floor
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for layer in range(encoder_decoder["depth"]):
+
+ encoder_layer = DemucsEncoder(
+ num_channels=num_channels,
+ hidden_size=hidden,
+ kernel_size=encoder_decoder["kernel_size"],
+ stride=encoder_decoder["stride"],
+ glu=encoder_decoder["glu"],
+ )
+ self.encoder.append(encoder_layer)
+
+ decoder_layer = DemucsDecoder(
+ num_channels=num_channels,
+ hidden_size=hidden,
+ kernel_size=encoder_decoder["kernel_size"],
+ stride=encoder_decoder["stride"],
+ glu=encoder_decoder["glu"],
+ layer=layer,
+ )
+ self.decoder.insert(0, decoder_layer)
+
+ num_channels = hidden
+ hidden = self.ED_DEFAULTS["growth_factor"] * hidden
+
+ self.de_lstm = DemucsLSTM(
+ input_size=num_channels,
+ hidden_size=num_channels,
+ num_layers=lstm["num_layers"],
+ bidirectional=lstm["bidirectional"],
+ )
+
+ def forward(self, waveform):
+
+ if waveform.dim() == 2:
+ waveform = waveform.unsqueeze(1)
+
+ if waveform.size(1) != self.hparams.num_channels:
+ raise ValueError(
+ f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
+ )
+ if self.normalize:
+ waveform = waveform.mean(dim=1, keepdim=True)
+ std = waveform.std(dim=-1, keepdim=True)
+ waveform = waveform / (self.floor + std)
+ else:
+ std = 1
+ length = waveform.shape[-1]
+ x = F.pad(waveform, (0, self.get_padding_length(length) - length))
+ if self.hparams.resample > 1:
+ x = audio.resample_audio(
+ audio=x,
+ sr=self.hparams.sampling_rate,
+ target_sr=int(
+ self.hparams.sampling_rate * self.hparams.resample
+ ),
+ )
+
+ encoder_outputs = []
+ for encoder in self.encoder:
+ x = encoder(x)
+ encoder_outputs.append(x)
+ x = x.permute(0, 2, 1)
+ x, _ = self.de_lstm(x)
+
+ x = x.permute(0, 2, 1)
+ for decoder in self.decoder:
+ skip_connection = encoder_outputs.pop(-1)
+ x = x + skip_connection[..., : x.shape[-1]]
+ x = decoder(x)
+
+ if self.hparams.resample > 1:
+ x = audio.resample_audio(
+ x,
+ int(self.hparams.sampling_rate * self.hparams.resample),
+ self.hparams.sampling_rate,
+ )
+
+ out = x[..., :length]
+ return std * out
+
+ def get_padding_length(self, input_length):
+
+ input_length = math.ceil(input_length * self.hparams.resample)
+
+ for layer in range(
+ self.hparams.encoder_decoder["depth"]
+ ): # encoder operation
+ input_length = (
+ math.ceil(
+ (input_length - self.hparams.encoder_decoder["kernel_size"])
+ / self.hparams.encoder_decoder["stride"]
+ )
+ + 1
+ )
+ input_length = max(1, input_length)
+ for layer in range(
+ self.hparams.encoder_decoder["depth"]
+ ): # decoder operaration
+ input_length = (input_length - 1) * self.hparams.encoder_decoder[
+ "stride"
+ ] + self.hparams.encoder_decoder["kernel_size"]
+ input_length = math.ceil(input_length / self.hparams.resample)
+
+ return int(input_length)
diff --git a/enhancer/models/model.py b/enhancer/models/model.py
new file mode 100644
index 0000000..c679669
--- /dev/null
+++ b/enhancer/models/model.py
@@ -0,0 +1,431 @@
+import os
+from collections import defaultdict
+from importlib import import_module
+from pathlib import Path
+from typing import Any, List, Optional, Text, Union
+from urllib.parse import urlparse
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from huggingface_hub import cached_download, hf_hub_url
+from pytorch_lightning.utilities.cloud_io import load as pl_load
+from torch import nn
+from torch.optim import Adam
+
+from enhancer.data.dataset import EnhancerDataset
+from enhancer.inference import Inference
+from enhancer.loss import LOSS_MAP, LossWrapper
+from enhancer.version import __version__
+
+CACHE_DIR = ""
+HF_TORCH_WEIGHTS = ""
+DEFAULT_DEVICE = "cpu"
+
+
+class Model(pl.LightningModule):
+ """
+ Base class for all models
+ parameters:
+ num_channels: int, default to 1
+ number of channels in input audio
+ sampling_rate : int, default 16khz
+ audio sampling rate
+ lr: float, optional
+ learning rate for model training
+ dataset: EnhancerDataset, optional
+ Enhancer dataset used for training/validation
+ duration: float, optional
+ duration used for training/inference
+ loss : string or List of strings or custom loss (nn.Module), default to "mse"
+ loss functions to be used. Available ("mse","mae","Si-SDR")
+
+ """
+
+ def __init__(
+ self,
+ num_channels: int = 1,
+ sampling_rate: int = 16000,
+ lr: float = 1e-3,
+ dataset: Optional[EnhancerDataset] = None,
+ duration: Optional[float] = None,
+ loss: Union[str, List] = "mse",
+ metric: Union[str, List, Any] = "mse",
+ ):
+ super().__init__()
+ assert (
+ num_channels == 1
+ ), "Enhancer only support for mono channel models"
+ self.dataset = dataset
+ self.save_hyperparameters(
+ "num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
+ )
+ if self.logger:
+ self.logger.experiment.log_dict(
+ dict(self.hparams), "hyperparameters.json"
+ )
+
+ self.loss = loss
+ self.metric = metric
+
+ @property
+ def loss(self):
+ return self._loss
+
+ @loss.setter
+ def loss(self, loss):
+
+ if isinstance(loss, str):
+ loss = [loss]
+
+ self._loss = LossWrapper(loss)
+
+ @property
+ def metric(self):
+ return self._metric
+
+ @metric.setter
+ def metric(self, metric):
+ self._metric = []
+ if isinstance(metric, (str, nn.Module)):
+ metric = [metric]
+
+ for func in metric:
+ if isinstance(func, str):
+ if func in LOSS_MAP.keys():
+ if func in ("pesq", "stoi"):
+ self._metric.append(
+ LOSS_MAP[func](self.hparams.sampling_rate)
+ )
+ else:
+ self._metric.append(LOSS_MAP[func]())
+ else:
+ ValueError(f"Invalid metrics {func}")
+
+ elif isinstance(func, nn.Module):
+ self._metric.append(func)
+ else:
+ raise ValueError("Invalid metrics")
+
+ @property
+ def dataset(self):
+ return self._dataset
+
+ @dataset.setter
+ def dataset(self, dataset):
+ self._dataset = dataset
+
+ def setup(self, stage: Optional[str] = None):
+ if stage == "fit":
+ torch.cuda.empty_cache()
+ self.dataset.setup(stage)
+ self.dataset.model = self
+
+ print(
+ "Total train duration",
+ self.dataset.train_dataloader().dataset.__len__()
+ * self.dataset.duration
+ / 60,
+ "minutes",
+ )
+ print(
+ "Total validation duration",
+ self.dataset.val_dataloader().dataset.__len__()
+ * self.dataset.duration
+ / 60,
+ "minutes",
+ )
+ print(
+ "Total test duration",
+ self.dataset.test_dataloader().dataset.__len__()
+ * self.dataset.duration
+ / 60,
+ "minutes",
+ )
+
+ def train_dataloader(self):
+ return self.dataset.train_dataloader()
+
+ def val_dataloader(self):
+ return self.dataset.val_dataloader()
+
+ def test_dataloader(self):
+ return self.dataset.test_dataloader()
+
+ def configure_optimizers(self):
+ return Adam(self.parameters(), lr=self.hparams.lr)
+
+ def training_step(self, batch, batch_idx: int):
+
+ mixed_waveform = batch["noisy"]
+ target = batch["clean"]
+ prediction = self(mixed_waveform)
+ loss = self.loss(prediction, target)
+
+ self.log(
+ "train_loss",
+ loss.item(),
+ on_epoch=True,
+ on_step=True,
+ logger=True,
+ prog_bar=True,
+ )
+
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx: int):
+
+ metric_dict = {}
+ mixed_waveform = batch["noisy"]
+ target = batch["clean"]
+ prediction = self(mixed_waveform)
+
+ metric_dict["valid_loss"] = self.loss(target, prediction).item()
+ for metric in self.metric:
+ value = metric(target, prediction)
+ metric_dict[f"valid_{metric.name}"] = value.item()
+
+ self.log_dict(
+ metric_dict,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ logger=True,
+ )
+
+ return metric_dict
+
+ def test_step(self, batch, batch_idx):
+
+ metric_dict = {}
+ mixed_waveform = batch["noisy"]
+ target = batch["clean"]
+ prediction = self(mixed_waveform)
+
+ for metric in self.metric:
+ value = metric(target, prediction)
+ metric_dict[f"test_{metric.name}"] = value
+
+ self.log_dict(
+ metric_dict,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ logger=True,
+ )
+
+ return metric_dict
+
+ def test_epoch_end(self, outputs):
+
+ test_mean_metrics = defaultdict(int)
+ for output in outputs:
+ for metric, value in output.items():
+ test_mean_metrics[metric] += value.item()
+ for metric in test_mean_metrics.keys():
+ test_mean_metrics[metric] /= len(outputs)
+
+ print("----------TEST REPORT----------\n")
+ for metric in test_mean_metrics.keys():
+ print(f"|{metric.upper()} | {test_mean_metrics[metric]} |")
+ print("--------------------------------")
+
+ def on_save_checkpoint(self, checkpoint):
+
+ checkpoint["enhancer"] = {
+ "version": {"enhancer": __version__, "pytorch": torch.__version__},
+ "architecture": {
+ "module": self.__class__.__module__,
+ "class": self.__class__.__name__,
+ },
+ }
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ checkpoint: Union[Path, Text],
+ map_location=None,
+ hparams_file: Union[Path, Text] = None,
+ strict: bool = True,
+ use_auth_token: Union[Text, None] = None,
+ cached_dir: Union[Path, Text] = CACHE_DIR,
+ **kwargs,
+ ):
+ """
+ Load Pretrained model
+
+ parameters:
+ checkpoint : Path or str
+ Path to checkpoint, or a remote URL, or a model identifier from
+ the huggingface.co model hub.
+ map_location: optional
+ Same role as in torch.load().
+ Defaults to `lambda storage, loc: storage`.
+ hparams_file : Path or str, optional
+ Path to a .yaml file with hierarchical structure as in this example:
+ drop_prob: 0.2
+ dataloader:
+ batch_size: 32
+ You most likely won’t need this since Lightning will always save the
+ hyperparameters to the checkpoint. However, if your checkpoint weights
+ do not have the hyperparameters saved, use this method to pass in a .yaml
+ file with the hparams you would like to use. These will be converted
+ into a dict and passed into your Model for use.
+ strict : bool, optional
+ Whether to strictly enforce that the keys in checkpoint match
+ the keys returned by this module’s state dict. Defaults to True.
+ use_auth_token : str, optional
+ When loading a private huggingface.co model, set `use_auth_token`
+ to True or to a string containing your hugginface.co authentication
+ token that can be obtained by running `huggingface-cli login`
+ cache_dir: Path or str, optional
+ Path to model cache directory
+ kwargs: optional
+ Any extra keyword args needed to init the model.
+ Can also be used to override saved hyperparameter values.
+
+ Returns
+ -------
+ model : Model
+ Model
+
+ See also
+ --------
+ torch.load
+ """
+
+ checkpoint = str(checkpoint)
+ if hparams_file is not None:
+ hparams_file = str(hparams_file)
+
+ if os.path.isfile(checkpoint):
+ model_path_pl = checkpoint
+ elif urlparse(checkpoint).scheme in ("http", "https"):
+ model_path_pl = checkpoint
+ else:
+
+ if "@" in checkpoint:
+ model_id = checkpoint.split("@")[0]
+ revision_id = checkpoint.split("@")[1]
+ else:
+ model_id = checkpoint
+ revision_id = None
+
+ url = hf_hub_url(
+ model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
+ )
+ model_path_pl = cached_download(
+ url=url,
+ library_name="enhancer",
+ library_version=__version__,
+ cache_dir=cached_dir,
+ use_auth_token=use_auth_token,
+ )
+
+ if map_location is None:
+ map_location = torch.device(DEFAULT_DEVICE)
+
+ loaded_checkpoint = pl_load(model_path_pl, map_location)
+ module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
+ class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
+ module = import_module(module_name)
+ Klass = getattr(module, class_name)
+
+ try:
+ model = Klass.load_from_checkpoint(
+ checkpoint_path=model_path_pl,
+ map_location=map_location,
+ hparams_file=hparams_file,
+ strict=strict,
+ **kwargs,
+ )
+ except Exception as e:
+ print(e)
+
+ return model
+
+ def infer(self, batch: torch.Tensor, batch_size: int = 32):
+ """
+ perform model inference
+ parameters:
+ batch : torch.Tensor
+ input data
+ batch_size : int, default 32
+ batch size for inference
+ """
+
+ assert (
+ batch.ndim == 3
+ ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
+ batch_predictions = []
+ self.eval().to(self.device)
+ with torch.no_grad():
+ for batch_id in range(0, batch.shape[0], batch_size):
+ batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
+ self.device
+ )
+ prediction = self(batch_data)
+ batch_predictions.append(prediction)
+
+ return torch.vstack(batch_predictions)
+
+ def enhance(
+ self,
+ audio: Union[Path, np.ndarray, torch.Tensor],
+ sampling_rate: Optional[int] = None,
+ batch_size: int = 32,
+ save_output: bool = False,
+ duration: Optional[int] = None,
+ step_size: Optional[int] = None,
+ ):
+ """
+ Enhance audio using loaded pretained model.
+
+ parameters:
+ audio: Path to audio file or numpy array or torch tensor
+ single input audio
+ sampling_rate: int, optional incase input is path
+ sampling rate of input
+ batch_size: int, default 32
+ input audio is split into multiple chunks. Inference is done on batches
+ of these chunks according to given batch size.
+ save_output : bool, default False
+ weather to save output to file
+ duration : float, optional
+ chunk duration in seconds, defaults to duration of loaded pretrained model.
+ step_size: int, optional
+ step size between consecutive durations, defaults to 50% of duration
+ """
+
+ model_sampling_rate = self.hparams["sampling_rate"]
+ if duration is None:
+ duration = self.hparams["duration"]
+ waveform = Inference.read_input(
+ audio, sampling_rate, model_sampling_rate
+ )
+ waveform.to(self.device)
+ window_size = round(duration * model_sampling_rate)
+ batched_waveform = Inference.batchify(
+ waveform, window_size, step_size=step_size
+ )
+ batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
+ waveform = Inference.aggreagate(
+ batch_prediction,
+ window_size,
+ waveform.shape[-1],
+ step_size,
+ )
+
+ if save_output and isinstance(audio, (str, Path)):
+ Inference.write_output(waveform, audio, model_sampling_rate)
+
+ else:
+ waveform = Inference.prepare_output(
+ waveform, model_sampling_rate, audio, sampling_rate
+ )
+ return waveform
+
+ @property
+ def valid_monitor(self):
+
+ return "max" if self.loss.higher_better else "min"
diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py
new file mode 100644
index 0000000..ea5646a
--- /dev/null
+++ b/enhancer/models/waveunet.py
@@ -0,0 +1,201 @@
+import logging
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from enhancer.data.dataset import EnhancerDataset
+from enhancer.models.model import Model
+
+
+class WavenetDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 5,
+ padding: int = 2,
+ stride: int = 1,
+ dilation: int = 1,
+ ):
+ super(WavenetDecoder, self).__init__()
+ self.decoder = nn.Sequential(
+ nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ ),
+ nn.BatchNorm1d(out_channels),
+ nn.LeakyReLU(negative_slope=0.1),
+ )
+
+ def forward(self, waveform):
+
+ return self.decoder(waveform)
+
+
+class WavenetEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 15,
+ padding: int = 7,
+ stride: int = 1,
+ dilation: int = 1,
+ ):
+ super(WavenetEncoder, self).__init__()
+ self.encoder = nn.Sequential(
+ nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ ),
+ nn.BatchNorm1d(out_channels),
+ nn.LeakyReLU(negative_slope=0.1),
+ )
+
+ def forward(self, waveform):
+ return self.encoder(waveform)
+
+
+class WaveUnet(Model):
+ """
+ Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
+ parameters:
+ num_channels: int, defaults to 1
+ number of channels in input audio
+ depth : int, defaults to 12
+ depth of network
+ initial_output_channels: int, defaults to 24
+ number of output channels in intial upsampling layer
+ sampling_rate: int, defaults to 16KHz
+ sampling rate of input audio
+ lr : float, defaults to 1e-3
+ learning rate used for training
+ dataset: EnhancerDataset, optional
+ EnhancerDataset object containing train/validation data for training
+ duration : float, optional
+ chunk duration in seconds
+ loss : string or List of strings
+ loss function to be used, available ("mse","mae","SI-SDR")
+ metric : string or List of strings
+ metric function to be used, available ("mse","mae","SI-SDR")
+ """
+
+ def __init__(
+ self,
+ num_channels: int = 1,
+ depth: int = 12,
+ initial_output_channels: int = 24,
+ sampling_rate: int = 16000,
+ lr: float = 1e-3,
+ dataset: Optional[EnhancerDataset] = None,
+ duration: Optional[float] = None,
+ loss: Union[str, List] = "mse",
+ metric: Union[str, List] = "mse",
+ ):
+ duration = (
+ dataset.duration if isinstance(dataset, EnhancerDataset) else None
+ )
+ if dataset is not None:
+ if sampling_rate != dataset.sampling_rate:
+ logging.warning(
+ f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
+ )
+ sampling_rate = dataset.sampling_rate
+ super().__init__(
+ num_channels=num_channels,
+ sampling_rate=sampling_rate,
+ lr=lr,
+ dataset=dataset,
+ duration=duration,
+ loss=loss,
+ metric=metric,
+ )
+ self.save_hyperparameters("depth")
+ self.encoders = nn.ModuleList()
+ self.decoders = nn.ModuleList()
+ out_channels = initial_output_channels
+ for layer in range(depth):
+
+ encoder = WavenetEncoder(num_channels, out_channels)
+ self.encoders.append(encoder)
+
+ num_channels = out_channels
+ out_channels += initial_output_channels
+ if layer == depth - 1:
+ decoder = WavenetDecoder(
+ depth * initial_output_channels + num_channels, num_channels
+ )
+ else:
+ decoder = WavenetDecoder(
+ num_channels + out_channels, num_channels
+ )
+
+ self.decoders.insert(0, decoder)
+
+ bottleneck_dim = depth * initial_output_channels
+ self.bottleneck = nn.Sequential(
+ nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
+ nn.BatchNorm1d(bottleneck_dim),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ )
+ self.final = nn.Sequential(
+ nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
+ nn.Tanh(),
+ )
+
+ def forward(self, waveform):
+ if waveform.dim() == 2:
+ waveform = waveform.unsqueeze(1)
+
+ if waveform.size(1) != 1:
+ raise TypeError(
+ f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
+ )
+
+ encoder_outputs = []
+ out = waveform
+ for encoder in self.encoders:
+ out = encoder(out)
+ encoder_outputs.insert(0, out)
+ out = out[:, :, ::2]
+
+ out = self.bottleneck(out)
+
+ for layer, decoder in enumerate(self.decoders):
+ out = F.interpolate(out, scale_factor=2, mode="linear")
+ out = self.fix_last_dim(out, encoder_outputs[layer])
+ out = torch.cat([out, encoder_outputs[layer]], dim=1)
+ out = decoder(out)
+
+ out = torch.cat([out, waveform], dim=1)
+ out = self.final(out)
+ return out
+
+ def fix_last_dim(self, x, target):
+ """
+ centre crop along last dimension
+ """
+
+ assert (
+ x.shape[-1] >= target.shape[-1]
+ ), "input dimension cannot be larger than target dimension"
+ if x.shape[-1] == target.shape[-1]:
+ return x
+
+ diff = x.shape[-1] - target.shape[-1]
+ if diff % 2 != 0:
+ x = F.pad(x, (0, 1))
+ diff += 1
+
+ crop = diff // 2
+ return x[:, :, crop:-crop]
diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py
new file mode 100644
index 0000000..de0db9f
--- /dev/null
+++ b/enhancer/utils/__init__.py
@@ -0,0 +1,3 @@
+from enhancer.utils.config import Files
+from enhancer.utils.io import Audio
+from enhancer.utils.utils import check_files
diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py
new file mode 100644
index 0000000..252e6c9
--- /dev/null
+++ b/enhancer/utils/config.py
@@ -0,0 +1,9 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class Files:
+ train_clean: str
+ train_noisy: str
+ test_clean: str
+ test_noisy: str
diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py
new file mode 100644
index 0000000..d151ef8
--- /dev/null
+++ b/enhancer/utils/io.py
@@ -0,0 +1,128 @@
+import os
+from pathlib import Path
+from typing import Optional, Union
+
+import librosa
+import numpy as np
+import torch
+import torchaudio
+
+
+class Audio:
+ """
+ Audio utils
+ parameters:
+ sampling_rate : int, defaults to 16KHz
+ audio sampling rate
+ mono: bool, defaults to True
+ return_tensors: bool, defaults to True
+ returns torch tensor type if set to True else numpy ndarray
+ """
+
+ def __init__(
+ self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
+ ) -> None:
+
+ self.sampling_rate = sampling_rate
+ self.mono = mono
+ self.return_tensor = return_tensor
+
+ def __call__(
+ self,
+ audio: Union[Path, np.ndarray, torch.Tensor],
+ sampling_rate: Optional[int] = None,
+ offset: Optional[float] = None,
+ duration: Optional[float] = None,
+ ):
+ """
+ read and process input audio
+ parameters:
+ audio: Path to audio file or numpy array or torch tensor
+ single input audio
+ sampling_rate : int, optional
+ sampling rate of the audio input
+ offset: float, optional
+ offset from which the audio must be read, reads from beginning if unused.
+ duration: float (seconds), optional
+ read duration, reads full audio starting from offset if not used
+ """
+ if isinstance(audio, str):
+ if os.path.exists(audio):
+ audio, sampling_rate = librosa.load(
+ audio,
+ sr=sampling_rate,
+ mono=False,
+ offset=offset,
+ duration=duration,
+ )
+ if len(audio.shape) == 1:
+ audio = audio.reshape(1, -1)
+ else:
+ raise FileNotFoundError(f"File {audio} deos not exist")
+ elif isinstance(audio, np.ndarray):
+ if len(audio.shape) == 1:
+ audio = audio.reshape(1, -1)
+ else:
+ raise ValueError("audio should be either filepath or numpy ndarray")
+
+ if self.mono:
+ audio = self.convert_mono(audio)
+
+ if sampling_rate:
+ audio = self.__class__.resample_audio(
+ audio, sampling_rate, self.sampling_rate
+ )
+ if self.return_tensor:
+ return torch.tensor(audio)
+ else:
+ return audio
+
+ @staticmethod
+ def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
+ """
+ convert input audio into mono (1)
+ parameters:
+ audio: np.ndarray or torch.Tensor
+ """
+ if len(audio.shape) > 2:
+ assert (
+ audio.shape[0] == 1
+ ), "convert mono only accepts single waveform"
+ audio = audio.reshape(audio.shape[1], audio.shape[2])
+
+ assert (
+ audio.shape[1] >> audio.shape[0]
+ ), f"expected input format (num_channels,num_samples) got {audio.shape}"
+ num_channels, num_samples = audio.shape
+ if num_channels > 1:
+ return audio.mean(axis=0).reshape(1, num_samples)
+ return audio
+
+ @staticmethod
+ def resample_audio(
+ audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
+ ):
+ """
+ resample audio to desired sampling rate
+ parameters:
+ audio : Path to audio file or numpy array or torch tensor
+ audio waveform
+ sr : int
+ current sampling rate
+ target_sr : int
+ target sampling rate
+
+ """
+ if sr != target_sr:
+ if isinstance(audio, np.ndarray):
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
+ elif isinstance(audio, torch.Tensor):
+ audio = torchaudio.functional.resample(
+ audio, orig_freq=sr, new_freq=target_sr
+ )
+ else:
+ raise ValueError(
+ "Input should be either numpy array or torch tensor"
+ )
+
+ return audio
diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py
new file mode 100644
index 0000000..dd9395a
--- /dev/null
+++ b/enhancer/utils/random.py
@@ -0,0 +1,36 @@
+import os
+import random
+
+import torch
+
+
+def create_unique_rng(epoch: int):
+ """create unique random number generator for each (worker_id,epoch) combination"""
+
+ rng = random.Random()
+
+ global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
+ global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
+ node_rank = int(os.environ.get("NODE_RANK", "0"))
+ world_size = int(os.environ.get("WORLD_SIZE", "0"))
+
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ num_workers = worker_info.num_workers
+ worker_id = worker_info.id
+ else:
+ num_workers = 1
+ worker_id = 0
+
+ seed = (
+ global_seed
+ + worker_id
+ + local_rank * num_workers
+ + node_rank * num_workers * global_rank
+ + epoch * num_workers * world_size
+ )
+
+ rng.seed(seed)
+
+ return rng
diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py
new file mode 100644
index 0000000..5af1f92
--- /dev/null
+++ b/enhancer/utils/transforms.py
@@ -0,0 +1,93 @@
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.signal import get_window
+from torch import nn
+
+
+class ConvFFT(nn.Module):
+ def __init__(
+ self,
+ window_len: int,
+ nfft: Optional[int] = None,
+ window: str = "hamming",
+ ):
+ super().__init__()
+ self.window_len = window_len
+ self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
+ self.window = torch.from_numpy(
+ get_window(window, window_len, fftbins=True).astype("float32")
+ )
+
+ def init_kernel(self, inverse=False):
+
+ fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
+ real, imag = np.real(fourier_basis), np.imag(fourier_basis)
+ kernel = np.concatenate([real, imag], 1).T
+ if inverse:
+ kernel = np.linalg.pinv(kernel).T
+ kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
+ kernel *= self.window
+ return kernel
+
+
+class ConvSTFT(ConvFFT):
+ def __init__(
+ self,
+ window_len: int,
+ hop_size: Optional[int] = None,
+ nfft: Optional[int] = None,
+ window: str = "hamming",
+ ):
+ super().__init__(window_len=window_len, nfft=nfft, window=window)
+ self.hop_size = hop_size if hop_size else window_len // 2
+ self.register_buffer("weight", self.init_kernel())
+
+ def forward(self, input):
+
+ if input.dim() < 2:
+ raise ValueError(
+ f"Expected signal with shape 2 or 3 got {input.dim()}"
+ )
+ elif input.dim() == 2:
+ input = input.unsqueeze(1)
+ else:
+ pass
+ input = F.pad(
+ input,
+ (self.window_len - self.hop_size, self.window_len - self.hop_size),
+ )
+ output = F.conv1d(input, self.weight, stride=self.hop_size)
+
+ return output
+
+
+class ConviSTFT(ConvFFT):
+ def __init__(
+ self,
+ window_len: int,
+ hop_size: Optional[int] = None,
+ nfft: Optional[int] = None,
+ window: str = "hamming",
+ ):
+ super().__init__(window_len=window_len, nfft=nfft, window=window)
+ self.hop_size = hop_size if hop_size else window_len // 2
+ self.register_buffer("weight", self.init_kernel(True))
+ self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
+
+ def forward(self, input, phase=None):
+
+ if phase is not None:
+ real = input * torch.cos(phase)
+ imag = input * torch.sin(phase)
+ input = torch.cat([real, imag], 1)
+ out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
+ coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
+ coeff = coeff.to(input.device)
+ coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
+ out = out / (coeff + 1e-8)
+ pad = self.window_len - self.hop_size
+ out = out[..., pad:-pad]
+ return out
diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py
new file mode 100644
index 0000000..ad45139
--- /dev/null
+++ b/enhancer/utils/utils.py
@@ -0,0 +1,27 @@
+import os
+from typing import Optional
+
+from enhancer.utils.config import Files
+
+
+def check_files(root_dir: str, files: Files):
+
+ path_variables = [
+ member_var
+ for member_var in dir(files)
+ if not member_var.startswith("__")
+ ]
+ for variable in path_variables:
+ path = getattr(files, variable)
+ if not os.path.isdir(os.path.join(root_dir, path)):
+ raise ValueError(f"Invalid {path}, is not a directory")
+
+ return files, root_dir
+
+
+def merge_dict(default_dict: dict, custom: Optional[dict] = None):
+
+ params = dict(default_dict)
+ if custom:
+ params.update(custom)
+ return params
diff --git a/enhancer/version.py b/enhancer/version.py
new file mode 100644
index 0000000..f102a9c
--- /dev/null
+++ b/enhancer/version.py
@@ -0,0 +1 @@
+__version__ = "0.0.1"
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..8da22e1
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,8 @@
+name: enhancer
+
+dependencies:
+ - pip=21.0.1
+ - python=3.8
+ - pip:
+ - -r requirements.txt
+ - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..b3e5d7c
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,15 @@
+[tool.black]
+line-length = 80
+target-version = ['py38']
+exclude = '''
+
+(
+ /(
+ \.eggs # exclude a few common directories in the
+ | \.git # root of the project
+ | \.mypy_cache
+ | \.tox
+ | \.venv
+ )/
+)
+'''
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..fb54920
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+boto3>=1.24.86
+huggingface-hub>=0.10.0
+hydra-core>=1.2.0
+joblib>=1.2.0
+librosa>=0.9.2
+mlflow>=1.29.0
+numpy>=1.23.3
+pesq==0.0.4
+protobuf>=3.19.6
+pystoi==0.3.3
+pytest-lazy-fixture>=0.6.3
+pytorch-lightning>=1.7.7
+scikit-learn>=1.1.2
+scipy>=1.9.1
+soundfile>=0.11.0
+torch>=1.12.1
+torch-audiomentations==0.11.0
+torchaudio>=0.12.1
+tqdm>=4.64.1
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..309ac9a
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,100 @@
+# This file is used to configure your project.
+# Read more about the various options under:
+# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
+
+[metadata]
+name = enhancer
+description = Deep learning for speech enhacement
+author = Shahul Ess
+author-email = shahules786@gmail.com
+license = mit
+long-description = file: README.md
+long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
+# Change if running only on Windows, Mac or Linux (comma-separated)
+platforms = Linux, Mac
+# Add here all kinds of additional classifiers as defined under
+# https://pypi.python.org/pypi?%3Aaction=list_classifiers
+classifiers =
+ Development Status :: 4 - Beta
+ Programming Language :: Python
+
+[options]
+zip_safe = False
+packages = find:
+include_package_data = True
+# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
+setup_requires = setuptools
+# Add here dependencies of your project (semicolon/line-separated), e.g.
+# install_requires = numpy; scipy
+# Require a specific Python version, e.g. Python 2.7 or >= 3.4
+python_requires = >=3.8
+
+[options.packages.find]
+where = .
+exclude =
+ tests
+
+[options.extras_require]
+# Add here additional requirements for extra features, to install with:
+# `pip install fastaudio[PDF]` like:
+# PDF = ReportLab; RXP
+# Add here test requirements (semicolon/line-separated)
+testing =
+ pytest>=7.1.3
+ pytest-cov>=4.0.0
+dev =
+ pre-commit>=2.20.0
+ black>=22.8.0
+ flake8>=5.0.4
+cli =
+ hydra-core >=1.1,<=1.2
+
+
+[options.entry_points]
+
+console_scripts =
+ enhancer-train=enhancer.cli.train:train
+
+[test]
+# py.test options when running `python setup.py test`
+# addopts = --verbose
+extras = True
+
+[tool:pytest]
+# Options for py.test:
+# Specify command line options as you would do when invoking py.test directly.
+# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
+# in order to write a coverage file that can be read by Jenkins.
+addopts =
+ --cov enhancer --cov-report term-missing
+ --verbose
+norecursedirs =
+ dist
+ build
+ .tox
+testpaths = tests
+
+[aliases]
+dists = bdist_wheel
+
+[bdist_wheel]
+# Use this option if your package is pure-python
+universal = 1
+
+[build_sphinx]
+source_dir = doc
+build_dir = build/sphinx
+
+[devpi:upload]
+# Options for the devpi: PyPI server and packaging tool
+# VCS export must be deactivated since we are using setuptools-scm
+no-vcs = 1
+formats = bdist_wheel
+
+[flake8]
+# Some sane defaults for the code style checker flake8
+exclude =
+ .tox
+ build
+ dist
+ .eggs
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..79282b3
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,63 @@
+import os
+import sys
+from pathlib import Path
+
+from pkg_resources import VersionConflict, require
+from setuptools import find_packages, setup
+
+with open("README.md") as f:
+ long_description = f.read()
+
+with open("requirements.txt") as f:
+ requirements = f.read().splitlines()
+
+try:
+ require("setuptools>=38.3")
+except VersionConflict:
+ print("Error: version of setuptools is too old (<38.3)!")
+ sys.exit(1)
+
+
+ROOT_DIR = Path(__file__).parent.resolve()
+# Creating the version file
+
+with open("version.txt") as f:
+ version = f.read()
+
+version = version.strip()
+sha = "Unknown"
+
+if os.getenv("BUILD_VERSION"):
+ version = os.getenv("BUILD_VERSION")
+elif sha != "Unknown":
+ version += "+" + sha[:7]
+print("-- Building version " + version)
+
+version_path = ROOT_DIR / "enhancer" / "version.py"
+
+with open(version_path, "w") as f:
+ f.write("__version__ = '{}'\n".format(version))
+
+if __name__ == "__main__":
+ setup(
+ name="enhancer",
+ namespace_packages=["enhancer"],
+ version=version,
+ packages=find_packages(),
+ install_requires=requirements,
+ description="Deep learning toolkit for speech enhancement",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ author="Shahul Es",
+ author_email="shahules786@gmail.com",
+ url="",
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Natural Language :: English",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Topic :: Scientific/Engineering",
+ ],
+ )
diff --git a/tests/data/vctk/clean_testset_wav/p257_166.wav b/tests/data/vctk/clean_testset_wav/p257_166.wav
new file mode 100644
index 0000000..932df27
Binary files /dev/null and b/tests/data/vctk/clean_testset_wav/p257_166.wav differ
diff --git a/tests/data/vctk/clean_testset_wav/p257_167.wav b/tests/data/vctk/clean_testset_wav/p257_167.wav
new file mode 100644
index 0000000..6b72b76
Binary files /dev/null and b/tests/data/vctk/clean_testset_wav/p257_167.wav differ
diff --git a/tests/data/vctk/noisy_testset_wav/p257_166.wav b/tests/data/vctk/noisy_testset_wav/p257_166.wav
new file mode 100644
index 0000000..139b7cb
Binary files /dev/null and b/tests/data/vctk/noisy_testset_wav/p257_166.wav differ
diff --git a/tests/data/vctk/noisy_testset_wav/p257_167.wav b/tests/data/vctk/noisy_testset_wav/p257_167.wav
new file mode 100644
index 0000000..59788b7
Binary files /dev/null and b/tests/data/vctk/noisy_testset_wav/p257_167.wav differ
diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py
new file mode 100644
index 0000000..4d14871
--- /dev/null
+++ b/tests/loss_function_test.py
@@ -0,0 +1,32 @@
+import pytest
+import torch
+
+from enhancer.loss import mean_absolute_error, mean_squared_error
+
+loss_functions = [mean_absolute_error(), mean_squared_error()]
+
+
+def check_loss_shapes_compatibility(loss_fun):
+
+ batch_size = 4
+ shape = (1, 1000)
+ loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
+
+ with pytest.raises(TypeError):
+ loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
+
+
+@pytest.mark.parametrize("loss", loss_functions)
+def test_loss_input_shapes(loss):
+ check_loss_shapes_compatibility(loss)
+
+
+@pytest.mark.parametrize("loss", loss_functions)
+def test_loss_output_type(loss):
+
+ batch_size = 4
+ prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
+ batch_size, 1, 1000
+ )
+ loss_value = loss(prediction, target)
+ assert isinstance(loss_value.item(), float)
diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py
new file mode 100644
index 0000000..524a6cf
--- /dev/null
+++ b/tests/models/complexnn_test.py
@@ -0,0 +1,50 @@
+import torch
+
+from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
+from enhancer.models.complexnn.rnn import ComplexLSTM
+from enhancer.models.complexnn.utils import ComplexBatchNorm2D
+
+
+def test_complexconv2d():
+ sample_input = torch.rand(1, 2, 256, 13)
+ conv = ComplexConv2d(
+ 2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
+ )
+ with torch.no_grad():
+ out = conv(sample_input)
+ assert out.shape == torch.Size([1, 32, 128, 13])
+
+
+def test_complexconvtranspose2d():
+ sample_input = torch.rand(1, 512, 4, 13)
+ conv = ComplexConvTranspose2d(
+ 256 * 2,
+ 128 * 2,
+ kernel_size=(5, 2),
+ stride=(2, 1),
+ padding=(2, 0),
+ output_padding=(1, 0),
+ )
+ with torch.no_grad():
+ out = conv(sample_input)
+
+ assert out.shape == torch.Size([1, 256, 8, 14])
+
+
+def test_complexlstm():
+ sample_input = torch.rand(13, 2, 128)
+ lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
+ with torch.no_grad():
+ out = lstm(sample_input)
+
+ assert out[0].shape == torch.Size([13, 1, 512])
+ assert out[1].shape == torch.Size([13, 1, 512])
+
+
+def test_complexbatchnorm2d():
+ sample_input = torch.rand(1, 64, 64, 14)
+ batchnorm = ComplexBatchNorm2D(num_features=64)
+ with torch.no_grad():
+ out = batchnorm(sample_input)
+
+ assert out.size() == sample_input.size()
diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py
new file mode 100644
index 0000000..29e030e
--- /dev/null
+++ b/tests/models/demucs_test.py
@@ -0,0 +1,43 @@
+import pytest
+import torch
+
+from enhancer.data.dataset import EnhancerDataset
+from enhancer.models import Demucs
+from enhancer.utils.config import Files
+
+
+@pytest.fixture
+def vctk_dataset():
+ root_dir = "tests/data/vctk"
+ files = Files(
+ train_clean="clean_testset_wav",
+ train_noisy="noisy_testset_wav",
+ test_clean="clean_testset_wav",
+ test_noisy="noisy_testset_wav",
+ )
+ dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
+ return dataset
+
+
+@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
+def test_forward(batch_size, samples):
+ model = Demucs()
+ model.eval()
+
+ data = torch.rand(batch_size, 1, samples, requires_grad=False)
+ with torch.no_grad():
+ _ = model(data)
+
+ data = torch.rand(batch_size, 2, samples, requires_grad=False)
+ with torch.no_grad():
+ with pytest.raises(ValueError):
+ _ = model(data)
+
+
+@pytest.mark.parametrize(
+ "dataset,channels,loss",
+ [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
+)
+def test_demucs_init(dataset, channels, loss):
+ with torch.no_grad():
+ _ = Demucs(num_channels=channels, dataset=dataset, loss=loss)
diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py
new file mode 100644
index 0000000..96a853b
--- /dev/null
+++ b/tests/models/test_dccrn.py
@@ -0,0 +1,43 @@
+import pytest
+import torch
+
+from enhancer.data.dataset import EnhancerDataset
+from enhancer.models.dccrn import DCCRN
+from enhancer.utils.config import Files
+
+
+@pytest.fixture
+def vctk_dataset():
+ root_dir = "tests/data/vctk"
+ files = Files(
+ train_clean="clean_testset_wav",
+ train_noisy="noisy_testset_wav",
+ test_clean="clean_testset_wav",
+ test_noisy="noisy_testset_wav",
+ )
+ dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
+ return dataset
+
+
+@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
+def test_forward(batch_size, samples):
+ model = DCCRN()
+ model.eval()
+
+ data = torch.rand(batch_size, 1, samples, requires_grad=False)
+ with torch.no_grad():
+ _ = model(data)
+
+ data = torch.rand(batch_size, 2, samples, requires_grad=False)
+ with torch.no_grad():
+ with pytest.raises(ValueError):
+ _ = model(data)
+
+
+@pytest.mark.parametrize(
+ "dataset,channels,loss",
+ [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
+)
+def test_demucs_init(dataset, channels, loss):
+ with torch.no_grad():
+ _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)
diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py
new file mode 100644
index 0000000..9c4dd96
--- /dev/null
+++ b/tests/models/test_waveunet.py
@@ -0,0 +1,43 @@
+import pytest
+import torch
+
+from enhancer.data.dataset import EnhancerDataset
+from enhancer.models import WaveUnet
+from enhancer.utils.config import Files
+
+
+@pytest.fixture
+def vctk_dataset():
+ root_dir = "tests/data/vctk"
+ files = Files(
+ train_clean="clean_testset_wav",
+ train_noisy="noisy_testset_wav",
+ test_clean="clean_testset_wav",
+ test_noisy="noisy_testset_wav",
+ )
+ dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
+ return dataset
+
+
+@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
+def test_forward(batch_size, samples):
+ model = WaveUnet()
+ model.eval()
+
+ data = torch.rand(batch_size, 1, samples, requires_grad=False)
+ with torch.no_grad():
+ _ = model(data)
+
+ data = torch.rand(batch_size, 2, samples, requires_grad=False)
+ with torch.no_grad():
+ with pytest.raises(TypeError):
+ _ = model(data)
+
+
+@pytest.mark.parametrize(
+ "dataset,channels,loss",
+ [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
+)
+def test_demucs_init(dataset, channels, loss):
+ with torch.no_grad():
+ _ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)
diff --git a/tests/test_inference.py b/tests/test_inference.py
new file mode 100644
index 0000000..a6e2423
--- /dev/null
+++ b/tests/test_inference.py
@@ -0,0 +1,29 @@
+import pytest
+import torch
+
+from enhancer.inference import Inference
+
+
+@pytest.mark.parametrize(
+ "audio",
+ ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
+)
+def test_read_input(audio):
+
+ read_audio = Inference.read_input(audio, 48000, 16000)
+ assert isinstance(read_audio, torch.Tensor)
+ assert read_audio.shape[0] == 1
+
+
+def test_batchify():
+ rand = torch.rand(1, 1000)
+ batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
+ assert batched_rand.shape[0] == 12
+
+
+def test_aggregate():
+ rand = torch.rand(12, 1, 100)
+ agg_rand = Inference.aggreagate(
+ data=rand, window_size=100, total_frames=1000, step_size=100
+ )
+ assert agg_rand.shape[-1] == 1000
diff --git a/tests/transforms_test.py b/tests/transforms_test.py
new file mode 100644
index 0000000..89425ad
--- /dev/null
+++ b/tests/transforms_test.py
@@ -0,0 +1,18 @@
+import torch
+
+from enhancer.utils.transforms import ConviSTFT, ConvSTFT
+
+
+def test_stft_istft():
+ sample_input = torch.rand(1, 1, 16000)
+ stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
+ istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
+
+ with torch.no_grad():
+ spectrogram = stft(sample_input)
+ waveform = istft(spectrogram)
+ assert sample_input.shape == waveform.shape
+ assert (
+ torch.isclose(waveform, sample_input).sum().item()
+ > sample_input.shape[-1] // 2
+ )
diff --git a/tests/utils_test.py b/tests/utils_test.py
new file mode 100644
index 0000000..cd5240c
--- /dev/null
+++ b/tests/utils_test.py
@@ -0,0 +1,49 @@
+import numpy as np
+import pytest
+import torch
+
+from enhancer.data.fileprocessor import Fileprocessor
+from enhancer.utils.io import Audio
+
+
+def test_io_channel():
+
+ input_audio = np.random.rand(2, 32000)
+ audio = Audio(mono=True, return_tensor=False)
+ output_audio = audio(input_audio)
+ assert output_audio.shape[0] == 1
+
+
+def test_io_resampling():
+
+ input_audio = np.random.rand(1, 32000)
+ resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
+
+ input_audio = torch.rand(1, 32000)
+ resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
+
+ assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
+
+
+def test_fileprocessor_vctk():
+
+ fp = Fileprocessor.from_name(
+ "vctk",
+ "tests/data/vctk/clean_testset_wav",
+ "tests/data/vctk/noisy_testset_wav",
+ )
+ matching_dict = fp.prepare_matching_dict()
+ assert len(matching_dict) == 2
+
+
+@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
+def test_fileprocessor_names(dataset_name):
+ fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir")
+ assert hasattr(fp.matching_function, "__call__")
+
+
+def test_fileprocessor_invaliname():
+ with pytest.raises(ValueError):
+ _ = Fileprocessor.from_name(
+ "undefined", "clean_dir", "noisy_dir", 16000
+ ).prepare_matching_dict()
diff --git a/version.txt b/version.txt
new file mode 100644
index 0000000..8acdd82
--- /dev/null
+++ b/version.txt
@@ -0,0 +1 @@
+0.0.1