Merge pull request #20 from shahules786/dev-recipe
recipes and tutorials
This commit is contained in:
commit
ebba5952e5
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2019 Pariente Manuel
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
14
README.md
14
README.md
|
|
@ -2,9 +2,12 @@
|
|||
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
|
||||
</p>
|
||||
|
||||

|
||||

|
||||
|
||||
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
|
||||
|
||||
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
||||
| **[Quick Start](#quick-start-fire)** | **[Installation](#installation)** | **[Tutorials](https://github.com/shahules786/enhancer/notebooks/)** | **[Available Recipes](#recipes)** | **[Demo]()**
|
||||
## Key features :key:
|
||||
|
||||
* Various pretrained models nicely integrated with huggingface :hugs: that users can select and use without any hastle.
|
||||
|
|
@ -20,6 +23,15 @@ model = Mayamodel.from_pretrained("mayavoz/waveunet")
|
|||
model("noisy_audio.wav")
|
||||
```
|
||||
|
||||
## Recipes
|
||||
|
||||
| Model | Dataset | STOI | PESQ | URL |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| WaveUnet | Vctk-28spk | 0.836 | 2.78 | shahules786/mayavoz-waveunet-valentini-28spk |
|
||||
| Demucs | DNS-2020 (30hrs) | 0.961 | 2.56 | shahules786/mayavoz-demucs-valentini-28spk |
|
||||
| DCCRN | DNS-2020 (30hrs) | | | mayavoz/dccrn-vctk28 |
|
||||
|
||||
|
||||
## Installation
|
||||
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,11 @@ from enhancer.inference import Inference
|
|||
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||
from enhancer.version import __version__
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
CACHE_DIR = os.getenv(
|
||||
"ENHANCER_CACHE",
|
||||
os.path.expanduser("~/.cache/torch/enhancer"),
|
||||
)
|
||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,338 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ccd61d5c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom model training using mayavoz [advanced]\n",
|
||||
"\n",
|
||||
"In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
|
||||
"\n",
|
||||
" - [Data preparation using MayaDataset](#dataprep)\n",
|
||||
" - [Model customization](#modelcustom)\n",
|
||||
" - [callbacks & LR schedulers](#callbacks)\n",
|
||||
" - [Model training & testing](#train)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "726c320f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- **install mayavoz**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c987c799",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -q mayavoz"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8ff9857b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"dataprep\"></div>\n",
|
||||
"\n",
|
||||
"### Data preparation\n",
|
||||
"\n",
|
||||
"`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"- VCTK/\n",
|
||||
" |__ clean_train_wav/\n",
|
||||
" |__ noisy_train_wav/\n",
|
||||
" |__ clean_test_wav/\n",
|
||||
" |__ noisy_test_wav/\n",
|
||||
" \n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "64cbc0c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.utils import Files\n",
|
||||
"file = Files(train_clean=\"clean_train_wav\",\n",
|
||||
" train_noisy=\"noisy_train_wav\",\n",
|
||||
" test_clean=\"clean_test_wav\",\n",
|
||||
" test_noisy=\"noisy_test_wav\")\n",
|
||||
"root_dir = \"VCTK\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d324bd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- `name`: name of the dataset. \n",
|
||||
"- `duration`: control the duration of each audio instance fed into your model.\n",
|
||||
"- `stride` is used if set to move the sliding window.\n",
|
||||
"- `sampling_rate`: desired sampling rate for audio\n",
|
||||
"- `batch_size`: model batch size\n",
|
||||
"- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
|
||||
"- `matching_function`: there are two types of mapping functions.\n",
|
||||
" - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
|
||||
" - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6834941d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"name = \"vctk\"\n",
|
||||
"duration : 4.5\n",
|
||||
"stride : 2.0\n",
|
||||
"sampling_rate : 16000\n",
|
||||
"min_valid_minutes : 20.0\n",
|
||||
"batch_size : 32\n",
|
||||
"matching_function : \"one_to_one\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d08c6bf8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.dataset import MayaDataset\n",
|
||||
"dataset = MayaDataset(\n",
|
||||
" name=name,\n",
|
||||
" root_dir=root_dir,\n",
|
||||
" files=files,\n",
|
||||
" duration=duration,\n",
|
||||
" stride=stride,\n",
|
||||
" sampling_rate=sampling_rate,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" min_valid_minutes=min_valid_minutes,\n",
|
||||
" matching_function=matching_function\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5b315bde",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now your custom dataloader is ready!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "01548fe5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"modelcustom\"></div>\n",
|
||||
"\n",
|
||||
"### Model Customization\n",
|
||||
"Now, this is very easy. \n",
|
||||
"\n",
|
||||
"- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
|
||||
" - `WaveUnet`\n",
|
||||
" - `Demucs`\n",
|
||||
" - `DCCRN`\n",
|
||||
"- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n",
|
||||
"- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
|
||||
"- `dataset`: mayavoz dataset object as prepared earlier.\n",
|
||||
"- `loss` : model loss. Multiple loss functions are available.\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
|
||||
"\n",
|
||||
"mayavoz can accept **custom loss functions**. It should be of the form.\n",
|
||||
"```\n",
|
||||
"class your_custom_loss(nn.Module):\n",
|
||||
" def __init__(self,**kwargs):\n",
|
||||
" self.higher_better = False ## loss minimization direction\n",
|
||||
" self.name = \"your_loss_name\" ## loss name logging \n",
|
||||
" ...\n",
|
||||
" def forward(self,prediction, target):\n",
|
||||
" loss = ....\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b36b457c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.models import Demucs\n",
|
||||
"model = Demucs(\n",
|
||||
" sampling_rate=16000,\n",
|
||||
" dataset=dataset,\n",
|
||||
" loss=[\"mae\"],\n",
|
||||
" metrics=[\"stoi\",\"pesq\"])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1523d638",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"callbacks\"></div>\n",
|
||||
"\n",
|
||||
"### learning rate schedulers and callbacks\n",
|
||||
"Here I am using `ReduceLROnPlateau`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8de6931c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
||||
"\n",
|
||||
"def configure_optimizers(self):\n",
|
||||
" optimizer = instantiate(\n",
|
||||
" config.optimizer,\n",
|
||||
" lr=parameters.get(\"lr\"),\n",
|
||||
" params=self.parameters(),\n",
|
||||
" )\n",
|
||||
" scheduler = ReduceLROnPlateau(\n",
|
||||
" optimizer=optimizer,\n",
|
||||
" mode=direction,\n",
|
||||
" factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
|
||||
" verbose=True,\n",
|
||||
" min_lr=parameters.get(\"min_lr\", 1e-6),\n",
|
||||
" patience=parameters.get(\"ReduceLr_patience\", 3),\n",
|
||||
" )\n",
|
||||
" return {\n",
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": scheduler,\n",
|
||||
" \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model.configure_optimizers = MethodType(configure_optimizers, model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f7b5af5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f6b62a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"callbacks = []\n",
|
||||
"direction = model.valid_monitor ## min or max \n",
|
||||
"checkpoint = ModelCheckpoint(\n",
|
||||
" dirpath=\"./model\",\n",
|
||||
" filename=f\"model_filename\",\n",
|
||||
" monitor=\"valid_loss\",\n",
|
||||
" verbose=False,\n",
|
||||
" mode=direction,\n",
|
||||
" every_n_epochs=1,\n",
|
||||
" )\n",
|
||||
"callbacks.append(checkpoint)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f3534445",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"train\"></div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3dc0348b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pytorch_lightning as pl\n",
|
||||
"trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
|
||||
"trainer.fit(model)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "56dcfec1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Test your model agaist test dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63851feb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.test(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d3f5350",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Hurray! you have your speech enhancement model trained and tested.**\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10d630e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "enhancer",
|
||||
"language": "python",
|
||||
"name": "enhancer"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
|
@ -0,0 +1,427 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7bd11665",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Getting Started with Mayavoz\n",
|
||||
"\n",
|
||||
"#### Contents:\n",
|
||||
"- [How to do inference using pretrained model](#inference)\n",
|
||||
"- [How to train your custom model](#basictrain)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d3c589bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Install Mayavoz"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5b68e053",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -q mayavoz "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87ee497f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"inference\"></div>\n",
|
||||
"\n",
|
||||
"### Pretrained Model\n",
|
||||
"\n",
|
||||
"To start using pretrained model,select any of the available recipes from [here](). \n",
|
||||
"For this exercice I am selecting [mayavoz/waveunet]()\n",
|
||||
"\n",
|
||||
"- Mayavoz supports multiple input and output format. Input for inference can be in any of the below format\n",
|
||||
" - audio file path\n",
|
||||
" - numpy audio data\n",
|
||||
" - torch tensor audio data\n",
|
||||
" \n",
|
||||
"It auto-detects the input format and does inference for you.\n",
|
||||
" \n",
|
||||
"At the moment mayavoz only accepts single audio input"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bd514ff4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Load model**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "67698871",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from mayavoz import Mayamodel\n",
|
||||
"model = Mayamodel.from_pretrained(\"mayavoz/waveunet\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c7fd4cbe",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Inference using file path**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d7996c16",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"file = \"myvoice.wav\"\n",
|
||||
"audio = model.enhance(\"myvoice.wav\")\n",
|
||||
"audio.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8ee20a83",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Inference using torch tensor**\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1a1c718",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"audio_tensor = torch.rand(1,1,32000) ## random audio data\n",
|
||||
"audio = model.enhance(audio_tensor)\n",
|
||||
"audio.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2ac27920",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- if you want to save the output, just pass `save_output=True`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e0313f7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"audio = model.enhance(\"myvoice.wav\",save_output=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "25077720",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from Ipython.audio import Audio\n",
|
||||
"\n",
|
||||
"Audio(\"myvoice_cleaned.wav\",rate=SAMPLING_RATE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3170bb0b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"basictrain\"></div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Training your own custom Model\n",
|
||||
"\n",
|
||||
"There are two ways of doing this\n",
|
||||
"\n",
|
||||
"* [Using mayavoz framework ](#code)\n",
|
||||
"* [Using mayavoz command line tool ](#cli)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a44fc314",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"code\"></div>\n",
|
||||
"\n",
|
||||
"**Using Mayavoz framwork** [Basic]\n",
|
||||
"- Prepapare dataloader\n",
|
||||
"- import preferred model\n",
|
||||
"- Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbc14b36",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Files is dataclass that helps your to organise your train/test file paths"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c8c2b12",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.utils import Files\n",
|
||||
"\n",
|
||||
"name = \"dataset_name\"\n",
|
||||
"root_dir = \"root_directory_of_your_dataset\"\n",
|
||||
"files = Files(train_clean=\"train_cleanfiles_foldername\",\n",
|
||||
" train_noisy=\"noisy_train_foldername\",\n",
|
||||
" test_clean=\"clean_test_foldername\",\n",
|
||||
" test_noisy=\"noisy_test_foldername\")\n",
|
||||
"duration = 4.0 \n",
|
||||
"stride = None\n",
|
||||
"sampling_rate = 16000"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07ef8721",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now there are two types of `matching_function`\n",
|
||||
"- `one_to_one` : In this one clean file will only have one corresponding noisy file. For example VCTK datasets\n",
|
||||
"- `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4b0fdc62",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mapping_function = \"one_to_one\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff0cfe60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.dataset import MayaDataset\n",
|
||||
"dataset = MayaDataset(\n",
|
||||
" name=name,\n",
|
||||
" root_dir=root_dir,\n",
|
||||
" files=files,\n",
|
||||
" duration=duration,\n",
|
||||
" stride=stride,\n",
|
||||
" sampling_rate=sampling_rate\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "acfdc655",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.models import Demucs\n",
|
||||
"model = Demucs(dataset=dataset, loss=\"mae\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "4fabe46d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pytorch_lightning as pl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "20d98ed0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(model)\n",
|
||||
"trainer.fit(max_epochs=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "28bc697b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**mayavoz model and dataset are highly customazibale**, see [here]() for advanced usage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "df01aa1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div id=\"cli\"></div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Mayavoz CLI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2bbf2747",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install mayavoz[cli]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4447dd07",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### TL;DR\n",
|
||||
"Calling the following command would train mayavoz Demucs model on DNS-2020 dataset.\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"mayavoz-train \\\n",
|
||||
" model=Demucs \\\n",
|
||||
" Demucs.sampling_rate=16000 \\\n",
|
||||
" dataset=DNS-2020 \\\n",
|
||||
" DNS-2020.name = \"dns-2020\" \\\n",
|
||||
" DNS-2020.root_dir=\"your_root_dir\" \\\n",
|
||||
" DNS-2020.train_clean=\"\" \\\n",
|
||||
" DNS-2020.train_noisy=\"\" \\\n",
|
||||
" DNS-2020.test_clean=\"\" \\\n",
|
||||
" DNS-2020.test_noisy=\"\" \\\n",
|
||||
" DNS-2020.sampling_rate=16000 \\\n",
|
||||
" DNS-2020.duration=2.0 \\\n",
|
||||
" traine=default \\ \n",
|
||||
" default.max_epochs=1 \\\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This is more or less equaivalent to below code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9278742a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mayavoz.utils import Files\n",
|
||||
"from mayavoz.data import MayaDataset\n",
|
||||
"from mayavoz.models import Demucs\n",
|
||||
"\n",
|
||||
"files = Files(\n",
|
||||
" train_clean=\"\",\n",
|
||||
" train_noisy=\"\",\n",
|
||||
" test_clean=\"\",\n",
|
||||
" test_noisy=\"\"\n",
|
||||
")\n",
|
||||
"dataset = MayaDataset(\n",
|
||||
" name='dns-2020'\n",
|
||||
" root_dir=\"your_root_dir\",\n",
|
||||
" files=files,\n",
|
||||
" sampling_rate=16000,\n",
|
||||
" duration=2.0)\n",
|
||||
"model = Demucs(dataset=dataset,sampling_rate=16000)\n",
|
||||
"trainer = Trainer(max_epochs=1)\n",
|
||||
"trainer.fit(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb26692c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Hydra-based configuration\n",
|
||||
"mayavoz-train relies on Hydra to configure the training process. Adding --cfg job option to the previous command will let you know about the actual configuration used for training:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"mayavoz-train --cfg job \\\n",
|
||||
" model=Demucs \\\n",
|
||||
" Demucs.sampling_rate=16000 \\\n",
|
||||
" dataset=DNS-2020\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"_target_: enhancer.models.demucs.Demucs\n",
|
||||
"num_channels: 1\n",
|
||||
"resample: 4\n",
|
||||
"sampling_rate : 16000\n",
|
||||
"\n",
|
||||
"encoder_decoder:\n",
|
||||
" depth: 4\n",
|
||||
" initial_output_channels: 64\n",
|
||||
" \n",
|
||||
"[...]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"To change the sampling_rate, you can \n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"mayavoz-train \\\n",
|
||||
" model=Demucs model.sampling_rate=16000 \\\n",
|
||||
" dataset=DNS-2020\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "93555860",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "enhancer",
|
||||
"language": "python",
|
||||
"name": "enhancer"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
### DNS Challenge's dataset
|
||||
|
||||
The Deep Noise Suppression (DNS) Challenge is a single-channel speech enhancement
|
||||
challenge organized by Microsoft, with a focus on real-time applications.
|
||||
More info can be found on the [official page](https://dns-challenge.azurewebsites.net/).
|
||||
|
||||
**References**
|
||||
The challenge paper, [here](https://arxiv.org/abs/2001.08662).
|
||||
```BibTex
|
||||
@misc{DNSChallenge2020,
|
||||
title={The INTERSPEECH 2020 Deep Noise Suppression Challenge: Datasets, Subjective Speech Quality and Testing Framework},
|
||||
author={Chandan K. A. Reddy and Ebrahim Beyrami and Harishchandra Dubey and Vishak Gopal and Roger Cheng and Ross Cutler and Sergiy Matusevych and Robert Aichner and Ashkan Aazami and Sebastian Braun and Puneet Rana and Sriram Srinivasan and Johannes Gehrke}, year={2020},
|
||||
doi=https://doi.org/10.48550/arXiv.2001.08662,
|
||||
}
|
||||
```
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config_log.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
stride : 0.5
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
min_valid_minutes : 25
|
||||
files:
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
loss : mae
|
||||
metric : [stoi,pesq,si-sdr]
|
||||
lr : 0.0003
|
||||
Early_stop : False
|
||||
ReduceLr_patience : 10
|
||||
ReduceLr_factor : 0.1
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/enhancer
|
||||
run_name : baseline
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
_target_: enhancer.models.demucs.Demucs
|
||||
num_channels: 1
|
||||
resample: 4
|
||||
sampling_rate : 16000
|
||||
|
||||
encoder_decoder:
|
||||
depth: 4
|
||||
initial_output_channels: 64
|
||||
kernel_size: 8
|
||||
stride: 4
|
||||
growth_factor: 2
|
||||
glu: True
|
||||
|
||||
lstm:
|
||||
bidirectional: True
|
||||
num_layers: 2
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 1
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: null
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config_log.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : WaveUnet
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 2
|
||||
stride : 1
|
||||
sampling_rate: 16000
|
||||
batch_size: 128
|
||||
valid_minutes : 25
|
||||
files:
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
loss : mae
|
||||
metric : [stoi,pesq,si-sdr]
|
||||
lr : 0.003
|
||||
ReduceLr_patience : 10
|
||||
ReduceLr_factor : 0.1
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
Early_stop : False
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/enhancer
|
||||
run_name : baseline
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
_target_: enhancer.models.waveunet.WaveUnet
|
||||
num_channels : 1
|
||||
depth : 9
|
||||
initial_output_channels: 24
|
||||
sampling_rate : 16000
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 2
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: ddp
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
fast_dev_run: True
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config_log.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="valid_loss",
|
||||
verbose=False,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_train_steps_per_epoch",
|
||||
dataset.train__len__() / dataset.batch_size,
|
||||
)
|
||||
logger.experiment.log_param(
|
||||
logger.run_id,
|
||||
"num_valid_steps_per_epoch",
|
||||
dataset.val__len__() / dataset.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
name : dns-2020
|
||||
duration : 2.0
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
valid_size: 0.05
|
||||
files:
|
||||
train_clean : CleanSpeech_training
|
||||
test_clean : CleanSpeech_training
|
||||
train_noisy : NoisySpeech_training
|
||||
test_noisy : NoisySpeech_training
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
stride : 2
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
valid_minutes : 15
|
||||
files:
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
loss : mae
|
||||
metric : [stoi,pesq,si-sdr]
|
||||
lr : 0.0003
|
||||
ReduceLr_patience : 5
|
||||
ReduceLr_factor : 0.2
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
experiment_name : shahules/enhancer
|
||||
run_name : Demucs + Vtck with stride + augmentations
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
_target_: enhancer.models.dccrn.DCCRN
|
||||
num_channels: 1
|
||||
sampling_rate : 16000
|
||||
complex_lstm : True
|
||||
complex_norm : True
|
||||
complex_relu : True
|
||||
masking_mode : True
|
||||
|
||||
encoder_decoder:
|
||||
initial_output_channels : 32
|
||||
depth : 6
|
||||
kernel_size : 5
|
||||
growth_factor : 2
|
||||
stride : 2
|
||||
padding : 2
|
||||
output_padding : 1
|
||||
|
||||
lstm:
|
||||
num_layers : 2
|
||||
hidden_size : 256
|
||||
|
||||
stft:
|
||||
window_len : 400
|
||||
hop_size : 100
|
||||
nfft : 512
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
_target_: enhancer.models.demucs.Demucs
|
||||
num_channels: 1
|
||||
resample: 4
|
||||
sampling_rate : 16000
|
||||
|
||||
encoder_decoder:
|
||||
depth: 4
|
||||
initial_output_channels: 64
|
||||
kernel_size: 8
|
||||
stride: 4
|
||||
growth_factor: 2
|
||||
glu: True
|
||||
|
||||
lstm:
|
||||
bidirectional: False
|
||||
num_layers: 2
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
_target_: enhancer.models.waveunet.WaveUnet
|
||||
num_channels : 1
|
||||
depth : 9
|
||||
initial_output_channels: 24
|
||||
sampling_rate : 16000
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
_target_: torch.optim.Adam
|
||||
lr: 1e-3
|
||||
betas: [0.9, 0.999]
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
amsgrad: False
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 2
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
fast_dev_run: False
|
||||
gpus: null
|
||||
gradient_clip_val: 0
|
||||
gradient_clip_algorithm: norm
|
||||
ipus: null
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
min_steps: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: max_size_cycle
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
overfit_batches: 0.0
|
||||
precision: 32
|
||||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: ddp
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
val_check_interval: 1.0
|
||||
weights_save_path: null
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
fast_dev_run: True
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
## Valentini dataset
|
||||
|
||||
Clean and noisy parallel speech database. The database was designed to train and test speech enhancement methods that operate at 48kHz. A more detailed description can be found in the papers associated with the database.[official page](https://datashare.ed.ac.uk/handle/10283/2791)
|
||||
|
||||
**References**
|
||||
```BibTex
|
||||
@misc{DNSChallenge2020,
|
||||
title={Noisy speech database for training speech enhancement algorithms and TTS models},
|
||||
author={Valentini-Botinhao, Cassia}, year={2017},
|
||||
doi=https://doi.org/10.7488/ds/2117,
|
||||
}
|
||||
```
|
||||
Loading…
Reference in New Issue