diff --git a/mayavoz/cli/train_config/dataset/DNS-2020.yaml b/mayavoz/cli/train_config/dataset/DNS-2020.yaml index 520efc9..5c67be2 100644 --- a/mayavoz/cli/train_config/dataset/DNS-2020.yaml +++ b/mayavoz/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/mayavoz/cli/train_config/dataset/Vctk.yaml b/mayavoz/cli/train_config/dataset/Vctk.yaml index f30a835..584abe7 100644 --- a/mayavoz/cli/train_config/dataset/Vctk.yaml +++ b/mayavoz/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/mayavoz/data/__init__.py b/mayavoz/data/__init__.py index c7663d7..02604df 100644 --- a/mayavoz/data/__init__.py +++ b/mayavoz/data/__init__.py @@ -1 +1 @@ -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset diff --git a/mayavoz/data/dataset.py b/mayavoz/data/dataset.py index 5296499..d47967c 100644 --- a/mayavoz/data/dataset.py +++ b/mayavoz/data/dataset.py @@ -248,7 +248,7 @@ class TaskDataset(pl.LightningDataModule): ) -class EnhancerDataset(TaskDataset): +class MayaDataset(TaskDataset): """ Dataset object for creating clean-noisy speech enhancement datasets paramters: diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 278072f..6b8646c 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from torch import nn -from mayavoz.data import EnhancerDataset +from mayavoz.data import MayaDataset from mayavoz.models import Mayamodel from mayavoz.models.complexnn import ( ComplexBatchNorm2D, @@ -134,13 +134,13 @@ class DCCRN(Mayamodel): num_channels: int = 1, sampling_rate=16000, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, duration: Optional[float] = None, loss: Union[str, List, Any] = "mse", metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, EnhancerDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index db69c80..8424f17 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import torch.nn.functional as F from torch import nn -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models.model import Mayamodel from mayavoz.utils.io import Audio as audio from mayavoz.utils.utils import merge_dict @@ -102,8 +102,8 @@ class Demucs(Mayamodel): sampling rate of input audio lr : float, defaults to 1e-3 learning rate used for training - dataset: EnhancerDataset, optional - EnhancerDataset object containing train/validation data for training + dataset: MayaDataset, optional + MayaDataset object containing train/validation data for training duration : float, optional chunk duration in seconds loss : string or List of strings @@ -135,13 +135,13 @@ class Demucs(Mayamodel): sampling_rate=16000, normalize=True, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", floor=1e-3, ): duration = ( - dataset.duration if isinstance(dataset, EnhancerDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index aede7a3..e248b2c 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -13,7 +13,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from torch import nn from torch.optim import Adam -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.inference import Inference from mayavoz.loss import LOSS_MAP, LossWrapper from mayavoz.version import __version__ @@ -37,7 +37,7 @@ class Mayamodel(pl.LightningModule): audio sampling rate lr: float, optional learning rate for model training - dataset: EnhancerDataset, optional + dataset: MayaDataset, optional mayavoz dataset used for training/validation duration: float, optional duration used for training/inference @@ -51,7 +51,7 @@ class Mayamodel(pl.LightningModule): num_channels: int = 1, sampling_rate: int = 16000, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List, Any] = "mse", diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index 9e5a4ae..c9acfda 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models.model import Mayamodel @@ -80,8 +80,8 @@ class WaveUnet(Mayamodel): sampling rate of input audio lr : float, defaults to 1e-3 learning rate used for training - dataset: EnhancerDataset, optional - EnhancerDataset object containing train/validation data for training + dataset: MayaDataset, optional + MayaDataset object containing train/validation data for training duration : float, optional chunk duration in seconds loss : string or List of strings @@ -97,13 +97,13 @@ class WaveUnet(Mayamodel): initial_output_channels: int = 24, sampling_rate: int = 16000, lr: float = 1e-3, - dataset: Optional[EnhancerDataset] = None, + dataset: Optional[MayaDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, EnhancerDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: diff --git a/notebooks/Custom_model_training.ipynb b/notebooks/Custom_model_training.ipynb index 2e5ed67..7c963c2 100644 --- a/notebooks/Custom_model_training.ipynb +++ b/notebooks/Custom_model_training.ipynb @@ -316,9 +316,9 @@ ], "metadata": { "kernelspec": { - "display_name": "mayavoz", + "display_name": "enhancer", "language": "python", - "name": "mayavoz" + "name": "enhancer" }, "language_info": { "codemirror_mode": { diff --git a/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml b/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml index 520efc9..5c67be2 100644 --- a/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml +++ b/recipes/DNS/DNS-2020/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml b/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml index f30a835..584abe7 100644 --- a/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml +++ b/recipes/DNS/DNS-2020/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml index cca932a..8e726d1 100644 --- a/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/Demucs/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml index 870bbb9..d2e6b30 100644 --- a/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/WaveUnet/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 2 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml index 520efc9..5c67be2 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/DNS-2020.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 duration : 2.0 diff --git a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml index f30a835..584abe7 100644 --- a/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml +++ b/recipes/Valentini-dataset/28spk/cli/train_config/dataset/Vctk.yaml @@ -1,4 +1,4 @@ -_target_: mayavoz.data.dataset.EnhancerDataset +_target_: mayavoz.data.dataset.MayaDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index 51bdb27..e1203b7 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,7 +1,7 @@ import pytest import torch -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models import Demucs from mayavoz.utils.config import Files @@ -15,7 +15,7 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py index 7e93af3..bc2a039 100644 --- a/tests/models/test_dccrn.py +++ b/tests/models/test_dccrn.py @@ -1,7 +1,7 @@ import pytest import torch -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models.dccrn import DCCRN from mayavoz.utils.config import Files @@ -15,7 +15,7 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 9526820..bc250d1 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,7 +1,7 @@ import pytest import torch -from mayavoz.data.dataset import EnhancerDataset +from mayavoz.data.dataset import MayaDataset from mayavoz.models import WaveUnet from mayavoz.utils.config import Files @@ -15,7 +15,7 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) return dataset