{ "cells": [ { "cell_type": "markdown", "id": "ccd61d5c", "metadata": {}, "source": [ "## Custom model training using mayavoz [advanced]\n", "\n", "In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n", "\n", " - [Data preparation using MayaDataset](#dataprep)\n", " - [Model customization](#modelcustom)\n", " - [callbacks & LR schedulers](#callbacks)\n", " - [Model training & testing](#train)\n" ] }, { "cell_type": "markdown", "id": "726c320f", "metadata": {}, "source": [ "- **install mayavoz**" ] }, { "cell_type": "code", "execution_count": null, "id": "c987c799", "metadata": {}, "outputs": [], "source": [ "! pip install -q mayavoz" ] }, { "cell_type": "markdown", "id": "8ff9857b", "metadata": {}, "source": [ "
\n", "\n", "### Data preparation\n", "\n", "`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n", "\n", "```\n", "- VCTK/\n", " |__ clean_train_wav/\n", " |__ noisy_train_wav/\n", " |__ clean_test_wav/\n", " |__ noisy_test_wav/\n", " \n", "```" ] }, { "cell_type": "code", "execution_count": 2, "id": "64cbc0c8", "metadata": {}, "outputs": [], "source": [ "from mayavoz.utils import Files\n", "file = Files(train_clean=\"clean_train_wav\",\n", " train_noisy=\"noisy_train_wav\",\n", " test_clean=\"clean_test_wav\",\n", " test_noisy=\"noisy_test_wav\")\n", "root_dir = \"VCTK\"" ] }, { "cell_type": "markdown", "id": "2d324bd1", "metadata": {}, "source": [ "- `name`: name of the dataset. \n", "- `duration`: control the duration of each audio instance fed into your model.\n", "- `stride` is used if set to move the sliding window.\n", "- `sampling_rate`: desired sampling rate for audio\n", "- `batch_size`: model batch size\n", "- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n", "- `matching_function`: there are two types of mapping functions.\n", " - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n", " - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "6834941d", "metadata": {}, "outputs": [], "source": [ "name = \"vctk\"\n", "duration : 4.5\n", "stride : 2.0\n", "sampling_rate : 16000\n", "min_valid_minutes : 20.0\n", "batch_size : 32\n", "matching_function : \"one_to_one\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d08c6bf8", "metadata": {}, "outputs": [], "source": [ "from mayavoz.dataset import MayaDataset\n", "dataset = MayaDataset(\n", " name=name,\n", " root_dir=root_dir,\n", " files=files,\n", " duration=duration,\n", " stride=stride,\n", " sampling_rate=sampling_rate,\n", " batch_size=batch_size,\n", " min_valid_minutes=min_valid_minutes,\n", " matching_function=matching_function\n", " )" ] }, { "cell_type": "markdown", "id": "5b315bde", "metadata": {}, "source": [ "Now your custom dataloader is ready!" ] }, { "cell_type": "markdown", "id": "01548fe5", "metadata": {}, "source": [ "\n", "\n", "### Model Customization\n", "Now, this is very easy. \n", "\n", "- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n", " - `WaveUnet`\n", " - `Demucs`\n", " - `DCCRN`\n", "- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n", "- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n", "- `dataset`: mayavoz dataset object as prepared earlier.\n", "- `loss` : model loss. Multiple loss functions are available.\n", "\n", " \n", " \n", "you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n", "\n", "mayavoz can accept **custom loss functions**. It should be of the form.\n", "```\n", "class your_custom_loss(nn.Module):\n", " def __init__(self,**kwargs):\n", " self.higher_better = False ## loss minimization direction\n", " self.name = \"your_loss_name\" ## loss name logging \n", " ...\n", " def forward(self,prediction, target):\n", " loss = ....\n", " return loss\n", " \n", "```\n", "\n", "- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b36b457c", "metadata": {}, "outputs": [], "source": [ "from mayavoz.models import Demucs\n", "model = Demucs(\n", " sampling_rate=16000,\n", " dataset=dataset,\n", " loss=[\"mae\"],\n", " metrics=[\"stoi\",\"pesq\"])\n" ] }, { "cell_type": "markdown", "id": "1523d638", "metadata": {}, "source": [ "\n", "\n", "### learning rate schedulers and callbacks\n", "Here I am using `ReduceLROnPlateau`" ] }, { "cell_type": "code", "execution_count": null, "id": "8de6931c", "metadata": {}, "outputs": [], "source": [ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "\n", "def configure_optimizers(self):\n", " optimizer = instantiate(\n", " config.optimizer,\n", " lr=parameters.get(\"lr\"),\n", " params=self.parameters(),\n", " )\n", " scheduler = ReduceLROnPlateau(\n", " optimizer=optimizer,\n", " mode=direction,\n", " factor=parameters.get(\"ReduceLr_factor\", 0.1),\n", " verbose=True,\n", " min_lr=parameters.get(\"min_lr\", 1e-6),\n", " patience=parameters.get(\"ReduceLr_patience\", 3),\n", " )\n", " return {\n", " \"optimizer\": optimizer,\n", " \"lr_scheduler\": scheduler,\n", " \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n", " }\n", "\n", "\n", "model.configure_optimizers = MethodType(configure_optimizers, model)" ] }, { "cell_type": "markdown", "id": "2f7b5af5", "metadata": {}, "source": [ "you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`" ] }, { "cell_type": "code", "execution_count": null, "id": "6f6b62a1", "metadata": {}, "outputs": [], "source": [ "callbacks = []\n", "direction = model.valid_monitor ## min or max \n", "checkpoint = ModelCheckpoint(\n", " dirpath=\"./model\",\n", " filename=f\"model_filename\",\n", " monitor=\"valid_loss\",\n", " verbose=False,\n", " mode=direction,\n", " every_n_epochs=1,\n", " )\n", "callbacks.append(checkpoint)" ] }, { "cell_type": "markdown", "id": "f3534445", "metadata": {}, "source": [ "\n", "\n", "\n", "### Train" ] }, { "cell_type": "code", "execution_count": null, "id": "3dc0348b", "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n", "trainer.fit(model)\n" ] }, { "cell_type": "markdown", "id": "56dcfec1", "metadata": {}, "source": [ "- Test your model agaist test dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "63851feb", "metadata": {}, "outputs": [], "source": [ "trainer.test(model)" ] }, { "cell_type": "markdown", "id": "4d3f5350", "metadata": {}, "source": [ "**Hurray! you have your speech enhancement model trained and tested.**\n" ] }, { "cell_type": "code", "execution_count": null, "id": "10d630e8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "mayavoz", "language": "python", "name": "mayavoz" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }