Compare commits
	
		
			No commits in common. "main" and "dev-reformat" have entirely different histories.
		
	
	
		
			main
			...
			dev-reform
		
	
		
							
								
								
									
										2
									
								
								.flake8
								
								
								
								
							
							
						
						
									
										2
									
								
								.flake8
								
								
								
								
							|  | @ -1,5 +1,5 @@ | |||
| [flake8] | ||||
| per-file-ignores = "mayavoz/model/__init__.py:F401" | ||||
| per-file-ignores = __init__.py:F401 | ||||
| ignore = E203, E266, E501, W503 | ||||
| # line length is intentionally set to 80 here because black uses Bugbear | ||||
| # See https://github.com/psf/black/blob/master/README.md#line-length for more details | ||||
|  |  | |||
|  | @ -1 +0,0 @@ | |||
| notebooks/** linguist-vendored | ||||
|  | @ -1,51 +0,0 @@ | |||
| # This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||||
| # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||||
| 
 | ||||
| name: mayavoz | ||||
| 
 | ||||
| on: | ||||
|   push: | ||||
|     branches: [ main ] | ||||
|   pull_request: | ||||
|     branches: [ main ] | ||||
| jobs: | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       matrix: | ||||
|         python-version: [3.8] | ||||
|     steps: | ||||
|     - uses: actions/checkout@v2 | ||||
|     - name: Set up Python | ||||
|       uses: actions/setup-python@v1.1.1 | ||||
|       env : | ||||
|         ACTIONS_ALLOW_UNSECURE_COMMANDS : true | ||||
|       with: | ||||
|         python-version: ${{ matrix.python-version }} | ||||
|     - name: Cache pip | ||||
|       uses: actions/cache@v1 | ||||
|       with: | ||||
|        path: ~/.cache/pip # This path is specific to Ubuntu | ||||
|        # Look to see if there is a cache hit for the corresponding requirements file | ||||
|        key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} | ||||
|        restore-keys: | | ||||
|          ${{ runner.os }}-pip- | ||||
|          ${{ runner.os }}- | ||||
|     # You can test your matrix by printing the current Python version | ||||
|     - name: Display Python version | ||||
|       run: python -c "import sys; print(sys.version)" | ||||
|     - name: Install dependencies | ||||
|       run: | | ||||
|         python -m pip install --upgrade pip | ||||
|         sudo apt-get install libsndfile1 | ||||
|         pip install -r requirements.txt | ||||
|         pip install black pytest-cov | ||||
|     - name: Install mayavoz | ||||
|       run: | | ||||
|           pip install -e .[dev,testing] | ||||
|     - name: Run black | ||||
|       run: | ||||
|         black --check .  --exclude mayavoz/version.py | ||||
|     - name: Test with pytest | ||||
|       run: | ||||
|         pytest tests --cov=mayavoz/ | ||||
|  | @ -1,10 +1,4 @@ | |||
| #local | ||||
| cleaned_my_voice.wav | ||||
| lightning_logs/ | ||||
| my_voice.wav | ||||
| pretrained/ | ||||
| *.ckpt | ||||
| *_local.yaml | ||||
| cli/train_config/dataset/Vctk_local.yaml | ||||
| .DS_Store | ||||
| outputs/ | ||||
|  |  | |||
|  | @ -23,7 +23,6 @@ repos: | |||
|       hooks: | ||||
|       - id: flake8 | ||||
|         args: ['--ignore=E203,E501,F811,E712,W503'] | ||||
|         exclude: __init__.py | ||||
| 
 | ||||
|     # Formatting, Whitespace, etc | ||||
|     - repo: https://github.com/pre-commit/pre-commit-hooks | ||||
|  |  | |||
|  | @ -1,46 +0,0 @@ | |||
| # Contributing | ||||
| 
 | ||||
| Hi there 👋 | ||||
| 
 | ||||
| If you're reading this I hope that you're looking forward to adding value to Mayavoz. This document will help you to get started with your journey. | ||||
| 
 | ||||
| ## How to get your code in Mayavoz | ||||
| 
 | ||||
| 1. We use git and GitHub. | ||||
| 
 | ||||
| 2. Fork the mayavoz repository (https://github.com/shahules786/mayavoz) on GitHub under your own account. (This creates a copy of mayavoz under your account, and GitHub knows where it came from, and we typically call this “upstream”.) | ||||
| 
 | ||||
| 3. Clone your own mayavoz repository. git clone https://github.com/ <your-account> /mayavoz (This downloads the git repository to your machine, git knows where it came from, and calls it “origin”.) | ||||
| 
 | ||||
| 4. Create a branch for each specific feature you are developing. git checkout -b your-branch-name | ||||
| 
 | ||||
| 5. Make + commit changes. git add files-you-changed ... git commit -m "Short message about what you did" | ||||
| 
 | ||||
| 6. Push the branch to your GitHub repository. git push origin your-branch-name | ||||
| 
 | ||||
| 7. Navigate to GitHub, and create a pull request from your branch to the upstream repository mayavoz/mayavoz, to the “develop” branch. | ||||
| 
 | ||||
| 8. The Pull Request (PR) appears on the upstream repository. Discuss your contribution there. If you push more changes to your branch on GitHub (on your repository), they are added to the PR. | ||||
| 
 | ||||
| 9. When the reviewer is satisfied that the code improves repository quality, they can merge. | ||||
| 
 | ||||
| Note that CI tests will be run when you create a PR. If you want to be sure that your code will not fail these tests, we have set up pre-commit hooks that you can install. | ||||
| 
 | ||||
| **If you're worried about things not being perfect with your code, we will work togethor and make it perfect. So, make your move!** | ||||
| 
 | ||||
| ## Formating | ||||
| 
 | ||||
| We use [black](https://black.readthedocs.io/en/stable/) and [flake8](https://flake8.pycqa.org/en/latest/) for code formating. Please ensure that you use the same before submitting the PR. | ||||
| 
 | ||||
| 
 | ||||
| ## Testing | ||||
| We adopt unit testing using [pytest](https://docs.pytest.org/en/latest/contents.html) | ||||
| Please make sure that adding your new component does not decrease test coverage. | ||||
| 
 | ||||
| ## Other tools | ||||
| The use of [per-commit](https://pre-commit.com/) is recommended to ensure different requirements such as code formating, etc. | ||||
| 
 | ||||
| ## How to start contributing to Mayavoz? | ||||
| 
 | ||||
| 1. Checkout issues marked as `good first issue`, let us know you're interested in working on some issue by commenting under it. | ||||
| 2. For others, I would suggest you to explore mayavoz. One way to do is to use it to train your own model. This was you might end by finding a new unreported bug or getting an idea to improve Mayavoz. | ||||
							
								
								
									
										20
									
								
								LICENSE
								
								
								
								
							
							
						
						
									
										20
									
								
								LICENSE
								
								
								
								
							|  | @ -1,20 +0,0 @@ | |||
| MIT License | ||||
| 
 | ||||
| Copyright (c) 2022 Shahul Es | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
| 
 | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
| 
 | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
|  | @ -1,4 +0,0 @@ | |||
| recursive-include mayavoz *.py | ||||
| recursive-include mayavoz *.yaml | ||||
| global-exclude *.pyc | ||||
| global-exclude __pycache__ | ||||
							
								
								
									
										82
									
								
								README.md
								
								
								
								
							
							
						
						
									
										82
									
								
								README.md
								
								
								
								
							|  | @ -1,78 +1,6 @@ | |||
| <p align="center"> | ||||
|   <img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" /> | ||||
| </p> | ||||
| # enhancer | ||||
| Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides | ||||
| 
 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | ||||
| mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio practioners & researchers. It provides easy to use pretrained speech enhancement models and facilitates highly customisable model training. | ||||
| 
 | ||||
| | **[Quick Start](#quick-start-fire)** | **[Installation](#installation)** | **[Tutorials](https://github.com/shahules786/enhancer/tree/main/notebooks)** | **[Available Recipes](#recipes)** | **[Demo](#demo)** | ||||
| ## Key features :key: | ||||
| 
 | ||||
| * Various pretrained models nicely integrated with [huggingface hub](https://huggingface.co/docs/hub/index) :hugs: that users can select and use without any hastle. | ||||
| * :package: Ability to train and validate your own custom speech enhancement models with just under 10 lines of code! | ||||
| * :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself! | ||||
| * :zap: Supports multi-gpu training integrated with [Pytorch Lightning](https://pytorchlightning.ai/). | ||||
| * :shield: data augmentations integrated using [torch-augmentations](https://github.com/asteroid-team/torch-audiomentations) | ||||
| 
 | ||||
| 
 | ||||
| ## Demo | ||||
| 
 | ||||
| Noisy speech followed by enhanced version. | ||||
| 
 | ||||
| https://user-images.githubusercontent.com/25312635/203756185-737557f4-6e21-4146-aa2c-95da69d0de4c.mp4 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| ## Quick Start :fire: | ||||
| ``` python | ||||
| from mayavoz.models import Mayamodel | ||||
| 
 | ||||
| model = Mayamodel.from_pretrained("shahules786/mayavoz-waveunet-valentini-28spk") | ||||
| model.enhance("noisy_audio.wav") | ||||
| ``` | ||||
| 
 | ||||
| ## Recipes | ||||
| 
 | ||||
| | Model     | Dataset      | STOI    | PESQ  | URL                           | | ||||
| | :---:     |  :---:       | :---:   | :---: | :---:                         | | ||||
| | WaveUnet  | Valentini-28spk   | 0.836   | 2.78  |  shahules786/mayavoz-waveunet-valentini-28spk      | | ||||
| | Demucs    | Valentini-28spk   | 0.961   | 2.56  |  shahules786/mayavoz-demucs-valentini-28spk       | | ||||
| | DCCRN     | Valentini-28spk   | 0.724   | 2.55  |  shahules786/mayavoz-dccrn-valentini-28spk         | | ||||
| | Demucs     | MS-SNSD-20hrs  | 0.56 | 1.26  | shahules786/mayavoz-demucs-ms-snsd-20       | | ||||
| 
 | ||||
| Test scores are based on respective test set associated with train dataset. | ||||
| 
 | ||||
| **See [tutorials](/notebooks/) to train your custom model** | ||||
| 
 | ||||
| ## Installation | ||||
| Only Python 3.8+ is officially supported (though it might work with Python 3.7) | ||||
| 
 | ||||
| - With Pypi | ||||
| ``` | ||||
| pip install mayavoz | ||||
| ``` | ||||
| 
 | ||||
| - With conda | ||||
| 
 | ||||
| ``` | ||||
| conda env create -f environment.yml | ||||
| conda activate mayavoz | ||||
| ``` | ||||
| 
 | ||||
| - From source code | ||||
| ``` | ||||
| git clone url | ||||
| cd mayavoz | ||||
| pip install -e . | ||||
| ``` | ||||
| 
 | ||||
| ## Support | ||||
| 
 | ||||
| For commercial enquiries and scientific consulting, please [contact me](https://shahules786.github.io/). | ||||
| 
 | ||||
| ### Acknowledgements | ||||
| Sincere gratitude to [AMPLYFI](https://amplyfi.com/) for supporting this project. | ||||
| * Various pretrained models nicely integrated with huggingface that users can select and use without any hastle. | ||||
| * Ability to train and validation your own custom speech enhancement models with just under 10 lines of code! | ||||
| * A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself! | ||||
|  | @ -3,17 +3,11 @@ from types import MethodType | |||
| 
 | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig, OmegaConf | ||||
| from pytorch_lightning.callbacks import ( | ||||
|     EarlyStopping, | ||||
|     LearningRateMonitor, | ||||
|     ModelCheckpoint, | ||||
| ) | ||||
| from omegaconf import DictConfig | ||||
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| 
 | ||||
| # from torch_audiomentations import Compose, Shift | ||||
| 
 | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||
| 
 | ||||
|  | @ -21,8 +15,6 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0") | |||
| @hydra.main(config_path="train_config", config_name="config") | ||||
| def main(config: DictConfig): | ||||
| 
 | ||||
|     OmegaConf.save(config, "config_log.yaml") | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger( | ||||
|         experiment_name=config.mlflow.experiment_name, | ||||
|  | @ -31,13 +23,8 @@ def main(config: DictConfig): | |||
|     ) | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
|     # apply_augmentations = Compose( | ||||
|     #     [ | ||||
|     #         Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), | ||||
|     #     ] | ||||
|     # ) | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset, augmentations=None) | ||||
|     dataset = instantiate(config.dataset) | ||||
|     model = instantiate( | ||||
|         config.model, | ||||
|         dataset=dataset, | ||||
|  | @ -50,15 +37,12 @@ def main(config: DictConfig): | |||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model", | ||||
|         filename=f"model_{JOB_ID}", | ||||
|         monitor="valid_loss", | ||||
|         verbose=False, | ||||
|         monitor="val_loss", | ||||
|         verbose=True, | ||||
|         mode=direction, | ||||
|         every_n_epochs=1, | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     callbacks.append(LearningRateMonitor(logging_interval="epoch")) | ||||
| 
 | ||||
|     if parameters.get("Early_stop", False): | ||||
|     early_stopping = EarlyStopping( | ||||
|         monitor="val_loss", | ||||
|         mode=direction, | ||||
|  | @ -69,11 +53,11 @@ def main(config: DictConfig): | |||
|     ) | ||||
|     callbacks.append(early_stopping) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|     def configure_optimizer(self): | ||||
|         optimizer = instantiate( | ||||
|             config.optimizer, | ||||
|             lr=parameters.get("lr"), | ||||
|             params=self.parameters(), | ||||
|             parameters=self.parameters(), | ||||
|         ) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|  | @ -83,37 +67,18 @@ def main(config: DictConfig): | |||
|             min_lr=parameters.get("min_lr", 1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience", 3), | ||||
|         ) | ||||
|         return { | ||||
|             "optimizer": optimizer, | ||||
|             "lr_scheduler": scheduler, | ||||
|             "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', | ||||
|         } | ||||
|         return {"optimizer": optimizer, "lr_scheduler": scheduler} | ||||
| 
 | ||||
|     model.configure_optimizers = MethodType(configure_optimizers, model) | ||||
|     model.configure_parameters = MethodType(configure_optimizer, model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     trainer.test(model) | ||||
| 
 | ||||
|     logger.experiment.log_artifact( | ||||
|         logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" | ||||
|     ) | ||||
| 
 | ||||
|     saved_location = os.path.join( | ||||
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||
|     ) | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_train_steps_per_epoch", | ||||
|             dataset.train__len__() / dataset.batch_size, | ||||
|         ) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_valid_steps_per_epoch", | ||||
|             dataset.val__len__() / dataset.batch_size, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  | @ -0,0 +1,12 @@ | |||
| _target_: enhancer.data.dataset.EnhancerDataset | ||||
| root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test | ||||
| name : dns-2020 | ||||
| duration : 1.0 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 32 | ||||
| files: | ||||
|   root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test | ||||
|   train_clean : clean_test_wav | ||||
|   test_clean : clean_test_wav | ||||
|   train_noisy : clean_test_wav | ||||
|   test_noisy : clean_test_wav | ||||
|  | @ -1,11 +1,10 @@ | |||
| _target_: mayavoz.data.dataset.MayaDataset | ||||
| _target_: enhancer.data.dataset.EnhancerDataset | ||||
| name : vctk | ||||
| root_dir : /scratch/c.sistc3/DS_10283_2791 | ||||
| duration : 2 | ||||
| stride : 1 | ||||
| duration : 1.0 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 128 | ||||
| valid_minutes : 25 | ||||
| batch_size: 64 | ||||
| 
 | ||||
| files: | ||||
|   train_clean : clean_trainset_28spk_wav | ||||
|   test_clean : clean_testset_wav | ||||
|  | @ -0,0 +1,13 @@ | |||
| _target_: enhancer.data.dataset.EnhancerDataset | ||||
| name : vctk | ||||
| root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk | ||||
| duration : 1.0 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 64 | ||||
| num_workers : 0 | ||||
| 
 | ||||
| files: | ||||
|   train_clean : clean_testset_wav | ||||
|   test_clean : clean_testset_wav | ||||
|   train_noisy : noisy_testset_wav | ||||
|   test_noisy : noisy_testset_wav | ||||
|  | @ -0,0 +1,7 @@ | |||
| loss : mse | ||||
| metric : mae | ||||
| lr : 0.0001 | ||||
| ReduceLr_patience : 5 | ||||
| ReduceLr_factor : 0.1 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
|  | @ -0,0 +1,2 @@ | |||
| experiment_name : shahules/enhancer | ||||
| run_name : baseline | ||||
|  | @ -1,13 +1,13 @@ | |||
| _target_: mayavoz.models.demucs.Demucs | ||||
| _target_: enhancer.models.demucs.Demucs | ||||
| num_channels: 1 | ||||
| resample: 4 | ||||
| resample: 2 | ||||
| sampling_rate : 16000 | ||||
| 
 | ||||
| encoder_decoder: | ||||
|   depth: 4 | ||||
|   initial_output_channels: 64 | ||||
|   depth: 5 | ||||
|   initial_output_channels: 32 | ||||
|   kernel_size: 8 | ||||
|   stride: 4 | ||||
|   stride: 1 | ||||
|   growth_factor: 2 | ||||
|   glu: True | ||||
| 
 | ||||
|  | @ -1,5 +1,5 @@ | |||
| _target_: mayavoz.models.waveunet.WaveUnet | ||||
| _target_: enhancer.models.waveunet.WaveUnet | ||||
| num_channels : 1 | ||||
| depth : 9 | ||||
| depth : 12 | ||||
| initial_output_channels: 24 | ||||
| sampling_rate : 16000 | ||||
|  | @ -1,5 +1,5 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| accelerator: gpu | ||||
| accelerator: auto | ||||
| accumulate_grad_batches: 1 | ||||
| amp_backend: native | ||||
| auto_lr_find: True | ||||
|  | @ -9,7 +9,7 @@ benchmark: False | |||
| check_val_every_n_epoch: 1 | ||||
| detect_anomaly: False | ||||
| deterministic: False | ||||
| devices: 1 | ||||
| devices: -1 | ||||
| enable_checkpointing: True | ||||
| enable_model_summary: True | ||||
| enable_progress_bar: True | ||||
|  | @ -22,9 +22,9 @@ limit_predict_batches: 1.0 | |||
| limit_test_batches: 1.0 | ||||
| limit_train_batches: 1.0 | ||||
| limit_val_batches: 1.0 | ||||
| log_every_n_steps: 50 | ||||
| max_epochs: 200 | ||||
| max_steps: -1 | ||||
| log_every_n_steps: 1 | ||||
| max_epochs: 10 | ||||
| max_steps: null | ||||
| max_time: null | ||||
| min_epochs: 1 | ||||
| min_steps: null | ||||
|  | @ -0,0 +1 @@ | |||
| from enhancer.data.dataset import EnhancerDataset | ||||
|  | @ -0,0 +1,220 @@ | |||
| import math | ||||
| import multiprocessing | ||||
| import os | ||||
| from typing import Optional | ||||
| 
 | ||||
| import pytorch_lightning as pl | ||||
| import torch.nn.functional as F | ||||
| from torch.utils.data import DataLoader, Dataset, IterableDataset | ||||
| 
 | ||||
| from enhancer.data.fileprocessor import Fileprocessor | ||||
| from enhancer.utils import check_files | ||||
| from enhancer.utils.config import Files | ||||
| from enhancer.utils.io import Audio | ||||
| from enhancer.utils.random import create_unique_rng | ||||
| 
 | ||||
| 
 | ||||
| class TrainDataset(IterableDataset): | ||||
|     def __init__(self, dataset): | ||||
|         self.dataset = dataset | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         return self.dataset.train__iter__() | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self.dataset.train__len__() | ||||
| 
 | ||||
| 
 | ||||
| class ValidDataset(Dataset): | ||||
|     def __init__(self, dataset): | ||||
|         self.dataset = dataset | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         return self.dataset.val__getitem__(idx) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self.dataset.val__len__() | ||||
| 
 | ||||
| 
 | ||||
| class TaskDataset(pl.LightningDataModule): | ||||
|     def __init__( | ||||
|         self, | ||||
|         name: str, | ||||
|         root_dir: str, | ||||
|         files: Files, | ||||
|         duration: float = 1.0, | ||||
|         sampling_rate: int = 48000, | ||||
|         matching_function=None, | ||||
|         batch_size=32, | ||||
|         num_workers: Optional[int] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.files, self.root_dir = check_files(root_dir, files) | ||||
|         self.duration = duration | ||||
|         self.sampling_rate = sampling_rate | ||||
|         self.batch_size = batch_size | ||||
|         self.matching_function = matching_function | ||||
|         self._validation = [] | ||||
|         if num_workers is None: | ||||
|             num_workers = multiprocessing.cpu_count() // 2 | ||||
|         self.num_workers = num_workers | ||||
| 
 | ||||
|     def setup(self, stage: Optional[str] = None): | ||||
| 
 | ||||
|         if stage in ("fit", None): | ||||
| 
 | ||||
|             train_clean = os.path.join(self.root_dir, self.files.train_clean) | ||||
|             train_noisy = os.path.join(self.root_dir, self.files.train_noisy) | ||||
|             fp = Fileprocessor.from_name( | ||||
|                 self.name, train_clean, train_noisy, self.matching_function | ||||
|             ) | ||||
|             self.train_data = fp.prepare_matching_dict() | ||||
| 
 | ||||
|             val_clean = os.path.join(self.root_dir, self.files.test_clean) | ||||
|             val_noisy = os.path.join(self.root_dir, self.files.test_noisy) | ||||
|             fp = Fileprocessor.from_name( | ||||
|                 self.name, val_clean, val_noisy, self.matching_function | ||||
|             ) | ||||
|             val_data = fp.prepare_matching_dict() | ||||
| 
 | ||||
|             for item in val_data: | ||||
|                 clean, noisy, total_dur = item.values() | ||||
|                 if total_dur < self.duration: | ||||
|                     continue | ||||
|                 num_segments = round(total_dur / self.duration) | ||||
|                 for index in range(num_segments): | ||||
|                     start_time = index * self.duration | ||||
|                     self._validation.append( | ||||
|                         ({"clean": clean, "noisy": noisy}, start_time) | ||||
|                     ) | ||||
| 
 | ||||
|     def train_dataloader(self): | ||||
|         return DataLoader( | ||||
|             TrainDataset(self), | ||||
|             batch_size=self.batch_size, | ||||
|             num_workers=self.num_workers, | ||||
|         ) | ||||
| 
 | ||||
|     def val_dataloader(self): | ||||
|         return DataLoader( | ||||
|             ValidDataset(self), | ||||
|             batch_size=self.batch_size, | ||||
|             num_workers=self.num_workers, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class EnhancerDataset(TaskDataset): | ||||
|     """ | ||||
|     Dataset object for creating clean-noisy speech enhancement datasets | ||||
|     paramters: | ||||
|     name : str | ||||
|         name of the dataset | ||||
|     root_dir : str | ||||
|         root directory of the dataset containing clean/noisy folders | ||||
|     files : Files | ||||
|         dataclass containing train_clean, train_noisy, test_clean, test_noisy | ||||
|         folder names (refer enhancer.utils.Files dataclass) | ||||
|     duration : float | ||||
|         expected audio duration of single audio sample for training | ||||
|     sampling_rate : int | ||||
|         desired sampling rate | ||||
|     batch_size : int | ||||
|         batch size of each batch | ||||
|     num_workers : int | ||||
|         num workers to be used while training | ||||
|     matching_function : str | ||||
|         maching functions - (one_to_one,one_to_many). Default set to None. | ||||
|         use one_to_one mapping for datasets with one noisy file for each clean file | ||||
|         use one_to_many mapping for multiple noisy files for each clean file | ||||
| 
 | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         name: str, | ||||
|         root_dir: str, | ||||
|         files: Files, | ||||
|         duration=1.0, | ||||
|         sampling_rate=48000, | ||||
|         matching_function=None, | ||||
|         batch_size=32, | ||||
|         num_workers: Optional[int] = None, | ||||
|     ): | ||||
| 
 | ||||
|         super().__init__( | ||||
|             name=name, | ||||
|             root_dir=root_dir, | ||||
|             files=files, | ||||
|             sampling_rate=sampling_rate, | ||||
|             duration=duration, | ||||
|             matching_function=matching_function, | ||||
|             batch_size=batch_size, | ||||
|             num_workers=num_workers, | ||||
|         ) | ||||
| 
 | ||||
|         self.sampling_rate = sampling_rate | ||||
|         self.files = files | ||||
|         self.duration = max(1.0, duration) | ||||
|         self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) | ||||
| 
 | ||||
|     def setup(self, stage: Optional[str] = None): | ||||
| 
 | ||||
|         super().setup(stage=stage) | ||||
| 
 | ||||
|     def train__iter__(self): | ||||
| 
 | ||||
|         rng = create_unique_rng(self.model.current_epoch) | ||||
| 
 | ||||
|         while True: | ||||
| 
 | ||||
|             file_dict, *_ = rng.choices( | ||||
|                 self.train_data, | ||||
|                 k=1, | ||||
|                 weights=[file["duration"] for file in self.train_data], | ||||
|             ) | ||||
|             file_duration = file_dict["duration"] | ||||
|             start_time = round(rng.uniform(0, file_duration - self.duration), 2) | ||||
|             data = self.prepare_segment(file_dict, start_time) | ||||
|             yield data | ||||
| 
 | ||||
|     def val__getitem__(self, idx): | ||||
|         return self.prepare_segment(*self._validation[idx]) | ||||
| 
 | ||||
|     def prepare_segment(self, file_dict: dict, start_time: float): | ||||
| 
 | ||||
|         clean_segment = self.audio( | ||||
|             file_dict["clean"], offset=start_time, duration=self.duration | ||||
|         ) | ||||
|         noisy_segment = self.audio( | ||||
|             file_dict["noisy"], offset=start_time, duration=self.duration | ||||
|         ) | ||||
|         clean_segment = F.pad( | ||||
|             clean_segment, | ||||
|             ( | ||||
|                 0, | ||||
|                 int( | ||||
|                     self.duration * self.sampling_rate - clean_segment.shape[-1] | ||||
|                 ), | ||||
|             ), | ||||
|         ) | ||||
|         noisy_segment = F.pad( | ||||
|             noisy_segment, | ||||
|             ( | ||||
|                 0, | ||||
|                 int( | ||||
|                     self.duration * self.sampling_rate - noisy_segment.shape[-1] | ||||
|                 ), | ||||
|             ), | ||||
|         ) | ||||
|         return {"clean": clean_segment, "noisy": noisy_segment} | ||||
| 
 | ||||
|     def train__len__(self): | ||||
|         return math.ceil( | ||||
|             sum([file["duration"] for file in self.train_data]) / self.duration | ||||
|         ) | ||||
| 
 | ||||
|     def val__len__(self): | ||||
|         return len(self._validation) | ||||
|  | @ -55,31 +55,32 @@ class ProcessorFunctions: | |||
|         One clean audio have multiple noisy audio files | ||||
|         """ | ||||
| 
 | ||||
|         matching_wavfiles = list() | ||||
|         matching_wavfiles = dict() | ||||
|         clean_filenames = [ | ||||
|             file.split("/")[-1] | ||||
|             for file in glob.glob(os.path.join(clean_path, "*.wav")) | ||||
|         ] | ||||
|         for clean_file in clean_filenames: | ||||
|             noisy_filenames = glob.glob( | ||||
|                 os.path.join(noisy_path, f"*_{clean_file}") | ||||
|                 os.path.join(noisy_path, f"*_{clean_file}.wav") | ||||
|             ) | ||||
|             for noisy_file in noisy_filenames: | ||||
| 
 | ||||
|                 sr_clean, clean_wav = wavfile.read( | ||||
|                 sr_clean, clean_file = wavfile.read( | ||||
|                     os.path.join(clean_path, clean_file) | ||||
|                 ) | ||||
|                 sr_noisy, noisy_wav = wavfile.read(noisy_file) | ||||
|                 if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and ( | ||||
|                 sr_noisy, noisy_file = wavfile.read(noisy_file) | ||||
|                 if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( | ||||
|                     sr_clean == sr_noisy | ||||
|                 ): | ||||
|                     matching_wavfiles.append( | ||||
|                     matching_wavfiles.update( | ||||
|                         { | ||||
|                             "clean": os.path.join(clean_path, clean_file), | ||||
|                             "noisy": noisy_file, | ||||
|                             "duration": clean_wav.shape[-1] / sr_clean, | ||||
|                             "duration": clean_file.shape[-1] / sr_clean, | ||||
|                         } | ||||
|                     ) | ||||
| 
 | ||||
|         return matching_wavfiles | ||||
| 
 | ||||
| 
 | ||||
|  | @ -93,14 +94,10 @@ class Fileprocessor: | |||
|     def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): | ||||
| 
 | ||||
|         if matching_function is None: | ||||
|             if name.lower() in ("vctk", "valentini"): | ||||
|             if name.lower() == "vctk": | ||||
|                 return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) | ||||
|             elif name.lower() == "ms-snsd": | ||||
|             elif name.lower() == "dns-2020": | ||||
|                 return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}" | ||||
|                 ) | ||||
|         else: | ||||
|             if matching_function not in MATCHING_FNS: | ||||
|                 raise ValueError( | ||||
|  | @ -8,7 +8,7 @@ from librosa import load as load_audio | |||
| from scipy.io import wavfile | ||||
| from scipy.signal import get_window | ||||
| 
 | ||||
| from mayavoz.utils import Audio | ||||
| from enhancer.utils import Audio | ||||
| 
 | ||||
| 
 | ||||
| class Inference: | ||||
|  | @ -91,11 +91,10 @@ class Inference: | |||
|         window_size: int, | ||||
|         total_frames: int, | ||||
|         step_size: Optional[int] = None, | ||||
|         window="hamming", | ||||
|         window="hanning", | ||||
|     ): | ||||
|         """ | ||||
|         stitch batched waveform into single waveform. (Overlap-add) | ||||
|         inspired from https://github.com/asteroid-team/asteroid | ||||
|         arguments: | ||||
|             data: batched waveform | ||||
|             window_size : window_size used to batch waveform | ||||
|  | @ -140,9 +139,7 @@ class Inference: | |||
|         if filename.is_file(): | ||||
|             raise FileExistsError(f"file {filename} already exists") | ||||
|         else: | ||||
|             wavfile.write( | ||||
|                 filename, rate=sr, data=waveform.detach().cpu().numpy() | ||||
|             ) | ||||
|             wavfile.write(filename, rate=sr, data=waveform.detach().cpu()) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def prepare_output( | ||||
|  | @ -1,11 +1,5 @@ | |||
| import warnings | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torchmetrics import ScaleInvariantSignalNoiseRatio | ||||
| from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality | ||||
| from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility | ||||
| 
 | ||||
| 
 | ||||
| class mean_squared_error(nn.Module): | ||||
|  | @ -18,7 +12,6 @@ class mean_squared_error(nn.Module): | |||
| 
 | ||||
|         self.loss_fun = nn.MSELoss(reduction=reduction) | ||||
|         self.higher_better = False | ||||
|         self.name = "mse" | ||||
| 
 | ||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|  | @ -41,7 +34,6 @@ class mean_absolute_error(nn.Module): | |||
| 
 | ||||
|         self.loss_fun = nn.L1Loss(reduction=reduction) | ||||
|         self.higher_better = False | ||||
|         self.name = "mae" | ||||
| 
 | ||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|  | @ -54,22 +46,22 @@ class mean_absolute_error(nn.Module): | |||
|         return self.loss_fun(prediction, target) | ||||
| 
 | ||||
| 
 | ||||
| class Si_SDR: | ||||
| class Si_SDR(nn.Module): | ||||
|     """ | ||||
|     SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, reduction: str = "mean"): | ||||
|         super().__init__() | ||||
|         if reduction in ["sum", "mean", None]: | ||||
|             self.reduction = reduction | ||||
|         else: | ||||
|             raise TypeError( | ||||
|                 "Invalid reduction, valid options are sum, mean, None" | ||||
|             ) | ||||
|         self.higher_better = True | ||||
|         self.name = "si-sdr" | ||||
|         self.higher_better = False | ||||
| 
 | ||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|         if prediction.size() != target.size() or target.ndim < 3: | ||||
|             raise TypeError( | ||||
|  | @ -98,47 +90,7 @@ class Si_SDR: | |||
|         return si_sdr | ||||
| 
 | ||||
| 
 | ||||
| class Stoi: | ||||
|     """ | ||||
|     STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. | ||||
|     Note that input will be moved to cpu to perform the metric calculation. | ||||
|     parameters: | ||||
|         sr: int | ||||
|             sampling rate | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, sr: int): | ||||
|         self.sr = sr | ||||
|         self.stoi = ShortTimeObjectiveIntelligibility(fs=sr) | ||||
|         self.name = "stoi" | ||||
| 
 | ||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|         return self.stoi(prediction, target) | ||||
| 
 | ||||
| 
 | ||||
| class Pesq: | ||||
|     def __init__(self, sr: int, mode="wb"): | ||||
| 
 | ||||
|         self.sr = sr | ||||
|         self.name = "pesq" | ||||
|         self.mode = mode | ||||
|         self.pesq = PerceptualEvaluationSpeechQuality( | ||||
|             fs=self.sr, mode=self.mode | ||||
|         ) | ||||
| 
 | ||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|         pesq_values = [] | ||||
|         for pred, target_ in zip(prediction, target): | ||||
|             try: | ||||
|                 pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) | ||||
|             except Exception as e: | ||||
|                 warnings.warn(f"{e} error occured while calculating PESQ") | ||||
|         return torch.tensor(np.mean(pesq_values)) | ||||
| 
 | ||||
| 
 | ||||
| class LossWrapper(nn.Module): | ||||
| class Avergeloss(nn.Module): | ||||
|     """ | ||||
|     Combine multiple metics of same nature. | ||||
|     for example, ["mea","mae"] | ||||
|  | @ -160,11 +112,9 @@ class LossWrapper(nn.Module): | |||
|             ) | ||||
| 
 | ||||
|         self.higher_better = direction[0] | ||||
|         self.name = "" | ||||
|         for loss in losses: | ||||
|             loss = self.validate_loss(loss) | ||||
|             self.valid_losses.append(loss()) | ||||
|             self.name += f"{loss().name}_" | ||||
| 
 | ||||
|     def validate_loss(self, loss: str): | ||||
|         if loss not in LOSS_MAP.keys(): | ||||
|  | @ -183,34 +133,8 @@ class LossWrapper(nn.Module): | |||
|         return loss | ||||
| 
 | ||||
| 
 | ||||
| class Si_snr(nn.Module): | ||||
|     """ | ||||
|     SI-SNR | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) | ||||
|         self.higher_better = False | ||||
|         self.name = "si_snr" | ||||
| 
 | ||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||
| 
 | ||||
|         if prediction.size() != target.size() or target.ndim < 3: | ||||
|             raise TypeError( | ||||
|                 f"""Inputs must be of the same shape (batch_size,channels,samples) | ||||
|                     got {prediction.size()} and {target.size()} instead""" | ||||
|             ) | ||||
| 
 | ||||
|         return -1 * self.loss_fun(prediction, target) | ||||
| 
 | ||||
| 
 | ||||
| LOSS_MAP = { | ||||
|     "mae": mean_absolute_error, | ||||
|     "mse": mean_squared_error, | ||||
|     "si-sdr": Si_SDR, | ||||
|     "pesq": Pesq, | ||||
|     "stoi": Stoi, | ||||
|     "si-snr": Si_snr, | ||||
|     "SI-SDR": Si_SDR, | ||||
| } | ||||
|  | @ -0,0 +1,3 @@ | |||
| from enhancer.models.demucs import Demucs | ||||
| from enhancer.models.model import Model | ||||
| from enhancer.models.waveunet import WaveUnet | ||||
|  | @ -1,14 +1,14 @@ | |||
| import logging | ||||
| import math | ||||
| import warnings | ||||
| from typing import List, Optional, Union | ||||
| 
 | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
| 
 | ||||
| from mayavoz.data.dataset import MayaDataset | ||||
| from mayavoz.models.model import Mayamodel | ||||
| from mayavoz.utils.io import Audio as audio | ||||
| from mayavoz.utils.utils import merge_dict | ||||
| from enhancer.data.dataset import EnhancerDataset | ||||
| from enhancer.models.model import Model | ||||
| from enhancer.utils.io import Audio as audio | ||||
| from enhancer.utils.utils import merge_dict | ||||
| 
 | ||||
| 
 | ||||
| class DemucsLSTM(nn.Module): | ||||
|  | @ -49,7 +49,7 @@ class DemucsEncoder(nn.Module): | |||
|         self.encoder = nn.Sequential( | ||||
|             nn.Conv1d(num_channels, hidden_size, kernel_size, stride), | ||||
|             nn.ReLU(), | ||||
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), | ||||
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), | ||||
|             activation, | ||||
|         ) | ||||
| 
 | ||||
|  | @ -72,7 +72,7 @@ class DemucsDecoder(nn.Module): | |||
|         activation = nn.GLU(1) if glu else nn.ReLU() | ||||
|         multi_factor = 2 if glu else 1 | ||||
|         self.decoder = nn.Sequential( | ||||
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), | ||||
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), | ||||
|             activation, | ||||
|             nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), | ||||
|         ) | ||||
|  | @ -88,7 +88,7 @@ class DemucsDecoder(nn.Module): | |||
|         return out | ||||
| 
 | ||||
| 
 | ||||
| class Demucs(Mayamodel): | ||||
| class Demucs(Model): | ||||
|     """ | ||||
|     Demucs model from https://arxiv.org/pdf/1911.13254.pdf | ||||
|     parameters: | ||||
|  | @ -102,8 +102,8 @@ class Demucs(Mayamodel): | |||
|             sampling rate of input audio | ||||
|         lr : float, defaults to 1e-3 | ||||
|             learning rate used for training | ||||
|         dataset: MayaDataset, optional | ||||
|             MayaDataset object containing train/validation data for training | ||||
|         dataset: EnhancerDataset, optional | ||||
|             EnhancerDataset object containing train/validation data for training | ||||
|         duration : float, optional | ||||
|             chunk duration in seconds | ||||
|         loss : string or List of strings | ||||
|  | @ -116,7 +116,7 @@ class Demucs(Mayamodel): | |||
|     ED_DEFAULTS = { | ||||
|         "initial_output_channels": 48, | ||||
|         "kernel_size": 8, | ||||
|         "stride": 4, | ||||
|         "stride": 1, | ||||
|         "depth": 5, | ||||
|         "glu": True, | ||||
|         "growth_factor": 2, | ||||
|  | @ -133,20 +133,17 @@ class Demucs(Mayamodel): | |||
|         num_channels: int = 1, | ||||
|         resample: int = 4, | ||||
|         sampling_rate=16000, | ||||
|         normalize=True, | ||||
|         lr: float = 1e-3, | ||||
|         dataset: Optional[MayaDataset] = None, | ||||
|         duration: Optional[float] = None, | ||||
|         dataset: Optional[EnhancerDataset] = None, | ||||
|         loss: Union[str, List] = "mse", | ||||
|         metric: Union[str, List] = "mse", | ||||
|         floor=1e-3, | ||||
|     ): | ||||
|         duration = ( | ||||
|             dataset.duration if isinstance(dataset, MayaDataset) else duration | ||||
|             dataset.duration if isinstance(dataset, EnhancerDataset) else None | ||||
|         ) | ||||
|         if dataset is not None: | ||||
|             if sampling_rate != dataset.sampling_rate: | ||||
|                 warnings.warn( | ||||
|                 logging.warn( | ||||
|                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||
|                 ) | ||||
|                 sampling_rate = dataset.sampling_rate | ||||
|  | @ -164,8 +161,6 @@ class Demucs(Mayamodel): | |||
|         lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||||
|         self.save_hyperparameters("encoder_decoder", "lstm", "resample") | ||||
|         hidden = encoder_decoder["initial_output_channels"] | ||||
|         self.normalize = normalize | ||||
|         self.floor = floor | ||||
|         self.encoder = nn.ModuleList() | ||||
|         self.decoder = nn.ModuleList() | ||||
| 
 | ||||
|  | @ -184,7 +179,7 @@ class Demucs(Mayamodel): | |||
|                 num_channels=num_channels, | ||||
|                 hidden_size=hidden, | ||||
|                 kernel_size=encoder_decoder["kernel_size"], | ||||
|                 stride=encoder_decoder["stride"], | ||||
|                 stride=1, | ||||
|                 glu=encoder_decoder["glu"], | ||||
|                 layer=layer, | ||||
|             ) | ||||
|  | @ -205,16 +200,11 @@ class Demucs(Mayamodel): | |||
|         if waveform.dim() == 2: | ||||
|             waveform = waveform.unsqueeze(1) | ||||
| 
 | ||||
|         if waveform.size(1) != self.hparams.num_channels: | ||||
|             raise ValueError( | ||||
|                 f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" | ||||
|         if waveform.size(1) != 1: | ||||
|             raise TypeError( | ||||
|                 f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" | ||||
|             ) | ||||
|         if self.normalize: | ||||
|             waveform = waveform.mean(dim=1, keepdim=True) | ||||
|             std = waveform.std(dim=-1, keepdim=True) | ||||
|             waveform = waveform / (self.floor + std) | ||||
|         else: | ||||
|             std = 1 | ||||
| 
 | ||||
|         length = waveform.shape[-1] | ||||
|         x = F.pad(waveform, (0, self.get_padding_length(length) - length)) | ||||
|         if self.hparams.resample > 1: | ||||
|  | @ -236,7 +226,7 @@ class Demucs(Mayamodel): | |||
|         x = x.permute(0, 2, 1) | ||||
|         for decoder in self.decoder: | ||||
|             skip_connection = encoder_outputs.pop(-1) | ||||
|             x = x + skip_connection[..., : x.shape[-1]] | ||||
|             x += skip_connection[..., : x.shape[-1]] | ||||
|             x = decoder(x) | ||||
| 
 | ||||
|         if self.hparams.resample > 1: | ||||
|  | @ -246,8 +236,7 @@ class Demucs(Mayamodel): | |||
|                 self.hparams.sampling_rate, | ||||
|             ) | ||||
| 
 | ||||
|         out = x[..., :length] | ||||
|         return std * out | ||||
|         return x | ||||
| 
 | ||||
|     def get_padding_length(self, input_length): | ||||
| 
 | ||||
|  | @ -1,8 +1,7 @@ | |||
| import os | ||||
| from collections import defaultdict | ||||
| from importlib import import_module | ||||
| from pathlib import Path | ||||
| from typing import Any, List, Optional, Text, Union | ||||
| from typing import Any, Dict, List, Optional, Text, Union | ||||
| from urllib.parse import urlparse | ||||
| 
 | ||||
| import numpy as np | ||||
|  | @ -10,24 +9,19 @@ import pytorch_lightning as pl | |||
| import torch | ||||
| from huggingface_hub import cached_download, hf_hub_url | ||||
| from pytorch_lightning.utilities.cloud_io import load as pl_load | ||||
| from torch import nn | ||||
| from torch.optim import Adam | ||||
| 
 | ||||
| from mayavoz.data.dataset import MayaDataset | ||||
| from mayavoz.inference import Inference | ||||
| from mayavoz.loss import LOSS_MAP, LossWrapper | ||||
| from mayavoz.version import __version__ | ||||
| from enhancer import __version__ | ||||
| from enhancer.data.dataset import EnhancerDataset | ||||
| from enhancer.inference import Inference | ||||
| from enhancer.loss import Avergeloss | ||||
| 
 | ||||
| CACHE_DIR = os.getenv( | ||||
|     "ENHANCER_CACHE", | ||||
|     os.path.expanduser("~/.cache/torch/mayavoz"), | ||||
| ) | ||||
| HF_TORCH_WEIGHTS = "pytorch_model.ckpt" | ||||
| CACHE_DIR = "" | ||||
| HF_TORCH_WEIGHTS = "" | ||||
| DEFAULT_DEVICE = "cpu" | ||||
| SAVE_NAME = "mayavoz" | ||||
| 
 | ||||
| 
 | ||||
| class Mayamodel(pl.LightningModule): | ||||
| class Model(pl.LightningModule): | ||||
|     """ | ||||
|     Base class for all models | ||||
|     parameters: | ||||
|  | @ -37,11 +31,11 @@ class Mayamodel(pl.LightningModule): | |||
|             audio sampling rate | ||||
|         lr: float, optional | ||||
|             learning rate for model training | ||||
|         dataset: MayaDataset, optional | ||||
|             mayavoz dataset used for training/validation | ||||
|         dataset: EnhancerDataset, optional | ||||
|             Enhancer dataset used for training/validation | ||||
|         duration: float, optional | ||||
|             duration used for training/inference | ||||
|         loss : string or List of strings or custom loss (nn.Module), default to "mse" | ||||
|         loss : string or List of strings, default to "mse" | ||||
|             loss functions to be used. Available ("mse","mae","Si-SDR") | ||||
| 
 | ||||
|     """ | ||||
|  | @ -51,13 +45,15 @@ class Mayamodel(pl.LightningModule): | |||
|         num_channels: int = 1, | ||||
|         sampling_rate: int = 16000, | ||||
|         lr: float = 1e-3, | ||||
|         dataset: Optional[MayaDataset] = None, | ||||
|         dataset: Optional[EnhancerDataset] = None, | ||||
|         duration: Optional[float] = None, | ||||
|         loss: Union[str, List] = "mse", | ||||
|         metric: Union[str, List, Any] = "mse", | ||||
|         metric: Union[str, List] = "mse", | ||||
|     ): | ||||
|         super().__init__() | ||||
|         assert num_channels == 1, "mayavoz only support for mono channel models" | ||||
|         assert ( | ||||
|             num_channels == 1 | ||||
|         ), "Enhancer only support for mono channel models" | ||||
|         self.dataset = dataset | ||||
|         self.save_hyperparameters( | ||||
|             "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" | ||||
|  | @ -78,9 +74,9 @@ class Mayamodel(pl.LightningModule): | |||
|     def loss(self, loss): | ||||
| 
 | ||||
|         if isinstance(loss, str): | ||||
|             loss = [loss] | ||||
|             losses = [loss] | ||||
| 
 | ||||
|         self._loss = LossWrapper(loss) | ||||
|         self._loss = Avergeloss(losses) | ||||
| 
 | ||||
|     @property | ||||
|     def metric(self): | ||||
|  | @ -88,26 +84,11 @@ class Mayamodel(pl.LightningModule): | |||
| 
 | ||||
|     @metric.setter | ||||
|     def metric(self, metric): | ||||
|         self._metric = [] | ||||
|         if isinstance(metric, (str, nn.Module)): | ||||
| 
 | ||||
|         if isinstance(metric, str): | ||||
|             metric = [metric] | ||||
| 
 | ||||
|         for func in metric: | ||||
|             if isinstance(func, str): | ||||
|                 if func in LOSS_MAP.keys(): | ||||
|                     if func in ("pesq", "stoi"): | ||||
|                         self._metric.append( | ||||
|                             LOSS_MAP[func](self.hparams.sampling_rate) | ||||
|                         ) | ||||
|                     else: | ||||
|                         self._metric.append(LOSS_MAP[func]()) | ||||
|                 else: | ||||
|                     ValueError(f"Invalid metrics {func}") | ||||
| 
 | ||||
|             elif isinstance(func, nn.Module): | ||||
|                 self._metric.append(func) | ||||
|             else: | ||||
|                 raise ValueError("Invalid metrics") | ||||
|         self._metric = Avergeloss(metric) | ||||
| 
 | ||||
|     @property | ||||
|     def dataset(self): | ||||
|  | @ -119,41 +100,15 @@ class Mayamodel(pl.LightningModule): | |||
| 
 | ||||
|     def setup(self, stage: Optional[str] = None): | ||||
|         if stage == "fit": | ||||
|             torch.cuda.empty_cache() | ||||
|             self.dataset.setup(stage) | ||||
|             self.dataset.model = self | ||||
| 
 | ||||
|             print( | ||||
|                 "Total train duration", | ||||
|                 self.dataset.train_dataloader().dataset.__len__() | ||||
|                 * self.dataset.duration | ||||
|                 / 60, | ||||
|                 "minutes", | ||||
|             ) | ||||
|             print( | ||||
|                 "Total validation duration", | ||||
|                 self.dataset.val_dataloader().dataset.__len__() | ||||
|                 * self.dataset.duration | ||||
|                 / 60, | ||||
|                 "minutes", | ||||
|             ) | ||||
|             print( | ||||
|                 "Total test duration", | ||||
|                 self.dataset.test_dataloader().dataset.__len__() | ||||
|                 * self.dataset.duration | ||||
|                 / 60, | ||||
|                 "minutes", | ||||
|             ) | ||||
| 
 | ||||
|     def train_dataloader(self): | ||||
|         return self.dataset.train_dataloader() | ||||
| 
 | ||||
|     def val_dataloader(self): | ||||
|         return self.dataset.val_dataloader() | ||||
| 
 | ||||
|     def test_dataloader(self): | ||||
|         return self.dataset.test_dataloader() | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         return Adam(self.parameters(), lr=self.hparams.lr) | ||||
| 
 | ||||
|  | @ -162,86 +117,59 @@ class Mayamodel(pl.LightningModule): | |||
|         mixed_waveform = batch["noisy"] | ||||
|         target = batch["clean"] | ||||
|         prediction = self(mixed_waveform) | ||||
| 
 | ||||
|         loss = self.loss(prediction, target) | ||||
| 
 | ||||
|         self.log( | ||||
|             "train_loss", | ||||
|             loss.item(), | ||||
|             on_epoch=True, | ||||
|             on_step=True, | ||||
|             logger=True, | ||||
|             prog_bar=True, | ||||
|         if self.logger: | ||||
|             self.logger.experiment.log_metric( | ||||
|                 run_id=self.logger.run_id, | ||||
|                 key="train_loss", | ||||
|                 value=loss.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
| 
 | ||||
|         self.log("train_loss", loss.item()) | ||||
|         return {"loss": loss} | ||||
| 
 | ||||
|     def validation_step(self, batch, batch_idx: int): | ||||
| 
 | ||||
|         metric_dict = {} | ||||
|         mixed_waveform = batch["noisy"] | ||||
|         target = batch["clean"] | ||||
|         prediction = self(mixed_waveform) | ||||
| 
 | ||||
|         metric_dict["valid_loss"] = self.loss(target, prediction).item() | ||||
|         for metric in self.metric: | ||||
|             value = metric(target, prediction) | ||||
|             metric_dict[f"valid_{metric.name}"] = value.item() | ||||
|         metric_val = self.metric(prediction, target) | ||||
|         loss_val = self.loss(prediction, target) | ||||
|         self.log("val_metric", metric_val.item()) | ||||
|         self.log("val_loss", loss_val.item()) | ||||
| 
 | ||||
|         self.log_dict( | ||||
|             metric_dict, | ||||
|             on_step=True, | ||||
|             on_epoch=True, | ||||
|             prog_bar=True, | ||||
|             logger=True, | ||||
|         if self.logger: | ||||
|             self.logger.experiment.log_metric( | ||||
|                 run_id=self.logger.run_id, | ||||
|                 key="val_loss", | ||||
|                 value=loss_val.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
|             self.logger.experiment.log_metric( | ||||
|                 run_id=self.logger.run_id, | ||||
|                 key="val_metric", | ||||
|                 value=metric_val.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
| 
 | ||||
|         return metric_dict | ||||
| 
 | ||||
|     def test_step(self, batch, batch_idx): | ||||
| 
 | ||||
|         metric_dict = {} | ||||
|         mixed_waveform = batch["noisy"] | ||||
|         target = batch["clean"] | ||||
|         prediction = self(mixed_waveform) | ||||
| 
 | ||||
|         for metric in self.metric: | ||||
|             value = metric(target, prediction) | ||||
|             metric_dict[f"test_{metric.name}"] = value | ||||
| 
 | ||||
|         self.log_dict( | ||||
|             metric_dict, | ||||
|             on_step=True, | ||||
|             on_epoch=True, | ||||
|             prog_bar=True, | ||||
|             logger=True, | ||||
|         ) | ||||
| 
 | ||||
|         return metric_dict | ||||
| 
 | ||||
|     def test_epoch_end(self, outputs): | ||||
| 
 | ||||
|         test_mean_metrics = defaultdict(int) | ||||
|         for output in outputs: | ||||
|             for metric, value in output.items(): | ||||
|                 test_mean_metrics[metric] += value.item() | ||||
|         for metric in test_mean_metrics.keys(): | ||||
|             test_mean_metrics[metric] /= len(outputs) | ||||
| 
 | ||||
|         print("----------TEST REPORT----------\n") | ||||
|         for metric in test_mean_metrics.keys(): | ||||
|             print(f"|{metric.upper()} | {test_mean_metrics[metric]} |") | ||||
|         print("--------------------------------") | ||||
|         return {"loss": loss_val} | ||||
| 
 | ||||
|     def on_save_checkpoint(self, checkpoint): | ||||
| 
 | ||||
|         checkpoint[SAVE_NAME] = { | ||||
|             "version": {SAVE_NAME: __version__, "pytorch": torch.__version__}, | ||||
|         checkpoint["enhancer"] = { | ||||
|             "version": {"enhancer": __version__, "pytorch": torch.__version__}, | ||||
|             "architecture": { | ||||
|                 "module": self.__class__.__module__, | ||||
|                 "class": self.__class__.__name__, | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     def on_load_checkpoint(self, checkpoint: Dict[str, Any]): | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_pretrained( | ||||
|         cls, | ||||
|  | @ -281,15 +209,16 @@ class Mayamodel(pl.LightningModule): | |||
|             to True or to a string containing your hugginface.co authentication | ||||
|             token that can be obtained by running `huggingface-cli login` | ||||
|         cache_dir: Path or str, optional | ||||
|             Path to model cache directory | ||||
|             Path to model cache directory. Defaults to content of PYANNOTE_CACHE | ||||
|             environment variable, or "~/.cache/torch/pyannote" when unset. | ||||
|         kwargs: optional | ||||
|             Any extra keyword args needed to init the model. | ||||
|             Can also be used to override saved hyperparameter values. | ||||
| 
 | ||||
|         Returns | ||||
|         ------- | ||||
|         model : Mayamodel | ||||
|             Mayamodel | ||||
|         model : Model | ||||
|             Model | ||||
| 
 | ||||
|         See also | ||||
|         -------- | ||||
|  | @ -318,7 +247,7 @@ class Mayamodel(pl.LightningModule): | |||
|             ) | ||||
|             model_path_pl = cached_download( | ||||
|                 url=url, | ||||
|                 library_name="mayavoz", | ||||
|                 library_name="enhancer", | ||||
|                 library_version=__version__, | ||||
|                 cache_dir=cached_dir, | ||||
|                 use_auth_token=use_auth_token, | ||||
|  | @ -328,8 +257,8 @@ class Mayamodel(pl.LightningModule): | |||
|             map_location = torch.device(DEFAULT_DEVICE) | ||||
| 
 | ||||
|         loaded_checkpoint = pl_load(model_path_pl, map_location) | ||||
|         module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"] | ||||
|         class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"] | ||||
|         module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] | ||||
|         class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] | ||||
|         module = import_module(module_name) | ||||
|         Klass = getattr(module, class_name) | ||||
| 
 | ||||
|  | @ -361,9 +290,10 @@ class Mayamodel(pl.LightningModule): | |||
|         ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" | ||||
|         batch_predictions = [] | ||||
|         self.eval().to(self.device) | ||||
| 
 | ||||
|         with torch.no_grad(): | ||||
|             for batch_id in range(0, batch.shape[0], batch_size): | ||||
|                 batch_data = batch[batch_id : (batch_id + batch_size), :, :].to( | ||||
|                 batch_data = batch[batch_id : batch_id + batch_size, :, :].to( | ||||
|                     self.device | ||||
|                 ) | ||||
|                 prediction = self(batch_data) | ||||
|  | @ -1,12 +1,12 @@ | |||
| import warnings | ||||
| import logging | ||||
| from typing import List, Optional, Union | ||||
| 
 | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| 
 | ||||
| from mayavoz.data.dataset import MayaDataset | ||||
| from mayavoz.models.model import Mayamodel | ||||
| from enhancer.data.dataset import EnhancerDataset | ||||
| from enhancer.models.model import Model | ||||
| 
 | ||||
| 
 | ||||
| class WavenetDecoder(nn.Module): | ||||
|  | @ -66,7 +66,7 @@ class WavenetEncoder(nn.Module): | |||
|         return self.encoder(waveform) | ||||
| 
 | ||||
| 
 | ||||
| class WaveUnet(Mayamodel): | ||||
| class WaveUnet(Model): | ||||
|     """ | ||||
|     Wave-U-Net model from  https://arxiv.org/pdf/1811.11307.pdf | ||||
|     parameters: | ||||
|  | @ -80,8 +80,8 @@ class WaveUnet(Mayamodel): | |||
|             sampling rate of input audio | ||||
|         lr : float, defaults to 1e-3 | ||||
|             learning rate used for training | ||||
|         dataset: MayaDataset, optional | ||||
|             MayaDataset object containing train/validation data for training | ||||
|         dataset: EnhancerDataset, optional | ||||
|             EnhancerDataset object containing train/validation data for training | ||||
|         duration : float, optional | ||||
|             chunk duration in seconds | ||||
|         loss : string or List of strings | ||||
|  | @ -97,17 +97,17 @@ class WaveUnet(Mayamodel): | |||
|         initial_output_channels: int = 24, | ||||
|         sampling_rate: int = 16000, | ||||
|         lr: float = 1e-3, | ||||
|         dataset: Optional[MayaDataset] = None, | ||||
|         dataset: Optional[EnhancerDataset] = None, | ||||
|         duration: Optional[float] = None, | ||||
|         loss: Union[str, List] = "mse", | ||||
|         metric: Union[str, List] = "mse", | ||||
|     ): | ||||
|         duration = ( | ||||
|             dataset.duration if isinstance(dataset, MayaDataset) else duration | ||||
|             dataset.duration if isinstance(dataset, EnhancerDataset) else None | ||||
|         ) | ||||
|         if dataset is not None: | ||||
|             if sampling_rate != dataset.sampling_rate: | ||||
|                 warnings.warn( | ||||
|                 logging.warn( | ||||
|                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||
|                 ) | ||||
|                 sampling_rate = dataset.sampling_rate | ||||
|  | @ -0,0 +1,3 @@ | |||
| from enhancer.utils.config import Files | ||||
| from enhancer.utils.io import Audio | ||||
| from enhancer.utils.utils import check_files | ||||
|  | @ -70,7 +70,7 @@ class Audio: | |||
| 
 | ||||
|         if sampling_rate: | ||||
|             audio = self.__class__.resample_audio( | ||||
|                 audio, sampling_rate, self.sampling_rate | ||||
|                 audio, self.sampling_rate, sampling_rate | ||||
|             ) | ||||
|         if self.return_tensor: | ||||
|             return torch.tensor(audio) | ||||
|  | @ -1,7 +1,7 @@ | |||
| import os | ||||
| from typing import Optional | ||||
| 
 | ||||
| from mayavoz.utils.config import Files | ||||
| from enhancer.utils.config import Files | ||||
| 
 | ||||
| 
 | ||||
| def check_files(root_dir: str, files: Files): | ||||
|  | @ -1,4 +1,4 @@ | |||
| name: mayavoz | ||||
| name: enhancer | ||||
| 
 | ||||
| dependencies: | ||||
|   - pip=21.0.1 | ||||
|  |  | |||
|  | @ -0,0 +1,39 @@ | |||
| #!/bin/bash | ||||
| set -e | ||||
| 
 | ||||
| echo '----------------------------------------------------' | ||||
| echo ' SLURM_CLUSTER_NAME = '$SLURM_CLUSTER_NAME | ||||
| echo '    SLURMD_NODENAME = '$SLURMD_NODENAME | ||||
| echo '        SLURM_JOBID = '$SLURM_JOBID | ||||
| echo '     SLURM_JOB_USER = '$SLURM_JOB_USER | ||||
| echo '    SLURM_PARTITION = '$SLURM_JOB_PARTITION | ||||
| echo '  SLURM_JOB_ACCOUNT = '$SLURM_JOB_ACCOUNT | ||||
| echo '----------------------------------------------------' | ||||
| 
 | ||||
| #TeamCity Output | ||||
| cat << EOF | ||||
| ##teamcity[buildNumber '$SLURM_JOBID'] | ||||
| EOF | ||||
| 
 | ||||
| echo "Load HPC modules" | ||||
| module load anaconda | ||||
| 
 | ||||
| echo "Activate Environment" | ||||
| source activate enhancer | ||||
| export TRANSFORMERS_OFFLINE=True | ||||
| export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer | ||||
| export HYDRA_FULL_ERROR=1 | ||||
| 
 | ||||
| echo $PYTHONPATH | ||||
| 
 | ||||
| source ~/mlflow_settings.sh | ||||
| 
 | ||||
| echo "Making temp dir" | ||||
| mkdir temp | ||||
| pwd | ||||
| 
 | ||||
| #python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train | ||||
| #python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test | ||||
| 
 | ||||
| echo "Start Training..." | ||||
| python cli/train.py | ||||
|  | @ -1,2 +0,0 @@ | |||
| __import__("pkg_resources").declare_namespace(__name__) | ||||
| from mayavoz.models import Mayamodel | ||||
|  | @ -1,120 +0,0 @@ | |||
| import os | ||||
| from types import MethodType | ||||
| 
 | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig, OmegaConf | ||||
| from pytorch_lightning.callbacks import ( | ||||
|     EarlyStopping, | ||||
|     LearningRateMonitor, | ||||
|     ModelCheckpoint, | ||||
| ) | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| 
 | ||||
| # from torch_audiomentations import Compose, Shift | ||||
| 
 | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||
| 
 | ||||
| 
 | ||||
| @hydra.main(config_path="train_config", config_name="config") | ||||
| def train(config: DictConfig): | ||||
| 
 | ||||
|     OmegaConf.save(config, "config.yaml") | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger( | ||||
|         experiment_name=config.mlflow.experiment_name, | ||||
|         run_name=config.mlflow.run_name, | ||||
|         tags={"JOB_ID": JOB_ID}, | ||||
|     ) | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
|     # apply_augmentations = Compose( | ||||
|     #     [ | ||||
|     #         Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), | ||||
|     #     ] | ||||
|     # ) | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset, augmentations=None) | ||||
|     model = instantiate( | ||||
|         config.model, | ||||
|         dataset=dataset, | ||||
|         lr=parameters.get("lr"), | ||||
|         loss=parameters.get("loss"), | ||||
|         metric=parameters.get("metric"), | ||||
|     ) | ||||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model", | ||||
|         filename=f"model_{JOB_ID}", | ||||
|         monitor="valid_loss", | ||||
|         verbose=False, | ||||
|         mode=direction, | ||||
|         every_n_epochs=1, | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     callbacks.append(LearningRateMonitor(logging_interval="epoch")) | ||||
| 
 | ||||
|     if parameters.get("Early_stop", False): | ||||
|         early_stopping = EarlyStopping( | ||||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=parameters.get("EarlyStopping_patience", 10), | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|         callbacks.append(early_stopping) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         optimizer = instantiate( | ||||
|             config.optimizer, | ||||
|             lr=parameters.get("lr"), | ||||
|             params=self.parameters(), | ||||
|         ) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor", 0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr", 1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience", 3), | ||||
|         ) | ||||
|         return { | ||||
|             "optimizer": optimizer, | ||||
|             "lr_scheduler": scheduler, | ||||
|             "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', | ||||
|         } | ||||
| 
 | ||||
|     model.configure_optimizers = MethodType(configure_optimizers, model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     trainer.test(model) | ||||
| 
 | ||||
|     logger.experiment.log_artifact( | ||||
|         logger.run_id, f"{trainer.default_root_dir}/config.yaml" | ||||
|     ) | ||||
| 
 | ||||
|     saved_location = os.path.join( | ||||
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||
|     ) | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_train_steps_per_epoch", | ||||
|             dataset.train__len__() / dataset.batch_size, | ||||
|         ) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_valid_steps_per_epoch", | ||||
|             dataset.val__len__() / dataset.batch_size, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     train() | ||||
|  | @ -1,7 +0,0 @@ | |||
| defaults: | ||||
|   - model : Demucs | ||||
|   - dataset : Vctk | ||||
|   - optimizer : Adam | ||||
|   - hyperparameters : default | ||||
|   - trainer : default | ||||
|   - mlflow : experiment | ||||
|  | @ -1,12 +0,0 @@ | |||
| _target_: mayavoz.data.dataset.MayaDataset | ||||
| name : MS-SDSD | ||||
| root_dir : /Users/shahules/Myprojects/MS-SNSD | ||||
| duration : 2.0 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 32 | ||||
| min_valid_minutes: 15 | ||||
| files: | ||||
|   train_clean : CleanSpeech_training | ||||
|   test_clean : CleanSpeech_training | ||||
|   train_noisy : NoisySpeech_training | ||||
|   test_noisy : NoisySpeech_training | ||||
|  | @ -1,13 +0,0 @@ | |||
| _target_: mayavoz.data.dataset.MayaDataset | ||||
| name : Valentini | ||||
| root_dir : /scratch/c.sistc3/DS_10283_2791 | ||||
| duration : 4.5 | ||||
| stride : 2 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 32 | ||||
| valid_minutes : 15 | ||||
| files: | ||||
|   train_clean : clean_trainset_28spk_wav | ||||
|   test_clean : clean_testset_wav | ||||
|   train_noisy : noisy_trainset_28spk_wav | ||||
|   test_noisy : noisy_testset_wav | ||||
|  | @ -1,7 +0,0 @@ | |||
| loss : mae | ||||
| metric : [stoi,pesq,si-sdr] | ||||
| lr : 0.0003 | ||||
| ReduceLr_patience : 5 | ||||
| ReduceLr_factor : 0.2 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
|  | @ -1,2 +0,0 @@ | |||
| experiment_name : shahules/mayavoz | ||||
| run_name : Demucs + Vtck with stride + augmentations | ||||
|  | @ -1,25 +0,0 @@ | |||
| _target_: mayavoz.models.dccrn.DCCRN | ||||
| num_channels: 1 | ||||
| sampling_rate : 16000 | ||||
| complex_lstm : True | ||||
| complex_norm : True | ||||
| complex_relu : True | ||||
| masking_mode : True | ||||
| 
 | ||||
| encoder_decoder: | ||||
|   initial_output_channels : 32 | ||||
|   depth : 6 | ||||
|   kernel_size : 5 | ||||
|   growth_factor : 2 | ||||
|   stride : 2 | ||||
|   padding : 2 | ||||
|   output_padding : 1 | ||||
| 
 | ||||
| lstm: | ||||
|   num_layers : 2 | ||||
|   hidden_size : 256 | ||||
| 
 | ||||
| stft: | ||||
|   window_len : 400 | ||||
|   hop_size : 100 | ||||
|   nfft : 512 | ||||
|  | @ -1,46 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| accelerator: gpu | ||||
| accumulate_grad_batches: 1 | ||||
| amp_backend: native | ||||
| auto_lr_find: True | ||||
| auto_scale_batch_size: False | ||||
| auto_select_gpus: True | ||||
| benchmark: False | ||||
| check_val_every_n_epoch: 1 | ||||
| detect_anomaly: False | ||||
| deterministic: False | ||||
| devices: 2 | ||||
| enable_checkpointing: True | ||||
| enable_model_summary: True | ||||
| enable_progress_bar: True | ||||
| fast_dev_run: False | ||||
| gpus: null | ||||
| gradient_clip_val: 0 | ||||
| gradient_clip_algorithm: norm | ||||
| ipus: null | ||||
| limit_predict_batches: 1.0 | ||||
| limit_test_batches: 1.0 | ||||
| limit_train_batches: 1.0 | ||||
| limit_val_batches: 1.0 | ||||
| log_every_n_steps: 50 | ||||
| max_epochs: 200 | ||||
| max_steps: -1 | ||||
| max_time: null | ||||
| min_epochs: 1 | ||||
| min_steps: null | ||||
| move_metrics_to_cpu: False | ||||
| multiple_trainloader_mode: max_size_cycle | ||||
| num_nodes: 1 | ||||
| num_processes: 1 | ||||
| num_sanity_val_steps: 2 | ||||
| overfit_batches: 0.0 | ||||
| precision: 32 | ||||
| profiler: null | ||||
| reload_dataloaders_every_n_epochs: 0 | ||||
| replace_sampler_ddp: True | ||||
| strategy: ddp | ||||
| sync_batchnorm: False | ||||
| tpu_cores: null | ||||
| track_grad_norm: -1 | ||||
| val_check_interval: 1.0 | ||||
| weights_save_path: null | ||||
|  | @ -1 +0,0 @@ | |||
| from mayavoz.data.dataset import MayaDataset | ||||
|  | @ -1,393 +0,0 @@ | |||
| import math | ||||
| import multiprocessing | ||||
| import os | ||||
| import sys | ||||
| import warnings | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
| 
 | ||||
| import numpy as np | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch.utils.data import DataLoader, Dataset, RandomSampler | ||||
| from torch_audiomentations import Compose | ||||
| 
 | ||||
| from mayavoz.data.fileprocessor import Fileprocessor | ||||
| from mayavoz.utils import check_files | ||||
| from mayavoz.utils.config import Files | ||||
| from mayavoz.utils.io import Audio | ||||
| from mayavoz.utils.random import create_unique_rng | ||||
| 
 | ||||
| LARGE_NUM = 2147483647 | ||||
| 
 | ||||
| 
 | ||||
| class TrainDataset(Dataset): | ||||
|     def __init__(self, dataset): | ||||
|         self.dataset = dataset | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         return self.dataset.train__getitem__(idx) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self.dataset.train__len__() | ||||
| 
 | ||||
| 
 | ||||
| class ValidDataset(Dataset): | ||||
|     def __init__(self, dataset): | ||||
|         self.dataset = dataset | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         return self.dataset.val__getitem__(idx) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self.dataset.val__len__() | ||||
| 
 | ||||
| 
 | ||||
| class TestDataset(Dataset): | ||||
|     def __init__(self, dataset): | ||||
|         self.dataset = dataset | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         return self.dataset.test__getitem__(idx) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self.dataset.test__len__() | ||||
| 
 | ||||
| 
 | ||||
| class TaskDataset(pl.LightningDataModule): | ||||
|     def __init__( | ||||
|         self, | ||||
|         name: str, | ||||
|         root_dir: str, | ||||
|         files: Files, | ||||
|         min_valid_minutes: float = 0.20, | ||||
|         duration: float = 1.0, | ||||
|         stride=None, | ||||
|         sampling_rate: int = 48000, | ||||
|         matching_function=None, | ||||
|         batch_size=32, | ||||
|         num_workers: Optional[int] = None, | ||||
|         augmentations: Optional[Compose] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.files, self.root_dir = check_files(root_dir, files) | ||||
|         self.duration = duration | ||||
|         self.stride = stride or duration | ||||
|         self.sampling_rate = sampling_rate | ||||
|         self.batch_size = batch_size | ||||
|         self.matching_function = matching_function | ||||
|         self._validation = [] | ||||
|         if num_workers is None: | ||||
|             num_workers = multiprocessing.cpu_count() // 2 | ||||
|         if num_workers is None: | ||||
|             num_workers = multiprocessing.cpu_count() // 2 | ||||
| 
 | ||||
|         if ( | ||||
|             num_workers > 0 | ||||
|             and sys.platform == "darwin" | ||||
|             and sys.version_info[0] >= 3 | ||||
|             and sys.version_info[1] >= 8 | ||||
|         ): | ||||
|             warnings.warn( | ||||
|                 "num_workers > 0 is not supported with macOS and Python 3.8+: " | ||||
|                 "setting num_workers = 0." | ||||
|             ) | ||||
|             num_workers = 0 | ||||
| 
 | ||||
|         self.num_workers = num_workers | ||||
|         if min_valid_minutes > 0.0: | ||||
|             self.min_valid_minutes = min_valid_minutes | ||||
|         else: | ||||
|             raise ValueError("min_valid_minutes must be greater than 0") | ||||
| 
 | ||||
|         self.augmentations = augmentations | ||||
| 
 | ||||
|     def setup(self, stage: Optional[str] = None): | ||||
|         """ | ||||
|         prepare train/validation/test data splits | ||||
|         """ | ||||
| 
 | ||||
|         if stage in ("fit", None): | ||||
| 
 | ||||
|             train_clean = os.path.join(self.root_dir, self.files.train_clean) | ||||
|             train_noisy = os.path.join(self.root_dir, self.files.train_noisy) | ||||
|             fp = Fileprocessor.from_name( | ||||
|                 self.name, train_clean, train_noisy, self.matching_function | ||||
|             ) | ||||
|             train_data = fp.prepare_matching_dict() | ||||
|             train_data, self.val_data = self.train_valid_split( | ||||
|                 train_data, | ||||
|                 min_valid_minutes=self.min_valid_minutes, | ||||
|                 random_state=42, | ||||
|             ) | ||||
| 
 | ||||
|             self.train_data = self.prepare_traindata(train_data) | ||||
|             self._validation = self.prepare_mapstype(self.val_data) | ||||
| 
 | ||||
|             test_clean = os.path.join(self.root_dir, self.files.test_clean) | ||||
|             test_noisy = os.path.join(self.root_dir, self.files.test_noisy) | ||||
|             fp = Fileprocessor.from_name( | ||||
|                 self.name, test_clean, test_noisy, self.matching_function | ||||
|             ) | ||||
|             test_data = fp.prepare_matching_dict() | ||||
|             self._test = self.prepare_mapstype(test_data) | ||||
| 
 | ||||
|     def train_valid_split( | ||||
|         self, data, min_valid_minutes: float = 20, random_state: int = 42 | ||||
|     ): | ||||
| 
 | ||||
|         min_valid_minutes *= 60 | ||||
|         valid_sec_now = 0.0 | ||||
|         valid_indices = [] | ||||
|         all_speakers = np.unique( | ||||
|             [Path(file["clean"]).name.split("_")[0] for file in data] | ||||
|         ) | ||||
|         possible_indices = list(range(0, len(all_speakers))) | ||||
|         rng = create_unique_rng(len(all_speakers)) | ||||
| 
 | ||||
|         while valid_sec_now <= min_valid_minutes: | ||||
|             speaker_index = rng.choice(possible_indices) | ||||
|             possible_indices.remove(speaker_index) | ||||
|             speaker_name = all_speakers[speaker_index] | ||||
|             print(f"Selected f{speaker_name} for valid") | ||||
|             file_indices = [ | ||||
|                 i | ||||
|                 for i, file in enumerate(data) | ||||
|                 if speaker_name == Path(file["clean"]).name.split("_")[0] | ||||
|             ] | ||||
|             for i in file_indices: | ||||
|                 valid_indices.append(i) | ||||
|                 valid_sec_now += data[i]["duration"] | ||||
| 
 | ||||
|         train_data = [ | ||||
|             item for i, item in enumerate(data) if i not in valid_indices | ||||
|         ] | ||||
|         valid_data = [item for i, item in enumerate(data) if i in valid_indices] | ||||
|         return train_data, valid_data | ||||
| 
 | ||||
|     def prepare_traindata(self, data): | ||||
|         train_data = [] | ||||
|         for item in data: | ||||
|             clean, noisy, total_dur = item.values() | ||||
|             num_segments = self.get_num_segments( | ||||
|                 total_dur, self.duration, self.stride | ||||
|             ) | ||||
|             samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments) | ||||
|             train_data.append(samples_metadata) | ||||
|         return train_data | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def get_num_segments(file_duration, duration, stride): | ||||
| 
 | ||||
|         if file_duration < duration: | ||||
|             num_segments = 1 | ||||
|         else: | ||||
|             num_segments = math.ceil((file_duration - duration) / stride) + 1 | ||||
| 
 | ||||
|         return num_segments | ||||
| 
 | ||||
|     def prepare_mapstype(self, data): | ||||
| 
 | ||||
|         metadata = [] | ||||
|         for item in data: | ||||
|             clean, noisy, total_dur = item.values() | ||||
|             if total_dur < self.duration: | ||||
|                 metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) | ||||
|             else: | ||||
|                 num_segments = self.get_num_segments( | ||||
|                     total_dur, self.duration, self.duration | ||||
|                 ) | ||||
|                 for index in range(num_segments): | ||||
|                     start_time = index * self.duration | ||||
|                     metadata.append( | ||||
|                         ({"clean": clean, "noisy": noisy}, start_time) | ||||
|                     ) | ||||
|         return metadata | ||||
| 
 | ||||
|     def train_collatefn(self, batch): | ||||
| 
 | ||||
|         output = {"clean": [], "noisy": []} | ||||
|         for item in batch: | ||||
|             output["clean"].append(item["clean"]) | ||||
|             output["noisy"].append(item["noisy"]) | ||||
| 
 | ||||
|         output["clean"] = torch.stack(output["clean"], dim=0) | ||||
|         output["noisy"] = torch.stack(output["noisy"], dim=0) | ||||
| 
 | ||||
|         if self.augmentations is not None: | ||||
|             noise = output["noisy"] - output["clean"] | ||||
|             output["clean"] = self.augmentations( | ||||
|                 output["clean"], sample_rate=self.sampling_rate | ||||
|             ) | ||||
|             self.augmentations.freeze_parameters() | ||||
|             output["noisy"] = ( | ||||
|                 self.augmentations(noise, sample_rate=self.sampling_rate) | ||||
|                 + output["clean"] | ||||
|             ) | ||||
| 
 | ||||
|         return output | ||||
| 
 | ||||
|     @property | ||||
|     def generator(self): | ||||
|         generator = torch.Generator() | ||||
|         if hasattr(self, "model"): | ||||
|             seed = self.model.current_epoch + LARGE_NUM | ||||
|         else: | ||||
|             seed = LARGE_NUM | ||||
|         return generator.manual_seed(seed) | ||||
| 
 | ||||
|     def train_dataloader(self): | ||||
|         dataset = TrainDataset(self) | ||||
|         sampler = RandomSampler(dataset, generator=self.generator) | ||||
|         return DataLoader( | ||||
|             dataset, | ||||
|             batch_size=self.batch_size, | ||||
|             num_workers=self.num_workers, | ||||
|             sampler=sampler, | ||||
|             collate_fn=self.train_collatefn, | ||||
|         ) | ||||
| 
 | ||||
|     def val_dataloader(self): | ||||
|         return DataLoader( | ||||
|             ValidDataset(self), | ||||
|             batch_size=self.batch_size, | ||||
|             num_workers=self.num_workers, | ||||
|         ) | ||||
| 
 | ||||
|     def test_dataloader(self): | ||||
|         return DataLoader( | ||||
|             TestDataset(self), | ||||
|             batch_size=self.batch_size, | ||||
|             num_workers=self.num_workers, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class MayaDataset(TaskDataset): | ||||
|     """ | ||||
|     Dataset object for creating clean-noisy speech enhancement datasets | ||||
|     paramters: | ||||
|     name : str | ||||
|         name of the dataset | ||||
|     root_dir : str | ||||
|         root directory of the dataset containing clean/noisy folders | ||||
|     files : Files | ||||
|         dataclass containing train_clean, train_noisy, test_clean, test_noisy | ||||
|         folder names (refer mayavoz.utils.Files dataclass) | ||||
|     min_valid_minutes: float | ||||
|         minimum validation split size time in minutes | ||||
|         algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data. | ||||
|     duration : float | ||||
|         expected audio duration of single audio sample for training | ||||
|     sampling_rate : int | ||||
|         desired sampling rate | ||||
|     batch_size : int | ||||
|         batch size of each batch | ||||
|     num_workers : int | ||||
|         num workers to be used while training | ||||
|     matching_function : str | ||||
|         maching functions - (one_to_one,one_to_many). Default set to None. | ||||
|         use one_to_one mapping for datasets with one noisy file for each clean file | ||||
|         use one_to_many mapping for multiple noisy files for each clean file | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         name: str, | ||||
|         root_dir: str, | ||||
|         files: Files, | ||||
|         min_valid_minutes=5.0, | ||||
|         duration=1.0, | ||||
|         stride=None, | ||||
|         sampling_rate=48000, | ||||
|         matching_function=None, | ||||
|         batch_size=32, | ||||
|         num_workers: Optional[int] = None, | ||||
|         augmentations: Optional[Compose] = None, | ||||
|     ): | ||||
| 
 | ||||
|         super().__init__( | ||||
|             name=name, | ||||
|             root_dir=root_dir, | ||||
|             files=files, | ||||
|             min_valid_minutes=min_valid_minutes, | ||||
|             sampling_rate=sampling_rate, | ||||
|             duration=duration, | ||||
|             matching_function=matching_function, | ||||
|             batch_size=batch_size, | ||||
|             num_workers=num_workers, | ||||
|             augmentations=augmentations, | ||||
|         ) | ||||
| 
 | ||||
|         self.sampling_rate = sampling_rate | ||||
|         self.files = files | ||||
|         self.duration = max(1.0, duration) | ||||
|         self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) | ||||
|         self.stride = stride or duration | ||||
| 
 | ||||
|     def setup(self, stage: Optional[str] = None): | ||||
| 
 | ||||
|         super().setup(stage=stage) | ||||
| 
 | ||||
|     def train__getitem__(self, idx): | ||||
| 
 | ||||
|         for filedict, num_samples in self.train_data: | ||||
|             if idx >= num_samples: | ||||
|                 idx -= num_samples | ||||
|                 continue | ||||
|             else: | ||||
|                 start = 0 | ||||
|                 if self.duration is not None: | ||||
|                     start = idx * self.stride | ||||
|                 return self.prepare_segment(filedict, start) | ||||
| 
 | ||||
|     def val__getitem__(self, idx): | ||||
|         return self.prepare_segment(*self._validation[idx]) | ||||
| 
 | ||||
|     def test__getitem__(self, idx): | ||||
|         return self.prepare_segment(*self._test[idx]) | ||||
| 
 | ||||
|     def prepare_segment(self, file_dict: dict, start_time: float): | ||||
|         clean_segment = self.audio( | ||||
|             file_dict["clean"], offset=start_time, duration=self.duration | ||||
|         ) | ||||
|         noisy_segment = self.audio( | ||||
|             file_dict["noisy"], offset=start_time, duration=self.duration | ||||
|         ) | ||||
|         clean_segment = F.pad( | ||||
|             clean_segment, | ||||
|             ( | ||||
|                 0, | ||||
|                 int( | ||||
|                     self.duration * self.sampling_rate - clean_segment.shape[-1] | ||||
|                 ), | ||||
|             ), | ||||
|         ) | ||||
|         noisy_segment = F.pad( | ||||
|             noisy_segment, | ||||
|             ( | ||||
|                 0, | ||||
|                 int( | ||||
|                     self.duration * self.sampling_rate - noisy_segment.shape[-1] | ||||
|                 ), | ||||
|             ), | ||||
|         ) | ||||
|         return { | ||||
|             "clean": clean_segment, | ||||
|             "noisy": noisy_segment, | ||||
|         } | ||||
| 
 | ||||
|     def train__len__(self): | ||||
|         _, num_examples = list(zip(*self.train_data)) | ||||
|         return sum(num_examples) | ||||
| 
 | ||||
|     def val__len__(self): | ||||
|         return len(self._validation) | ||||
| 
 | ||||
|     def test__len__(self): | ||||
|         return len(self._test) | ||||
|  | @ -1,3 +0,0 @@ | |||
| from mayavoz.models.demucs import Demucs | ||||
| from mayavoz.models.model import Mayamodel | ||||
| from mayavoz.models.waveunet import WaveUnet | ||||
|  | @ -1,5 +0,0 @@ | |||
| from mayavoz.models.complexnn.conv import ComplexConv2d  # noqa | ||||
| from mayavoz.models.complexnn.conv import ComplexConvTranspose2d  # noqa | ||||
| from mayavoz.models.complexnn.rnn import ComplexLSTM  # noqa | ||||
| from mayavoz.models.complexnn.utils import ComplexBatchNorm2D  # noqa | ||||
| from mayavoz.models.complexnn.utils import ComplexRelu  # noqa | ||||
|  | @ -1,136 +0,0 @@ | |||
| from typing import Tuple | ||||
| 
 | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
| 
 | ||||
| 
 | ||||
| def init_weights(nnet): | ||||
|     nn.init.xavier_normal_(nnet.weight.data) | ||||
|     nn.init.constant_(nnet.bias, 0.0) | ||||
|     return nnet | ||||
| 
 | ||||
| 
 | ||||
| class ComplexConv2d(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_channels: int, | ||||
|         out_channels: int, | ||||
|         kernel_size: Tuple[int, int] = (1, 1), | ||||
|         stride: Tuple[int, int] = (1, 1), | ||||
|         padding: Tuple[int, int] = (0, 0), | ||||
|         groups: int = 1, | ||||
|         dilation: int = 1, | ||||
|     ): | ||||
|         """ | ||||
|         Complex Conv2d (non-causal) | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.in_channels = in_channels // 2 | ||||
|         self.out_channels = out_channels // 2 | ||||
|         self.kernel_size = kernel_size | ||||
|         self.stride = stride | ||||
|         self.padding = padding | ||||
|         self.groups = groups | ||||
|         self.dilation = dilation | ||||
| 
 | ||||
|         self.real_conv = nn.Conv2d( | ||||
|             self.in_channels, | ||||
|             self.out_channels, | ||||
|             kernel_size=self.kernel_size, | ||||
|             stride=self.stride, | ||||
|             padding=(self.padding[0], 0), | ||||
|             groups=self.groups, | ||||
|             dilation=self.dilation, | ||||
|         ) | ||||
|         self.imag_conv = nn.Conv2d( | ||||
|             self.in_channels, | ||||
|             self.out_channels, | ||||
|             kernel_size=self.kernel_size, | ||||
|             stride=self.stride, | ||||
|             padding=(self.padding[0], 0), | ||||
|             groups=self.groups, | ||||
|             dilation=self.dilation, | ||||
|         ) | ||||
|         self.imag_conv = init_weights(self.imag_conv) | ||||
|         self.real_conv = init_weights(self.real_conv) | ||||
| 
 | ||||
|     def forward(self, input): | ||||
|         """ | ||||
|         complex axis should be always 1 dim | ||||
|         """ | ||||
|         input = F.pad(input, [self.padding[1], 0, 0, 0]) | ||||
| 
 | ||||
|         real, imag = torch.chunk(input, 2, 1) | ||||
| 
 | ||||
|         real_real = self.real_conv(real) | ||||
|         real_imag = self.imag_conv(real) | ||||
| 
 | ||||
|         imag_imag = self.imag_conv(imag) | ||||
|         imag_real = self.real_conv(imag) | ||||
| 
 | ||||
|         real = real_real - imag_imag | ||||
|         imag = real_imag - imag_real | ||||
| 
 | ||||
|         out = torch.cat([real, imag], 1) | ||||
|         return out | ||||
| 
 | ||||
| 
 | ||||
| class ComplexConvTranspose2d(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_channels: int, | ||||
|         out_channels: int, | ||||
|         kernel_size: Tuple[int, int] = (1, 1), | ||||
|         stride: Tuple[int, int] = (1, 1), | ||||
|         padding: Tuple[int, int] = (0, 0), | ||||
|         output_padding: Tuple[int, int] = (0, 0), | ||||
|         groups: int = 1, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.in_channels = in_channels // 2 | ||||
|         self.out_channels = out_channels // 2 | ||||
|         self.kernel_size = kernel_size | ||||
|         self.stride = stride | ||||
|         self.padding = padding | ||||
|         self.groups = groups | ||||
|         self.output_padding = output_padding | ||||
| 
 | ||||
|         self.real_conv = nn.ConvTranspose2d( | ||||
|             self.in_channels, | ||||
|             self.out_channels, | ||||
|             kernel_size=self.kernel_size, | ||||
|             stride=self.stride, | ||||
|             padding=self.padding, | ||||
|             output_padding=self.output_padding, | ||||
|             groups=self.groups, | ||||
|         ) | ||||
| 
 | ||||
|         self.imag_conv = nn.ConvTranspose2d( | ||||
|             self.in_channels, | ||||
|             self.out_channels, | ||||
|             kernel_size=self.kernel_size, | ||||
|             stride=self.stride, | ||||
|             padding=self.padding, | ||||
|             output_padding=self.output_padding, | ||||
|             groups=self.groups, | ||||
|         ) | ||||
| 
 | ||||
|         self.real_conv = init_weights(self.real_conv) | ||||
|         self.imag_conv = init_weights(self.imag_conv) | ||||
| 
 | ||||
|     def forward(self, input): | ||||
| 
 | ||||
|         real, imag = torch.chunk(input, 2, 1) | ||||
|         real_real = self.real_conv(real) | ||||
|         real_imag = self.imag_conv(real) | ||||
| 
 | ||||
|         imag_imag = self.imag_conv(imag) | ||||
|         imag_real = self.real_conv(imag) | ||||
| 
 | ||||
|         real = real_real - imag_imag | ||||
|         imag = real_imag + imag_real | ||||
| 
 | ||||
|         out = torch.cat([real, imag], 1) | ||||
| 
 | ||||
|         return out | ||||
|  | @ -1,68 +0,0 @@ | |||
| from typing import List, Optional | ||||
| 
 | ||||
| import torch | ||||
| from torch import nn | ||||
| 
 | ||||
| 
 | ||||
| class ComplexLSTM(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_size: int, | ||||
|         hidden_size: int, | ||||
|         num_layers: int = 1, | ||||
|         projection_size: Optional[int] = None, | ||||
|         bidirectional: bool = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.input_size = input_size // 2 | ||||
|         self.hidden_size = hidden_size // 2 | ||||
|         self.num_layers = num_layers | ||||
| 
 | ||||
|         self.real_lstm = nn.LSTM( | ||||
|             self.input_size, | ||||
|             self.hidden_size, | ||||
|             self.num_layers, | ||||
|             bidirectional=bidirectional, | ||||
|             batch_first=False, | ||||
|         ) | ||||
|         self.imag_lstm = nn.LSTM( | ||||
|             self.input_size, | ||||
|             self.hidden_size, | ||||
|             self.num_layers, | ||||
|             bidirectional=bidirectional, | ||||
|             batch_first=False, | ||||
|         ) | ||||
| 
 | ||||
|         bidirectional = 2 if bidirectional else 1 | ||||
|         if projection_size is not None: | ||||
|             self.projection_size = projection_size // 2 | ||||
|             self.real_linear = nn.Linear( | ||||
|                 self.hidden_size * bidirectional, self.projection_size | ||||
|             ) | ||||
|             self.imag_linear = nn.Linear( | ||||
|                 self.hidden_size * bidirectional, self.projection_size | ||||
|             ) | ||||
|         else: | ||||
|             self.projection_size = None | ||||
| 
 | ||||
|     def forward(self, input): | ||||
| 
 | ||||
|         if isinstance(input, List): | ||||
|             real, imag = input | ||||
|         else: | ||||
|             real, imag = torch.chunk(input, 2, 1) | ||||
| 
 | ||||
|         real_real = self.real_lstm(real)[0] | ||||
|         real_imag = self.imag_lstm(real)[0] | ||||
| 
 | ||||
|         imag_imag = self.imag_lstm(imag)[0] | ||||
|         imag_real = self.real_lstm(imag)[0] | ||||
| 
 | ||||
|         real = real_real - imag_imag | ||||
|         imag = imag_real + real_imag | ||||
| 
 | ||||
|         if self.projection_size is not None: | ||||
|             real = self.real_linear(real) | ||||
|             imag = self.imag_linear(imag) | ||||
| 
 | ||||
|         return [real, imag] | ||||
|  | @ -1,199 +0,0 @@ | |||
| import torch | ||||
| from torch import nn | ||||
| 
 | ||||
| 
 | ||||
| class ComplexBatchNorm2D(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_features: int, | ||||
|         eps: float = 1e-5, | ||||
|         momentum: float = 0.1, | ||||
|         affine: bool = True, | ||||
|         track_running_stats: bool = True, | ||||
|     ): | ||||
|         """ | ||||
|         Complex batch normalization 2D | ||||
|         https://arxiv.org/abs/1705.09792 | ||||
| 
 | ||||
| 
 | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.num_features = num_features // 2 | ||||
|         self.affine = affine | ||||
|         self.momentum = momentum | ||||
|         self.track_running_stats = track_running_stats | ||||
|         self.eps = eps | ||||
| 
 | ||||
|         if self.affine: | ||||
|             self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|             self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features)) | ||||
|         else: | ||||
|             self.register_parameter("Wrr", None) | ||||
|             self.register_parameter("Wri", None) | ||||
|             self.register_parameter("Wii", None) | ||||
|             self.register_parameter("Br", None) | ||||
|             self.register_parameter("Bi", None) | ||||
| 
 | ||||
|         if self.track_running_stats: | ||||
|             values = torch.zeros(self.num_features) | ||||
|             self.register_buffer("Mean_real", values) | ||||
|             self.register_buffer("Mean_imag", values) | ||||
|             self.register_buffer("Var_rr", values) | ||||
|             self.register_buffer("Var_ri", values) | ||||
|             self.register_buffer("Var_ii", values) | ||||
|             self.register_buffer( | ||||
|                 "num_batches_tracked", torch.tensor(0, dtype=torch.long) | ||||
|             ) | ||||
|         else: | ||||
|             self.register_parameter("Mean_real", None) | ||||
|             self.register_parameter("Mean_imag", None) | ||||
|             self.register_parameter("Var_rr", None) | ||||
|             self.register_parameter("Var_ri", None) | ||||
|             self.register_parameter("Var_ii", None) | ||||
|             self.register_parameter("num_batches_tracked", None) | ||||
| 
 | ||||
|         self.reset_parameters() | ||||
| 
 | ||||
|     def reset_parameters(self): | ||||
|         if self.affine: | ||||
|             self.Wrr.data.fill_(1) | ||||
|             self.Wii.data.fill_(1) | ||||
|             self.Wri.data.uniform_(-0.9, 0.9) | ||||
|             self.Br.data.fill_(0) | ||||
|             self.Bi.data.fill_(0) | ||||
|         self.reset_running_stats() | ||||
| 
 | ||||
|     def reset_running_stats(self): | ||||
|         if self.track_running_stats: | ||||
|             self.Mean_real.zero_() | ||||
|             self.Mean_imag.zero_() | ||||
|             self.Var_rr.fill_(1) | ||||
|             self.Var_ri.zero_() | ||||
|             self.Var_ii.fill_(1) | ||||
|             self.num_batches_tracked.zero_() | ||||
| 
 | ||||
|     def extra_repr(self): | ||||
|         return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( | ||||
|             **self.__dict__ | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, input): | ||||
| 
 | ||||
|         real, imag = torch.chunk(input, 2, 1) | ||||
|         exp_avg_factor = 0.0 | ||||
| 
 | ||||
|         training = self.training and self.track_running_stats | ||||
|         if training: | ||||
|             self.num_batches_tracked += 1 | ||||
|             if self.momentum is None: | ||||
|                 exp_avg_factor = 1 / self.num_batches_tracked | ||||
|             else: | ||||
|                 exp_avg_factor = self.momentum | ||||
| 
 | ||||
|         redux = [i for i in reversed(range(real.dim())) if i != 1] | ||||
|         vdim = [1] * real.dim() | ||||
|         vdim[1] = real.size(1) | ||||
| 
 | ||||
|         if training: | ||||
|             batch_mean_real, batch_mean_imag = real, imag | ||||
|             for dim in redux: | ||||
|                 batch_mean_real = batch_mean_real.mean(dim, keepdim=True) | ||||
|                 batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True) | ||||
|             if self.track_running_stats: | ||||
|                 self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor) | ||||
|                 self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor) | ||||
| 
 | ||||
|         else: | ||||
|             batch_mean_real = self.Mean_real.view(vdim) | ||||
|             batch_mean_imag = self.Mean_imag.view(vdim) | ||||
| 
 | ||||
|         real = real - batch_mean_real | ||||
|         imag = imag - batch_mean_imag | ||||
| 
 | ||||
|         if training: | ||||
|             batch_var_rr = real * real | ||||
|             batch_var_ri = real * imag | ||||
|             batch_var_ii = imag * imag | ||||
|             for dim in redux: | ||||
|                 batch_var_rr = batch_var_rr.mean(dim, keepdim=True) | ||||
|                 batch_var_ri = batch_var_ri.mean(dim, keepdim=True) | ||||
|                 batch_var_ii = batch_var_ii.mean(dim, keepdim=True) | ||||
|             if self.track_running_stats: | ||||
|                 self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) | ||||
|                 self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) | ||||
|                 self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) | ||||
|         else: | ||||
|             batch_var_rr = self.Var_rr.view(vdim) | ||||
|             batch_var_ii = self.Var_ii.view(vdim) | ||||
|             batch_var_ri = self.Var_ri.view(vdim) | ||||
| 
 | ||||
|         batch_var_rr += self.eps | ||||
|         batch_var_ii += self.eps | ||||
| 
 | ||||
|         # Covariance matrics | ||||
|         # | batch_var_rr    batch_var_ri | | ||||
|         # | batch_var_ir    batch_var_ii |  here batch_var_ir == batch_var_ri | ||||
|         # Inverse square root of cov matrix by combining below two formulas | ||||
|         # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix | ||||
|         # https://mathworld.wolfram.com/MatrixInverse.html | ||||
| 
 | ||||
|         tau = batch_var_rr + batch_var_ii | ||||
|         s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri | ||||
|         t = (tau + 2 * s).sqrt() | ||||
| 
 | ||||
|         rst = (s * t).reciprocal() | ||||
|         Urr = (batch_var_ii + s) * rst | ||||
|         Uri = -batch_var_ri * rst | ||||
|         Uii = (batch_var_rr + s) * rst | ||||
| 
 | ||||
|         if self.affine: | ||||
|             Wrr, Wri, Wii = ( | ||||
|                 self.Wrr.view(vdim), | ||||
|                 self.Wri.view(vdim), | ||||
|                 self.Wii.view(vdim), | ||||
|             ) | ||||
|             Zrr = (Wrr * Urr) + (Wri * Uri) | ||||
|             Zri = (Wrr * Uri) + (Wri * Uii) | ||||
|             Zir = (Wii * Uri) + (Wri * Urr) | ||||
|             Zii = (Wri * Uri) + (Wii * Uii) | ||||
|         else: | ||||
|             Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii | ||||
| 
 | ||||
|         yr = (Zrr * real) + (Zri * imag) | ||||
|         yi = (Zir * real) + (Zii * imag) | ||||
| 
 | ||||
|         if self.affine: | ||||
|             yr = yr + self.Br.view(vdim) | ||||
|             yi = yi + self.Bi.view(vdim) | ||||
| 
 | ||||
|         outputs = torch.cat([yr, yi], 1) | ||||
|         return outputs | ||||
| 
 | ||||
| 
 | ||||
| class ComplexRelu(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.real_relu = nn.PReLU() | ||||
|         self.imag_relu = nn.PReLU() | ||||
| 
 | ||||
|     def forward(self, input): | ||||
| 
 | ||||
|         real, imag = torch.chunk(input, 2, 1) | ||||
|         real = self.real_relu(real) | ||||
|         imag = self.imag_relu(imag) | ||||
|         return torch.cat([real, imag], dim=1) | ||||
| 
 | ||||
| 
 | ||||
| def complex_cat(inputs, axis=1): | ||||
| 
 | ||||
|     real, imag = [], [] | ||||
|     for data in inputs: | ||||
|         real_data, imag_data = torch.chunk(data, 2, axis) | ||||
|         real.append(real_data) | ||||
|         imag.append(imag_data) | ||||
|     real = torch.cat(real, axis) | ||||
|     imag = torch.cat(imag, axis) | ||||
|     return torch.cat([real, imag], axis) | ||||
|  | @ -1,338 +0,0 @@ | |||
| import warnings | ||||
| from typing import Any, List, Optional, Tuple, Union | ||||
| 
 | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
| 
 | ||||
| from mayavoz.data import MayaDataset | ||||
| from mayavoz.models import Mayamodel | ||||
| from mayavoz.models.complexnn import ( | ||||
|     ComplexBatchNorm2D, | ||||
|     ComplexConv2d, | ||||
|     ComplexConvTranspose2d, | ||||
|     ComplexLSTM, | ||||
|     ComplexRelu, | ||||
| ) | ||||
| from mayavoz.models.complexnn.utils import complex_cat | ||||
| from mayavoz.utils.transforms import ConviSTFT, ConvSTFT | ||||
| from mayavoz.utils.utils import merge_dict | ||||
| 
 | ||||
| 
 | ||||
| class DCCRN_ENCODER(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_channels: int, | ||||
|         out_channel: int, | ||||
|         kernel_size: Tuple[int, int], | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         stride: Tuple[int, int] = (2, 1), | ||||
|         padding: Tuple[int, int] = (2, 1), | ||||
|     ): | ||||
|         super().__init__() | ||||
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||
|         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||
| 
 | ||||
|         self.encoder = nn.Sequential( | ||||
|             ComplexConv2d( | ||||
|                 in_channels, | ||||
|                 out_channel, | ||||
|                 kernel_size=kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|             ), | ||||
|             batchnorm(out_channel), | ||||
|             activation, | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         return self.encoder(waveform) | ||||
| 
 | ||||
| 
 | ||||
| class DCCRN_DECODER(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_channels: int, | ||||
|         out_channels: int, | ||||
|         kernel_size: Tuple[int, int], | ||||
|         layer: int = 0, | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         stride: Tuple[int, int] = (2, 1), | ||||
|         padding: Tuple[int, int] = (2, 0), | ||||
|         output_padding: Tuple[int, int] = (1, 0), | ||||
|     ): | ||||
|         super().__init__() | ||||
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||
|         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||
| 
 | ||||
|         if layer != 0: | ||||
|             self.decoder = nn.Sequential( | ||||
|                 ComplexConvTranspose2d( | ||||
|                     in_channels, | ||||
|                     out_channels, | ||||
|                     kernel_size=kernel_size, | ||||
|                     stride=stride, | ||||
|                     padding=padding, | ||||
|                     output_padding=output_padding, | ||||
|                 ), | ||||
|                 batchnorm(out_channels), | ||||
|                 activation, | ||||
|             ) | ||||
|         else: | ||||
|             self.decoder = nn.Sequential( | ||||
|                 ComplexConvTranspose2d( | ||||
|                     in_channels, | ||||
|                     out_channels, | ||||
|                     kernel_size=kernel_size, | ||||
|                     stride=stride, | ||||
|                     padding=padding, | ||||
|                     output_padding=output_padding, | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         return self.decoder(waveform) | ||||
| 
 | ||||
| 
 | ||||
| class DCCRN(Mayamodel): | ||||
| 
 | ||||
|     STFT_DEFAULTS = { | ||||
|         "window_len": 400, | ||||
|         "hop_size": 100, | ||||
|         "nfft": 512, | ||||
|         "window": "hamming", | ||||
|     } | ||||
| 
 | ||||
|     ED_DEFAULTS = { | ||||
|         "initial_output_channels": 32, | ||||
|         "depth": 6, | ||||
|         "kernel_size": 5, | ||||
|         "growth_factor": 2, | ||||
|         "stride": 2, | ||||
|         "padding": 2, | ||||
|         "output_padding": 1, | ||||
|     } | ||||
| 
 | ||||
|     LSTM_DEFAULTS = { | ||||
|         "num_layers": 2, | ||||
|         "hidden_size": 256, | ||||
|     } | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         stft: Optional[dict] = None, | ||||
|         encoder_decoder: Optional[dict] = None, | ||||
|         lstm: Optional[dict] = None, | ||||
|         complex_lstm: bool = True, | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         masking_mode: str = "E", | ||||
|         num_channels: int = 1, | ||||
|         sampling_rate=16000, | ||||
|         lr: float = 1e-3, | ||||
|         dataset: Optional[MayaDataset] = None, | ||||
|         duration: Optional[float] = None, | ||||
|         loss: Union[str, List, Any] = "mse", | ||||
|         metric: Union[str, List] = "mse", | ||||
|     ): | ||||
|         duration = ( | ||||
|             dataset.duration if isinstance(dataset, MayaDataset) else duration | ||||
|         ) | ||||
|         if dataset is not None: | ||||
|             if sampling_rate != dataset.sampling_rate: | ||||
|                 warnings.warn( | ||||
|                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||
|                 ) | ||||
|                 sampling_rate = dataset.sampling_rate | ||||
|         super().__init__( | ||||
|             num_channels=num_channels, | ||||
|             sampling_rate=sampling_rate, | ||||
|             lr=lr, | ||||
|             dataset=dataset, | ||||
|             duration=duration, | ||||
|             loss=loss, | ||||
|             metric=metric, | ||||
|         ) | ||||
| 
 | ||||
|         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) | ||||
|         lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||||
|         stft = merge_dict(self.STFT_DEFAULTS, stft) | ||||
|         self.save_hyperparameters( | ||||
|             "encoder_decoder", | ||||
|             "lstm", | ||||
|             "stft", | ||||
|             "complex_lstm", | ||||
|             "complex_norm", | ||||
|             "masking_mode", | ||||
|         ) | ||||
|         self.complex_lstm = complex_lstm | ||||
|         self.complex_norm = complex_norm | ||||
|         self.masking_mode = masking_mode | ||||
| 
 | ||||
|         self.stft = ConvSTFT( | ||||
|             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] | ||||
|         ) | ||||
|         self.istft = ConviSTFT( | ||||
|             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] | ||||
|         ) | ||||
| 
 | ||||
|         self.encoder = nn.ModuleList() | ||||
|         self.decoder = nn.ModuleList() | ||||
| 
 | ||||
|         num_channels *= 2 | ||||
|         hidden_size = encoder_decoder["initial_output_channels"] | ||||
|         growth_factor = 2 | ||||
| 
 | ||||
|         for layer in range(encoder_decoder["depth"]): | ||||
| 
 | ||||
|             encoder_ = DCCRN_ENCODER( | ||||
|                 num_channels, | ||||
|                 hidden_size, | ||||
|                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||
|                 stride=(encoder_decoder["stride"], 1), | ||||
|                 padding=(encoder_decoder["padding"], 1), | ||||
|                 complex_norm=complex_norm, | ||||
|                 complex_relu=complex_relu, | ||||
|             ) | ||||
|             self.encoder.append(encoder_) | ||||
| 
 | ||||
|             decoder_ = DCCRN_DECODER( | ||||
|                 hidden_size + hidden_size, | ||||
|                 num_channels, | ||||
|                 layer=layer, | ||||
|                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||
|                 stride=(encoder_decoder["stride"], 1), | ||||
|                 padding=(encoder_decoder["padding"], 0), | ||||
|                 output_padding=(encoder_decoder["output_padding"], 0), | ||||
|                 complex_norm=complex_norm, | ||||
|                 complex_relu=complex_relu, | ||||
|             ) | ||||
| 
 | ||||
|             self.decoder.insert(0, decoder_) | ||||
| 
 | ||||
|             if layer < encoder_decoder["depth"] - 3: | ||||
|                 num_channels = hidden_size | ||||
|                 hidden_size *= growth_factor | ||||
|             else: | ||||
|                 num_channels = hidden_size | ||||
| 
 | ||||
|         kernel_size = hidden_size / 2 | ||||
|         hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) | ||||
| 
 | ||||
|         if self.complex_lstm: | ||||
|             lstms = [] | ||||
|             for layer in range(lstm["num_layers"]): | ||||
| 
 | ||||
|                 if layer == 0: | ||||
|                     input_size = int(hidden_size * kernel_size) | ||||
|                 else: | ||||
|                     input_size = lstm["hidden_size"] | ||||
| 
 | ||||
|                 if layer == lstm["num_layers"] - 1: | ||||
|                     projection_size = int(hidden_size * kernel_size) | ||||
|                 else: | ||||
|                     projection_size = None | ||||
| 
 | ||||
|                 kwargs = { | ||||
|                     "input_size": input_size, | ||||
|                     "hidden_size": lstm["hidden_size"], | ||||
|                     "num_layers": 1, | ||||
|                 } | ||||
| 
 | ||||
|                 lstms.append( | ||||
|                     ComplexLSTM(projection_size=projection_size, **kwargs) | ||||
|                 ) | ||||
|             self.lstm = nn.Sequential(*lstms) | ||||
|         else: | ||||
|             self.lstm = nn.Sequential( | ||||
|                 nn.LSTM( | ||||
|                     input_size=hidden_size * kernel_size, | ||||
|                     hidden_sizs=lstm["hidden_size"], | ||||
|                     num_layers=lstm["num_layers"], | ||||
|                     dropout=0.0, | ||||
|                     batch_first=False, | ||||
|                 )[0], | ||||
|                 nn.Linear(lstm["hidden"], hidden_size * kernel_size), | ||||
|             ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         if waveform.dim() == 2: | ||||
|             waveform = waveform.unsqueeze(1) | ||||
| 
 | ||||
|         if waveform.size(1) != self.hparams.num_channels: | ||||
|             raise ValueError( | ||||
|                 f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" | ||||
|             ) | ||||
| 
 | ||||
|         waveform_stft = self.stft(waveform) | ||||
|         real = waveform_stft[:, : self.stft.nfft // 2 + 1] | ||||
|         imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] | ||||
| 
 | ||||
|         mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9) | ||||
|         phase_spec = torch.atan2(imag, real) | ||||
|         complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:] | ||||
| 
 | ||||
|         encoder_outputs = [] | ||||
|         out = complex_spec | ||||
|         for _, encoder in enumerate(self.encoder): | ||||
|             out = encoder(out) | ||||
|             encoder_outputs.append(out) | ||||
| 
 | ||||
|         B, C, D, T = out.size() | ||||
|         out = out.permute(3, 0, 1, 2) | ||||
|         if self.complex_lstm: | ||||
| 
 | ||||
|             lstm_real = out[:, :, : C // 2] | ||||
|             lstm_imag = out[:, :, C // 2 :] | ||||
|             lstm_real = lstm_real.reshape(T, B, C // 2 * D) | ||||
|             lstm_imag = lstm_imag.reshape(T, B, C // 2 * D) | ||||
|             lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag]) | ||||
|             lstm_real = lstm_real.reshape(T, B, C // 2, D) | ||||
|             lstm_imag = lstm_imag.reshape(T, B, C // 2, D) | ||||
|             out = torch.cat([lstm_real, lstm_imag], 2) | ||||
|         else: | ||||
|             out = out.reshape(T, B, C * D) | ||||
|             out = self.lstm(out) | ||||
|             out = out.reshape(T, B, D, C) | ||||
| 
 | ||||
|         out = out.permute(1, 2, 3, 0) | ||||
|         for layer, decoder in enumerate(self.decoder): | ||||
|             skip_connection = encoder_outputs.pop(-1) | ||||
|             out = complex_cat([skip_connection, out]) | ||||
|             out = decoder(out) | ||||
|             out = out[..., 1:] | ||||
|         mask_real, mask_imag = out[:, 0], out[:, 1] | ||||
|         mask_real = F.pad(mask_real, [0, 0, 1, 0]) | ||||
|         mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) | ||||
|         if self.masking_mode == "E": | ||||
| 
 | ||||
|             mask_mag = torch.sqrt(mask_real**2 + mask_imag**2) | ||||
|             real_phase = mask_real / (mask_mag + 1e-8) | ||||
|             imag_phase = mask_imag / (mask_mag + 1e-8) | ||||
|             mask_phase = torch.atan2(imag_phase, real_phase) | ||||
|             mask_mag = torch.tanh(mask_mag) | ||||
|             est_mag = mask_mag * mag_spec | ||||
|             est_phase = mask_phase * phase_spec | ||||
|             # cos(theta) + isin(theta) | ||||
|             real = est_mag + torch.cos(est_phase) | ||||
|             imag = est_mag + torch.sin(est_phase) | ||||
| 
 | ||||
|         if self.masking_mode == "C": | ||||
| 
 | ||||
|             real = real * mask_real - imag * mask_imag | ||||
|             imag = real * mask_imag + imag * mask_real | ||||
| 
 | ||||
|         else: | ||||
| 
 | ||||
|             real = real * mask_real | ||||
|             imag = imag * mask_imag | ||||
| 
 | ||||
|         spec = torch.cat([real, imag], 1) | ||||
|         wav = self.istft(spec) | ||||
|         wav = wav.clamp_(-1, 1) | ||||
|         return wav | ||||
|  | @ -1,3 +0,0 @@ | |||
| from mayavoz.utils.config import Files | ||||
| from mayavoz.utils.io import Audio | ||||
| from mayavoz.utils.utils import check_files | ||||
|  | @ -1,93 +0,0 @@ | |||
| from typing import Optional | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from scipy.signal import get_window | ||||
| from torch import nn | ||||
| 
 | ||||
| 
 | ||||
| class ConvFFT(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         window_len: int, | ||||
|         nfft: Optional[int] = None, | ||||
|         window: str = "hamming", | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.window_len = window_len | ||||
|         self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) | ||||
|         self.window = torch.from_numpy( | ||||
|             get_window(window, window_len, fftbins=True).astype("float32") | ||||
|         ) | ||||
| 
 | ||||
|     def init_kernel(self, inverse=False): | ||||
| 
 | ||||
|         fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] | ||||
|         real, imag = np.real(fourier_basis), np.imag(fourier_basis) | ||||
|         kernel = np.concatenate([real, imag], 1).T | ||||
|         if inverse: | ||||
|             kernel = np.linalg.pinv(kernel).T | ||||
|         kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) | ||||
|         kernel *= self.window | ||||
|         return kernel | ||||
| 
 | ||||
| 
 | ||||
| class ConvSTFT(ConvFFT): | ||||
|     def __init__( | ||||
|         self, | ||||
|         window_len: int, | ||||
|         hop_size: Optional[int] = None, | ||||
|         nfft: Optional[int] = None, | ||||
|         window: str = "hamming", | ||||
|     ): | ||||
|         super().__init__(window_len=window_len, nfft=nfft, window=window) | ||||
|         self.hop_size = hop_size if hop_size else window_len // 2 | ||||
|         self.register_buffer("weight", self.init_kernel()) | ||||
| 
 | ||||
|     def forward(self, input): | ||||
| 
 | ||||
|         if input.dim() < 2: | ||||
|             raise ValueError( | ||||
|                 f"Expected signal with shape 2 or 3 got {input.dim()}" | ||||
|             ) | ||||
|         elif input.dim() == 2: | ||||
|             input = input.unsqueeze(1) | ||||
|         else: | ||||
|             pass | ||||
|         input = F.pad( | ||||
|             input, | ||||
|             (self.window_len - self.hop_size, self.window_len - self.hop_size), | ||||
|         ) | ||||
|         output = F.conv1d(input, self.weight, stride=self.hop_size) | ||||
| 
 | ||||
|         return output | ||||
| 
 | ||||
| 
 | ||||
| class ConviSTFT(ConvFFT): | ||||
|     def __init__( | ||||
|         self, | ||||
|         window_len: int, | ||||
|         hop_size: Optional[int] = None, | ||||
|         nfft: Optional[int] = None, | ||||
|         window: str = "hamming", | ||||
|     ): | ||||
|         super().__init__(window_len=window_len, nfft=nfft, window=window) | ||||
|         self.hop_size = hop_size if hop_size else window_len // 2 | ||||
|         self.register_buffer("weight", self.init_kernel(True)) | ||||
|         self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) | ||||
| 
 | ||||
|     def forward(self, input, phase=None): | ||||
| 
 | ||||
|         if phase is not None: | ||||
|             real = input * torch.cos(phase) | ||||
|             imag = input * torch.sin(phase) | ||||
|             input = torch.cat([real, imag], 1) | ||||
|         out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) | ||||
|         coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 | ||||
|         coeff = coeff.to(input.device) | ||||
|         coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) | ||||
|         out = out / (coeff + 1e-8) | ||||
|         pad = self.window_len - self.hop_size | ||||
|         out = out[..., pad:-pad] | ||||
|         return out | ||||
|  | @ -1,338 +0,0 @@ | |||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "ccd61d5c", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "## Custom model training using mayavoz [advanced]\n", | ||||
|     "\n", | ||||
|     "In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n", | ||||
|     "\n", | ||||
|     " - [Data preparation using MayaDataset](#dataprep)\n", | ||||
|     " - [Model customization](#modelcustom)\n", | ||||
|     " - [callbacks & LR schedulers](#callbacks)\n", | ||||
|     " - [Model training & testing](#train)\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "726c320f", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "- **install mayavoz**" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "c987c799", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "! pip install -q mayavoz" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "8ff9857b", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"dataprep\"></div>\n", | ||||
|     "\n", | ||||
|     "### Data preparation\n", | ||||
|     "\n", | ||||
|     "`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n", | ||||
|     "\n", | ||||
|     "```\n", | ||||
|     "- VCTK/\n", | ||||
|     "    |__ clean_train_wav/\n", | ||||
|     "    |__ noisy_train_wav/\n", | ||||
|     "    |__ clean_test_wav/\n", | ||||
|     "    |__ noisy_test_wav/\n", | ||||
|     "    \n", | ||||
|     "```" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "64cbc0c8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from mayavoz.utils import Files\n", | ||||
|     "file = Files(train_clean=\"clean_train_wav\",\n", | ||||
|     "            train_noisy=\"noisy_train_wav\",\n", | ||||
|     "            test_clean=\"clean_test_wav\",\n", | ||||
|     "            test_noisy=\"noisy_test_wav\")\n", | ||||
|     "root_dir = \"VCTK\"" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "2d324bd1", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "- `name`: name of the dataset. \n", | ||||
|     "- `duration`: control the duration of each audio instance fed into your model.\n", | ||||
|     "- `stride` is used if set to move the sliding window.\n", | ||||
|     "- `sampling_rate`: desired sampling rate for audio\n", | ||||
|     "- `batch_size`: model batch size\n", | ||||
|     "- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n", | ||||
|     "- `matching_function`: there are two types of mapping functions.\n", | ||||
|     "    - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n", | ||||
|     "    - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example MS-SNSD dataset.\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "6834941d", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "name = \"vctk\"\n", | ||||
|     "duration : 4.5\n", | ||||
|     "stride : 2.0\n", | ||||
|     "sampling_rate : 16000\n", | ||||
|     "min_valid_minutes : 20.0\n", | ||||
|     "batch_size : 32\n", | ||||
|     "matching_function : \"one_to_one\"\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "d08c6bf8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from mayavoz.dataset import MayaDataset\n", | ||||
|     "dataset = MayaDataset(\n", | ||||
|     "            name=name,\n", | ||||
|     "            root_dir=root_dir,\n", | ||||
|     "            files=files,\n", | ||||
|     "            duration=duration,\n", | ||||
|     "            stride=stride,\n", | ||||
|     "            sampling_rate=sampling_rate,\n", | ||||
|     "            batch_size=batch_size,\n", | ||||
|     "            min_valid_minutes=min_valid_minutes,\n", | ||||
|     "            matching_function=matching_function\n", | ||||
|     "        )" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "5b315bde", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "Now your custom dataloader is ready!" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "01548fe5", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"modelcustom\"></div>\n", | ||||
|     "\n", | ||||
|     "### Model Customization\n", | ||||
|     "Now, this is very easy. \n", | ||||
|     "\n", | ||||
|     "- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n", | ||||
|     "   - `WaveUnet`\n", | ||||
|     "   - `Demucs`\n", | ||||
|     "   - `DCCRN`\n", | ||||
|     "- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you.   Just check the parameters and pass it to as required.\n", | ||||
|     "- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n", | ||||
|     "- `dataset`: mayavoz dataset object as prepared earlier.\n", | ||||
|     "- `loss` : model loss. Multiple loss functions are available.\n", | ||||
|     "\n", | ||||
|     "        \n", | ||||
|     "        \n", | ||||
|     "you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n", | ||||
|     "\n", | ||||
|     "mayavoz can accept **custom loss functions**. It should be of the form.\n", | ||||
|     "```\n", | ||||
|     "class your_custom_loss(nn.Module):\n", | ||||
|     "    def __init__(self,**kwargs):\n", | ||||
|     "        self.higher_better = False  ## loss minimization direction\n", | ||||
|     "        self.name = \"your_loss_name\" ## loss name logging \n", | ||||
|     "        ...\n", | ||||
|     "    def forward(self,prediction, target):\n", | ||||
|     "        loss = ....\n", | ||||
|     "        return loss\n", | ||||
|     "        \n", | ||||
|     "```\n", | ||||
|     "\n", | ||||
|     "- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "b36b457c", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from mayavoz.models import Demucs\n", | ||||
|     "model = Demucs(\n", | ||||
|     "        sampling_rate=16000,\n", | ||||
|     "        dataset=dataset,\n", | ||||
|     "        loss=[\"mae\"],\n", | ||||
|     "        metrics=[\"stoi\",\"pesq\"])\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "1523d638", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"callbacks\"></div>\n", | ||||
|     "\n", | ||||
|     "### learning rate schedulers and callbacks\n", | ||||
|     "Here I am using `ReduceLROnPlateau`" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "8de6931c", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", | ||||
|     "\n", | ||||
|     "def configure_optimizers(self):\n", | ||||
|     "        optimizer = instantiate(\n", | ||||
|     "            config.optimizer,\n", | ||||
|     "            lr=parameters.get(\"lr\"),\n", | ||||
|     "            params=self.parameters(),\n", | ||||
|     "        )\n", | ||||
|     "        scheduler = ReduceLROnPlateau(\n", | ||||
|     "            optimizer=optimizer,\n", | ||||
|     "            mode=direction,\n", | ||||
|     "            factor=parameters.get(\"ReduceLr_factor\", 0.1),\n", | ||||
|     "            verbose=True,\n", | ||||
|     "            min_lr=parameters.get(\"min_lr\", 1e-6),\n", | ||||
|     "            patience=parameters.get(\"ReduceLr_patience\", 3),\n", | ||||
|     "        )\n", | ||||
|     "        return {\n", | ||||
|     "            \"optimizer\": optimizer,\n", | ||||
|     "            \"lr_scheduler\": scheduler,\n", | ||||
|     "            \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n", | ||||
|     "        }\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "model.configure_optimizers = MethodType(configure_optimizers, model)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "2f7b5af5", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "6f6b62a1", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "callbacks = []\n", | ||||
|     "direction = model.valid_monitor ## min or max \n", | ||||
|     "checkpoint = ModelCheckpoint(\n", | ||||
|     "        dirpath=\"./model\",\n", | ||||
|     "        filename=f\"model_filename\",\n", | ||||
|     "        monitor=\"valid_loss\",\n", | ||||
|     "        verbose=False,\n", | ||||
|     "        mode=direction,\n", | ||||
|     "        every_n_epochs=1,\n", | ||||
|     "    )\n", | ||||
|     "callbacks.append(checkpoint)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "f3534445", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"train\"></div>\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "### Train" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "3dc0348b", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import pytorch_lightning as pl\n", | ||||
|     "trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n", | ||||
|     "trainer.fit(model)\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "56dcfec1", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "- Test your model agaist test dataset" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "63851feb", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "trainer.test(model)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "4d3f5350", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "**Hurray! you have your speech enhancement model trained and tested.**\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "10d630e8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "enhancer", | ||||
|    "language": "python", | ||||
|    "name": "enhancer" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.13" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							|  | @ -1,120 +0,0 @@ | |||
| import os | ||||
| from types import MethodType | ||||
| 
 | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig, OmegaConf | ||||
| from pytorch_lightning.callbacks import ( | ||||
|     EarlyStopping, | ||||
|     LearningRateMonitor, | ||||
|     ModelCheckpoint, | ||||
| ) | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| 
 | ||||
| # from torch_audiomentations import Compose, Shift | ||||
| 
 | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||
| 
 | ||||
| 
 | ||||
| @hydra.main(config_path="train_config", config_name="config") | ||||
| def train(config: DictConfig): | ||||
| 
 | ||||
|     OmegaConf.save(config, "config.yaml") | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger( | ||||
|         experiment_name=config.mlflow.experiment_name, | ||||
|         run_name=config.mlflow.run_name, | ||||
|         tags={"JOB_ID": JOB_ID}, | ||||
|     ) | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
|     # apply_augmentations = Compose( | ||||
|     #     [ | ||||
|     #         Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), | ||||
|     #     ] | ||||
|     # ) | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset, augmentations=None) | ||||
|     model = instantiate( | ||||
|         config.model, | ||||
|         dataset=dataset, | ||||
|         lr=parameters.get("lr"), | ||||
|         loss=parameters.get("loss"), | ||||
|         metric=parameters.get("metric"), | ||||
|     ) | ||||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model", | ||||
|         filename=f"model_{JOB_ID}", | ||||
|         monitor="valid_loss", | ||||
|         verbose=False, | ||||
|         mode=direction, | ||||
|         every_n_epochs=1, | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     callbacks.append(LearningRateMonitor(logging_interval="epoch")) | ||||
| 
 | ||||
|     if parameters.get("Early_stop", False): | ||||
|         early_stopping = EarlyStopping( | ||||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=parameters.get("EarlyStopping_patience", 10), | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|         callbacks.append(early_stopping) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         optimizer = instantiate( | ||||
|             config.optimizer, | ||||
|             lr=parameters.get("lr"), | ||||
|             params=self.parameters(), | ||||
|         ) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor", 0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr", 1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience", 3), | ||||
|         ) | ||||
|         return { | ||||
|             "optimizer": optimizer, | ||||
|             "lr_scheduler": scheduler, | ||||
|             "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', | ||||
|         } | ||||
| 
 | ||||
|     model.configure_optimizers = MethodType(configure_optimizers, model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     trainer.test(model) | ||||
| 
 | ||||
|     logger.experiment.log_artifact( | ||||
|         logger.run_id, f"{trainer.default_root_dir}/config.yaml" | ||||
|     ) | ||||
| 
 | ||||
|     saved_location = os.path.join( | ||||
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||
|     ) | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_train_steps_per_epoch", | ||||
|             dataset.train__len__() / dataset.batch_size, | ||||
|         ) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_valid_steps_per_epoch", | ||||
|             dataset.val__len__() / dataset.batch_size, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     train() | ||||
|  | @ -1,7 +0,0 @@ | |||
| defaults: | ||||
|   - model : Demucs | ||||
|   - dataset : MS-SNSD | ||||
|   - optimizer : Adam | ||||
|   - hyperparameters : default | ||||
|   - trainer : default | ||||
|   - mlflow : experiment | ||||
|  | @ -1,13 +0,0 @@ | |||
| _target_: mayavoz.data.dataset.MayaDataset | ||||
| name : MS-SDSD | ||||
| root_dir : /Users/shahules/Myprojects/MS-SNSD | ||||
| duration : 1.5 | ||||
| stride : 1 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 32 | ||||
| min_valid_minutes: 25 | ||||
| files: | ||||
|   train_clean : CleanSpeech_training | ||||
|   test_clean : CleanSpeech_training | ||||
|   train_noisy : NoisySpeech_training | ||||
|   test_noisy : NoisySpeech_training | ||||
|  | @ -1,7 +0,0 @@ | |||
| loss : si-snr | ||||
| metric : [stoi,pesq] | ||||
| lr : 0.001 | ||||
| ReduceLr_patience : 10 | ||||
| ReduceLr_factor : 0.5 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
|  | @ -1,2 +0,0 @@ | |||
| experiment_name : shahules/mayavoz | ||||
| run_name : Demucs + Vtck with stride + augmentations | ||||
|  | @ -1,25 +0,0 @@ | |||
| _target_: mayavoz.models.dccrn.DCCRN | ||||
| num_channels: 1 | ||||
| sampling_rate : 16000 | ||||
| complex_lstm : True | ||||
| complex_norm : True | ||||
| complex_relu : True | ||||
| masking_mode : True | ||||
| 
 | ||||
| encoder_decoder: | ||||
|   initial_output_channels : 32 | ||||
|   depth : 6 | ||||
|   kernel_size : 5 | ||||
|   growth_factor : 2 | ||||
|   stride : 2 | ||||
|   padding : 2 | ||||
|   output_padding : 1 | ||||
| 
 | ||||
| lstm: | ||||
|   num_layers : 2 | ||||
|   hidden_size : 256 | ||||
| 
 | ||||
| stft: | ||||
|   window_len : 400 | ||||
|   hop_size : 100 | ||||
|   nfft : 512 | ||||
|  | @ -1,6 +0,0 @@ | |||
| _target_: torch.optim.Adam | ||||
| lr: 1e-3 | ||||
| betas: [0.9, 0.999] | ||||
| eps: 1e-08 | ||||
| weight_decay: 0 | ||||
| amsgrad: False | ||||
|  | @ -1,46 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| accelerator: gpu | ||||
| accumulate_grad_batches: 1 | ||||
| amp_backend: native | ||||
| auto_lr_find: True | ||||
| auto_scale_batch_size: False | ||||
| auto_select_gpus: True | ||||
| benchmark: False | ||||
| check_val_every_n_epoch: 1 | ||||
| detect_anomaly: False | ||||
| deterministic: False | ||||
| devices: 2 | ||||
| enable_checkpointing: True | ||||
| enable_model_summary: True | ||||
| enable_progress_bar: True | ||||
| fast_dev_run: False | ||||
| gpus: null | ||||
| gradient_clip_val: 0 | ||||
| gradient_clip_algorithm: norm | ||||
| ipus: null | ||||
| limit_predict_batches: 1.0 | ||||
| limit_test_batches: 1.0 | ||||
| limit_train_batches: 1.0 | ||||
| limit_val_batches: 1.0 | ||||
| log_every_n_steps: 50 | ||||
| max_epochs: 200 | ||||
| max_steps: -1 | ||||
| max_time: null | ||||
| min_epochs: 1 | ||||
| min_steps: null | ||||
| move_metrics_to_cpu: False | ||||
| multiple_trainloader_mode: max_size_cycle | ||||
| num_nodes: 1 | ||||
| num_processes: 1 | ||||
| num_sanity_val_steps: 2 | ||||
| overfit_batches: 0.0 | ||||
| precision: 32 | ||||
| profiler: null | ||||
| reload_dataloaders_every_n_epochs: 0 | ||||
| replace_sampler_ddp: True | ||||
| strategy: ddp | ||||
| sync_batchnorm: False | ||||
| tpu_cores: null | ||||
| track_grad_norm: -1 | ||||
| val_check_interval: 1.0 | ||||
| weights_save_path: null | ||||
|  | @ -1,2 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| fast_dev_run: True | ||||
|  | @ -1,120 +0,0 @@ | |||
| import os | ||||
| from types import MethodType | ||||
| 
 | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig, OmegaConf | ||||
| from pytorch_lightning.callbacks import ( | ||||
|     EarlyStopping, | ||||
|     LearningRateMonitor, | ||||
|     ModelCheckpoint, | ||||
| ) | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| 
 | ||||
| # from torch_audiomentations import Compose, Shift | ||||
| 
 | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||
| 
 | ||||
| 
 | ||||
| @hydra.main(config_path="train_config", config_name="config") | ||||
| def train(config: DictConfig): | ||||
| 
 | ||||
|     OmegaConf.save(config, "config.yaml") | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger( | ||||
|         experiment_name=config.mlflow.experiment_name, | ||||
|         run_name=config.mlflow.run_name, | ||||
|         tags={"JOB_ID": JOB_ID}, | ||||
|     ) | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
|     # apply_augmentations = Compose( | ||||
|     #     [ | ||||
|     #         Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), | ||||
|     #     ] | ||||
|     # ) | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset, augmentations=None) | ||||
|     model = instantiate( | ||||
|         config.model, | ||||
|         dataset=dataset, | ||||
|         lr=parameters.get("lr"), | ||||
|         loss=parameters.get("loss"), | ||||
|         metric=parameters.get("metric"), | ||||
|     ) | ||||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model", | ||||
|         filename=f"model_{JOB_ID}", | ||||
|         monitor="valid_loss", | ||||
|         verbose=False, | ||||
|         mode=direction, | ||||
|         every_n_epochs=1, | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     callbacks.append(LearningRateMonitor(logging_interval="epoch")) | ||||
| 
 | ||||
|     if parameters.get("Early_stop", False): | ||||
|         early_stopping = EarlyStopping( | ||||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=parameters.get("EarlyStopping_patience", 10), | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|         callbacks.append(early_stopping) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         optimizer = instantiate( | ||||
|             config.optimizer, | ||||
|             lr=parameters.get("lr"), | ||||
|             params=self.parameters(), | ||||
|         ) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor", 0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr", 1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience", 3), | ||||
|         ) | ||||
|         return { | ||||
|             "optimizer": optimizer, | ||||
|             "lr_scheduler": scheduler, | ||||
|             "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', | ||||
|         } | ||||
| 
 | ||||
|     model.configure_optimizers = MethodType(configure_optimizers, model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     trainer.test(model) | ||||
| 
 | ||||
|     logger.experiment.log_artifact( | ||||
|         logger.run_id, f"{trainer.default_root_dir}/config.yaml" | ||||
|     ) | ||||
| 
 | ||||
|     saved_location = os.path.join( | ||||
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||
|     ) | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_train_steps_per_epoch", | ||||
|             dataset.train__len__() / dataset.batch_size, | ||||
|         ) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_valid_steps_per_epoch", | ||||
|             dataset.val__len__() / dataset.batch_size, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     train() | ||||
|  | @ -1,7 +0,0 @@ | |||
| defaults: | ||||
|   - model : Demucs | ||||
|   - dataset : MS-SNSD | ||||
|   - optimizer : Adam | ||||
|   - hyperparameters : default | ||||
|   - trainer : default | ||||
|   - mlflow : experiment | ||||
|  | @ -1,13 +0,0 @@ | |||
| _target_: mayavoz.data.dataset.MayaDataset | ||||
| name : MS-SDSD | ||||
| root_dir : /Users/shahules/Myprojects/MS-SNSD | ||||
| duration : 5 | ||||
| stride : 1 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 32 | ||||
| min_valid_minutes: 25 | ||||
| files: | ||||
|   train_clean : CleanSpeech_training | ||||
|   test_clean : CleanSpeech_training | ||||
|   train_noisy : NoisySpeech_training | ||||
|   test_noisy : NoisySpeech_training | ||||
|  | @ -1,7 +0,0 @@ | |||
| loss : mae | ||||
| metric : [stoi,pesq] | ||||
| lr : 0.0003 | ||||
| ReduceLr_patience : 10 | ||||
| ReduceLr_factor : 0.5 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
|  | @ -1,2 +0,0 @@ | |||
| experiment_name : shahules/mayavoz | ||||
| run_name : demucs-ms-snsd | ||||
|  | @ -1,16 +0,0 @@ | |||
| _target_: mayavoz.models.demucs.Demucs | ||||
| num_channels: 1 | ||||
| resample: 4 | ||||
| sampling_rate : 16000 | ||||
| 
 | ||||
| encoder_decoder: | ||||
|   depth: 4 | ||||
|   initial_output_channels: 64 | ||||
|   kernel_size: 8 | ||||
|   stride: 4 | ||||
|   growth_factor: 2 | ||||
|   glu: True | ||||
| 
 | ||||
| lstm: | ||||
|   bidirectional: False | ||||
|   num_layers: 2 | ||||
|  | @ -1,6 +0,0 @@ | |||
| _target_: torch.optim.Adam | ||||
| lr: 1e-3 | ||||
| betas: [0.9, 0.999] | ||||
| eps: 1e-08 | ||||
| weight_decay: 0 | ||||
| amsgrad: False | ||||
|  | @ -1,46 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| accelerator: gpu | ||||
| accumulate_grad_batches: 1 | ||||
| amp_backend: native | ||||
| auto_lr_find: True | ||||
| auto_scale_batch_size: False | ||||
| auto_select_gpus: True | ||||
| benchmark: False | ||||
| check_val_every_n_epoch: 1 | ||||
| detect_anomaly: False | ||||
| deterministic: False | ||||
| devices: 2 | ||||
| enable_checkpointing: True | ||||
| enable_model_summary: True | ||||
| enable_progress_bar: True | ||||
| fast_dev_run: False | ||||
| gpus: null | ||||
| gradient_clip_val: 0 | ||||
| gradient_clip_algorithm: norm | ||||
| ipus: null | ||||
| limit_predict_batches: 1.0 | ||||
| limit_test_batches: 1.0 | ||||
| limit_train_batches: 1.0 | ||||
| limit_val_batches: 1.0 | ||||
| log_every_n_steps: 50 | ||||
| max_epochs: 200 | ||||
| max_steps: -1 | ||||
| max_time: null | ||||
| min_epochs: 1 | ||||
| min_steps: null | ||||
| move_metrics_to_cpu: False | ||||
| multiple_trainloader_mode: max_size_cycle | ||||
| num_nodes: 1 | ||||
| num_processes: 1 | ||||
| num_sanity_val_steps: 2 | ||||
| overfit_batches: 0.0 | ||||
| precision: 32 | ||||
| profiler: null | ||||
| reload_dataloaders_every_n_epochs: 0 | ||||
| replace_sampler_ddp: True | ||||
| strategy: ddp | ||||
| sync_batchnorm: False | ||||
| tpu_cores: null | ||||
| track_grad_norm: -1 | ||||
| val_check_interval: 1.0 | ||||
| weights_save_path: null | ||||
|  | @ -1,2 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| fast_dev_run: True | ||||
|  | @ -1,17 +0,0 @@ | |||
| ### Microsoft Scalable Noisy Speech Dataset (MS-SNSD) | ||||
| 
 | ||||
|  MS-SNSD is a speech datasetthat can scale to arbitrary sizes depending on the number of speakers, noise types, and Speech to Noise Ratio (SNR) levels desired. | ||||
| 
 | ||||
| ### Dataset download & setup | ||||
| - Follow steps in the official repo [here](https://github.com/microsoft/MS-SNSD) to download and setup the dataset. | ||||
| 
 | ||||
| **References** | ||||
| ```BibTex | ||||
| @article{reddy2019scalable, | ||||
|   title={A Scalable Noisy Speech Dataset and Online Subjective Test Framework}, | ||||
|   author={Reddy, Chandan KA and Beyrami, Ebrahim and Pool, Jamie and Cutler, Ross and Srinivasan, Sriram and Gehrke, Johannes}, | ||||
|   journal={Proc. Interspeech 2019}, | ||||
|   pages={1816--1820}, | ||||
|   year={2019} | ||||
| } | ||||
| ``` | ||||
|  | @ -1,120 +0,0 @@ | |||
| import os | ||||
| from types import MethodType | ||||
| 
 | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig, OmegaConf | ||||
| from pytorch_lightning.callbacks import ( | ||||
|     EarlyStopping, | ||||
|     LearningRateMonitor, | ||||
|     ModelCheckpoint, | ||||
| ) | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| 
 | ||||
| # from torch_audiomentations import Compose, Shift | ||||
| 
 | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||
| 
 | ||||
| 
 | ||||
| @hydra.main(config_path="train_config", config_name="config") | ||||
| def main(config: DictConfig): | ||||
| 
 | ||||
|     OmegaConf.save(config, "config_log.yaml") | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger( | ||||
|         experiment_name=config.mlflow.experiment_name, | ||||
|         run_name=config.mlflow.run_name, | ||||
|         tags={"JOB_ID": JOB_ID}, | ||||
|     ) | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
|     # apply_augmentations = Compose( | ||||
|     #     [ | ||||
|     #         Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), | ||||
|     #     ] | ||||
|     # ) | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset, augmentations=None) | ||||
|     model = instantiate( | ||||
|         config.model, | ||||
|         dataset=dataset, | ||||
|         lr=parameters.get("lr"), | ||||
|         loss=parameters.get("loss"), | ||||
|         metric=parameters.get("metric"), | ||||
|     ) | ||||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model", | ||||
|         filename=f"model_{JOB_ID}", | ||||
|         monitor="valid_loss", | ||||
|         verbose=False, | ||||
|         mode=direction, | ||||
|         every_n_epochs=1, | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     callbacks.append(LearningRateMonitor(logging_interval="epoch")) | ||||
| 
 | ||||
|     if parameters.get("Early_stop", False): | ||||
|         early_stopping = EarlyStopping( | ||||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=parameters.get("EarlyStopping_patience", 10), | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|         callbacks.append(early_stopping) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         optimizer = instantiate( | ||||
|             config.optimizer, | ||||
|             lr=parameters.get("lr"), | ||||
|             params=self.parameters(), | ||||
|         ) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor", 0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr", 1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience", 3), | ||||
|         ) | ||||
|         return { | ||||
|             "optimizer": optimizer, | ||||
|             "lr_scheduler": scheduler, | ||||
|             "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', | ||||
|         } | ||||
| 
 | ||||
|     model.configure_optimizers = MethodType(configure_optimizers, model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
|     trainer.test(model) | ||||
| 
 | ||||
|     logger.experiment.log_artifact( | ||||
|         logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" | ||||
|     ) | ||||
| 
 | ||||
|     saved_location = os.path.join( | ||||
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||
|     ) | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_train_steps_per_epoch", | ||||
|             dataset.train__len__() / dataset.batch_size, | ||||
|         ) | ||||
|         logger.experiment.log_param( | ||||
|             logger.run_id, | ||||
|             "num_valid_steps_per_epoch", | ||||
|             dataset.val__len__() / dataset.batch_size, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  | @ -1,7 +0,0 @@ | |||
| defaults: | ||||
|   - model : Demucs | ||||
|   - dataset : Vctk | ||||
|   - optimizer : Adam | ||||
|   - hyperparameters : default | ||||
|   - trainer : default | ||||
|   - mlflow : experiment | ||||
|  | @ -1,13 +0,0 @@ | |||
| _target_: mayavoz.data.dataset.MayaDataset | ||||
| name : vctk | ||||
| root_dir : /scratch/c.sistc3/DS_10283_2791 | ||||
| duration : 4.5 | ||||
| stride : 0.5 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 32 | ||||
| min_valid_minutes : 25 | ||||
| files: | ||||
|   train_clean : clean_trainset_28spk_wav | ||||
|   test_clean : clean_testset_wav | ||||
|   train_noisy : noisy_trainset_28spk_wav | ||||
|   test_noisy : noisy_testset_wav | ||||
|  | @ -1,8 +0,0 @@ | |||
| loss : mae | ||||
| metric : [stoi,pesq,si-sdr] | ||||
| lr : 0.0003 | ||||
| Early_stop : False | ||||
| ReduceLr_patience : 10 | ||||
| ReduceLr_factor : 0.1 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
|  | @ -1,2 +0,0 @@ | |||
| experiment_name : shahules/mayavoz | ||||
| run_name : baseline | ||||
|  | @ -1,16 +0,0 @@ | |||
| _target_: mayavoz.models.demucs.Demucs | ||||
| num_channels: 1 | ||||
| resample: 4 | ||||
| sampling_rate : 16000 | ||||
| 
 | ||||
| encoder_decoder: | ||||
|   depth: 4 | ||||
|   initial_output_channels: 64 | ||||
|   kernel_size: 8 | ||||
|   stride: 4 | ||||
|   growth_factor: 2 | ||||
|   glu: True | ||||
| 
 | ||||
| lstm: | ||||
|   bidirectional: True | ||||
|   num_layers: 2 | ||||
|  | @ -1,6 +0,0 @@ | |||
| _target_: torch.optim.Adam | ||||
| lr: 1e-3 | ||||
| betas: [0.9, 0.999] | ||||
| eps: 1e-08 | ||||
| weight_decay: 0 | ||||
| amsgrad: False | ||||
|  | @ -1,8 +0,0 @@ | |||
| loss : mae | ||||
| metric : [stoi,pesq,si-sdr] | ||||
| lr : 0.003 | ||||
| ReduceLr_patience : 10 | ||||
| ReduceLr_factor : 0.1 | ||||
| min_lr : 0.000001 | ||||
| EarlyStopping_factor : 10 | ||||
| Early_stop : False | ||||
|  | @ -1,2 +0,0 @@ | |||
| experiment_name : shahules/mayavoz | ||||
| run_name : baseline | ||||
|  | @ -1,5 +0,0 @@ | |||
| _target_: mayavoz.models.waveunet.WaveUnet | ||||
| num_channels : 1 | ||||
| depth : 9 | ||||
| initial_output_channels: 24 | ||||
| sampling_rate : 16000 | ||||
|  | @ -1,6 +0,0 @@ | |||
| _target_: torch.optim.Adam | ||||
| lr: 1e-3 | ||||
| betas: [0.9, 0.999] | ||||
| eps: 1e-08 | ||||
| weight_decay: 0 | ||||
| amsgrad: False | ||||
|  | @ -1,46 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| accelerator: gpu | ||||
| accumulate_grad_batches: 1 | ||||
| amp_backend: native | ||||
| auto_lr_find: True | ||||
| auto_scale_batch_size: False | ||||
| auto_select_gpus: True | ||||
| benchmark: False | ||||
| check_val_every_n_epoch: 1 | ||||
| detect_anomaly: False | ||||
| deterministic: False | ||||
| devices: 2 | ||||
| enable_checkpointing: True | ||||
| enable_model_summary: True | ||||
| enable_progress_bar: True | ||||
| fast_dev_run: False | ||||
| gpus: null | ||||
| gradient_clip_val: 0 | ||||
| gradient_clip_algorithm: norm | ||||
| ipus: null | ||||
| limit_predict_batches: 1.0 | ||||
| limit_test_batches: 1.0 | ||||
| limit_train_batches: 1.0 | ||||
| limit_val_batches: 1.0 | ||||
| log_every_n_steps: 50 | ||||
| max_epochs: 200 | ||||
| max_steps: -1 | ||||
| max_time: null | ||||
| min_epochs: 1 | ||||
| min_steps: null | ||||
| move_metrics_to_cpu: False | ||||
| multiple_trainloader_mode: max_size_cycle | ||||
| num_nodes: 1 | ||||
| num_processes: 1 | ||||
| num_sanity_val_steps: 2 | ||||
| overfit_batches: 0.0 | ||||
| precision: 32 | ||||
| profiler: null | ||||
| reload_dataloaders_every_n_epochs: 0 | ||||
| replace_sampler_ddp: True | ||||
| strategy: ddp | ||||
| sync_batchnorm: False | ||||
| tpu_cores: null | ||||
| track_grad_norm: -1 | ||||
| val_check_interval: 1.0 | ||||
| weights_save_path: null | ||||
|  | @ -1,2 +0,0 @@ | |||
| _target_: pytorch_lightning.Trainer | ||||
| fast_dev_run: True | ||||
|  | @ -1,12 +0,0 @@ | |||
| ## Valentini dataset | ||||
| 
 | ||||
| Clean and noisy parallel speech database. The database was designed to train and test speech enhancement methods that operate at 48kHz. A more detailed description can be found in the papers associated with the database.[official page](https://datashare.ed.ac.uk/handle/10283/2791) | ||||
| 
 | ||||
| **References** | ||||
| ```BibTex | ||||
| @misc{ | ||||
| title={Noisy speech database for training speech enhancement algorithms and TTS models}, | ||||
| author={Valentini-Botinhao, Cassia}, year={2017}, | ||||
| doi=https://doi.org/10.7488/ds/2117, | ||||
| } | ||||
| ``` | ||||
|  | @ -1,19 +1,16 @@ | |||
| black>=22.8.0 | ||||
| boto3>=1.24.86 | ||||
| huggingface-hub>=0.10.0 | ||||
| flake8>=5.0.4 | ||||
| huggingface-hu>=0.10.0 | ||||
| hydra-core>=1.2.0 | ||||
| joblib>=1.2.0 | ||||
| librosa>=0.9.2 | ||||
| mlflow>=1.28.0 | ||||
| mlflow>=1.29.0 | ||||
| numpy>=1.23.3 | ||||
| pesq==0.0.4 | ||||
| protobuf>=3.19.6 | ||||
| pystoi==0.3.3 | ||||
| pytest-lazy-fixture>=0.6.3 | ||||
| pytorch-lightning>=1.7.7 | ||||
| scikit-learn>=1.1.2 | ||||
| scipy>=1.9.1 | ||||
| soundfile>=0.11.0 | ||||
| torch>=1.12.1 | ||||
| torch-audiomentations==0.11.0 | ||||
| torchaudio>=0.12.1 | ||||
| tqdm>=4.64.1 | ||||
|  |  | |||
							
								
								
									
										104
									
								
								setup.cfg
								
								
								
								
							
							
						
						
									
										104
									
								
								setup.cfg
								
								
								
								
							|  | @ -1,104 +0,0 @@ | |||
| # This file is used to configure your project. | ||||
| # Read more about the various options under: | ||||
| # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files | ||||
| 
 | ||||
| [metadata] | ||||
| name = mayavoz | ||||
| description = Deep learning for speech enhacement | ||||
| author = Shahul Ess | ||||
| author-email = shahules786@gmail.com | ||||
| license = mit | ||||
| long-description = file: README.md | ||||
| long-description-content-type = text/markdown; charset=UTF-8; variant=GFM | ||||
| # Change if running only on Windows, Mac or Linux (comma-separated) | ||||
| platforms = Linux, Mac | ||||
| # Add here all kinds of additional classifiers as defined under | ||||
| # https://pypi.python.org/pypi?%3Aaction=list_classifiers | ||||
| classifiers = | ||||
|     Development Status :: 4 - Beta | ||||
|     Programming Language :: Python | ||||
| 
 | ||||
| [options] | ||||
| zip_safe = False | ||||
| packages = find: | ||||
| include_package_data = True | ||||
| # DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! | ||||
| setup_requires = setuptools | ||||
| # Add here dependencies of your project (semicolon/line-separated), e.g. | ||||
| # install_requires = numpy; scipy | ||||
| # Require a specific Python version, e.g. Python 2.7 or >= 3.4 | ||||
| python_requires = >=3.8 | ||||
| 
 | ||||
| [options.packages.find] | ||||
| where = . | ||||
| exclude = | ||||
|     tests | ||||
| 
 | ||||
| [options.extras_require] | ||||
| # Add here additional requirements for extra features, to install with: | ||||
| # `pip install fastaudio[PDF]` like: | ||||
| # PDF = ReportLab; RXP | ||||
| # Add here test requirements (semicolon/line-separated) | ||||
| testing = | ||||
|     pytest>=7.1.3 | ||||
|     pytest-cov>=4.0.0 | ||||
| dev = | ||||
|     pre-commit>=2.20.0 | ||||
|     black>=22.8.0 | ||||
|     flake8>=5.0.4 | ||||
| cli = | ||||
|     hydra-core >=1.1,<=1.2 | ||||
| 
 | ||||
| 
 | ||||
| [options.entry_points] | ||||
| 
 | ||||
| console_scripts = | ||||
|     mayavoz-train=mayavoz.cli.train:train | ||||
| 
 | ||||
| [test] | ||||
| # py.test options when running `python setup.py test` | ||||
| # addopts = --verbose | ||||
| extras = True | ||||
| 
 | ||||
| [tool:pytest] | ||||
| # Options for py.test: | ||||
| # Specify command line options as you would do when invoking py.test directly. | ||||
| # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml | ||||
| # in order to write a coverage file that can be read by Jenkins. | ||||
| addopts = | ||||
|     --cov mayavoz --cov-report term-missing | ||||
|     --verbose | ||||
| norecursedirs = | ||||
|     dist | ||||
|     build | ||||
|     .tox | ||||
| testpaths = tests | ||||
| 
 | ||||
| [aliases] | ||||
| dists = bdist_wheel | ||||
| 
 | ||||
| [bdist_wheel] | ||||
| # Use this option if your package is pure-python | ||||
| universal = 1 | ||||
| 
 | ||||
| [build_sphinx] | ||||
| source_dir = doc | ||||
| build_dir = build/sphinx | ||||
| 
 | ||||
| [devpi:upload] | ||||
| # Options for the devpi: PyPI server and packaging tool | ||||
| # VCS export must be deactivated since we are using setuptools-scm | ||||
| no-vcs = 1 | ||||
| formats = bdist_wheel | ||||
| 
 | ||||
| [flake8] | ||||
| # Some sane defaults for the code style checker flake8 | ||||
| exclude = | ||||
|     .tox | ||||
|     build | ||||
|     dist | ||||
|     .eggs | ||||
| 
 | ||||
| [options.data_files] | ||||
| . = requirements.txt | ||||
| _ = version.txt | ||||
							
								
								
									
										63
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										63
									
								
								setup.py
								
								
								
								
							|  | @ -1,63 +0,0 @@ | |||
| import os | ||||
| import sys | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from pkg_resources import VersionConflict, require | ||||
| from setuptools import find_packages, setup | ||||
| 
 | ||||
| with open("README.md") as f: | ||||
|     long_description = f.read() | ||||
| 
 | ||||
| with open("requirements.txt") as f: | ||||
|     requirements = f.read().splitlines() | ||||
| 
 | ||||
| try: | ||||
|     require("setuptools>=38.3") | ||||
| except VersionConflict: | ||||
|     print("Error: version of setuptools is too old (<38.3)!") | ||||
|     sys.exit(1) | ||||
| 
 | ||||
| 
 | ||||
| ROOT_DIR = Path(__file__).parent.resolve() | ||||
| # Creating the version file | ||||
| 
 | ||||
| with open("version.txt") as f: | ||||
|     version = f.read() | ||||
| 
 | ||||
| version = version.strip() | ||||
| sha = "Unknown" | ||||
| 
 | ||||
| if os.getenv("BUILD_VERSION"): | ||||
|     version = os.getenv("BUILD_VERSION") | ||||
| elif sha != "Unknown": | ||||
|     version += "+" + sha[:7] | ||||
| print("-- Building version " + version) | ||||
| 
 | ||||
| version_path = ROOT_DIR / "mayavoz" / "version.py" | ||||
| 
 | ||||
| with open(version_path, "w") as f: | ||||
|     f.write("__version__ = '{}'\n".format(version)) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     setup( | ||||
|         name="mayavoz", | ||||
|         namespace_packages=["mayavoz"], | ||||
|         version=version, | ||||
|         packages=find_packages(), | ||||
|         install_requires=requirements, | ||||
|         description="Deep learning toolkit for speech enhancement", | ||||
|         long_description=long_description, | ||||
|         long_description_content_type="text/markdown", | ||||
|         author="Shahul Es", | ||||
|         author_email="shahules786@gmail.com", | ||||
|         url="", | ||||
|         classifiers=[ | ||||
|             "Development Status :: 4 - Beta", | ||||
|             "Intended Audience :: Science/Research", | ||||
|             "License :: OSI Approved :: MIT License", | ||||
|             "Natural Language :: English", | ||||
|             "Programming Language :: Python :: 3.7", | ||||
|             "Programming Language :: Python :: 3.8", | ||||
|             "Topic :: Scientific/Engineering", | ||||
|         ], | ||||
|     ) | ||||
|  | @ -0,0 +1,13 @@ | |||
| #!/bin/bash | ||||
| set -e | ||||
| 
 | ||||
| echo "Loading Anaconda Module" | ||||
| module load anaconda | ||||
| 
 | ||||
| echo "Creating Virtual Environment" | ||||
| conda env create -f environment.yml ||  conda env update -f environment.yml | ||||
| 
 | ||||
| source activate enhancer | ||||
| 
 | ||||
| echo "copying files" | ||||
| # cp /scratch/$USER/TIMIT/.* /deep-transcriber | ||||
|  | @ -1,7 +1,7 @@ | |||
| import pytest | ||||
| import torch | ||||
| 
 | ||||
| from mayavoz.loss import mean_absolute_error, mean_squared_error | ||||
| from enhancer.loss import mean_absolute_error, mean_squared_error | ||||
| 
 | ||||
| loss_functions = [mean_absolute_error(), mean_squared_error()] | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,50 +0,0 @@ | |||
| import torch | ||||
| 
 | ||||
| from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d | ||||
| from mayavoz.models.complexnn.rnn import ComplexLSTM | ||||
| from mayavoz.models.complexnn.utils import ComplexBatchNorm2D | ||||
| 
 | ||||
| 
 | ||||
| def test_complexconv2d(): | ||||
|     sample_input = torch.rand(1, 2, 256, 13) | ||||
|     conv = ComplexConv2d( | ||||
|         2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1) | ||||
|     ) | ||||
|     with torch.no_grad(): | ||||
|         out = conv(sample_input) | ||||
|     assert out.shape == torch.Size([1, 32, 128, 13]) | ||||
| 
 | ||||
| 
 | ||||
| def test_complexconvtranspose2d(): | ||||
|     sample_input = torch.rand(1, 512, 4, 13) | ||||
|     conv = ComplexConvTranspose2d( | ||||
|         256 * 2, | ||||
|         128 * 2, | ||||
|         kernel_size=(5, 2), | ||||
|         stride=(2, 1), | ||||
|         padding=(2, 0), | ||||
|         output_padding=(1, 0), | ||||
|     ) | ||||
|     with torch.no_grad(): | ||||
|         out = conv(sample_input) | ||||
| 
 | ||||
|     assert out.shape == torch.Size([1, 256, 8, 14]) | ||||
| 
 | ||||
| 
 | ||||
| def test_complexlstm(): | ||||
|     sample_input = torch.rand(13, 2, 128) | ||||
|     lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2) | ||||
|     with torch.no_grad(): | ||||
|         out = lstm(sample_input) | ||||
| 
 | ||||
|     assert out[0].shape == torch.Size([13, 1, 512]) | ||||
|     assert out[1].shape == torch.Size([13, 1, 512]) | ||||
| 
 | ||||
| 
 | ||||
| def test_complexbatchnorm2d(): | ||||
|     sample_input = torch.rand(1, 64, 64, 14) | ||||
|     batchnorm = ComplexBatchNorm2D(num_features=64) | ||||
|     with torch.no_grad(): | ||||
|         out = batchnorm(sample_input) | ||||
| 
 | ||||
|     assert out.size() == sample_input.size() | ||||
|  | @ -1,9 +1,9 @@ | |||
| import pytest | ||||
| import torch | ||||
| 
 | ||||
| from mayavoz.data.dataset import MayaDataset | ||||
| from mayavoz.models import Demucs | ||||
| from mayavoz.utils.config import Files | ||||
| from enhancer.data.dataset import EnhancerDataset | ||||
| from enhancer.models import Demucs | ||||
| from enhancer.utils.config import Files | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
|  | @ -15,9 +15,7 @@ def vctk_dataset(): | |||
|         test_clean="clean_testset_wav", | ||||
|         test_noisy="noisy_testset_wav", | ||||
|     ) | ||||
|     dataset = MayaDataset( | ||||
|         name="vctk", root_dir=root_dir, files=files, sampling_rate=16000 | ||||
|     ) | ||||
|     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
|  | @ -32,7 +30,7 @@ def test_forward(batch_size, samples): | |||
| 
 | ||||
|     data = torch.rand(batch_size, 2, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         with pytest.raises(ValueError): | ||||
|         with pytest.raises(TypeError): | ||||
|             _ = model(data) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,45 +0,0 @@ | |||
| import pytest | ||||
| import torch | ||||
| 
 | ||||
| from mayavoz.data.dataset import MayaDataset | ||||
| from mayavoz.models.dccrn import DCCRN | ||||
| from mayavoz.utils.config import Files | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def vctk_dataset(): | ||||
|     root_dir = "tests/data/vctk" | ||||
|     files = Files( | ||||
|         train_clean="clean_testset_wav", | ||||
|         train_noisy="noisy_testset_wav", | ||||
|         test_clean="clean_testset_wav", | ||||
|         test_noisy="noisy_testset_wav", | ||||
|     ) | ||||
|     dataset = MayaDataset( | ||||
|         name="vctk", root_dir=root_dir, files=files, sampling_rate=16000 | ||||
|     ) | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | ||||
| def test_forward(batch_size, samples): | ||||
|     model = DCCRN() | ||||
|     model.eval() | ||||
| 
 | ||||
|     data = torch.rand(batch_size, 1, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         _ = model(data) | ||||
| 
 | ||||
|     data = torch.rand(batch_size, 2, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         with pytest.raises(ValueError): | ||||
|             _ = model(data) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "dataset,channels,loss", | ||||
|     [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], | ||||
| ) | ||||
| def test_demucs_init(dataset, channels, loss): | ||||
|     with torch.no_grad(): | ||||
|         _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss) | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue