Compare commits
72 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
31cd404e03 | |
|
|
8cf8dd9717 | |
|
|
c2b2b83fd5 | |
|
|
bcb94d5f34 | |
|
|
e0abb458d5 | |
|
|
e322c6280d | |
|
|
7548647dfb | |
|
|
1035fbb236 | |
|
|
a06c0a3865 | |
|
|
64f52fe010 | |
|
|
078d3eb244 | |
|
|
629adf0232 | |
|
|
501948e866 | |
|
|
aa61056376 | |
|
|
187e48d125 | |
|
|
7882b8cca3 | |
|
|
789da44114 | |
|
|
78138c5f93 | |
|
|
b8a05c775c | |
|
|
2b68598d7b | |
|
|
29a432540e | |
|
|
8d1c057b86 | |
|
|
6e0f69f575 | |
|
|
0e58691a2c | |
|
|
807f4b93ea | |
|
|
315d646347 | |
|
|
f34e49e341 | |
|
|
fa47860f57 | |
|
|
f7eb0a600c | |
|
|
ba2d00648c | |
|
|
8a55a77640 | |
|
|
94a4ea38ed | |
|
|
8d25b0ed79 | |
|
|
09ba645315 | |
|
|
8906496366 | |
|
|
e4a2eb7844 | |
|
|
8a6af87627 | |
|
|
5a392332ba | |
|
|
f66a5236e1 | |
|
|
d415bb0c59 | |
|
|
8c1524a998 | |
|
|
7161f84a27 | |
|
|
2c79e60a85 | |
|
|
41ee2fce0b | |
|
|
0c5db496e2 | |
|
|
031221b79e | |
|
|
50062eaf40 | |
|
|
0b02b73094 | |
|
|
2ccc2822cd | |
|
|
1667de624e | |
|
|
32579b7a39 | |
|
|
bb68e9e4eb | |
|
|
a21ef707ad | |
|
|
81c5f13ff6 | |
|
|
a417e226f3 | |
|
|
5d8f49d78e | |
|
|
14156743f9 | |
|
|
845575a2ad | |
|
|
c9b78b0e73 | |
|
|
3068476512 | |
|
|
ffb364196e | |
|
|
52cefcb962 | |
|
|
61923f6d68 | |
|
|
e90efe3163 | |
|
|
aa043aaf40 | |
|
|
4f6ccadf4b | |
|
|
0e982cd493 | |
|
|
0787d946da | |
|
|
e06ba07889 | |
|
|
741fd7b87c | |
|
|
a064151e2e | |
|
|
25557757c7 |
2
.flake8
2
.flake8
|
|
@ -1,5 +1,5 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
per-file-ignores = "mayavoz/model/__init__.py:F401"
|
per-file-ignores = __init__.py:F401
|
||||||
ignore = E203, E266, E501, W503
|
ignore = E203, E266, E501, W503
|
||||||
# line length is intentionally set to 80 here because black uses Bugbear
|
# line length is intentionally set to 80 here because black uses Bugbear
|
||||||
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
notebooks/** linguist-vendored
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
name: mayavoz
|
name: Enhancer
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ dev ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ dev ]
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
@ -40,12 +40,12 @@ jobs:
|
||||||
sudo apt-get install libsndfile1
|
sudo apt-get install libsndfile1
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
pip install black pytest-cov
|
pip install black pytest-cov
|
||||||
- name: Install mayavoz
|
- name: Install enhancer
|
||||||
run: |
|
run: |
|
||||||
pip install -e .[dev,testing]
|
pip install -e .[dev,testing]
|
||||||
- name: Run black
|
- name: Run black
|
||||||
run:
|
run:
|
||||||
black --check . --exclude mayavoz/version.py
|
black --check . --exclude enhancer/version.py
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run:
|
run:
|
||||||
pytest tests --cov=mayavoz/
|
pytest tests --cov=enhancer/
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,5 @@
|
||||||
#local
|
#local
|
||||||
cleaned_my_voice.wav
|
|
||||||
lightning_logs/
|
|
||||||
my_voice.wav
|
|
||||||
pretrained/
|
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*_local.yaml
|
|
||||||
cli/train_config/dataset/Vctk_local.yaml
|
cli/train_config/dataset/Vctk_local.yaml
|
||||||
.DS_Store
|
.DS_Store
|
||||||
outputs/
|
outputs/
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
args: ['--ignore=E203,E501,F811,E712,W503']
|
args: ['--ignore=E203,E501,F811,E712,W503']
|
||||||
exclude: __init__.py
|
|
||||||
|
|
||||||
# Formatting, Whitespace, etc
|
# Formatting, Whitespace, etc
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
|
@ -41,4 +40,5 @@ repos:
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: requirements-txt-fixer
|
- id: requirements-txt-fixer
|
||||||
- id: mixed-line-ending
|
- id: mixed-line-ending
|
||||||
|
exclude: noisyspeech_synthesizer.cfg
|
||||||
args: ['--fix=no']
|
args: ['--fix=no']
|
||||||
|
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
# Contributing
|
|
||||||
|
|
||||||
Hi there 👋
|
|
||||||
|
|
||||||
If you're reading this I hope that you're looking forward to adding value to Mayavoz. This document will help you to get started with your journey.
|
|
||||||
|
|
||||||
## How to get your code in Mayavoz
|
|
||||||
|
|
||||||
1. We use git and GitHub.
|
|
||||||
|
|
||||||
2. Fork the mayavoz repository (https://github.com/shahules786/mayavoz) on GitHub under your own account. (This creates a copy of mayavoz under your account, and GitHub knows where it came from, and we typically call this “upstream”.)
|
|
||||||
|
|
||||||
3. Clone your own mayavoz repository. git clone https://github.com/ <your-account> /mayavoz (This downloads the git repository to your machine, git knows where it came from, and calls it “origin”.)
|
|
||||||
|
|
||||||
4. Create a branch for each specific feature you are developing. git checkout -b your-branch-name
|
|
||||||
|
|
||||||
5. Make + commit changes. git add files-you-changed ... git commit -m "Short message about what you did"
|
|
||||||
|
|
||||||
6. Push the branch to your GitHub repository. git push origin your-branch-name
|
|
||||||
|
|
||||||
7. Navigate to GitHub, and create a pull request from your branch to the upstream repository mayavoz/mayavoz, to the “develop” branch.
|
|
||||||
|
|
||||||
8. The Pull Request (PR) appears on the upstream repository. Discuss your contribution there. If you push more changes to your branch on GitHub (on your repository), they are added to the PR.
|
|
||||||
|
|
||||||
9. When the reviewer is satisfied that the code improves repository quality, they can merge.
|
|
||||||
|
|
||||||
Note that CI tests will be run when you create a PR. If you want to be sure that your code will not fail these tests, we have set up pre-commit hooks that you can install.
|
|
||||||
|
|
||||||
**If you're worried about things not being perfect with your code, we will work togethor and make it perfect. So, make your move!**
|
|
||||||
|
|
||||||
## Formating
|
|
||||||
|
|
||||||
We use [black](https://black.readthedocs.io/en/stable/) and [flake8](https://flake8.pycqa.org/en/latest/) for code formating. Please ensure that you use the same before submitting the PR.
|
|
||||||
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
We adopt unit testing using [pytest](https://docs.pytest.org/en/latest/contents.html)
|
|
||||||
Please make sure that adding your new component does not decrease test coverage.
|
|
||||||
|
|
||||||
## Other tools
|
|
||||||
The use of [per-commit](https://pre-commit.com/) is recommended to ensure different requirements such as code formating, etc.
|
|
||||||
|
|
||||||
## How to start contributing to Mayavoz?
|
|
||||||
|
|
||||||
1. Checkout issues marked as `good first issue`, let us know you're interested in working on some issue by commenting under it.
|
|
||||||
2. For others, I would suggest you to explore mayavoz. One way to do is to use it to train your own model. This was you might end by finding a new unreported bug or getting an idea to improve Mayavoz.
|
|
||||||
20
LICENSE
20
LICENSE
|
|
@ -1,20 +0,0 @@
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2022 Shahul Es
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
recursive-include mayavoz *.py
|
|
||||||
recursive-include mayavoz *.yaml
|
|
||||||
global-exclude *.pyc
|
|
||||||
global-exclude __pycache__
|
|
||||||
51
README.md
51
README.md
|
|
@ -2,52 +2,24 @@
|
||||||
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
|
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||

|
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio practioners & researchers. It provides easy to use pretrained speech enhancement models and facilitates highly customisable model training.
|
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
||||||
|
|
||||||
| **[Quick Start](#quick-start-fire)** | **[Installation](#installation)** | **[Tutorials](https://github.com/shahules786/enhancer/tree/main/notebooks)** | **[Available Recipes](#recipes)** | **[Demo](#demo)**
|
|
||||||
## Key features :key:
|
## Key features :key:
|
||||||
|
|
||||||
* Various pretrained models nicely integrated with [huggingface hub](https://huggingface.co/docs/hub/index) :hugs: that users can select and use without any hastle.
|
* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
|
||||||
* :package: Ability to train and validate your own custom speech enhancement models with just under 10 lines of code!
|
* :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
|
||||||
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
|
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
|
||||||
* :zap: Supports multi-gpu training integrated with [Pytorch Lightning](https://pytorchlightning.ai/).
|
* :zap: Supports multi-gpu training integrated with Pytorch Lightning.
|
||||||
* :shield: data augmentations integrated using [torch-augmentations](https://github.com/asteroid-team/torch-audiomentations)
|
|
||||||
|
|
||||||
|
|
||||||
## Demo
|
|
||||||
|
|
||||||
Noisy speech followed by enhanced version.
|
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/25312635/203756185-737557f4-6e21-4146-aa2c-95da69d0de4c.mp4
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Quick Start :fire:
|
## Quick Start :fire:
|
||||||
``` python
|
``` python
|
||||||
from mayavoz.models import Mayamodel
|
from mayavoz import Mayamodel
|
||||||
|
|
||||||
model = Mayamodel.from_pretrained("shahules786/mayavoz-waveunet-valentini-28spk")
|
model = Mayamodel.from_pretrained("mayavoz/waveunet")
|
||||||
model.enhance("noisy_audio.wav")
|
model("noisy_audio.wav")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recipes
|
|
||||||
|
|
||||||
| Model | Dataset | STOI | PESQ | URL |
|
|
||||||
| :---: | :---: | :---: | :---: | :---: |
|
|
||||||
| WaveUnet | Valentini-28spk | 0.836 | 2.78 | shahules786/mayavoz-waveunet-valentini-28spk |
|
|
||||||
| Demucs | Valentini-28spk | 0.961 | 2.56 | shahules786/mayavoz-demucs-valentini-28spk |
|
|
||||||
| DCCRN | Valentini-28spk | 0.724 | 2.55 | shahules786/mayavoz-dccrn-valentini-28spk |
|
|
||||||
| Demucs | MS-SNSD-20hrs | 0.56 | 1.26 | shahules786/mayavoz-demucs-ms-snsd-20 |
|
|
||||||
|
|
||||||
Test scores are based on respective test set associated with train dataset.
|
|
||||||
|
|
||||||
**See [tutorials](/notebooks/) to train your custom model**
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
|
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
|
||||||
|
|
||||||
|
|
@ -69,10 +41,3 @@ git clone url
|
||||||
cd mayavoz
|
cd mayavoz
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
## Support
|
|
||||||
|
|
||||||
For commercial enquiries and scientific consulting, please [contact me](https://shahules786.github.io/).
|
|
||||||
|
|
||||||
### Acknowledgements
|
|
||||||
Sincere gratitude to [AMPLYFI](https://amplyfi.com/) for supporting this project.
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Wed Jun 26 15:54:05 2019
|
||||||
|
|
||||||
|
@author: chkarada
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
# Function to read audio
|
||||||
|
def audioread(path, norm=True, start=0, stop=None):
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise ValueError("[{}] does not exist!".format(path))
|
||||||
|
try:
|
||||||
|
x, sr = sf.read(path, start=start, stop=stop)
|
||||||
|
except RuntimeError: # fix for sph pcm-embedded shortened v2
|
||||||
|
print("WARNING: Audio type not supported")
|
||||||
|
|
||||||
|
if len(x.shape) == 1: # mono
|
||||||
|
if norm:
|
||||||
|
rms = (x**2).mean() ** 0.5
|
||||||
|
scalar = 10 ** (-25 / 20) / (rms)
|
||||||
|
x = x * scalar
|
||||||
|
return x, sr
|
||||||
|
else: # multi-channel
|
||||||
|
x = x.T
|
||||||
|
x = x.sum(axis=0) / x.shape[0]
|
||||||
|
if norm:
|
||||||
|
rms = (x**2).mean() ** 0.5
|
||||||
|
scalar = 10 ** (-25 / 20) / (rms)
|
||||||
|
x = x * scalar
|
||||||
|
return x, sr
|
||||||
|
|
||||||
|
|
||||||
|
# Funtion to write audio
|
||||||
|
def audiowrite(data, fs, destpath, norm=False):
|
||||||
|
if norm:
|
||||||
|
eps = 0.0
|
||||||
|
rms = (data**2).mean() ** 0.5
|
||||||
|
scalar = 10 ** (-25 / 10) / (rms + eps)
|
||||||
|
data = data * scalar
|
||||||
|
if max(abs(data)) >= 1:
|
||||||
|
data = data / max(abs(data), eps)
|
||||||
|
|
||||||
|
destpath = os.path.abspath(destpath)
|
||||||
|
destdir = os.path.dirname(destpath)
|
||||||
|
|
||||||
|
if not os.path.exists(destdir):
|
||||||
|
os.makedirs(destdir)
|
||||||
|
|
||||||
|
sf.write(destpath, data, fs)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Function to mix clean speech and noise at various SNR levels
|
||||||
|
def snr_mixer(clean, noise, snr):
|
||||||
|
# Normalizing to -25 dB FS
|
||||||
|
rmsclean = (clean**2).mean() ** 0.5
|
||||||
|
scalarclean = 10 ** (-25 / 20) / rmsclean
|
||||||
|
clean = clean * scalarclean
|
||||||
|
rmsclean = (clean**2).mean() ** 0.5
|
||||||
|
|
||||||
|
rmsnoise = (noise**2).mean() ** 0.5
|
||||||
|
scalarnoise = 10 ** (-25 / 20) / rmsnoise
|
||||||
|
noise = noise * scalarnoise
|
||||||
|
rmsnoise = (noise**2).mean() ** 0.5
|
||||||
|
|
||||||
|
# Set the noise level for a given SNR
|
||||||
|
noisescalar = np.sqrt(rmsclean / (10 ** (snr / 20)) / rmsnoise)
|
||||||
|
noisenewlevel = noise * noisescalar
|
||||||
|
noisyspeech = clean + noisenewlevel
|
||||||
|
return clean, noisenewlevel, noisyspeech
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
from mayavoz.models import Mayamodel
|
|
||||||
|
|
@ -4,16 +4,10 @@ from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import (
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
EarlyStopping,
|
|
||||||
LearningRateMonitor,
|
|
||||||
ModelCheckpoint,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
# from torch_audiomentations import Compose, Shift
|
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
@ -31,13 +25,8 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
# apply_augmentations = Compose(
|
|
||||||
# [
|
|
||||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
dataset = instantiate(config.dataset, augmentations=None)
|
dataset = instantiate(config.dataset)
|
||||||
model = instantiate(
|
model = instantiate(
|
||||||
config.model,
|
config.model,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|
@ -56,8 +45,6 @@ def main(config: DictConfig):
|
||||||
every_n_epochs=1,
|
every_n_epochs=1,
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
if parameters.get("Early_stop", False):
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
|
|
@ -69,11 +56,11 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
callbacks.append(early_stopping)
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizer(self):
|
||||||
optimizer = instantiate(
|
optimizer = instantiate(
|
||||||
config.optimizer,
|
config.optimizer,
|
||||||
lr=parameters.get("lr"),
|
lr=parameters.get("lr"),
|
||||||
params=self.parameters(),
|
parameters=self.parameters(),
|
||||||
)
|
)
|
||||||
scheduler = ReduceLROnPlateau(
|
scheduler = ReduceLROnPlateau(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
|
@ -83,13 +70,9 @@ def main(config: DictConfig):
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
min_lr=parameters.get("min_lr", 1e-6),
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
patience=parameters.get("ReduceLr_patience", 3),
|
||||||
)
|
)
|
||||||
return {
|
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
||||||
"optimizer": optimizer,
|
|
||||||
"lr_scheduler": scheduler,
|
|
||||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
|
||||||
}
|
|
||||||
|
|
||||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
model.configure_parameters = MethodType(configure_optimizer, model)
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
@ -3,5 +3,5 @@ defaults:
|
||||||
- dataset : Vctk
|
- dataset : Vctk
|
||||||
- optimizer : Adam
|
- optimizer : Adam
|
||||||
- hyperparameters : default
|
- hyperparameters : default
|
||||||
- trainer : default
|
- trainer : fastrun_dev
|
||||||
- mlflow : experiment
|
- mlflow : experiment
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
|
||||||
|
name : dns-2020
|
||||||
|
duration : 1.0
|
||||||
|
sampling_rate: 8000
|
||||||
|
batch_size: 32
|
||||||
|
files:
|
||||||
|
train_clean : clean_test_wav
|
||||||
|
test_clean : clean_test_wav
|
||||||
|
train_noisy : clean_test_wav
|
||||||
|
test_noisy : clean_test_wav
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 2
|
duration : 1.5
|
||||||
stride : 1
|
|
||||||
sampling_rate: 16000
|
sampling_rate: 16000
|
||||||
batch_size: 128
|
batch_size: 256
|
||||||
valid_minutes : 25
|
valid_size : 0.05
|
||||||
|
|
||||||
files:
|
files:
|
||||||
train_clean : clean_trainset_28spk_wav
|
train_clean : clean_trainset_28spk_wav
|
||||||
test_clean : clean_testset_wav
|
test_clean : clean_testset_wav
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
name : vctk
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
|
||||||
|
duration : 1.0
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 64
|
||||||
|
num_workers : 0
|
||||||
|
|
||||||
|
files:
|
||||||
|
train_clean : clean_testset_wav
|
||||||
|
test_clean : clean_testset_wav
|
||||||
|
train_noisy : noisy_testset_wav
|
||||||
|
test_noisy : noisy_testset_wav
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
loss : si-snr
|
loss : mse
|
||||||
metric : [stoi,pesq]
|
metric : [stoi,pesq,si-sdr]
|
||||||
lr : 0.001
|
lr : 0.001
|
||||||
ReduceLr_patience : 10
|
ReduceLr_patience : 10
|
||||||
ReduceLr_factor : 0.5
|
ReduceLr_factor : 0.5
|
||||||
min_lr : 0.000001
|
min_lr : 0.00
|
||||||
EarlyStopping_factor : 10
|
EarlyStopping_factor : 10
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
experiment_name : shahules/enhancer
|
||||||
|
run_name : baseline
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
_target_: mayavoz.models.demucs.Demucs
|
_target_: enhancer.models.demucs.Demucs
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
resample: 4
|
resample: 2
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
||||||
encoder_decoder:
|
encoder_decoder:
|
||||||
depth: 4
|
depth: 5
|
||||||
initial_output_channels: 64
|
initial_output_channels: 32
|
||||||
kernel_size: 8
|
kernel_size: 8
|
||||||
stride: 4
|
stride: 4
|
||||||
growth_factor: 2
|
growth_factor: 2
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
_target_: mayavoz.models.waveunet.WaveUnet
|
_target_: enhancer.models.waveunet.WaveUnet
|
||||||
num_channels : 1
|
num_channels : 1
|
||||||
depth : 9
|
depth : 12
|
||||||
initial_output_channels: 24
|
initial_output_channels: 24
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
@ -2,14 +2,14 @@ _target_: pytorch_lightning.Trainer
|
||||||
accelerator: gpu
|
accelerator: gpu
|
||||||
accumulate_grad_batches: 1
|
accumulate_grad_batches: 1
|
||||||
amp_backend: native
|
amp_backend: native
|
||||||
auto_lr_find: True
|
auto_lr_find: False
|
||||||
auto_scale_batch_size: False
|
auto_scale_batch_size: False
|
||||||
auto_select_gpus: True
|
auto_select_gpus: True
|
||||||
benchmark: False
|
benchmark: False
|
||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
detect_anomaly: False
|
detect_anomaly: False
|
||||||
deterministic: False
|
deterministic: False
|
||||||
devices: 1
|
devices: 2
|
||||||
enable_checkpointing: True
|
enable_checkpointing: True
|
||||||
enable_model_summary: True
|
enable_model_summary: True
|
||||||
enable_progress_bar: True
|
enable_progress_bar: True
|
||||||
|
|
@ -22,9 +22,8 @@ limit_predict_batches: 1.0
|
||||||
limit_test_batches: 1.0
|
limit_test_batches: 1.0
|
||||||
limit_train_batches: 1.0
|
limit_train_batches: 1.0
|
||||||
limit_val_batches: 1.0
|
limit_val_batches: 1.0
|
||||||
log_every_n_steps: 50
|
log_every_n_steps: 100
|
||||||
max_epochs: 200
|
max_epochs: 250
|
||||||
max_steps: -1
|
|
||||||
max_time: null
|
max_time: null
|
||||||
min_epochs: 1
|
min_epochs: 1
|
||||||
min_steps: null
|
min_steps: null
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
@ -0,0 +1,263 @@
|
||||||
|
import math
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
from enhancer.utils import check_files
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
|
|
||||||
|
class TrainDataset(IterableDataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self.dataset.train__iter__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.train__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class ValidDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.val__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.val__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.test__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.test__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDataset(pl.LightningDataModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
root_dir: str,
|
||||||
|
files: Files,
|
||||||
|
valid_size: float = 0.20,
|
||||||
|
duration: float = 1.0,
|
||||||
|
sampling_rate: int = 48000,
|
||||||
|
matching_function=None,
|
||||||
|
batch_size=32,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.files, self.root_dir = check_files(root_dir, files)
|
||||||
|
self.duration = duration
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.matching_function = matching_function
|
||||||
|
self._validation = []
|
||||||
|
if num_workers is None:
|
||||||
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
|
self.num_workers = num_workers
|
||||||
|
if valid_size > 0.0:
|
||||||
|
self.valid_size = valid_size
|
||||||
|
else:
|
||||||
|
raise ValueError("valid_size must be greater than 0")
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
prepare train/validation/test data splits
|
||||||
|
"""
|
||||||
|
|
||||||
|
if stage in ("fit", None):
|
||||||
|
|
||||||
|
train_clean = os.path.join(self.root_dir, self.files.train_clean)
|
||||||
|
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
|
)
|
||||||
|
train_data = fp.prepare_matching_dict()
|
||||||
|
self.train_data, self.val_data = train_test_split(
|
||||||
|
train_data, test_size=0.20, shuffle=True, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
self._validation = self.prepare_mapstype(self.val_data)
|
||||||
|
|
||||||
|
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||||
|
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
self.name, test_clean, test_noisy, self.matching_function
|
||||||
|
)
|
||||||
|
test_data = fp.prepare_matching_dict()
|
||||||
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
|
def prepare_mapstype(self, data):
|
||||||
|
|
||||||
|
metadata = []
|
||||||
|
for item in data:
|
||||||
|
clean, noisy, total_dur = item.values()
|
||||||
|
if total_dur < self.duration:
|
||||||
|
continue
|
||||||
|
num_segments = round(total_dur / self.duration)
|
||||||
|
for index in range(num_segments):
|
||||||
|
start_time = index * self.duration
|
||||||
|
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
TrainDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
ValidDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
TestDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancerDataset(TaskDataset):
|
||||||
|
"""
|
||||||
|
Dataset object for creating clean-noisy speech enhancement datasets
|
||||||
|
paramters:
|
||||||
|
name : str
|
||||||
|
name of the dataset
|
||||||
|
root_dir : str
|
||||||
|
root directory of the dataset containing clean/noisy folders
|
||||||
|
files : Files
|
||||||
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
|
folder names (refer enhancer.utils.Files dataclass)
|
||||||
|
duration : float
|
||||||
|
expected audio duration of single audio sample for training
|
||||||
|
sampling_rate : int
|
||||||
|
desired sampling rate
|
||||||
|
batch_size : int
|
||||||
|
batch size of each batch
|
||||||
|
num_workers : int
|
||||||
|
num workers to be used while training
|
||||||
|
matching_function : str
|
||||||
|
maching functions - (one_to_one,one_to_many). Default set to None.
|
||||||
|
use one_to_one mapping for datasets with one noisy file for each clean file
|
||||||
|
use one_to_many mapping for multiple noisy files for each clean file
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
root_dir: str,
|
||||||
|
files: Files,
|
||||||
|
valid_size=0.2,
|
||||||
|
duration=1.0,
|
||||||
|
sampling_rate=48000,
|
||||||
|
matching_function=None,
|
||||||
|
batch_size=32,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
root_dir=root_dir,
|
||||||
|
files=files,
|
||||||
|
valid_size=valid_size,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
duration=duration,
|
||||||
|
matching_function=matching_function,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.files = files
|
||||||
|
self.duration = max(1.0, duration)
|
||||||
|
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
super().setup(stage=stage)
|
||||||
|
|
||||||
|
def train__iter__(self):
|
||||||
|
|
||||||
|
rng = create_unique_rng(self.model.current_epoch)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
file_dict, *_ = rng.choices(
|
||||||
|
self.train_data,
|
||||||
|
k=1,
|
||||||
|
weights=[file["duration"] for file in self.train_data],
|
||||||
|
)
|
||||||
|
file_duration = file_dict["duration"]
|
||||||
|
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
|
||||||
|
data = self.prepare_segment(file_dict, start_time)
|
||||||
|
yield data
|
||||||
|
|
||||||
|
def val__getitem__(self, idx):
|
||||||
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
||||||
|
def test__getitem__(self, idx):
|
||||||
|
return self.prepare_segment(*self._test[idx])
|
||||||
|
|
||||||
|
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||||
|
|
||||||
|
clean_segment = self.audio(
|
||||||
|
file_dict["clean"], offset=start_time, duration=self.duration
|
||||||
|
)
|
||||||
|
noisy_segment = self.audio(
|
||||||
|
file_dict["noisy"], offset=start_time, duration=self.duration
|
||||||
|
)
|
||||||
|
clean_segment = F.pad(
|
||||||
|
clean_segment,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
int(
|
||||||
|
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
noisy_segment = F.pad(
|
||||||
|
noisy_segment,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
int(
|
||||||
|
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return {"clean": clean_segment, "noisy": noisy_segment}
|
||||||
|
|
||||||
|
def train__len__(self):
|
||||||
|
return math.ceil(
|
||||||
|
sum([file["duration"] for file in self.train_data]) / self.duration
|
||||||
|
)
|
||||||
|
|
||||||
|
def val__len__(self):
|
||||||
|
return len(self._validation)
|
||||||
|
|
||||||
|
def test__len__(self):
|
||||||
|
return len(self._test)
|
||||||
|
|
@ -62,24 +62,25 @@ class ProcessorFunctions:
|
||||||
]
|
]
|
||||||
for clean_file in clean_filenames:
|
for clean_file in clean_filenames:
|
||||||
noisy_filenames = glob.glob(
|
noisy_filenames = glob.glob(
|
||||||
os.path.join(noisy_path, f"*_{clean_file}")
|
os.path.join(noisy_path, f"*_{clean_file}.wav")
|
||||||
)
|
)
|
||||||
for noisy_file in noisy_filenames:
|
for noisy_file in noisy_filenames:
|
||||||
|
|
||||||
sr_clean, clean_wav = wavfile.read(
|
sr_clean, clean_file = wavfile.read(
|
||||||
os.path.join(clean_path, clean_file)
|
os.path.join(clean_path, clean_file)
|
||||||
)
|
)
|
||||||
sr_noisy, noisy_wav = wavfile.read(noisy_file)
|
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
||||||
if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and (
|
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||||
sr_clean == sr_noisy
|
sr_clean == sr_noisy
|
||||||
):
|
):
|
||||||
matching_wavfiles.append(
|
matching_wavfiles.append(
|
||||||
{
|
{
|
||||||
"clean": os.path.join(clean_path, clean_file),
|
"clean": os.path.join(clean_path, clean_file),
|
||||||
"noisy": noisy_file,
|
"noisy": noisy_file,
|
||||||
"duration": clean_wav.shape[-1] / sr_clean,
|
"duration": clean_file.shape[-1] / sr_clean,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return matching_wavfiles
|
return matching_wavfiles
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -93,9 +94,9 @@ class Fileprocessor:
|
||||||
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||||
|
|
||||||
if matching_function is None:
|
if matching_function is None:
|
||||||
if name.lower() in ("vctk", "valentini"):
|
if name.lower() == "vctk":
|
||||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||||
elif name.lower() == "ms-snsd":
|
elif name.lower() == "dns-2020":
|
||||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -8,7 +8,7 @@ from librosa import load as load_audio
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
from scipy.signal import get_window
|
from scipy.signal import get_window
|
||||||
|
|
||||||
from mayavoz.utils import Audio
|
from enhancer.utils import Audio
|
||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
|
|
@ -95,7 +95,6 @@ class Inference:
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
stitch batched waveform into single waveform. (Overlap-add)
|
stitch batched waveform into single waveform. (Overlap-add)
|
||||||
inspired from https://github.com/asteroid-team/asteroid
|
|
||||||
arguments:
|
arguments:
|
||||||
data: batched waveform
|
data: batched waveform
|
||||||
window_size : window_size used to batch waveform
|
window_size : window_size used to batch waveform
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
import warnings
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchmetrics import ScaleInvariantSignalNoiseRatio
|
|
||||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
@ -66,8 +65,8 @@ class Si_SDR:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Invalid reduction, valid options are sum, mean, None"
|
"Invalid reduction, valid options are sum, mean, None"
|
||||||
)
|
)
|
||||||
self.higher_better = True
|
self.higher_better = False
|
||||||
self.name = "si-sdr"
|
self.name = "Si-SDR"
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -123,18 +122,18 @@ class Pesq:
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
self.name = "pesq"
|
self.name = "pesq"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.pesq = PerceptualEvaluationSpeechQuality(
|
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||||
fs=self.sr, mode=self.mode
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
pesq_values = []
|
pesq_values = []
|
||||||
for pred, target_ in zip(prediction, target):
|
for pred, target_ in zip(prediction, target):
|
||||||
try:
|
try:
|
||||||
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
pesq_values.append(
|
||||||
|
self.pesq(pred.squeeze(), target_.squeeze()).item()
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f"{e} error occured while calculating PESQ")
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
return torch.tensor(np.mean(pesq_values))
|
return torch.tensor(np.mean(pesq_values))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -183,34 +182,10 @@ class LossWrapper(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class Si_snr(nn.Module):
|
|
||||||
"""
|
|
||||||
SI-SNR
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
|
||||||
self.higher_better = False
|
|
||||||
self.name = "si_snr"
|
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
|
||||||
raise TypeError(
|
|
||||||
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
|
||||||
got {prediction.size()} and {target.size()} instead"""
|
|
||||||
)
|
|
||||||
|
|
||||||
return -1 * self.loss_fun(prediction, target)
|
|
||||||
|
|
||||||
|
|
||||||
LOSS_MAP = {
|
LOSS_MAP = {
|
||||||
"mae": mean_absolute_error,
|
"mae": mean_absolute_error,
|
||||||
"mse": mean_squared_error,
|
"mse": mean_squared_error,
|
||||||
"si-sdr": Si_SDR,
|
"si-sdr": Si_SDR,
|
||||||
"pesq": Pesq,
|
"pesq": Pesq,
|
||||||
"stoi": Stoi,
|
"stoi": Stoi,
|
||||||
"si-snr": Si_snr,
|
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from enhancer.models.demucs import Demucs
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
from enhancer.models.waveunet import WaveUnet
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mayavoz.data.dataset import MayaDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from mayavoz.models.model import Mayamodel
|
from enhancer.models.model import Model
|
||||||
from mayavoz.utils.io import Audio as audio
|
from enhancer.utils.io import Audio as audio
|
||||||
from mayavoz.utils.utils import merge_dict
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
class DemucsLSTM(nn.Module):
|
class DemucsLSTM(nn.Module):
|
||||||
|
|
@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Demucs(Mayamodel):
|
class Demucs(Model):
|
||||||
"""
|
"""
|
||||||
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -102,8 +102,8 @@ class Demucs(Mayamodel):
|
||||||
sampling rate of input audio
|
sampling rate of input audio
|
||||||
lr : float, defaults to 1e-3
|
lr : float, defaults to 1e-3
|
||||||
learning rate used for training
|
learning rate used for training
|
||||||
dataset: MayaDataset, optional
|
dataset: EnhancerDataset, optional
|
||||||
MayaDataset object containing train/validation data for training
|
EnhancerDataset object containing train/validation data for training
|
||||||
duration : float, optional
|
duration : float, optional
|
||||||
chunk duration in seconds
|
chunk duration in seconds
|
||||||
loss : string or List of strings
|
loss : string or List of strings
|
||||||
|
|
@ -133,20 +133,17 @@ class Demucs(Mayamodel):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
resample: int = 4,
|
resample: int = 4,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
normalize=True,
|
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[MayaDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration: Optional[float] = None,
|
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
floor=1e-3,
|
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
warnings.warn(
|
logging.warning(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
@ -164,8 +161,6 @@ class Demucs(Mayamodel):
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||||
hidden = encoder_decoder["initial_output_channels"]
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
self.normalize = normalize
|
|
||||||
self.floor = floor
|
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
|
@ -205,16 +200,11 @@ class Demucs(Mayamodel):
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1) != self.hparams.num_channels:
|
if waveform.size(1) != 1:
|
||||||
raise ValueError(
|
raise TypeError(
|
||||||
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
)
|
)
|
||||||
if self.normalize:
|
|
||||||
waveform = waveform.mean(dim=1, keepdim=True)
|
|
||||||
std = waveform.std(dim=-1, keepdim=True)
|
|
||||||
waveform = waveform / (self.floor + std)
|
|
||||||
else:
|
|
||||||
std = 1
|
|
||||||
length = waveform.shape[-1]
|
length = waveform.shape[-1]
|
||||||
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||||
if self.hparams.resample > 1:
|
if self.hparams.resample > 1:
|
||||||
|
|
@ -247,7 +237,7 @@ class Demucs(Mayamodel):
|
||||||
)
|
)
|
||||||
|
|
||||||
out = x[..., :length]
|
out = x[..., :length]
|
||||||
return std * out
|
return out
|
||||||
|
|
||||||
def get_padding_length(self, input_length):
|
def get_padding_length(self, input_length):
|
||||||
|
|
||||||
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Text, Union
|
from typing import List, Optional, Text, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -10,24 +10,19 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
from torch import nn
|
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from mayavoz.data.dataset import MayaDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from mayavoz.inference import Inference
|
from enhancer.inference import Inference
|
||||||
from mayavoz.loss import LOSS_MAP, LossWrapper
|
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||||
from mayavoz.version import __version__
|
from enhancer.version import __version__
|
||||||
|
|
||||||
CACHE_DIR = os.getenv(
|
CACHE_DIR = ""
|
||||||
"ENHANCER_CACHE",
|
HF_TORCH_WEIGHTS = ""
|
||||||
os.path.expanduser("~/.cache/torch/mayavoz"),
|
|
||||||
)
|
|
||||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
SAVE_NAME = "mayavoz"
|
|
||||||
|
|
||||||
|
|
||||||
class Mayamodel(pl.LightningModule):
|
class Model(pl.LightningModule):
|
||||||
"""
|
"""
|
||||||
Base class for all models
|
Base class for all models
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -37,11 +32,11 @@ class Mayamodel(pl.LightningModule):
|
||||||
audio sampling rate
|
audio sampling rate
|
||||||
lr: float, optional
|
lr: float, optional
|
||||||
learning rate for model training
|
learning rate for model training
|
||||||
dataset: MayaDataset, optional
|
dataset: EnhancerDataset, optional
|
||||||
mayavoz dataset used for training/validation
|
Enhancer dataset used for training/validation
|
||||||
duration: float, optional
|
duration: float, optional
|
||||||
duration used for training/inference
|
duration used for training/inference
|
||||||
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
loss : string or List of strings, default to "mse"
|
||||||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -51,13 +46,15 @@ class Mayamodel(pl.LightningModule):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
sampling_rate: int = 16000,
|
sampling_rate: int = 16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[MayaDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List, Any] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels == 1, "mayavoz only support for mono channel models"
|
assert (
|
||||||
|
num_channels == 1
|
||||||
|
), "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.save_hyperparameters(
|
self.save_hyperparameters(
|
||||||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||||
|
|
@ -89,11 +86,10 @@ class Mayamodel(pl.LightningModule):
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self, metric):
|
def metric(self, metric):
|
||||||
self._metric = []
|
self._metric = []
|
||||||
if isinstance(metric, (str, nn.Module)):
|
if isinstance(metric, str):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
|
|
||||||
for func in metric:
|
for func in metric:
|
||||||
if isinstance(func, str):
|
|
||||||
if func in LOSS_MAP.keys():
|
if func in LOSS_MAP.keys():
|
||||||
if func in ("pesq", "stoi"):
|
if func in ("pesq", "stoi"):
|
||||||
self._metric.append(
|
self._metric.append(
|
||||||
|
|
@ -101,13 +97,9 @@ class Mayamodel(pl.LightningModule):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._metric.append(LOSS_MAP[func]())
|
self._metric.append(LOSS_MAP[func]())
|
||||||
else:
|
|
||||||
ValueError(f"Invalid metrics {func}")
|
|
||||||
|
|
||||||
elif isinstance(func, nn.Module):
|
|
||||||
self._metric.append(func)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid metrics")
|
raise ValueError(f"Invalid metrics {func}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|
@ -121,29 +113,22 @@ class Mayamodel(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
"Total train duration",
|
"Total train duration",
|
||||||
self.dataset.train_dataloader().dataset.__len__()
|
self.dataset.train_dataloader().dataset.__len__() / 60,
|
||||||
* self.dataset.duration
|
|
||||||
/ 60,
|
|
||||||
"minutes",
|
"minutes",
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"Total validation duration",
|
"Total validation duration",
|
||||||
self.dataset.val_dataloader().dataset.__len__()
|
self.dataset.val_dataloader().dataset.__len__() / 60,
|
||||||
* self.dataset.duration
|
|
||||||
/ 60,
|
|
||||||
"minutes",
|
"minutes",
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"Total test duration",
|
"Total test duration",
|
||||||
self.dataset.test_dataloader().dataset.__len__()
|
self.dataset.test_dataloader().dataset.__len__() / 60,
|
||||||
* self.dataset.duration
|
|
||||||
/ 60,
|
|
||||||
"minutes",
|
"minutes",
|
||||||
)
|
)
|
||||||
|
self.dataset.model = self
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
@ -234,8 +219,8 @@ class Mayamodel(pl.LightningModule):
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint[SAVE_NAME] = {
|
checkpoint["enhancer"] = {
|
||||||
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__},
|
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||||
"architecture": {
|
"architecture": {
|
||||||
"module": self.__class__.__module__,
|
"module": self.__class__.__module__,
|
||||||
"class": self.__class__.__name__,
|
"class": self.__class__.__name__,
|
||||||
|
|
@ -288,8 +273,8 @@ class Mayamodel(pl.LightningModule):
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
model : Mayamodel
|
model : Model
|
||||||
Mayamodel
|
Model
|
||||||
|
|
||||||
See also
|
See also
|
||||||
--------
|
--------
|
||||||
|
|
@ -318,7 +303,7 @@ class Mayamodel(pl.LightningModule):
|
||||||
)
|
)
|
||||||
model_path_pl = cached_download(
|
model_path_pl = cached_download(
|
||||||
url=url,
|
url=url,
|
||||||
library_name="mayavoz",
|
library_name="enhancer",
|
||||||
library_version=__version__,
|
library_version=__version__,
|
||||||
cache_dir=cached_dir,
|
cache_dir=cached_dir,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
|
|
@ -328,8 +313,8 @@ class Mayamodel(pl.LightningModule):
|
||||||
map_location = torch.device(DEFAULT_DEVICE)
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||||
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"]
|
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||||
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"]
|
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||||
module = import_module(module_name)
|
module = import_module(module_name)
|
||||||
Klass = getattr(module, class_name)
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import warnings
|
import logging
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mayavoz.data.dataset import MayaDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from mayavoz.models.model import Mayamodel
|
from enhancer.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
class WavenetDecoder(nn.Module):
|
class WavenetDecoder(nn.Module):
|
||||||
|
|
@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module):
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
class WaveUnet(Mayamodel):
|
class WaveUnet(Model):
|
||||||
"""
|
"""
|
||||||
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -80,8 +80,8 @@ class WaveUnet(Mayamodel):
|
||||||
sampling rate of input audio
|
sampling rate of input audio
|
||||||
lr : float, defaults to 1e-3
|
lr : float, defaults to 1e-3
|
||||||
learning rate used for training
|
learning rate used for training
|
||||||
dataset: MayaDataset, optional
|
dataset: EnhancerDataset, optional
|
||||||
MayaDataset object containing train/validation data for training
|
EnhancerDataset object containing train/validation data for training
|
||||||
duration : float, optional
|
duration : float, optional
|
||||||
chunk duration in seconds
|
chunk duration in seconds
|
||||||
loss : string or List of strings
|
loss : string or List of strings
|
||||||
|
|
@ -97,17 +97,17 @@ class WaveUnet(Mayamodel):
|
||||||
initial_output_channels: int = 24,
|
initial_output_channels: int = 24,
|
||||||
sampling_rate: int = 16000,
|
sampling_rate: int = 16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[MayaDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
warnings.warn(
|
logging.warning(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
from enhancer.utils.utils import check_files
|
||||||
|
|
@ -70,7 +70,7 @@ class Audio:
|
||||||
|
|
||||||
if sampling_rate:
|
if sampling_rate:
|
||||||
audio = self.__class__.resample_audio(
|
audio = self.__class__.resample_audio(
|
||||||
audio, sampling_rate, self.sampling_rate
|
audio, self.sampling_rate, sampling_rate
|
||||||
)
|
)
|
||||||
if self.return_tensor:
|
if self.return_tensor:
|
||||||
return torch.tensor(audio)
|
return torch.tensor(audio)
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mayavoz.utils.config import Files
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
def check_files(root_dir: str, files: Files):
|
def check_files(root_dir: str, files: Files):
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
name: mayavoz
|
name: enhancer
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
- pip=21.0.1
|
- pip=21.0.1
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo '----------------------------------------------------'
|
||||||
|
echo ' SLURM_CLUSTER_NAME = '$SLURM_CLUSTER_NAME
|
||||||
|
echo ' SLURMD_NODENAME = '$SLURMD_NODENAME
|
||||||
|
echo ' SLURM_JOBID = '$SLURM_JOBID
|
||||||
|
echo ' SLURM_JOB_USER = '$SLURM_JOB_USER
|
||||||
|
echo ' SLURM_PARTITION = '$SLURM_JOB_PARTITION
|
||||||
|
echo ' SLURM_JOB_ACCOUNT = '$SLURM_JOB_ACCOUNT
|
||||||
|
echo '----------------------------------------------------'
|
||||||
|
|
||||||
|
#TeamCity Output
|
||||||
|
cat << EOF
|
||||||
|
##teamcity[buildNumber '$SLURM_JOBID']
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Load HPC modules"
|
||||||
|
module load anaconda
|
||||||
|
|
||||||
|
echo "Activate Environment"
|
||||||
|
source activate enhancer
|
||||||
|
export TRANSFORMERS_OFFLINE=True
|
||||||
|
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
|
||||||
|
export HYDRA_FULL_ERROR=1
|
||||||
|
|
||||||
|
echo $PYTHONPATH
|
||||||
|
|
||||||
|
source ~/mlflow_settings.sh
|
||||||
|
|
||||||
|
echo "Making temp dir"
|
||||||
|
mkdir temp
|
||||||
|
pwd
|
||||||
|
|
||||||
|
# echo "files"
|
||||||
|
# rm -rf /scratch/c.sistc3/MS-SNSD/DNS30/CleanSpeech_training
|
||||||
|
# rm -rf /scratch/c.sistc3/MS-SNSD/DNS30/NoisySpeech_training
|
||||||
|
# rm -rf /scratch/c.sistc3/MS-SNSD/DNS30/NoisySpeech_testing
|
||||||
|
# rm -rf /scratch/c.sistc3/MS-SNSD/DNS30/CleanSpeech_testing
|
||||||
|
|
||||||
|
# cp -r /scratch/c.sistc3/MS-SNSD/DNS30/NoisySpeech_testing /scratch/c.sistc3/MS-SNSD/DNS15/
|
||||||
|
# cp -r /scratch/c.sistc3/MS-SNSD/DNS30/CleanSpeech_testing /scratch/c.sistc3/MS-SNSD/DNS15/
|
||||||
|
# rm -rf /scratch/c.sistc3/MS-SNSD/DNS20
|
||||||
|
|
||||||
|
# mkdir /scratch/c.sistc3/MS-SNSD/DNS20
|
||||||
|
|
||||||
|
python noisyspeech_synthesizer.py
|
||||||
|
|
||||||
|
mv ./CleanSpeech_testing/ /scratch/c.sistc3/MS-SNSD/DNS20
|
||||||
|
mv ./NoisySpeech_testing/ /scratch/c.sistc3/MS-SNSD/DNS20
|
||||||
|
ls /scratch/c.sistc3/MS-SNSD/DNS20
|
||||||
|
#python enhancer/cli/train.py
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
import os
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
from hydra.utils import instantiate
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pytorch_lightning.callbacks import (
|
|
||||||
EarlyStopping,
|
|
||||||
LearningRateMonitor,
|
|
||||||
ModelCheckpoint,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
|
|
||||||
# from torch_audiomentations import Compose, Shift
|
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
|
||||||
def train(config: DictConfig):
|
|
||||||
|
|
||||||
OmegaConf.save(config, "config.yaml")
|
|
||||||
|
|
||||||
callbacks = []
|
|
||||||
logger = MLFlowLogger(
|
|
||||||
experiment_name=config.mlflow.experiment_name,
|
|
||||||
run_name=config.mlflow.run_name,
|
|
||||||
tags={"JOB_ID": JOB_ID},
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
|
||||||
# apply_augmentations = Compose(
|
|
||||||
# [
|
|
||||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
dataset = instantiate(config.dataset, augmentations=None)
|
|
||||||
model = instantiate(
|
|
||||||
config.model,
|
|
||||||
dataset=dataset,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
loss=parameters.get("loss"),
|
|
||||||
metric=parameters.get("metric"),
|
|
||||||
)
|
|
||||||
|
|
||||||
direction = model.valid_monitor
|
|
||||||
checkpoint = ModelCheckpoint(
|
|
||||||
dirpath="./model",
|
|
||||||
filename=f"model_{JOB_ID}",
|
|
||||||
monitor="valid_loss",
|
|
||||||
verbose=False,
|
|
||||||
mode=direction,
|
|
||||||
every_n_epochs=1,
|
|
||||||
)
|
|
||||||
callbacks.append(checkpoint)
|
|
||||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
|
||||||
early_stopping = EarlyStopping(
|
|
||||||
monitor="val_loss",
|
|
||||||
mode=direction,
|
|
||||||
min_delta=0.0,
|
|
||||||
patience=parameters.get("EarlyStopping_patience", 10),
|
|
||||||
strict=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stopping)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = instantiate(
|
|
||||||
config.optimizer,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
params=self.parameters(),
|
|
||||||
)
|
|
||||||
scheduler = ReduceLROnPlateau(
|
|
||||||
optimizer=optimizer,
|
|
||||||
mode=direction,
|
|
||||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
|
||||||
verbose=True,
|
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"optimizer": optimizer,
|
|
||||||
"lr_scheduler": scheduler,
|
|
||||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
|
||||||
}
|
|
||||||
|
|
||||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
|
||||||
trainer.fit(model)
|
|
||||||
trainer.test(model)
|
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
saved_location = os.path.join(
|
|
||||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
|
||||||
)
|
|
||||||
if os.path.isfile(saved_location):
|
|
||||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_train_steps_per_epoch",
|
|
||||||
dataset.train__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_valid_steps_per_epoch",
|
|
||||||
dataset.val__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
train()
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
|
||||||
name : MS-SDSD
|
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
|
||||||
duration : 2.0
|
|
||||||
sampling_rate: 16000
|
|
||||||
batch_size: 32
|
|
||||||
min_valid_minutes: 15
|
|
||||||
files:
|
|
||||||
train_clean : CleanSpeech_training
|
|
||||||
test_clean : CleanSpeech_training
|
|
||||||
train_noisy : NoisySpeech_training
|
|
||||||
test_noisy : NoisySpeech_training
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
|
||||||
name : Valentini
|
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
|
||||||
duration : 4.5
|
|
||||||
stride : 2
|
|
||||||
sampling_rate: 16000
|
|
||||||
batch_size: 32
|
|
||||||
valid_minutes : 15
|
|
||||||
files:
|
|
||||||
train_clean : clean_trainset_28spk_wav
|
|
||||||
test_clean : clean_testset_wav
|
|
||||||
train_noisy : noisy_trainset_28spk_wav
|
|
||||||
test_noisy : noisy_testset_wav
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
loss : mae
|
|
||||||
metric : [stoi,pesq,si-sdr]
|
|
||||||
lr : 0.0003
|
|
||||||
ReduceLr_patience : 5
|
|
||||||
ReduceLr_factor : 0.2
|
|
||||||
min_lr : 0.000001
|
|
||||||
EarlyStopping_factor : 10
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
experiment_name : shahules/mayavoz
|
|
||||||
run_name : Demucs + Vtck with stride + augmentations
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
_target_: mayavoz.models.dccrn.DCCRN
|
|
||||||
num_channels: 1
|
|
||||||
sampling_rate : 16000
|
|
||||||
complex_lstm : True
|
|
||||||
complex_norm : True
|
|
||||||
complex_relu : True
|
|
||||||
masking_mode : True
|
|
||||||
|
|
||||||
encoder_decoder:
|
|
||||||
initial_output_channels : 32
|
|
||||||
depth : 6
|
|
||||||
kernel_size : 5
|
|
||||||
growth_factor : 2
|
|
||||||
stride : 2
|
|
||||||
padding : 2
|
|
||||||
output_padding : 1
|
|
||||||
|
|
||||||
lstm:
|
|
||||||
num_layers : 2
|
|
||||||
hidden_size : 256
|
|
||||||
|
|
||||||
stft:
|
|
||||||
window_len : 400
|
|
||||||
hop_size : 100
|
|
||||||
nfft : 512
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
accelerator: gpu
|
|
||||||
accumulate_grad_batches: 1
|
|
||||||
amp_backend: native
|
|
||||||
auto_lr_find: True
|
|
||||||
auto_scale_batch_size: False
|
|
||||||
auto_select_gpus: True
|
|
||||||
benchmark: False
|
|
||||||
check_val_every_n_epoch: 1
|
|
||||||
detect_anomaly: False
|
|
||||||
deterministic: False
|
|
||||||
devices: 2
|
|
||||||
enable_checkpointing: True
|
|
||||||
enable_model_summary: True
|
|
||||||
enable_progress_bar: True
|
|
||||||
fast_dev_run: False
|
|
||||||
gpus: null
|
|
||||||
gradient_clip_val: 0
|
|
||||||
gradient_clip_algorithm: norm
|
|
||||||
ipus: null
|
|
||||||
limit_predict_batches: 1.0
|
|
||||||
limit_test_batches: 1.0
|
|
||||||
limit_train_batches: 1.0
|
|
||||||
limit_val_batches: 1.0
|
|
||||||
log_every_n_steps: 50
|
|
||||||
max_epochs: 200
|
|
||||||
max_steps: -1
|
|
||||||
max_time: null
|
|
||||||
min_epochs: 1
|
|
||||||
min_steps: null
|
|
||||||
move_metrics_to_cpu: False
|
|
||||||
multiple_trainloader_mode: max_size_cycle
|
|
||||||
num_nodes: 1
|
|
||||||
num_processes: 1
|
|
||||||
num_sanity_val_steps: 2
|
|
||||||
overfit_batches: 0.0
|
|
||||||
precision: 32
|
|
||||||
profiler: null
|
|
||||||
reload_dataloaders_every_n_epochs: 0
|
|
||||||
replace_sampler_ddp: True
|
|
||||||
strategy: ddp
|
|
||||||
sync_batchnorm: False
|
|
||||||
tpu_cores: null
|
|
||||||
track_grad_norm: -1
|
|
||||||
val_check_interval: 1.0
|
|
||||||
weights_save_path: null
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from mayavoz.data.dataset import MayaDataset
|
|
||||||
|
|
@ -1,393 +0,0 @@
|
||||||
import math
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
|
||||||
from torch_audiomentations import Compose
|
|
||||||
|
|
||||||
from mayavoz.data.fileprocessor import Fileprocessor
|
|
||||||
from mayavoz.utils import check_files
|
|
||||||
from mayavoz.utils.config import Files
|
|
||||||
from mayavoz.utils.io import Audio
|
|
||||||
from mayavoz.utils.random import create_unique_rng
|
|
||||||
|
|
||||||
LARGE_NUM = 2147483647
|
|
||||||
|
|
||||||
|
|
||||||
class TrainDataset(Dataset):
|
|
||||||
def __init__(self, dataset):
|
|
||||||
self.dataset = dataset
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.dataset.train__getitem__(idx)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.dataset.train__len__()
|
|
||||||
|
|
||||||
|
|
||||||
class ValidDataset(Dataset):
|
|
||||||
def __init__(self, dataset):
|
|
||||||
self.dataset = dataset
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.dataset.val__getitem__(idx)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.dataset.val__len__()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataset(Dataset):
|
|
||||||
def __init__(self, dataset):
|
|
||||||
self.dataset = dataset
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.dataset.test__getitem__(idx)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.dataset.test__len__()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskDataset(pl.LightningDataModule):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
root_dir: str,
|
|
||||||
files: Files,
|
|
||||||
min_valid_minutes: float = 0.20,
|
|
||||||
duration: float = 1.0,
|
|
||||||
stride=None,
|
|
||||||
sampling_rate: int = 48000,
|
|
||||||
matching_function=None,
|
|
||||||
batch_size=32,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
augmentations: Optional[Compose] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.name = name
|
|
||||||
self.files, self.root_dir = check_files(root_dir, files)
|
|
||||||
self.duration = duration
|
|
||||||
self.stride = stride or duration
|
|
||||||
self.sampling_rate = sampling_rate
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.matching_function = matching_function
|
|
||||||
self._validation = []
|
|
||||||
if num_workers is None:
|
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
|
||||||
if num_workers is None:
|
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
|
||||||
|
|
||||||
if (
|
|
||||||
num_workers > 0
|
|
||||||
and sys.platform == "darwin"
|
|
||||||
and sys.version_info[0] >= 3
|
|
||||||
and sys.version_info[1] >= 8
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
"num_workers > 0 is not supported with macOS and Python 3.8+: "
|
|
||||||
"setting num_workers = 0."
|
|
||||||
)
|
|
||||||
num_workers = 0
|
|
||||||
|
|
||||||
self.num_workers = num_workers
|
|
||||||
if min_valid_minutes > 0.0:
|
|
||||||
self.min_valid_minutes = min_valid_minutes
|
|
||||||
else:
|
|
||||||
raise ValueError("min_valid_minutes must be greater than 0")
|
|
||||||
|
|
||||||
self.augmentations = augmentations
|
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
prepare train/validation/test data splits
|
|
||||||
"""
|
|
||||||
|
|
||||||
if stage in ("fit", None):
|
|
||||||
|
|
||||||
train_clean = os.path.join(self.root_dir, self.files.train_clean)
|
|
||||||
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
|
|
||||||
fp = Fileprocessor.from_name(
|
|
||||||
self.name, train_clean, train_noisy, self.matching_function
|
|
||||||
)
|
|
||||||
train_data = fp.prepare_matching_dict()
|
|
||||||
train_data, self.val_data = self.train_valid_split(
|
|
||||||
train_data,
|
|
||||||
min_valid_minutes=self.min_valid_minutes,
|
|
||||||
random_state=42,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train_data = self.prepare_traindata(train_data)
|
|
||||||
self._validation = self.prepare_mapstype(self.val_data)
|
|
||||||
|
|
||||||
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
|
||||||
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
|
||||||
fp = Fileprocessor.from_name(
|
|
||||||
self.name, test_clean, test_noisy, self.matching_function
|
|
||||||
)
|
|
||||||
test_data = fp.prepare_matching_dict()
|
|
||||||
self._test = self.prepare_mapstype(test_data)
|
|
||||||
|
|
||||||
def train_valid_split(
|
|
||||||
self, data, min_valid_minutes: float = 20, random_state: int = 42
|
|
||||||
):
|
|
||||||
|
|
||||||
min_valid_minutes *= 60
|
|
||||||
valid_sec_now = 0.0
|
|
||||||
valid_indices = []
|
|
||||||
all_speakers = np.unique(
|
|
||||||
[Path(file["clean"]).name.split("_")[0] for file in data]
|
|
||||||
)
|
|
||||||
possible_indices = list(range(0, len(all_speakers)))
|
|
||||||
rng = create_unique_rng(len(all_speakers))
|
|
||||||
|
|
||||||
while valid_sec_now <= min_valid_minutes:
|
|
||||||
speaker_index = rng.choice(possible_indices)
|
|
||||||
possible_indices.remove(speaker_index)
|
|
||||||
speaker_name = all_speakers[speaker_index]
|
|
||||||
print(f"Selected f{speaker_name} for valid")
|
|
||||||
file_indices = [
|
|
||||||
i
|
|
||||||
for i, file in enumerate(data)
|
|
||||||
if speaker_name == Path(file["clean"]).name.split("_")[0]
|
|
||||||
]
|
|
||||||
for i in file_indices:
|
|
||||||
valid_indices.append(i)
|
|
||||||
valid_sec_now += data[i]["duration"]
|
|
||||||
|
|
||||||
train_data = [
|
|
||||||
item for i, item in enumerate(data) if i not in valid_indices
|
|
||||||
]
|
|
||||||
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
|
||||||
return train_data, valid_data
|
|
||||||
|
|
||||||
def prepare_traindata(self, data):
|
|
||||||
train_data = []
|
|
||||||
for item in data:
|
|
||||||
clean, noisy, total_dur = item.values()
|
|
||||||
num_segments = self.get_num_segments(
|
|
||||||
total_dur, self.duration, self.stride
|
|
||||||
)
|
|
||||||
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
|
|
||||||
train_data.append(samples_metadata)
|
|
||||||
return train_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_num_segments(file_duration, duration, stride):
|
|
||||||
|
|
||||||
if file_duration < duration:
|
|
||||||
num_segments = 1
|
|
||||||
else:
|
|
||||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
|
||||||
|
|
||||||
return num_segments
|
|
||||||
|
|
||||||
def prepare_mapstype(self, data):
|
|
||||||
|
|
||||||
metadata = []
|
|
||||||
for item in data:
|
|
||||||
clean, noisy, total_dur = item.values()
|
|
||||||
if total_dur < self.duration:
|
|
||||||
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
|
||||||
else:
|
|
||||||
num_segments = self.get_num_segments(
|
|
||||||
total_dur, self.duration, self.duration
|
|
||||||
)
|
|
||||||
for index in range(num_segments):
|
|
||||||
start_time = index * self.duration
|
|
||||||
metadata.append(
|
|
||||||
({"clean": clean, "noisy": noisy}, start_time)
|
|
||||||
)
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def train_collatefn(self, batch):
|
|
||||||
|
|
||||||
output = {"clean": [], "noisy": []}
|
|
||||||
for item in batch:
|
|
||||||
output["clean"].append(item["clean"])
|
|
||||||
output["noisy"].append(item["noisy"])
|
|
||||||
|
|
||||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
|
||||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
|
||||||
|
|
||||||
if self.augmentations is not None:
|
|
||||||
noise = output["noisy"] - output["clean"]
|
|
||||||
output["clean"] = self.augmentations(
|
|
||||||
output["clean"], sample_rate=self.sampling_rate
|
|
||||||
)
|
|
||||||
self.augmentations.freeze_parameters()
|
|
||||||
output["noisy"] = (
|
|
||||||
self.augmentations(noise, sample_rate=self.sampling_rate)
|
|
||||||
+ output["clean"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@property
|
|
||||||
def generator(self):
|
|
||||||
generator = torch.Generator()
|
|
||||||
if hasattr(self, "model"):
|
|
||||||
seed = self.model.current_epoch + LARGE_NUM
|
|
||||||
else:
|
|
||||||
seed = LARGE_NUM
|
|
||||||
return generator.manual_seed(seed)
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
dataset = TrainDataset(self)
|
|
||||||
sampler = RandomSampler(dataset, generator=self.generator)
|
|
||||||
return DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
sampler=sampler,
|
|
||||||
collate_fn=self.train_collatefn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
|
||||||
return DataLoader(
|
|
||||||
ValidDataset(self),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_dataloader(self):
|
|
||||||
return DataLoader(
|
|
||||||
TestDataset(self),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MayaDataset(TaskDataset):
|
|
||||||
"""
|
|
||||||
Dataset object for creating clean-noisy speech enhancement datasets
|
|
||||||
paramters:
|
|
||||||
name : str
|
|
||||||
name of the dataset
|
|
||||||
root_dir : str
|
|
||||||
root directory of the dataset containing clean/noisy folders
|
|
||||||
files : Files
|
|
||||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
|
||||||
folder names (refer mayavoz.utils.Files dataclass)
|
|
||||||
min_valid_minutes: float
|
|
||||||
minimum validation split size time in minutes
|
|
||||||
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
|
||||||
duration : float
|
|
||||||
expected audio duration of single audio sample for training
|
|
||||||
sampling_rate : int
|
|
||||||
desired sampling rate
|
|
||||||
batch_size : int
|
|
||||||
batch size of each batch
|
|
||||||
num_workers : int
|
|
||||||
num workers to be used while training
|
|
||||||
matching_function : str
|
|
||||||
maching functions - (one_to_one,one_to_many). Default set to None.
|
|
||||||
use one_to_one mapping for datasets with one noisy file for each clean file
|
|
||||||
use one_to_many mapping for multiple noisy files for each clean file
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
root_dir: str,
|
|
||||||
files: Files,
|
|
||||||
min_valid_minutes=5.0,
|
|
||||||
duration=1.0,
|
|
||||||
stride=None,
|
|
||||||
sampling_rate=48000,
|
|
||||||
matching_function=None,
|
|
||||||
batch_size=32,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
augmentations: Optional[Compose] = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name=name,
|
|
||||||
root_dir=root_dir,
|
|
||||||
files=files,
|
|
||||||
min_valid_minutes=min_valid_minutes,
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
duration=duration,
|
|
||||||
matching_function=matching_function,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
augmentations=augmentations,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
|
||||||
self.files = files
|
|
||||||
self.duration = max(1.0, duration)
|
|
||||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
|
||||||
self.stride = stride or duration
|
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
|
||||||
|
|
||||||
super().setup(stage=stage)
|
|
||||||
|
|
||||||
def train__getitem__(self, idx):
|
|
||||||
|
|
||||||
for filedict, num_samples in self.train_data:
|
|
||||||
if idx >= num_samples:
|
|
||||||
idx -= num_samples
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
start = 0
|
|
||||||
if self.duration is not None:
|
|
||||||
start = idx * self.stride
|
|
||||||
return self.prepare_segment(filedict, start)
|
|
||||||
|
|
||||||
def val__getitem__(self, idx):
|
|
||||||
return self.prepare_segment(*self._validation[idx])
|
|
||||||
|
|
||||||
def test__getitem__(self, idx):
|
|
||||||
return self.prepare_segment(*self._test[idx])
|
|
||||||
|
|
||||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
|
||||||
clean_segment = self.audio(
|
|
||||||
file_dict["clean"], offset=start_time, duration=self.duration
|
|
||||||
)
|
|
||||||
noisy_segment = self.audio(
|
|
||||||
file_dict["noisy"], offset=start_time, duration=self.duration
|
|
||||||
)
|
|
||||||
clean_segment = F.pad(
|
|
||||||
clean_segment,
|
|
||||||
(
|
|
||||||
0,
|
|
||||||
int(
|
|
||||||
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
noisy_segment = F.pad(
|
|
||||||
noisy_segment,
|
|
||||||
(
|
|
||||||
0,
|
|
||||||
int(
|
|
||||||
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"clean": clean_segment,
|
|
||||||
"noisy": noisy_segment,
|
|
||||||
}
|
|
||||||
|
|
||||||
def train__len__(self):
|
|
||||||
_, num_examples = list(zip(*self.train_data))
|
|
||||||
return sum(num_examples)
|
|
||||||
|
|
||||||
def val__len__(self):
|
|
||||||
return len(self._validation)
|
|
||||||
|
|
||||||
def test__len__(self):
|
|
||||||
return len(self._test)
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
from mayavoz.models.demucs import Demucs
|
|
||||||
from mayavoz.models.model import Mayamodel
|
|
||||||
from mayavoz.models.waveunet import WaveUnet
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
from mayavoz.models.complexnn.conv import ComplexConv2d # noqa
|
|
||||||
from mayavoz.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
|
||||||
from mayavoz.models.complexnn.rnn import ComplexLSTM # noqa
|
|
||||||
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
|
||||||
from mayavoz.models.complexnn.utils import ComplexRelu # noqa
|
|
||||||
|
|
@ -1,136 +0,0 @@
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
def init_weights(nnet):
|
|
||||||
nn.init.xavier_normal_(nnet.weight.data)
|
|
||||||
nn.init.constant_(nnet.bias, 0.0)
|
|
||||||
return nnet
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexConv2d(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: Tuple[int, int] = (1, 1),
|
|
||||||
stride: Tuple[int, int] = (1, 1),
|
|
||||||
padding: Tuple[int, int] = (0, 0),
|
|
||||||
groups: int = 1,
|
|
||||||
dilation: int = 1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Complex Conv2d (non-causal)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels // 2
|
|
||||||
self.out_channels = out_channels // 2
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.groups = groups
|
|
||||||
self.dilation = dilation
|
|
||||||
|
|
||||||
self.real_conv = nn.Conv2d(
|
|
||||||
self.in_channels,
|
|
||||||
self.out_channels,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=(self.padding[0], 0),
|
|
||||||
groups=self.groups,
|
|
||||||
dilation=self.dilation,
|
|
||||||
)
|
|
||||||
self.imag_conv = nn.Conv2d(
|
|
||||||
self.in_channels,
|
|
||||||
self.out_channels,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=(self.padding[0], 0),
|
|
||||||
groups=self.groups,
|
|
||||||
dilation=self.dilation,
|
|
||||||
)
|
|
||||||
self.imag_conv = init_weights(self.imag_conv)
|
|
||||||
self.real_conv = init_weights(self.real_conv)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
"""
|
|
||||||
complex axis should be always 1 dim
|
|
||||||
"""
|
|
||||||
input = F.pad(input, [self.padding[1], 0, 0, 0])
|
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
|
||||||
|
|
||||||
real_real = self.real_conv(real)
|
|
||||||
real_imag = self.imag_conv(real)
|
|
||||||
|
|
||||||
imag_imag = self.imag_conv(imag)
|
|
||||||
imag_real = self.real_conv(imag)
|
|
||||||
|
|
||||||
real = real_real - imag_imag
|
|
||||||
imag = real_imag - imag_real
|
|
||||||
|
|
||||||
out = torch.cat([real, imag], 1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexConvTranspose2d(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: Tuple[int, int] = (1, 1),
|
|
||||||
stride: Tuple[int, int] = (1, 1),
|
|
||||||
padding: Tuple[int, int] = (0, 0),
|
|
||||||
output_padding: Tuple[int, int] = (0, 0),
|
|
||||||
groups: int = 1,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels // 2
|
|
||||||
self.out_channels = out_channels // 2
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.groups = groups
|
|
||||||
self.output_padding = output_padding
|
|
||||||
|
|
||||||
self.real_conv = nn.ConvTranspose2d(
|
|
||||||
self.in_channels,
|
|
||||||
self.out_channels,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
output_padding=self.output_padding,
|
|
||||||
groups=self.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.imag_conv = nn.ConvTranspose2d(
|
|
||||||
self.in_channels,
|
|
||||||
self.out_channels,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
output_padding=self.output_padding,
|
|
||||||
groups=self.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.real_conv = init_weights(self.real_conv)
|
|
||||||
self.imag_conv = init_weights(self.imag_conv)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
|
||||||
real_real = self.real_conv(real)
|
|
||||||
real_imag = self.imag_conv(real)
|
|
||||||
|
|
||||||
imag_imag = self.imag_conv(imag)
|
|
||||||
imag_real = self.real_conv(imag)
|
|
||||||
|
|
||||||
real = real_real - imag_imag
|
|
||||||
imag = real_imag + imag_real
|
|
||||||
|
|
||||||
out = torch.cat([real, imag], 1)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
@ -1,68 +0,0 @@
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexLSTM(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
num_layers: int = 1,
|
|
||||||
projection_size: Optional[int] = None,
|
|
||||||
bidirectional: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.input_size = input_size // 2
|
|
||||||
self.hidden_size = hidden_size // 2
|
|
||||||
self.num_layers = num_layers
|
|
||||||
|
|
||||||
self.real_lstm = nn.LSTM(
|
|
||||||
self.input_size,
|
|
||||||
self.hidden_size,
|
|
||||||
self.num_layers,
|
|
||||||
bidirectional=bidirectional,
|
|
||||||
batch_first=False,
|
|
||||||
)
|
|
||||||
self.imag_lstm = nn.LSTM(
|
|
||||||
self.input_size,
|
|
||||||
self.hidden_size,
|
|
||||||
self.num_layers,
|
|
||||||
bidirectional=bidirectional,
|
|
||||||
batch_first=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
bidirectional = 2 if bidirectional else 1
|
|
||||||
if projection_size is not None:
|
|
||||||
self.projection_size = projection_size // 2
|
|
||||||
self.real_linear = nn.Linear(
|
|
||||||
self.hidden_size * bidirectional, self.projection_size
|
|
||||||
)
|
|
||||||
self.imag_linear = nn.Linear(
|
|
||||||
self.hidden_size * bidirectional, self.projection_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.projection_size = None
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
if isinstance(input, List):
|
|
||||||
real, imag = input
|
|
||||||
else:
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
|
||||||
|
|
||||||
real_real = self.real_lstm(real)[0]
|
|
||||||
real_imag = self.imag_lstm(real)[0]
|
|
||||||
|
|
||||||
imag_imag = self.imag_lstm(imag)[0]
|
|
||||||
imag_real = self.real_lstm(imag)[0]
|
|
||||||
|
|
||||||
real = real_real - imag_imag
|
|
||||||
imag = imag_real + real_imag
|
|
||||||
|
|
||||||
if self.projection_size is not None:
|
|
||||||
real = self.real_linear(real)
|
|
||||||
imag = self.imag_linear(imag)
|
|
||||||
|
|
||||||
return [real, imag]
|
|
||||||
|
|
@ -1,199 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexBatchNorm2D(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_features: int,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
momentum: float = 0.1,
|
|
||||||
affine: bool = True,
|
|
||||||
track_running_stats: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Complex batch normalization 2D
|
|
||||||
https://arxiv.org/abs/1705.09792
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features // 2
|
|
||||||
self.affine = affine
|
|
||||||
self.momentum = momentum
|
|
||||||
self.track_running_stats = track_running_stats
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if self.affine:
|
|
||||||
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
||||||
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
||||||
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
||||||
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
||||||
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
|
||||||
else:
|
|
||||||
self.register_parameter("Wrr", None)
|
|
||||||
self.register_parameter("Wri", None)
|
|
||||||
self.register_parameter("Wii", None)
|
|
||||||
self.register_parameter("Br", None)
|
|
||||||
self.register_parameter("Bi", None)
|
|
||||||
|
|
||||||
if self.track_running_stats:
|
|
||||||
values = torch.zeros(self.num_features)
|
|
||||||
self.register_buffer("Mean_real", values)
|
|
||||||
self.register_buffer("Mean_imag", values)
|
|
||||||
self.register_buffer("Var_rr", values)
|
|
||||||
self.register_buffer("Var_ri", values)
|
|
||||||
self.register_buffer("Var_ii", values)
|
|
||||||
self.register_buffer(
|
|
||||||
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("Mean_real", None)
|
|
||||||
self.register_parameter("Mean_imag", None)
|
|
||||||
self.register_parameter("Var_rr", None)
|
|
||||||
self.register_parameter("Var_ri", None)
|
|
||||||
self.register_parameter("Var_ii", None)
|
|
||||||
self.register_parameter("num_batches_tracked", None)
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if self.affine:
|
|
||||||
self.Wrr.data.fill_(1)
|
|
||||||
self.Wii.data.fill_(1)
|
|
||||||
self.Wri.data.uniform_(-0.9, 0.9)
|
|
||||||
self.Br.data.fill_(0)
|
|
||||||
self.Bi.data.fill_(0)
|
|
||||||
self.reset_running_stats()
|
|
||||||
|
|
||||||
def reset_running_stats(self):
|
|
||||||
if self.track_running_stats:
|
|
||||||
self.Mean_real.zero_()
|
|
||||||
self.Mean_imag.zero_()
|
|
||||||
self.Var_rr.fill_(1)
|
|
||||||
self.Var_ri.zero_()
|
|
||||||
self.Var_ii.fill_(1)
|
|
||||||
self.num_batches_tracked.zero_()
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
|
|
||||||
**self.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
|
||||||
exp_avg_factor = 0.0
|
|
||||||
|
|
||||||
training = self.training and self.track_running_stats
|
|
||||||
if training:
|
|
||||||
self.num_batches_tracked += 1
|
|
||||||
if self.momentum is None:
|
|
||||||
exp_avg_factor = 1 / self.num_batches_tracked
|
|
||||||
else:
|
|
||||||
exp_avg_factor = self.momentum
|
|
||||||
|
|
||||||
redux = [i for i in reversed(range(real.dim())) if i != 1]
|
|
||||||
vdim = [1] * real.dim()
|
|
||||||
vdim[1] = real.size(1)
|
|
||||||
|
|
||||||
if training:
|
|
||||||
batch_mean_real, batch_mean_imag = real, imag
|
|
||||||
for dim in redux:
|
|
||||||
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
|
|
||||||
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
|
|
||||||
if self.track_running_stats:
|
|
||||||
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
|
|
||||||
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
|
|
||||||
|
|
||||||
else:
|
|
||||||
batch_mean_real = self.Mean_real.view(vdim)
|
|
||||||
batch_mean_imag = self.Mean_imag.view(vdim)
|
|
||||||
|
|
||||||
real = real - batch_mean_real
|
|
||||||
imag = imag - batch_mean_imag
|
|
||||||
|
|
||||||
if training:
|
|
||||||
batch_var_rr = real * real
|
|
||||||
batch_var_ri = real * imag
|
|
||||||
batch_var_ii = imag * imag
|
|
||||||
for dim in redux:
|
|
||||||
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
|
|
||||||
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
|
|
||||||
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
|
|
||||||
if self.track_running_stats:
|
|
||||||
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
|
|
||||||
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
|
|
||||||
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
|
|
||||||
else:
|
|
||||||
batch_var_rr = self.Var_rr.view(vdim)
|
|
||||||
batch_var_ii = self.Var_ii.view(vdim)
|
|
||||||
batch_var_ri = self.Var_ri.view(vdim)
|
|
||||||
|
|
||||||
batch_var_rr += self.eps
|
|
||||||
batch_var_ii += self.eps
|
|
||||||
|
|
||||||
# Covariance matrics
|
|
||||||
# | batch_var_rr batch_var_ri |
|
|
||||||
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
|
|
||||||
# Inverse square root of cov matrix by combining below two formulas
|
|
||||||
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
|
|
||||||
# https://mathworld.wolfram.com/MatrixInverse.html
|
|
||||||
|
|
||||||
tau = batch_var_rr + batch_var_ii
|
|
||||||
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
|
|
||||||
t = (tau + 2 * s).sqrt()
|
|
||||||
|
|
||||||
rst = (s * t).reciprocal()
|
|
||||||
Urr = (batch_var_ii + s) * rst
|
|
||||||
Uri = -batch_var_ri * rst
|
|
||||||
Uii = (batch_var_rr + s) * rst
|
|
||||||
|
|
||||||
if self.affine:
|
|
||||||
Wrr, Wri, Wii = (
|
|
||||||
self.Wrr.view(vdim),
|
|
||||||
self.Wri.view(vdim),
|
|
||||||
self.Wii.view(vdim),
|
|
||||||
)
|
|
||||||
Zrr = (Wrr * Urr) + (Wri * Uri)
|
|
||||||
Zri = (Wrr * Uri) + (Wri * Uii)
|
|
||||||
Zir = (Wii * Uri) + (Wri * Urr)
|
|
||||||
Zii = (Wri * Uri) + (Wii * Uii)
|
|
||||||
else:
|
|
||||||
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
|
|
||||||
|
|
||||||
yr = (Zrr * real) + (Zri * imag)
|
|
||||||
yi = (Zir * real) + (Zii * imag)
|
|
||||||
|
|
||||||
if self.affine:
|
|
||||||
yr = yr + self.Br.view(vdim)
|
|
||||||
yi = yi + self.Bi.view(vdim)
|
|
||||||
|
|
||||||
outputs = torch.cat([yr, yi], 1)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexRelu(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.real_relu = nn.PReLU()
|
|
||||||
self.imag_relu = nn.PReLU()
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
real, imag = torch.chunk(input, 2, 1)
|
|
||||||
real = self.real_relu(real)
|
|
||||||
imag = self.imag_relu(imag)
|
|
||||||
return torch.cat([real, imag], dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def complex_cat(inputs, axis=1):
|
|
||||||
|
|
||||||
real, imag = [], []
|
|
||||||
for data in inputs:
|
|
||||||
real_data, imag_data = torch.chunk(data, 2, axis)
|
|
||||||
real.append(real_data)
|
|
||||||
imag.append(imag_data)
|
|
||||||
real = torch.cat(real, axis)
|
|
||||||
imag = torch.cat(imag, axis)
|
|
||||||
return torch.cat([real, imag], axis)
|
|
||||||
|
|
@ -1,338 +0,0 @@
|
||||||
import warnings
|
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from mayavoz.data import MayaDataset
|
|
||||||
from mayavoz.models import Mayamodel
|
|
||||||
from mayavoz.models.complexnn import (
|
|
||||||
ComplexBatchNorm2D,
|
|
||||||
ComplexConv2d,
|
|
||||||
ComplexConvTranspose2d,
|
|
||||||
ComplexLSTM,
|
|
||||||
ComplexRelu,
|
|
||||||
)
|
|
||||||
from mayavoz.models.complexnn.utils import complex_cat
|
|
||||||
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT
|
|
||||||
from mayavoz.utils.utils import merge_dict
|
|
||||||
|
|
||||||
|
|
||||||
class DCCRN_ENCODER(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channel: int,
|
|
||||||
kernel_size: Tuple[int, int],
|
|
||||||
complex_norm: bool = True,
|
|
||||||
complex_relu: bool = True,
|
|
||||||
stride: Tuple[int, int] = (2, 1),
|
|
||||||
padding: Tuple[int, int] = (2, 1),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
|
||||||
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
|
||||||
|
|
||||||
self.encoder = nn.Sequential(
|
|
||||||
ComplexConv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channel,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
),
|
|
||||||
batchnorm(out_channel),
|
|
||||||
activation,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, waveform):
|
|
||||||
|
|
||||||
return self.encoder(waveform)
|
|
||||||
|
|
||||||
|
|
||||||
class DCCRN_DECODER(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: Tuple[int, int],
|
|
||||||
layer: int = 0,
|
|
||||||
complex_norm: bool = True,
|
|
||||||
complex_relu: bool = True,
|
|
||||||
stride: Tuple[int, int] = (2, 1),
|
|
||||||
padding: Tuple[int, int] = (2, 0),
|
|
||||||
output_padding: Tuple[int, int] = (1, 0),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
|
||||||
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
|
||||||
|
|
||||||
if layer != 0:
|
|
||||||
self.decoder = nn.Sequential(
|
|
||||||
ComplexConvTranspose2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
output_padding=output_padding,
|
|
||||||
),
|
|
||||||
batchnorm(out_channels),
|
|
||||||
activation,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.decoder = nn.Sequential(
|
|
||||||
ComplexConvTranspose2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
output_padding=output_padding,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, waveform):
|
|
||||||
|
|
||||||
return self.decoder(waveform)
|
|
||||||
|
|
||||||
|
|
||||||
class DCCRN(Mayamodel):
|
|
||||||
|
|
||||||
STFT_DEFAULTS = {
|
|
||||||
"window_len": 400,
|
|
||||||
"hop_size": 100,
|
|
||||||
"nfft": 512,
|
|
||||||
"window": "hamming",
|
|
||||||
}
|
|
||||||
|
|
||||||
ED_DEFAULTS = {
|
|
||||||
"initial_output_channels": 32,
|
|
||||||
"depth": 6,
|
|
||||||
"kernel_size": 5,
|
|
||||||
"growth_factor": 2,
|
|
||||||
"stride": 2,
|
|
||||||
"padding": 2,
|
|
||||||
"output_padding": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
LSTM_DEFAULTS = {
|
|
||||||
"num_layers": 2,
|
|
||||||
"hidden_size": 256,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stft: Optional[dict] = None,
|
|
||||||
encoder_decoder: Optional[dict] = None,
|
|
||||||
lstm: Optional[dict] = None,
|
|
||||||
complex_lstm: bool = True,
|
|
||||||
complex_norm: bool = True,
|
|
||||||
complex_relu: bool = True,
|
|
||||||
masking_mode: str = "E",
|
|
||||||
num_channels: int = 1,
|
|
||||||
sampling_rate=16000,
|
|
||||||
lr: float = 1e-3,
|
|
||||||
dataset: Optional[MayaDataset] = None,
|
|
||||||
duration: Optional[float] = None,
|
|
||||||
loss: Union[str, List, Any] = "mse",
|
|
||||||
metric: Union[str, List] = "mse",
|
|
||||||
):
|
|
||||||
duration = (
|
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
|
||||||
)
|
|
||||||
if dataset is not None:
|
|
||||||
if sampling_rate != dataset.sampling_rate:
|
|
||||||
warnings.warn(
|
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
|
||||||
)
|
|
||||||
sampling_rate = dataset.sampling_rate
|
|
||||||
super().__init__(
|
|
||||||
num_channels=num_channels,
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
lr=lr,
|
|
||||||
dataset=dataset,
|
|
||||||
duration=duration,
|
|
||||||
loss=loss,
|
|
||||||
metric=metric,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
|
||||||
stft = merge_dict(self.STFT_DEFAULTS, stft)
|
|
||||||
self.save_hyperparameters(
|
|
||||||
"encoder_decoder",
|
|
||||||
"lstm",
|
|
||||||
"stft",
|
|
||||||
"complex_lstm",
|
|
||||||
"complex_norm",
|
|
||||||
"masking_mode",
|
|
||||||
)
|
|
||||||
self.complex_lstm = complex_lstm
|
|
||||||
self.complex_norm = complex_norm
|
|
||||||
self.masking_mode = masking_mode
|
|
||||||
|
|
||||||
self.stft = ConvSTFT(
|
|
||||||
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
|
||||||
)
|
|
||||||
self.istft = ConviSTFT(
|
|
||||||
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder = nn.ModuleList()
|
|
||||||
self.decoder = nn.ModuleList()
|
|
||||||
|
|
||||||
num_channels *= 2
|
|
||||||
hidden_size = encoder_decoder["initial_output_channels"]
|
|
||||||
growth_factor = 2
|
|
||||||
|
|
||||||
for layer in range(encoder_decoder["depth"]):
|
|
||||||
|
|
||||||
encoder_ = DCCRN_ENCODER(
|
|
||||||
num_channels,
|
|
||||||
hidden_size,
|
|
||||||
kernel_size=(encoder_decoder["kernel_size"], 2),
|
|
||||||
stride=(encoder_decoder["stride"], 1),
|
|
||||||
padding=(encoder_decoder["padding"], 1),
|
|
||||||
complex_norm=complex_norm,
|
|
||||||
complex_relu=complex_relu,
|
|
||||||
)
|
|
||||||
self.encoder.append(encoder_)
|
|
||||||
|
|
||||||
decoder_ = DCCRN_DECODER(
|
|
||||||
hidden_size + hidden_size,
|
|
||||||
num_channels,
|
|
||||||
layer=layer,
|
|
||||||
kernel_size=(encoder_decoder["kernel_size"], 2),
|
|
||||||
stride=(encoder_decoder["stride"], 1),
|
|
||||||
padding=(encoder_decoder["padding"], 0),
|
|
||||||
output_padding=(encoder_decoder["output_padding"], 0),
|
|
||||||
complex_norm=complex_norm,
|
|
||||||
complex_relu=complex_relu,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decoder.insert(0, decoder_)
|
|
||||||
|
|
||||||
if layer < encoder_decoder["depth"] - 3:
|
|
||||||
num_channels = hidden_size
|
|
||||||
hidden_size *= growth_factor
|
|
||||||
else:
|
|
||||||
num_channels = hidden_size
|
|
||||||
|
|
||||||
kernel_size = hidden_size / 2
|
|
||||||
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
|
|
||||||
|
|
||||||
if self.complex_lstm:
|
|
||||||
lstms = []
|
|
||||||
for layer in range(lstm["num_layers"]):
|
|
||||||
|
|
||||||
if layer == 0:
|
|
||||||
input_size = int(hidden_size * kernel_size)
|
|
||||||
else:
|
|
||||||
input_size = lstm["hidden_size"]
|
|
||||||
|
|
||||||
if layer == lstm["num_layers"] - 1:
|
|
||||||
projection_size = int(hidden_size * kernel_size)
|
|
||||||
else:
|
|
||||||
projection_size = None
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"input_size": input_size,
|
|
||||||
"hidden_size": lstm["hidden_size"],
|
|
||||||
"num_layers": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
lstms.append(
|
|
||||||
ComplexLSTM(projection_size=projection_size, **kwargs)
|
|
||||||
)
|
|
||||||
self.lstm = nn.Sequential(*lstms)
|
|
||||||
else:
|
|
||||||
self.lstm = nn.Sequential(
|
|
||||||
nn.LSTM(
|
|
||||||
input_size=hidden_size * kernel_size,
|
|
||||||
hidden_sizs=lstm["hidden_size"],
|
|
||||||
num_layers=lstm["num_layers"],
|
|
||||||
dropout=0.0,
|
|
||||||
batch_first=False,
|
|
||||||
)[0],
|
|
||||||
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, waveform):
|
|
||||||
|
|
||||||
if waveform.dim() == 2:
|
|
||||||
waveform = waveform.unsqueeze(1)
|
|
||||||
|
|
||||||
if waveform.size(1) != self.hparams.num_channels:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
|
||||||
)
|
|
||||||
|
|
||||||
waveform_stft = self.stft(waveform)
|
|
||||||
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
|
||||||
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
|
||||||
|
|
||||||
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
|
|
||||||
phase_spec = torch.atan2(imag, real)
|
|
||||||
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
|
|
||||||
|
|
||||||
encoder_outputs = []
|
|
||||||
out = complex_spec
|
|
||||||
for _, encoder in enumerate(self.encoder):
|
|
||||||
out = encoder(out)
|
|
||||||
encoder_outputs.append(out)
|
|
||||||
|
|
||||||
B, C, D, T = out.size()
|
|
||||||
out = out.permute(3, 0, 1, 2)
|
|
||||||
if self.complex_lstm:
|
|
||||||
|
|
||||||
lstm_real = out[:, :, : C // 2]
|
|
||||||
lstm_imag = out[:, :, C // 2 :]
|
|
||||||
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
|
|
||||||
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
|
|
||||||
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
|
|
||||||
lstm_real = lstm_real.reshape(T, B, C // 2, D)
|
|
||||||
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
|
|
||||||
out = torch.cat([lstm_real, lstm_imag], 2)
|
|
||||||
else:
|
|
||||||
out = out.reshape(T, B, C * D)
|
|
||||||
out = self.lstm(out)
|
|
||||||
out = out.reshape(T, B, D, C)
|
|
||||||
|
|
||||||
out = out.permute(1, 2, 3, 0)
|
|
||||||
for layer, decoder in enumerate(self.decoder):
|
|
||||||
skip_connection = encoder_outputs.pop(-1)
|
|
||||||
out = complex_cat([skip_connection, out])
|
|
||||||
out = decoder(out)
|
|
||||||
out = out[..., 1:]
|
|
||||||
mask_real, mask_imag = out[:, 0], out[:, 1]
|
|
||||||
mask_real = F.pad(mask_real, [0, 0, 1, 0])
|
|
||||||
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
|
|
||||||
if self.masking_mode == "E":
|
|
||||||
|
|
||||||
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
|
|
||||||
real_phase = mask_real / (mask_mag + 1e-8)
|
|
||||||
imag_phase = mask_imag / (mask_mag + 1e-8)
|
|
||||||
mask_phase = torch.atan2(imag_phase, real_phase)
|
|
||||||
mask_mag = torch.tanh(mask_mag)
|
|
||||||
est_mag = mask_mag * mag_spec
|
|
||||||
est_phase = mask_phase * phase_spec
|
|
||||||
# cos(theta) + isin(theta)
|
|
||||||
real = est_mag + torch.cos(est_phase)
|
|
||||||
imag = est_mag + torch.sin(est_phase)
|
|
||||||
|
|
||||||
if self.masking_mode == "C":
|
|
||||||
|
|
||||||
real = real * mask_real - imag * mask_imag
|
|
||||||
imag = real * mask_imag + imag * mask_real
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
real = real * mask_real
|
|
||||||
imag = imag * mask_imag
|
|
||||||
|
|
||||||
spec = torch.cat([real, imag], 1)
|
|
||||||
wav = self.istft(spec)
|
|
||||||
wav = wav.clamp_(-1, 1)
|
|
||||||
return wav
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
from mayavoz.utils.config import Files
|
|
||||||
from mayavoz.utils.io import Audio
|
|
||||||
from mayavoz.utils.utils import check_files
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from scipy.signal import get_window
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class ConvFFT(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_len: int,
|
|
||||||
nfft: Optional[int] = None,
|
|
||||||
window: str = "hamming",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.window_len = window_len
|
|
||||||
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
|
|
||||||
self.window = torch.from_numpy(
|
|
||||||
get_window(window, window_len, fftbins=True).astype("float32")
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_kernel(self, inverse=False):
|
|
||||||
|
|
||||||
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
|
|
||||||
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
|
|
||||||
kernel = np.concatenate([real, imag], 1).T
|
|
||||||
if inverse:
|
|
||||||
kernel = np.linalg.pinv(kernel).T
|
|
||||||
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
|
|
||||||
kernel *= self.window
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
class ConvSTFT(ConvFFT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_len: int,
|
|
||||||
hop_size: Optional[int] = None,
|
|
||||||
nfft: Optional[int] = None,
|
|
||||||
window: str = "hamming",
|
|
||||||
):
|
|
||||||
super().__init__(window_len=window_len, nfft=nfft, window=window)
|
|
||||||
self.hop_size = hop_size if hop_size else window_len // 2
|
|
||||||
self.register_buffer("weight", self.init_kernel())
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
if input.dim() < 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected signal with shape 2 or 3 got {input.dim()}"
|
|
||||||
)
|
|
||||||
elif input.dim() == 2:
|
|
||||||
input = input.unsqueeze(1)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
input = F.pad(
|
|
||||||
input,
|
|
||||||
(self.window_len - self.hop_size, self.window_len - self.hop_size),
|
|
||||||
)
|
|
||||||
output = F.conv1d(input, self.weight, stride=self.hop_size)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class ConviSTFT(ConvFFT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_len: int,
|
|
||||||
hop_size: Optional[int] = None,
|
|
||||||
nfft: Optional[int] = None,
|
|
||||||
window: str = "hamming",
|
|
||||||
):
|
|
||||||
super().__init__(window_len=window_len, nfft=nfft, window=window)
|
|
||||||
self.hop_size = hop_size if hop_size else window_len // 2
|
|
||||||
self.register_buffer("weight", self.init_kernel(True))
|
|
||||||
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
|
|
||||||
|
|
||||||
def forward(self, input, phase=None):
|
|
||||||
|
|
||||||
if phase is not None:
|
|
||||||
real = input * torch.cos(phase)
|
|
||||||
imag = input * torch.sin(phase)
|
|
||||||
input = torch.cat([real, imag], 1)
|
|
||||||
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
|
|
||||||
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
|
|
||||||
coeff = coeff.to(input.device)
|
|
||||||
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
|
|
||||||
out = out / (coeff + 1e-8)
|
|
||||||
pad = self.window_len - self.hop_size
|
|
||||||
out = out[..., pad:-pad]
|
|
||||||
return out
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Configuration for generating Noisy Speech Dataset
|
||||||
|
|
||||||
|
# - sampling_rate: Specify the sampling rate. Default is 16 kHz
|
||||||
|
# - audioformat: default is .wav
|
||||||
|
# - audio_length: Minimum Length of each audio clip (noisy and clean speech) in seconds that will be generated by augmenting utterances.
|
||||||
|
# - silence_length: Duration of silence introduced between clean speech utterances.
|
||||||
|
# - total_hours: Total number of hours of data required. Units are in hours.
|
||||||
|
# - snr_lower: Lower bound for SNR required (default: 0 dB)
|
||||||
|
# - snr_upper: Upper bound for SNR required (default: 40 dB)
|
||||||
|
# - total_snrlevels: Number of SNR levels required (default: 5, which means there are 5 levels between snr_lower and snr_upper)
|
||||||
|
# - noise_dir: Default is None. But specify the noise directory path if noise files are not in the source directory
|
||||||
|
# - Speech_dir: Default is None. But specify the speech directory path if speech files are not in the source directory
|
||||||
|
# - noise_types_excluded: Noise files starting with the following tags to be excluded in the noise list. Example: noise_types_excluded: Babble, AirConditioner
|
||||||
|
# Specify 'None' if no noise files to be excluded.
|
||||||
|
|
||||||
|
[noisy_speech]
|
||||||
|
|
||||||
|
sampling_rate: 16000
|
||||||
|
audioformat: *.wav
|
||||||
|
audio_length: 10
|
||||||
|
silence_length: 0.2
|
||||||
|
total_hours: 1
|
||||||
|
snr_lower: 0
|
||||||
|
snr_upper: 40
|
||||||
|
total_snrlevels: 2
|
||||||
|
naming: test
|
||||||
|
|
||||||
|
noise_dir: /scratch/c.sistc3/MS-SNSD/noise_test
|
||||||
|
speech_dir: /scratch/c.sistc3/MS-SNSD/clean_test
|
||||||
|
noise_types_excluded: None
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
"""
|
||||||
|
@author: chkarada
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import configparser as CP
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from audiolib import audioread, audiowrite, snr_mixer
|
||||||
|
|
||||||
|
|
||||||
|
def main(cfg):
|
||||||
|
snr_lower = float(cfg["snr_lower"])
|
||||||
|
snr_upper = float(cfg["snr_upper"])
|
||||||
|
total_snrlevels = int(cfg["total_snrlevels"])
|
||||||
|
|
||||||
|
clean_dir = os.path.join(os.path.dirname(__file__), "clean_train")
|
||||||
|
if cfg["speech_dir"] != "None":
|
||||||
|
clean_dir = cfg["speech_dir"]
|
||||||
|
if not os.path.exists(clean_dir):
|
||||||
|
assert False, "Clean speech data is required"
|
||||||
|
|
||||||
|
noise_dir = os.path.join(os.path.dirname(__file__), "noise_train")
|
||||||
|
if cfg["noise_dir"] != "None":
|
||||||
|
noise_dir = cfg["noise_dir"]
|
||||||
|
if not os.path.exists(noise_dir):
|
||||||
|
assert False, "Noise data is required"
|
||||||
|
name = cfg["naming"]
|
||||||
|
fs = float(cfg["sampling_rate"])
|
||||||
|
audioformat = cfg["audioformat"]
|
||||||
|
total_hours = float(cfg["total_hours"])
|
||||||
|
audio_length = float(cfg["audio_length"])
|
||||||
|
silence_length = float(cfg["silence_length"])
|
||||||
|
noisyspeech_dir = os.path.join(
|
||||||
|
os.path.dirname(__file__), f"NoisySpeech_{name}ing"
|
||||||
|
)
|
||||||
|
if not os.path.exists(noisyspeech_dir):
|
||||||
|
os.makedirs(noisyspeech_dir)
|
||||||
|
clean_proc_dir = os.path.join(
|
||||||
|
os.path.dirname(__file__), f"CleanSpeech_{name}ing"
|
||||||
|
)
|
||||||
|
if not os.path.exists(clean_proc_dir):
|
||||||
|
os.makedirs(clean_proc_dir)
|
||||||
|
noise_proc_dir = os.path.join(
|
||||||
|
os.path.dirname(__file__), f"NoisySpeech_{name}ing"
|
||||||
|
)
|
||||||
|
if not os.path.exists(noise_proc_dir):
|
||||||
|
os.makedirs(noise_proc_dir)
|
||||||
|
|
||||||
|
total_secs = total_hours * 60 * 60
|
||||||
|
total_samples = int(total_secs * fs)
|
||||||
|
audio_length = int(audio_length * fs)
|
||||||
|
SNR = np.linspace(snr_lower, snr_upper, total_snrlevels)
|
||||||
|
cleanfilenames = glob.glob(os.path.join(clean_dir, audioformat))
|
||||||
|
if cfg["noise_types_excluded"] == "None":
|
||||||
|
noisefilenames = glob.glob(os.path.join(noise_dir, audioformat))
|
||||||
|
else:
|
||||||
|
filestoexclude = cfg["noise_types_excluded"].split(",")
|
||||||
|
noisefilenames = glob.glob(os.path.join(noise_dir, audioformat))
|
||||||
|
for i in range(len(filestoexclude)):
|
||||||
|
noisefilenames = [
|
||||||
|
fn
|
||||||
|
for fn in noisefilenames
|
||||||
|
if not os.path.basename(fn).startswith(filestoexclude[i])
|
||||||
|
]
|
||||||
|
|
||||||
|
filecounter = 0
|
||||||
|
num_samples = 0
|
||||||
|
|
||||||
|
while num_samples < total_samples:
|
||||||
|
idx_s = np.random.randint(0, np.size(cleanfilenames))
|
||||||
|
clean, fs = audioread(cleanfilenames[idx_s])
|
||||||
|
|
||||||
|
if len(clean) > audio_length:
|
||||||
|
clean = clean
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
while len(clean) <= audio_length:
|
||||||
|
idx_s = idx_s + 1
|
||||||
|
if idx_s >= np.size(cleanfilenames) - 1:
|
||||||
|
idx_s = np.random.randint(0, np.size(cleanfilenames))
|
||||||
|
newclean, fs = audioread(cleanfilenames[idx_s])
|
||||||
|
cleanconcat = np.append(
|
||||||
|
clean, np.zeros(int(fs * silence_length))
|
||||||
|
)
|
||||||
|
clean = np.append(cleanconcat, newclean)
|
||||||
|
|
||||||
|
idx_n = np.random.randint(0, np.size(noisefilenames))
|
||||||
|
noise, fs = audioread(noisefilenames[idx_n])
|
||||||
|
|
||||||
|
if len(noise) >= len(clean):
|
||||||
|
noise = noise[0 : len(clean)]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
while len(noise) <= len(clean):
|
||||||
|
idx_n = idx_n + 1
|
||||||
|
if idx_n >= np.size(noisefilenames) - 1:
|
||||||
|
idx_n = np.random.randint(0, np.size(noisefilenames))
|
||||||
|
newnoise, fs = audioread(noisefilenames[idx_n])
|
||||||
|
noiseconcat = np.append(
|
||||||
|
noise, np.zeros(int(fs * silence_length))
|
||||||
|
)
|
||||||
|
noise = np.append(noiseconcat, newnoise)
|
||||||
|
noise = noise[0 : len(clean)]
|
||||||
|
filecounter = filecounter + 1
|
||||||
|
|
||||||
|
for i in range(np.size(SNR)):
|
||||||
|
clean_snr, noise_snr, noisy_snr = snr_mixer(
|
||||||
|
clean=clean, noise=noise, snr=SNR[i]
|
||||||
|
)
|
||||||
|
noisyfilename = (
|
||||||
|
"noisy"
|
||||||
|
+ str(filecounter)
|
||||||
|
+ "_SNRdb_"
|
||||||
|
+ str(SNR[i])
|
||||||
|
+ "_clnsp"
|
||||||
|
+ str(filecounter)
|
||||||
|
+ ".wav"
|
||||||
|
)
|
||||||
|
cleanfilename = "clnsp" + str(filecounter) + ".wav"
|
||||||
|
noisefilename = (
|
||||||
|
"noisy" + str(filecounter) + "_SNRdb_" + str(SNR[i]) + ".wav"
|
||||||
|
)
|
||||||
|
noisypath = os.path.join(noisyspeech_dir, noisyfilename)
|
||||||
|
cleanpath = os.path.join(clean_proc_dir, cleanfilename)
|
||||||
|
noisepath = os.path.join(noise_proc_dir, noisefilename)
|
||||||
|
audiowrite(noisy_snr, fs, noisypath, norm=False)
|
||||||
|
audiowrite(clean_snr, fs, cleanpath, norm=False)
|
||||||
|
audiowrite(noise_snr, fs, noisepath, norm=False)
|
||||||
|
num_samples = num_samples + len(noisy_snr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# Configurations: read noisyspeech_synthesizer.cfg
|
||||||
|
parser.add_argument(
|
||||||
|
"--cfg",
|
||||||
|
default="noisyspeech_synthesizer.cfg",
|
||||||
|
help="Read noisyspeech_synthesizer.cfg for all the details",
|
||||||
|
)
|
||||||
|
parser.add_argument("--cfg_str", type=str, default="noisy_speech")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
cfgpath = os.path.join(os.path.dirname(__file__), args.cfg)
|
||||||
|
assert os.path.exists(cfgpath), f"No configuration file as [{cfgpath}]"
|
||||||
|
cfg = CP.ConfigParser()
|
||||||
|
cfg._interpolation = CP.ExtendedInterpolation()
|
||||||
|
cfg.read(cfgpath)
|
||||||
|
|
||||||
|
main(cfg._sections[args.cfg_str])
|
||||||
|
|
@ -1,338 +0,0 @@
|
||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "ccd61d5c",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Custom model training using mayavoz [advanced]\n",
|
|
||||||
"\n",
|
|
||||||
"In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
|
|
||||||
"\n",
|
|
||||||
" - [Data preparation using MayaDataset](#dataprep)\n",
|
|
||||||
" - [Model customization](#modelcustom)\n",
|
|
||||||
" - [callbacks & LR schedulers](#callbacks)\n",
|
|
||||||
" - [Model training & testing](#train)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "726c320f",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- **install mayavoz**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c987c799",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"! pip install -q mayavoz"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "8ff9857b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<div id=\"dataprep\"></div>\n",
|
|
||||||
"\n",
|
|
||||||
"### Data preparation\n",
|
|
||||||
"\n",
|
|
||||||
"`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
|
|
||||||
"\n",
|
|
||||||
"```\n",
|
|
||||||
"- VCTK/\n",
|
|
||||||
" |__ clean_train_wav/\n",
|
|
||||||
" |__ noisy_train_wav/\n",
|
|
||||||
" |__ clean_test_wav/\n",
|
|
||||||
" |__ noisy_test_wav/\n",
|
|
||||||
" \n",
|
|
||||||
"```"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "64cbc0c8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from mayavoz.utils import Files\n",
|
|
||||||
"file = Files(train_clean=\"clean_train_wav\",\n",
|
|
||||||
" train_noisy=\"noisy_train_wav\",\n",
|
|
||||||
" test_clean=\"clean_test_wav\",\n",
|
|
||||||
" test_noisy=\"noisy_test_wav\")\n",
|
|
||||||
"root_dir = \"VCTK\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "2d324bd1",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- `name`: name of the dataset. \n",
|
|
||||||
"- `duration`: control the duration of each audio instance fed into your model.\n",
|
|
||||||
"- `stride` is used if set to move the sliding window.\n",
|
|
||||||
"- `sampling_rate`: desired sampling rate for audio\n",
|
|
||||||
"- `batch_size`: model batch size\n",
|
|
||||||
"- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
|
|
||||||
"- `matching_function`: there are two types of mapping functions.\n",
|
|
||||||
" - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
|
|
||||||
" - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example MS-SNSD dataset.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "6834941d",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"name = \"vctk\"\n",
|
|
||||||
"duration : 4.5\n",
|
|
||||||
"stride : 2.0\n",
|
|
||||||
"sampling_rate : 16000\n",
|
|
||||||
"min_valid_minutes : 20.0\n",
|
|
||||||
"batch_size : 32\n",
|
|
||||||
"matching_function : \"one_to_one\"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "d08c6bf8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from mayavoz.dataset import MayaDataset\n",
|
|
||||||
"dataset = MayaDataset(\n",
|
|
||||||
" name=name,\n",
|
|
||||||
" root_dir=root_dir,\n",
|
|
||||||
" files=files,\n",
|
|
||||||
" duration=duration,\n",
|
|
||||||
" stride=stride,\n",
|
|
||||||
" sampling_rate=sampling_rate,\n",
|
|
||||||
" batch_size=batch_size,\n",
|
|
||||||
" min_valid_minutes=min_valid_minutes,\n",
|
|
||||||
" matching_function=matching_function\n",
|
|
||||||
" )"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "5b315bde",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Now your custom dataloader is ready!"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "01548fe5",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<div id=\"modelcustom\"></div>\n",
|
|
||||||
"\n",
|
|
||||||
"### Model Customization\n",
|
|
||||||
"Now, this is very easy. \n",
|
|
||||||
"\n",
|
|
||||||
"- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
|
|
||||||
" - `WaveUnet`\n",
|
|
||||||
" - `Demucs`\n",
|
|
||||||
" - `DCCRN`\n",
|
|
||||||
"- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n",
|
|
||||||
"- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
|
|
||||||
"- `dataset`: mayavoz dataset object as prepared earlier.\n",
|
|
||||||
"- `loss` : model loss. Multiple loss functions are available.\n",
|
|
||||||
"\n",
|
|
||||||
" \n",
|
|
||||||
" \n",
|
|
||||||
"you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
|
|
||||||
"\n",
|
|
||||||
"mayavoz can accept **custom loss functions**. It should be of the form.\n",
|
|
||||||
"```\n",
|
|
||||||
"class your_custom_loss(nn.Module):\n",
|
|
||||||
" def __init__(self,**kwargs):\n",
|
|
||||||
" self.higher_better = False ## loss minimization direction\n",
|
|
||||||
" self.name = \"your_loss_name\" ## loss name logging \n",
|
|
||||||
" ...\n",
|
|
||||||
" def forward(self,prediction, target):\n",
|
|
||||||
" loss = ....\n",
|
|
||||||
" return loss\n",
|
|
||||||
" \n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "b36b457c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from mayavoz.models import Demucs\n",
|
|
||||||
"model = Demucs(\n",
|
|
||||||
" sampling_rate=16000,\n",
|
|
||||||
" dataset=dataset,\n",
|
|
||||||
" loss=[\"mae\"],\n",
|
|
||||||
" metrics=[\"stoi\",\"pesq\"])\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "1523d638",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<div id=\"callbacks\"></div>\n",
|
|
||||||
"\n",
|
|
||||||
"### learning rate schedulers and callbacks\n",
|
|
||||||
"Here I am using `ReduceLROnPlateau`"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "8de6931c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
|
||||||
"\n",
|
|
||||||
"def configure_optimizers(self):\n",
|
|
||||||
" optimizer = instantiate(\n",
|
|
||||||
" config.optimizer,\n",
|
|
||||||
" lr=parameters.get(\"lr\"),\n",
|
|
||||||
" params=self.parameters(),\n",
|
|
||||||
" )\n",
|
|
||||||
" scheduler = ReduceLROnPlateau(\n",
|
|
||||||
" optimizer=optimizer,\n",
|
|
||||||
" mode=direction,\n",
|
|
||||||
" factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
|
|
||||||
" verbose=True,\n",
|
|
||||||
" min_lr=parameters.get(\"min_lr\", 1e-6),\n",
|
|
||||||
" patience=parameters.get(\"ReduceLr_patience\", 3),\n",
|
|
||||||
" )\n",
|
|
||||||
" return {\n",
|
|
||||||
" \"optimizer\": optimizer,\n",
|
|
||||||
" \"lr_scheduler\": scheduler,\n",
|
|
||||||
" \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"model.configure_optimizers = MethodType(configure_optimizers, model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "2f7b5af5",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "6f6b62a1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"callbacks = []\n",
|
|
||||||
"direction = model.valid_monitor ## min or max \n",
|
|
||||||
"checkpoint = ModelCheckpoint(\n",
|
|
||||||
" dirpath=\"./model\",\n",
|
|
||||||
" filename=f\"model_filename\",\n",
|
|
||||||
" monitor=\"valid_loss\",\n",
|
|
||||||
" verbose=False,\n",
|
|
||||||
" mode=direction,\n",
|
|
||||||
" every_n_epochs=1,\n",
|
|
||||||
" )\n",
|
|
||||||
"callbacks.append(checkpoint)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "f3534445",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"<div id=\"train\"></div>\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"### Train"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "3dc0348b",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import pytorch_lightning as pl\n",
|
|
||||||
"trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
|
|
||||||
"trainer.fit(model)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "56dcfec1",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Test your model agaist test dataset"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "63851feb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"trainer.test(model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "4d3f5350",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Hurray! you have your speech enhancement model trained and tested.**\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "10d630e8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "enhancer",
|
|
||||||
"language": "python",
|
|
||||||
"name": "enhancer"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.13"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -2,7 +2,6 @@
|
||||||
line-length = 80
|
line-length = 80
|
||||||
target-version = ['py38']
|
target-version = ['py38']
|
||||||
exclude = '''
|
exclude = '''
|
||||||
|
|
||||||
(
|
(
|
||||||
/(
|
/(
|
||||||
\.eggs # exclude a few common directories in the
|
\.eggs # exclude a few common directories in the
|
||||||
|
|
@ -10,6 +9,9 @@ exclude = '''
|
||||||
| \.mypy_cache
|
| \.mypy_cache
|
||||||
| \.tox
|
| \.tox
|
||||||
| \.venv
|
| \.venv
|
||||||
|
| noisyspeech_synthesizer.py
|
||||||
|
| noisyspeech_synthesizer.cfg
|
||||||
|
|
||||||
)/
|
)/
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
|
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
import os
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
from hydra.utils import instantiate
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pytorch_lightning.callbacks import (
|
|
||||||
EarlyStopping,
|
|
||||||
LearningRateMonitor,
|
|
||||||
ModelCheckpoint,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
|
|
||||||
# from torch_audiomentations import Compose, Shift
|
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
|
||||||
def train(config: DictConfig):
|
|
||||||
|
|
||||||
OmegaConf.save(config, "config.yaml")
|
|
||||||
|
|
||||||
callbacks = []
|
|
||||||
logger = MLFlowLogger(
|
|
||||||
experiment_name=config.mlflow.experiment_name,
|
|
||||||
run_name=config.mlflow.run_name,
|
|
||||||
tags={"JOB_ID": JOB_ID},
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
|
||||||
# apply_augmentations = Compose(
|
|
||||||
# [
|
|
||||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
dataset = instantiate(config.dataset, augmentations=None)
|
|
||||||
model = instantiate(
|
|
||||||
config.model,
|
|
||||||
dataset=dataset,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
loss=parameters.get("loss"),
|
|
||||||
metric=parameters.get("metric"),
|
|
||||||
)
|
|
||||||
|
|
||||||
direction = model.valid_monitor
|
|
||||||
checkpoint = ModelCheckpoint(
|
|
||||||
dirpath="./model",
|
|
||||||
filename=f"model_{JOB_ID}",
|
|
||||||
monitor="valid_loss",
|
|
||||||
verbose=False,
|
|
||||||
mode=direction,
|
|
||||||
every_n_epochs=1,
|
|
||||||
)
|
|
||||||
callbacks.append(checkpoint)
|
|
||||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
|
||||||
early_stopping = EarlyStopping(
|
|
||||||
monitor="val_loss",
|
|
||||||
mode=direction,
|
|
||||||
min_delta=0.0,
|
|
||||||
patience=parameters.get("EarlyStopping_patience", 10),
|
|
||||||
strict=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stopping)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = instantiate(
|
|
||||||
config.optimizer,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
params=self.parameters(),
|
|
||||||
)
|
|
||||||
scheduler = ReduceLROnPlateau(
|
|
||||||
optimizer=optimizer,
|
|
||||||
mode=direction,
|
|
||||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
|
||||||
verbose=True,
|
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"optimizer": optimizer,
|
|
||||||
"lr_scheduler": scheduler,
|
|
||||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
|
||||||
}
|
|
||||||
|
|
||||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
|
||||||
trainer.fit(model)
|
|
||||||
trainer.test(model)
|
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
saved_location = os.path.join(
|
|
||||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
|
||||||
)
|
|
||||||
if os.path.isfile(saved_location):
|
|
||||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_train_steps_per_epoch",
|
|
||||||
dataset.train__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_valid_steps_per_epoch",
|
|
||||||
dataset.val__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
train()
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
defaults:
|
|
||||||
- model : Demucs
|
|
||||||
- dataset : MS-SNSD
|
|
||||||
- optimizer : Adam
|
|
||||||
- hyperparameters : default
|
|
||||||
- trainer : default
|
|
||||||
- mlflow : experiment
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
|
||||||
name : MS-SDSD
|
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
|
||||||
duration : 1.5
|
|
||||||
stride : 1
|
|
||||||
sampling_rate: 16000
|
|
||||||
batch_size: 32
|
|
||||||
min_valid_minutes: 25
|
|
||||||
files:
|
|
||||||
train_clean : CleanSpeech_training
|
|
||||||
test_clean : CleanSpeech_training
|
|
||||||
train_noisy : NoisySpeech_training
|
|
||||||
test_noisy : NoisySpeech_training
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
experiment_name : shahules/mayavoz
|
|
||||||
run_name : Demucs + Vtck with stride + augmentations
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
_target_: mayavoz.models.dccrn.DCCRN
|
|
||||||
num_channels: 1
|
|
||||||
sampling_rate : 16000
|
|
||||||
complex_lstm : True
|
|
||||||
complex_norm : True
|
|
||||||
complex_relu : True
|
|
||||||
masking_mode : True
|
|
||||||
|
|
||||||
encoder_decoder:
|
|
||||||
initial_output_channels : 32
|
|
||||||
depth : 6
|
|
||||||
kernel_size : 5
|
|
||||||
growth_factor : 2
|
|
||||||
stride : 2
|
|
||||||
padding : 2
|
|
||||||
output_padding : 1
|
|
||||||
|
|
||||||
lstm:
|
|
||||||
num_layers : 2
|
|
||||||
hidden_size : 256
|
|
||||||
|
|
||||||
stft:
|
|
||||||
window_len : 400
|
|
||||||
hop_size : 100
|
|
||||||
nfft : 512
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
_target_: torch.optim.Adam
|
|
||||||
lr: 1e-3
|
|
||||||
betas: [0.9, 0.999]
|
|
||||||
eps: 1e-08
|
|
||||||
weight_decay: 0
|
|
||||||
amsgrad: False
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
accelerator: gpu
|
|
||||||
accumulate_grad_batches: 1
|
|
||||||
amp_backend: native
|
|
||||||
auto_lr_find: True
|
|
||||||
auto_scale_batch_size: False
|
|
||||||
auto_select_gpus: True
|
|
||||||
benchmark: False
|
|
||||||
check_val_every_n_epoch: 1
|
|
||||||
detect_anomaly: False
|
|
||||||
deterministic: False
|
|
||||||
devices: 2
|
|
||||||
enable_checkpointing: True
|
|
||||||
enable_model_summary: True
|
|
||||||
enable_progress_bar: True
|
|
||||||
fast_dev_run: False
|
|
||||||
gpus: null
|
|
||||||
gradient_clip_val: 0
|
|
||||||
gradient_clip_algorithm: norm
|
|
||||||
ipus: null
|
|
||||||
limit_predict_batches: 1.0
|
|
||||||
limit_test_batches: 1.0
|
|
||||||
limit_train_batches: 1.0
|
|
||||||
limit_val_batches: 1.0
|
|
||||||
log_every_n_steps: 50
|
|
||||||
max_epochs: 200
|
|
||||||
max_steps: -1
|
|
||||||
max_time: null
|
|
||||||
min_epochs: 1
|
|
||||||
min_steps: null
|
|
||||||
move_metrics_to_cpu: False
|
|
||||||
multiple_trainloader_mode: max_size_cycle
|
|
||||||
num_nodes: 1
|
|
||||||
num_processes: 1
|
|
||||||
num_sanity_val_steps: 2
|
|
||||||
overfit_batches: 0.0
|
|
||||||
precision: 32
|
|
||||||
profiler: null
|
|
||||||
reload_dataloaders_every_n_epochs: 0
|
|
||||||
replace_sampler_ddp: True
|
|
||||||
strategy: ddp
|
|
||||||
sync_batchnorm: False
|
|
||||||
tpu_cores: null
|
|
||||||
track_grad_norm: -1
|
|
||||||
val_check_interval: 1.0
|
|
||||||
weights_save_path: null
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
fast_dev_run: True
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
import os
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
from hydra.utils import instantiate
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pytorch_lightning.callbacks import (
|
|
||||||
EarlyStopping,
|
|
||||||
LearningRateMonitor,
|
|
||||||
ModelCheckpoint,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
|
|
||||||
# from torch_audiomentations import Compose, Shift
|
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
|
||||||
def train(config: DictConfig):
|
|
||||||
|
|
||||||
OmegaConf.save(config, "config.yaml")
|
|
||||||
|
|
||||||
callbacks = []
|
|
||||||
logger = MLFlowLogger(
|
|
||||||
experiment_name=config.mlflow.experiment_name,
|
|
||||||
run_name=config.mlflow.run_name,
|
|
||||||
tags={"JOB_ID": JOB_ID},
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
|
||||||
# apply_augmentations = Compose(
|
|
||||||
# [
|
|
||||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
dataset = instantiate(config.dataset, augmentations=None)
|
|
||||||
model = instantiate(
|
|
||||||
config.model,
|
|
||||||
dataset=dataset,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
loss=parameters.get("loss"),
|
|
||||||
metric=parameters.get("metric"),
|
|
||||||
)
|
|
||||||
|
|
||||||
direction = model.valid_monitor
|
|
||||||
checkpoint = ModelCheckpoint(
|
|
||||||
dirpath="./model",
|
|
||||||
filename=f"model_{JOB_ID}",
|
|
||||||
monitor="valid_loss",
|
|
||||||
verbose=False,
|
|
||||||
mode=direction,
|
|
||||||
every_n_epochs=1,
|
|
||||||
)
|
|
||||||
callbacks.append(checkpoint)
|
|
||||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
|
||||||
early_stopping = EarlyStopping(
|
|
||||||
monitor="val_loss",
|
|
||||||
mode=direction,
|
|
||||||
min_delta=0.0,
|
|
||||||
patience=parameters.get("EarlyStopping_patience", 10),
|
|
||||||
strict=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stopping)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = instantiate(
|
|
||||||
config.optimizer,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
params=self.parameters(),
|
|
||||||
)
|
|
||||||
scheduler = ReduceLROnPlateau(
|
|
||||||
optimizer=optimizer,
|
|
||||||
mode=direction,
|
|
||||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
|
||||||
verbose=True,
|
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"optimizer": optimizer,
|
|
||||||
"lr_scheduler": scheduler,
|
|
||||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
|
||||||
}
|
|
||||||
|
|
||||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
|
||||||
trainer.fit(model)
|
|
||||||
trainer.test(model)
|
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
saved_location = os.path.join(
|
|
||||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
|
||||||
)
|
|
||||||
if os.path.isfile(saved_location):
|
|
||||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_train_steps_per_epoch",
|
|
||||||
dataset.train__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_valid_steps_per_epoch",
|
|
||||||
dataset.val__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
train()
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
defaults:
|
|
||||||
- model : Demucs
|
|
||||||
- dataset : MS-SNSD
|
|
||||||
- optimizer : Adam
|
|
||||||
- hyperparameters : default
|
|
||||||
- trainer : default
|
|
||||||
- mlflow : experiment
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
|
||||||
name : MS-SDSD
|
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
|
||||||
duration : 5
|
|
||||||
stride : 1
|
|
||||||
sampling_rate: 16000
|
|
||||||
batch_size: 32
|
|
||||||
min_valid_minutes: 25
|
|
||||||
files:
|
|
||||||
train_clean : CleanSpeech_training
|
|
||||||
test_clean : CleanSpeech_training
|
|
||||||
train_noisy : NoisySpeech_training
|
|
||||||
test_noisy : NoisySpeech_training
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
loss : mae
|
|
||||||
metric : [stoi,pesq]
|
|
||||||
lr : 0.0003
|
|
||||||
ReduceLr_patience : 10
|
|
||||||
ReduceLr_factor : 0.5
|
|
||||||
min_lr : 0.000001
|
|
||||||
EarlyStopping_factor : 10
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
experiment_name : shahules/mayavoz
|
|
||||||
run_name : demucs-ms-snsd
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
_target_: mayavoz.models.demucs.Demucs
|
|
||||||
num_channels: 1
|
|
||||||
resample: 4
|
|
||||||
sampling_rate : 16000
|
|
||||||
|
|
||||||
encoder_decoder:
|
|
||||||
depth: 4
|
|
||||||
initial_output_channels: 64
|
|
||||||
kernel_size: 8
|
|
||||||
stride: 4
|
|
||||||
growth_factor: 2
|
|
||||||
glu: True
|
|
||||||
|
|
||||||
lstm:
|
|
||||||
bidirectional: False
|
|
||||||
num_layers: 2
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
_target_: torch.optim.Adam
|
|
||||||
lr: 1e-3
|
|
||||||
betas: [0.9, 0.999]
|
|
||||||
eps: 1e-08
|
|
||||||
weight_decay: 0
|
|
||||||
amsgrad: False
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
accelerator: gpu
|
|
||||||
accumulate_grad_batches: 1
|
|
||||||
amp_backend: native
|
|
||||||
auto_lr_find: True
|
|
||||||
auto_scale_batch_size: False
|
|
||||||
auto_select_gpus: True
|
|
||||||
benchmark: False
|
|
||||||
check_val_every_n_epoch: 1
|
|
||||||
detect_anomaly: False
|
|
||||||
deterministic: False
|
|
||||||
devices: 2
|
|
||||||
enable_checkpointing: True
|
|
||||||
enable_model_summary: True
|
|
||||||
enable_progress_bar: True
|
|
||||||
fast_dev_run: False
|
|
||||||
gpus: null
|
|
||||||
gradient_clip_val: 0
|
|
||||||
gradient_clip_algorithm: norm
|
|
||||||
ipus: null
|
|
||||||
limit_predict_batches: 1.0
|
|
||||||
limit_test_batches: 1.0
|
|
||||||
limit_train_batches: 1.0
|
|
||||||
limit_val_batches: 1.0
|
|
||||||
log_every_n_steps: 50
|
|
||||||
max_epochs: 200
|
|
||||||
max_steps: -1
|
|
||||||
max_time: null
|
|
||||||
min_epochs: 1
|
|
||||||
min_steps: null
|
|
||||||
move_metrics_to_cpu: False
|
|
||||||
multiple_trainloader_mode: max_size_cycle
|
|
||||||
num_nodes: 1
|
|
||||||
num_processes: 1
|
|
||||||
num_sanity_val_steps: 2
|
|
||||||
overfit_batches: 0.0
|
|
||||||
precision: 32
|
|
||||||
profiler: null
|
|
||||||
reload_dataloaders_every_n_epochs: 0
|
|
||||||
replace_sampler_ddp: True
|
|
||||||
strategy: ddp
|
|
||||||
sync_batchnorm: False
|
|
||||||
tpu_cores: null
|
|
||||||
track_grad_norm: -1
|
|
||||||
val_check_interval: 1.0
|
|
||||||
weights_save_path: null
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
fast_dev_run: True
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
### Microsoft Scalable Noisy Speech Dataset (MS-SNSD)
|
|
||||||
|
|
||||||
MS-SNSD is a speech datasetthat can scale to arbitrary sizes depending on the number of speakers, noise types, and Speech to Noise Ratio (SNR) levels desired.
|
|
||||||
|
|
||||||
### Dataset download & setup
|
|
||||||
- Follow steps in the official repo [here](https://github.com/microsoft/MS-SNSD) to download and setup the dataset.
|
|
||||||
|
|
||||||
**References**
|
|
||||||
```BibTex
|
|
||||||
@article{reddy2019scalable,
|
|
||||||
title={A Scalable Noisy Speech Dataset and Online Subjective Test Framework},
|
|
||||||
author={Reddy, Chandan KA and Beyrami, Ebrahim and Pool, Jamie and Cutler, Ross and Srinivasan, Sriram and Gehrke, Johannes},
|
|
||||||
journal={Proc. Interspeech 2019},
|
|
||||||
pages={1816--1820},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
import os
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
from hydra.utils import instantiate
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pytorch_lightning.callbacks import (
|
|
||||||
EarlyStopping,
|
|
||||||
LearningRateMonitor,
|
|
||||||
ModelCheckpoint,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
|
|
||||||
# from torch_audiomentations import Compose, Shift
|
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
|
||||||
def main(config: DictConfig):
|
|
||||||
|
|
||||||
OmegaConf.save(config, "config_log.yaml")
|
|
||||||
|
|
||||||
callbacks = []
|
|
||||||
logger = MLFlowLogger(
|
|
||||||
experiment_name=config.mlflow.experiment_name,
|
|
||||||
run_name=config.mlflow.run_name,
|
|
||||||
tags={"JOB_ID": JOB_ID},
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
|
||||||
# apply_augmentations = Compose(
|
|
||||||
# [
|
|
||||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
dataset = instantiate(config.dataset, augmentations=None)
|
|
||||||
model = instantiate(
|
|
||||||
config.model,
|
|
||||||
dataset=dataset,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
loss=parameters.get("loss"),
|
|
||||||
metric=parameters.get("metric"),
|
|
||||||
)
|
|
||||||
|
|
||||||
direction = model.valid_monitor
|
|
||||||
checkpoint = ModelCheckpoint(
|
|
||||||
dirpath="./model",
|
|
||||||
filename=f"model_{JOB_ID}",
|
|
||||||
monitor="valid_loss",
|
|
||||||
verbose=False,
|
|
||||||
mode=direction,
|
|
||||||
every_n_epochs=1,
|
|
||||||
)
|
|
||||||
callbacks.append(checkpoint)
|
|
||||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
|
||||||
early_stopping = EarlyStopping(
|
|
||||||
monitor="val_loss",
|
|
||||||
mode=direction,
|
|
||||||
min_delta=0.0,
|
|
||||||
patience=parameters.get("EarlyStopping_patience", 10),
|
|
||||||
strict=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stopping)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = instantiate(
|
|
||||||
config.optimizer,
|
|
||||||
lr=parameters.get("lr"),
|
|
||||||
params=self.parameters(),
|
|
||||||
)
|
|
||||||
scheduler = ReduceLROnPlateau(
|
|
||||||
optimizer=optimizer,
|
|
||||||
mode=direction,
|
|
||||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
|
||||||
verbose=True,
|
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"optimizer": optimizer,
|
|
||||||
"lr_scheduler": scheduler,
|
|
||||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
|
||||||
}
|
|
||||||
|
|
||||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
|
||||||
trainer.fit(model)
|
|
||||||
trainer.test(model)
|
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
saved_location = os.path.join(
|
|
||||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
|
||||||
)
|
|
||||||
if os.path.isfile(saved_location):
|
|
||||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_train_steps_per_epoch",
|
|
||||||
dataset.train__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
logger.experiment.log_param(
|
|
||||||
logger.run_id,
|
|
||||||
"num_valid_steps_per_epoch",
|
|
||||||
dataset.val__len__() / dataset.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
defaults:
|
|
||||||
- model : Demucs
|
|
||||||
- dataset : Vctk
|
|
||||||
- optimizer : Adam
|
|
||||||
- hyperparameters : default
|
|
||||||
- trainer : default
|
|
||||||
- mlflow : experiment
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
|
||||||
name : vctk
|
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
|
||||||
duration : 4.5
|
|
||||||
stride : 0.5
|
|
||||||
sampling_rate: 16000
|
|
||||||
batch_size: 32
|
|
||||||
min_valid_minutes : 25
|
|
||||||
files:
|
|
||||||
train_clean : clean_trainset_28spk_wav
|
|
||||||
test_clean : clean_testset_wav
|
|
||||||
train_noisy : noisy_trainset_28spk_wav
|
|
||||||
test_noisy : noisy_testset_wav
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
loss : mae
|
|
||||||
metric : [stoi,pesq,si-sdr]
|
|
||||||
lr : 0.0003
|
|
||||||
Early_stop : False
|
|
||||||
ReduceLr_patience : 10
|
|
||||||
ReduceLr_factor : 0.1
|
|
||||||
min_lr : 0.000001
|
|
||||||
EarlyStopping_factor : 10
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
experiment_name : shahules/mayavoz
|
|
||||||
run_name : baseline
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
_target_: mayavoz.models.demucs.Demucs
|
|
||||||
num_channels: 1
|
|
||||||
resample: 4
|
|
||||||
sampling_rate : 16000
|
|
||||||
|
|
||||||
encoder_decoder:
|
|
||||||
depth: 4
|
|
||||||
initial_output_channels: 64
|
|
||||||
kernel_size: 8
|
|
||||||
stride: 4
|
|
||||||
growth_factor: 2
|
|
||||||
glu: True
|
|
||||||
|
|
||||||
lstm:
|
|
||||||
bidirectional: True
|
|
||||||
num_layers: 2
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
_target_: torch.optim.Adam
|
|
||||||
lr: 1e-3
|
|
||||||
betas: [0.9, 0.999]
|
|
||||||
eps: 1e-08
|
|
||||||
weight_decay: 0
|
|
||||||
amsgrad: False
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
defaults:
|
|
||||||
- model : WaveUnet
|
|
||||||
- dataset : Vctk
|
|
||||||
- optimizer : Adam
|
|
||||||
- hyperparameters : default
|
|
||||||
- trainer : default
|
|
||||||
- mlflow : experiment
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
loss : mae
|
|
||||||
metric : [stoi,pesq,si-sdr]
|
|
||||||
lr : 0.003
|
|
||||||
ReduceLr_patience : 10
|
|
||||||
ReduceLr_factor : 0.1
|
|
||||||
min_lr : 0.000001
|
|
||||||
EarlyStopping_factor : 10
|
|
||||||
Early_stop : False
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
experiment_name : shahules/mayavoz
|
|
||||||
run_name : baseline
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
_target_: mayavoz.models.waveunet.WaveUnet
|
|
||||||
num_channels : 1
|
|
||||||
depth : 9
|
|
||||||
initial_output_channels: 24
|
|
||||||
sampling_rate : 16000
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
_target_: torch.optim.Adam
|
|
||||||
lr: 1e-3
|
|
||||||
betas: [0.9, 0.999]
|
|
||||||
eps: 1e-08
|
|
||||||
weight_decay: 0
|
|
||||||
amsgrad: False
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
accelerator: gpu
|
|
||||||
accumulate_grad_batches: 1
|
|
||||||
amp_backend: native
|
|
||||||
auto_lr_find: True
|
|
||||||
auto_scale_batch_size: False
|
|
||||||
auto_select_gpus: True
|
|
||||||
benchmark: False
|
|
||||||
check_val_every_n_epoch: 1
|
|
||||||
detect_anomaly: False
|
|
||||||
deterministic: False
|
|
||||||
devices: 2
|
|
||||||
enable_checkpointing: True
|
|
||||||
enable_model_summary: True
|
|
||||||
enable_progress_bar: True
|
|
||||||
fast_dev_run: False
|
|
||||||
gpus: null
|
|
||||||
gradient_clip_val: 0
|
|
||||||
gradient_clip_algorithm: norm
|
|
||||||
ipus: null
|
|
||||||
limit_predict_batches: 1.0
|
|
||||||
limit_test_batches: 1.0
|
|
||||||
limit_train_batches: 1.0
|
|
||||||
limit_val_batches: 1.0
|
|
||||||
log_every_n_steps: 50
|
|
||||||
max_epochs: 200
|
|
||||||
max_steps: -1
|
|
||||||
max_time: null
|
|
||||||
min_epochs: 1
|
|
||||||
min_steps: null
|
|
||||||
move_metrics_to_cpu: False
|
|
||||||
multiple_trainloader_mode: max_size_cycle
|
|
||||||
num_nodes: 1
|
|
||||||
num_processes: 1
|
|
||||||
num_sanity_val_steps: 2
|
|
||||||
overfit_batches: 0.0
|
|
||||||
precision: 32
|
|
||||||
profiler: null
|
|
||||||
reload_dataloaders_every_n_epochs: 0
|
|
||||||
replace_sampler_ddp: True
|
|
||||||
strategy: ddp
|
|
||||||
sync_batchnorm: False
|
|
||||||
tpu_cores: null
|
|
||||||
track_grad_norm: -1
|
|
||||||
val_check_interval: 1.0
|
|
||||||
weights_save_path: null
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
|
||||||
fast_dev_run: True
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
## Valentini dataset
|
|
||||||
|
|
||||||
Clean and noisy parallel speech database. The database was designed to train and test speech enhancement methods that operate at 48kHz. A more detailed description can be found in the papers associated with the database.[official page](https://datashare.ed.ac.uk/handle/10283/2791)
|
|
||||||
|
|
||||||
**References**
|
|
||||||
```BibTex
|
|
||||||
@misc{
|
|
||||||
title={Noisy speech database for training speech enhancement algorithms and TTS models},
|
|
||||||
author={Valentini-Botinhao, Cassia}, year={2017},
|
|
||||||
doi=https://doi.org/10.7488/ds/2117,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
boto3>=1.24.86
|
# torch>=1.12.1
|
||||||
huggingface-hub>=0.10.0
|
# torchaudio>=0.12.1
|
||||||
hydra-core>=1.2.0
|
# tqdm>=4.64.1
|
||||||
joblib>=1.2.0
|
configparser
|
||||||
librosa>=0.9.2
|
# boto3>=1.24.86
|
||||||
mlflow>=1.28.0
|
# huggingface-hub>=0.10.0
|
||||||
|
# hydra-core>=1.2.0
|
||||||
|
# joblib>=1.2.0
|
||||||
|
# librosa>=0.9.2
|
||||||
|
# mlflow>=1.29.0
|
||||||
numpy>=1.23.3
|
numpy>=1.23.3
|
||||||
pesq==0.0.4
|
# pesq==0.0.4
|
||||||
protobuf>=3.19.6
|
# protobuf>=3.19.6
|
||||||
pystoi==0.3.3
|
# pystoi==0.3.3
|
||||||
pytest-lazy-fixture>=0.6.3
|
# pytest-lazy-fixture>=0.6.3
|
||||||
pytorch-lightning>=1.7.7
|
# pytorch-lightning>=1.7.7
|
||||||
scikit-learn>=1.1.2
|
# scikit-learn>=1.1.2
|
||||||
scipy>=1.9.1
|
scipy>=1.9.1
|
||||||
soundfile>=0.11.0
|
soundfile>=0.11.0
|
||||||
torch>=1.12.1
|
|
||||||
torch-audiomentations==0.11.0
|
|
||||||
torchaudio>=0.12.1
|
|
||||||
tqdm>=4.64.1
|
|
||||||
|
|
|
||||||
10
setup.cfg
10
setup.cfg
|
|
@ -3,7 +3,7 @@
|
||||||
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
name = mayavoz
|
name = enhancer
|
||||||
description = Deep learning for speech enhacement
|
description = Deep learning for speech enhacement
|
||||||
author = Shahul Ess
|
author = Shahul Ess
|
||||||
author-email = shahules786@gmail.com
|
author-email = shahules786@gmail.com
|
||||||
|
|
@ -53,7 +53,7 @@ cli =
|
||||||
[options.entry_points]
|
[options.entry_points]
|
||||||
|
|
||||||
console_scripts =
|
console_scripts =
|
||||||
mayavoz-train=mayavoz.cli.train:train
|
enhancer-train=enhancer.cli.train:train
|
||||||
|
|
||||||
[test]
|
[test]
|
||||||
# py.test options when running `python setup.py test`
|
# py.test options when running `python setup.py test`
|
||||||
|
|
@ -66,7 +66,7 @@ extras = True
|
||||||
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
||||||
# in order to write a coverage file that can be read by Jenkins.
|
# in order to write a coverage file that can be read by Jenkins.
|
||||||
addopts =
|
addopts =
|
||||||
--cov mayavoz --cov-report term-missing
|
--cov enhancer --cov-report term-missing
|
||||||
--verbose
|
--verbose
|
||||||
norecursedirs =
|
norecursedirs =
|
||||||
dist
|
dist
|
||||||
|
|
@ -98,7 +98,3 @@ exclude =
|
||||||
build
|
build
|
||||||
dist
|
dist
|
||||||
.eggs
|
.eggs
|
||||||
|
|
||||||
[options.data_files]
|
|
||||||
. = requirements.txt
|
|
||||||
_ = version.txt
|
|
||||||
|
|
|
||||||
6
setup.py
6
setup.py
|
|
@ -33,15 +33,15 @@ elif sha != "Unknown":
|
||||||
version += "+" + sha[:7]
|
version += "+" + sha[:7]
|
||||||
print("-- Building version " + version)
|
print("-- Building version " + version)
|
||||||
|
|
||||||
version_path = ROOT_DIR / "mayavoz" / "version.py"
|
version_path = ROOT_DIR / "enhancer" / "version.py"
|
||||||
|
|
||||||
with open(version_path, "w") as f:
|
with open(version_path, "w") as f:
|
||||||
f.write("__version__ = '{}'\n".format(version))
|
f.write("__version__ = '{}'\n".format(version))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
setup(
|
setup(
|
||||||
name="mayavoz",
|
name="enhancer",
|
||||||
namespace_packages=["mayavoz"],
|
namespace_packages=["enhancer"],
|
||||||
version=version,
|
version=version,
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "Loading Anaconda Module"
|
||||||
|
module load anaconda
|
||||||
|
|
||||||
|
echo "Creating Virtual Environment"
|
||||||
|
conda env create -f environment.yml || conda env update -f environment.yml
|
||||||
|
|
||||||
|
source activate enhancer
|
||||||
|
|
||||||
|
echo "copying files"
|
||||||
|
# cp /scratch/$USER/TIMIT/.* /deep-transcriber
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mayavoz.loss import mean_absolute_error, mean_squared_error
|
from enhancer.loss import mean_absolute_error, mean_squared_error
|
||||||
|
|
||||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue