diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..abbbc73 --- /dev/null +++ b/.flake8 @@ -0,0 +1,9 @@ +[flake8] +per-file-ignores = __init__.py:F401 +ignore = E203, E266, E501, W503 +# line length is intentionally set to 80 here because black uses Bugbear +# See https://github.com/psf/black/blob/master/README.md#line-length for more details +max-line-length = 80 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 +exclude = tools/kaldi_decoder diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..4c64745 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,51 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Enhancer + +on: + push: + branches: [ dev ] + pull_request: + branches: [ dev ] +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1.1.1 + env : + ACTIONS_ALLOW_UNSECURE_COMMANDS : true + with: + python-version: ${{ matrix.python-version }} + - name: Cache pip + uses: actions/cache@v1 + with: + path: ~/.cache/pip # This path is specific to Ubuntu + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + ${{ runner.os }}- + # You can test your matrix by printing the current Python version + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + sudo apt-get install libsndfile1 + pip install -r requirements.txt + pip install black pytest-cov + - name: Install enhancer + run: | + pip install -e .[dev,testing] + - name: Run black + run: + black --check . --exclude enhancer/version.py + - name: Test with pytest + run: + pytest tests --cov=enhancer/ diff --git a/.gitignore b/.gitignore index b6e4761..cd1b1e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ +#local +*.ckpt +*_local.yaml +cli/train_config/dataset/Vctk_local.yaml +.DS_Store +outputs/ +datasets/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..807429c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ + +repos: + # # Clean Notebooks + # - repo: https://github.com/kynan/nbstripout + # rev: master + # hooks: + # - id: nbstripout + # Format Code + - repo: https://github.com/ambv/black + rev: 22.8.0 + hooks: + - id: black + + # Sort imports + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://gitlab.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: ['--ignore=E203,E501,F811,E712,W503'] + + # Formatting, Whitespace, etc + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=no'] diff --git a/README.md b/README.md index e462afa..13b8e14 100644 --- a/README.md +++ b/README.md @@ -1 +1,43 @@ -# enhancer \ No newline at end of file +

+ +

+ +mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training. + +| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()** +## Key features :key: + +* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle. +* :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code! +* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself! +* :zap: Supports multi-gpu training integrated with Pytorch Lightning. + +## Quick Start :fire: +``` python +from mayavoz import Mayamodel + +model = Mayamodel.from_pretrained("mayavoz/waveunet") +model("noisy_audio.wav") +``` + +## Installation +Only Python 3.8+ is officially supported (though it might work with Python 3.7) + +- With Pypi +``` +pip install mayavoz +``` + +- With conda + +``` +conda env create -f environment.yml +conda activate mayavoz +``` + +- From source code +``` +git clone url +cd mayavoz +pip install -e . +``` diff --git a/enhancer/__init__.py b/enhancer/__init__.py new file mode 100644 index 0000000..5284146 --- /dev/null +++ b/enhancer/__init__.py @@ -0,0 +1 @@ +__import__("pkg_resources").declare_namespace(__name__) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py new file mode 100644 index 0000000..c00c024 --- /dev/null +++ b/enhancer/cli/train.py @@ -0,0 +1,120 @@ +import os +from types import MethodType + +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) +from pytorch_lightning.loggers import MLFlowLogger +from torch.optim.lr_scheduler import ReduceLROnPlateau + +# from torch_audiomentations import Compose, Shift + +os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID", "0") + + +@hydra.main(config_path="train_config", config_name="config") +def main(config: DictConfig): + + OmegaConf.save(config, "config_log.yaml") + + callbacks = [] + logger = MLFlowLogger( + experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name, + tags={"JOB_ID": JOB_ID}, + ) + + parameters = config.hyperparameters + # apply_augmentations = Compose( + # [ + # Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), + # ] + # ) + + dataset = instantiate(config.dataset, augmentations=None) + model = instantiate( + config.model, + dataset=dataset, + lr=parameters.get("lr"), + loss=parameters.get("loss"), + metric=parameters.get("metric"), + ) + + direction = model.valid_monitor + checkpoint = ModelCheckpoint( + dirpath="./model", + filename=f"model_{JOB_ID}", + monitor="valid_loss", + verbose=False, + mode=direction, + every_n_epochs=1, + ) + callbacks.append(checkpoint) + callbacks.append(LearningRateMonitor(logging_interval="epoch")) + + if parameters.get("Early_stop", False): + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) + + def configure_optimizers(self): + optimizer = instantiate( + config.optimizer, + lr=parameters.get("lr"), + params=self.parameters(), + ) + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + mode=direction, + factor=parameters.get("ReduceLr_factor", 0.1), + verbose=True, + min_lr=parameters.get("min_lr", 1e-6), + patience=parameters.get("ReduceLr_patience", 3), + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', + } + + model.configure_optimizers = MethodType(configure_optimizers, model) + + trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) + trainer.fit(model) + trainer.test(model) + + logger.experiment.log_artifact( + logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" + ) + + saved_location = os.path.join( + trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" + ) + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id, saved_location) + logger.experiment.log_param( + logger.run_id, + "num_train_steps_per_epoch", + dataset.train__len__() / dataset.batch_size, + ) + logger.experiment.log_param( + logger.run_id, + "num_valid_steps_per_epoch", + dataset.val__len__() / dataset.batch_size, + ) + + +if __name__ == "__main__": + main() diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml new file mode 100644 index 0000000..8d0ab14 --- /dev/null +++ b/enhancer/cli/train_config/config.yaml @@ -0,0 +1,7 @@ +defaults: + - model : Demucs + - dataset : Vctk + - optimizer : Adam + - hyperparameters : default + - trainer : default + - mlflow : experiment diff --git a/enhancer/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml new file mode 100644 index 0000000..09a14fb --- /dev/null +++ b/enhancer/cli/train_config/dataset/DNS-2020.yaml @@ -0,0 +1,12 @@ +_target_: enhancer.data.dataset.EnhancerDataset +root_dir : /Users/shahules/Myprojects/MS-SNSD +name : dns-2020 +duration : 2.0 +sampling_rate: 16000 +batch_size: 32 +valid_size: 0.05 +files: + train_clean : CleanSpeech_training + test_clean : CleanSpeech_training + train_noisy : NoisySpeech_training + test_noisy : NoisySpeech_training diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml new file mode 100644 index 0000000..c33d29a --- /dev/null +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -0,0 +1,13 @@ +_target_: enhancer.data.dataset.EnhancerDataset +name : vctk +root_dir : /scratch/c.sistc3/DS_10283_2791 +duration : 4.5 +stride : 2 +sampling_rate: 16000 +batch_size: 32 +valid_minutes : 15 +files: + train_clean : clean_trainset_28spk_wav + test_clean : clean_testset_wav + train_noisy : noisy_trainset_28spk_wav + test_noisy : noisy_testset_wav diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml new file mode 100644 index 0000000..ba44597 --- /dev/null +++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml @@ -0,0 +1,13 @@ +_target_: enhancer.data.dataset.EnhancerDataset +name : vctk +root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk +duration : 1.0 +sampling_rate: 16000 +batch_size: 64 +num_workers : 0 + +files: + train_clean : clean_testset_wav + test_clean : clean_testset_wav + train_noisy : noisy_testset_wav + test_noisy : noisy_testset_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml new file mode 100644 index 0000000..1782ea9 --- /dev/null +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -0,0 +1,7 @@ +loss : mae +metric : [stoi,pesq,si-sdr] +lr : 0.0003 +ReduceLr_patience : 5 +ReduceLr_factor : 0.2 +min_lr : 0.000001 +EarlyStopping_factor : 10 diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml new file mode 100644 index 0000000..d597333 --- /dev/null +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -0,0 +1,2 @@ +experiment_name : shahules/enhancer +run_name : Demucs + Vtck with stride + augmentations diff --git a/enhancer/cli/train_config/model/DCCRN.yaml b/enhancer/cli/train_config/model/DCCRN.yaml new file mode 100644 index 0000000..3190391 --- /dev/null +++ b/enhancer/cli/train_config/model/DCCRN.yaml @@ -0,0 +1,25 @@ +_target_: enhancer.models.dccrn.DCCRN +num_channels: 1 +sampling_rate : 16000 +complex_lstm : True +complex_norm : True +complex_relu : True +masking_mode : True + +encoder_decoder: + initial_output_channels : 32 + depth : 6 + kernel_size : 5 + growth_factor : 2 + stride : 2 + padding : 2 + output_padding : 1 + +lstm: + num_layers : 2 + hidden_size : 256 + +stft: + window_len : 400 + hop_size : 100 + nfft : 512 diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml new file mode 100644 index 0000000..513e603 --- /dev/null +++ b/enhancer/cli/train_config/model/Demucs.yaml @@ -0,0 +1,16 @@ +_target_: enhancer.models.demucs.Demucs +num_channels: 1 +resample: 4 +sampling_rate : 16000 + +encoder_decoder: + depth: 4 + initial_output_channels: 64 + kernel_size: 8 + stride: 4 + growth_factor: 2 + glu: True + +lstm: + bidirectional: False + num_layers: 2 diff --git a/enhancer/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml new file mode 100644 index 0000000..29d48c7 --- /dev/null +++ b/enhancer/cli/train_config/model/WaveUnet.yaml @@ -0,0 +1,5 @@ +_target_: enhancer.models.waveunet.WaveUnet +num_channels : 1 +depth : 9 +initial_output_channels: 24 +sampling_rate : 16000 diff --git a/enhancer/cli/train_config/optimizer/Adam.yaml b/enhancer/cli/train_config/optimizer/Adam.yaml new file mode 100644 index 0000000..7952b81 --- /dev/null +++ b/enhancer/cli/train_config/optimizer/Adam.yaml @@ -0,0 +1,6 @@ +_target_: torch.optim.Adam +lr: 1e-3 +betas: [0.9, 0.999] +eps: 1e-08 +weight_decay: 0 +amsgrad: False diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml new file mode 100644 index 0000000..958c418 --- /dev/null +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -0,0 +1,46 @@ +_target_: pytorch_lightning.Trainer +accelerator: gpu +accumulate_grad_batches: 1 +amp_backend: native +auto_lr_find: True +auto_scale_batch_size: False +auto_select_gpus: True +benchmark: False +check_val_every_n_epoch: 1 +detect_anomaly: False +deterministic: False +devices: 2 +enable_checkpointing: True +enable_model_summary: True +enable_progress_bar: True +fast_dev_run: False +gpus: null +gradient_clip_val: 0 +gradient_clip_algorithm: norm +ipus: null +limit_predict_batches: 1.0 +limit_test_batches: 1.0 +limit_train_batches: 1.0 +limit_val_batches: 1.0 +log_every_n_steps: 50 +max_epochs: 200 +max_steps: -1 +max_time: null +min_epochs: 1 +min_steps: null +move_metrics_to_cpu: False +multiple_trainloader_mode: max_size_cycle +num_nodes: 1 +num_processes: 1 +num_sanity_val_steps: 2 +overfit_batches: 0.0 +precision: 32 +profiler: null +reload_dataloaders_every_n_epochs: 0 +replace_sampler_ddp: True +strategy: ddp +sync_batchnorm: False +tpu_cores: null +track_grad_norm: -1 +val_check_interval: 1.0 +weights_save_path: null diff --git a/enhancer/cli/train_config/trainer/fastrun_dev.yaml b/enhancer/cli/train_config/trainer/fastrun_dev.yaml new file mode 100644 index 0000000..682149e --- /dev/null +++ b/enhancer/cli/train_config/trainer/fastrun_dev.yaml @@ -0,0 +1,2 @@ +_target_: pytorch_lightning.Trainer +fast_dev_run: True diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py new file mode 100644 index 0000000..7efd946 --- /dev/null +++ b/enhancer/data/__init__.py @@ -0,0 +1 @@ +from enhancer.data.dataset import EnhancerDataset diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py new file mode 100644 index 0000000..284dfdb --- /dev/null +++ b/enhancer/data/dataset.py @@ -0,0 +1,376 @@ +import math +import multiprocessing +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, RandomSampler +from torch_audiomentations import Compose + +from enhancer.data.fileprocessor import Fileprocessor +from enhancer.utils import check_files +from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.random import create_unique_rng + +LARGE_NUM = 2147483647 + + +class TrainDataset(Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, idx): + return self.dataset.train__getitem__(idx) + + def __len__(self): + return self.dataset.train__len__() + + +class ValidDataset(Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, idx): + return self.dataset.val__getitem__(idx) + + def __len__(self): + return self.dataset.val__len__() + + +class TestDataset(Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, idx): + return self.dataset.test__getitem__(idx) + + def __len__(self): + return self.dataset.test__len__() + + +class TaskDataset(pl.LightningDataModule): + def __init__( + self, + name: str, + root_dir: str, + files: Files, + min_valid_minutes: float = 0.20, + duration: float = 1.0, + stride=None, + sampling_rate: int = 48000, + matching_function=None, + batch_size=32, + num_workers: Optional[int] = None, + augmentations: Optional[Compose] = None, + ): + super().__init__() + + self.name = name + self.files, self.root_dir = check_files(root_dir, files) + self.duration = duration + self.stride = stride or duration + self.sampling_rate = sampling_rate + self.batch_size = batch_size + self.matching_function = matching_function + self._validation = [] + if num_workers is None: + num_workers = multiprocessing.cpu_count() // 2 + self.num_workers = num_workers + if min_valid_minutes > 0.0: + self.min_valid_minutes = min_valid_minutes + else: + raise ValueError("min_valid_minutes must be greater than 0") + + self.augmentations = augmentations + + def setup(self, stage: Optional[str] = None): + """ + prepare train/validation/test data splits + """ + + if stage in ("fit", None): + + train_clean = os.path.join(self.root_dir, self.files.train_clean) + train_noisy = os.path.join(self.root_dir, self.files.train_noisy) + fp = Fileprocessor.from_name( + self.name, train_clean, train_noisy, self.matching_function + ) + train_data = fp.prepare_matching_dict() + train_data, self.val_data = self.train_valid_split( + train_data, + min_valid_minutes=self.min_valid_minutes, + random_state=42, + ) + + self.train_data = self.prepare_traindata(train_data) + self._validation = self.prepare_mapstype(self.val_data) + + test_clean = os.path.join(self.root_dir, self.files.test_clean) + test_noisy = os.path.join(self.root_dir, self.files.test_noisy) + fp = Fileprocessor.from_name( + self.name, test_clean, test_noisy, self.matching_function + ) + test_data = fp.prepare_matching_dict() + self._test = self.prepare_mapstype(test_data) + + def train_valid_split( + self, data, min_valid_minutes: float = 20, random_state: int = 42 + ): + + min_valid_minutes *= 60 + valid_sec_now = 0.0 + valid_indices = [] + all_speakers = np.unique( + [Path(file["clean"]).name.split("_")[0] for file in data] + ) + possible_indices = list(range(0, len(all_speakers))) + rng = create_unique_rng(len(all_speakers)) + + while valid_sec_now <= min_valid_minutes: + speaker_index = rng.choice(possible_indices) + possible_indices.remove(speaker_index) + speaker_name = all_speakers[speaker_index] + print(f"Selected f{speaker_name} for valid") + file_indices = [ + i + for i, file in enumerate(data) + if speaker_name == Path(file["clean"]).name.split("_")[0] + ] + for i in file_indices: + valid_indices.append(i) + valid_sec_now += data[i]["duration"] + + train_data = [ + item for i, item in enumerate(data) if i not in valid_indices + ] + valid_data = [item for i, item in enumerate(data) if i in valid_indices] + return train_data, valid_data + + def prepare_traindata(self, data): + train_data = [] + for item in data: + clean, noisy, total_dur = item.values() + num_segments = self.get_num_segments( + total_dur, self.duration, self.stride + ) + samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments) + train_data.append(samples_metadata) + return train_data + + @staticmethod + def get_num_segments(file_duration, duration, stride): + + if file_duration < duration: + num_segments = 1 + else: + num_segments = math.ceil((file_duration - duration) / stride) + 1 + + return num_segments + + def prepare_mapstype(self, data): + + metadata = [] + for item in data: + clean, noisy, total_dur = item.values() + if total_dur < self.duration: + metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) + else: + num_segments = self.get_num_segments( + total_dur, self.duration, self.duration + ) + for index in range(num_segments): + start_time = index * self.duration + metadata.append( + ({"clean": clean, "noisy": noisy}, start_time) + ) + return metadata + + def train_collatefn(self, batch): + + output = {"clean": [], "noisy": []} + for item in batch: + output["clean"].append(item["clean"]) + output["noisy"].append(item["noisy"]) + + output["clean"] = torch.stack(output["clean"], dim=0) + output["noisy"] = torch.stack(output["noisy"], dim=0) + + if self.augmentations is not None: + noise = output["noisy"] - output["clean"] + output["clean"] = self.augmentations( + output["clean"], sample_rate=self.sampling_rate + ) + self.augmentations.freeze_parameters() + output["noisy"] = ( + self.augmentations(noise, sample_rate=self.sampling_rate) + + output["clean"] + ) + + return output + + @property + def generator(self): + generator = torch.Generator() + if hasattr(self, "model"): + seed = self.model.current_epoch + LARGE_NUM + else: + seed = LARGE_NUM + return generator.manual_seed(seed) + + def train_dataloader(self): + dataset = TrainDataset(self) + sampler = RandomSampler(dataset, generator=self.generator) + return DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + sampler=sampler, + collate_fn=self.train_collatefn, + ) + + def val_dataloader(self): + return DataLoader( + ValidDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return DataLoader( + TestDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + +class EnhancerDataset(TaskDataset): + """ + Dataset object for creating clean-noisy speech enhancement datasets + paramters: + name : str + name of the dataset + root_dir : str + root directory of the dataset containing clean/noisy folders + files : Files + dataclass containing train_clean, train_noisy, test_clean, test_noisy + folder names (refer enhancer.utils.Files dataclass) + min_valid_minutes: float + minimum validation split size time in minutes + algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data. + duration : float + expected audio duration of single audio sample for training + sampling_rate : int + desired sampling rate + batch_size : int + batch size of each batch + num_workers : int + num workers to be used while training + matching_function : str + maching functions - (one_to_one,one_to_many). Default set to None. + use one_to_one mapping for datasets with one noisy file for each clean file + use one_to_many mapping for multiple noisy files for each clean file + + + + """ + + def __init__( + self, + name: str, + root_dir: str, + files: Files, + min_valid_minutes=5.0, + duration=1.0, + stride=None, + sampling_rate=48000, + matching_function=None, + batch_size=32, + num_workers: Optional[int] = None, + augmentations: Optional[Compose] = None, + ): + + super().__init__( + name=name, + root_dir=root_dir, + files=files, + min_valid_minutes=min_valid_minutes, + sampling_rate=sampling_rate, + duration=duration, + matching_function=matching_function, + batch_size=batch_size, + num_workers=num_workers, + augmentations=augmentations, + ) + + self.sampling_rate = sampling_rate + self.files = files + self.duration = max(1.0, duration) + self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) + self.stride = stride or duration + + def setup(self, stage: Optional[str] = None): + + super().setup(stage=stage) + + def train__getitem__(self, idx): + + for filedict, num_samples in self.train_data: + if idx >= num_samples: + idx -= num_samples + continue + else: + start = 0 + if self.duration is not None: + start = idx * self.stride + return self.prepare_segment(filedict, start) + + def val__getitem__(self, idx): + return self.prepare_segment(*self._validation[idx]) + + def test__getitem__(self, idx): + return self.prepare_segment(*self._test[idx]) + + def prepare_segment(self, file_dict: dict, start_time: float): + clean_segment = self.audio( + file_dict["clean"], offset=start_time, duration=self.duration + ) + noisy_segment = self.audio( + file_dict["noisy"], offset=start_time, duration=self.duration + ) + clean_segment = F.pad( + clean_segment, + ( + 0, + int( + self.duration * self.sampling_rate - clean_segment.shape[-1] + ), + ), + ) + noisy_segment = F.pad( + noisy_segment, + ( + 0, + int( + self.duration * self.sampling_rate - noisy_segment.shape[-1] + ), + ), + ) + return { + "clean": clean_segment, + "noisy": noisy_segment, + } + + def train__len__(self): + _, num_examples = list(zip(*self.train_data)) + return sum(num_examples) + + def val__len__(self): + return len(self._validation) + + def test__len__(self): + return len(self._test) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py new file mode 100644 index 0000000..5b099d4 --- /dev/null +++ b/enhancer/data/fileprocessor.py @@ -0,0 +1,121 @@ +import glob +import os + +import numpy as np +from scipy.io import wavfile + +MATCHING_FNS = ("one_to_one", "one_to_many") + + +class ProcessorFunctions: + """ + Preprocessing methods for different types of speech enhacement datasets. + """ + + @staticmethod + def one_to_one(clean_path, noisy_path): + """ + One clean audio can have only one noisy audio file + """ + + matching_wavfiles = list() + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] + noisy_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(noisy_path, "*.wav")) + ] + common_filenames = np.intersect1d(noisy_filenames, clean_filenames) + + for file_name in common_filenames: + + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, file_name) + ) + sr_noisy, noisy_file = wavfile.read( + os.path.join(noisy_path, file_name) + ) + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): + matching_wavfiles.append( + { + "clean": os.path.join(clean_path, file_name), + "noisy": os.path.join(noisy_path, file_name), + "duration": clean_file.shape[-1] / sr_clean, + } + ) + return matching_wavfiles + + @staticmethod + def one_to_many(clean_path, noisy_path): + """ + One clean audio have multiple noisy audio files + """ + + matching_wavfiles = list() + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] + for clean_file in clean_filenames: + noisy_filenames = glob.glob( + os.path.join(noisy_path, f"*_{clean_file}") + ) + for noisy_file in noisy_filenames: + + sr_clean, clean_wav = wavfile.read( + os.path.join(clean_path, clean_file) + ) + sr_noisy, noisy_wav = wavfile.read(noisy_file) + if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and ( + sr_clean == sr_noisy + ): + matching_wavfiles.append( + { + "clean": os.path.join(clean_path, clean_file), + "noisy": noisy_file, + "duration": clean_wav.shape[-1] / sr_clean, + } + ) + return matching_wavfiles + + +class Fileprocessor: + def __init__(self, clean_dir, noisy_dir, matching_function=None): + self.clean_dir = clean_dir + self.noisy_dir = noisy_dir + self.matching_function = matching_function + + @classmethod + def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): + + if matching_function is None: + if name.lower() == "vctk": + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) + elif name.lower() == "dns-2020": + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) + else: + raise ValueError( + f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}" + ) + else: + if matching_function not in MATCHING_FNS: + raise ValueError( + f"Invalid matching function! Avaialble options are {MATCHING_FNS}" + ) + else: + return cls( + clean_dir, + noisy_dir, + getattr(ProcessorFunctions, matching_function), + ) + + def prepare_matching_dict(self): + + if self.matching_function is None: + raise ValueError("Not a valid matching function") + + return self.matching_function(self.clean_dir, self.noisy_dir) diff --git a/enhancer/inference.py b/enhancer/inference.py new file mode 100644 index 0000000..d9282fd --- /dev/null +++ b/enhancer/inference.py @@ -0,0 +1,170 @@ +from pathlib import Path +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from librosa import load as load_audio +from scipy.io import wavfile +from scipy.signal import get_window + +from enhancer.utils import Audio + + +class Inference: + """ + contains methods used for inference. + """ + + @staticmethod + def read_input(audio, sr, model_sr): + """ + read and verify audio input regardless of the input format. + arguments: + audio : audio input + sr : sampling rate of input audio + model_sr : sampling rate used for model training. + """ + + if isinstance(audio, (np.ndarray, torch.Tensor)): + assert sr is not None, "Invalid sampling rate!" + if len(audio.shape) == 1: + audio = audio.reshape(1, -1) + + if isinstance(audio, str): + audio = Path(audio) + if not audio.is_file(): + raise ValueError(f"Input file {audio} does not exist") + else: + audio, sr = load_audio( + audio, + sr=sr, + ) + if len(audio.shape) == 1: + audio = audio.reshape(1, -1) + else: + assert ( + audio.shape[0] == 1 + ), "Enhance inference only supports single waveform" + + waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr) + waveform = Audio.convert_mono(waveform) + if isinstance(waveform, np.ndarray): + waveform = torch.from_numpy(waveform) + + return waveform + + @staticmethod + def batchify( + waveform: torch.Tensor, + window_size: int, + step_size: Optional[int] = None, + ): + """ + break input waveform into samples with duration specified.(Overlap-add) + arguments: + waveform : audio waveform + window_size : window size used for splitting waveform into batches + step_size : step_size used for splitting waveform into batches + """ + assert ( + waveform.ndim == 2 + ), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" + _, num_samples = waveform.shape + waveform = waveform.unsqueeze(-1) + step_size = window_size // 2 if step_size is None else step_size + + if num_samples >= window_size: + waveform_batch = F.unfold( + waveform[None, ...], + kernel_size=(window_size, 1), + stride=(step_size, 1), + padding=(window_size, 0), + ) + waveform_batch = waveform_batch.permute(2, 0, 1) + + return waveform_batch + + @staticmethod + def aggreagate( + data: torch.Tensor, + window_size: int, + total_frames: int, + step_size: Optional[int] = None, + window="hamming", + ): + """ + stitch batched waveform into single waveform. (Overlap-add) + arguments: + data: batched waveform + window_size : window_size used to batch waveform + step_size : step_size used to batch waveform + total_frames : total number of frames present in original waveform + window : type of window used for overlap-add mechanism. + """ + num_chunks, n_channels, num_frames = data.shape + window = get_window(window=window, Nx=data.shape[-1]) + window = torch.from_numpy(window).to(data.device) + data *= window + step_size = window_size // 2 if step_size is None else step_size + + data = data.permute(1, 2, 0) + data = F.fold( + data, + (total_frames, 1), + kernel_size=(window_size, 1), + stride=(step_size, 1), + padding=(window_size, 0), + ).squeeze(-1) + + return data.reshape(1, n_channels, -1) + + @staticmethod + def write_output( + waveform: torch.Tensor, filename: Union[str, Path], sr: int + ): + """ + write audio output as wav file + arguments: + waveform : audio waveform + filename : name of the wave file. Output will be written as cleaned_filename.wav + sr : sampling rate + """ + + if isinstance(filename, str): + filename = Path(filename) + + parent, name = filename.parent, "cleaned_" + filename.name + filename = parent / Path(name) + if filename.is_file(): + raise FileExistsError(f"file {filename} already exists") + else: + wavfile.write( + filename, rate=sr, data=waveform.detach().cpu().numpy() + ) + + @staticmethod + def prepare_output( + waveform: torch.Tensor, + model_sampling_rate: int, + audio: Union[str, np.ndarray, torch.Tensor], + sampling_rate: Optional[int], + ): + """ + prepare output audio based on input format + arguments: + waveform : predicted audio waveform + model_sampling_rate : sampling rate used to train the model + audio : input audio + sampling_rate : input audio sampling rate + + """ + if isinstance(audio, np.ndarray): + waveform = waveform.detach().cpu().numpy() + + if sampling_rate is not None: + waveform = Audio.resample_audio( + waveform, sr=model_sampling_rate, target_sr=sampling_rate + ) + + return waveform diff --git a/enhancer/loss.py b/enhancer/loss.py new file mode 100644 index 0000000..75527bb --- /dev/null +++ b/enhancer/loss.py @@ -0,0 +1,216 @@ +import logging + +import numpy as np +import torch +import torch.nn as nn +from torchmetrics import ScaleInvariantSignalNoiseRatio +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility + + +class mean_squared_error(nn.Module): + """ + Mean squared error / L1 loss + """ + + def __init__(self, reduction="mean"): + super().__init__() + + self.loss_fun = nn.MSELoss(reduction=reduction) + self.higher_better = False + self.name = "mse" + + def forward(self, prediction: torch.Tensor, target: torch.Tensor): + + if prediction.size() != target.size() or target.ndim < 3: + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + return self.loss_fun(prediction, target) + + +class mean_absolute_error(nn.Module): + """ + Mean absolute error / L2 loss + """ + + def __init__(self, reduction="mean"): + super().__init__() + + self.loss_fun = nn.L1Loss(reduction=reduction) + self.higher_better = False + self.name = "mae" + + def forward(self, prediction: torch.Tensor, target: torch.Tensor): + + if prediction.size() != target.size() or target.ndim < 3: + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + return self.loss_fun(prediction, target) + + +class Si_SDR: + """ + SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) + """ + + def __init__(self, reduction: str = "mean"): + if reduction in ["sum", "mean", None]: + self.reduction = reduction + else: + raise TypeError( + "Invalid reduction, valid options are sum, mean, None" + ) + self.higher_better = True + self.name = "si-sdr" + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + + if prediction.size() != target.size() or target.ndim < 3: + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + target_energy = torch.sum(target**2, keepdim=True, dim=-1) + scaling_factor = ( + torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy + ) + target_projection = target * scaling_factor + noise = prediction - target_projection + ratio = torch.sum(target_projection**2, dim=-1) / torch.sum( + noise**2, dim=-1 + ) + si_sdr = 10 * torch.log10(ratio).mean(dim=-1) + + if self.reduction == "sum": + si_sdr = si_sdr.sum() + elif self.reduction == "mean": + si_sdr = si_sdr.mean() + else: + pass + + return si_sdr + + +class Stoi: + """ + STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. + Note that input will be moved to cpu to perform the metric calculation. + parameters: + sr: int + sampling rate + """ + + def __init__(self, sr: int): + self.sr = sr + self.stoi = ShortTimeObjectiveIntelligibility(fs=sr) + self.name = "stoi" + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + + return self.stoi(prediction, target) + + +class Pesq: + def __init__(self, sr: int, mode="wb"): + + self.sr = sr + self.name = "pesq" + self.mode = mode + self.pesq = PerceptualEvaluationSpeechQuality( + fs=self.sr, mode=self.mode + ) + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + + pesq_values = [] + for pred, target_ in zip(prediction, target): + try: + pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) + except Exception as e: + logging.warning(f"{e} error occured while calculating PESQ") + return torch.tensor(np.mean(pesq_values)) + + +class LossWrapper(nn.Module): + """ + Combine multiple metics of same nature. + for example, ["mea","mae"] + parameters: + losses : loss function names to be combined + """ + + def __init__(self, losses): + super().__init__() + + self.valid_losses = nn.ModuleList() + + direction = [ + getattr(LOSS_MAP[loss](), "higher_better") for loss in losses + ] + if len(set(direction)) > 1: + raise ValueError( + "all cost functions should be of same nature, maximize or minimize!" + ) + + self.higher_better = direction[0] + self.name = "" + for loss in losses: + loss = self.validate_loss(loss) + self.valid_losses.append(loss()) + self.name += f"{loss().name}_" + + def validate_loss(self, loss: str): + if loss not in LOSS_MAP.keys(): + raise ValueError( + f"""Invalid loss function {loss}, available loss functions are + {tuple([loss for loss in LOSS_MAP.keys()])}""" + ) + else: + return LOSS_MAP[loss] + + def forward(self, prediction: torch.Tensor, target: torch.Tensor): + loss = 0.0 + for loss_fun in self.valid_losses: + loss += loss_fun(prediction, target) + + return loss + + +class Si_snr(nn.Module): + """ + SI-SNR + """ + + def __init__(self, **kwargs): + super().__init__() + + self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) + self.higher_better = True + self.name = "si_snr" + + def forward(self, prediction: torch.Tensor, target: torch.Tensor): + + if prediction.size() != target.size() or target.ndim < 3: + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + return self.loss_fun(prediction, target) + + +LOSS_MAP = { + "mae": mean_absolute_error, + "mse": mean_squared_error, + "si-sdr": Si_SDR, + "pesq": Pesq, + "stoi": Stoi, + "si-snr": Si_snr, +} diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py new file mode 100644 index 0000000..2d97568 --- /dev/null +++ b/enhancer/models/__init__.py @@ -0,0 +1,3 @@ +from enhancer.models.demucs import Demucs +from enhancer.models.model import Model +from enhancer.models.waveunet import WaveUnet diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py new file mode 100644 index 0000000..918a261 --- /dev/null +++ b/enhancer/models/complexnn/__init__.py @@ -0,0 +1,5 @@ +from enhancer.models.complexnn.conv import ComplexConv2d # noqa +from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa +from enhancer.models.complexnn.rnn import ComplexLSTM # noqa +from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa +from enhancer.models.complexnn.utils import ComplexRelu # noqa diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py new file mode 100644 index 0000000..d9a4d0f --- /dev/null +++ b/enhancer/models/complexnn/conv.py @@ -0,0 +1,136 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def init_weights(nnet): + nn.init.xavier_normal_(nnet.weight.data) + nn.init.constant_(nnet.bias, 0.0) + return nnet + + +class ComplexConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + groups: int = 1, + dilation: int = 1, + ): + """ + Complex Conv2d (non-causal) + """ + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.dilation = dilation + + self.real_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = init_weights(self.imag_conv) + self.real_conv = init_weights(self.real_conv) + + def forward(self, input): + """ + complex axis should be always 1 dim + """ + input = F.pad(input, [self.padding[1], 0, 0, 0]) + + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + return out + + +class ComplexConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + output_padding: Tuple[int, int] = (0, 0), + groups: int = 1, + ): + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.output_padding = output_padding + + self.real_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + self.imag_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + self.real_conv = init_weights(self.real_conv) + self.imag_conv = init_weights(self.imag_conv) + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + + return out diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py new file mode 100644 index 0000000..847030b --- /dev/null +++ b/enhancer/models/complexnn/rnn.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from torch import nn + + +class ComplexLSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + projection_size: Optional[int] = None, + bidirectional: bool = False, + ): + super().__init__() + self.input_size = input_size // 2 + self.hidden_size = hidden_size // 2 + self.num_layers = num_layers + + self.real_lstm = nn.LSTM( + self.input_size, + self.hidden_size, + self.num_layers, + bidirectional=bidirectional, + batch_first=False, + ) + self.imag_lstm = nn.LSTM( + self.input_size, + self.hidden_size, + self.num_layers, + bidirectional=bidirectional, + batch_first=False, + ) + + bidirectional = 2 if bidirectional else 1 + if projection_size is not None: + self.projection_size = projection_size // 2 + self.real_linear = nn.Linear( + self.hidden_size * bidirectional, self.projection_size + ) + self.imag_linear = nn.Linear( + self.hidden_size * bidirectional, self.projection_size + ) + else: + self.projection_size = None + + def forward(self, input): + + if isinstance(input, List): + real, imag = input + else: + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_lstm(real)[0] + real_imag = self.imag_lstm(real)[0] + + imag_imag = self.imag_lstm(imag)[0] + imag_real = self.real_lstm(imag)[0] + + real = real_real - imag_imag + imag = imag_real + real_imag + + if self.projection_size is not None: + real = self.real_linear(real) + imag = self.imag_linear(imag) + + return [real, imag] diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py new file mode 100644 index 0000000..0c28f9b --- /dev/null +++ b/enhancer/models/complexnn/utils.py @@ -0,0 +1,199 @@ +import torch +from torch import nn + + +class ComplexBatchNorm2D(nn.Module): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + """ + Complex batch normalization 2D + https://arxiv.org/abs/1705.09792 + + + """ + super().__init__() + self.num_features = num_features // 2 + self.affine = affine + self.momentum = momentum + self.track_running_stats = track_running_stats + self.eps = eps + + if self.affine: + self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features)) + else: + self.register_parameter("Wrr", None) + self.register_parameter("Wri", None) + self.register_parameter("Wii", None) + self.register_parameter("Br", None) + self.register_parameter("Bi", None) + + if self.track_running_stats: + values = torch.zeros(self.num_features) + self.register_buffer("Mean_real", values) + self.register_buffer("Mean_imag", values) + self.register_buffer("Var_rr", values) + self.register_buffer("Var_ri", values) + self.register_buffer("Var_ii", values) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("Mean_real", None) + self.register_parameter("Mean_imag", None) + self.register_parameter("Var_rr", None) + self.register_parameter("Var_ri", None) + self.register_parameter("Var_ii", None) + self.register_parameter("num_batches_tracked", None) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.Wrr.data.fill_(1) + self.Wii.data.fill_(1) + self.Wri.data.uniform_(-0.9, 0.9) + self.Br.data.fill_(0) + self.Bi.data.fill_(0) + self.reset_running_stats() + + def reset_running_stats(self): + if self.track_running_stats: + self.Mean_real.zero_() + self.Mean_imag.zero_() + self.Var_rr.fill_(1) + self.Var_ri.zero_() + self.Var_ii.fill_(1) + self.num_batches_tracked.zero_() + + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( + **self.__dict__ + ) + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + exp_avg_factor = 0.0 + + training = self.training and self.track_running_stats + if training: + self.num_batches_tracked += 1 + if self.momentum is None: + exp_avg_factor = 1 / self.num_batches_tracked + else: + exp_avg_factor = self.momentum + + redux = [i for i in reversed(range(real.dim())) if i != 1] + vdim = [1] * real.dim() + vdim[1] = real.size(1) + + if training: + batch_mean_real, batch_mean_imag = real, imag + for dim in redux: + batch_mean_real = batch_mean_real.mean(dim, keepdim=True) + batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True) + if self.track_running_stats: + self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor) + self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor) + + else: + batch_mean_real = self.Mean_real.view(vdim) + batch_mean_imag = self.Mean_imag.view(vdim) + + real = real - batch_mean_real + imag = imag - batch_mean_imag + + if training: + batch_var_rr = real * real + batch_var_ri = real * imag + batch_var_ii = imag * imag + for dim in redux: + batch_var_rr = batch_var_rr.mean(dim, keepdim=True) + batch_var_ri = batch_var_ri.mean(dim, keepdim=True) + batch_var_ii = batch_var_ii.mean(dim, keepdim=True) + if self.track_running_stats: + self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) + self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) + self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) + else: + batch_var_rr = self.Var_rr.view(vdim) + batch_var_ii = self.Var_ii.view(vdim) + batch_var_ri = self.Var_ri.view(vdim) + + batch_var_rr += self.eps + batch_var_ii += self.eps + + # Covariance matrics + # | batch_var_rr batch_var_ri | + # | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri + # Inverse square root of cov matrix by combining below two formulas + # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix + # https://mathworld.wolfram.com/MatrixInverse.html + + tau = batch_var_rr + batch_var_ii + s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri + t = (tau + 2 * s).sqrt() + + rst = (s * t).reciprocal() + Urr = (batch_var_ii + s) * rst + Uri = -batch_var_ri * rst + Uii = (batch_var_rr + s) * rst + + if self.affine: + Wrr, Wri, Wii = ( + self.Wrr.view(vdim), + self.Wri.view(vdim), + self.Wii.view(vdim), + ) + Zrr = (Wrr * Urr) + (Wri * Uri) + Zri = (Wrr * Uri) + (Wri * Uii) + Zir = (Wii * Uri) + (Wri * Urr) + Zii = (Wri * Uri) + (Wii * Uii) + else: + Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii + + yr = (Zrr * real) + (Zri * imag) + yi = (Zir * real) + (Zii * imag) + + if self.affine: + yr = yr + self.Br.view(vdim) + yi = yi + self.Bi.view(vdim) + + outputs = torch.cat([yr, yi], 1) + return outputs + + +class ComplexRelu(nn.Module): + def __init__(self): + super().__init__() + self.real_relu = nn.PReLU() + self.imag_relu = nn.PReLU() + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + real = self.real_relu(real) + imag = self.imag_relu(imag) + return torch.cat([real, imag], dim=1) + + +def complex_cat(inputs, axis=1): + + real, imag = [], [] + for data in inputs: + real_data, imag_data = torch.chunk(data, 2, axis) + real.append(real_data) + imag.append(imag_data) + real = torch.cat(real, axis) + imag = torch.cat(imag, axis) + return torch.cat([real, imag], axis) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py new file mode 100644 index 0000000..7b1e5b1 --- /dev/null +++ b/enhancer/models/dccrn.py @@ -0,0 +1,338 @@ +import logging +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from enhancer.data import EnhancerDataset +from enhancer.models import Model +from enhancer.models.complexnn import ( + ComplexBatchNorm2D, + ComplexConv2d, + ComplexConvTranspose2d, + ComplexLSTM, + ComplexRelu, +) +from enhancer.models.complexnn.utils import complex_cat +from enhancer.utils.transforms import ConviSTFT, ConvSTFT +from enhancer.utils.utils import merge_dict + + +class DCCRN_ENCODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channel: int, + kernel_size: Tuple[int, int], + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 1), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + self.encoder = nn.Sequential( + ComplexConv2d( + in_channels, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + batchnorm(out_channel), + activation, + ) + + def forward(self, waveform): + + return self.encoder(waveform) + + +class DCCRN_DECODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + layer: int = 0, + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 0), + output_padding: Tuple[int, int] = (1, 0), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + if layer != 0: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + batchnorm(out_channels), + activation, + ) + else: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + ) + + def forward(self, waveform): + + return self.decoder(waveform) + + +class DCCRN(Model): + + STFT_DEFAULTS = { + "window_len": 400, + "hop_size": 100, + "nfft": 512, + "window": "hamming", + } + + ED_DEFAULTS = { + "initial_output_channels": 32, + "depth": 6, + "kernel_size": 5, + "growth_factor": 2, + "stride": 2, + "padding": 2, + "output_padding": 1, + } + + LSTM_DEFAULTS = { + "num_layers": 2, + "hidden_size": 256, + } + + def __init__( + self, + stft: Optional[dict] = None, + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + complex_lstm: bool = True, + complex_norm: bool = True, + complex_relu: bool = True, + masking_mode: str = "E", + num_channels: int = 1, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, + loss: Union[str, List, Any] = "mse", + metric: Union[str, List] = "mse", + ): + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) + if dataset is not None: + if sampling_rate != dataset.sampling_rate: + logging.warning( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) + sampling_rate = dataset.sampling_rate + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + stft = merge_dict(self.STFT_DEFAULTS, stft) + self.save_hyperparameters( + "encoder_decoder", + "lstm", + "stft", + "complex_lstm", + "complex_norm", + "masking_mode", + ) + self.complex_lstm = complex_lstm + self.complex_norm = complex_norm + self.masking_mode = masking_mode + + self.stft = ConvSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + self.istft = ConviSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + num_channels *= 2 + hidden_size = encoder_decoder["initial_output_channels"] + growth_factor = 2 + + for layer in range(encoder_decoder["depth"]): + + encoder_ = DCCRN_ENCODER( + num_channels, + hidden_size, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 1), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + self.encoder.append(encoder_) + + decoder_ = DCCRN_DECODER( + hidden_size + hidden_size, + num_channels, + layer=layer, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 0), + output_padding=(encoder_decoder["output_padding"], 0), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + + self.decoder.insert(0, decoder_) + + if layer < encoder_decoder["depth"] - 3: + num_channels = hidden_size + hidden_size *= growth_factor + else: + num_channels = hidden_size + + kernel_size = hidden_size / 2 + hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) + + if self.complex_lstm: + lstms = [] + for layer in range(lstm["num_layers"]): + + if layer == 0: + input_size = int(hidden_size * kernel_size) + else: + input_size = lstm["hidden_size"] + + if layer == lstm["num_layers"] - 1: + projection_size = int(hidden_size * kernel_size) + else: + projection_size = None + + kwargs = { + "input_size": input_size, + "hidden_size": lstm["hidden_size"], + "num_layers": 1, + } + + lstms.append( + ComplexLSTM(projection_size=projection_size, **kwargs) + ) + self.lstm = nn.Sequential(*lstms) + else: + self.lstm = nn.Sequential( + nn.LSTM( + input_size=hidden_size * kernel_size, + hidden_sizs=lstm["hidden_size"], + num_layers=lstm["num_layers"], + dropout=0.0, + batch_first=False, + )[0], + nn.Linear(lstm["hidden"], hidden_size * kernel_size), + ) + + def forward(self, waveform): + + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" + ) + + waveform_stft = self.stft(waveform) + real = waveform_stft[:, : self.stft.nfft // 2 + 1] + imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] + + mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9) + phase_spec = torch.atan2(imag, real) + complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:] + + encoder_outputs = [] + out = complex_spec + for _, encoder in enumerate(self.encoder): + out = encoder(out) + encoder_outputs.append(out) + + B, C, D, T = out.size() + out = out.permute(3, 0, 1, 2) + if self.complex_lstm: + + lstm_real = out[:, :, : C // 2] + lstm_imag = out[:, :, C // 2 :] + lstm_real = lstm_real.reshape(T, B, C // 2 * D) + lstm_imag = lstm_imag.reshape(T, B, C // 2 * D) + lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag]) + lstm_real = lstm_real.reshape(T, B, C // 2, D) + lstm_imag = lstm_imag.reshape(T, B, C // 2, D) + out = torch.cat([lstm_real, lstm_imag], 2) + else: + out = out.reshape(T, B, C * D) + out = self.lstm(out) + out = out.reshape(T, B, D, C) + + out = out.permute(1, 2, 3, 0) + for layer, decoder in enumerate(self.decoder): + skip_connection = encoder_outputs.pop(-1) + out = complex_cat([skip_connection, out]) + out = decoder(out) + out = out[..., 1:] + mask_real, mask_imag = out[:, 0], out[:, 1] + mask_real = F.pad(mask_real, [0, 0, 1, 0]) + mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) + if self.masking_mode == "E": + + mask_mag = torch.sqrt(mask_real**2 + mask_imag**2) + real_phase = mask_real / (mask_mag + 1e-8) + imag_phase = mask_imag / (mask_mag + 1e-8) + mask_phase = torch.atan2(imag_phase, real_phase) + mask_mag = torch.tanh(mask_mag) + est_mag = mask_mag * mag_spec + est_phase = mask_phase * phase_spec + # cos(theta) + isin(theta) + real = est_mag + torch.cos(est_phase) + imag = est_mag + torch.sin(est_phase) + + if self.masking_mode == "C": + + real = real * mask_real - imag * mask_imag + imag = real * mask_imag + imag * mask_real + + else: + + real = real * mask_real + imag = imag * mask_imag + + spec = torch.cat([real, imag], 1) + wav = self.istft(spec) + wav = wav.clamp_(-1, 1) + return wav diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py new file mode 100644 index 0000000..fafb84e --- /dev/null +++ b/enhancer/models/demucs.py @@ -0,0 +1,274 @@ +import logging +import math +from typing import List, Optional, Union + +import torch.nn.functional as F +from torch import nn + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model +from enhancer.utils.io import Audio as audio +from enhancer.utils.utils import merge_dict + + +class DemucsLSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + bidirectional: bool = True, + ): + super().__init__() + self.lstm = nn.LSTM( + input_size, hidden_size, num_layers, bidirectional=bidirectional + ) + dim = 2 if bidirectional else 1 + self.linear = nn.Linear(dim * hidden_size, hidden_size) + + def forward(self, x): + + output, (h, c) = self.lstm(x) + output = self.linear(output) + + return output, (h, c) + + +class DemucsEncoder(nn.Module): + def __init__( + self, + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, + ): + super().__init__() + activation = nn.GLU(1) if glu else nn.ReLU() + multi_factor = 2 if glu else 1 + self.encoder = nn.Sequential( + nn.Conv1d(num_channels, hidden_size, kernel_size, stride), + nn.ReLU(), + nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), + activation, + ) + + def forward(self, waveform): + + return self.encoder(waveform) + + +class DemucsDecoder(nn.Module): + def __init__( + self, + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, + layer: int = 0, + ): + super().__init__() + activation = nn.GLU(1) if glu else nn.ReLU() + multi_factor = 2 if glu else 1 + self.decoder = nn.Sequential( + nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), + activation, + nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), + ) + if layer > 0: + self.decoder.add_module("4", nn.ReLU()) + + def forward( + self, + waveform, + ): + + out = self.decoder(waveform) + return out + + +class Demucs(Model): + """ + Demucs model from https://arxiv.org/pdf/1911.13254.pdf + parameters: + encoder_decoder: dict, optional + keyword arguments passsed to encoder decoder block + lstm : dict, optional + keyword arguments passsed to LSTM block + num_channels: int, defaults to 1 + number channels in input audio + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + + """ + + ED_DEFAULTS = { + "initial_output_channels": 48, + "kernel_size": 8, + "stride": 4, + "depth": 5, + "glu": True, + "growth_factor": 2, + } + LSTM_DEFAULTS = { + "bidirectional": True, + "num_layers": 2, + } + + def __init__( + self, + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + num_channels: int = 1, + resample: int = 4, + sampling_rate=16000, + normalize=True, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + loss: Union[str, List] = "mse", + metric: Union[str, List] = "mse", + floor=1e-3, + ): + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) + if dataset is not None: + if sampling_rate != dataset.sampling_rate: + logging.warning( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) + sampling_rate = dataset.sampling_rate + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + self.save_hyperparameters("encoder_decoder", "lstm", "resample") + hidden = encoder_decoder["initial_output_channels"] + self.normalize = normalize + self.floor = floor + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for layer in range(encoder_decoder["depth"]): + + encoder_layer = DemucsEncoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=encoder_decoder["stride"], + glu=encoder_decoder["glu"], + ) + self.encoder.append(encoder_layer) + + decoder_layer = DemucsDecoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=encoder_decoder["stride"], + glu=encoder_decoder["glu"], + layer=layer, + ) + self.decoder.insert(0, decoder_layer) + + num_channels = hidden + hidden = self.ED_DEFAULTS["growth_factor"] * hidden + + self.de_lstm = DemucsLSTM( + input_size=num_channels, + hidden_size=num_channels, + num_layers=lstm["num_layers"], + bidirectional=lstm["bidirectional"], + ) + + def forward(self, waveform): + + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" + ) + if self.normalize: + waveform = waveform.mean(dim=1, keepdim=True) + std = waveform.std(dim=-1, keepdim=True) + waveform = waveform / (self.floor + std) + else: + std = 1 + length = waveform.shape[-1] + x = F.pad(waveform, (0, self.get_padding_length(length) - length)) + if self.hparams.resample > 1: + x = audio.resample_audio( + audio=x, + sr=self.hparams.sampling_rate, + target_sr=int( + self.hparams.sampling_rate * self.hparams.resample + ), + ) + + encoder_outputs = [] + for encoder in self.encoder: + x = encoder(x) + encoder_outputs.append(x) + x = x.permute(0, 2, 1) + x, _ = self.de_lstm(x) + + x = x.permute(0, 2, 1) + for decoder in self.decoder: + skip_connection = encoder_outputs.pop(-1) + x = x + skip_connection[..., : x.shape[-1]] + x = decoder(x) + + if self.hparams.resample > 1: + x = audio.resample_audio( + x, + int(self.hparams.sampling_rate * self.hparams.resample), + self.hparams.sampling_rate, + ) + + out = x[..., :length] + return std * out + + def get_padding_length(self, input_length): + + input_length = math.ceil(input_length * self.hparams.resample) + + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # encoder operation + input_length = ( + math.ceil( + (input_length - self.hparams.encoder_decoder["kernel_size"]) + / self.hparams.encoder_decoder["stride"] + ) + + 1 + ) + input_length = max(1, input_length) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # decoder operaration + input_length = (input_length - 1) * self.hparams.encoder_decoder[ + "stride" + ] + self.hparams.encoder_decoder["kernel_size"] + input_length = math.ceil(input_length / self.hparams.resample) + + return int(input_length) diff --git a/enhancer/models/model.py b/enhancer/models/model.py new file mode 100644 index 0000000..c679669 --- /dev/null +++ b/enhancer/models/model.py @@ -0,0 +1,431 @@ +import os +from collections import defaultdict +from importlib import import_module +from pathlib import Path +from typing import Any, List, Optional, Text, Union +from urllib.parse import urlparse + +import numpy as np +import pytorch_lightning as pl +import torch +from huggingface_hub import cached_download, hf_hub_url +from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch import nn +from torch.optim import Adam + +from enhancer.data.dataset import EnhancerDataset +from enhancer.inference import Inference +from enhancer.loss import LOSS_MAP, LossWrapper +from enhancer.version import __version__ + +CACHE_DIR = "" +HF_TORCH_WEIGHTS = "" +DEFAULT_DEVICE = "cpu" + + +class Model(pl.LightningModule): + """ + Base class for all models + parameters: + num_channels: int, default to 1 + number of channels in input audio + sampling_rate : int, default 16khz + audio sampling rate + lr: float, optional + learning rate for model training + dataset: EnhancerDataset, optional + Enhancer dataset used for training/validation + duration: float, optional + duration used for training/inference + loss : string or List of strings or custom loss (nn.Module), default to "mse" + loss functions to be used. Available ("mse","mae","Si-SDR") + + """ + + def __init__( + self, + num_channels: int = 1, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, + loss: Union[str, List] = "mse", + metric: Union[str, List, Any] = "mse", + ): + super().__init__() + assert ( + num_channels == 1 + ), "Enhancer only support for mono channel models" + self.dataset = dataset + self.save_hyperparameters( + "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" + ) + if self.logger: + self.logger.experiment.log_dict( + dict(self.hparams), "hyperparameters.json" + ) + + self.loss = loss + self.metric = metric + + @property + def loss(self): + return self._loss + + @loss.setter + def loss(self, loss): + + if isinstance(loss, str): + loss = [loss] + + self._loss = LossWrapper(loss) + + @property + def metric(self): + return self._metric + + @metric.setter + def metric(self, metric): + self._metric = [] + if isinstance(metric, (str, nn.Module)): + metric = [metric] + + for func in metric: + if isinstance(func, str): + if func in LOSS_MAP.keys(): + if func in ("pesq", "stoi"): + self._metric.append( + LOSS_MAP[func](self.hparams.sampling_rate) + ) + else: + self._metric.append(LOSS_MAP[func]()) + else: + ValueError(f"Invalid metrics {func}") + + elif isinstance(func, nn.Module): + self._metric.append(func) + else: + raise ValueError("Invalid metrics") + + @property + def dataset(self): + return self._dataset + + @dataset.setter + def dataset(self, dataset): + self._dataset = dataset + + def setup(self, stage: Optional[str] = None): + if stage == "fit": + torch.cuda.empty_cache() + self.dataset.setup(stage) + self.dataset.model = self + + print( + "Total train duration", + self.dataset.train_dataloader().dataset.__len__() + * self.dataset.duration + / 60, + "minutes", + ) + print( + "Total validation duration", + self.dataset.val_dataloader().dataset.__len__() + * self.dataset.duration + / 60, + "minutes", + ) + print( + "Total test duration", + self.dataset.test_dataloader().dataset.__len__() + * self.dataset.duration + / 60, + "minutes", + ) + + def train_dataloader(self): + return self.dataset.train_dataloader() + + def val_dataloader(self): + return self.dataset.val_dataloader() + + def test_dataloader(self): + return self.dataset.test_dataloader() + + def configure_optimizers(self): + return Adam(self.parameters(), lr=self.hparams.lr) + + def training_step(self, batch, batch_idx: int): + + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + loss = self.loss(prediction, target) + + self.log( + "train_loss", + loss.item(), + on_epoch=True, + on_step=True, + logger=True, + prog_bar=True, + ) + + return {"loss": loss} + + def validation_step(self, batch, batch_idx: int): + + metric_dict = {} + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + + metric_dict["valid_loss"] = self.loss(target, prediction).item() + for metric in self.metric: + value = metric(target, prediction) + metric_dict[f"valid_{metric.name}"] = value.item() + + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return metric_dict + + def test_step(self, batch, batch_idx): + + metric_dict = {} + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + + for metric in self.metric: + value = metric(target, prediction) + metric_dict[f"test_{metric.name}"] = value + + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return metric_dict + + def test_epoch_end(self, outputs): + + test_mean_metrics = defaultdict(int) + for output in outputs: + for metric, value in output.items(): + test_mean_metrics[metric] += value.item() + for metric in test_mean_metrics.keys(): + test_mean_metrics[metric] /= len(outputs) + + print("----------TEST REPORT----------\n") + for metric in test_mean_metrics.keys(): + print(f"|{metric.upper()} | {test_mean_metrics[metric]} |") + print("--------------------------------") + + def on_save_checkpoint(self, checkpoint): + + checkpoint["enhancer"] = { + "version": {"enhancer": __version__, "pytorch": torch.__version__}, + "architecture": { + "module": self.__class__.__module__, + "class": self.__class__.__name__, + }, + } + + @classmethod + def from_pretrained( + cls, + checkpoint: Union[Path, Text], + map_location=None, + hparams_file: Union[Path, Text] = None, + strict: bool = True, + use_auth_token: Union[Text, None] = None, + cached_dir: Union[Path, Text] = CACHE_DIR, + **kwargs, + ): + """ + Load Pretrained model + + parameters: + checkpoint : Path or str + Path to checkpoint, or a remote URL, or a model identifier from + the huggingface.co model hub. + map_location: optional + Same role as in torch.load(). + Defaults to `lambda storage, loc: storage`. + hparams_file : Path or str, optional + Path to a .yaml file with hierarchical structure as in this example: + drop_prob: 0.2 + dataloader: + batch_size: 32 + You most likely won’t need this since Lightning will always save the + hyperparameters to the checkpoint. However, if your checkpoint weights + do not have the hyperparameters saved, use this method to pass in a .yaml + file with the hparams you would like to use. These will be converted + into a dict and passed into your Model for use. + strict : bool, optional + Whether to strictly enforce that the keys in checkpoint match + the keys returned by this module’s state dict. Defaults to True. + use_auth_token : str, optional + When loading a private huggingface.co model, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + cache_dir: Path or str, optional + Path to model cache directory + kwargs: optional + Any extra keyword args needed to init the model. + Can also be used to override saved hyperparameter values. + + Returns + ------- + model : Model + Model + + See also + -------- + torch.load + """ + + checkpoint = str(checkpoint) + if hparams_file is not None: + hparams_file = str(hparams_file) + + if os.path.isfile(checkpoint): + model_path_pl = checkpoint + elif urlparse(checkpoint).scheme in ("http", "https"): + model_path_pl = checkpoint + else: + + if "@" in checkpoint: + model_id = checkpoint.split("@")[0] + revision_id = checkpoint.split("@")[1] + else: + model_id = checkpoint + revision_id = None + + url = hf_hub_url( + model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id + ) + model_path_pl = cached_download( + url=url, + library_name="enhancer", + library_version=__version__, + cache_dir=cached_dir, + use_auth_token=use_auth_token, + ) + + if map_location is None: + map_location = torch.device(DEFAULT_DEVICE) + + loaded_checkpoint = pl_load(model_path_pl, map_location) + module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] + module = import_module(module_name) + Klass = getattr(module, class_name) + + try: + model = Klass.load_from_checkpoint( + checkpoint_path=model_path_pl, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, + ) + except Exception as e: + print(e) + + return model + + def infer(self, batch: torch.Tensor, batch_size: int = 32): + """ + perform model inference + parameters: + batch : torch.Tensor + input data + batch_size : int, default 32 + batch size for inference + """ + + assert ( + batch.ndim == 3 + ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" + batch_predictions = [] + self.eval().to(self.device) + with torch.no_grad(): + for batch_id in range(0, batch.shape[0], batch_size): + batch_data = batch[batch_id : (batch_id + batch_size), :, :].to( + self.device + ) + prediction = self(batch_data) + batch_predictions.append(prediction) + + return torch.vstack(batch_predictions) + + def enhance( + self, + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + batch_size: int = 32, + save_output: bool = False, + duration: Optional[int] = None, + step_size: Optional[int] = None, + ): + """ + Enhance audio using loaded pretained model. + + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate: int, optional incase input is path + sampling rate of input + batch_size: int, default 32 + input audio is split into multiple chunks. Inference is done on batches + of these chunks according to given batch size. + save_output : bool, default False + weather to save output to file + duration : float, optional + chunk duration in seconds, defaults to duration of loaded pretrained model. + step_size: int, optional + step size between consecutive durations, defaults to 50% of duration + """ + + model_sampling_rate = self.hparams["sampling_rate"] + if duration is None: + duration = self.hparams["duration"] + waveform = Inference.read_input( + audio, sampling_rate, model_sampling_rate + ) + waveform.to(self.device) + window_size = round(duration * model_sampling_rate) + batched_waveform = Inference.batchify( + waveform, window_size, step_size=step_size + ) + batch_prediction = self.infer(batched_waveform, batch_size=batch_size) + waveform = Inference.aggreagate( + batch_prediction, + window_size, + waveform.shape[-1], + step_size, + ) + + if save_output and isinstance(audio, (str, Path)): + Inference.write_output(waveform, audio, model_sampling_rate) + + else: + waveform = Inference.prepare_output( + waveform, model_sampling_rate, audio, sampling_rate + ) + return waveform + + @property + def valid_monitor(self): + + return "max" if self.loss.higher_better else "min" diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py new file mode 100644 index 0000000..ea5646a --- /dev/null +++ b/enhancer/models/waveunet.py @@ -0,0 +1,201 @@ +import logging +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model + + +class WavenetDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 5, + padding: int = 2, + stride: int = 1, + dilation: int = 1, + ): + super(WavenetDecoder, self).__init__() + self.decoder = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), + nn.BatchNorm1d(out_channels), + nn.LeakyReLU(negative_slope=0.1), + ) + + def forward(self, waveform): + + return self.decoder(waveform) + + +class WavenetEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 15, + padding: int = 7, + stride: int = 1, + dilation: int = 1, + ): + super(WavenetEncoder, self).__init__() + self.encoder = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), + nn.BatchNorm1d(out_channels), + nn.LeakyReLU(negative_slope=0.1), + ) + + def forward(self, waveform): + return self.encoder(waveform) + + +class WaveUnet(Model): + """ + Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf + parameters: + num_channels: int, defaults to 1 + number of channels in input audio + depth : int, defaults to 12 + depth of network + initial_output_channels: int, defaults to 24 + number of output channels in intial upsampling layer + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + """ + + def __init__( + self, + num_channels: int = 1, + depth: int = 12, + initial_output_channels: int = 24, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, + loss: Union[str, List] = "mse", + metric: Union[str, List] = "mse", + ): + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) + if dataset is not None: + if sampling_rate != dataset.sampling_rate: + logging.warning( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) + sampling_rate = dataset.sampling_rate + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + self.save_hyperparameters("depth") + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + out_channels = initial_output_channels + for layer in range(depth): + + encoder = WavenetEncoder(num_channels, out_channels) + self.encoders.append(encoder) + + num_channels = out_channels + out_channels += initial_output_channels + if layer == depth - 1: + decoder = WavenetDecoder( + depth * initial_output_channels + num_channels, num_channels + ) + else: + decoder = WavenetDecoder( + num_channels + out_channels, num_channels + ) + + self.decoders.insert(0, decoder) + + bottleneck_dim = depth * initial_output_channels + self.bottleneck = nn.Sequential( + nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7), + nn.BatchNorm1d(bottleneck_dim), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + ) + self.final = nn.Sequential( + nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), + nn.Tanh(), + ) + + def forward(self, waveform): + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != 1: + raise TypeError( + f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels" + ) + + encoder_outputs = [] + out = waveform + for encoder in self.encoders: + out = encoder(out) + encoder_outputs.insert(0, out) + out = out[:, :, ::2] + + out = self.bottleneck(out) + + for layer, decoder in enumerate(self.decoders): + out = F.interpolate(out, scale_factor=2, mode="linear") + out = self.fix_last_dim(out, encoder_outputs[layer]) + out = torch.cat([out, encoder_outputs[layer]], dim=1) + out = decoder(out) + + out = torch.cat([out, waveform], dim=1) + out = self.final(out) + return out + + def fix_last_dim(self, x, target): + """ + centre crop along last dimension + """ + + assert ( + x.shape[-1] >= target.shape[-1] + ), "input dimension cannot be larger than target dimension" + if x.shape[-1] == target.shape[-1]: + return x + + diff = x.shape[-1] - target.shape[-1] + if diff % 2 != 0: + x = F.pad(x, (0, 1)) + diff += 1 + + crop = diff // 2 + return x[:, :, crop:-crop] diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py new file mode 100644 index 0000000..de0db9f --- /dev/null +++ b/enhancer/utils/__init__.py @@ -0,0 +1,3 @@ +from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.utils import check_files diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py new file mode 100644 index 0000000..252e6c9 --- /dev/null +++ b/enhancer/utils/config.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + + +@dataclass +class Files: + train_clean: str + train_noisy: str + test_clean: str + test_noisy: str diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py new file mode 100644 index 0000000..d151ef8 --- /dev/null +++ b/enhancer/utils/io.py @@ -0,0 +1,128 @@ +import os +from pathlib import Path +from typing import Optional, Union + +import librosa +import numpy as np +import torch +import torchaudio + + +class Audio: + """ + Audio utils + parameters: + sampling_rate : int, defaults to 16KHz + audio sampling rate + mono: bool, defaults to True + return_tensors: bool, defaults to True + returns torch tensor type if set to True else numpy ndarray + """ + + def __init__( + self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True + ) -> None: + + self.sampling_rate = sampling_rate + self.mono = mono + self.return_tensor = return_tensor + + def __call__( + self, + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + offset: Optional[float] = None, + duration: Optional[float] = None, + ): + """ + read and process input audio + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate : int, optional + sampling rate of the audio input + offset: float, optional + offset from which the audio must be read, reads from beginning if unused. + duration: float (seconds), optional + read duration, reads full audio starting from offset if not used + """ + if isinstance(audio, str): + if os.path.exists(audio): + audio, sampling_rate = librosa.load( + audio, + sr=sampling_rate, + mono=False, + offset=offset, + duration=duration, + ) + if len(audio.shape) == 1: + audio = audio.reshape(1, -1) + else: + raise FileNotFoundError(f"File {audio} deos not exist") + elif isinstance(audio, np.ndarray): + if len(audio.shape) == 1: + audio = audio.reshape(1, -1) + else: + raise ValueError("audio should be either filepath or numpy ndarray") + + if self.mono: + audio = self.convert_mono(audio) + + if sampling_rate: + audio = self.__class__.resample_audio( + audio, sampling_rate, self.sampling_rate + ) + if self.return_tensor: + return torch.tensor(audio) + else: + return audio + + @staticmethod + def convert_mono(audio: Union[np.ndarray, torch.Tensor]): + """ + convert input audio into mono (1) + parameters: + audio: np.ndarray or torch.Tensor + """ + if len(audio.shape) > 2: + assert ( + audio.shape[0] == 1 + ), "convert mono only accepts single waveform" + audio = audio.reshape(audio.shape[1], audio.shape[2]) + + assert ( + audio.shape[1] >> audio.shape[0] + ), f"expected input format (num_channels,num_samples) got {audio.shape}" + num_channels, num_samples = audio.shape + if num_channels > 1: + return audio.mean(axis=0).reshape(1, num_samples) + return audio + + @staticmethod + def resample_audio( + audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int + ): + """ + resample audio to desired sampling rate + parameters: + audio : Path to audio file or numpy array or torch tensor + audio waveform + sr : int + current sampling rate + target_sr : int + target sampling rate + + """ + if sr != target_sr: + if isinstance(audio, np.ndarray): + audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) + elif isinstance(audio, torch.Tensor): + audio = torchaudio.functional.resample( + audio, orig_freq=sr, new_freq=target_sr + ) + else: + raise ValueError( + "Input should be either numpy array or torch tensor" + ) + + return audio diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py new file mode 100644 index 0000000..dd9395a --- /dev/null +++ b/enhancer/utils/random.py @@ -0,0 +1,36 @@ +import os +import random + +import torch + + +def create_unique_rng(epoch: int): + """create unique random number generator for each (worker_id,epoch) combination""" + + rng = random.Random() + + global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) + global_rank = int(os.environ.get("GLOBAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + node_rank = int(os.environ.get("NODE_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "0")) + + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker_id = worker_info.id + else: + num_workers = 1 + worker_id = 0 + + seed = ( + global_seed + + worker_id + + local_rank * num_workers + + node_rank * num_workers * global_rank + + epoch * num_workers * world_size + ) + + rng.seed(seed) + + return rng diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py new file mode 100644 index 0000000..5af1f92 --- /dev/null +++ b/enhancer/utils/transforms.py @@ -0,0 +1,93 @@ +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.signal import get_window +from torch import nn + + +class ConvFFT(nn.Module): + def __init__( + self, + window_len: int, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super().__init__() + self.window_len = window_len + self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) + self.window = torch.from_numpy( + get_window(window, window_len, fftbins=True).astype("float32") + ) + + def init_kernel(self, inverse=False): + + fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] + real, imag = np.real(fourier_basis), np.imag(fourier_basis) + kernel = np.concatenate([real, imag], 1).T + if inverse: + kernel = np.linalg.pinv(kernel).T + kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) + kernel *= self.window + return kernel + + +class ConvSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super().__init__(window_len=window_len, nfft=nfft, window=window) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel()) + + def forward(self, input): + + if input.dim() < 2: + raise ValueError( + f"Expected signal with shape 2 or 3 got {input.dim()}" + ) + elif input.dim() == 2: + input = input.unsqueeze(1) + else: + pass + input = F.pad( + input, + (self.window_len - self.hop_size, self.window_len - self.hop_size), + ) + output = F.conv1d(input, self.weight, stride=self.hop_size) + + return output + + +class ConviSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super().__init__(window_len=window_len, nfft=nfft, window=window) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel(True)) + self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) + + def forward(self, input, phase=None): + + if phase is not None: + real = input * torch.cos(phase) + imag = input * torch.sin(phase) + input = torch.cat([real, imag], 1) + out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) + coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 + coeff = coeff.to(input.device) + coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) + out = out / (coeff + 1e-8) + pad = self.window_len - self.hop_size + out = out[..., pad:-pad] + return out diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py new file mode 100644 index 0000000..ad45139 --- /dev/null +++ b/enhancer/utils/utils.py @@ -0,0 +1,27 @@ +import os +from typing import Optional + +from enhancer.utils.config import Files + + +def check_files(root_dir: str, files: Files): + + path_variables = [ + member_var + for member_var in dir(files) + if not member_var.startswith("__") + ] + for variable in path_variables: + path = getattr(files, variable) + if not os.path.isdir(os.path.join(root_dir, path)): + raise ValueError(f"Invalid {path}, is not a directory") + + return files, root_dir + + +def merge_dict(default_dict: dict, custom: Optional[dict] = None): + + params = dict(default_dict) + if custom: + params.update(custom) + return params diff --git a/enhancer/version.py b/enhancer/version.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/enhancer/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..8da22e1 --- /dev/null +++ b/environment.yml @@ -0,0 +1,8 @@ +name: enhancer + +dependencies: + - pip=21.0.1 + - python=3.8 + - pip: + - -r requirements.txt + - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b3e5d7c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[tool.black] +line-length = 80 +target-version = ['py38'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.mypy_cache + | \.tox + | \.venv + )/ +) +''' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fb54920 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +boto3>=1.24.86 +huggingface-hub>=0.10.0 +hydra-core>=1.2.0 +joblib>=1.2.0 +librosa>=0.9.2 +mlflow>=1.29.0 +numpy>=1.23.3 +pesq==0.0.4 +protobuf>=3.19.6 +pystoi==0.3.3 +pytest-lazy-fixture>=0.6.3 +pytorch-lightning>=1.7.7 +scikit-learn>=1.1.2 +scipy>=1.9.1 +soundfile>=0.11.0 +torch>=1.12.1 +torch-audiomentations==0.11.0 +torchaudio>=0.12.1 +tqdm>=4.64.1 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..309ac9a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,100 @@ +# This file is used to configure your project. +# Read more about the various options under: +# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files + +[metadata] +name = enhancer +description = Deep learning for speech enhacement +author = Shahul Ess +author-email = shahules786@gmail.com +license = mit +long-description = file: README.md +long-description-content-type = text/markdown; charset=UTF-8; variant=GFM +# Change if running only on Windows, Mac or Linux (comma-separated) +platforms = Linux, Mac +# Add here all kinds of additional classifiers as defined under +# https://pypi.python.org/pypi?%3Aaction=list_classifiers +classifiers = + Development Status :: 4 - Beta + Programming Language :: Python + +[options] +zip_safe = False +packages = find: +include_package_data = True +# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! +setup_requires = setuptools +# Add here dependencies of your project (semicolon/line-separated), e.g. +# install_requires = numpy; scipy +# Require a specific Python version, e.g. Python 2.7 or >= 3.4 +python_requires = >=3.8 + +[options.packages.find] +where = . +exclude = + tests + +[options.extras_require] +# Add here additional requirements for extra features, to install with: +# `pip install fastaudio[PDF]` like: +# PDF = ReportLab; RXP +# Add here test requirements (semicolon/line-separated) +testing = + pytest>=7.1.3 + pytest-cov>=4.0.0 +dev = + pre-commit>=2.20.0 + black>=22.8.0 + flake8>=5.0.4 +cli = + hydra-core >=1.1,<=1.2 + + +[options.entry_points] + +console_scripts = + enhancer-train=enhancer.cli.train:train + +[test] +# py.test options when running `python setup.py test` +# addopts = --verbose +extras = True + +[tool:pytest] +# Options for py.test: +# Specify command line options as you would do when invoking py.test directly. +# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml +# in order to write a coverage file that can be read by Jenkins. +addopts = + --cov enhancer --cov-report term-missing + --verbose +norecursedirs = + dist + build + .tox +testpaths = tests + +[aliases] +dists = bdist_wheel + +[bdist_wheel] +# Use this option if your package is pure-python +universal = 1 + +[build_sphinx] +source_dir = doc +build_dir = build/sphinx + +[devpi:upload] +# Options for the devpi: PyPI server and packaging tool +# VCS export must be deactivated since we are using setuptools-scm +no-vcs = 1 +formats = bdist_wheel + +[flake8] +# Some sane defaults for the code style checker flake8 +exclude = + .tox + build + dist + .eggs diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..79282b3 --- /dev/null +++ b/setup.py @@ -0,0 +1,63 @@ +import os +import sys +from pathlib import Path + +from pkg_resources import VersionConflict, require +from setuptools import find_packages, setup + +with open("README.md") as f: + long_description = f.read() + +with open("requirements.txt") as f: + requirements = f.read().splitlines() + +try: + require("setuptools>=38.3") +except VersionConflict: + print("Error: version of setuptools is too old (<38.3)!") + sys.exit(1) + + +ROOT_DIR = Path(__file__).parent.resolve() +# Creating the version file + +with open("version.txt") as f: + version = f.read() + +version = version.strip() +sha = "Unknown" + +if os.getenv("BUILD_VERSION"): + version = os.getenv("BUILD_VERSION") +elif sha != "Unknown": + version += "+" + sha[:7] +print("-- Building version " + version) + +version_path = ROOT_DIR / "enhancer" / "version.py" + +with open(version_path, "w") as f: + f.write("__version__ = '{}'\n".format(version)) + +if __name__ == "__main__": + setup( + name="enhancer", + namespace_packages=["enhancer"], + version=version, + packages=find_packages(), + install_requires=requirements, + description="Deep learning toolkit for speech enhancement", + long_description=long_description, + long_description_content_type="text/markdown", + author="Shahul Es", + author_email="shahules786@gmail.com", + url="", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + ], + ) diff --git a/tests/data/vctk/clean_testset_wav/p257_166.wav b/tests/data/vctk/clean_testset_wav/p257_166.wav new file mode 100644 index 0000000..932df27 Binary files /dev/null and b/tests/data/vctk/clean_testset_wav/p257_166.wav differ diff --git a/tests/data/vctk/clean_testset_wav/p257_167.wav b/tests/data/vctk/clean_testset_wav/p257_167.wav new file mode 100644 index 0000000..6b72b76 Binary files /dev/null and b/tests/data/vctk/clean_testset_wav/p257_167.wav differ diff --git a/tests/data/vctk/noisy_testset_wav/p257_166.wav b/tests/data/vctk/noisy_testset_wav/p257_166.wav new file mode 100644 index 0000000..139b7cb Binary files /dev/null and b/tests/data/vctk/noisy_testset_wav/p257_166.wav differ diff --git a/tests/data/vctk/noisy_testset_wav/p257_167.wav b/tests/data/vctk/noisy_testset_wav/p257_167.wav new file mode 100644 index 0000000..59788b7 Binary files /dev/null and b/tests/data/vctk/noisy_testset_wav/p257_167.wav differ diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py new file mode 100644 index 0000000..4d14871 --- /dev/null +++ b/tests/loss_function_test.py @@ -0,0 +1,32 @@ +import pytest +import torch + +from enhancer.loss import mean_absolute_error, mean_squared_error + +loss_functions = [mean_absolute_error(), mean_squared_error()] + + +def check_loss_shapes_compatibility(loss_fun): + + batch_size = 4 + shape = (1, 1000) + loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape)) + + with pytest.raises(TypeError): + loss_fun(torch.rand(4, *shape), torch.rand(6, *shape)) + + +@pytest.mark.parametrize("loss", loss_functions) +def test_loss_input_shapes(loss): + check_loss_shapes_compatibility(loss) + + +@pytest.mark.parametrize("loss", loss_functions) +def test_loss_output_type(loss): + + batch_size = 4 + prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( + batch_size, 1, 1000 + ) + loss_value = loss(prediction, target) + assert isinstance(loss_value.item(), float) diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py new file mode 100644 index 0000000..524a6cf --- /dev/null +++ b/tests/models/complexnn_test.py @@ -0,0 +1,50 @@ +import torch + +from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from enhancer.models.complexnn.rnn import ComplexLSTM +from enhancer.models.complexnn.utils import ComplexBatchNorm2D + + +def test_complexconv2d(): + sample_input = torch.rand(1, 2, 256, 13) + conv = ComplexConv2d( + 2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1) + ) + with torch.no_grad(): + out = conv(sample_input) + assert out.shape == torch.Size([1, 32, 128, 13]) + + +def test_complexconvtranspose2d(): + sample_input = torch.rand(1, 512, 4, 13) + conv = ComplexConvTranspose2d( + 256 * 2, + 128 * 2, + kernel_size=(5, 2), + stride=(2, 1), + padding=(2, 0), + output_padding=(1, 0), + ) + with torch.no_grad(): + out = conv(sample_input) + + assert out.shape == torch.Size([1, 256, 8, 14]) + + +def test_complexlstm(): + sample_input = torch.rand(13, 2, 128) + lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2) + with torch.no_grad(): + out = lstm(sample_input) + + assert out[0].shape == torch.Size([13, 1, 512]) + assert out[1].shape == torch.Size([13, 1, 512]) + + +def test_complexbatchnorm2d(): + sample_input = torch.rand(1, 64, 64, 14) + batchnorm = ComplexBatchNorm2D(num_features=64) + with torch.no_grad(): + out = batchnorm(sample_input) + + assert out.size() == sample_input.size() diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py new file mode 100644 index 0000000..29e030e --- /dev/null +++ b/tests/models/demucs_test.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models import Demucs +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = Demucs() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(ValueError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py new file mode 100644 index 0000000..96a853b --- /dev/null +++ b/tests/models/test_dccrn.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.dccrn import DCCRN +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = DCCRN() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(ValueError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py new file mode 100644 index 0000000..9c4dd96 --- /dev/null +++ b/tests/models/test_waveunet.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models import WaveUnet +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = WaveUnet() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(TypeError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 0000000..a6e2423 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from enhancer.inference import Inference + + +@pytest.mark.parametrize( + "audio", + ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], +) +def test_read_input(audio): + + read_audio = Inference.read_input(audio, 48000, 16000) + assert isinstance(read_audio, torch.Tensor) + assert read_audio.shape[0] == 1 + + +def test_batchify(): + rand = torch.rand(1, 1000) + batched_rand = Inference.batchify(rand, window_size=100, step_size=100) + assert batched_rand.shape[0] == 12 + + +def test_aggregate(): + rand = torch.rand(12, 1, 100) + agg_rand = Inference.aggreagate( + data=rand, window_size=100, total_frames=1000, step_size=100 + ) + assert agg_rand.shape[-1] == 1000 diff --git a/tests/transforms_test.py b/tests/transforms_test.py new file mode 100644 index 0000000..89425ad --- /dev/null +++ b/tests/transforms_test.py @@ -0,0 +1,18 @@ +import torch + +from enhancer.utils.transforms import ConviSTFT, ConvSTFT + + +def test_stft_istft(): + sample_input = torch.rand(1, 1, 16000) + stft = ConvSTFT(window_len=400, hop_size=100, nfft=512) + istft = ConviSTFT(window_len=400, hop_size=100, nfft=512) + + with torch.no_grad(): + spectrogram = stft(sample_input) + waveform = istft(spectrogram) + assert sample_input.shape == waveform.shape + assert ( + torch.isclose(waveform, sample_input).sum().item() + > sample_input.shape[-1] // 2 + ) diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 0000000..cd5240c --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import torch + +from enhancer.data.fileprocessor import Fileprocessor +from enhancer.utils.io import Audio + + +def test_io_channel(): + + input_audio = np.random.rand(2, 32000) + audio = Audio(mono=True, return_tensor=False) + output_audio = audio(input_audio) + assert output_audio.shape[0] == 1 + + +def test_io_resampling(): + + input_audio = np.random.rand(1, 32000) + resampled_audio = Audio.resample_audio(input_audio, 16000, 8000) + + input_audio = torch.rand(1, 32000) + resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000) + + assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 + + +def test_fileprocessor_vctk(): + + fp = Fileprocessor.from_name( + "vctk", + "tests/data/vctk/clean_testset_wav", + "tests/data/vctk/noisy_testset_wav", + ) + matching_dict = fp.prepare_matching_dict() + assert len(matching_dict) == 2 + + +@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) +def test_fileprocessor_names(dataset_name): + fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir") + assert hasattr(fp.matching_function, "__call__") + + +def test_fileprocessor_invaliname(): + with pytest.raises(ValueError): + _ = Fileprocessor.from_name( + "undefined", "clean_dir", "noisy_dir", 16000 + ).prepare_matching_dict() diff --git a/version.txt b/version.txt new file mode 100644 index 0000000..8acdd82 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.0.1