339 lines
9.4 KiB
Plaintext
339 lines
9.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ccd61d5c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Custom model training using mayavoz [advanced]\n",
|
|
"\n",
|
|
"In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
|
|
"\n",
|
|
" - [Data preparation using MayaDataset](#dataprep)\n",
|
|
" - [Model customization](#modelcustom)\n",
|
|
" - [callbacks & LR schedulers](#callbacks)\n",
|
|
" - [Model training & testing](#train)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "726c320f",
|
|
"metadata": {},
|
|
"source": [
|
|
"- **install mayavoz**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c987c799",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"! pip install -q mayavoz"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8ff9857b",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div id=\"dataprep\"></div>\n",
|
|
"\n",
|
|
"### Data preparation\n",
|
|
"\n",
|
|
"`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
|
|
"\n",
|
|
"```\n",
|
|
"- VCTK/\n",
|
|
" |__ clean_train_wav/\n",
|
|
" |__ noisy_train_wav/\n",
|
|
" |__ clean_test_wav/\n",
|
|
" |__ noisy_test_wav/\n",
|
|
" \n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "64cbc0c8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from mayavoz.utils import Files\n",
|
|
"file = Files(train_clean=\"clean_train_wav\",\n",
|
|
" train_noisy=\"noisy_train_wav\",\n",
|
|
" test_clean=\"clean_test_wav\",\n",
|
|
" test_noisy=\"noisy_test_wav\")\n",
|
|
"root_dir = \"VCTK\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2d324bd1",
|
|
"metadata": {},
|
|
"source": [
|
|
"- `name`: name of the dataset. \n",
|
|
"- `duration`: control the duration of each audio instance fed into your model.\n",
|
|
"- `stride` is used if set to move the sliding window.\n",
|
|
"- `sampling_rate`: desired sampling rate for audio\n",
|
|
"- `batch_size`: model batch size\n",
|
|
"- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
|
|
"- `matching_function`: there are two types of mapping functions.\n",
|
|
" - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
|
|
" - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "6834941d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"name = \"vctk\"\n",
|
|
"duration : 4.5\n",
|
|
"stride : 2.0\n",
|
|
"sampling_rate : 16000\n",
|
|
"min_valid_minutes : 20.0\n",
|
|
"batch_size : 32\n",
|
|
"matching_function : \"one_to_one\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d08c6bf8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from mayavoz.dataset import MayaDataset\n",
|
|
"dataset = MayaDataset(\n",
|
|
" name=name,\n",
|
|
" root_dir=root_dir,\n",
|
|
" files=files,\n",
|
|
" duration=duration,\n",
|
|
" stride=stride,\n",
|
|
" sampling_rate=sampling_rate,\n",
|
|
" batch_size=batch_size,\n",
|
|
" min_valid_minutes=min_valid_minutes,\n",
|
|
" matching_function=matching_function\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5b315bde",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now your custom dataloader is ready!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "01548fe5",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div id=\"modelcustom\"></div>\n",
|
|
"\n",
|
|
"### Model Customization\n",
|
|
"Now, this is very easy. \n",
|
|
"\n",
|
|
"- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
|
|
" - `WaveUnet`\n",
|
|
" - `Demucs`\n",
|
|
" - `DCCRN`\n",
|
|
"- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n",
|
|
"- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
|
|
"- `dataset`: mayavoz dataset object as prepared earlier.\n",
|
|
"- `loss` : model loss. Multiple loss functions are available.\n",
|
|
"\n",
|
|
" \n",
|
|
" \n",
|
|
"you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
|
|
"\n",
|
|
"mayavoz can accept **custom loss functions**. It should be of the form.\n",
|
|
"```\n",
|
|
"class your_custom_loss(nn.Module):\n",
|
|
" def __init__(self,**kwargs):\n",
|
|
" self.higher_better = False ## loss minimization direction\n",
|
|
" self.name = \"your_loss_name\" ## loss name logging \n",
|
|
" ...\n",
|
|
" def forward(self,prediction, target):\n",
|
|
" loss = ....\n",
|
|
" return loss\n",
|
|
" \n",
|
|
"```\n",
|
|
"\n",
|
|
"- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b36b457c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from mayavoz.models import Demucs\n",
|
|
"model = Demucs(\n",
|
|
" sampling_rate=16000,\n",
|
|
" dataset=dataset,\n",
|
|
" loss=[\"mae\"],\n",
|
|
" metrics=[\"stoi\",\"pesq\"])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1523d638",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div id=\"callbacks\"></div>\n",
|
|
"\n",
|
|
"### learning rate schedulers and callbacks\n",
|
|
"Here I am using `ReduceLROnPlateau`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8de6931c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
|
"\n",
|
|
"def configure_optimizers(self):\n",
|
|
" optimizer = instantiate(\n",
|
|
" config.optimizer,\n",
|
|
" lr=parameters.get(\"lr\"),\n",
|
|
" params=self.parameters(),\n",
|
|
" )\n",
|
|
" scheduler = ReduceLROnPlateau(\n",
|
|
" optimizer=optimizer,\n",
|
|
" mode=direction,\n",
|
|
" factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
|
|
" verbose=True,\n",
|
|
" min_lr=parameters.get(\"min_lr\", 1e-6),\n",
|
|
" patience=parameters.get(\"ReduceLr_patience\", 3),\n",
|
|
" )\n",
|
|
" return {\n",
|
|
" \"optimizer\": optimizer,\n",
|
|
" \"lr_scheduler\": scheduler,\n",
|
|
" \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"model.configure_optimizers = MethodType(configure_optimizers, model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f7b5af5",
|
|
"metadata": {},
|
|
"source": [
|
|
"you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6f6b62a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"callbacks = []\n",
|
|
"direction = model.valid_monitor ## min or max \n",
|
|
"checkpoint = ModelCheckpoint(\n",
|
|
" dirpath=\"./model\",\n",
|
|
" filename=f\"model_filename\",\n",
|
|
" monitor=\"valid_loss\",\n",
|
|
" verbose=False,\n",
|
|
" mode=direction,\n",
|
|
" every_n_epochs=1,\n",
|
|
" )\n",
|
|
"callbacks.append(checkpoint)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f3534445",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div id=\"train\"></div>\n",
|
|
"\n",
|
|
"\n",
|
|
"### Train"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3dc0348b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pytorch_lightning as pl\n",
|
|
"trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
|
|
"trainer.fit(model)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "56dcfec1",
|
|
"metadata": {},
|
|
"source": [
|
|
"- Test your model agaist test dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "63851feb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"trainer.test(model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4d3f5350",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Hurray! you have your speech enhancement model trained and tested.**\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "10d630e8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "enhancer",
|
|
"language": "python",
|
|
"name": "enhancer"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|