## Custom model training using mayavoz [advanced]

In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. 

 - [Data preparation using MayaDataset](#dataprep)
 - [Model customization](#modelcustom)
 - [callbacks & LR schedulers](#callbacks)
 - [Model training & testing](#train)


- **install mayavoz**

In [None]:
! pip install -q mayavoz

<div id="dataprep"></div>

### Data preparation

`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example

```
- VCTK/
    |__ clean_train_wav/
    |__ noisy_train_wav/
    |__ clean_test_wav/
    |__ noisy_test_wav/
    
```

In [2]:
from mayavoz.utils import Files
file = Files(train_clean="clean_train_wav",
            train_noisy="noisy_train_wav",
            test_clean="clean_test_wav",
            test_noisy="noisy_test_wav")
root_dir = "VCTK"

- `name`: name of the dataset. 
- `duration`: control the duration of each audio instance fed into your model.
- `stride` is used if set to move the sliding window.
- `sampling_rate`: desired sampling rate for audio
- `batch_size`: model batch size
- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).
- `matching_function`: there are two types of mapping functions.
    - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets
    - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset.


In [3]:
name = "vctk"
duration : 4.5
stride : 2.0
sampling_rate : 16000
min_valid_minutes : 20.0
batch_size : 32
matching_function : "one_to_one"


In [None]:
from mayavoz.dataset import MayaDataset
dataset = MayaDataset(
            name=name,
            root_dir=root_dir,
            files=files,
            duration=duration,
            stride=stride,
            sampling_rate=sampling_rate,
            batch_size=batch_size,
            min_valid_minutes=min_valid_minutes,
            matching_function=matching_function
        )

Now your custom dataloader is ready!

<div id="modelcustom"></div>

### Model Customization
Now, this is very easy. 

- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.
   - `WaveUnet`
   - `Demucs`
   - `DCCRN`
- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you.   Just check the parameters and pass it to as required.
- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)
- `dataset`: mayavoz dataset object as prepared earlier.
- `loss` : model loss. Multiple loss functions are available.

        
        
you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `["mae","mse"]`. Available loss functions are `mse`,`mae`,`si-snr`.

mayavoz can accept **custom loss functions**. It should be of the form.
```
class your_custom_loss(nn.Module):
    def __init__(self,**kwargs):
        self.higher_better = False  ## loss minimization direction
        self.name = "your_loss_name" ## loss name logging 
        ...
    def forward(self,prediction, target):
        loss = ....
        return loss
        
```

- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.


In [None]:
from mayavoz.models import Demucs
model = Demucs(
        sampling_rate=16000,
        dataset=dataset,
        loss=["mae"],
        metrics=["stoi","pesq"])


<div id="callbacks"></div>

### learning rate schedulers and callbacks
Here I am using `ReduceLROnPlateau`

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

def configure_optimizers(self):
        optimizer = instantiate(
            config.optimizer,
            lr=parameters.get("lr"),
            params=self.parameters(),
        )
        scheduler = ReduceLROnPlateau(
            optimizer=optimizer,
            mode=direction,
            factor=parameters.get("ReduceLr_factor", 0.1),
            verbose=True,
            min_lr=parameters.get("min_lr", 1e-6),
            patience=parameters.get("ReduceLr_patience", 3),
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
        }


model.configure_optimizers = MethodType(configure_optimizers, model)

you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`

In [None]:
callbacks = []
direction = model.valid_monitor ## min or max 
checkpoint = ModelCheckpoint(
        dirpath="./model",
        filename=f"model_filename",
        monitor="valid_loss",
        verbose=False,
        mode=direction,
        every_n_epochs=1,
    )
callbacks.append(checkpoint)

<div id="train"></div>


### Train

In [None]:
import pytorch_lightning as pl
trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator="gpu")
trainer.fit(model)


- Test your model agaist test dataset

In [None]:
trainer.test(model)

**Hurray! you have your speech enhancement model trained and tested.**
