339 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
			
		
		
	
	
			339 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "ccd61d5c",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## Custom model training using mayavoz [advanced]\n",
 | |
|     "\n",
 | |
|     "In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
 | |
|     "\n",
 | |
|     " - [Data preparation using MayaDataset](#dataprep)\n",
 | |
|     " - [Model customization](#modelcustom)\n",
 | |
|     " - [callbacks & LR schedulers](#callbacks)\n",
 | |
|     " - [Model training & testing](#train)\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "726c320f",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- **install mayavoz**"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "c987c799",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "! pip install -q mayavoz"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8ff9857b",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<div id=\"dataprep\"></div>\n",
 | |
|     "\n",
 | |
|     "### Data preparation\n",
 | |
|     "\n",
 | |
|     "`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
 | |
|     "\n",
 | |
|     "```\n",
 | |
|     "- VCTK/\n",
 | |
|     "    |__ clean_train_wav/\n",
 | |
|     "    |__ noisy_train_wav/\n",
 | |
|     "    |__ clean_test_wav/\n",
 | |
|     "    |__ noisy_test_wav/\n",
 | |
|     "    \n",
 | |
|     "```"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 2,
 | |
|    "id": "64cbc0c8",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from mayavoz.utils import Files\n",
 | |
|     "file = Files(train_clean=\"clean_train_wav\",\n",
 | |
|     "            train_noisy=\"noisy_train_wav\",\n",
 | |
|     "            test_clean=\"clean_test_wav\",\n",
 | |
|     "            test_noisy=\"noisy_test_wav\")\n",
 | |
|     "root_dir = \"VCTK\""
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "2d324bd1",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- `name`: name of the dataset. \n",
 | |
|     "- `duration`: control the duration of each audio instance fed into your model.\n",
 | |
|     "- `stride` is used if set to move the sliding window.\n",
 | |
|     "- `sampling_rate`: desired sampling rate for audio\n",
 | |
|     "- `batch_size`: model batch size\n",
 | |
|     "- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
 | |
|     "- `matching_function`: there are two types of mapping functions.\n",
 | |
|     "    - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
 | |
|     "    - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 3,
 | |
|    "id": "6834941d",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "name = \"vctk\"\n",
 | |
|     "duration : 4.5\n",
 | |
|     "stride : 2.0\n",
 | |
|     "sampling_rate : 16000\n",
 | |
|     "min_valid_minutes : 20.0\n",
 | |
|     "batch_size : 32\n",
 | |
|     "matching_function : \"one_to_one\"\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "d08c6bf8",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from mayavoz.dataset import MayaDataset\n",
 | |
|     "dataset = MayaDataset(\n",
 | |
|     "            name=name,\n",
 | |
|     "            root_dir=root_dir,\n",
 | |
|     "            files=files,\n",
 | |
|     "            duration=duration,\n",
 | |
|     "            stride=stride,\n",
 | |
|     "            sampling_rate=sampling_rate,\n",
 | |
|     "            batch_size=batch_size,\n",
 | |
|     "            min_valid_minutes=min_valid_minutes,\n",
 | |
|     "            matching_function=matching_function\n",
 | |
|     "        )"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "5b315bde",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "Now your custom dataloader is ready!"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "01548fe5",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<div id=\"modelcustom\"></div>\n",
 | |
|     "\n",
 | |
|     "### Model Customization\n",
 | |
|     "Now, this is very easy. \n",
 | |
|     "\n",
 | |
|     "- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
 | |
|     "   - `WaveUnet`\n",
 | |
|     "   - `Demucs`\n",
 | |
|     "   - `DCCRN`\n",
 | |
|     "- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you.   Just check the parameters and pass it to as required.\n",
 | |
|     "- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
 | |
|     "- `dataset`: mayavoz dataset object as prepared earlier.\n",
 | |
|     "- `loss` : model loss. Multiple loss functions are available.\n",
 | |
|     "\n",
 | |
|     "        \n",
 | |
|     "        \n",
 | |
|     "you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
 | |
|     "\n",
 | |
|     "mayavoz can accept **custom loss functions**. It should be of the form.\n",
 | |
|     "```\n",
 | |
|     "class your_custom_loss(nn.Module):\n",
 | |
|     "    def __init__(self,**kwargs):\n",
 | |
|     "        self.higher_better = False  ## loss minimization direction\n",
 | |
|     "        self.name = \"your_loss_name\" ## loss name logging \n",
 | |
|     "        ...\n",
 | |
|     "    def forward(self,prediction, target):\n",
 | |
|     "        loss = ....\n",
 | |
|     "        return loss\n",
 | |
|     "        \n",
 | |
|     "```\n",
 | |
|     "\n",
 | |
|     "- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "b36b457c",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from mayavoz.models import Demucs\n",
 | |
|     "model = Demucs(\n",
 | |
|     "        sampling_rate=16000,\n",
 | |
|     "        dataset=dataset,\n",
 | |
|     "        loss=[\"mae\"],\n",
 | |
|     "        metrics=[\"stoi\",\"pesq\"])\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "1523d638",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<div id=\"callbacks\"></div>\n",
 | |
|     "\n",
 | |
|     "### learning rate schedulers and callbacks\n",
 | |
|     "Here I am using `ReduceLROnPlateau`"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "8de6931c",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
 | |
|     "\n",
 | |
|     "def configure_optimizers(self):\n",
 | |
|     "        optimizer = instantiate(\n",
 | |
|     "            config.optimizer,\n",
 | |
|     "            lr=parameters.get(\"lr\"),\n",
 | |
|     "            params=self.parameters(),\n",
 | |
|     "        )\n",
 | |
|     "        scheduler = ReduceLROnPlateau(\n",
 | |
|     "            optimizer=optimizer,\n",
 | |
|     "            mode=direction,\n",
 | |
|     "            factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
 | |
|     "            verbose=True,\n",
 | |
|     "            min_lr=parameters.get(\"min_lr\", 1e-6),\n",
 | |
|     "            patience=parameters.get(\"ReduceLr_patience\", 3),\n",
 | |
|     "        )\n",
 | |
|     "        return {\n",
 | |
|     "            \"optimizer\": optimizer,\n",
 | |
|     "            \"lr_scheduler\": scheduler,\n",
 | |
|     "            \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
 | |
|     "        }\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "model.configure_optimizers = MethodType(configure_optimizers, model)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "2f7b5af5",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "6f6b62a1",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "callbacks = []\n",
 | |
|     "direction = model.valid_monitor ## min or max \n",
 | |
|     "checkpoint = ModelCheckpoint(\n",
 | |
|     "        dirpath=\"./model\",\n",
 | |
|     "        filename=f\"model_filename\",\n",
 | |
|     "        monitor=\"valid_loss\",\n",
 | |
|     "        verbose=False,\n",
 | |
|     "        mode=direction,\n",
 | |
|     "        every_n_epochs=1,\n",
 | |
|     "    )\n",
 | |
|     "callbacks.append(checkpoint)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "f3534445",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<div id=\"train\"></div>\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "### Train"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "3dc0348b",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "import pytorch_lightning as pl\n",
 | |
|     "trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
 | |
|     "trainer.fit(model)\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "56dcfec1",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Test your model agaist test dataset"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "63851feb",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "trainer.test(model)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "4d3f5350",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "**Hurray! you have your speech enhancement model trained and tested.**\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "10d630e8",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": []
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "enhancer",
 | |
|    "language": "python",
 | |
|    "name": "enhancer"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 3
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython3",
 | |
|    "version": "3.8.13"
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 5
 | |
| }
 |