rename dataset

This commit is contained in:
shahules786 2022-11-15 14:33:27 +05:30
parent bfd53937c2
commit 8bc63becce
18 changed files with 34 additions and 34 deletions

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD root_dir : /Users/shahules/Myprojects/MS-SNSD
name : dns-2020 name : dns-2020
duration : 2.0 duration : 2.0

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 4.5

View File

@ -1 +1 @@
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset

View File

@ -248,7 +248,7 @@ class TaskDataset(pl.LightningDataModule):
) )
class EnhancerDataset(TaskDataset): class MayaDataset(TaskDataset):
""" """
Dataset object for creating clean-noisy speech enhancement datasets Dataset object for creating clean-noisy speech enhancement datasets
paramters: paramters:

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from mayavoz.data import EnhancerDataset from mayavoz.data import MayaDataset
from mayavoz.models import Mayamodel from mayavoz.models import Mayamodel
from mayavoz.models.complexnn import ( from mayavoz.models.complexnn import (
ComplexBatchNorm2D, ComplexBatchNorm2D,
@ -134,13 +134,13 @@ class DCCRN(Mayamodel):
num_channels: int = 1, num_channels: int = 1,
sampling_rate=16000, sampling_rate=16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None, dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse", loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None dataset.duration if isinstance(dataset, MayaDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:

View File

@ -5,7 +5,7 @@ from typing import List, Optional, Union
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset
from mayavoz.models.model import Mayamodel from mayavoz.models.model import Mayamodel
from mayavoz.utils.io import Audio as audio from mayavoz.utils.io import Audio as audio
from mayavoz.utils.utils import merge_dict from mayavoz.utils.utils import merge_dict
@ -102,8 +102,8 @@ class Demucs(Mayamodel):
sampling rate of input audio sampling rate of input audio
lr : float, defaults to 1e-3 lr : float, defaults to 1e-3
learning rate used for training learning rate used for training
dataset: EnhancerDataset, optional dataset: MayaDataset, optional
EnhancerDataset object containing train/validation data for training MayaDataset object containing train/validation data for training
duration : float, optional duration : float, optional
chunk duration in seconds chunk duration in seconds
loss : string or List of strings loss : string or List of strings
@ -135,13 +135,13 @@ class Demucs(Mayamodel):
sampling_rate=16000, sampling_rate=16000,
normalize=True, normalize=True,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None, dataset: Optional[MayaDataset] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
floor=1e-3, floor=1e-3,
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None dataset.duration if isinstance(dataset, MayaDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:

View File

@ -13,7 +13,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset
from mayavoz.inference import Inference from mayavoz.inference import Inference
from mayavoz.loss import LOSS_MAP, LossWrapper from mayavoz.loss import LOSS_MAP, LossWrapper
from mayavoz.version import __version__ from mayavoz.version import __version__
@ -37,7 +37,7 @@ class Mayamodel(pl.LightningModule):
audio sampling rate audio sampling rate
lr: float, optional lr: float, optional
learning rate for model training learning rate for model training
dataset: EnhancerDataset, optional dataset: MayaDataset, optional
mayavoz dataset used for training/validation mayavoz dataset used for training/validation
duration: float, optional duration: float, optional
duration used for training/inference duration used for training/inference
@ -51,7 +51,7 @@ class Mayamodel(pl.LightningModule):
num_channels: int = 1, num_channels: int = 1,
sampling_rate: int = 16000, sampling_rate: int = 16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None, dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse", metric: Union[str, List, Any] = "mse",

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset
from mayavoz.models.model import Mayamodel from mayavoz.models.model import Mayamodel
@ -80,8 +80,8 @@ class WaveUnet(Mayamodel):
sampling rate of input audio sampling rate of input audio
lr : float, defaults to 1e-3 lr : float, defaults to 1e-3
learning rate used for training learning rate used for training
dataset: EnhancerDataset, optional dataset: MayaDataset, optional
EnhancerDataset object containing train/validation data for training MayaDataset object containing train/validation data for training
duration : float, optional duration : float, optional
chunk duration in seconds chunk duration in seconds
loss : string or List of strings loss : string or List of strings
@ -97,13 +97,13 @@ class WaveUnet(Mayamodel):
initial_output_channels: int = 24, initial_output_channels: int = 24,
sampling_rate: int = 16000, sampling_rate: int = 16000,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None, dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None dataset.duration if isinstance(dataset, MayaDataset) else None
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:

View File

@ -316,9 +316,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "mayavoz", "display_name": "enhancer",
"language": "python", "language": "python",
"name": "mayavoz" "name": "enhancer"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD root_dir : /Users/shahules/Myprojects/MS-SNSD
name : dns-2020 name : dns-2020
duration : 2.0 duration : 2.0

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 4.5

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 4.5

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 2 duration : 2

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD root_dir : /Users/shahules/Myprojects/MS-SNSD
name : dns-2020 name : dns-2020
duration : 2.0 duration : 2.0

View File

@ -1,4 +1,4 @@
_target_: mayavoz.data.dataset.EnhancerDataset _target_: mayavoz.data.dataset.MayaDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 4.5

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset
from mayavoz.models import Demucs from mayavoz.models import Demucs
from mayavoz.utils.config import Files from mayavoz.utils.config import Files
@ -15,7 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset
from mayavoz.models.dccrn import DCCRN from mayavoz.models.dccrn import DCCRN
from mayavoz.utils.config import Files from mayavoz.utils.config import Files
@ -15,7 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset

View File

@ -1,7 +1,7 @@
import pytest import pytest
import torch import torch
from mayavoz.data.dataset import EnhancerDataset from mayavoz.data.dataset import MayaDataset
from mayavoz.models import WaveUnet from mayavoz.models import WaveUnet
from mayavoz.utils.config import Files from mayavoz.utils.config import Files
@ -15,7 +15,7 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset