Compare commits

..

No commits in common. "main" and "dev-reformat" have entirely different histories.

105 changed files with 487 additions and 3995 deletions

View File

@ -1,5 +1,5 @@
[flake8] [flake8]
per-file-ignores = "mayavoz/model/__init__.py:F401" per-file-ignores = __init__.py:F401
ignore = E203, E266, E501, W503 ignore = E203, E266, E501, W503
# line length is intentionally set to 80 here because black uses Bugbear # line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details # See https://github.com/psf/black/blob/master/README.md#line-length for more details

1
.gitattributes vendored
View File

@ -1 +0,0 @@
notebooks/** linguist-vendored

View File

@ -1,51 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: mayavoz
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1.1.1
env :
ACTIONS_ALLOW_UNSECURE_COMMANDS : true
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v1
with:
path: ~/.cache/pip # This path is specific to Ubuntu
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
# You can test your matrix by printing the current Python version
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo apt-get install libsndfile1
pip install -r requirements.txt
pip install black pytest-cov
- name: Install mayavoz
run: |
pip install -e .[dev,testing]
- name: Run black
run:
black --check . --exclude mayavoz/version.py
- name: Test with pytest
run:
pytest tests --cov=mayavoz/

6
.gitignore vendored
View File

@ -1,10 +1,4 @@
#local #local
cleaned_my_voice.wav
lightning_logs/
my_voice.wav
pretrained/
*.ckpt
*_local.yaml
cli/train_config/dataset/Vctk_local.yaml cli/train_config/dataset/Vctk_local.yaml
.DS_Store .DS_Store
outputs/ outputs/

View File

@ -23,7 +23,6 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
args: ['--ignore=E203,E501,F811,E712,W503'] args: ['--ignore=E203,E501,F811,E712,W503']
exclude: __init__.py
# Formatting, Whitespace, etc # Formatting, Whitespace, etc
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks

View File

@ -1,46 +0,0 @@
# Contributing
Hi there 👋
If you're reading this I hope that you're looking forward to adding value to Mayavoz. This document will help you to get started with your journey.
## How to get your code in Mayavoz
1. We use git and GitHub.
2. Fork the mayavoz repository (https://github.com/shahules786/mayavoz) on GitHub under your own account. (This creates a copy of mayavoz under your account, and GitHub knows where it came from, and we typically call this “upstream”.)
3. Clone your own mayavoz repository. git clone https://github.com/ <your-account> /mayavoz (This downloads the git repository to your machine, git knows where it came from, and calls it “origin”.)
4. Create a branch for each specific feature you are developing. git checkout -b your-branch-name
5. Make + commit changes. git add files-you-changed ... git commit -m "Short message about what you did"
6. Push the branch to your GitHub repository. git push origin your-branch-name
7. Navigate to GitHub, and create a pull request from your branch to the upstream repository mayavoz/mayavoz, to the “develop” branch.
8. The Pull Request (PR) appears on the upstream repository. Discuss your contribution there. If you push more changes to your branch on GitHub (on your repository), they are added to the PR.
9. When the reviewer is satisfied that the code improves repository quality, they can merge.
Note that CI tests will be run when you create a PR. If you want to be sure that your code will not fail these tests, we have set up pre-commit hooks that you can install.
**If you're worried about things not being perfect with your code, we will work togethor and make it perfect. So, make your move!**
## Formating
We use [black](https://black.readthedocs.io/en/stable/) and [flake8](https://flake8.pycqa.org/en/latest/) for code formating. Please ensure that you use the same before submitting the PR.
## Testing
We adopt unit testing using [pytest](https://docs.pytest.org/en/latest/contents.html)
Please make sure that adding your new component does not decrease test coverage.
## Other tools
The use of [per-commit](https://pre-commit.com/) is recommended to ensure different requirements such as code formating, etc.
## How to start contributing to Mayavoz?
1. Checkout issues marked as `good first issue`, let us know you're interested in working on some issue by commenting under it.
2. For others, I would suggest you to explore mayavoz. One way to do is to use it to train your own model. This was you might end by finding a new unreported bug or getting an idea to improve Mayavoz.

20
LICENSE
View File

@ -1,20 +0,0 @@
MIT License
Copyright (c) 2022 Shahul Es
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,4 +0,0 @@
recursive-include mayavoz *.py
recursive-include mayavoz *.yaml
global-exclude *.pyc
global-exclude __pycache__

View File

@ -1,78 +1,6 @@
<p align="center"> # enhancer
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" /> Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides
</p>
![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/shahules786/mayavoz/ci.yaml?branch=main) * Various pretrained models nicely integrated with huggingface that users can select and use without any hastle.
![GitHub](https://img.shields.io/github/license/shahules786/enhancer) * Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
![GitHub issues](https://img.shields.io/github/issues/shahules786/enhancer?logo=GitHub) * A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
![GitHub Repo stars](https://img.shields.io/github/stars/shahules786/enhancer?style=social)
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio practioners & researchers. It provides easy to use pretrained speech enhancement models and facilitates highly customisable model training.
| **[Quick Start](#quick-start-fire)** | **[Installation](#installation)** | **[Tutorials](https://github.com/shahules786/enhancer/tree/main/notebooks)** | **[Available Recipes](#recipes)** | **[Demo](#demo)**
## Key features :key:
* Various pretrained models nicely integrated with [huggingface hub](https://huggingface.co/docs/hub/index) :hugs: that users can select and use without any hastle.
* :package: Ability to train and validate your own custom speech enhancement models with just under 10 lines of code!
* :magic_wand: A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
* :zap: Supports multi-gpu training integrated with [Pytorch Lightning](https://pytorchlightning.ai/).
* :shield: data augmentations integrated using [torch-augmentations](https://github.com/asteroid-team/torch-audiomentations)
## Demo
Noisy speech followed by enhanced version.
https://user-images.githubusercontent.com/25312635/203756185-737557f4-6e21-4146-aa2c-95da69d0de4c.mp4
## Quick Start :fire:
``` python
from mayavoz.models import Mayamodel
model = Mayamodel.from_pretrained("shahules786/mayavoz-waveunet-valentini-28spk")
model.enhance("noisy_audio.wav")
```
## Recipes
| Model | Dataset | STOI | PESQ | URL |
| :---: | :---: | :---: | :---: | :---: |
| WaveUnet | Valentini-28spk | 0.836 | 2.78 | shahules786/mayavoz-waveunet-valentini-28spk |
| Demucs | Valentini-28spk | 0.961 | 2.56 | shahules786/mayavoz-demucs-valentini-28spk |
| DCCRN | Valentini-28spk | 0.724 | 2.55 | shahules786/mayavoz-dccrn-valentini-28spk |
| Demucs | MS-SNSD-20hrs | 0.56 | 1.26 | shahules786/mayavoz-demucs-ms-snsd-20 |
Test scores are based on respective test set associated with train dataset.
**See [tutorials](/notebooks/) to train your custom model**
## Installation
Only Python 3.8+ is officially supported (though it might work with Python 3.7)
- With Pypi
```
pip install mayavoz
```
- With conda
```
conda env create -f environment.yml
conda activate mayavoz
```
- From source code
```
git clone url
cd mayavoz
pip install -e .
```
## Support
For commercial enquiries and scientific consulting, please [contact me](https://shahules786.github.io/).
### Acknowledgements
Sincere gratitude to [AMPLYFI](https://amplyfi.com/) for supporting this project.

View File

@ -3,17 +3,11 @@ from types import MethodType
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig
from pytorch_lightning.callbacks import ( from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1" os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0") JOB_ID = os.environ.get("SLURM_JOBID", "0")
@ -21,8 +15,6 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config") @hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig): def main(config: DictConfig):
OmegaConf.save(config, "config_log.yaml")
callbacks = [] callbacks = []
logger = MLFlowLogger( logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name, experiment_name=config.mlflow.experiment_name,
@ -31,13 +23,8 @@ def main(config: DictConfig):
) )
parameters = config.hyperparameters parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None) dataset = instantiate(config.dataset)
model = instantiate( model = instantiate(
config.model, config.model,
dataset=dataset, dataset=dataset,
@ -50,30 +37,27 @@ def main(config: DictConfig):
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
dirpath="./model", dirpath="./model",
filename=f"model_{JOB_ID}", filename=f"model_{JOB_ID}",
monitor="valid_loss", monitor="val_loss",
verbose=False, verbose=True,
mode=direction, mode=direction,
every_n_epochs=1, every_n_epochs=1,
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch")) early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
if parameters.get("Early_stop", False): def configure_optimizer(self):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate( optimizer = instantiate(
config.optimizer, config.optimizer,
lr=parameters.get("lr"), lr=parameters.get("lr"),
params=self.parameters(), parameters=self.parameters(),
) )
scheduler = ReduceLROnPlateau( scheduler = ReduceLROnPlateau(
optimizer=optimizer, optimizer=optimizer,
@ -83,37 +67,18 @@ def main(config: DictConfig):
min_lr=parameters.get("min_lr", 1e-6), min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3), patience=parameters.get("ReduceLr_patience", 3),
) )
return { return {"optimizer": optimizer, "lr_scheduler": scheduler}
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model) model.configure_parameters = MethodType(configure_optimizer, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model) trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
)
saved_location = os.path.join( saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
) )
if os.path.isfile(saved_location): if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location) logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,12 @@
_target_: enhancer.data.dataset.EnhancerDataset
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
name : dns-2020
duration : 1.0
sampling_rate: 16000
batch_size: 32
files:
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
train_clean : clean_test_wav
test_clean : clean_test_wav
train_noisy : clean_test_wav
test_noisy : clean_test_wav

View File

@ -1,11 +1,10 @@
_target_: mayavoz.data.dataset.MayaDataset _target_: enhancer.data.dataset.EnhancerDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 2 duration : 1.0
stride : 1
sampling_rate: 16000 sampling_rate: 16000
batch_size: 128 batch_size: 64
valid_minutes : 25
files: files:
train_clean : clean_trainset_28spk_wav train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav test_clean : clean_testset_wav

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
duration : 1.0
sampling_rate: 16000
batch_size: 64
num_workers : 0
files:
train_clean : clean_testset_wav
test_clean : clean_testset_wav
train_noisy : noisy_testset_wav
test_noisy : noisy_testset_wav

View File

@ -0,0 +1,7 @@
loss : mse
metric : mae
lr : 0.0001
ReduceLr_patience : 5
ReduceLr_factor : 0.1
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -0,0 +1,2 @@
experiment_name : shahules/enhancer
run_name : baseline

View File

@ -1,13 +1,13 @@
_target_: mayavoz.models.demucs.Demucs _target_: enhancer.models.demucs.Demucs
num_channels: 1 num_channels: 1
resample: 4 resample: 2
sampling_rate : 16000 sampling_rate : 16000
encoder_decoder: encoder_decoder:
depth: 4 depth: 5
initial_output_channels: 64 initial_output_channels: 32
kernel_size: 8 kernel_size: 8
stride: 4 stride: 1
growth_factor: 2 growth_factor: 2
glu: True glu: True

View File

@ -1,5 +1,5 @@
_target_: mayavoz.models.waveunet.WaveUnet _target_: enhancer.models.waveunet.WaveUnet
num_channels : 1 num_channels : 1
depth : 9 depth : 12
initial_output_channels: 24 initial_output_channels: 24
sampling_rate : 16000 sampling_rate : 16000

View File

@ -1,5 +1,5 @@
_target_: pytorch_lightning.Trainer _target_: pytorch_lightning.Trainer
accelerator: gpu accelerator: auto
accumulate_grad_batches: 1 accumulate_grad_batches: 1
amp_backend: native amp_backend: native
auto_lr_find: True auto_lr_find: True
@ -9,7 +9,7 @@ benchmark: False
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
detect_anomaly: False detect_anomaly: False
deterministic: False deterministic: False
devices: 1 devices: -1
enable_checkpointing: True enable_checkpointing: True
enable_model_summary: True enable_model_summary: True
enable_progress_bar: True enable_progress_bar: True
@ -22,9 +22,9 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0 limit_test_batches: 1.0
limit_train_batches: 1.0 limit_train_batches: 1.0
limit_val_batches: 1.0 limit_val_batches: 1.0
log_every_n_steps: 50 log_every_n_steps: 1
max_epochs: 200 max_epochs: 10
max_steps: -1 max_steps: null
max_time: null max_time: null
min_epochs: 1 min_epochs: 1
min_steps: null min_steps: null

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

220
enhancer/data/dataset.py Normal file
View File

@ -0,0 +1,220 @@
import math
import multiprocessing
import os
from typing import Optional
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
class TrainDataset(IterableDataset):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
return self.dataset.train__iter__()
def __len__(self):
return self.dataset.train__len__()
class ValidDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.val__getitem__(idx)
def __len__(self):
return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name: str,
root_dir: str,
files: Files,
duration: float = 1.0,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
):
super().__init__()
self.name = name
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
if stage in ("fit", None):
train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict()
for item in val_data:
clean, noisy, total_dur = item.values()
if total_dur < self.duration:
continue
num_segments = round(total_dur / self.duration)
for index in range(num_segments):
start_time = index * self.duration
self._validation.append(
({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self):
return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset):
"""
Dataset object for creating clean-noisy speech enhancement datasets
paramters:
name : str
name of the dataset
root_dir : str
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass)
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
batch_size : int
batch size of each batch
num_workers : int
num workers to be used while training
matching_function : str
maching functions - (one_to_one,one_to_many). Default set to None.
use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
self,
name: str,
root_dir: str,
files: Files,
duration=1.0,
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
):
super().__init__(
name=name,
root_dir=root_dir,
files=files,
sampling_rate=sampling_rate,
duration=duration,
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
)
self.sampling_rate = sampling_rate
self.files = files
self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
def setup(self, stage: Optional[str] = None):
super().setup(stage=stage)
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
while True:
file_dict, *_ = rng.choices(
self.train_data,
k=1,
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time)
yield data
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
noisy_segment = self.audio(
file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self):
return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self):
return len(self._validation)

View File

@ -55,31 +55,32 @@ class ProcessorFunctions:
One clean audio have multiple noisy audio files One clean audio have multiple noisy audio files
""" """
matching_wavfiles = list() matching_wavfiles = dict()
clean_filenames = [ clean_filenames = [
file.split("/")[-1] file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav")) for file in glob.glob(os.path.join(clean_path, "*.wav"))
] ]
for clean_file in clean_filenames: for clean_file in clean_filenames:
noisy_filenames = glob.glob( noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}") os.path.join(noisy_path, f"*_{clean_file}.wav")
) )
for noisy_file in noisy_filenames: for noisy_file in noisy_filenames:
sr_clean, clean_wav = wavfile.read( sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file) os.path.join(clean_path, clean_file)
) )
sr_noisy, noisy_wav = wavfile.read(noisy_file) sr_noisy, noisy_file = wavfile.read(noisy_file)
if (clean_wav.shape[-1] == noisy_wav.shape[-1]) and ( if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy sr_clean == sr_noisy
): ):
matching_wavfiles.append( matching_wavfiles.update(
{ {
"clean": os.path.join(clean_path, clean_file), "clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file, "noisy": noisy_file,
"duration": clean_wav.shape[-1] / sr_clean, "duration": clean_file.shape[-1] / sr_clean,
} }
) )
return matching_wavfiles return matching_wavfiles
@ -93,14 +94,10 @@ class Fileprocessor:
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None: if matching_function is None:
if name.lower() in ("vctk", "valentini"): if name.lower() == "vctk":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "ms-snsd": elif name.lower() == "dns-2020":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else:
raise ValueError(
f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
)
else: else:
if matching_function not in MATCHING_FNS: if matching_function not in MATCHING_FNS:
raise ValueError( raise ValueError(

View File

@ -8,7 +8,7 @@ from librosa import load as load_audio
from scipy.io import wavfile from scipy.io import wavfile
from scipy.signal import get_window from scipy.signal import get_window
from mayavoz.utils import Audio from enhancer.utils import Audio
class Inference: class Inference:
@ -91,11 +91,10 @@ class Inference:
window_size: int, window_size: int,
total_frames: int, total_frames: int,
step_size: Optional[int] = None, step_size: Optional[int] = None,
window="hamming", window="hanning",
): ):
""" """
stitch batched waveform into single waveform. (Overlap-add) stitch batched waveform into single waveform. (Overlap-add)
inspired from https://github.com/asteroid-team/asteroid
arguments: arguments:
data: batched waveform data: batched waveform
window_size : window_size used to batch waveform window_size : window_size used to batch waveform
@ -140,9 +139,7 @@ class Inference:
if filename.is_file(): if filename.is_file():
raise FileExistsError(f"file {filename} already exists") raise FileExistsError(f"file {filename} already exists")
else: else:
wavfile.write( wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
filename, rate=sr, data=waveform.detach().cpu().numpy()
)
@staticmethod @staticmethod
def prepare_output( def prepare_output(

View File

@ -1,11 +1,5 @@
import warnings
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchmetrics import ScaleInvariantSignalNoiseRatio
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
class mean_squared_error(nn.Module): class mean_squared_error(nn.Module):
@ -18,7 +12,6 @@ class mean_squared_error(nn.Module):
self.loss_fun = nn.MSELoss(reduction=reduction) self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False self.higher_better = False
self.name = "mse"
def forward(self, prediction: torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -41,7 +34,6 @@ class mean_absolute_error(nn.Module):
self.loss_fun = nn.L1Loss(reduction=reduction) self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False self.higher_better = False
self.name = "mae"
def forward(self, prediction: torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -54,22 +46,22 @@ class mean_absolute_error(nn.Module):
return self.loss_fun(prediction, target) return self.loss_fun(prediction, target)
class Si_SDR: class Si_SDR(nn.Module):
""" """
SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
""" """
def __init__(self, reduction: str = "mean"): def __init__(self, reduction: str = "mean"):
super().__init__()
if reduction in ["sum", "mean", None]: if reduction in ["sum", "mean", None]:
self.reduction = reduction self.reduction = reduction
else: else:
raise TypeError( raise TypeError(
"Invalid reduction, valid options are sum, mean, None" "Invalid reduction, valid options are sum, mean, None"
) )
self.higher_better = True self.higher_better = False
self.name = "si-sdr"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3: if prediction.size() != target.size() or target.ndim < 3:
raise TypeError( raise TypeError(
@ -98,47 +90,7 @@ class Si_SDR:
return si_sdr return si_sdr
class Stoi: class Avergeloss(nn.Module):
"""
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to cpu to perform the metric calculation.
parameters:
sr: int
sampling rate
"""
def __init__(self, sr: int):
self.sr = sr
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
self.name = "stoi"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
return self.stoi(prediction, target)
class Pesq:
def __init__(self, sr: int, mode="wb"):
self.sr = sr
self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(
fs=self.sr, mode=self.mode
)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
pesq_values = []
for pred, target_ in zip(prediction, target):
try:
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
except Exception as e:
warnings.warn(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values))
class LossWrapper(nn.Module):
""" """
Combine multiple metics of same nature. Combine multiple metics of same nature.
for example, ["mea","mae"] for example, ["mea","mae"]
@ -160,11 +112,9 @@ class LossWrapper(nn.Module):
) )
self.higher_better = direction[0] self.higher_better = direction[0]
self.name = ""
for loss in losses: for loss in losses:
loss = self.validate_loss(loss) loss = self.validate_loss(loss)
self.valid_losses.append(loss()) self.valid_losses.append(loss())
self.name += f"{loss().name}_"
def validate_loss(self, loss: str): def validate_loss(self, loss: str):
if loss not in LOSS_MAP.keys(): if loss not in LOSS_MAP.keys():
@ -183,34 +133,8 @@ class LossWrapper(nn.Module):
return loss return loss
class Si_snr(nn.Module):
"""
SI-SNR
"""
def __init__(self, **kwargs):
super().__init__()
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
self.higher_better = False
self.name = "si_snr"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return -1 * self.loss_fun(prediction, target)
LOSS_MAP = { LOSS_MAP = {
"mae": mean_absolute_error, "mae": mean_absolute_error,
"mse": mean_squared_error, "mse": mean_squared_error,
"si-sdr": Si_SDR, "SI-SDR": Si_SDR,
"pesq": Pesq,
"stoi": Stoi,
"si-snr": Si_snr,
} }

View File

@ -0,0 +1,3 @@
from enhancer.models.demucs import Demucs
from enhancer.models.model import Model
from enhancer.models.waveunet import WaveUnet

View File

@ -1,14 +1,14 @@
import logging
import math import math
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models.model import Mayamodel from enhancer.models.model import Model
from mayavoz.utils.io import Audio as audio from enhancer.utils.io import Audio as audio
from mayavoz.utils.utils import merge_dict from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module): class DemucsLSTM(nn.Module):
@ -49,7 +49,7 @@ class DemucsEncoder(nn.Module):
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(num_channels, hidden_size, kernel_size, stride), nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation, activation,
) )
@ -72,7 +72,7 @@ class DemucsDecoder(nn.Module):
activation = nn.GLU(1) if glu else nn.ReLU() activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation, activation,
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
) )
@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module):
return out return out
class Demucs(Mayamodel): class Demucs(Model):
""" """
Demucs model from https://arxiv.org/pdf/1911.13254.pdf Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters: parameters:
@ -102,8 +102,8 @@ class Demucs(Mayamodel):
sampling rate of input audio sampling rate of input audio
lr : float, defaults to 1e-3 lr : float, defaults to 1e-3
learning rate used for training learning rate used for training
dataset: MayaDataset, optional dataset: EnhancerDataset, optional
MayaDataset object containing train/validation data for training EnhancerDataset object containing train/validation data for training
duration : float, optional duration : float, optional
chunk duration in seconds chunk duration in seconds
loss : string or List of strings loss : string or List of strings
@ -116,7 +116,7 @@ class Demucs(Mayamodel):
ED_DEFAULTS = { ED_DEFAULTS = {
"initial_output_channels": 48, "initial_output_channels": 48,
"kernel_size": 8, "kernel_size": 8,
"stride": 4, "stride": 1,
"depth": 5, "depth": 5,
"glu": True, "glu": True,
"growth_factor": 2, "growth_factor": 2,
@ -133,20 +133,17 @@ class Demucs(Mayamodel):
num_channels: int = 1, num_channels: int = 1,
resample: int = 4, resample: int = 4,
sampling_rate=16000, sampling_rate=16000,
normalize=True,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
floor=1e-3,
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else duration dataset.duration if isinstance(dataset, EnhancerDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
warnings.warn( logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate
@ -164,8 +161,6 @@ class Demucs(Mayamodel):
lstm = merge_dict(self.LSTM_DEFAULTS, lstm) lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample") self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"] hidden = encoder_decoder["initial_output_channels"]
self.normalize = normalize
self.floor = floor
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
@ -184,7 +179,7 @@ class Demucs(Mayamodel):
num_channels=num_channels, num_channels=num_channels,
hidden_size=hidden, hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"], kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"], stride=1,
glu=encoder_decoder["glu"], glu=encoder_decoder["glu"],
layer=layer, layer=layer,
) )
@ -205,16 +200,11 @@ class Demucs(Mayamodel):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels: if waveform.size(1) != 1:
raise ValueError( raise TypeError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
) )
if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True)
std = waveform.std(dim=-1, keepdim=True)
waveform = waveform / (self.floor + std)
else:
std = 1
length = waveform.shape[-1] length = waveform.shape[-1]
x = F.pad(waveform, (0, self.get_padding_length(length) - length)) x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample > 1: if self.hparams.resample > 1:
@ -236,7 +226,7 @@ class Demucs(Mayamodel):
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
for decoder in self.decoder: for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1) skip_connection = encoder_outputs.pop(-1)
x = x + skip_connection[..., : x.shape[-1]] x += skip_connection[..., : x.shape[-1]]
x = decoder(x) x = decoder(x)
if self.hparams.resample > 1: if self.hparams.resample > 1:
@ -246,8 +236,7 @@ class Demucs(Mayamodel):
self.hparams.sampling_rate, self.hparams.sampling_rate,
) )
out = x[..., :length] return x
return std * out
def get_padding_length(self, input_length): def get_padding_length(self, input_length):

View File

@ -1,8 +1,7 @@
import os import os
from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Text, Union from typing import Any, Dict, List, Optional, Text, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -10,24 +9,19 @@ import pytorch_lightning as pl
import torch import torch
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam from torch.optim import Adam
from mayavoz.data.dataset import MayaDataset from enhancer import __version__
from mayavoz.inference import Inference from enhancer.data.dataset import EnhancerDataset
from mayavoz.loss import LOSS_MAP, LossWrapper from enhancer.inference import Inference
from mayavoz.version import __version__ from enhancer.loss import Avergeloss
CACHE_DIR = os.getenv( CACHE_DIR = ""
"ENHANCER_CACHE", HF_TORCH_WEIGHTS = ""
os.path.expanduser("~/.cache/torch/mayavoz"),
)
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu" DEFAULT_DEVICE = "cpu"
SAVE_NAME = "mayavoz"
class Mayamodel(pl.LightningModule): class Model(pl.LightningModule):
""" """
Base class for all models Base class for all models
parameters: parameters:
@ -37,11 +31,11 @@ class Mayamodel(pl.LightningModule):
audio sampling rate audio sampling rate
lr: float, optional lr: float, optional
learning rate for model training learning rate for model training
dataset: MayaDataset, optional dataset: EnhancerDataset, optional
mayavoz dataset used for training/validation Enhancer dataset used for training/validation
duration: float, optional duration: float, optional
duration used for training/inference duration used for training/inference
loss : string or List of strings or custom loss (nn.Module), default to "mse" loss : string or List of strings, default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR") loss functions to be used. Available ("mse","mae","Si-SDR")
""" """
@ -51,13 +45,15 @@ class Mayamodel(pl.LightningModule):
num_channels: int = 1, num_channels: int = 1,
sampling_rate: int = 16000, sampling_rate: int = 16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse", metric: Union[str, List] = "mse",
): ):
super().__init__() super().__init__()
assert num_channels == 1, "mayavoz only support for mono channel models" assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset self.dataset = dataset
self.save_hyperparameters( self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration" "num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
@ -78,9 +74,9 @@ class Mayamodel(pl.LightningModule):
def loss(self, loss): def loss(self, loss):
if isinstance(loss, str): if isinstance(loss, str):
loss = [loss] losses = [loss]
self._loss = LossWrapper(loss) self._loss = Avergeloss(losses)
@property @property
def metric(self): def metric(self):
@ -88,26 +84,11 @@ class Mayamodel(pl.LightningModule):
@metric.setter @metric.setter
def metric(self, metric): def metric(self, metric):
self._metric = []
if isinstance(metric, (str, nn.Module)): if isinstance(metric, str):
metric = [metric] metric = [metric]
for func in metric: self._metric = Avergeloss(metric)
if isinstance(func, str):
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
LOSS_MAP[func](self.hparams.sampling_rate)
)
else:
self._metric.append(LOSS_MAP[func]())
else:
ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError("Invalid metrics")
@property @property
def dataset(self): def dataset(self):
@ -119,41 +100,15 @@ class Mayamodel(pl.LightningModule):
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
if stage == "fit": if stage == "fit":
torch.cuda.empty_cache()
self.dataset.setup(stage) self.dataset.setup(stage)
self.dataset.model = self self.dataset.model = self
print(
"Total train duration",
self.dataset.train_dataloader().dataset.__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total validation duration",
self.dataset.val_dataloader().dataset.__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total test duration",
self.dataset.test_dataloader().dataset.__len__()
* self.dataset.duration
/ 60,
"minutes",
)
def train_dataloader(self): def train_dataloader(self):
return self.dataset.train_dataloader() return self.dataset.train_dataloader()
def val_dataloader(self): def val_dataloader(self):
return self.dataset.val_dataloader() return self.dataset.val_dataloader()
def test_dataloader(self):
return self.dataset.test_dataloader()
def configure_optimizers(self): def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr) return Adam(self.parameters(), lr=self.hparams.lr)
@ -162,86 +117,59 @@ class Mayamodel(pl.LightningModule):
mixed_waveform = batch["noisy"] mixed_waveform = batch["noisy"]
target = batch["clean"] target = batch["clean"]
prediction = self(mixed_waveform) prediction = self(mixed_waveform)
loss = self.loss(prediction, target) loss = self.loss(prediction, target)
self.log( if self.logger:
"train_loss", self.logger.experiment.log_metric(
loss.item(), run_id=self.logger.run_id,
on_epoch=True, key="train_loss",
on_step=True, value=loss.item(),
logger=True, step=self.global_step,
prog_bar=True, )
) self.log("train_loss", loss.item())
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx: int): def validation_step(self, batch, batch_idx: int):
metric_dict = {}
mixed_waveform = batch["noisy"] mixed_waveform = batch["noisy"]
target = batch["clean"] target = batch["clean"]
prediction = self(mixed_waveform) prediction = self(mixed_waveform)
metric_dict["valid_loss"] = self.loss(target, prediction).item() metric_val = self.metric(prediction, target)
for metric in self.metric: loss_val = self.loss(prediction, target)
value = metric(target, prediction) self.log("val_metric", metric_val.item())
metric_dict[f"valid_{metric.name}"] = value.item() self.log("val_loss", loss_val.item())
self.log_dict( if self.logger:
metric_dict, self.logger.experiment.log_metric(
on_step=True, run_id=self.logger.run_id,
on_epoch=True, key="val_loss",
prog_bar=True, value=loss_val.item(),
logger=True, step=self.global_step,
) )
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
step=self.global_step,
)
return metric_dict return {"loss": loss_val}
def test_step(self, batch, batch_idx):
metric_dict = {}
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
for metric in self.metric:
value = metric(target, prediction)
metric_dict[f"test_{metric.name}"] = value
self.log_dict(
metric_dict,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return metric_dict
def test_epoch_end(self, outputs):
test_mean_metrics = defaultdict(int)
for output in outputs:
for metric, value in output.items():
test_mean_metrics[metric] += value.item()
for metric in test_mean_metrics.keys():
test_mean_metrics[metric] /= len(outputs)
print("----------TEST REPORT----------\n")
for metric in test_mean_metrics.keys():
print(f"|{metric.upper()} | {test_mean_metrics[metric]} |")
print("--------------------------------")
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint[SAVE_NAME] = { checkpoint["enhancer"] = {
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__}, "version": {"enhancer": __version__, "pytorch": torch.__version__},
"architecture": { "architecture": {
"module": self.__class__.__module__, "module": self.__class__.__module__,
"class": self.__class__.__name__, "class": self.__class__.__name__,
}, },
} }
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
@ -281,15 +209,16 @@ class Mayamodel(pl.LightningModule):
to True or to a string containing your hugginface.co authentication to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login` token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional cache_dir: Path or str, optional
Path to model cache directory Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
kwargs: optional kwargs: optional
Any extra keyword args needed to init the model. Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values. Can also be used to override saved hyperparameter values.
Returns Returns
------- -------
model : Mayamodel model : Model
Mayamodel Model
See also See also
-------- --------
@ -318,7 +247,7 @@ class Mayamodel(pl.LightningModule):
) )
model_path_pl = cached_download( model_path_pl = cached_download(
url=url, url=url,
library_name="mayavoz", library_name="enhancer",
library_version=__version__, library_version=__version__,
cache_dir=cached_dir, cache_dir=cached_dir,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
@ -328,8 +257,8 @@ class Mayamodel(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE) map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl, map_location) loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"] module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"] class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name) module = import_module(module_name)
Klass = getattr(module, class_name) Klass = getattr(module, class_name)
@ -361,9 +290,10 @@ class Mayamodel(pl.LightningModule):
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = [] batch_predictions = []
self.eval().to(self.device) self.eval().to(self.device)
with torch.no_grad(): with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size): for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to( batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
self.device self.device
) )
prediction = self(batch_data) prediction = self(batch_data)

View File

@ -1,12 +1,12 @@
import warnings import logging
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models.model import Mayamodel from enhancer.models.model import Model
class WavenetDecoder(nn.Module): class WavenetDecoder(nn.Module):
@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module):
return self.encoder(waveform) return self.encoder(waveform)
class WaveUnet(Mayamodel): class WaveUnet(Model):
""" """
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
parameters: parameters:
@ -80,8 +80,8 @@ class WaveUnet(Mayamodel):
sampling rate of input audio sampling rate of input audio
lr : float, defaults to 1e-3 lr : float, defaults to 1e-3
learning rate used for training learning rate used for training
dataset: MayaDataset, optional dataset: EnhancerDataset, optional
MayaDataset object containing train/validation data for training EnhancerDataset object containing train/validation data for training
duration : float, optional duration : float, optional
chunk duration in seconds chunk duration in seconds
loss : string or List of strings loss : string or List of strings
@ -97,17 +97,17 @@ class WaveUnet(Mayamodel):
initial_output_channels: int = 24, initial_output_channels: int = 24,
sampling_rate: int = 16000, sampling_rate: int = 16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else duration dataset.duration if isinstance(dataset, EnhancerDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
warnings.warn( logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

View File

@ -0,0 +1,3 @@
from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.utils import check_files

View File

@ -70,7 +70,7 @@ class Audio:
if sampling_rate: if sampling_rate:
audio = self.__class__.resample_audio( audio = self.__class__.resample_audio(
audio, sampling_rate, self.sampling_rate audio, self.sampling_rate, sampling_rate
) )
if self.return_tensor: if self.return_tensor:
return torch.tensor(audio) return torch.tensor(audio)

View File

@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Optional
from mayavoz.utils.config import Files from enhancer.utils.config import Files
def check_files(root_dir: str, files: Files): def check_files(root_dir: str, files: Files):

View File

@ -1,4 +1,4 @@
name: mayavoz name: enhancer
dependencies: dependencies:
- pip=21.0.1 - pip=21.0.1

39
hpc_entrypoint.sh Normal file
View File

@ -0,0 +1,39 @@
#!/bin/bash
set -e
echo '----------------------------------------------------'
echo ' SLURM_CLUSTER_NAME = '$SLURM_CLUSTER_NAME
echo ' SLURMD_NODENAME = '$SLURMD_NODENAME
echo ' SLURM_JOBID = '$SLURM_JOBID
echo ' SLURM_JOB_USER = '$SLURM_JOB_USER
echo ' SLURM_PARTITION = '$SLURM_JOB_PARTITION
echo ' SLURM_JOB_ACCOUNT = '$SLURM_JOB_ACCOUNT
echo '----------------------------------------------------'
#TeamCity Output
cat << EOF
##teamcity[buildNumber '$SLURM_JOBID']
EOF
echo "Load HPC modules"
module load anaconda
echo "Activate Environment"
source activate enhancer
export TRANSFORMERS_OFFLINE=True
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
export HYDRA_FULL_ERROR=1
echo $PYTHONPATH
source ~/mlflow_settings.sh
echo "Making temp dir"
mkdir temp
pwd
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
echo "Start Training..."
python cli/train.py

View File

@ -1,2 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)
from mayavoz.models import Mayamodel

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def train(config: DictConfig):
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
train()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,12 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : MS-SDSD
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 2.0
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 15
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : Valentini
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5
stride : 2
sampling_rate: 16000
batch_size: 32
valid_minutes : 15
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -1,7 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
ReduceLr_patience : 5
ReduceLr_factor : 0.2
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,25 +0,0 @@
_target_: mayavoz.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1 +0,0 @@
from mayavoz.data.dataset import MayaDataset

View File

@ -1,393 +0,0 @@
import math
import multiprocessing
import os
import sys
import warnings
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch_audiomentations import Compose
from mayavoz.data.fileprocessor import Fileprocessor
from mayavoz.utils import check_files
from mayavoz.utils.config import Files
from mayavoz.utils.io import Audio
from mayavoz.utils.random import create_unique_rng
LARGE_NUM = 2147483647
class TrainDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.train__getitem__(idx)
def __len__(self):
return self.dataset.train__len__()
class ValidDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.val__getitem__(idx)
def __len__(self):
return self.dataset.val__len__()
class TestDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.test__getitem__(idx)
def __len__(self):
return self.dataset.test__len__()
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name: str,
root_dir: str,
files: Files,
min_valid_minutes: float = 0.20,
duration: float = 1.0,
stride=None,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
):
super().__init__()
self.name = name
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.stride = stride or duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
if (
num_workers > 0
and sys.platform == "darwin"
and sys.version_info[0] >= 3
and sys.version_info[1] >= 8
):
warnings.warn(
"num_workers > 0 is not supported with macOS and Python 3.8+: "
"setting num_workers = 0."
)
num_workers = 0
self.num_workers = num_workers
if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes
else:
raise ValueError("min_valid_minutes must be greater than 0")
self.augmentations = augmentations
def setup(self, stage: Optional[str] = None):
"""
prepare train/validation/test data splits
"""
if stage in ("fit", None):
train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
train_data = fp.prepare_matching_dict()
train_data, self.val_data = self.train_valid_split(
train_data,
min_valid_minutes=self.min_valid_minutes,
random_state=42,
)
self.train_data = self.prepare_traindata(train_data)
self._validation = self.prepare_mapstype(self.val_data)
test_clean = os.path.join(self.root_dir, self.files.test_clean)
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, test_clean, test_noisy, self.matching_function
)
test_data = fp.prepare_matching_dict()
self._test = self.prepare_mapstype(test_data)
def train_valid_split(
self, data, min_valid_minutes: float = 20, random_state: int = 42
):
min_valid_minutes *= 60
valid_sec_now = 0.0
valid_indices = []
all_speakers = np.unique(
[Path(file["clean"]).name.split("_")[0] for file in data]
)
possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= min_valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
print(f"Selected f{speaker_name} for valid")
file_indices = [
i
for i, file in enumerate(data)
if speaker_name == Path(file["clean"]).name.split("_")[0]
]
for i in file_indices:
valid_indices.append(i)
valid_sec_now += data[i]["duration"]
train_data = [
item for i, item in enumerate(data) if i not in valid_indices
]
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
return train_data, valid_data
def prepare_traindata(self, data):
train_data = []
for item in data:
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
train_data.append(samples_metadata)
return train_data
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def prepare_mapstype(self, data):
metadata = []
for item in data:
clean, noisy, total_dur = item.values()
if total_dur < self.duration:
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
else:
num_segments = self.get_num_segments(
total_dur, self.duration, self.duration
)
for index in range(num_segments):
start_time = index * self.duration
metadata.append(
({"clean": clean, "noisy": noisy}, start_time)
)
return metadata
def train_collatefn(self, batch):
output = {"clean": [], "noisy": []}
for item in batch:
output["clean"].append(item["clean"])
output["noisy"].append(item["noisy"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
if self.augmentations is not None:
noise = output["noisy"] - output["clean"]
output["clean"] = self.augmentations(
output["clean"], sample_rate=self.sampling_rate
)
self.augmentations.freeze_parameters()
output["noisy"] = (
self.augmentations(noise, sample_rate=self.sampling_rate)
+ output["clean"]
)
return output
@property
def generator(self):
generator = torch.Generator()
if hasattr(self, "model"):
seed = self.model.current_epoch + LARGE_NUM
else:
seed = LARGE_NUM
return generator.manual_seed(seed)
def train_dataloader(self):
dataset = TrainDataset(self)
sampler = RandomSampler(dataset, generator=self.generator)
return DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=sampler,
collate_fn=self.train_collatefn,
)
def val_dataloader(self):
return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def test_dataloader(self):
return DataLoader(
TestDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class MayaDataset(TaskDataset):
"""
Dataset object for creating clean-noisy speech enhancement datasets
paramters:
name : str
name of the dataset
root_dir : str
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer mayavoz.utils.Files dataclass)
min_valid_minutes: float
minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
batch_size : int
batch size of each batch
num_workers : int
num workers to be used while training
matching_function : str
maching functions - (one_to_one,one_to_many). Default set to None.
use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
self,
name: str,
root_dir: str,
files: Files,
min_valid_minutes=5.0,
duration=1.0,
stride=None,
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
):
super().__init__(
name=name,
root_dir=root_dir,
files=files,
min_valid_minutes=min_valid_minutes,
sampling_rate=sampling_rate,
duration=duration,
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
augmentations=augmentations,
)
self.sampling_rate = sampling_rate
self.files = files
self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
self.stride = stride or duration
def setup(self, stage: Optional[str] = None):
super().setup(stage=stage)
def train__getitem__(self, idx):
for filedict, num_samples in self.train_data:
if idx >= num_samples:
idx -= num_samples
continue
else:
start = 0
if self.duration is not None:
start = idx * self.stride
return self.prepare_segment(filedict, start)
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
def test__getitem__(self, idx):
return self.prepare_segment(*self._test[idx])
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
noisy_segment = self.audio(
file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {
"clean": clean_segment,
"noisy": noisy_segment,
}
def train__len__(self):
_, num_examples = list(zip(*self.train_data))
return sum(num_examples)
def val__len__(self):
return len(self._validation)
def test__len__(self):
return len(self._test)

View File

@ -1,3 +0,0 @@
from mayavoz.models.demucs import Demucs
from mayavoz.models.model import Mayamodel
from mayavoz.models.waveunet import WaveUnet

View File

@ -1,5 +0,0 @@
from mayavoz.models.complexnn.conv import ComplexConv2d # noqa
from mayavoz.models.complexnn.conv import ComplexConvTranspose2d # noqa
from mayavoz.models.complexnn.rnn import ComplexLSTM # noqa
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D # noqa
from mayavoz.models.complexnn.utils import ComplexRelu # noqa

View File

@ -1,136 +0,0 @@
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
def init_weights(nnet):
nn.init.xavier_normal_(nnet.weight.data)
nn.init.constant_(nnet.bias, 0.0)
return nnet
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
groups: int = 1,
dilation: int = 1,
):
"""
Complex Conv2d (non-causal)
"""
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.real_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = init_weights(self.imag_conv)
self.real_conv = init_weights(self.real_conv)
def forward(self, input):
"""
complex axis should be always 1 dim
"""
input = F.pad(input, [self.padding[1], 0, 0, 0])
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
output_padding: Tuple[int, int] = (0, 0),
groups: int = 1,
):
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.output_padding = output_padding
self.real_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.imag_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.real_conv = init_weights(self.real_conv)
self.imag_conv = init_weights(self.imag_conv)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag + imag_real
out = torch.cat([real, imag], 1)
return out

View File

@ -1,68 +0,0 @@
from typing import List, Optional
import torch
from torch import nn
class ComplexLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
projection_size: Optional[int] = None,
bidirectional: bool = False,
):
super().__init__()
self.input_size = input_size // 2
self.hidden_size = hidden_size // 2
self.num_layers = num_layers
self.real_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
self.imag_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
bidirectional = 2 if bidirectional else 1
if projection_size is not None:
self.projection_size = projection_size // 2
self.real_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
self.imag_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
else:
self.projection_size = None
def forward(self, input):
if isinstance(input, List):
real, imag = input
else:
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_lstm(real)[0]
real_imag = self.imag_lstm(real)[0]
imag_imag = self.imag_lstm(imag)[0]
imag_real = self.real_lstm(imag)[0]
real = real_real - imag_imag
imag = imag_real + real_imag
if self.projection_size is not None:
real = self.real_linear(real)
imag = self.imag_linear(imag)
return [real, imag]

View File

@ -1,199 +0,0 @@
import torch
from torch import nn
class ComplexBatchNorm2D(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
"""
Complex batch normalization 2D
https://arxiv.org/abs/1705.09792
"""
super().__init__()
self.num_features = num_features // 2
self.affine = affine
self.momentum = momentum
self.track_running_stats = track_running_stats
self.eps = eps
if self.affine:
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter("Wrr", None)
self.register_parameter("Wri", None)
self.register_parameter("Wii", None)
self.register_parameter("Br", None)
self.register_parameter("Bi", None)
if self.track_running_stats:
values = torch.zeros(self.num_features)
self.register_buffer("Mean_real", values)
self.register_buffer("Mean_imag", values)
self.register_buffer("Var_rr", values)
self.register_buffer("Var_ri", values)
self.register_buffer("Var_ii", values)
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("Mean_real", None)
self.register_parameter("Mean_imag", None)
self.register_parameter("Var_rr", None)
self.register_parameter("Var_ri", None)
self.register_parameter("Var_ii", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
def reset_parameters(self):
if self.affine:
self.Wrr.data.fill_(1)
self.Wii.data.fill_(1)
self.Wri.data.uniform_(-0.9, 0.9)
self.Br.data.fill_(0)
self.Bi.data.fill_(0)
self.reset_running_stats()
def reset_running_stats(self):
if self.track_running_stats:
self.Mean_real.zero_()
self.Mean_imag.zero_()
self.Var_rr.fill_(1)
self.Var_ri.zero_()
self.Var_ii.fill_(1)
self.num_batches_tracked.zero_()
def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
**self.__dict__
)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
exp_avg_factor = 0.0
training = self.training and self.track_running_stats
if training:
self.num_batches_tracked += 1
if self.momentum is None:
exp_avg_factor = 1 / self.num_batches_tracked
else:
exp_avg_factor = self.momentum
redux = [i for i in reversed(range(real.dim())) if i != 1]
vdim = [1] * real.dim()
vdim[1] = real.size(1)
if training:
batch_mean_real, batch_mean_imag = real, imag
for dim in redux:
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
if self.track_running_stats:
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
else:
batch_mean_real = self.Mean_real.view(vdim)
batch_mean_imag = self.Mean_imag.view(vdim)
real = real - batch_mean_real
imag = imag - batch_mean_imag
if training:
batch_var_rr = real * real
batch_var_ri = real * imag
batch_var_ii = imag * imag
for dim in redux:
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
if self.track_running_stats:
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
else:
batch_var_rr = self.Var_rr.view(vdim)
batch_var_ii = self.Var_ii.view(vdim)
batch_var_ri = self.Var_ri.view(vdim)
batch_var_rr += self.eps
batch_var_ii += self.eps
# Covariance matrics
# | batch_var_rr batch_var_ri |
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
# Inverse square root of cov matrix by combining below two formulas
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
# https://mathworld.wolfram.com/MatrixInverse.html
tau = batch_var_rr + batch_var_ii
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
t = (tau + 2 * s).sqrt()
rst = (s * t).reciprocal()
Urr = (batch_var_ii + s) * rst
Uri = -batch_var_ri * rst
Uii = (batch_var_rr + s) * rst
if self.affine:
Wrr, Wri, Wii = (
self.Wrr.view(vdim),
self.Wri.view(vdim),
self.Wii.view(vdim),
)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wii * Uri) + (Wri * Urr)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * real) + (Zri * imag)
yi = (Zir * real) + (Zii * imag)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], 1)
return outputs
class ComplexRelu(nn.Module):
def __init__(self):
super().__init__()
self.real_relu = nn.PReLU()
self.imag_relu = nn.PReLU()
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real = self.real_relu(real)
imag = self.imag_relu(imag)
return torch.cat([real, imag], dim=1)
def complex_cat(inputs, axis=1):
real, imag = [], []
for data in inputs:
real_data, imag_data = torch.chunk(data, 2, axis)
real.append(real_data)
imag.append(imag_data)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
return torch.cat([real, imag], axis)

View File

@ -1,338 +0,0 @@
import warnings
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from mayavoz.data import MayaDataset
from mayavoz.models import Mayamodel
from mayavoz.models.complexnn import (
ComplexBatchNorm2D,
ComplexConv2d,
ComplexConvTranspose2d,
ComplexLSTM,
ComplexRelu,
)
from mayavoz.models.complexnn.utils import complex_cat
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT
from mayavoz.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channel: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 1),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.encoder = nn.Sequential(
ComplexConv2d(
in_channels,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
batchnorm(out_channel),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DCCRN_DECODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
layer: int = 0,
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 0),
output_padding: Tuple[int, int] = (1, 0),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
if layer != 0:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
),
batchnorm(out_channels),
activation,
)
else:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
)
def forward(self, waveform):
return self.decoder(waveform)
class DCCRN(Mayamodel):
STFT_DEFAULTS = {
"window_len": 400,
"hop_size": 100,
"nfft": 512,
"window": "hamming",
}
ED_DEFAULTS = {
"initial_output_channels": 32,
"depth": 6,
"kernel_size": 5,
"growth_factor": 2,
"stride": 2,
"padding": 2,
"output_padding": 1,
}
LSTM_DEFAULTS = {
"num_layers": 2,
"hidden_size": 256,
}
def __init__(
self,
stft: Optional[dict] = None,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
complex_lstm: bool = True,
complex_norm: bool = True,
complex_relu: bool = True,
masking_mode: str = "E",
num_channels: int = 1,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
stft = merge_dict(self.STFT_DEFAULTS, stft)
self.save_hyperparameters(
"encoder_decoder",
"lstm",
"stft",
"complex_lstm",
"complex_norm",
"masking_mode",
)
self.complex_lstm = complex_lstm
self.complex_norm = complex_norm
self.masking_mode = masking_mode
self.stft = ConvSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.istft = ConviSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
num_channels *= 2
hidden_size = encoder_decoder["initial_output_channels"]
growth_factor = 2
for layer in range(encoder_decoder["depth"]):
encoder_ = DCCRN_ENCODER(
num_channels,
hidden_size,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 1),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.encoder.append(encoder_)
decoder_ = DCCRN_DECODER(
hidden_size + hidden_size,
num_channels,
layer=layer,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 0),
output_padding=(encoder_decoder["output_padding"], 0),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.decoder.insert(0, decoder_)
if layer < encoder_decoder["depth"] - 3:
num_channels = hidden_size
hidden_size *= growth_factor
else:
num_channels = hidden_size
kernel_size = hidden_size / 2
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
if self.complex_lstm:
lstms = []
for layer in range(lstm["num_layers"]):
if layer == 0:
input_size = int(hidden_size * kernel_size)
else:
input_size = lstm["hidden_size"]
if layer == lstm["num_layers"] - 1:
projection_size = int(hidden_size * kernel_size)
else:
projection_size = None
kwargs = {
"input_size": input_size,
"hidden_size": lstm["hidden_size"],
"num_layers": 1,
}
lstms.append(
ComplexLSTM(projection_size=projection_size, **kwargs)
)
self.lstm = nn.Sequential(*lstms)
else:
self.lstm = nn.Sequential(
nn.LSTM(
input_size=hidden_size * kernel_size,
hidden_sizs=lstm["hidden_size"],
num_layers=lstm["num_layers"],
dropout=0.0,
batch_first=False,
)[0],
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
waveform_stft = self.stft(waveform)
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
phase_spec = torch.atan2(imag, real)
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
encoder_outputs = []
out = complex_spec
for _, encoder in enumerate(self.encoder):
out = encoder(out)
encoder_outputs.append(out)
B, C, D, T = out.size()
out = out.permute(3, 0, 1, 2)
if self.complex_lstm:
lstm_real = out[:, :, : C // 2]
lstm_imag = out[:, :, C // 2 :]
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
lstm_real = lstm_real.reshape(T, B, C // 2, D)
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
out = torch.cat([lstm_real, lstm_imag], 2)
else:
out = out.reshape(T, B, C * D)
out = self.lstm(out)
out = out.reshape(T, B, D, C)
out = out.permute(1, 2, 3, 0)
for layer, decoder in enumerate(self.decoder):
skip_connection = encoder_outputs.pop(-1)
out = complex_cat([skip_connection, out])
out = decoder(out)
out = out[..., 1:]
mask_real, mask_imag = out[:, 0], out[:, 1]
mask_real = F.pad(mask_real, [0, 0, 1, 0])
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
if self.masking_mode == "E":
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
real_phase = mask_real / (mask_mag + 1e-8)
imag_phase = mask_imag / (mask_mag + 1e-8)
mask_phase = torch.atan2(imag_phase, real_phase)
mask_mag = torch.tanh(mask_mag)
est_mag = mask_mag * mag_spec
est_phase = mask_phase * phase_spec
# cos(theta) + isin(theta)
real = est_mag + torch.cos(est_phase)
imag = est_mag + torch.sin(est_phase)
if self.masking_mode == "C":
real = real * mask_real - imag * mask_imag
imag = real * mask_imag + imag * mask_real
else:
real = real * mask_real
imag = imag * mask_imag
spec = torch.cat([real, imag], 1)
wav = self.istft(spec)
wav = wav.clamp_(-1, 1)
return wav

View File

@ -1,3 +0,0 @@
from mayavoz.utils.config import Files
from mayavoz.utils.io import Audio
from mayavoz.utils.utils import check_files

View File

@ -1,93 +0,0 @@
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import get_window
from torch import nn
class ConvFFT(nn.Module):
def __init__(
self,
window_len: int,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__()
self.window_len = window_len
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
self.window = torch.from_numpy(
get_window(window, window_len, fftbins=True).astype("float32")
)
def init_kernel(self, inverse=False):
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
kernel = np.concatenate([real, imag], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
kernel *= self.window
return kernel
class ConvSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel())
def forward(self, input):
if input.dim() < 2:
raise ValueError(
f"Expected signal with shape 2 or 3 got {input.dim()}"
)
elif input.dim() == 2:
input = input.unsqueeze(1)
else:
pass
input = F.pad(
input,
(self.window_len - self.hop_size, self.window_len - self.hop_size),
)
output = F.conv1d(input, self.weight, stride=self.hop_size)
return output
class ConviSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel(True))
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
def forward(self, input, phase=None):
if phase is not None:
real = input * torch.cos(phase)
imag = input * torch.sin(phase)
input = torch.cat([real, imag], 1)
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
coeff = coeff.to(input.device)
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
out = out / (coeff + 1e-8)
pad = self.window_len - self.hop_size
out = out[..., pad:-pad]
return out

View File

@ -1,338 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ccd61d5c",
"metadata": {},
"source": [
"## Custom model training using mayavoz [advanced]\n",
"\n",
"In this tutorial, we will cover advanced usages and customizations for training your own speecg enhancement model. \n",
"\n",
" - [Data preparation using MayaDataset](#dataprep)\n",
" - [Model customization](#modelcustom)\n",
" - [callbacks & LR schedulers](#callbacks)\n",
" - [Model training & testing](#train)\n"
]
},
{
"cell_type": "markdown",
"id": "726c320f",
"metadata": {},
"source": [
"- **install mayavoz**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c987c799",
"metadata": {},
"outputs": [],
"source": [
"! pip install -q mayavoz"
]
},
{
"cell_type": "markdown",
"id": "8ff9857b",
"metadata": {},
"source": [
"<div id=\"dataprep\"></div>\n",
"\n",
"### Data preparation\n",
"\n",
"`Files` is a dataclass that wraps and holds train/test paths togethor. There are usually one folder each for clean and noisy data. These paths must be relative to a `root_dir` where all these directories reside. For example\n",
"\n",
"```\n",
"- VCTK/\n",
" |__ clean_train_wav/\n",
" |__ noisy_train_wav/\n",
" |__ clean_test_wav/\n",
" |__ noisy_test_wav/\n",
" \n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "64cbc0c8",
"metadata": {},
"outputs": [],
"source": [
"from mayavoz.utils import Files\n",
"file = Files(train_clean=\"clean_train_wav\",\n",
" train_noisy=\"noisy_train_wav\",\n",
" test_clean=\"clean_test_wav\",\n",
" test_noisy=\"noisy_test_wav\")\n",
"root_dir = \"VCTK\""
]
},
{
"cell_type": "markdown",
"id": "2d324bd1",
"metadata": {},
"source": [
"- `name`: name of the dataset. \n",
"- `duration`: control the duration of each audio instance fed into your model.\n",
"- `stride` is used if set to move the sliding window.\n",
"- `sampling_rate`: desired sampling rate for audio\n",
"- `batch_size`: model batch size\n",
"- `min_valid_minutes`: minimum validation in minutes. Validation is automatically selected from training set. (exclusive users).\n",
"- `matching_function`: there are two types of mapping functions.\n",
" - `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n",
" - `one_to_many` : In this one clean file will only have one corresponding noisy file. For example MS-SNSD dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6834941d",
"metadata": {},
"outputs": [],
"source": [
"name = \"vctk\"\n",
"duration : 4.5\n",
"stride : 2.0\n",
"sampling_rate : 16000\n",
"min_valid_minutes : 20.0\n",
"batch_size : 32\n",
"matching_function : \"one_to_one\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d08c6bf8",
"metadata": {},
"outputs": [],
"source": [
"from mayavoz.dataset import MayaDataset\n",
"dataset = MayaDataset(\n",
" name=name,\n",
" root_dir=root_dir,\n",
" files=files,\n",
" duration=duration,\n",
" stride=stride,\n",
" sampling_rate=sampling_rate,\n",
" batch_size=batch_size,\n",
" min_valid_minutes=min_valid_minutes,\n",
" matching_function=matching_function\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "5b315bde",
"metadata": {},
"source": [
"Now your custom dataloader is ready!"
]
},
{
"cell_type": "markdown",
"id": "01548fe5",
"metadata": {},
"source": [
"<div id=\"modelcustom\"></div>\n",
"\n",
"### Model Customization\n",
"Now, this is very easy. \n",
"\n",
"- Import the preferred model from `mayavoz.models`. Currently 3 models are implemented.\n",
" - `WaveUnet`\n",
" - `Demucs`\n",
" - `DCCRN`\n",
"- Each of model hyperparameters such as depth,kernel_size,stride etc can be controlled by you. Just check the parameters and pass it to as required.\n",
"- `sampling_rate`: sampling rate (should be equal to dataset sampling rate)\n",
"- `dataset`: mayavoz dataset object as prepared earlier.\n",
"- `loss` : model loss. Multiple loss functions are available.\n",
"\n",
" \n",
" \n",
"you can pass one (as string)/more (as list of strings) of these loss functions as per your requirements. For example, model will automatically calculate loss as average of `mae` and `mse` if you pass loss as `[\"mae\",\"mse\"]`. Available loss functions are `mse`,`mae`,`si-snr`.\n",
"\n",
"mayavoz can accept **custom loss functions**. It should be of the form.\n",
"```\n",
"class your_custom_loss(nn.Module):\n",
" def __init__(self,**kwargs):\n",
" self.higher_better = False ## loss minimization direction\n",
" self.name = \"your_loss_name\" ## loss name logging \n",
" ...\n",
" def forward(self,prediction, target):\n",
" loss = ....\n",
" return loss\n",
" \n",
"```\n",
"\n",
"- metrics : validation metrics. Available options `mae`,`mse`,`si-sdr`,`si-sdr`,`pesq`,`stoi`. One or more can be used.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b36b457c",
"metadata": {},
"outputs": [],
"source": [
"from mayavoz.models import Demucs\n",
"model = Demucs(\n",
" sampling_rate=16000,\n",
" dataset=dataset,\n",
" loss=[\"mae\"],\n",
" metrics=[\"stoi\",\"pesq\"])\n"
]
},
{
"cell_type": "markdown",
"id": "1523d638",
"metadata": {},
"source": [
"<div id=\"callbacks\"></div>\n",
"\n",
"### learning rate schedulers and callbacks\n",
"Here I am using `ReduceLROnPlateau`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8de6931c",
"metadata": {},
"outputs": [],
"source": [
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"\n",
"def configure_optimizers(self):\n",
" optimizer = instantiate(\n",
" config.optimizer,\n",
" lr=parameters.get(\"lr\"),\n",
" params=self.parameters(),\n",
" )\n",
" scheduler = ReduceLROnPlateau(\n",
" optimizer=optimizer,\n",
" mode=direction,\n",
" factor=parameters.get(\"ReduceLr_factor\", 0.1),\n",
" verbose=True,\n",
" min_lr=parameters.get(\"min_lr\", 1e-6),\n",
" patience=parameters.get(\"ReduceLr_patience\", 3),\n",
" )\n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": scheduler,\n",
" \"monitor\": f'valid_{parameters.get(\"ReduceLr_monitor\", \"loss\")}',\n",
" }\n",
"\n",
"\n",
"model.configure_optimizers = MethodType(configure_optimizers, model)"
]
},
{
"cell_type": "markdown",
"id": "2f7b5af5",
"metadata": {},
"source": [
"you can use any number of callbacks and pass it directly to pytorch lightning trainer. Here I am using only `ModelCheckpoint`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f6b62a1",
"metadata": {},
"outputs": [],
"source": [
"callbacks = []\n",
"direction = model.valid_monitor ## min or max \n",
"checkpoint = ModelCheckpoint(\n",
" dirpath=\"./model\",\n",
" filename=f\"model_filename\",\n",
" monitor=\"valid_loss\",\n",
" verbose=False,\n",
" mode=direction,\n",
" every_n_epochs=1,\n",
" )\n",
"callbacks.append(checkpoint)"
]
},
{
"cell_type": "markdown",
"id": "f3534445",
"metadata": {},
"source": [
"<div id=\"train\"></div>\n",
"\n",
"\n",
"### Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3dc0348b",
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"trainer = plt.Trainer(max_epochs=1,callbacks=callbacks,accelerator=\"gpu\")\n",
"trainer.fit(model)\n"
]
},
{
"cell_type": "markdown",
"id": "56dcfec1",
"metadata": {},
"source": [
"- Test your model agaist test dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63851feb",
"metadata": {},
"outputs": [],
"source": [
"trainer.test(model)"
]
},
{
"cell_type": "markdown",
"id": "4d3f5350",
"metadata": {},
"source": [
"**Hurray! you have your speech enhancement model trained and tested.**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10d630e8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "enhancer",
"language": "python",
"name": "enhancer"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def train(config: DictConfig):
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
train()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : MS-SNSD
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : MS-SDSD
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 1.5
stride : 1
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 25
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,7 +0,0 @@
loss : si-snr
metric : [stoi,pesq]
lr : 0.001
ReduceLr_patience : 10
ReduceLr_factor : 0.5
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,25 +0,0 @@
_target_: mayavoz.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def train(config: DictConfig):
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
train()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : MS-SNSD
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : MS-SDSD
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 5
stride : 1
sampling_rate: 16000
batch_size: 32
min_valid_minutes: 25
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,7 +0,0 @@
loss : mae
metric : [stoi,pesq]
lr : 0.0003
ReduceLr_patience : 10
ReduceLr_factor : 0.5
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : demucs-ms-snsd

View File

@ -1,16 +0,0 @@
_target_: mayavoz.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2
glu: True
lstm:
bidirectional: False
num_layers: 2

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,17 +0,0 @@
### Microsoft Scalable Noisy Speech Dataset (MS-SNSD)
MS-SNSD is a speech datasetthat can scale to arbitrary sizes depending on the number of speakers, noise types, and Speech to Noise Ratio (SNR) levels desired.
### Dataset download & setup
- Follow steps in the official repo [here](https://github.com/microsoft/MS-SNSD) to download and setup the dataset.
**References**
```BibTex
@article{reddy2019scalable,
title={A Scalable Noisy Speech Dataset and Online Subjective Test Framework},
author={Reddy, Chandan KA and Beyrami, Ebrahim and Pool, Jamie and Cutler, Ross and Srinivasan, Sriram and Gehrke, Johannes},
journal={Proc. Interspeech 2019},
pages={1816--1820},
year={2019}
}
```

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
OmegaConf.save(config, "config_log.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
main()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5
stride : 0.5
sampling_rate: 16000
batch_size: 32
min_valid_minutes : 25
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -1,8 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
Early_stop : False
ReduceLr_patience : 10
ReduceLr_factor : 0.1
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : baseline

View File

@ -1,16 +0,0 @@
_target_: mayavoz.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2
glu: True
lstm:
bidirectional: True
num_layers: 2

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,8 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.003
ReduceLr_patience : 10
ReduceLr_factor : 0.1
min_lr : 0.000001
EarlyStopping_factor : 10
Early_stop : False

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : baseline

View File

@ -1,5 +0,0 @@
_target_: mayavoz.models.waveunet.WaveUnet
num_channels : 1
depth : 9
initial_output_channels: 24
sampling_rate : 16000

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,12 +0,0 @@
## Valentini dataset
Clean and noisy parallel speech database. The database was designed to train and test speech enhancement methods that operate at 48kHz. A more detailed description can be found in the papers associated with the database.[official page](https://datashare.ed.ac.uk/handle/10283/2791)
**References**
```BibTex
@misc{
title={Noisy speech database for training speech enhancement algorithms and TTS models},
author={Valentini-Botinhao, Cassia}, year={2017},
doi=https://doi.org/10.7488/ds/2117,
}
```

View File

@ -1,19 +1,16 @@
black>=22.8.0
boto3>=1.24.86 boto3>=1.24.86
huggingface-hub>=0.10.0 flake8>=5.0.4
huggingface-hu>=0.10.0
hydra-core>=1.2.0 hydra-core>=1.2.0
joblib>=1.2.0 joblib>=1.2.0
librosa>=0.9.2 librosa>=0.9.2
mlflow>=1.28.0 mlflow>=1.29.0
numpy>=1.23.3 numpy>=1.23.3
pesq==0.0.4
protobuf>=3.19.6 protobuf>=3.19.6
pystoi==0.3.3
pytest-lazy-fixture>=0.6.3
pytorch-lightning>=1.7.7 pytorch-lightning>=1.7.7
scikit-learn>=1.1.2 scikit-learn>=1.1.2
scipy>=1.9.1 scipy>=1.9.1
soundfile>=0.11.0
torch>=1.12.1 torch>=1.12.1
torch-audiomentations==0.11.0
torchaudio>=0.12.1 torchaudio>=0.12.1
tqdm>=4.64.1 tqdm>=4.64.1

104
setup.cfg
View File

@ -1,104 +0,0 @@
# This file is used to configure your project.
# Read more about the various options under:
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
name = mayavoz
description = Deep learning for speech enhacement
author = Shahul Ess
author-email = shahules786@gmail.com
license = mit
long-description = file: README.md
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
# Change if running only on Windows, Mac or Linux (comma-separated)
platforms = Linux, Mac
# Add here all kinds of additional classifiers as defined under
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers =
Development Status :: 4 - Beta
Programming Language :: Python
[options]
zip_safe = False
packages = find:
include_package_data = True
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
setup_requires = setuptools
# Add here dependencies of your project (semicolon/line-separated), e.g.
# install_requires = numpy; scipy
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8
[options.packages.find]
where = .
exclude =
tests
[options.extras_require]
# Add here additional requirements for extra features, to install with:
# `pip install fastaudio[PDF]` like:
# PDF = ReportLab; RXP
# Add here test requirements (semicolon/line-separated)
testing =
pytest>=7.1.3
pytest-cov>=4.0.0
dev =
pre-commit>=2.20.0
black>=22.8.0
flake8>=5.0.4
cli =
hydra-core >=1.1,<=1.2
[options.entry_points]
console_scripts =
mayavoz-train=mayavoz.cli.train:train
[test]
# py.test options when running `python setup.py test`
# addopts = --verbose
extras = True
[tool:pytest]
# Options for py.test:
# Specify command line options as you would do when invoking py.test directly.
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
# in order to write a coverage file that can be read by Jenkins.
addopts =
--cov mayavoz --cov-report term-missing
--verbose
norecursedirs =
dist
build
.tox
testpaths = tests
[aliases]
dists = bdist_wheel
[bdist_wheel]
# Use this option if your package is pure-python
universal = 1
[build_sphinx]
source_dir = doc
build_dir = build/sphinx
[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
# VCS export must be deactivated since we are using setuptools-scm
no-vcs = 1
formats = bdist_wheel
[flake8]
# Some sane defaults for the code style checker flake8
exclude =
.tox
build
dist
.eggs
[options.data_files]
. = requirements.txt
_ = version.txt

View File

@ -1,63 +0,0 @@
import os
import sys
from pathlib import Path
from pkg_resources import VersionConflict, require
from setuptools import find_packages, setup
with open("README.md") as f:
long_description = f.read()
with open("requirements.txt") as f:
requirements = f.read().splitlines()
try:
require("setuptools>=38.3")
except VersionConflict:
print("Error: version of setuptools is too old (<38.3)!")
sys.exit(1)
ROOT_DIR = Path(__file__).parent.resolve()
# Creating the version file
with open("version.txt") as f:
version = f.read()
version = version.strip()
sha = "Unknown"
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
print("-- Building version " + version)
version_path = ROOT_DIR / "mayavoz" / "version.py"
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))
if __name__ == "__main__":
setup(
name="mayavoz",
namespace_packages=["mayavoz"],
version=version,
packages=find_packages(),
install_requires=requirements,
description="Deep learning toolkit for speech enhancement",
long_description=long_description,
long_description_content_type="text/markdown",
author="Shahul Es",
author_email="shahules786@gmail.com",
url="",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
],
)

13
setup.sh Normal file
View File

@ -0,0 +1,13 @@
#!/bin/bash
set -e
echo "Loading Anaconda Module"
module load anaconda
echo "Creating Virtual Environment"
conda env create -f environment.yml || conda env update -f environment.yml
source activate enhancer
echo "copying files"
# cp /scratch/$USER/TIMIT/.* /deep-transcriber

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from mayavoz.loss import mean_absolute_error, mean_squared_error from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()] loss_functions = [mean_absolute_error(), mean_squared_error()]

View File

@ -1,50 +0,0 @@
import torch
from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from mayavoz.models.complexnn.rnn import ComplexLSTM
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d():
sample_input = torch.rand(1, 2, 256, 13)
conv = ComplexConv2d(
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 13])
def test_complexconvtranspose2d():
sample_input = torch.rand(1, 512, 4, 13)
conv = ComplexConvTranspose2d(
256 * 2,
128 * 2,
kernel_size=(5, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 256, 8, 14])
def test_complexlstm():
sample_input = torch.rand(13, 2, 128)
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
with torch.no_grad():
out = lstm(sample_input)
assert out[0].shape == torch.Size([13, 1, 512])
assert out[1].shape == torch.Size([13, 1, 512])
def test_complexbatchnorm2d():
sample_input = torch.rand(1, 64, 64, 14)
batchnorm = ComplexBatchNorm2D(num_features=64)
with torch.no_grad():
out = batchnorm(sample_input)
assert out.size() == sample_input.size()

View File

@ -1,9 +1,9 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import MayaDataset from enhancer.data.dataset import EnhancerDataset
from mayavoz.models import Demucs from enhancer.models import Demucs
from mayavoz.utils.config import Files from enhancer.utils.config import Files
@pytest.fixture @pytest.fixture
@ -15,9 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset( dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset
@ -32,7 +30,7 @@ def test_forward(batch_size, samples):
data = torch.rand(batch_size, 2, samples, requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(ValueError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)

View File

@ -1,45 +0,0 @@
import pytest
import torch
from mayavoz.data.dataset import MayaDataset
from mayavoz.models.dccrn import DCCRN
from mayavoz.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = DCCRN()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(ValueError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)

Some files were not shown because too many files have changed in this diff Show More