advanced tutorial
This commit is contained in:
		
							parent
							
								
									bc13fc03bf
								
							
						
					
					
						commit
						27ddf0bec9
					
				|  | @ -0,0 +1,338 @@ | |||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "ccd61d5c", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "## Custom model training using mayavoz [advanced]\n", | ||||
|     "\n", | ||||
|     "In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n", | ||||
|     "\n", | ||||
|     " - [Data preparation using MayaDataset](#dataprep)\n", | ||||
|     " - [Model customization](#modelcustom)\n", | ||||
|     " - [callbacks & LR schedulers](#callbacks)\n", | ||||
|     " - [Model training & testing](#train)\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "726c320f", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "- **install mayavoz**" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "c987c799", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "! pip install -q mayavoz" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "8ff9857b", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"dataprep\"></div>\n", | ||||
|     "\n", | ||||
|     "### Data preparation\n", | ||||
|     "\n", | ||||
|     "`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n", | ||||
|     "\n", | ||||
|     "```\n", | ||||
|     "- VCTK/\n", | ||||
|     "    |__ clean_train_wav/\n", | ||||
|     "    |__ noisy_train_wav/\n", | ||||
|     "    |__ clean_test_wav/\n", | ||||
|     "    |__ noisy_test_wav/\n", | ||||
|     "    \n", | ||||
|     "```" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "64cbc0c8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from mayavoz.utils import Files\n", | ||||
|     "file = Files(train_clean=\"clean_train_wav\",\n", | ||||
|     "            train_noisy=\"noisy_train_wav\",\n", | ||||
|     "            test_clean=\"clean_test_wav\",\n", | ||||
|     "            test_noisy=\"noisy_test_wav\")\n", | ||||
|     "root_dir = \"VCTK\"" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "2d324bd1", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "- `name`: name of the dataset. \n", | ||||
|     "- `duration`: control the duration of each audio instance fed into your model.\n", | ||||
|     "- `stride` is used if set to move the sliding window.\n", | ||||
|     "- `sampling_rate`: desired sampling rate for audio\n", | ||||
|     "- `batch_size`: model batch size\n", | ||||
|     "- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n", | ||||
|     "- `matching_function`: there are two types of mapping functions.\n", | ||||
|     "    - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n", | ||||
|     "    - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "6834941d", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "name = \"vctk\"\n", | ||||
|     "duration : 4.5\n", | ||||
|     "stride : 2.0\n", | ||||
|     "sampling_rate : 16000\n", | ||||
|     "min_valid_minutes : 20.0\n", | ||||
|     "batch_size : 32\n", | ||||
|     "matching_function : \"one_to_one\"\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "d08c6bf8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from mayavoz.dataset import MayaDataset\n", | ||||
|     "dataset = MayaDataset(\n", | ||||
|     "            name=name,\n", | ||||
|     "            root_dir=root_dir,\n", | ||||
|     "            files=files,\n", | ||||
|     "            duration=duration,\n", | ||||
|     "            stride=stride,\n", | ||||
|     "            sampling_rate=sampling_rate,\n", | ||||
|     "            batch_size=batch_size,\n", | ||||
|     "            min_valid_minutes=min_valid_minutes,\n", | ||||
|     "            matching_function=matching_function\n", | ||||
|     "        )" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "5b315bde", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "Now your custom dataloader is ready!" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "01548fe5", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"modelcustom\"></div>\n", | ||||
|     "\n", | ||||
|     "### Model Customization\n", | ||||
|     "Now, this is very easy. \n", | ||||
|     "\n", | ||||
|     "- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n", | ||||
|     "   - `WaveUnet`\n", | ||||
|     "   - `Demucs`\n", | ||||
|     "   - `DCCRN`\n", | ||||
|     "- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you.   Just check the parameters and pass it to as required.\n", | ||||
|     "- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n", | ||||
|     "- `dataset`: mayavoz dataset object as prepared earlier.\n", | ||||
|     "- `loss` : model loss. Multiple loss functions are available.\n", | ||||
|     "\n", | ||||
|     "        \n", | ||||
|     "        \n", | ||||
|     "you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n", | ||||
|     "\n", | ||||
|     "mayavoz can accept **custom loss functions**. It should be of the form.\n", | ||||
|     "```\n", | ||||
|     "class your_custom_loss(nn.Module):\n", | ||||
|     "    def __init__(self,**kwargs):\n", | ||||
|     "        self.higher_better = False  ## loss minimization direction\n", | ||||
|     "        self.name = \"your_loss_name\" ## loss name logging \n", | ||||
|     "        ...\n", | ||||
|     "    def forward(self,prediction, target):\n", | ||||
|     "        loss = ....\n", | ||||
|     "        return loss\n", | ||||
|     "        \n", | ||||
|     "```\n", | ||||
|     "\n", | ||||
|     "- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "b36b457c", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from mayavoz.models import Demucs\n", | ||||
|     "model = Demucs(\n", | ||||
|     "        sampling_rate=16000,\n", | ||||
|     "        dataset=dataset,\n", | ||||
|     "        loss=[\"mae\"],\n", | ||||
|     "        metrics=[\"stoi\",\"pesq\"])\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "1523d638", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"callbacks\"></div>\n", | ||||
|     "\n", | ||||
|     "### learning rate schedulers and callbacks\n", | ||||
|     "Here I am using `ReduceLROnPlateau`" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "8de6931c", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", | ||||
|     "\n", | ||||
|     "def configure_optimizers(self):\n", | ||||
|     "        optimizer = instantiate(\n", | ||||
|     "            config.optimizer,\n", | ||||
|     "            lr=parameters.get(\"lr\"),\n", | ||||
|     "            params=self.parameters(),\n", | ||||
|     "        )\n", | ||||
|     "        scheduler = ReduceLROnPlateau(\n", | ||||
|     "            optimizer=optimizer,\n", | ||||
|     "            mode=direction,\n", | ||||
|     "            factor=parameters.get(\"ReduceLr_factor\", 0.1),\n", | ||||
|     "            verbose=True,\n", | ||||
|     "            min_lr=parameters.get(\"min_lr\", 1e-6),\n", | ||||
|     "            patience=parameters.get(\"ReduceLr_patience\", 3),\n", | ||||
|     "        )\n", | ||||
|     "        return {\n", | ||||
|     "            \"optimizer\": optimizer,\n", | ||||
|     "            \"lr_scheduler\": scheduler,\n", | ||||
|     "            \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n", | ||||
|     "        }\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "model.configure_optimizers = MethodType(configure_optimizers, model)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "2f7b5af5", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "6f6b62a1", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "callbacks = []\n", | ||||
|     "direction = model.valid_monitor ## min or max \n", | ||||
|     "checkpoint = ModelCheckpoint(\n", | ||||
|     "        dirpath=\"./model\",\n", | ||||
|     "        filename=f\"model_filename\",\n", | ||||
|     "        monitor=\"valid_loss\",\n", | ||||
|     "        verbose=False,\n", | ||||
|     "        mode=direction,\n", | ||||
|     "        every_n_epochs=1,\n", | ||||
|     "    )\n", | ||||
|     "callbacks.append(checkpoint)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "f3534445", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "<div id=\"train\"></div>\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "### Train" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "3dc0348b", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import pytorch_lightning as pl\n", | ||||
|     "trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n", | ||||
|     "trainer.fit(model)\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "56dcfec1", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "- Test your model agaist test dataset" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "63851feb", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "trainer.test(model)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "4d3f5350", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "**Hurray! you have your speech enhancement model trained and tested.**\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "10d630e8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "enhancer", | ||||
|    "language": "python", | ||||
|    "name": "enhancer" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.13" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786