Compare commits

..

132 Commits

Author SHA1 Message Date
shahules786 957db4fcba dec duration 2022-11-24 10:31:10 +05:30
shahules786 24ac484b23 config 2022-11-23 19:26:16 +05:30
shahules786 530dab4648 fix metrics 2022-11-21 11:45:43 +05:30
shahules786 ebfb64c766 config 2022-11-21 09:45:46 +05:30
shahules786 9da80bbfb1 config 2022-11-21 09:45:24 +05:30
shahules786 75aa54b9d7 config 2022-11-15 14:55:24 +05:30
shahules786 07eb1d53ef config 2022-11-14 11:10:11 +05:30
shahules786 d522dc4233 config 2022-11-14 10:53:11 +05:30
shahules786 d5b17f3745 negate 2022-11-14 10:51:26 +05:30
shahules786 4e58df5e37 reduce bs 2022-11-10 12:07:02 +05:30
shahules786 ea9218077e config 2022-11-10 10:41:24 +05:30
shahules786 ca6797c3f1 config 2022-11-10 10:41:07 +05:30
shahules786 effb4b03fb fix valid 2022-11-10 10:33:35 +05:30
shahules786 2dda2fa1c1 config 2022-11-09 19:08:41 +05:30
shahules786 b5582832f3 config 2022-11-09 19:07:53 +05:30
shahules786 86b71ce090 mv coeff to device 2022-11-09 18:53:31 +05:30
shahules786 cad0bbedc8 dccrn 2022-11-08 17:18:45 +05:30
shahules786 c4e392aff5 config 2022-11-07 19:38:27 +05:30
shahules786 24fa16ca25 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-11-07 16:01:10 +05:30
shahules786 846b64ab88 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-11-07 13:01:21 +05:30
shahules786 e304e36c8a Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-11-07 12:28:44 +05:30
shahules786 69f6bb4926 add direction si-snr 2022-11-07 12:26:58 +05:30
shahules786 5b635a82a9 fix param 2022-11-07 12:14:41 +05:30
shahules786 ba25365eab config 2022-11-07 12:08:45 +05:30
shahules786 2d0b309b4d Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-11-07 12:02:42 +05:30
shahules786 ce04720e59 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-11-07 11:54:25 +05:30
shahules786 c0b18872b7 config 2022-11-07 10:55:10 +05:30
shahules786 cdda1deb87 rmv arg 2022-11-03 11:38:43 +05:30
shahules786 a3e488f101 config 2022-11-02 10:47:03 +05:30
shahules786 4badf64540 waveunet 2022-10-31 10:42:15 +05:30
shahules786 887a792d21 waveunet 2022-10-31 10:41:48 +05:30
shahules786 9e315ca6c4 waveunet 2022-10-31 10:11:22 +05:30
shahules786 00bb38c95b vctk 2022-10-31 10:08:24 +05:30
shahules786 3879dce620 config 2022-10-31 10:06:20 +05:30
shahules786 53e223954e config 2022-10-29 10:43:33 +05:30
shahules786 ce37ac06c6 config 2022-10-29 09:42:57 +05:30
shahules786 c4b8c5dfc8 rmv sampler print 2022-10-28 17:22:36 +05:30
shahules786 acb68c9855 merge dev 2022-10-28 13:10:36 +05:30
shahules786 8d8eaa80d5 vctk+demucs 2022-10-28 13:09:29 +05:30
shahules786 6028d918b6 debug ddp 2022-10-28 10:17:51 +05:30
shahules786 c321abe2ec DNS 2022-10-27 21:33:08 +05:30
shahules786 73cc925059 fix earlystop 2022-10-27 21:26:56 +05:30
shahules786 3a9e577ccb Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-27 16:18:53 +05:30
shahules786 430696cfac config 2022-10-27 15:31:47 +05:30
shahules786 77f3658e5c config 2022-10-27 15:30:06 +05:30
shahules786 dbfa580618 merge dev 2022-10-27 15:23:17 +05:30
shahules786 7425c9bf3a replace with mse 2022-10-22 20:25:43 +05:30
shahules786 96c934e96f w/o striding 2022-10-22 18:09:55 +05:30
shahules786 4cfe7a2463 Demus + DNS 2022-10-22 12:07:45 +05:30
shahules786 f492e44e6b Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-22 12:00:46 +05:30
shahules786 4b34cf6980 merge dev 2022-10-22 11:21:22 +05:30
shahules786 4b0e8a5ef1 DNS-2020 + VTCK 2022-10-22 11:20:25 +05:30
shahules786 ed30840f4e cpu 2022-10-19 16:41:23 +05:30
shahules786 23bee75ceb dataloader 2022-10-19 16:39:31 +05:30
shahules786 38e0de689e stride=1 2022-10-19 15:28:00 +05:30
shahules786 ba855e39e5 run waveunet 2022-10-19 15:27:41 +05:30
shahules786 2d8ca3f4b2 400 epochs
:
2022-10-19 12:40:27 +05:30
shahules786 11fbba6f77 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-19 12:38:54 +05:30
shahules786 7b3626c912 dns 2022-10-19 10:03:00 +05:30
shahules786 a2992cf759 dns 2022-10-19 09:33:30 +05:30
shahules786 3bec8c7723 dns 2022-10-18 21:31:39 +05:30
shahules786 3df4b27132 dns 2022-10-18 21:29:55 +05:30
shahules786 8737ed8066 config 2022-10-18 21:17:24 +05:30
shahules786 f426c3d880 vctk 4.5 2022-10-18 15:31:02 +05:30
shahules786 982520f30d vctk + demucs 2022-10-18 15:30:42 +05:30
shahules786 94d70c4ddf Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-18 15:23:57 +05:30
shahules786 d7991a5c0e fix arg 2022-10-18 09:51:56 +05:30
shahules786 37a4471b07 demucs + vctk 56 2022-10-17 21:43:48 +05:30
shahules786 399e7062f2 rmv mv operation 2022-10-17 15:33:29 +05:30
shahules786 897e913cfa valid size 30mins 2022-10-17 13:12:02 +05:30
shahules786 77e5a14908 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-17 13:10:54 +05:30
shahules786 3cb5c18c39 30hrs data 2022-10-17 12:39:29 +05:30
shahules786 4ed7fe3ce5 dns 30 hrs demucs 2022-10-17 11:03:09 +05:30
shahules786 32da7b347c demucs + vctk 3 sec 2022-10-16 17:40:29 +05:30
shahules786 3014a41501 load best model to test 2022-10-16 12:22:46 +05:30
shahules786 dab68de260 config 2022-10-16 12:04:43 +05:30
shahules786 288b5f4906 merge dev 2022-10-16 11:17:04 +05:30
shahules786 45b6fe0f3d Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-15 12:24:20 +05:30
shahules786 4744390dc6 configure dns 2022-10-15 11:51:58 +05:30
shahules786 b0c73bd109 dns 2020 2022-10-15 11:20:09 +05:30
shahules786 a66089a920 set max time 2022-10-15 11:18:30 +05:30
shahules786 8bbd1abf2f dns 2020 2022-10-15 11:18:20 +05:30
shahules786 6e0f69f575 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-14 12:46:48 +05:30
shahules786 0e58691a2c demucs 250 2022-10-14 12:45:34 +05:30
shahules786 807f4b93ea Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-14 12:43:47 +05:30
shahules786 315d646347 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-14 11:32:59 +05:30
shahules786 f34e49e341 WaveUnet 2022-10-14 11:15:16 +05:30
shahules786 fa47860f57 set BS to 256 2022-10-14 11:12:16 +05:30
shahules786 f7eb0a600c 500 epochs 2022-10-14 10:47:20 +05:30
shahules786 ba2d00648c demucs 100 epochs 2022-10-13 10:57:24 +05:30
shahules786 8a55a77640 run 100 epochs 2022-10-13 10:52:22 +05:30
shahules786 94a4ea38ed Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-13 10:50:59 +05:30
shahules786 8d25b0ed79 reduce epochs 2022-10-12 20:27:05 +05:30
shahules786 09ba645315 fix logging 2022-10-12 20:23:55 +05:30
shahules786 8906496366 waveunet 500 epochs 2022-10-12 10:49:00 +05:30
shahules786 e4a2eb7844 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-12 10:32:52 +05:30
shahules786 8a6af87627 pesq 2022-10-11 21:56:55 +05:30
shahules786 5a392332ba ensure 2 gpus 2022-10-11 21:56:35 +05:30
shahules786 f66a5236e1 Revert "demucs"
This reverts commit d415bb0c59.
2022-10-11 21:54:47 +05:30
shahules786 d415bb0c59 demucs 2022-10-11 21:41:19 +05:30
shahules786 8c1524a998 500 epochs 2022-10-11 21:38:27 +05:30
shahules786 7161f84a27 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-11 21:36:59 +05:30
shahules786 2c79e60a85 params 2022-10-11 21:33:19 +05:30
shahules786 41ee2fce0b Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-11 21:30:40 +05:30
shahules786 0c5db496e2 run waveunet 2022-10-11 16:51:41 +05:30
shahules786 031221b79e merge dev 2022-10-11 16:50:09 +05:30
shahules786 50062eaf40 rmv inplace operation 2022-10-11 15:10:34 +05:30
shahules786 0b02b73094 run demucs 32 2022-10-11 11:12:44 +05:30
shahules786 2ccc2822cd Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-11 11:12:02 +05:30
shahules786 1667de624e min settings 2022-10-10 21:04:43 +05:30
shahules786 32579b7a39 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-10 21:04:01 +05:30
shahules786 bb68e9e4eb demucs 2022-10-10 16:48:40 +05:30
shahules786 a21ef707ad ensure gpu 2022-10-10 15:59:48 +05:30
shahules786 81c5f13ff6 log metric 2022-10-10 15:32:37 +05:30
shahules786 a417e226f3 testrun for metrics 2022-10-10 12:49:41 +05:30
shahules786 5d8f49d78e Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-10 12:48:11 +05:30
shahules786 14156743f9 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-08 11:04:32 +05:30
shahules786 845575a2ad config 2022-10-08 10:18:22 +05:30
shahules786 c9b78b0e73 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-08 10:12:38 +05:30
shahules786 3068476512 reduce batch_size 2022-10-08 09:59:23 +05:30
shahules786 ffb364196e increase sr 2022-10-07 11:32:33 +05:30
shahules786 52cefcb962 run demucs 2022-10-07 10:56:14 +05:30
shahules786 61923f6d68 config 2022-10-07 10:46:06 +05:30
shahules786 e90efe3163 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-07 10:43:34 +05:30
shahules786 aa043aaf40 rmv max_steps 2022-10-06 11:52:05 +05:30
shahules786 4f6ccadf4b Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-06 11:49:40 +05:30
shahules786 0e982cd493 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-06 10:33:26 +05:30
shahules786 0787d946da decrease epochs 2022-10-06 10:21:07 +05:30
shahules786 e06ba07889 Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-06 10:19:38 +05:30
shahules786 741fd7b87c run cli 2022-10-06 09:55:01 +05:30
shahules786 a064151e2e Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk 2022-10-06 09:54:14 +05:30
shahules786 25557757c7 inc epochs 2022-10-03 21:26:59 +05:30
103 changed files with 221 additions and 2143 deletions

View File

@ -1,5 +1,5 @@
[flake8] [flake8]
per-file-ignores = "mayavoz/model/__init__.py:F401" per-file-ignores = __init__.py:F401
ignore = E203, E266, E501, W503 ignore = E203, E266, E501, W503
# line length is intentionally set to 80 here because black uses Bugbear # line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details # See https://github.com/psf/black/blob/master/README.md#line-length for more details

1
.gitattributes vendored
View File

@ -1 +0,0 @@
notebooks/** linguist-vendored

View File

@ -1,13 +1,13 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: mayavoz name: Enhancer
on: on:
push: push:
branches: [ main ] branches: [ dev ]
pull_request: pull_request:
branches: [ main ] branches: [ dev ]
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -40,12 +40,12 @@ jobs:
sudo apt-get install libsndfile1 sudo apt-get install libsndfile1
pip install -r requirements.txt pip install -r requirements.txt
pip install black pytest-cov pip install black pytest-cov
- name: Install mayavoz - name: Install enhancer
run: | run: |
pip install -e .[dev,testing] pip install -e .[dev,testing]
- name: Run black - name: Run black
run: run:
black --check . --exclude mayavoz/version.py black --check . --exclude enhancer/version.py
- name: Test with pytest - name: Test with pytest
run: run:
pytest tests --cov=mayavoz/ pytest tests --cov=enhancer/

4
.gitignore vendored
View File

@ -1,8 +1,4 @@
#local #local
cleaned_my_voice.wav
lightning_logs/
my_voice.wav
pretrained/
*.ckpt *.ckpt
*_local.yaml *_local.yaml
cli/train_config/dataset/Vctk_local.yaml cli/train_config/dataset/Vctk_local.yaml

View File

@ -23,7 +23,6 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
args: ['--ignore=E203,E501,F811,E712,W503'] args: ['--ignore=E203,E501,F811,E712,W503']
exclude: __init__.py
# Formatting, Whitespace, etc # Formatting, Whitespace, etc
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks

View File

@ -1,46 +0,0 @@
# Contributing
Hi there 👋
If you're reading this I hope that you're looking forward to adding value to Mayavoz. This document will help you to get started with your journey.
## How to get your code in Mayavoz
1. We use git and GitHub.
2. Fork the mayavoz repository (https://github.com/shahules786/mayavoz) on GitHub under your own account. (This creates a copy of mayavoz under your account, and GitHub knows where it came from, and we typically call this “upstream”.)
3. Clone your own mayavoz repository. git clone https://github.com/ <your-account> /mayavoz (This downloads the git repository to your machine, git knows where it came from, and calls it “origin”.)
4. Create a branch for each specific feature you are developing. git checkout -b your-branch-name
5. Make + commit changes. git add files-you-changed ... git commit -m "Short message about what you did"
6. Push the branch to your GitHub repository. git push origin your-branch-name
7. Navigate to GitHub, and create a pull request from your branch to the upstream repository mayavoz/mayavoz, to the “develop” branch.
8. The Pull Request (PR) appears on the upstream repository. Discuss your contribution there. If you push more changes to your branch on GitHub (on your repository), they are added to the PR.
9. When the reviewer is satisfied that the code improves repository quality, they can merge.
Note that CI tests will be run when you create a PR. If you want to be sure that your code will not fail these tests, we have set up pre-commit hooks that you can install.
**If you're worried about things not being perfect with your code, we will work togethor and make it perfect. So, make your move!**
## Formating
We use [black](https://black.readthedocs.io/en/stable/) and [flake8](https://flake8.pycqa.org/en/latest/) for code formating. Please ensure that you use the same before submitting the PR.
## Testing
We adopt unit testing using [pytest](https://docs.pytest.org/en/latest/contents.html)
Please make sure that adding your new component does not decrease test coverage.
## Other tools
The use of [per-commit](https://pre-commit.com/) is recommended to ensure different requirements such as code formating, etc.
## How to start contributing to Mayavoz?
1. Checkout issues marked as `good first issue`, let us know you're interested in working on some issue by commenting under it.
2. For others, I would suggest you to explore mayavoz. One way to do is to use it to train your own model. This was you might end by finding a new unreported bug or getting an idea to improve Mayavoz.

20
LICENSE
View File

@ -1,20 +0,0 @@
MIT License
Copyright (c) 2022 Shahul Es
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,4 +0,0 @@
recursive-include mayavoz *.py
recursive-include mayavoz *.yaml
global-exclude *.pyc
global-exclude __pycache__

View File

@ -2,52 +2,24 @@
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" /> <img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
</p> </p>
![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/shahules786/mayavoz/ci.yaml?branch=main) mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
![GitHub](https://img.shields.io/github/license/shahules786/enhancer)
![GitHub issues](https://img.shields.io/github/issues/shahules786/enhancer?logo=GitHub)
![GitHub Repo stars](https://img.shields.io/github/stars/shahules786/enhancer?style=social)
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio practioners & researchers. It provides easy to use pretrained speech enhancement models and facilitates highly customisable model training. | **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
| **[Quick Start](#quick-start-fire)** | **[Installation](#installation)** | **[Tutorials](https://github.com/shahules786/enhancer/tree/main/notebooks)** | **[Available Recipes](#recipes)** | **[Demo](#demo)**
## Key features :key: ## Key features :key:
* Various pretrained models nicely integrated with [huggingface hub](https://huggingface.co/docs/hub/index) :hugs: that users can select and use without any hastle. * Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
* :package: Ability to train and validate your own custom speech enhancement models with just under 10 lines of code! * :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself! * :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
* :zap: Supports multi-gpu training integrated with [Pytorch Lightning](https://pytorchlightning.ai/). * :zap: Supports multi-gpu training integrated with Pytorch Lightning.
* :shield: data augmentations integrated using [torch-augmentations](https://github.com/asteroid-team/torch-audiomentations)
## Demo
Noisy speech followed by enhanced version.
https://user-images.githubusercontent.com/25312635/203756185-737557f4-6e21-4146-aa2c-95da69d0de4c.mp4
## Quick Start :fire: ## Quick Start :fire:
``` python ``` python
from mayavoz.models import Mayamodel from mayavoz import Mayamodel
model = Mayamodel.from_pretrained("shahules786/mayavoz-waveunet-valentini-28spk") model = Mayamodel.from_pretrained("mayavoz/waveunet")
model.enhance("noisy_audio.wav") model("noisy_audio.wav")
``` ```
## Recipes
| Model | Dataset | STOI | PESQ | URL |
| :---: | :---: | :---: | :---: | :---: |
| WaveUnet | Valentini-28spk | 0.836 | 2.78 | shahules786/mayavoz-waveunet-valentini-28spk |
| Demucs | Valentini-28spk | 0.961 | 2.56 | shahules786/mayavoz-demucs-valentini-28spk |
| DCCRN | Valentini-28spk | 0.724 | 2.55 | shahules786/mayavoz-dccrn-valentini-28spk |
| Demucs | MS-SNSD-20hrs | 0.56 | 1.26 | shahules786/mayavoz-demucs-ms-snsd-20 |
Test scores are based on respective test set associated with train dataset.
**See [tutorials](/notebooks/) to train your custom model**
## Installation ## Installation
Only Python 3.8+ is officially supported (though it might work with Python 3.7) Only Python 3.8+ is officially supported (though it might work with Python 3.7)
@ -69,10 +41,3 @@ git clone url
cd mayavoz cd mayavoz
pip install -e . pip install -e .
``` ```
## Support
For commercial enquiries and scientific consulting, please [contact me](https://shahules786.github.io/).
### Acknowledgements
Sincere gratitude to [AMPLYFI](https://amplyfi.com/) for supporting this project.

View File

@ -1,2 +1 @@
__import__("pkg_resources").declare_namespace(__name__) __import__("pkg_resources").declare_namespace(__name__)
from mayavoz.models import Mayamodel

View File

@ -60,9 +60,9 @@ def main(config: DictConfig):
if parameters.get("Early_stop", False): if parameters.get("Early_stop", False):
early_stopping = EarlyStopping( early_stopping = EarlyStopping(
monitor="val_loss", monitor=f"valid_{parameters.get('EarlyStopping_metric','loss')}",
mode=direction, mode=direction,
min_delta=0.0, min_delta=parameters.get("EarlyStopping_delta", 0.00),
patience=parameters.get("EarlyStopping_patience", 10), patience=parameters.get("EarlyStopping_patience", 10),
strict=True, strict=True,
verbose=False, verbose=False,
@ -93,7 +93,7 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model) trainer.fit(model)
trainer.test(model) trainer.test(ckpt_path="best")
logger.experiment.log_artifact( logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"

View File

@ -1,6 +1,6 @@
defaults: defaults:
- model : Demucs - model : DCCRN
- dataset : Vctk - dataset : DNS-2020
- optimizer : Adam - optimizer : Adam
- hyperparameters : default - hyperparameters : default
- trainer : default - trainer : default

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
root_dir : /scratch/c.sistc3/MS-SNSD/DNS20
name : dns-2020
duration : 1.5
stride : 1
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 25.0
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_testing
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_testing

View File

@ -1,8 +1,8 @@
_target_: mayavoz.data.dataset.MayaDataset _target_: enhancer.data.dataset.EnhancerDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 2
stride : 0.5 stride : 1
sampling_rate: 16000 sampling_rate: 16000
batch_size: 32 batch_size: 32
min_valid_minutes : 25 min_valid_minutes : 25

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
duration : 1.0
sampling_rate: 16000
batch_size: 64
num_workers : 0
files:
train_clean : clean_testset_wav
test_clean : clean_testset_wav
train_noisy : noisy_testset_wav
test_noisy : noisy_testset_wav

View File

@ -2,6 +2,7 @@ loss : si-snr
metric : [stoi,pesq] metric : [stoi,pesq]
lr : 0.001 lr : 0.001
ReduceLr_patience : 10 ReduceLr_patience : 10
Early_stop : False
ReduceLr_factor : 0.5 ReduceLr_factor : 0.5
min_lr : 0.000001 min_lr : 0.0000001
EarlyStopping_factor : 10 EarlyStopping_patience : 10

View File

@ -0,0 +1,2 @@
experiment_name : shahules/enhancer
run_name : dccrn-dns20

View File

@ -1,10 +1,10 @@
_target_: mayavoz.models.dccrn.DCCRN _target_: enhancer.models.dccrn.DCCRN
num_channels: 1 num_channels: 1
sampling_rate : 16000 sampling_rate : 16000
complex_lstm : True complex_lstm : True
complex_norm : True complex_norm : True
complex_relu : True complex_relu : True
masking_mode : True masking_mode : "E"
encoder_decoder: encoder_decoder:
initial_output_channels : 32 initial_output_channels : 32

View File

@ -1,4 +1,4 @@
_target_: mayavoz.models.demucs.Demucs _target_: enhancer.models.demucs.Demucs
num_channels: 1 num_channels: 1
resample: 4 resample: 4
sampling_rate : 16000 sampling_rate : 16000

View File

@ -1,4 +1,4 @@
_target_: mayavoz.models.waveunet.WaveUnet _target_: enhancer.models.waveunet.WaveUnet
num_channels : 1 num_channels : 1
depth : 9 depth : 9
initial_output_channels: 24 initial_output_channels: 24

View File

@ -2,7 +2,7 @@ _target_: pytorch_lightning.Trainer
accelerator: gpu accelerator: gpu
accumulate_grad_batches: 1 accumulate_grad_batches: 1
amp_backend: native amp_backend: native
auto_lr_find: True auto_lr_find: False
auto_scale_batch_size: False auto_scale_batch_size: False
auto_select_gpus: True auto_select_gpus: True
benchmark: False benchmark: False
@ -23,7 +23,7 @@ limit_test_batches: 1.0
limit_train_batches: 1.0 limit_train_batches: 1.0
limit_val_batches: 1.0 limit_val_batches: 1.0
log_every_n_steps: 50 log_every_n_steps: 50
max_epochs: 200 max_epochs: 100
max_steps: -1 max_steps: -1
max_time: null max_time: null
min_epochs: 1 min_epochs: 1
@ -43,4 +43,3 @@ sync_batchnorm: False
tpu_cores: null tpu_cores: null
track_grad_norm: -1 track_grad_norm: -1
val_check_interval: 1.0 val_check_interval: 1.0
weights_save_path: null

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -1,8 +1,6 @@
import math import math
import multiprocessing import multiprocessing
import os import os
import sys
import warnings
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -13,11 +11,11 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch_audiomentations import Compose from torch_audiomentations import Compose
from mayavoz.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
from mayavoz.utils import check_files from enhancer.utils import check_files
from mayavoz.utils.config import Files from enhancer.utils.config import Files
from mayavoz.utils.io import Audio from enhancer.utils.io import Audio
from mayavoz.utils.random import create_unique_rng from enhancer.utils.random import create_unique_rng
LARGE_NUM = 2147483647 LARGE_NUM = 2147483647
@ -82,21 +80,6 @@ class TaskDataset(pl.LightningDataModule):
self._validation = [] self._validation = []
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
if (
num_workers > 0
and sys.platform == "darwin"
and sys.version_info[0] >= 3
and sys.version_info[1] >= 8
):
warnings.warn(
"num_workers > 0 is not supported with macOS and Python 3.8+: "
"setting num_workers = 0."
)
num_workers = 0
self.num_workers = num_workers self.num_workers = num_workers
if min_valid_minutes > 0.0: if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes self.min_valid_minutes = min_valid_minutes
@ -152,7 +135,7 @@ class TaskDataset(pl.LightningDataModule):
speaker_index = rng.choice(possible_indices) speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index) possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index] speaker_name = all_speakers[speaker_index]
print(f"Selected f{speaker_name} for valid") print(f"Selected {speaker_name} for valid")
file_indices = [ file_indices = [
i i
for i, file in enumerate(data) for i, file in enumerate(data)
@ -265,7 +248,7 @@ class TaskDataset(pl.LightningDataModule):
) )
class MayaDataset(TaskDataset): class EnhancerDataset(TaskDataset):
""" """
Dataset object for creating clean-noisy speech enhancement datasets Dataset object for creating clean-noisy speech enhancement datasets
paramters: paramters:
@ -275,7 +258,7 @@ class MayaDataset(TaskDataset):
root directory of the dataset containing clean/noisy folders root directory of the dataset containing clean/noisy folders
files : Files files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer mayavoz.utils.Files dataclass) folder names (refer enhancer.utils.Files dataclass)
min_valid_minutes: float min_valid_minutes: float
minimum validation split size time in minutes minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data. algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.

View File

@ -93,9 +93,9 @@ class Fileprocessor:
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None: if matching_function is None:
if name.lower() in ("vctk", "valentini"): if name.lower() == "vctk":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "ms-snsd": elif name.lower() == "dns-2020":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else: else:
raise ValueError( raise ValueError(

View File

@ -8,7 +8,7 @@ from librosa import load as load_audio
from scipy.io import wavfile from scipy.io import wavfile
from scipy.signal import get_window from scipy.signal import get_window
from mayavoz.utils import Audio from enhancer.utils import Audio
class Inference: class Inference:
@ -95,7 +95,6 @@ class Inference:
): ):
""" """
stitch batched waveform into single waveform. (Overlap-add) stitch batched waveform into single waveform. (Overlap-add)
inspired from https://github.com/asteroid-team/asteroid
arguments: arguments:
data: batched waveform data: batched waveform
window_size : window_size used to batch waveform window_size : window_size used to batch waveform

View File

@ -1,4 +1,4 @@
import warnings import logging
import numpy as np import numpy as np
import torch import torch
@ -134,7 +134,7 @@ class Pesq:
try: try:
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
except Exception as e: except Exception as e:
warnings.warn(f"{e} error occured while calculating PESQ") logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values)) return torch.tensor(np.mean(pesq_values))

View File

@ -0,0 +1,3 @@
from enhancer.models.demucs import Demucs
from enhancer.models.model import Model
from enhancer.models.waveunet import WaveUnet

View File

@ -0,0 +1,5 @@
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
from enhancer.models.complexnn.utils import ComplexRelu # noqa

View File

@ -129,7 +129,7 @@ class ComplexConvTranspose2d(nn.Module):
imag_real = self.real_conv(imag) imag_real = self.real_conv(imag)
real = real_real - imag_imag real = real_real - imag_imag
imag = real_imag + imag_real imag = real_imag - imag_real
out = torch.cat([real, imag], 1) out = torch.cat([real, imag], 1)

View File

@ -1,22 +1,22 @@
import warnings import logging
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from mayavoz.data import MayaDataset from enhancer.data import EnhancerDataset
from mayavoz.models import Mayamodel from enhancer.models import Model
from mayavoz.models.complexnn import ( from enhancer.models.complexnn import (
ComplexBatchNorm2D, ComplexBatchNorm2D,
ComplexConv2d, ComplexConv2d,
ComplexConvTranspose2d, ComplexConvTranspose2d,
ComplexLSTM, ComplexLSTM,
ComplexRelu, ComplexRelu,
) )
from mayavoz.models.complexnn.utils import complex_cat from enhancer.models.complexnn.utils import complex_cat
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT from enhancer.utils.transforms import ConviSTFT, ConvSTFT
from mayavoz.utils.utils import merge_dict from enhancer.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module): class DCCRN_ENCODER(nn.Module):
@ -98,7 +98,7 @@ class DCCRN_DECODER(nn.Module):
return self.decoder(waveform) return self.decoder(waveform)
class DCCRN(Mayamodel): class DCCRN(Model):
STFT_DEFAULTS = { STFT_DEFAULTS = {
"window_len": 400, "window_len": 400,
@ -134,17 +134,17 @@ class DCCRN(Mayamodel):
num_channels: int = 1, num_channels: int = 1,
sampling_rate=16000, sampling_rate=16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse", loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else duration dataset.duration if isinstance(dataset, EnhancerDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
warnings.warn( logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

View File

@ -1,14 +1,14 @@
import logging
import math import math
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models.model import Mayamodel from enhancer.models.model import Model
from mayavoz.utils.io import Audio as audio from enhancer.utils.io import Audio as audio
from mayavoz.utils.utils import merge_dict from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module): class DemucsLSTM(nn.Module):
@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module):
return out return out
class Demucs(Mayamodel): class Demucs(Model):
""" """
Demucs model from https://arxiv.org/pdf/1911.13254.pdf Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters: parameters:
@ -102,8 +102,8 @@ class Demucs(Mayamodel):
sampling rate of input audio sampling rate of input audio
lr : float, defaults to 1e-3 lr : float, defaults to 1e-3
learning rate used for training learning rate used for training
dataset: MayaDataset, optional dataset: EnhancerDataset, optional
MayaDataset object containing train/validation data for training EnhancerDataset object containing train/validation data for training
duration : float, optional duration : float, optional
chunk duration in seconds chunk duration in seconds
loss : string or List of strings loss : string or List of strings
@ -135,18 +135,17 @@ class Demucs(Mayamodel):
sampling_rate=16000, sampling_rate=16000,
normalize=True, normalize=True,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
floor=1e-3, floor=1e-3,
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else duration dataset.duration if isinstance(dataset, EnhancerDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
warnings.warn( logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

View File

@ -13,21 +13,17 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.inference import Inference from enhancer.inference import Inference
from mayavoz.loss import LOSS_MAP, LossWrapper from enhancer.loss import LOSS_MAP, LossWrapper
from mayavoz.version import __version__ from enhancer.version import __version__
CACHE_DIR = os.getenv( CACHE_DIR = ""
"ENHANCER_CACHE", HF_TORCH_WEIGHTS = ""
os.path.expanduser("~/.cache/torch/mayavoz"),
)
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu" DEFAULT_DEVICE = "cpu"
SAVE_NAME = "mayavoz"
class Mayamodel(pl.LightningModule): class Model(pl.LightningModule):
""" """
Base class for all models Base class for all models
parameters: parameters:
@ -37,8 +33,8 @@ class Mayamodel(pl.LightningModule):
audio sampling rate audio sampling rate
lr: float, optional lr: float, optional
learning rate for model training learning rate for model training
dataset: MayaDataset, optional dataset: EnhancerDataset, optional
mayavoz dataset used for training/validation Enhancer dataset used for training/validation
duration: float, optional duration: float, optional
duration used for training/inference duration used for training/inference
loss : string or List of strings or custom loss (nn.Module), default to "mse" loss : string or List of strings or custom loss (nn.Module), default to "mse"
@ -51,13 +47,15 @@ class Mayamodel(pl.LightningModule):
num_channels: int = 1, num_channels: int = 1,
sampling_rate: int = 16000, sampling_rate: int = 16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse", metric: Union[str, List, Any] = "mse",
): ):
super().__init__() super().__init__()
assert num_channels == 1, "mayavoz only support for mono channel models" assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset self.dataset = dataset
self.save_hyperparameters( self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration" "num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
@ -234,8 +232,8 @@ class Mayamodel(pl.LightningModule):
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint[SAVE_NAME] = { checkpoint["enhancer"] = {
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__}, "version": {"enhancer": __version__, "pytorch": torch.__version__},
"architecture": { "architecture": {
"module": self.__class__.__module__, "module": self.__class__.__module__,
"class": self.__class__.__name__, "class": self.__class__.__name__,
@ -288,8 +286,8 @@ class Mayamodel(pl.LightningModule):
Returns Returns
------- -------
model : Mayamodel model : Model
Mayamodel Model
See also See also
-------- --------
@ -318,7 +316,7 @@ class Mayamodel(pl.LightningModule):
) )
model_path_pl = cached_download( model_path_pl = cached_download(
url=url, url=url,
library_name="mayavoz", library_name="enhancer",
library_version=__version__, library_version=__version__,
cache_dir=cached_dir, cache_dir=cached_dir,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
@ -328,8 +326,8 @@ class Mayamodel(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE) map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl, map_location) loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"] module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"] class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name) module = import_module(module_name)
Klass = getattr(module, class_name) Klass = getattr(module, class_name)

View File

@ -1,12 +1,12 @@
import warnings import logging
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models.model import Mayamodel from enhancer.models.model import Model
class WavenetDecoder(nn.Module): class WavenetDecoder(nn.Module):
@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module):
return self.encoder(waveform) return self.encoder(waveform)
class WaveUnet(Mayamodel): class WaveUnet(Model):
""" """
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
parameters: parameters:
@ -80,8 +80,8 @@ class WaveUnet(Mayamodel):
sampling rate of input audio sampling rate of input audio
lr : float, defaults to 1e-3 lr : float, defaults to 1e-3
learning rate used for training learning rate used for training
dataset: MayaDataset, optional dataset: EnhancerDataset, optional
MayaDataset object containing train/validation data for training EnhancerDataset object containing train/validation data for training
duration : float, optional duration : float, optional
chunk duration in seconds chunk duration in seconds
loss : string or List of strings loss : string or List of strings
@ -97,17 +97,17 @@ class WaveUnet(Mayamodel):
initial_output_channels: int = 24, initial_output_channels: int = 24,
sampling_rate: int = 16000, sampling_rate: int = 16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else duration dataset.duration if isinstance(dataset, EnhancerDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
warnings.warn( logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

View File

@ -0,0 +1,3 @@
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.utils import check_files

View File

@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Optional
from mayavoz.utils.config import Files from enhancer.utils.config import Files
def check_files(root_dir: str, files: Files): def check_files(root_dir: str, files: Files):

View File

@ -1,4 +1,4 @@
name: mayavoz name: enhancer
dependencies: dependencies:
- pip=21.0.1 - pip=21.0.1

41
hpc_entrypoint.sh Normal file
View File

@ -0,0 +1,41 @@
#!/bin/bash
set -e
echo '----------------------------------------------------'
echo ' SLURM_CLUSTER_NAME = '$SLURM_CLUSTER_NAME
echo ' SLURMD_NODENAME = '$SLURMD_NODENAME
echo ' SLURM_JOBID = '$SLURM_JOBID
echo ' SLURM_JOB_USER = '$SLURM_JOB_USER
echo ' SLURM_PARTITION = '$SLURM_JOB_PARTITION
echo ' SLURM_JOB_ACCOUNT = '$SLURM_JOB_ACCOUNT
echo '----------------------------------------------------'
#TeamCity Output
cat << EOF
##teamcity[buildNumber '$SLURM_JOBID']
EOF
echo "Load HPC modules"
module load anaconda
echo "Activate Environment"
source activate enhancer
export TRANSFORMERS_OFFLINE=True
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
export HYDRA_FULL_ERROR=1
echo $PYTHONPATH
source ~/mlflow_settings.sh
echo "Making temp dir"
mkdir temp
pwd
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
# mv /scratch/c.sistc3/MS-SNSD/DNS20/CleanSpeech_testing /scratch/c.sistc3/MS-SNSD/DNS30/CleanSpeech_testing
# mv /scratch/c.sistc3/MS-SNSD/DNS20/NoisySpeech_testing /scratch/c.sistc3/MS-SNSD/DNS30/NoisySpeech_testing
echo "Start Training..."
python enhancer/cli/train.py

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def train(config: DictConfig):
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
train()

View File

@ -1,12 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : MS-SDSD
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 2.0
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 15
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : Valentini
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5
stride : 2
sampling_rate: 16000
batch_size: 32
valid_minutes : 15
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -1,7 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
ReduceLr_patience : 5
ReduceLr_factor : 0.2
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,16 +0,0 @@
_target_: mayavoz.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2
glu: True
lstm:
bidirectional: False
num_layers: 2

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1 +0,0 @@
from mayavoz.data.dataset import MayaDataset

View File

@ -1,3 +0,0 @@
from mayavoz.models.demucs import Demucs
from mayavoz.models.model import Mayamodel
from mayavoz.models.waveunet import WaveUnet

View File

@ -1,5 +0,0 @@
from mayavoz.models.complexnn.conv import ComplexConv2d # noqa
from mayavoz.models.complexnn.conv import ComplexConvTranspose2d # noqa
from mayavoz.models.complexnn.rnn import ComplexLSTM # noqa
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D # noqa
from mayavoz.models.complexnn.utils import ComplexRelu # noqa

View File

@ -1,3 +0,0 @@
from mayavoz.utils.config import Files
from mayavoz.utils.io import Audio
from mayavoz.utils.utils import check_files

View File

@ -1,338 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ccd61d5c",
"metadata": {},
"source": [
"## Custom model training using mayavoz [advanced]\n",
"\n",
"In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
"\n",
" - [Data preparation using MayaDataset](#dataprep)\n",
" - [Model customization](#modelcustom)\n",
" - [callbacks & LR schedulers](#callbacks)\n",
" - [Model training & testing](#train)\n"
]
},
{
"cell_type": "markdown",
"id": "726c320f",
"metadata": {},
"source": [
"- **install mayavoz**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c987c799",
"metadata": {},
"outputs": [],
"source": [
"! pip install -q mayavoz"
]
},
{
"cell_type": "markdown",
"id": "8ff9857b",
"metadata": {},
"source": [
"<div id=\"dataprep\"></div>\n",
"\n",
"### Data preparation\n",
"\n",
"`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
"\n",
"```\n",
"- VCTK/\n",
" |__ clean_train_wav/\n",
" |__ noisy_train_wav/\n",
" |__ clean_test_wav/\n",
" |__ noisy_test_wav/\n",
" \n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "64cbc0c8",
"metadata": {},
"outputs": [],
"source": [
"from mayavoz.utils import Files\n",
"file = Files(train_clean=\"clean_train_wav\",\n",
" train_noisy=\"noisy_train_wav\",\n",
" test_clean=\"clean_test_wav\",\n",
" test_noisy=\"noisy_test_wav\")\n",
"root_dir = \"VCTK\""
]
},
{
"cell_type": "markdown",
"id": "2d324bd1",
"metadata": {},
"source": [
"- `name`: name of the dataset. \n",
"- `duration`: control the duration of each audio instance fed into your model.\n",
"- `stride` is used if set to move the sliding window.\n",
"- `sampling_rate`: desired sampling rate for audio\n",
"- `batch_size`: model batch size\n",
"- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
"- `matching_function`: there are two types of mapping functions.\n",
" - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
" - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example MS-SNSD dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6834941d",
"metadata": {},
"outputs": [],
"source": [
"name = \"vctk\"\n",
"duration : 4.5\n",
"stride : 2.0\n",
"sampling_rate : 16000\n",
"min_valid_minutes : 20.0\n",
"batch_size : 32\n",
"matching_function : \"one_to_one\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d08c6bf8",
"metadata": {},
"outputs": [],
"source": [
"from mayavoz.dataset import MayaDataset\n",
"dataset = MayaDataset(\n",
" name=name,\n",
" root_dir=root_dir,\n",
" files=files,\n",
" duration=duration,\n",
" stride=stride,\n",
" sampling_rate=sampling_rate,\n",
" batch_size=batch_size,\n",
" min_valid_minutes=min_valid_minutes,\n",
" matching_function=matching_function\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "5b315bde",
"metadata": {},
"source": [
"Now your custom dataloader is ready!"
]
},
{
"cell_type": "markdown",
"id": "01548fe5",
"metadata": {},
"source": [
"<div id=\"modelcustom\"></div>\n",
"\n",
"### Model Customization\n",
"Now, this is very easy. \n",
"\n",
"- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
" - `WaveUnet`\n",
" - `Demucs`\n",
" - `DCCRN`\n",
"- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n",
"- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
"- `dataset`: mayavoz dataset object as prepared earlier.\n",
"- `loss` : model loss. Multiple loss functions are available.\n",
"\n",
" \n",
" \n",
"you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
"\n",
"mayavoz can accept **custom loss functions**. It should be of the form.\n",
"```\n",
"class your_custom_loss(nn.Module):\n",
" def __init__(self,**kwargs):\n",
" self.higher_better = False ## loss minimization direction\n",
" self.name = \"your_loss_name\" ## loss name logging \n",
" ...\n",
" def forward(self,prediction, target):\n",
" loss = ....\n",
" return loss\n",
" \n",
"```\n",
"\n",
"- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b36b457c",
"metadata": {},
"outputs": [],
"source": [
"from mayavoz.models import Demucs\n",
"model = Demucs(\n",
" sampling_rate=16000,\n",
" dataset=dataset,\n",
" loss=[\"mae\"],\n",
" metrics=[\"stoi\",\"pesq\"])\n"
]
},
{
"cell_type": "markdown",
"id": "1523d638",
"metadata": {},
"source": [
"<div id=\"callbacks\"></div>\n",
"\n",
"### learning rate schedulers and callbacks\n",
"Here I am using `ReduceLROnPlateau`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8de6931c",
"metadata": {},
"outputs": [],
"source": [
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"\n",
"def configure_optimizers(self):\n",
" optimizer = instantiate(\n",
" config.optimizer,\n",
" lr=parameters.get(\"lr\"),\n",
" params=self.parameters(),\n",
" )\n",
" scheduler = ReduceLROnPlateau(\n",
" optimizer=optimizer,\n",
" mode=direction,\n",
" factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
" verbose=True,\n",
" min_lr=parameters.get(\"min_lr\", 1e-6),\n",
" patience=parameters.get(\"ReduceLr_patience\", 3),\n",
" )\n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": scheduler,\n",
" \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
" }\n",
"\n",
"\n",
"model.configure_optimizers = MethodType(configure_optimizers, model)"
]
},
{
"cell_type": "markdown",
"id": "2f7b5af5",
"metadata": {},
"source": [
"you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f6b62a1",
"metadata": {},
"outputs": [],
"source": [
"callbacks = []\n",
"direction = model.valid_monitor ## min or max \n",
"checkpoint = ModelCheckpoint(\n",
" dirpath=\"./model\",\n",
" filename=f\"model_filename\",\n",
" monitor=\"valid_loss\",\n",
" verbose=False,\n",
" mode=direction,\n",
" every_n_epochs=1,\n",
" )\n",
"callbacks.append(checkpoint)"
]
},
{
"cell_type": "markdown",
"id": "f3534445",
"metadata": {},
"source": [
"<div id=\"train\"></div>\n",
"\n",
"\n",
"### Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3dc0348b",
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
"trainer.fit(model)\n"
]
},
{
"cell_type": "markdown",
"id": "56dcfec1",
"metadata": {},
"source": [
"- Test your model agaist test dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63851feb",
"metadata": {},
"outputs": [],
"source": [
"trainer.test(model)"
]
},
{
"cell_type": "markdown",
"id": "4d3f5350",
"metadata": {},
"source": [
"**Hurray! you have your speech enhancement model trained and tested.**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10d630e8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "enhancer",
"language": "python",
"name": "enhancer"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def train(config: DictConfig):
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
train()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : MS-SNSD
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : MS-SDSD
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 1.5
stride : 1
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 25
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,25 +0,0 @@
_target_: mayavoz.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def train(config: DictConfig):
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
train()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : MS-SNSD
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : MS-SDSD
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 5
stride : 1
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 25
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,7 +0,0 @@
loss : mae
metric : [stoi,pesq]
lr : 0.0003
ReduceLr_patience : 10
ReduceLr_factor : 0.5
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : demucs-ms-snsd

View File

@ -1,16 +0,0 @@
_target_: mayavoz.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2
glu: True
lstm:
bidirectional: False
num_layers: 2

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,17 +0,0 @@
### Microsoft Scalable Noisy Speech Dataset (MS-SNSD)
MS-SNSD is a speech datasetthat can scale to arbitrary sizes depending on the number of speakers, noise types, and Speech to Noise Ratio (SNR) levels desired.
### Dataset download & setup
- Follow steps in the official repo [here](https://github.com/microsoft/MS-SNSD) to download and setup the dataset.
**References**
```BibTex
@article{reddy2019scalable,
title={A Scalable Noisy Speech Dataset and Online Subjective Test Framework},
author={Reddy, Chandan KA and Beyrami, Ebrahim and Pool, Jamie and Cutler, Ross and Srinivasan, Sriram and Gehrke, Johannes},
journal={Proc. Interspeech 2019},
pages={1816--1820},
year={2019}
}
```

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
OmegaConf.save(config, "config_log.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
main()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,8 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
Early_stop : False
ReduceLr_patience : 10
ReduceLr_factor : 0.1
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : baseline

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,7 +0,0 @@
defaults:
- model : WaveUnet
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 2
stride : 1
sampling_rate: 16000
batch_size: 128
valid_minutes : 25
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -1,8 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.003
ReduceLr_patience : 10
ReduceLr_factor : 0.1
min_lr : 0.000001
EarlyStopping_factor : 10
Early_stop : False

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : baseline

View File

@ -1,5 +0,0 @@
_target_: mayavoz.models.waveunet.WaveUnet
num_channels : 1
depth : 9
initial_output_channels: 24
sampling_rate : 16000

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,12 +0,0 @@
## Valentini dataset
Clean and noisy parallel speech database. The database was designed to train and test speech enhancement methods that operate at 48kHz. A more detailed description can be found in the papers associated with the database.[official page](https://datashare.ed.ac.uk/handle/10283/2791)
**References**
```BibTex
@misc{
title={Noisy speech database for training speech enhancement algorithms and TTS models},
author={Valentini-Botinhao, Cassia}, year={2017},
doi=https://doi.org/10.7488/ds/2117,
}
```

View File

@ -3,7 +3,7 @@ huggingface-hub>=0.10.0
hydra-core>=1.2.0 hydra-core>=1.2.0
joblib>=1.2.0 joblib>=1.2.0
librosa>=0.9.2 librosa>=0.9.2
mlflow>=1.28.0 mlflow>=1.29.0
numpy>=1.23.3 numpy>=1.23.3
pesq==0.0.4 pesq==0.0.4
protobuf>=3.19.6 protobuf>=3.19.6

View File

@ -3,7 +3,7 @@
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata] [metadata]
name = mayavoz name = enhancer
description = Deep learning for speech enhacement description = Deep learning for speech enhacement
author = Shahul Ess author = Shahul Ess
author-email = shahules786@gmail.com author-email = shahules786@gmail.com
@ -53,7 +53,7 @@ cli =
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
mayavoz-train=mayavoz.cli.train:train enhancer-train=enhancer.cli.train:train
[test] [test]
# py.test options when running `python setup.py test` # py.test options when running `python setup.py test`
@ -66,7 +66,7 @@ extras = True
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
# in order to write a coverage file that can be read by Jenkins. # in order to write a coverage file that can be read by Jenkins.
addopts = addopts =
--cov mayavoz --cov-report term-missing --cov enhancer --cov-report term-missing
--verbose --verbose
norecursedirs = norecursedirs =
dist dist
@ -98,7 +98,3 @@ exclude =
build build
dist dist
.eggs .eggs
[options.data_files]
. = requirements.txt
_ = version.txt

View File

@ -33,15 +33,15 @@ elif sha != "Unknown":
version += "+" + sha[:7] version += "+" + sha[:7]
print("-- Building version " + version) print("-- Building version " + version)
version_path = ROOT_DIR / "mayavoz" / "version.py" version_path = ROOT_DIR / "enhancer" / "version.py"
with open(version_path, "w") as f: with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version)) f.write("__version__ = '{}'\n".format(version))
if __name__ == "__main__": if __name__ == "__main__":
setup( setup(
name="mayavoz", name="enhancer",
namespace_packages=["mayavoz"], namespace_packages=["enhancer"],
version=version, version=version,
packages=find_packages(), packages=find_packages(),
install_requires=requirements, install_requires=requirements,

13
setup.sh Normal file
View File

@ -0,0 +1,13 @@
#!/bin/bash
set -e
echo "Loading Anaconda Module"
module load anaconda
echo "Creating Virtual Environment"
conda env create -f environment.yml || conda env update -f environment.yml
source activate enhancer
echo "copying files"
# cp /scratch/$USER/TIMIT/.* /deep-transcriber

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from mayavoz.loss import mean_absolute_error, mean_squared_error from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()] loss_functions = [mean_absolute_error(), mean_squared_error()]

View File

@ -1,8 +1,8 @@
import torch import torch
from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from mayavoz.models.complexnn.rnn import ComplexLSTM from enhancer.models.complexnn.rnn import ComplexLSTM
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D from enhancer.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d(): def test_complexconv2d():

View File

@ -1,9 +1,9 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models import Demucs from enhancer.models import Demucs
from mayavoz.utils.config import Files from enhancer.utils.config import Files
@pytest.fixture @pytest.fixture
@ -15,9 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset( dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset

View File

@ -1,9 +1,9 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models.dccrn import DCCRN from enhancer.models.dccrn import DCCRN
from mayavoz.utils.config import Files from enhancer.utils.config import Files
@pytest.fixture @pytest.fixture
@ -15,9 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset( dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset

View File

@ -1,9 +1,9 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models import WaveUnet from enhancer.models import WaveUnet
from mayavoz.utils.config import Files from enhancer.utils.config import Files
@pytest.fixture @pytest.fixture
@ -15,9 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset( dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset

Some files were not shown because too many files have changed in this diff Show More