diff --git a/notebooks/Custom_model_training.ipynb b/notebooks/Custom_model_training.ipynb new file mode 100644 index 0000000..7c963c2 --- /dev/null +++ b/notebooks/Custom_model_training.ipynb @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ccd61d5c", + "metadata": {}, + "source": [ + "## Custom model training using mayavoz [advanced]\n", + "\n", + "In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n", + "\n", + " - [Data preparation using MayaDataset](#dataprep)\n", + " - [Model customization](#modelcustom)\n", + " - [callbacks & LR schedulers](#callbacks)\n", + " - [Model training & testing](#train)\n" + ] + }, + { + "cell_type": "markdown", + "id": "726c320f", + "metadata": {}, + "source": [ + "- **install mayavoz**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c987c799", + "metadata": {}, + "outputs": [], + "source": [ + "! pip install -q mayavoz" + ] + }, + { + "cell_type": "markdown", + "id": "8ff9857b", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Data preparation\n", + "\n", + "`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n", + "\n", + "```\n", + "- VCTK/\n", + " |__ clean_train_wav/\n", + " |__ noisy_train_wav/\n", + " |__ clean_test_wav/\n", + " |__ noisy_test_wav/\n", + " \n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "64cbc0c8", + "metadata": {}, + "outputs": [], + "source": [ + "from mayavoz.utils import Files\n", + "file = Files(train_clean=\"clean_train_wav\",\n", + " train_noisy=\"noisy_train_wav\",\n", + " test_clean=\"clean_test_wav\",\n", + " test_noisy=\"noisy_test_wav\")\n", + "root_dir = \"VCTK\"" + ] + }, + { + "cell_type": "markdown", + "id": "2d324bd1", + "metadata": {}, + "source": [ + "- `name`: name of the dataset. \n", + "- `duration`: control the duration of each audio instance fed into your model.\n", + "- `stride` is used if set to move the sliding window.\n", + "- `sampling_rate`: desired sampling rate for audio\n", + "- `batch_size`: model batch size\n", + "- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n", + "- `matching_function`: there are two types of mapping functions.\n", + " - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n", + " - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6834941d", + "metadata": {}, + "outputs": [], + "source": [ + "name = \"vctk\"\n", + "duration : 4.5\n", + "stride : 2.0\n", + "sampling_rate : 16000\n", + "min_valid_minutes : 20.0\n", + "batch_size : 32\n", + "matching_function : \"one_to_one\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d08c6bf8", + "metadata": {}, + "outputs": [], + "source": [ + "from mayavoz.dataset import MayaDataset\n", + "dataset = MayaDataset(\n", + " name=name,\n", + " root_dir=root_dir,\n", + " files=files,\n", + " duration=duration,\n", + " stride=stride,\n", + " sampling_rate=sampling_rate,\n", + " batch_size=batch_size,\n", + " min_valid_minutes=min_valid_minutes,\n", + " matching_function=matching_function\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "5b315bde", + "metadata": {}, + "source": [ + "Now your custom dataloader is ready!" + ] + }, + { + "cell_type": "markdown", + "id": "01548fe5", + "metadata": {}, + "source": [ + "\n", + "\n", + "### Model Customization\n", + "Now, this is very easy. \n", + "\n", + "- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n", + " - `WaveUnet`\n", + " - `Demucs`\n", + " - `DCCRN`\n", + "- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n", + "- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n", + "- `dataset`: mayavoz dataset object as prepared earlier.\n", + "- `loss` : model loss. Multiple loss functions are available.\n", + "\n", + " \n", + " \n", + "you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n", + "\n", + "mayavoz can accept **custom loss functions**. It should be of the form.\n", + "```\n", + "class your_custom_loss(nn.Module):\n", + " def __init__(self,**kwargs):\n", + " self.higher_better = False ## loss minimization direction\n", + " self.name = \"your_loss_name\" ## loss name logging \n", + " ...\n", + " def forward(self,prediction, target):\n", + " loss = ....\n", + " return loss\n", + " \n", + "```\n", + "\n", + "- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b36b457c", + "metadata": {}, + "outputs": [], + "source": [ + "from mayavoz.models import Demucs\n", + "model = Demucs(\n", + " sampling_rate=16000,\n", + " dataset=dataset,\n", + " loss=[\"mae\"],\n", + " metrics=[\"stoi\",\"pesq\"])\n" + ] + }, + { + "cell_type": "markdown", + "id": "1523d638", + "metadata": {}, + "source": [ + "\n", + "\n", + "### learning rate schedulers and callbacks\n", + "Here I am using `ReduceLROnPlateau`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8de6931c", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", + "\n", + "def configure_optimizers(self):\n", + " optimizer = instantiate(\n", + " config.optimizer,\n", + " lr=parameters.get(\"lr\"),\n", + " params=self.parameters(),\n", + " )\n", + " scheduler = ReduceLROnPlateau(\n", + " optimizer=optimizer,\n", + " mode=direction,\n", + " factor=parameters.get(\"ReduceLr_factor\", 0.1),\n", + " verbose=True,\n", + " min_lr=parameters.get(\"min_lr\", 1e-6),\n", + " patience=parameters.get(\"ReduceLr_patience\", 3),\n", + " )\n", + " return {\n", + " \"optimizer\": optimizer,\n", + " \"lr_scheduler\": scheduler,\n", + " \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n", + " }\n", + "\n", + "\n", + "model.configure_optimizers = MethodType(configure_optimizers, model)" + ] + }, + { + "cell_type": "markdown", + "id": "2f7b5af5", + "metadata": {}, + "source": [ + "you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f6b62a1", + "metadata": {}, + "outputs": [], + "source": [ + "callbacks = []\n", + "direction = model.valid_monitor ## min or max \n", + "checkpoint = ModelCheckpoint(\n", + " dirpath=\"./model\",\n", + " filename=f\"model_filename\",\n", + " monitor=\"valid_loss\",\n", + " verbose=False,\n", + " mode=direction,\n", + " every_n_epochs=1,\n", + " )\n", + "callbacks.append(checkpoint)" + ] + }, + { + "cell_type": "markdown", + "id": "f3534445", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "### Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dc0348b", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n", + "trainer.fit(model)\n" + ] + }, + { + "cell_type": "markdown", + "id": "56dcfec1", + "metadata": {}, + "source": [ + "- Test your model agaist test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63851feb", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test(model)" + ] + }, + { + "cell_type": "markdown", + "id": "4d3f5350", + "metadata": {}, + "source": [ + "**Hurray! you have your speech enhancement model trained and tested.**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10d630e8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "enhancer", + "language": "python", + "name": "enhancer" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}