From 4a2865ff03ce37902aeb2c3467ab453a17c68b9e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 14 Nov 2022 10:48:31 +0530 Subject: [PATCH 1/7] negate si-snr --- enhancer/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index 75527bb..9955efd 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -192,7 +192,7 @@ class Si_snr(nn.Module): super().__init__() self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs) - self.higher_better = True + self.higher_better = False self.name = "si_snr" def forward(self, prediction: torch.Tensor, target: torch.Tensor): @@ -203,7 +203,7 @@ class Si_snr(nn.Module): got {prediction.size()} and {target.size()} instead""" ) - return self.loss_fun(prediction, target) + return -1 * self.loss_fun(prediction, target) LOSS_MAP = { From 7838e744a98a840aa14443e22208a804bf18835e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 14 Nov 2022 11:37:26 +0530 Subject: [PATCH 2/7] rename package --- .flake8 | 2 +- .github/workflows/ci.yaml | 8 +- .pre-commit-config.yaml | 1 + .../cli/train_config/dataset/Vctk_local.yaml | 13 -- enhancer/data/__init__.py | 1 - enhancer/models/__init__.py | 3 - enhancer/models/complexnn/__init__.py | 5 - enhancer/utils/__init__.py | 3 - environment.yml | 2 +- {enhancer => mayavoz}/__init__.py | 0 {enhancer => mayavoz}/cli/train.py | 0 .../cli/train_config/config.yaml | 0 .../cli/train_config/dataset/DNS-2020.yaml | 2 +- .../cli/train_config/dataset/Vctk.yaml | 2 +- .../train_config/hyperparameters/default.yaml | 0 .../cli/train_config/mlflow/experiment.yaml | 2 +- .../cli/train_config/model/DCCRN.yaml | 2 +- .../cli/train_config/model/Demucs.yaml | 2 +- .../cli/train_config/model/WaveUnet.yaml | 2 +- .../cli/train_config/optimizer/Adam.yaml | 0 .../cli/train_config/trainer/default.yaml | 0 .../cli/train_config/trainer/fastrun_dev.yaml | 0 mayavoz/data/__init__.py | 1 + {enhancer => mayavoz}/data/dataset.py | 12 +- {enhancer => mayavoz}/data/fileprocessor.py | 0 {enhancer => mayavoz}/inference.py | 2 +- {enhancer => mayavoz}/loss.py | 0 mayavoz/models/__init__.py | 3 + mayavoz/models/complexnn/__init__.py | 5 + .../models/complexnn/conv.py | 0 {enhancer => mayavoz}/models/complexnn/rnn.py | 0 .../models/complexnn/utils.py | 0 {enhancer => mayavoz}/models/dccrn.py | 12 +- {enhancer => mayavoz}/models/demucs.py | 8 +- {enhancer => mayavoz}/models/model.py | 26 ++-- {enhancer => mayavoz}/models/waveunet.py | 4 +- mayavoz/utils/__init__.py | 3 + {enhancer => mayavoz}/utils/config.py | 0 {enhancer => mayavoz}/utils/io.py | 0 {enhancer => mayavoz}/utils/random.py | 0 {enhancer => mayavoz}/utils/transforms.py | 0 {enhancer => mayavoz}/utils/utils.py | 2 +- {enhancer => mayavoz}/version.py | 0 notebooks/Custom_model_training.ipynb | 4 +- notebooks/Getting_started.ipynb | 6 +- recipes/DNS/DNS-2020/cli/train.py | 120 ++++++++++++++++++ .../DNS/DNS-2020/cli/train_config/config.yaml | 7 + .../cli/train_config/dataset/DNS-2020.yaml | 12 ++ .../cli/train_config/dataset/Vctk.yaml | 13 ++ .../train_config/hyperparameters/default.yaml | 7 + .../cli/train_config/mlflow/experiment.yaml | 2 + .../cli/train_config/model/DCCRN.yaml | 25 ++++ .../cli/train_config/model/Demucs.yaml | 16 +++ .../cli/train_config/model/WaveUnet.yaml | 5 + .../cli/train_config/optimizer/Adam.yaml | 6 + .../cli/train_config/trainer/default.yaml | 46 +++++++ .../cli/train_config/trainer/fastrun_dev.yaml | 2 + .../Demucs/train_config/dataset/Vctk.yaml | 2 +- .../train_config/mlflow/experiment.yaml | 2 +- .../Demucs/train_config/model/Demucs.yaml | 2 +- .../WaveUnet/train_config/dataset/Vctk.yaml | 2 +- .../train_config/mlflow/experiment.yaml | 2 +- .../WaveUnet/train_config/model/WaveUnet.yaml | 2 +- .../cli/train_config/dataset/DNS-2020.yaml | 2 +- .../28spk/cli/train_config/dataset/Vctk.yaml | 2 +- .../cli/train_config/mlflow/experiment.yaml | 2 +- .../28spk/cli/train_config/model/DCCRN.yaml | 2 +- .../28spk/cli/train_config/model/Demucs.yaml | 2 +- .../cli/train_config/model/WaveUnet.yaml | 2 +- setup.cfg | 6 +- setup.py | 6 +- tests/loss_function_test.py | 2 +- tests/models/complexnn_test.py | 6 +- tests/models/demucs_test.py | 6 +- tests/models/test_dccrn.py | 6 +- tests/models/test_waveunet.py | 6 +- tests/test_inference.py | 2 +- tests/transforms_test.py | 2 +- tests/utils_test.py | 4 +- 79 files changed, 358 insertions(+), 111 deletions(-) delete mode 100644 enhancer/cli/train_config/dataset/Vctk_local.yaml delete mode 100644 enhancer/data/__init__.py delete mode 100644 enhancer/models/__init__.py delete mode 100644 enhancer/models/complexnn/__init__.py delete mode 100644 enhancer/utils/__init__.py rename {enhancer => mayavoz}/__init__.py (100%) rename {enhancer => mayavoz}/cli/train.py (100%) rename {enhancer => mayavoz}/cli/train_config/config.yaml (100%) rename {enhancer => mayavoz}/cli/train_config/dataset/DNS-2020.yaml (85%) rename {enhancer => mayavoz}/cli/train_config/dataset/Vctk.yaml (85%) rename {enhancer => mayavoz}/cli/train_config/hyperparameters/default.yaml (100%) rename {enhancer => mayavoz}/cli/train_config/mlflow/experiment.yaml (59%) rename {enhancer => mayavoz}/cli/train_config/model/DCCRN.yaml (90%) rename {enhancer => mayavoz}/cli/train_config/model/Demucs.yaml (84%) rename {enhancer => mayavoz}/cli/train_config/model/WaveUnet.yaml (63%) rename {enhancer => mayavoz}/cli/train_config/optimizer/Adam.yaml (100%) rename {enhancer => mayavoz}/cli/train_config/trainer/default.yaml (100%) rename {enhancer => mayavoz}/cli/train_config/trainer/fastrun_dev.yaml (100%) create mode 100644 mayavoz/data/__init__.py rename {enhancer => mayavoz}/data/dataset.py (97%) rename {enhancer => mayavoz}/data/fileprocessor.py (100%) rename {enhancer => mayavoz}/inference.py (99%) rename {enhancer => mayavoz}/loss.py (100%) create mode 100644 mayavoz/models/__init__.py create mode 100644 mayavoz/models/complexnn/__init__.py rename {enhancer => mayavoz}/models/complexnn/conv.py (100%) rename {enhancer => mayavoz}/models/complexnn/rnn.py (100%) rename {enhancer => mayavoz}/models/complexnn/utils.py (100%) rename {enhancer => mayavoz}/models/dccrn.py (97%) rename {enhancer => mayavoz}/models/demucs.py (97%) rename {enhancer => mayavoz}/models/model.py (94%) rename {enhancer => mayavoz}/models/waveunet.py (98%) create mode 100644 mayavoz/utils/__init__.py rename {enhancer => mayavoz}/utils/config.py (100%) rename {enhancer => mayavoz}/utils/io.py (100%) rename {enhancer => mayavoz}/utils/random.py (100%) rename {enhancer => mayavoz}/utils/transforms.py (100%) rename {enhancer => mayavoz}/utils/utils.py (93%) rename {enhancer => mayavoz}/version.py (100%) create mode 100644 recipes/DNS/DNS-2020/cli/train.py create mode 100644 recipes/DNS/DNS-2020/cli/train_config/config.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/hyperparameters/default.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/mlflow/experiment.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/model/DCCRN.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/model/Demucs.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/model/WaveUnet.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/optimizer/Adam.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/trainer/default.yaml create mode 100644 recipes/DNS/DNS-2020/cli/train_config/trainer/fastrun_dev.yaml diff --git a/.flake8 b/.flake8 index abbbc73..d738193 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -per-file-ignores = __init__.py:F401 +per-file-ignores = "mayavoz/model/__init__.py:F401" ignore = E203, E266, E501, W503 # line length is intentionally set to 80 here because black uses Bugbear # See https://github.com/psf/black/blob/master/README.md#line-length for more details diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4c64745..b42bdc5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Enhancer +name: mayavoz on: push: @@ -40,12 +40,12 @@ jobs: sudo apt-get install libsndfile1 pip install -r requirements.txt pip install black pytest-cov - - name: Install enhancer + - name: Install mayavoz run: | pip install -e .[dev,testing] - name: Run black run: - black --check . --exclude enhancer/version.py + black --check . --exclude mayavoz/version.py - name: Test with pytest run: - pytest tests --cov=enhancer/ + pytest tests --cov=mayavoz/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 807429c..b0a3da3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,6 +23,7 @@ repos: hooks: - id: flake8 args: ['--ignore=E203,E501,F811,E712,W503'] + exclude: __init__.py # Formatting, Whitespace, etc - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml deleted file mode 100644 index ba44597..0000000 --- a/enhancer/cli/train_config/dataset/Vctk_local.yaml +++ /dev/null @@ -1,13 +0,0 @@ -_target_: enhancer.data.dataset.EnhancerDataset -name : vctk -root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk -duration : 1.0 -sampling_rate: 16000 -batch_size: 64 -num_workers : 0 - -files: - train_clean : clean_testset_wav - test_clean : clean_testset_wav - train_noisy : noisy_testset_wav - test_noisy : noisy_testset_wav diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py deleted file mode 100644 index 7efd946..0000000 --- a/enhancer/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from enhancer.data.dataset import EnhancerDataset diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py deleted file mode 100644 index 2d97568..0000000 --- a/enhancer/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from enhancer.models.demucs import Demucs -from enhancer.models.model import Model -from enhancer.models.waveunet import WaveUnet diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py deleted file mode 100644 index 918a261..0000000 --- a/enhancer/models/complexnn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from enhancer.models.complexnn.conv import ComplexConv2d # noqa -from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa -from enhancer.models.complexnn.rnn import ComplexLSTM # noqa -from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa -from enhancer.models.complexnn.utils import ComplexRelu # noqa diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py deleted file mode 100644 index de0db9f..0000000 --- a/enhancer/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from enhancer.utils.config import Files -from enhancer.utils.io import Audio -from enhancer.utils.utils import check_files diff --git a/environment.yml b/environment.yml index 8da22e1..2e7f6cf 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: enhancer +name: mayavoz dependencies: - pip=21.0.1 diff --git a/enhancer/__init__.py b/mayavoz/__init__.py similarity index 100% rename from enhancer/__init__.py rename to mayavoz/__init__.py diff --git a/enhancer/cli/train.py b/mayavoz/cli/train.py similarity index 100% rename from enhancer/cli/train.py rename to mayavoz/cli/train.py diff --git a/enhancer/cli/train_config/config.yaml b/mayavoz/cli/train_config/config.yaml similarity index 100% rename from enhancer/cli/train_config/config.yaml rename to mayavoz/cli/train_config/config.yaml diff --git a/enhancer/cli/train_config/dataset/DNS-2020.yaml b/mayavoz/cli/train_config/dataset/DNS-2020.yaml similarity index 85% rename from enhancer/cli/train_config/dataset/DNS-2020.yaml rename to mayavoz/cli/train_config/dataset/DNS-2020.yaml index 09a14fb..520efc9 100644 --- a/enhancer/cli/train_config/dataset/DNS-2020.yaml +++ b/mayavoz/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.EnhancerDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/mayavoz/cli/train_config/dataset/Vctk.yaml similarity index 85% rename from enhancer/cli/train_config/dataset/Vctk.yaml rename to mayavoz/cli/train_config/dataset/Vctk.yaml index c33d29a..f30a835 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/mayavoz/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/mayavoz/cli/train_config/hyperparameters/default.yaml similarity index 100% rename from enhancer/cli/train_config/hyperparameters/default.yaml rename to mayavoz/cli/train_config/hyperparameters/default.yaml diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/mayavoz/cli/train_config/mlflow/experiment.yaml similarity index 59% rename from enhancer/cli/train_config/mlflow/experiment.yaml rename to mayavoz/cli/train_config/mlflow/experiment.yaml index d597333..9173e38 100644 --- a/enhancer/cli/train_config/mlflow/experiment.yaml +++ b/mayavoz/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ -experiment_name : shahules/enhancer +experiment_name : shahules/mayavoz run_name : Demucs + Vtck with stride + augmentations diff --git a/enhancer/cli/train_config/model/DCCRN.yaml b/mayavoz/cli/train_config/model/DCCRN.yaml similarity index 90% rename from enhancer/cli/train_config/model/DCCRN.yaml rename to mayavoz/cli/train_config/model/DCCRN.yaml index 3190391..d2ffcf1 100644 --- a/enhancer/cli/train_config/model/DCCRN.yaml +++ b/mayavoz/cli/train_config/model/DCCRN.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.dccrn.DCCRN +_target_: mayavoz.models.dccrn.DCCRN num_channels: 1 sampling_rate : 16000 complex_lstm : True diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/mayavoz/cli/train_config/model/Demucs.yaml similarity index 84% rename from enhancer/cli/train_config/model/Demucs.yaml rename to mayavoz/cli/train_config/model/Demucs.yaml index 513e603..f8d2eb8 100644 --- a/enhancer/cli/train_config/model/Demucs.yaml +++ b/mayavoz/cli/train_config/model/Demucs.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.demucs.Demucs +_target_: mayavoz.models.demucs.Demucs num_channels: 1 resample: 4 sampling_rate : 16000 diff --git a/enhancer/cli/train_config/model/WaveUnet.yaml b/mayavoz/cli/train_config/model/WaveUnet.yaml similarity index 63% rename from enhancer/cli/train_config/model/WaveUnet.yaml rename to mayavoz/cli/train_config/model/WaveUnet.yaml index 29d48c7..7c17448 100644 --- a/enhancer/cli/train_config/model/WaveUnet.yaml +++ b/mayavoz/cli/train_config/model/WaveUnet.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.waveunet.WaveUnet +_target_: mayavoz.models.waveunet.WaveUnet num_channels : 1 depth : 9 initial_output_channels: 24 diff --git a/enhancer/cli/train_config/optimizer/Adam.yaml b/mayavoz/cli/train_config/optimizer/Adam.yaml similarity index 100% rename from enhancer/cli/train_config/optimizer/Adam.yaml rename to mayavoz/cli/train_config/optimizer/Adam.yaml diff --git a/enhancer/cli/train_config/trainer/default.yaml b/mayavoz/cli/train_config/trainer/default.yaml similarity index 100% rename from enhancer/cli/train_config/trainer/default.yaml rename to mayavoz/cli/train_config/trainer/default.yaml diff --git a/enhancer/cli/train_config/trainer/fastrun_dev.yaml b/mayavoz/cli/train_config/trainer/fastrun_dev.yaml similarity index 100% rename from enhancer/cli/train_config/trainer/fastrun_dev.yaml rename to mayavoz/cli/train_config/trainer/fastrun_dev.yaml diff --git a/mayavoz/data/__init__.py b/mayavoz/data/__init__.py new file mode 100644 index 0000000..c7663d7 --- /dev/null +++ b/mayavoz/data/__init__.py @@ -0,0 +1 @@ +from mayavoz.data.dataset import EnhancerDataset diff --git a/enhancer/data/dataset.py b/mayavoz/data/dataset.py similarity index 97% rename from enhancer/data/dataset.py rename to mayavoz/data/dataset.py index 284dfdb..5296499 100644 --- a/enhancer/data/dataset.py +++ b/mayavoz/data/dataset.py @@ -11,11 +11,11 @@ import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, RandomSampler from torch_audiomentations import Compose -from enhancer.data.fileprocessor import Fileprocessor -from enhancer.utils import check_files -from enhancer.utils.config import Files -from enhancer.utils.io import Audio -from enhancer.utils.random import create_unique_rng +from mayavoz.data.fileprocessor import Fileprocessor +from mayavoz.utils import check_files +from mayavoz.utils.config import Files +from mayavoz.utils.io import Audio +from mayavoz.utils.random import create_unique_rng LARGE_NUM = 2147483647 @@ -258,7 +258,7 @@ class EnhancerDataset(TaskDataset): root directory of the dataset containing clean/noisy folders files : Files dataclass containing train_clean, train_noisy, test_clean, test_noisy - folder names (refer enhancer.utils.Files dataclass) + folder names (refer mayavoz.utils.Files dataclass) min_valid_minutes: float minimum validation split size time in minutes algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data. diff --git a/enhancer/data/fileprocessor.py b/mayavoz/data/fileprocessor.py similarity index 100% rename from enhancer/data/fileprocessor.py rename to mayavoz/data/fileprocessor.py diff --git a/enhancer/inference.py b/mayavoz/inference.py similarity index 99% rename from enhancer/inference.py rename to mayavoz/inference.py index d9282fd..112f619 100644 --- a/enhancer/inference.py +++ b/mayavoz/inference.py @@ -8,7 +8,7 @@ from librosa import load as load_audio from scipy.io import wavfile from scipy.signal import get_window -from enhancer.utils import Audio +from mayavoz.utils import Audio class Inference: diff --git a/enhancer/loss.py b/mayavoz/loss.py similarity index 100% rename from enhancer/loss.py rename to mayavoz/loss.py diff --git a/mayavoz/models/__init__.py b/mayavoz/models/__init__.py new file mode 100644 index 0000000..9cf2b9b --- /dev/null +++ b/mayavoz/models/__init__.py @@ -0,0 +1,3 @@ +from mayavoz.models.demucs import Demucs +from mayavoz.models.model import Model +from mayavoz.models.waveunet import WaveUnet diff --git a/mayavoz/models/complexnn/__init__.py b/mayavoz/models/complexnn/__init__.py new file mode 100644 index 0000000..d304e81 --- /dev/null +++ b/mayavoz/models/complexnn/__init__.py @@ -0,0 +1,5 @@ +from mayavoz.models.complexnn.conv import ComplexConv2d # noqa +from mayavoz.models.complexnn.conv import ComplexConvTranspose2d # noqa +from mayavoz.models.complexnn.rnn import ComplexLSTM # noqa +from mayavoz.models.complexnn.utils import ComplexBatchNorm2D # noqa +from mayavoz.models.complexnn.utils import ComplexRelu # noqa diff --git a/enhancer/models/complexnn/conv.py b/mayavoz/models/complexnn/conv.py similarity index 100% rename from enhancer/models/complexnn/conv.py rename to mayavoz/models/complexnn/conv.py diff --git a/enhancer/models/complexnn/rnn.py b/mayavoz/models/complexnn/rnn.py similarity index 100% rename from enhancer/models/complexnn/rnn.py rename to mayavoz/models/complexnn/rnn.py diff --git a/enhancer/models/complexnn/utils.py b/mayavoz/models/complexnn/utils.py similarity index 100% rename from enhancer/models/complexnn/utils.py rename to mayavoz/models/complexnn/utils.py diff --git a/enhancer/models/dccrn.py b/mayavoz/models/dccrn.py similarity index 97% rename from enhancer/models/dccrn.py rename to mayavoz/models/dccrn.py index 7b1e5b1..372696c 100644 --- a/enhancer/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -5,18 +5,18 @@ import torch import torch.nn.functional as F from torch import nn -from enhancer.data import EnhancerDataset -from enhancer.models import Model -from enhancer.models.complexnn import ( +from mayavoz.data import EnhancerDataset +from mayavoz.models import Model +from mayavoz.models.complexnn import ( ComplexBatchNorm2D, ComplexConv2d, ComplexConvTranspose2d, ComplexLSTM, ComplexRelu, ) -from enhancer.models.complexnn.utils import complex_cat -from enhancer.utils.transforms import ConviSTFT, ConvSTFT -from enhancer.utils.utils import merge_dict +from mayavoz.models.complexnn.utils import complex_cat +from mayavoz.utils.transforms import ConviSTFT, ConvSTFT +from mayavoz.utils.utils import merge_dict class DCCRN_ENCODER(nn.Module): diff --git a/enhancer/models/demucs.py b/mayavoz/models/demucs.py similarity index 97% rename from enhancer/models/demucs.py rename to mayavoz/models/demucs.py index fafb84e..a5e3147 100644 --- a/enhancer/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -5,10 +5,10 @@ from typing import List, Optional, Union import torch.nn.functional as F from torch import nn -from enhancer.data.dataset import EnhancerDataset -from enhancer.models.model import Model -from enhancer.utils.io import Audio as audio -from enhancer.utils.utils import merge_dict +from mayavoz.data.dataset import EnhancerDataset +from mayavoz.models.model import Model +from mayavoz.utils.io import Audio as audio +from mayavoz.utils.utils import merge_dict class DemucsLSTM(nn.Module): diff --git a/enhancer/models/model.py b/mayavoz/models/model.py similarity index 94% rename from enhancer/models/model.py rename to mayavoz/models/model.py index 9f285d3..2957e5b 100644 --- a/enhancer/models/model.py +++ b/mayavoz/models/model.py @@ -13,14 +13,14 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from torch import nn from torch.optim import Adam -from enhancer.data.dataset import EnhancerDataset -from enhancer.inference import Inference -from enhancer.loss import LOSS_MAP, LossWrapper -from enhancer.version import __version__ +from mayavoz.data.dataset import EnhancerDataset +from mayavoz.inference import Inference +from mayavoz.loss import LOSS_MAP, LossWrapper +from mayavoz.version import __version__ CACHE_DIR = os.getenv( "ENHANCER_CACHE", - os.path.expanduser("~/.cache/torch/enhancer"), + os.path.expanduser("~/.cache/torch/mayavoz"), ) HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu" @@ -37,7 +37,7 @@ class Model(pl.LightningModule): lr: float, optional learning rate for model training dataset: EnhancerDataset, optional - Enhancer dataset used for training/validation + mayavoz dataset used for training/validation duration: float, optional duration used for training/inference loss : string or List of strings or custom loss (nn.Module), default to "mse" @@ -56,9 +56,7 @@ class Model(pl.LightningModule): metric: Union[str, List, Any] = "mse", ): super().__init__() - assert ( - num_channels == 1 - ), "Enhancer only support for mono channel models" + assert num_channels == 1, "mayavoz only support for mono channel models" self.dataset = dataset self.save_hyperparameters( "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" @@ -235,8 +233,8 @@ class Model(pl.LightningModule): def on_save_checkpoint(self, checkpoint): - checkpoint["enhancer"] = { - "version": {"enhancer": __version__, "pytorch": torch.__version__}, + checkpoint["mayavoz"] = { + "version": {"mayavoz": __version__, "pytorch": torch.__version__}, "architecture": { "module": self.__class__.__module__, "class": self.__class__.__name__, @@ -319,7 +317,7 @@ class Model(pl.LightningModule): ) model_path_pl = cached_download( url=url, - library_name="enhancer", + library_name="mayavoz", library_version=__version__, cache_dir=cached_dir, use_auth_token=use_auth_token, @@ -329,8 +327,8 @@ class Model(pl.LightningModule): map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl, map_location) - module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] - class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] + module_name = loaded_checkpoint["mayavoz"]["architecture"]["module"] + class_name = loaded_checkpoint["mayavoz"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) diff --git a/enhancer/models/waveunet.py b/mayavoz/models/waveunet.py similarity index 98% rename from enhancer/models/waveunet.py rename to mayavoz/models/waveunet.py index ea5646a..ead2146 100644 --- a/enhancer/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -5,8 +5,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from enhancer.data.dataset import EnhancerDataset -from enhancer.models.model import Model +from mayavoz.data.dataset import EnhancerDataset +from mayavoz.models.model import Model class WavenetDecoder(nn.Module): diff --git a/mayavoz/utils/__init__.py b/mayavoz/utils/__init__.py new file mode 100644 index 0000000..cbd3785 --- /dev/null +++ b/mayavoz/utils/__init__.py @@ -0,0 +1,3 @@ +from mayavoz.utils.config import Files +from mayavoz.utils.io import Audio +from mayavoz.utils.utils import check_files diff --git a/enhancer/utils/config.py b/mayavoz/utils/config.py similarity index 100% rename from enhancer/utils/config.py rename to mayavoz/utils/config.py diff --git a/enhancer/utils/io.py b/mayavoz/utils/io.py similarity index 100% rename from enhancer/utils/io.py rename to mayavoz/utils/io.py diff --git a/enhancer/utils/random.py b/mayavoz/utils/random.py similarity index 100% rename from enhancer/utils/random.py rename to mayavoz/utils/random.py diff --git a/enhancer/utils/transforms.py b/mayavoz/utils/transforms.py similarity index 100% rename from enhancer/utils/transforms.py rename to mayavoz/utils/transforms.py diff --git a/enhancer/utils/utils.py b/mayavoz/utils/utils.py similarity index 93% rename from enhancer/utils/utils.py rename to mayavoz/utils/utils.py index ad45139..17730a4 100644 --- a/enhancer/utils/utils.py +++ b/mayavoz/utils/utils.py @@ -1,7 +1,7 @@ import os from typing import Optional -from enhancer.utils.config import Files +from mayavoz.utils.config import Files def check_files(root_dir: str, files: Files): diff --git a/enhancer/version.py b/mayavoz/version.py similarity index 100% rename from enhancer/version.py rename to mayavoz/version.py diff --git a/notebooks/Custom_model_training.ipynb b/notebooks/Custom_model_training.ipynb index 7c963c2..2e5ed67 100644 --- a/notebooks/Custom_model_training.ipynb +++ b/notebooks/Custom_model_training.ipynb @@ -316,9 +316,9 @@ ], "metadata": { "kernelspec": { - "display_name": "enhancer", + "display_name": "mayavoz", "language": "python", - "name": "enhancer" + "name": "mayavoz" }, "language_info": { "codemirror_mode": { diff --git a/notebooks/Getting_started.ipynb b/notebooks/Getting_started.ipynb index 8e9506c..c9a47dd 100644 --- a/notebooks/Getting_started.ipynb +++ b/notebooks/Getting_started.ipynb @@ -374,7 +374,7 @@ "```\n", "\n", "```yaml\n", - "_target_: enhancer.models.demucs.Demucs\n", + "_target_: mayavoz.models.demucs.Demucs\n", "num_channels: 1\n", "resample: 4\n", "sampling_rate : 16000\n", @@ -405,9 +405,9 @@ ], "metadata": { "kernelspec": { - "display_name": "enhancer", + "display_name": "mayavoz", "language": "python", - "name": "enhancer" + "name": "mayavoz" }, "language_info": { "codemirror_mode": { diff --git a/recipes/DNS/DNS-2020/cli/train.py b/recipes/DNS/DNS-2020/cli/train.py new file mode 100644 index 0000000..c00c024 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train.py @@ -0,0 +1,120 @@ +import os +from types import MethodType + +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) +from pytorch_lightning.loggers import MLFlowLogger +from torch.optim.lr_scheduler import ReduceLROnPlateau + +# from torch_audiomentations import Compose, Shift + +os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID", "0") + + +@hydra.main(config_path="train_config", config_name="config") +def main(config: DictConfig): + + OmegaConf.save(config, "config_log.yaml") + + callbacks = [] + logger = MLFlowLogger( + experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name, + tags={"JOB_ID": JOB_ID}, + ) + + parameters = config.hyperparameters + # apply_augmentations = Compose( + # [ + # Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), + # ] + # ) + + dataset = instantiate(config.dataset, augmentations=None) + model = instantiate( + config.model, + dataset=dataset, + lr=parameters.get("lr"), + loss=parameters.get("loss"), + metric=parameters.get("metric"), + ) + + direction = model.valid_monitor + checkpoint = ModelCheckpoint( + dirpath="./model", + filename=f"model_{JOB_ID}", + monitor="valid_loss", + verbose=False, + mode=direction, + every_n_epochs=1, + ) + callbacks.append(checkpoint) + callbacks.append(LearningRateMonitor(logging_interval="epoch")) + + if parameters.get("Early_stop", False): + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) + + def configure_optimizers(self): + optimizer = instantiate( + config.optimizer, + lr=parameters.get("lr"), + params=self.parameters(), + ) + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + mode=direction, + factor=parameters.get("ReduceLr_factor", 0.1), + verbose=True, + min_lr=parameters.get("min_lr", 1e-6), + patience=parameters.get("ReduceLr_patience", 3), + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', + } + + model.configure_optimizers = MethodType(configure_optimizers, model) + + trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) + trainer.fit(model) + trainer.test(model) + + logger.experiment.log_artifact( + logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" + ) + + saved_location = os.path.join( + trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" + ) + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id, saved_location) + logger.experiment.log_param( + logger.run_id, + "num_train_steps_per_epoch", + dataset.train__len__() / dataset.batch_size, + ) + logger.experiment.log_param( + logger.run_id, + "num_valid_steps_per_epoch", + dataset.val__len__() / dataset.batch_size, + ) + + +if __name__ == "__main__": + main() diff --git a/recipes/DNS/DNS-2020/cli/train_config/config.yaml b/recipes/DNS/DNS-2020/cli/train_config/config.yaml new file mode 100644 index 0000000..8d0ab14 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/config.yaml @@ -0,0 +1,7 @@ +defaults: + - model : Demucs + - dataset : Vctk + - optimizer : Adam + - hyperparameters : default + - trainer : default + - mlflow : experiment diff --git a/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml b/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml new file mode 100644 index 0000000..520efc9 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml @@ -0,0 +1,12 @@ +_target_: mayavoz.data.dataset.EnhancerDataset +root_dir : /Users/shahules/Myprojects/MS-SNSD +name : dns-2020 +duration : 2.0 +sampling_rate: 16000 +batch_size: 32 +valid_size: 0.05 +files: + train_clean : CleanSpeech_training + test_clean : CleanSpeech_training + train_noisy : NoisySpeech_training + test_noisy : NoisySpeech_training diff --git a/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml b/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml new file mode 100644 index 0000000..f30a835 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml @@ -0,0 +1,13 @@ +_target_: mayavoz.data.dataset.EnhancerDataset +name : vctk +root_dir : /scratch/c.sistc3/DS_10283_2791 +duration : 4.5 +stride : 2 +sampling_rate: 16000 +batch_size: 32 +valid_minutes : 15 +files: + train_clean : clean_trainset_28spk_wav + test_clean : clean_testset_wav + train_noisy : noisy_trainset_28spk_wav + test_noisy : noisy_testset_wav diff --git a/recipes/DNS/DNS-2020/cli/train_config/hyperparameters/default.yaml b/recipes/DNS/DNS-2020/cli/train_config/hyperparameters/default.yaml new file mode 100644 index 0000000..1782ea9 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/hyperparameters/default.yaml @@ -0,0 +1,7 @@ +loss : mae +metric : [stoi,pesq,si-sdr] +lr : 0.0003 +ReduceLr_patience : 5 +ReduceLr_factor : 0.2 +min_lr : 0.000001 +EarlyStopping_factor : 10 diff --git a/recipes/DNS/DNS-2020/cli/train_config/mlflow/experiment.yaml b/recipes/DNS/DNS-2020/cli/train_config/mlflow/experiment.yaml new file mode 100644 index 0000000..9173e38 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/mlflow/experiment.yaml @@ -0,0 +1,2 @@ +experiment_name : shahules/mayavoz +run_name : Demucs + Vtck with stride + augmentations diff --git a/recipes/DNS/DNS-2020/cli/train_config/model/DCCRN.yaml b/recipes/DNS/DNS-2020/cli/train_config/model/DCCRN.yaml new file mode 100644 index 0000000..d2ffcf1 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/model/DCCRN.yaml @@ -0,0 +1,25 @@ +_target_: mayavoz.models.dccrn.DCCRN +num_channels: 1 +sampling_rate : 16000 +complex_lstm : True +complex_norm : True +complex_relu : True +masking_mode : True + +encoder_decoder: + initial_output_channels : 32 + depth : 6 + kernel_size : 5 + growth_factor : 2 + stride : 2 + padding : 2 + output_padding : 1 + +lstm: + num_layers : 2 + hidden_size : 256 + +stft: + window_len : 400 + hop_size : 100 + nfft : 512 diff --git a/recipes/DNS/DNS-2020/cli/train_config/model/Demucs.yaml b/recipes/DNS/DNS-2020/cli/train_config/model/Demucs.yaml new file mode 100644 index 0000000..f8d2eb8 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/model/Demucs.yaml @@ -0,0 +1,16 @@ +_target_: mayavoz.models.demucs.Demucs +num_channels: 1 +resample: 4 +sampling_rate : 16000 + +encoder_decoder: + depth: 4 + initial_output_channels: 64 + kernel_size: 8 + stride: 4 + growth_factor: 2 + glu: True + +lstm: + bidirectional: False + num_layers: 2 diff --git a/recipes/DNS/DNS-2020/cli/train_config/model/WaveUnet.yaml b/recipes/DNS/DNS-2020/cli/train_config/model/WaveUnet.yaml new file mode 100644 index 0000000..7c17448 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/model/WaveUnet.yaml @@ -0,0 +1,5 @@ +_target_: mayavoz.models.waveunet.WaveUnet +num_channels : 1 +depth : 9 +initial_output_channels: 24 +sampling_rate : 16000 diff --git a/recipes/DNS/DNS-2020/cli/train_config/optimizer/Adam.yaml b/recipes/DNS/DNS-2020/cli/train_config/optimizer/Adam.yaml new file mode 100644 index 0000000..7952b81 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/optimizer/Adam.yaml @@ -0,0 +1,6 @@ +_target_: torch.optim.Adam +lr: 1e-3 +betas: [0.9, 0.999] +eps: 1e-08 +weight_decay: 0 +amsgrad: False diff --git a/recipes/DNS/DNS-2020/cli/train_config/trainer/default.yaml b/recipes/DNS/DNS-2020/cli/train_config/trainer/default.yaml new file mode 100644 index 0000000..958c418 --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/trainer/default.yaml @@ -0,0 +1,46 @@ +_target_: pytorch_lightning.Trainer +accelerator: gpu +accumulate_grad_batches: 1 +amp_backend: native +auto_lr_find: True +auto_scale_batch_size: False +auto_select_gpus: True +benchmark: False +check_val_every_n_epoch: 1 +detect_anomaly: False +deterministic: False +devices: 2 +enable_checkpointing: True +enable_model_summary: True +enable_progress_bar: True +fast_dev_run: False +gpus: null +gradient_clip_val: 0 +gradient_clip_algorithm: norm +ipus: null +limit_predict_batches: 1.0 +limit_test_batches: 1.0 +limit_train_batches: 1.0 +limit_val_batches: 1.0 +log_every_n_steps: 50 +max_epochs: 200 +max_steps: -1 +max_time: null +min_epochs: 1 +min_steps: null +move_metrics_to_cpu: False +multiple_trainloader_mode: max_size_cycle +num_nodes: 1 +num_processes: 1 +num_sanity_val_steps: 2 +overfit_batches: 0.0 +precision: 32 +profiler: null +reload_dataloaders_every_n_epochs: 0 +replace_sampler_ddp: True +strategy: ddp +sync_batchnorm: False +tpu_cores: null +track_grad_norm: -1 +val_check_interval: 1.0 +weights_save_path: null diff --git a/recipes/DNS/DNS-2020/cli/train_config/trainer/fastrun_dev.yaml b/recipes/DNS/DNS-2020/cli/train_config/trainer/fastrun_dev.yaml new file mode 100644 index 0000000..682149e --- /dev/null +++ b/recipes/DNS/DNS-2020/cli/train_config/trainer/fastrun_dev.yaml @@ -0,0 +1,2 @@ +_target_: pytorch_lightning.Trainer +fast_dev_run: True diff --git a/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml index 25278eb..cca932a 100644 --- a/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/recipes/Valentini-dataset/28spk/Demucs/train_config/mlflow/experiment.yaml b/recipes/Valentini-dataset/28spk/Demucs/train_config/mlflow/experiment.yaml index e8893f6..933701b 100644 --- a/recipes/Valentini-dataset/28spk/Demucs/train_config/mlflow/experiment.yaml +++ b/recipes/Valentini-dataset/28spk/Demucs/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ -experiment_name : shahules/enhancer +experiment_name : shahules/mayavoz run_name : baseline diff --git a/recipes/Valentini-dataset/28spk/Demucs/train_config/model/Demucs.yaml b/recipes/Valentini-dataset/28spk/Demucs/train_config/model/Demucs.yaml index 0a051b5..2baf87d 100644 --- a/recipes/Valentini-dataset/28spk/Demucs/train_config/model/Demucs.yaml +++ b/recipes/Valentini-dataset/28spk/Demucs/train_config/model/Demucs.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.demucs.Demucs +_target_: mayavoz.models.demucs.Demucs num_channels: 1 resample: 4 sampling_rate : 16000 diff --git a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml index 831d576..870bbb9 100644 --- a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 2 diff --git a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/mlflow/experiment.yaml b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/mlflow/experiment.yaml index e8893f6..933701b 100644 --- a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/mlflow/experiment.yaml +++ b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ -experiment_name : shahules/enhancer +experiment_name : shahules/mayavoz run_name : baseline diff --git a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/model/WaveUnet.yaml b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/model/WaveUnet.yaml index 29d48c7..7c17448 100644 --- a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/model/WaveUnet.yaml +++ b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/model/WaveUnet.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.waveunet.WaveUnet +_target_: mayavoz.models.waveunet.WaveUnet num_channels : 1 depth : 9 initial_output_channels: 24 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml index 09a14fb..520efc9 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.EnhancerDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml index c33d29a..f30a835 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/mlflow/experiment.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/mlflow/experiment.yaml index d597333..9173e38 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/mlflow/experiment.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ -experiment_name : shahules/enhancer +experiment_name : shahules/mayavoz run_name : Demucs + Vtck with stride + augmentations diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/model/DCCRN.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/model/DCCRN.yaml index 3190391..d2ffcf1 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/model/DCCRN.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/model/DCCRN.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.dccrn.DCCRN +_target_: mayavoz.models.dccrn.DCCRN num_channels: 1 sampling_rate : 16000 complex_lstm : True diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/model/Demucs.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/model/Demucs.yaml index 513e603..f8d2eb8 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/model/Demucs.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/model/Demucs.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.demucs.Demucs +_target_: mayavoz.models.demucs.Demucs num_channels: 1 resample: 4 sampling_rate : 16000 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/model/WaveUnet.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/model/WaveUnet.yaml index 29d48c7..7c17448 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/model/WaveUnet.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/model/WaveUnet.yaml @@ -1,4 +1,4 @@ -_target_: enhancer.models.waveunet.WaveUnet +_target_: mayavoz.models.waveunet.WaveUnet num_channels : 1 depth : 9 initial_output_channels: 24 diff --git a/setup.cfg b/setup.cfg index 309ac9a..b860772 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files [metadata] -name = enhancer +name = mayavoz description = Deep learning for speech enhacement author = Shahul Ess author-email = shahules786@gmail.com @@ -53,7 +53,7 @@ cli = [options.entry_points] console_scripts = - enhancer-train=enhancer.cli.train:train + mayavoz-train=.cli.train:train [test] # py.test options when running `python setup.py test` @@ -66,7 +66,7 @@ extras = True # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml # in order to write a coverage file that can be read by Jenkins. addopts = - --cov enhancer --cov-report term-missing + --cov mayavoz --cov-report term-missing --verbose norecursedirs = dist diff --git a/setup.py b/setup.py index 79282b3..26ee10c 100644 --- a/setup.py +++ b/setup.py @@ -33,15 +33,15 @@ elif sha != "Unknown": version += "+" + sha[:7] print("-- Building version " + version) -version_path = ROOT_DIR / "enhancer" / "version.py" +version_path = ROOT_DIR / "mayavoz" / "version.py" with open(version_path, "w") as f: f.write("__version__ = '{}'\n".format(version)) if __name__ == "__main__": setup( - name="enhancer", - namespace_packages=["enhancer"], + name="mayavoz", + namespace_packages=["mayavoz"], version=version, packages=find_packages(), install_requires=requirements, diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index 4d14871..8f42082 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -1,7 +1,7 @@ import pytest import torch -from enhancer.loss import mean_absolute_error, mean_squared_error +from mayavoz.loss import mean_absolute_error, mean_squared_error loss_functions = [mean_absolute_error(), mean_squared_error()] diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 524a6cf..8c9c8e0 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,8 +1,8 @@ import torch -from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d -from enhancer.models.complexnn.rnn import ComplexLSTM -from enhancer.models.complexnn.utils import ComplexBatchNorm2D +from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from mayavoz.models.complexnn.rnn import ComplexLSTM +from mayavoz.models.complexnn.utils import ComplexBatchNorm2D def test_complexconv2d(): diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index 29e030e..51bdb27 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,9 +1,9 @@ import pytest import torch -from enhancer.data.dataset import EnhancerDataset -from enhancer.models import Demucs -from enhancer.utils.config import Files +from mayavoz.data.dataset import EnhancerDataset +from mayavoz.models import Demucs +from mayavoz.utils.config import Files @pytest.fixture diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py index 96a853b..7e93af3 100644 --- a/tests/models/test_dccrn.py +++ b/tests/models/test_dccrn.py @@ -1,9 +1,9 @@ import pytest import torch -from enhancer.data.dataset import EnhancerDataset -from enhancer.models.dccrn import DCCRN -from enhancer.utils.config import Files +from mayavoz.data.dataset import EnhancerDataset +from mayavoz.models.dccrn import DCCRN +from mayavoz.utils.config import Files @pytest.fixture diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 9c4dd96..9526820 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,9 +1,9 @@ import pytest import torch -from enhancer.data.dataset import EnhancerDataset -from enhancer.models import WaveUnet -from enhancer.utils.config import Files +from mayavoz.data.dataset import EnhancerDataset +from mayavoz.models import WaveUnet +from mayavoz.utils.config import Files @pytest.fixture diff --git a/tests/test_inference.py b/tests/test_inference.py index a6e2423..f727938 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,7 +1,7 @@ import pytest import torch -from enhancer.inference import Inference +from mayavoz.inference import Inference @pytest.mark.parametrize( diff --git a/tests/transforms_test.py b/tests/transforms_test.py index 89425ad..d9399ea 100644 --- a/tests/transforms_test.py +++ b/tests/transforms_test.py @@ -1,6 +1,6 @@ import torch -from enhancer.utils.transforms import ConviSTFT, ConvSTFT +from mayavoz.utils.transforms import ConviSTFT, ConvSTFT def test_stft_istft(): diff --git a/tests/utils_test.py b/tests/utils_test.py index cd5240c..2b3e7a4 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -2,8 +2,8 @@ import numpy as np import pytest import torch -from enhancer.data.fileprocessor import Fileprocessor -from enhancer.utils.io import Audio +from mayavoz.data.fileprocessor import Fileprocessor +from mayavoz.utils.io import Audio def test_io_channel(): From f8a44f823a125673c3e3944f0dfd43279f4439cd Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 14 Nov 2022 16:19:57 +0530 Subject: [PATCH 3/7] fix typo --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index b860772..b31929e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ cli = [options.entry_points] console_scripts = - mayavoz-train=.cli.train:train + mayavoz-train=mayavoz.cli.train:train [test] # py.test options when running `python setup.py test` From 12cde1b0abbd8ef8bf7b99b5393c843d8579232c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 14 Nov 2022 16:30:14 +0530 Subject: [PATCH 4/7] change save name --- mayavoz/models/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index 2957e5b..d82c5c5 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -24,6 +24,7 @@ CACHE_DIR = os.getenv( ) HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu" +SAVE_NAME = "enhancer" class Model(pl.LightningModule): @@ -233,8 +234,8 @@ class Model(pl.LightningModule): def on_save_checkpoint(self, checkpoint): - checkpoint["mayavoz"] = { - "version": {"mayavoz": __version__, "pytorch": torch.__version__}, + checkpoint[SAVE_NAME] = { + "version": {SAVE_NAME: __version__, "pytorch": torch.__version__}, "architecture": { "module": self.__class__.__module__, "class": self.__class__.__name__, @@ -327,8 +328,8 @@ class Model(pl.LightningModule): map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl, map_location) - module_name = loaded_checkpoint["mayavoz"]["architecture"]["module"] - class_name = loaded_checkpoint["mayavoz"]["architecture"]["class"] + module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"] + class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) From ba63c5439961d6778e95e6fef987f258ed7b5e97 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 14 Nov 2022 16:31:51 +0530 Subject: [PATCH 5/7] ci-cd --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b42bdc5..9ba5718 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,7 +7,7 @@ on: push: branches: [ dev ] pull_request: - branches: [ dev ] + branches: [ main ] jobs: build: runs-on: ubuntu-latest From bfd53937c29dcdd01f991eb6f59175d991bf2a58 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 14:29:04 +0530 Subject: [PATCH 6/7] rename to mayamodel --- mayavoz/models/__init__.py | 2 +- mayavoz/models/dccrn.py | 4 ++-- mayavoz/models/demucs.py | 4 ++-- mayavoz/models/model.py | 6 +++--- mayavoz/models/waveunet.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mayavoz/models/__init__.py b/mayavoz/models/__init__.py index 9cf2b9b..6e82eb3 100644 --- a/mayavoz/models/__init__.py +++ b/mayavoz/models/__init__.py @@ -1,3 +1,3 @@ from mayavoz.models.demucs import Demucs -from mayavoz.models.model import Model +from mayavoz.models.model import Mayamodel from mayavoz.models.waveunet import WaveUnet diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 372696c..278072f 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch import nn from mayavoz.data import EnhancerDataset -from mayavoz.models import Model +from mayavoz.models import Mayamodel from mayavoz.models.complexnn import ( ComplexBatchNorm2D, ComplexConv2d, @@ -98,7 +98,7 @@ class DCCRN_DECODER(nn.Module): return self.decoder(waveform) -class DCCRN(Model): +class DCCRN(Mayamodel): STFT_DEFAULTS = { "window_len": 400, diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index a5e3147..db69c80 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch import nn from mayavoz.data.dataset import EnhancerDataset -from mayavoz.models.model import Model +from mayavoz.models.model import Mayamodel from mayavoz.utils.io import Audio as audio from mayavoz.utils.utils import merge_dict @@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module): return out -class Demucs(Model): +class Demucs(Mayamodel): """ Demucs model from https://arxiv.org/pdf/1911.13254.pdf parameters: diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index d82c5c5..aede7a3 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -27,7 +27,7 @@ DEFAULT_DEVICE = "cpu" SAVE_NAME = "enhancer" -class Model(pl.LightningModule): +class Mayamodel(pl.LightningModule): """ Base class for all models parameters: @@ -288,8 +288,8 @@ class Model(pl.LightningModule): Returns ------- - model : Model - Model + model : Mayamodel + Mayamodel See also -------- diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index ead2146..9e5a4ae 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from mayavoz.data.dataset import EnhancerDataset -from mayavoz.models.model import Model +from mayavoz.models.model import Mayamodel class WavenetDecoder(nn.Module): @@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module): return self.encoder(waveform) -class WaveUnet(Model): +class WaveUnet(Mayamodel): """ Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf parameters: From 8bc63becce164987d2b2565767eab73e0b00e1f7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 14:33:27 +0530 Subject: [PATCH 7/7] rename dataset --- mayavoz/cli/train_config/dataset/DNS-2020.yaml | 2 +- mayavoz/cli/train_config/dataset/Vctk.yaml | 2 +- mayavoz/data/__init__.py | 2 +- mayavoz/data/dataset.py | 2 +- mayavoz/models/dccrn.py | 6 +++--- mayavoz/models/demucs.py | 10 +++++----- mayavoz/models/model.py | 6 +++--- mayavoz/models/waveunet.py | 10 +++++----- notebooks/Custom_model_training.ipynb | 4 ++-- .../DNS-2020/cli/train_config/dataset/DNS-2020.yaml | 2 +- .../DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml | 2 +- .../28spk/Demucs/train_config/dataset/Vctk.yaml | 2 +- .../28spk/WaveUnet/train_config/dataset/Vctk.yaml | 2 +- .../28spk/cli/train_config/dataset/DNS-2020.yaml | 2 +- .../28spk/cli/train_config/dataset/Vctk.yaml | 2 +- tests/models/demucs_test.py | 4 ++-- tests/models/test_dccrn.py | 4 ++-- tests/models/test_waveunet.py | 4 ++-- 18 files changed, 34 insertions(+), 34 deletions(-) diff --git a/mayavoz/cli/train_config/dataset/DNS-2020.yaml b/mayavoz/cli/train_config/dataset/DNS-2020.yaml index 520efc9..5c67be2 100644 --- a/mayavoz/cli/train_config/dataset/DNS-2020.yaml +++ b/mayavoz/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/mayavoz/cli/train_config/dataset/Vctk.yaml b/mayavoz/cli/train_config/dataset/Vctk.yaml index f30a835..584abe7 100644 --- a/mayavoz/cli/train_config/dataset/Vctk.yaml +++ b/mayavoz/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/mayavoz/data/__init__.py b/mayavoz/data/__init__.py index c7663d7..02604df 100644 --- a/mayavoz/data/__init__.py +++ b/mayavoz/data/__init__.py @@ -1 +1 @@ -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset diff --git a/mayavoz/data/dataset.py b/mayavoz/data/dataset.py index 5296499..d47967c 100644 --- a/mayavoz/data/dataset.py +++ b/mayavoz/data/dataset.py @@ -248,7 +248,7 @@ class TaskDataset(pl.LightningDataModule): ) -class EnhancerDataset(TaskDataset): +class MayaDataset(TaskDataset): """ Dataset object for creating clean-noisy speech enhancement datasets paramters: diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 278072f..6b8646c 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from torch import nn -from mayavoz.data import EnhancerDataset +from mayavoz.data import MayaDataset from mayavoz.models import Mayamodel from mayavoz.models.complexnn import ( ComplexBatchNorm2D, @@ -134,13 +134,13 @@ class DCCRN(Mayamodel): num_channels: int = 1, sampling_rate=16000, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, duration: Optional[float] = None, loss: Union[str, List, Any] = "mse", metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, EnhancerDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index db69c80..8424f17 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import torch.nn.functional as F from torch import nn -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models.model import Mayamodel from mayavoz.utils.io import Audio as audio from mayavoz.utils.utils import merge_dict @@ -102,8 +102,8 @@ class Demucs(Mayamodel): sampling rate of input audio lr : float, defaults to 1e-3 learning rate used for training - dataset: EnhancerDataset, optional - EnhancerDataset object containing train/validation data for training + dataset: MayaDataset, optional + MayaDataset object containing train/validation data for training duration : float, optional chunk duration in seconds loss : string or List of strings @@ -135,13 +135,13 @@ class Demucs(Mayamodel): sampling_rate=16000, normalize=True, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", floor=1e-3, ): duration = ( - dataset.duration if isinstance(dataset, EnhancerDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index aede7a3..e248b2c 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -13,7 +13,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from torch import nn from torch.optim import Adam -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.inference import Inference from mayavoz.loss import LOSS_MAP, LossWrapper from mayavoz.version import __version__ @@ -37,7 +37,7 @@ class Mayamodel(pl.LightningModule): audio sampling rate lr: float, optional learning rate for model training - dataset: EnhancerDataset, optional + dataset: MayaDataset, optional mayavoz dataset used for training/validation duration: float, optional duration used for training/inference @@ -51,7 +51,7 @@ class Mayamodel(pl.LightningModule): num_channels: int = 1, sampling_rate: int = 16000, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List, Any] = "mse", diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index 9e5a4ae..c9acfda 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models.model import Mayamodel @@ -80,8 +80,8 @@ class WaveUnet(Mayamodel): sampling rate of input audio lr : float, defaults to 1e-3 learning rate used for training - dataset: EnhancerDataset, optional - EnhancerDataset object containing train/validation data for training + dataset: MayaDataset, optional + MayaDataset object containing train/validation data for training duration : float, optional chunk duration in seconds loss : string or List of strings @@ -97,13 +97,13 @@ class WaveUnet(Mayamodel): initial_output_channels: int = 24, sampling_rate: int = 16000, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, EnhancerDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: diff --git a/notebooks/Custom_model_training.ipynb b/notebooks/Custom_model_training.ipynb index 2e5ed67..7c963c2 100644 --- a/notebooks/Custom_model_training.ipynb +++ b/notebooks/Custom_model_training.ipynb @@ -316,9 +316,9 @@ ], "metadata": { "kernelspec": { - "display_name": "mayavoz", + "display_name": "enhancer", "language": "python", - "name": "mayavoz" + "name": "enhancer" }, "language_info": { "codemirror_mode": { diff --git a/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml b/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml index 520efc9..5c67be2 100644 --- a/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml +++ b/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml b/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml index f30a835..584abe7 100644 --- a/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml +++ b/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml index cca932a..8e726d1 100644 --- a/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml index 870bbb9..d2e6b30 100644 --- a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 2 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml index 520efc9..5c67be2 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml index f30a835..584abe7 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index 51bdb27..e1203b7 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,7 +1,7 @@ import pytest import torch -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models import Demucs from mayavoz.utils.config import Files @@ -15,7 +15,7 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py index 7e93af3..bc2a039 100644 --- a/tests/models/test_dccrn.py +++ b/tests/models/test_dccrn.py @@ -1,7 +1,7 @@ import pytest import torch -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models.dccrn import DCCRN from mayavoz.utils.config import Files @@ -15,7 +15,7 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 9526820..bc250d1 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,7 +1,7 @@ import pytest import torch -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models import WaveUnet from mayavoz.utils.config import Files @@ -15,7 +15,7 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset