Compare commits
85 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
7133670252 | |
|
|
40031ab084 | |
|
|
379800d3f6 | |
|
|
f21fa24f0e | |
|
|
afa89749ad | |
|
|
e6fb143c8f | |
|
|
cd7c008d34 | |
|
|
287df5bff4 | |
|
|
b56cdf877a | |
|
|
915574bd30 | |
|
|
c0cdb9e6e9 | |
|
|
d57ef2c10a | |
|
|
0faa06027f | |
|
|
c88a87e109 | |
|
|
fe82b398ee | |
|
|
a47b93b699 | |
|
|
763ea60a52 | |
|
|
0bca3f9949 | |
|
|
31ab30be04 | |
|
|
dd0b060e09 | |
|
|
1d2c5eee55 | |
|
|
f2111321bf | |
|
|
25139d7d3f | |
|
|
b343ea3610 | |
|
|
249c535921 | |
|
|
2de2c715ed | |
|
|
612c022d24 | |
|
|
8b3bc67529 | |
|
|
9187a940e7 | |
|
|
18c95cf219 | |
|
|
502cad0984 | |
|
|
dd27de7467 | |
|
|
d8d61a231b | |
|
|
80320bbf92 | |
|
|
ceb69a09c3 | |
|
|
60b654a065 | |
|
|
22a1f27e63 | |
|
|
65f1924593 | |
|
|
9525d2491f | |
|
|
f94bd22eb6 | |
|
|
f1fe1a803a | |
|
|
386931d09b | |
|
|
9927542713 | |
|
|
da85de13ad | |
|
|
9ee809a047 | |
|
|
7afe928ee1 | |
|
|
434b44ddc9 | |
|
|
191c6a7499 | |
|
|
b99ef95719 | |
|
|
2bfca78caa | |
|
|
003bab91f9 | |
|
|
d9b817f650 | |
|
|
90fbfbce73 | |
|
|
7c7db84c39 | |
|
|
a4f0fda6a5 | |
|
|
8bc63becce | |
|
|
bfd53937c2 | |
|
|
ba63c54399 | |
|
|
12cde1b0ab | |
|
|
f8a44f823a | |
|
|
7838e744a9 | |
|
|
1abc450ef8 | |
|
|
4a2865ff03 | |
|
|
0e664ed371 | |
|
|
cb6f9c20ed | |
|
|
8e4c12b98d | |
|
|
a0e38c5e5c | |
|
|
ebba5952e5 | |
|
|
69c7a0100c | |
|
|
470ec74bcb | |
|
|
a2e083b315 | |
|
|
252d380acc | |
|
|
4eff036c1c | |
|
|
d2a7e3c730 | |
|
|
3b8551640f | |
|
|
1d366d6096 | |
|
|
d90db16bce | |
|
|
e941235ec0 | |
|
|
27ddf0bec9 | |
|
|
bc13fc03bf | |
|
|
7a502671e2 | |
|
|
6384915e17 | |
|
|
94ab778c0b | |
|
|
ef06786d8c | |
|
|
e0fbf55dca |
2
.flake8
2
.flake8
|
|
@ -1,5 +1,5 @@
|
|||
[flake8]
|
||||
per-file-ignores = __init__.py:F401
|
||||
per-file-ignores = "mayavoz/model/__init__.py:F401"
|
||||
ignore = E203, E266, E501, W503
|
||||
# line length is intentionally set to 80 here because black uses Bugbear
|
||||
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
notebooks/** linguist-vendored
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Enhancer
|
||||
name: mayavoz
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ dev ]
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ dev ]
|
||||
branches: [ main ]
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -40,12 +40,12 @@ jobs:
|
|||
sudo apt-get install libsndfile1
|
||||
pip install -r requirements.txt
|
||||
pip install black pytest-cov
|
||||
- name: Install enhancer
|
||||
- name: Install mayavoz
|
||||
run: |
|
||||
pip install -e .[dev,testing]
|
||||
- name: Run black
|
||||
run:
|
||||
black --check . --exclude enhancer/version.py
|
||||
black --check . --exclude mayavoz/version.py
|
||||
- name: Test with pytest
|
||||
run:
|
||||
pytest tests --cov=enhancer/
|
||||
pytest tests --cov=mayavoz/
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
#local
|
||||
cleaned_my_voice.wav
|
||||
lightning_logs/
|
||||
my_voice.wav
|
||||
pretrained/
|
||||
*.ckpt
|
||||
*_local.yaml
|
||||
cli/train_config/dataset/Vctk_local.yaml
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ repos:
|
|||
hooks:
|
||||
- id: flake8
|
||||
args: ['--ignore=E203,E501,F811,E712,W503']
|
||||
exclude: __init__.py
|
||||
|
||||
# Formatting, Whitespace, etc
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# Contributing
|
||||
|
||||
Hi there 👋
|
||||
|
||||
If you're reading this I hope that you're looking forward to adding value to Mayavoz. This document will help you to get started with your journey.
|
||||
|
||||
## How to get your code in Mayavoz
|
||||
|
||||
1. We use git and GitHub.
|
||||
|
||||
2. Fork the mayavoz repository (https://github.com/shahules786/mayavoz) on GitHub under your own account. (This creates a copy of mayavoz under your account, and GitHub knows where it came from, and we typically call this “upstream”.)
|
||||
|
||||
3. Clone your own mayavoz repository. git clone https://github.com/ <your-account> /mayavoz (This downloads the git repository to your machine, git knows where it came from, and calls it “origin”.)
|
||||
|
||||
4. Create a branch for each specific feature you are developing. git checkout -b your-branch-name
|
||||
|
||||
5. Make + commit changes. git add files-you-changed ... git commit -m "Short message about what you did"
|
||||
|
||||
6. Push the branch to your GitHub repository. git push origin your-branch-name
|
||||
|
||||
7. Navigate to GitHub, and create a pull request from your branch to the upstream repository mayavoz/mayavoz, to the “develop” branch.
|
||||
|
||||
8. The Pull Request (PR) appears on the upstream repository. Discuss your contribution there. If you push more changes to your branch on GitHub (on your repository), they are added to the PR.
|
||||
|
||||
9. When the reviewer is satisfied that the code improves repository quality, they can merge.
|
||||
|
||||
Note that CI tests will be run when you create a PR. If you want to be sure that your code will not fail these tests, we have set up pre-commit hooks that you can install.
|
||||
|
||||
**If you're worried about things not being perfect with your code, we will work togethor and make it perfect. So, make your move!**
|
||||
|
||||
## Formating
|
||||
|
||||
We use [black](https://black.readthedocs.io/en/stable/) and [flake8](https://flake8.pycqa.org/en/latest/) for code formating. Please ensure that you use the same before submitting the PR.
|
||||
|
||||
|
||||
## Testing
|
||||
We adopt unit testing using [pytest](https://docs.pytest.org/en/latest/contents.html)
|
||||
Please make sure that adding your new component does not decrease test coverage.
|
||||
|
||||
## Other tools
|
||||
The use of [per-commit](https://pre-commit.com/) is recommended to ensure different requirements such as code formating, etc.
|
||||
|
||||
## How to start contributing to Mayavoz?
|
||||
|
||||
1. Checkout issues marked as `good first issue`, let us know you're interested in working on some issue by commenting under it.
|
||||
2. For others, I would suggest you to explore mayavoz. One way to do is to use it to train your own model. This was you might end by finding a new unreported bug or getting an idea to improve Mayavoz.
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2022 Shahul Es
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
recursive-include mayavoz *.py
|
||||
recursive-include mayavoz *.yaml
|
||||
global-exclude *.pyc
|
||||
global-exclude __pycache__
|
||||
51
README.md
51
README.md
|
|
@ -2,24 +2,52 @@
|
|||
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
|
||||
</p>
|
||||
|
||||
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
||||
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio practioners & researchers. It provides easy to use pretrained speech enhancement models and facilitates highly customisable model training.
|
||||
|
||||
| **[Quick Start](#quick-start-fire)** | **[Installation](#installation)** | **[Tutorials](https://github.com/shahules786/enhancer/tree/main/notebooks)** | **[Available Recipes](#recipes)** | **[Demo](#demo)**
|
||||
## Key features :key:
|
||||
|
||||
* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
|
||||
* :package: Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
|
||||
* Various pretrained models nicely integrated with [huggingface hub](https://huggingface.co/docs/hub/index) :hugs: that users can select and use without any hastle.
|
||||
* :package: Ability to train and validate your own custom speech enhancement models with just under 10 lines of code!
|
||||
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
|
||||
* :zap: Supports multi-gpu training integrated with Pytorch Lightning.
|
||||
* :zap: Supports multi-gpu training integrated with [Pytorch Lightning](https://pytorchlightning.ai/).
|
||||
* :shield: data augmentations integrated using [torch-augmentations](https://github.com/asteroid-team/torch-audiomentations)
|
||||
|
||||
|
||||
## Demo
|
||||
|
||||
Noisy speech followed by enhanced version.
|
||||
|
||||
https://user-images.githubusercontent.com/25312635/203756185-737557f4-6e21-4146-aa2c-95da69d0de4c.mp4
|
||||
|
||||
|
||||
|
||||
## Quick Start :fire:
|
||||
``` python
|
||||
from mayavoz import Mayamodel
|
||||
from mayavoz.models import Mayamodel
|
||||
|
||||
model = Mayamodel.from_pretrained("mayavoz/waveunet")
|
||||
model("noisy_audio.wav")
|
||||
model = Mayamodel.from_pretrained("shahules786/mayavoz-waveunet-valentini-28spk")
|
||||
model.enhance("noisy_audio.wav")
|
||||
```
|
||||
|
||||
## Recipes
|
||||
|
||||
| Model | Dataset | STOI | PESQ | URL |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| WaveUnet | Valentini-28spk | 0.836 | 2.78 | shahules786/mayavoz-waveunet-valentini-28spk |
|
||||
| Demucs | Valentini-28spk | 0.961 | 2.56 | shahules786/mayavoz-demucs-valentini-28spk |
|
||||
| DCCRN | Valentini-28spk | 0.724 | 2.55 | shahules786/mayavoz-dccrn-valentini-28spk |
|
||||
| Demucs | MS-SNSD-20hrs | 0.56 | 1.26 | shahules786/mayavoz-demucs-ms-snsd-20 |
|
||||
|
||||
Test scores are based on respective test set associated with train dataset.
|
||||
|
||||
**See [tutorials](/notebooks/) to train your custom model**
|
||||
|
||||
## Installation
|
||||
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
|
||||
|
||||
|
|
@ -41,3 +69,10 @@ git clone url
|
|||
cd mayavoz
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For commercial enquiries and scientific consulting, please [contact me](https://shahules786.github.io/).
|
||||
|
||||
### Acknowledgements
|
||||
Sincere gratitude to [AMPLYFI](https://amplyfi.com/) for supporting this project.
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
name : vctk
|
||||
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
|
||||
duration : 1.0
|
||||
sampling_rate: 16000
|
||||
batch_size: 64
|
||||
num_workers : 0
|
||||
|
||||
files:
|
||||
train_clean : clean_testset_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_testset_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -1 +0,0 @@
|
|||
from enhancer.data.dataset import EnhancerDataset
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from enhancer.models.demucs import Demucs
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.models.waveunet import WaveUnet
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
|
||||
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
||||
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
|
||||
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
||||
from enhancer.models.complexnn.utils import ComplexRelu # noqa
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from enhancer.utils.config import Files
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.utils import check_files
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
name: enhancer
|
||||
name: mayavoz
|
||||
|
||||
dependencies:
|
||||
- pip=21.0.1
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo '----------------------------------------------------'
|
||||
echo ' SLURM_CLUSTER_NAME = '$SLURM_CLUSTER_NAME
|
||||
echo ' SLURMD_NODENAME = '$SLURMD_NODENAME
|
||||
echo ' SLURM_JOBID = '$SLURM_JOBID
|
||||
echo ' SLURM_JOB_USER = '$SLURM_JOB_USER
|
||||
echo ' SLURM_PARTITION = '$SLURM_JOB_PARTITION
|
||||
echo ' SLURM_JOB_ACCOUNT = '$SLURM_JOB_ACCOUNT
|
||||
echo '----------------------------------------------------'
|
||||
|
||||
#TeamCity Output
|
||||
cat << EOF
|
||||
##teamcity[buildNumber '$SLURM_JOBID']
|
||||
EOF
|
||||
|
||||
echo "Load HPC modules"
|
||||
module load anaconda
|
||||
|
||||
echo "Activate Environment"
|
||||
source activate enhancer
|
||||
export TRANSFORMERS_OFFLINE=True
|
||||
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
|
||||
export HYDRA_FULL_ERROR=1
|
||||
|
||||
echo $PYTHONPATH
|
||||
|
||||
source ~/mlflow_settings.sh
|
||||
|
||||
echo "Making temp dir"
|
||||
mkdir temp
|
||||
pwd
|
||||
|
||||
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train
|
||||
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
|
||||
|
||||
echo "Start Training..."
|
||||
python enhancer/cli/train.py
|
||||
|
|
@ -1 +1,2 @@
|
|||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
from mayavoz.models import Mayamodel
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def train(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : MS-SDSD
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
name : dns-2020
|
||||
duration : 2.0
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
valid_size: 0.05
|
||||
min_valid_minutes: 15
|
||||
files:
|
||||
train_clean : CleanSpeech_training
|
||||
test_clean : CleanSpeech_training
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
name : vctk
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : Valentini
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
stride : 2
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
experiment_name : shahules/enhancer
|
||||
experiment_name : shahules/mayavoz
|
||||
run_name : Demucs + Vtck with stride + augmentations
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: enhancer.models.dccrn.DCCRN
|
||||
_target_: mayavoz.models.dccrn.DCCRN
|
||||
num_channels: 1
|
||||
sampling_rate : 16000
|
||||
complex_lstm : True
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: enhancer.models.demucs.Demucs
|
||||
_target_: mayavoz.models.demucs.Demucs
|
||||
num_channels: 1
|
||||
resample: 4
|
||||
sampling_rate : 16000
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: enhancer.models.waveunet.WaveUnet
|
||||
_target_: mayavoz.models.waveunet.WaveUnet
|
||||
num_channels : 1
|
||||
depth : 9
|
||||
initial_output_channels: 24
|
||||
|
|
@ -0,0 +1 @@
|
|||
from mayavoz.data.dataset import MayaDataset
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -11,11 +13,11 @@ import torch.nn.functional as F
|
|||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch_audiomentations import Compose
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
from enhancer.utils import check_files
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
from mayavoz.data.fileprocessor import Fileprocessor
|
||||
from mayavoz.utils import check_files
|
||||
from mayavoz.utils.config import Files
|
||||
from mayavoz.utils.io import Audio
|
||||
from mayavoz.utils.random import create_unique_rng
|
||||
|
||||
LARGE_NUM = 2147483647
|
||||
|
||||
|
|
@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self._validation = []
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
|
||||
if (
|
||||
num_workers > 0
|
||||
and sys.platform == "darwin"
|
||||
and sys.version_info[0] >= 3
|
||||
and sys.version_info[1] >= 8
|
||||
):
|
||||
warnings.warn(
|
||||
"num_workers > 0 is not supported with macOS and Python 3.8+: "
|
||||
"setting num_workers = 0."
|
||||
)
|
||||
num_workers = 0
|
||||
|
||||
self.num_workers = num_workers
|
||||
if min_valid_minutes > 0.0:
|
||||
self.min_valid_minutes = min_valid_minutes
|
||||
|
|
@ -248,7 +265,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
)
|
||||
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
class MayaDataset(TaskDataset):
|
||||
"""
|
||||
Dataset object for creating clean-noisy speech enhancement datasets
|
||||
paramters:
|
||||
|
|
@ -258,7 +275,7 @@ class EnhancerDataset(TaskDataset):
|
|||
root directory of the dataset containing clean/noisy folders
|
||||
files : Files
|
||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||
folder names (refer enhancer.utils.Files dataclass)
|
||||
folder names (refer mayavoz.utils.Files dataclass)
|
||||
min_valid_minutes: float
|
||||
minimum validation split size time in minutes
|
||||
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
||||
|
|
@ -93,9 +93,9 @@ class Fileprocessor:
|
|||
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||
|
||||
if matching_function is None:
|
||||
if name.lower() == "vctk":
|
||||
if name.lower() in ("vctk", "valentini"):
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||
elif name.lower() == "dns-2020":
|
||||
elif name.lower() == "ms-snsd":
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
@ -8,7 +8,7 @@ from librosa import load as load_audio
|
|||
from scipy.io import wavfile
|
||||
from scipy.signal import get_window
|
||||
|
||||
from enhancer.utils import Audio
|
||||
from mayavoz.utils import Audio
|
||||
|
||||
|
||||
class Inference:
|
||||
|
|
@ -95,6 +95,7 @@ class Inference:
|
|||
):
|
||||
"""
|
||||
stitch batched waveform into single waveform. (Overlap-add)
|
||||
inspired from https://github.com/asteroid-team/asteroid
|
||||
arguments:
|
||||
data: batched waveform
|
||||
window_size : window_size used to batch waveform
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -134,7 +134,7 @@ class Pesq:
|
|||
try:
|
||||
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||
except Exception as e:
|
||||
logging.warning(f"{e} error occured while calculating PESQ")
|
||||
warnings.warn(f"{e} error occured while calculating PESQ")
|
||||
return torch.tensor(np.mean(pesq_values))
|
||||
|
||||
|
||||
|
|
@ -192,7 +192,7 @@ class Si_snr(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
||||
self.higher_better = True
|
||||
self.higher_better = False
|
||||
self.name = "si_snr"
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
|
@ -203,7 +203,7 @@ class Si_snr(nn.Module):
|
|||
got {prediction.size()} and {target.size()} instead"""
|
||||
)
|
||||
|
||||
return self.loss_fun(prediction, target)
|
||||
return -1 * self.loss_fun(prediction, target)
|
||||
|
||||
|
||||
LOSS_MAP = {
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from mayavoz.models.demucs import Demucs
|
||||
from mayavoz.models.model import Mayamodel
|
||||
from mayavoz.models.waveunet import WaveUnet
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from mayavoz.models.complexnn.conv import ComplexConv2d # noqa
|
||||
from mayavoz.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
||||
from mayavoz.models.complexnn.rnn import ComplexLSTM # noqa
|
||||
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
||||
from mayavoz.models.complexnn.utils import ComplexRelu # noqa
|
||||
|
|
@ -129,7 +129,7 @@ class ComplexConvTranspose2d(nn.Module):
|
|||
imag_real = self.real_conv(imag)
|
||||
|
||||
real = real_real - imag_imag
|
||||
imag = real_imag - imag_real
|
||||
imag = real_imag + imag_real
|
||||
|
||||
out = torch.cat([real, imag], 1)
|
||||
|
||||
|
|
@ -1,22 +1,22 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from enhancer.data import EnhancerDataset
|
||||
from enhancer.models import Model
|
||||
from enhancer.models.complexnn import (
|
||||
from mayavoz.data import MayaDataset
|
||||
from mayavoz.models import Mayamodel
|
||||
from mayavoz.models.complexnn import (
|
||||
ComplexBatchNorm2D,
|
||||
ComplexConv2d,
|
||||
ComplexConvTranspose2d,
|
||||
ComplexLSTM,
|
||||
ComplexRelu,
|
||||
)
|
||||
from enhancer.models.complexnn.utils import complex_cat
|
||||
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||
from enhancer.utils.utils import merge_dict
|
||||
from mayavoz.models.complexnn.utils import complex_cat
|
||||
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT
|
||||
from mayavoz.utils.utils import merge_dict
|
||||
|
||||
|
||||
class DCCRN_ENCODER(nn.Module):
|
||||
|
|
@ -98,7 +98,7 @@ class DCCRN_DECODER(nn.Module):
|
|||
return self.decoder(waveform)
|
||||
|
||||
|
||||
class DCCRN(Model):
|
||||
class DCCRN(Mayamodel):
|
||||
|
||||
STFT_DEFAULTS = {
|
||||
"window_len": 400,
|
||||
|
|
@ -134,17 +134,17 @@ class DCCRN(Model):
|
|||
num_channels: int = 1,
|
||||
sampling_rate=16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List, Any] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.utils.io import Audio as audio
|
||||
from enhancer.utils.utils import merge_dict
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models.model import Mayamodel
|
||||
from mayavoz.utils.io import Audio as audio
|
||||
from mayavoz.utils.utils import merge_dict
|
||||
|
||||
|
||||
class DemucsLSTM(nn.Module):
|
||||
|
|
@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class Demucs(Model):
|
||||
class Demucs(Mayamodel):
|
||||
"""
|
||||
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||
parameters:
|
||||
|
|
@ -102,8 +102,8 @@ class Demucs(Model):
|
|||
sampling rate of input audio
|
||||
lr : float, defaults to 1e-3
|
||||
learning rate used for training
|
||||
dataset: EnhancerDataset, optional
|
||||
EnhancerDataset object containing train/validation data for training
|
||||
dataset: MayaDataset, optional
|
||||
MayaDataset object containing train/validation data for training
|
||||
duration : float, optional
|
||||
chunk duration in seconds
|
||||
loss : string or List of strings
|
||||
|
|
@ -135,17 +135,18 @@ class Demucs(Model):
|
|||
sampling_rate=16000,
|
||||
normalize=True,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
floor=1e-3,
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
@ -13,17 +13,21 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
|
|||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.inference import Inference
|
||||
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||
from enhancer.version import __version__
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.inference import Inference
|
||||
from mayavoz.loss import LOSS_MAP, LossWrapper
|
||||
from mayavoz.version import __version__
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
CACHE_DIR = os.getenv(
|
||||
"ENHANCER_CACHE",
|
||||
os.path.expanduser("~/.cache/torch/mayavoz"),
|
||||
)
|
||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
SAVE_NAME = "mayavoz"
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
class Mayamodel(pl.LightningModule):
|
||||
"""
|
||||
Base class for all models
|
||||
parameters:
|
||||
|
|
@ -33,8 +37,8 @@ class Model(pl.LightningModule):
|
|||
audio sampling rate
|
||||
lr: float, optional
|
||||
learning rate for model training
|
||||
dataset: EnhancerDataset, optional
|
||||
Enhancer dataset used for training/validation
|
||||
dataset: MayaDataset, optional
|
||||
mayavoz dataset used for training/validation
|
||||
duration: float, optional
|
||||
duration used for training/inference
|
||||
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
||||
|
|
@ -47,15 +51,13 @@ class Model(pl.LightningModule):
|
|||
num_channels: int = 1,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List, Any] = "mse",
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
num_channels == 1
|
||||
), "Enhancer only support for mono channel models"
|
||||
assert num_channels == 1, "mayavoz only support for mono channel models"
|
||||
self.dataset = dataset
|
||||
self.save_hyperparameters(
|
||||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||
|
|
@ -232,8 +234,8 @@ class Model(pl.LightningModule):
|
|||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["enhancer"] = {
|
||||
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||
checkpoint[SAVE_NAME] = {
|
||||
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__},
|
||||
"architecture": {
|
||||
"module": self.__class__.__module__,
|
||||
"class": self.__class__.__name__,
|
||||
|
|
@ -286,8 +288,8 @@ class Model(pl.LightningModule):
|
|||
|
||||
Returns
|
||||
-------
|
||||
model : Model
|
||||
Model
|
||||
model : Mayamodel
|
||||
Mayamodel
|
||||
|
||||
See also
|
||||
--------
|
||||
|
|
@ -316,7 +318,7 @@ class Model(pl.LightningModule):
|
|||
)
|
||||
model_path_pl = cached_download(
|
||||
url=url,
|
||||
library_name="enhancer",
|
||||
library_name="mayavoz",
|
||||
library_version=__version__,
|
||||
cache_dir=cached_dir,
|
||||
use_auth_token=use_auth_token,
|
||||
|
|
@ -326,8 +328,8 @@ class Model(pl.LightningModule):
|
|||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.model import Model
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models.model import Mayamodel
|
||||
|
||||
|
||||
class WavenetDecoder(nn.Module):
|
||||
|
|
@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module):
|
|||
return self.encoder(waveform)
|
||||
|
||||
|
||||
class WaveUnet(Model):
|
||||
class WaveUnet(Mayamodel):
|
||||
"""
|
||||
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||
parameters:
|
||||
|
|
@ -80,8 +80,8 @@ class WaveUnet(Model):
|
|||
sampling rate of input audio
|
||||
lr : float, defaults to 1e-3
|
||||
learning rate used for training
|
||||
dataset: EnhancerDataset, optional
|
||||
EnhancerDataset object containing train/validation data for training
|
||||
dataset: MayaDataset, optional
|
||||
MayaDataset object containing train/validation data for training
|
||||
duration : float, optional
|
||||
chunk duration in seconds
|
||||
loss : string or List of strings
|
||||
|
|
@ -97,17 +97,17 @@ class WaveUnet(Model):
|
|||
initial_output_channels: int = 24,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from mayavoz.utils.config import Files
|
||||
from mayavoz.utils.io import Audio
|
||||
from mayavoz.utils.utils import check_files
|
||||
|
|
@ -85,7 +85,7 @@ class ConviSTFT(ConvFFT):
|
|||
input = torch.cat([real, imag], 1)
|
||||
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
|
||||
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
|
||||
coeff.to(input.device)
|
||||
coeff = coeff.to(input.device)
|
||||
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
|
||||
out = out / (coeff + 1e-8)
|
||||
pad = self.window_len - self.hop_size
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from enhancer.utils.config import Files
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
||||
def check_files(root_dir: str, files: Files):
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ccd61d5c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom model training using mayavoz [advanced]\n",
|
||||
"\n",
|
||||
"In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
|
||||
"\n",
|
||||
" - [Data preparation using MayaDataset](#dataprep)\n",
|
||||
" - [Model customization](#modelcustom)\n",
|
||||
" - [callbacks & LR schedulers](#callbacks)\n",
|
||||
" - [Model training & testing](#train)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "726c320f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- **install mayavoz**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c987c799",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -q mayavoz"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8ff9857b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"dataprep\"></div>\n",
|
||||
"\n",
|
||||
"### Data preparation\n",
|
||||
"\n",
|
||||
"`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"- VCTK/\n",
|
||||
" |__ clean_train_wav/\n",
|
||||
" |__ noisy_train_wav/\n",
|
||||
" |__ clean_test_wav/\n",
|
||||
" |__ noisy_test_wav/\n",
|
||||
" \n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "64cbc0c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.utils import Files\n",
|
||||
"file = Files(train_clean=\"clean_train_wav\",\n",
|
||||
" train_noisy=\"noisy_train_wav\",\n",
|
||||
" test_clean=\"clean_test_wav\",\n",
|
||||
" test_noisy=\"noisy_test_wav\")\n",
|
||||
"root_dir = \"VCTK\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d324bd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- `name`: name of the dataset. \n",
|
||||
"- `duration`: control the duration of each audio instance fed into your model.\n",
|
||||
"- `stride` is used if set to move the sliding window.\n",
|
||||
"- `sampling_rate`: desired sampling rate for audio\n",
|
||||
"- `batch_size`: model batch size\n",
|
||||
"- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
|
||||
"- `matching_function`: there are two types of mapping functions.\n",
|
||||
" - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
|
||||
" - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example MS-SNSD dataset.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6834941d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"name = \"vctk\"\n",
|
||||
"duration : 4.5\n",
|
||||
"stride : 2.0\n",
|
||||
"sampling_rate : 16000\n",
|
||||
"min_valid_minutes : 20.0\n",
|
||||
"batch_size : 32\n",
|
||||
"matching_function : \"one_to_one\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d08c6bf8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.dataset import MayaDataset\n",
|
||||
"dataset = MayaDataset(\n",
|
||||
" name=name,\n",
|
||||
" root_dir=root_dir,\n",
|
||||
" files=files,\n",
|
||||
" duration=duration,\n",
|
||||
" stride=stride,\n",
|
||||
" sampling_rate=sampling_rate,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" min_valid_minutes=min_valid_minutes,\n",
|
||||
" matching_function=matching_function\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5b315bde",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now your custom dataloader is ready!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "01548fe5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"modelcustom\"></div>\n",
|
||||
"\n",
|
||||
"### Model Customization\n",
|
||||
"Now, this is very easy. \n",
|
||||
"\n",
|
||||
"- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
|
||||
" - `WaveUnet`\n",
|
||||
" - `Demucs`\n",
|
||||
" - `DCCRN`\n",
|
||||
"- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n",
|
||||
"- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
|
||||
"- `dataset`: mayavoz dataset object as prepared earlier.\n",
|
||||
"- `loss` : model loss. Multiple loss functions are available.\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
|
||||
"\n",
|
||||
"mayavoz can accept **custom loss functions**. It should be of the form.\n",
|
||||
"```\n",
|
||||
"class your_custom_loss(nn.Module):\n",
|
||||
" def __init__(self,**kwargs):\n",
|
||||
" self.higher_better = False ## loss minimization direction\n",
|
||||
" self.name = \"your_loss_name\" ## loss name logging \n",
|
||||
" ...\n",
|
||||
" def forward(self,prediction, target):\n",
|
||||
" loss = ....\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b36b457c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.models import Demucs\n",
|
||||
"model = Demucs(\n",
|
||||
" sampling_rate=16000,\n",
|
||||
" dataset=dataset,\n",
|
||||
" loss=[\"mae\"],\n",
|
||||
" metrics=[\"stoi\",\"pesq\"])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1523d638",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"callbacks\"></div>\n",
|
||||
"\n",
|
||||
"### learning rate schedulers and callbacks\n",
|
||||
"Here I am using `ReduceLROnPlateau`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8de6931c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
||||
"\n",
|
||||
"def configure_optimizers(self):\n",
|
||||
" optimizer = instantiate(\n",
|
||||
" config.optimizer,\n",
|
||||
" lr=parameters.get(\"lr\"),\n",
|
||||
" params=self.parameters(),\n",
|
||||
" )\n",
|
||||
" scheduler = ReduceLROnPlateau(\n",
|
||||
" optimizer=optimizer,\n",
|
||||
" mode=direction,\n",
|
||||
" factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
|
||||
" verbose=True,\n",
|
||||
" min_lr=parameters.get(\"min_lr\", 1e-6),\n",
|
||||
" patience=parameters.get(\"ReduceLr_patience\", 3),\n",
|
||||
" )\n",
|
||||
" return {\n",
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": scheduler,\n",
|
||||
" \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model.configure_optimizers = MethodType(configure_optimizers, model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f7b5af5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f6b62a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"callbacks = []\n",
|
||||
"direction = model.valid_monitor ## min or max \n",
|
||||
"checkpoint = ModelCheckpoint(\n",
|
||||
" dirpath=\"./model\",\n",
|
||||
" filename=f\"model_filename\",\n",
|
||||
" monitor=\"valid_loss\",\n",
|
||||
" verbose=False,\n",
|
||||
" mode=direction,\n",
|
||||
" every_n_epochs=1,\n",
|
||||
" )\n",
|
||||
"callbacks.append(checkpoint)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f3534445",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"train\"></div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3dc0348b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pytorch_lightning as pl\n",
|
||||
"trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
|
||||
"trainer.fit(model)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "56dcfec1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Test your model agaist test dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63851feb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.test(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d3f5350",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Hurray! you have your speech enhancement model trained and tested.**\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10d630e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "enhancer",
|
||||
"language": "python",
|
||||
"name": "enhancer"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def train(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- dataset : MS-SNSD
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : MS-SDSD
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
duration : 1.5
|
||||
stride : 1
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
min_valid_minutes: 25
|
||||
files:
|
||||
train_clean : CleanSpeech_training
|
||||
test_clean : CleanSpeech_training
|
||||
train_noisy : NoisySpeech_training
|
||||
test_noisy : NoisySpeech_training
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
loss : si-snr
|
||||
metric : [stoi,pesq]
|
||||
lr : 0.001
|
||||
ReduceLr_patience : 10
|
||||
ReduceLr_factor : 0.5
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/mayavoz
|
||||
run_name : Demucs + Vtck with stride + augmentations
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
_target_: mayavoz.models.dccrn.DCCRN
|
||||
num_channels: 1
|
||||
sampling_rate : 16000
|
||||
complex_lstm : True
|
||||
complex_norm : True
|
||||
complex_relu : True
|
||||
masking_mode : True
|
||||
|
||||
encoder_decoder:
|
||||
initial_output_channels : 32
|
||||
depth : 6
|
||||
kernel_size : 5
|
||||
growth_factor : 2
|
||||
stride : 2
|
||||
padding : 2
|
||||
output_padding : 1
|
||||
|
||||
lstm:
|
||||
num_layers : 2
|
||||
hidden_size : 256
|
||||
|
||||
stft:
|
||||
window_len : 400
|
||||
hop_size : 100
|
||||
nfft : 512
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 2
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: ddp
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
fast_dev_run: True
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def train(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- dataset : MS-SNSD
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : MS-SDSD
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
duration : 5
|
||||
stride : 1
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
min_valid_minutes: 25
|
||||
files:
|
||||
train_clean : CleanSpeech_training
|
||||
test_clean : CleanSpeech_training
|
||||
train_noisy : NoisySpeech_training
|
||||
test_noisy : NoisySpeech_training
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
loss : mae
|
||||
metric : [stoi,pesq]
|
||||
lr : 0.0003
|
||||
ReduceLr_patience : 10
|
||||
ReduceLr_factor : 0.5
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/mayavoz
|
||||
run_name : demucs-ms-snsd
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
_target_: mayavoz.models.demucs.Demucs
|
||||
num_channels: 1
|
||||
resample: 4
|
||||
sampling_rate : 16000
|
||||
|
||||
encoder_decoder:
|
||||
depth: 4
|
||||
initial_output_channels: 64
|
||||
kernel_size: 8
|
||||
stride: 4
|
||||
growth_factor: 2
|
||||
glu: True
|
||||
|
||||
lstm:
|
||||
bidirectional: False
|
||||
num_layers: 2
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 2
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: ddp
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
fast_dev_run: True
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
### Microsoft Scalable Noisy Speech Dataset (MS-SNSD)
|
||||
|
||||
MS-SNSD is a speech datasetthat can scale to arbitrary sizes depending on the number of speakers, noise types, and Speech to Noise Ratio (SNR) levels desired.
|
||||
|
||||
### Dataset download & setup
|
||||
- Follow steps in the official repo [here](https://github.com/microsoft/MS-SNSD) to download and setup the dataset.
|
||||
|
||||
**References**
|
||||
```BibTex
|
||||
@article{reddy2019scalable,
|
||||
title={A Scalable Noisy Speech Dataset and Online Subjective Test Framework},
|
||||
author={Reddy, Chandan KA and Beyrami, Ebrahim and Pool, Jamie and Cutler, Ross and Srinivasan, Sriram and Gehrke, Johannes},
|
||||
journal={Proc. Interspeech 2019},
|
||||
pages={1816--1820},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
stride : 0.5
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
min_valid_minutes : 25
|
||||
files:
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
loss : mae
|
||||
metric : [stoi,pesq,si-sdr]
|
||||
lr : 0.0003
|
||||
Early_stop : False
|
||||
ReduceLr_patience : 10
|
||||
ReduceLr_factor : 0.1
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/mayavoz
|
||||
run_name : baseline
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
_target_: mayavoz.models.demucs.Demucs
|
||||
num_channels: 1
|
||||
resample: 4
|
||||
sampling_rate : 16000
|
||||
|
||||
encoder_decoder:
|
||||
depth: 4
|
||||
initial_output_channels: 64
|
||||
kernel_size: 8
|
||||
stride: 4
|
||||
growth_factor: 2
|
||||
glu: True
|
||||
|
||||
lstm:
|
||||
bidirectional: True
|
||||
num_layers: 2
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 1
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: null
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config_log.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : WaveUnet
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 2
|
||||
stride : 1
|
||||
sampling_rate: 16000
|
||||
batch_size: 128
|
||||
valid_minutes : 25
|
||||
files:
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
loss : mae
|
||||
metric : [stoi,pesq,si-sdr]
|
||||
lr : 0.003
|
||||
ReduceLr_patience : 10
|
||||
ReduceLr_factor : 0.1
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
Early_stop : False
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/mayavoz
|
||||
run_name : baseline
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
_target_: mayavoz.models.waveunet.WaveUnet
|
||||
num_channels : 1
|
||||
depth : 9
|
||||
initial_output_channels: 24
|
||||
sampling_rate : 16000
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 2
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: ddp
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
fast_dev_run: True
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
## Valentini dataset
|
||||
|
||||
Clean and noisy parallel speech database. The database was designed to train and test speech enhancement methods that operate at 48kHz. A more detailed description can be found in the papers associated with the database.[official page](https://datashare.ed.ac.uk/handle/10283/2791)
|
||||
|
||||
**References**
|
||||
```BibTex
|
||||
@misc{
|
||||
title={Noisy speech database for training speech enhancement algorithms and TTS models},
|
||||
author={Valentini-Botinhao, Cassia}, year={2017},
|
||||
doi=https://doi.org/10.7488/ds/2117,
|
||||
}
|
||||
```
|
||||
|
|
@ -3,7 +3,7 @@ huggingface-hub>=0.10.0
|
|||
hydra-core>=1.2.0
|
||||
joblib>=1.2.0
|
||||
librosa>=0.9.2
|
||||
mlflow>=1.29.0
|
||||
mlflow>=1.28.0
|
||||
numpy>=1.23.3
|
||||
pesq==0.0.4
|
||||
protobuf>=3.19.6
|
||||
|
|
|
|||
10
setup.cfg
10
setup.cfg
|
|
@ -3,7 +3,7 @@
|
|||
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
||||
|
||||
[metadata]
|
||||
name = enhancer
|
||||
name = mayavoz
|
||||
description = Deep learning for speech enhacement
|
||||
author = Shahul Ess
|
||||
author-email = shahules786@gmail.com
|
||||
|
|
@ -53,7 +53,7 @@ cli =
|
|||
[options.entry_points]
|
||||
|
||||
console_scripts =
|
||||
enhancer-train=enhancer.cli.train:train
|
||||
mayavoz-train=mayavoz.cli.train:train
|
||||
|
||||
[test]
|
||||
# py.test options when running `python setup.py test`
|
||||
|
|
@ -66,7 +66,7 @@ extras = True
|
|||
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
||||
# in order to write a coverage file that can be read by Jenkins.
|
||||
addopts =
|
||||
--cov enhancer --cov-report term-missing
|
||||
--cov mayavoz --cov-report term-missing
|
||||
--verbose
|
||||
norecursedirs =
|
||||
dist
|
||||
|
|
@ -98,3 +98,7 @@ exclude =
|
|||
build
|
||||
dist
|
||||
.eggs
|
||||
|
||||
[options.data_files]
|
||||
. = requirements.txt
|
||||
_ = version.txt
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -33,15 +33,15 @@ elif sha != "Unknown":
|
|||
version += "+" + sha[:7]
|
||||
print("-- Building version " + version)
|
||||
|
||||
version_path = ROOT_DIR / "enhancer" / "version.py"
|
||||
version_path = ROOT_DIR / "mayavoz" / "version.py"
|
||||
|
||||
with open(version_path, "w") as f:
|
||||
f.write("__version__ = '{}'\n".format(version))
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="enhancer",
|
||||
namespace_packages=["enhancer"],
|
||||
name="mayavoz",
|
||||
namespace_packages=["mayavoz"],
|
||||
version=version,
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
|
|
|
|||
13
setup.sh
13
setup.sh
|
|
@ -1,13 +0,0 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "Loading Anaconda Module"
|
||||
module load anaconda
|
||||
|
||||
echo "Creating Virtual Environment"
|
||||
conda env create -f environment.yml || conda env update -f environment.yml
|
||||
|
||||
source activate enhancer
|
||||
|
||||
echo "copying files"
|
||||
# cp /scratch/$USER/TIMIT/.* /deep-transcriber
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from enhancer.loss import mean_absolute_error, mean_squared_error
|
||||
from mayavoz.loss import mean_absolute_error, mean_squared_error
|
||||
|
||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
|
||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
|
||||
from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||
from mayavoz.models.complexnn.rnn import ComplexLSTM
|
||||
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D
|
||||
|
||||
|
||||
def test_complexconv2d():
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models import Demucs
|
||||
from enhancer.utils.config import Files
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models import Demucs
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(
|
||||
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.dccrn import DCCRN
|
||||
from enhancer.utils.config import Files
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models.dccrn import DCCRN
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(
|
||||
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models import WaveUnet
|
||||
from enhancer.utils.config import Files
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models import WaveUnet
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(
|
||||
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from enhancer.inference import Inference
|
||||
from mayavoz.inference import Inference
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -27,3 +27,12 @@ def test_aggregate():
|
|||
data=rand, window_size=100, total_frames=1000, step_size=100
|
||||
)
|
||||
assert agg_rand.shape[-1] == 1000
|
||||
|
||||
|
||||
def test_pretrained():
|
||||
from mayavoz.models import Mayamodel
|
||||
|
||||
model = Mayamodel.from_pretrained(
|
||||
"shahules786/mayavoz-waveunet-valentini-28spk"
|
||||
)
|
||||
_ = model.enhance("tests/data/vctk/clean_testset_wav/p257_166.wav")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT
|
||||
|
||||
|
||||
def test_stft_istft():
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue