rename dataset
This commit is contained in:
parent
bfd53937c2
commit
8bc63becce
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
name : dns-2020
|
||||
duration : 2.0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
)
|
||||
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
class MayaDataset(TaskDataset):
|
||||
"""
|
||||
Dataset object for creating clean-noisy speech enhancement datasets
|
||||
paramters:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mayavoz.data import EnhancerDataset
|
||||
from mayavoz.data import MayaDataset
|
||||
from mayavoz.models import Mayamodel
|
||||
from mayavoz.models.complexnn import (
|
||||
ComplexBatchNorm2D,
|
||||
|
|
@ -134,13 +134,13 @@ class DCCRN(Mayamodel):
|
|||
num_channels: int = 1,
|
||||
sampling_rate=16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List, Any] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import List, Optional, Union
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models.model import Mayamodel
|
||||
from mayavoz.utils.io import Audio as audio
|
||||
from mayavoz.utils.utils import merge_dict
|
||||
|
|
@ -102,8 +102,8 @@ class Demucs(Mayamodel):
|
|||
sampling rate of input audio
|
||||
lr : float, defaults to 1e-3
|
||||
learning rate used for training
|
||||
dataset: EnhancerDataset, optional
|
||||
EnhancerDataset object containing train/validation data for training
|
||||
dataset: MayaDataset, optional
|
||||
MayaDataset object containing train/validation data for training
|
||||
duration : float, optional
|
||||
chunk duration in seconds
|
||||
loss : string or List of strings
|
||||
|
|
@ -135,13 +135,13 @@ class Demucs(Mayamodel):
|
|||
sampling_rate=16000,
|
||||
normalize=True,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
floor=1e-3,
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
|
|||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.inference import Inference
|
||||
from mayavoz.loss import LOSS_MAP, LossWrapper
|
||||
from mayavoz.version import __version__
|
||||
|
|
@ -37,7 +37,7 @@ class Mayamodel(pl.LightningModule):
|
|||
audio sampling rate
|
||||
lr: float, optional
|
||||
learning rate for model training
|
||||
dataset: EnhancerDataset, optional
|
||||
dataset: MayaDataset, optional
|
||||
mayavoz dataset used for training/validation
|
||||
duration: float, optional
|
||||
duration used for training/inference
|
||||
|
|
@ -51,7 +51,7 @@ class Mayamodel(pl.LightningModule):
|
|||
num_channels: int = 1,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List, Any] = "mse",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models.model import Mayamodel
|
||||
|
||||
|
||||
|
|
@ -80,8 +80,8 @@ class WaveUnet(Mayamodel):
|
|||
sampling rate of input audio
|
||||
lr : float, defaults to 1e-3
|
||||
learning rate used for training
|
||||
dataset: EnhancerDataset, optional
|
||||
EnhancerDataset object containing train/validation data for training
|
||||
dataset: MayaDataset, optional
|
||||
MayaDataset object containing train/validation data for training
|
||||
duration : float, optional
|
||||
chunk duration in seconds
|
||||
loss : string or List of strings
|
||||
|
|
@ -97,13 +97,13 @@ class WaveUnet(Mayamodel):
|
|||
initial_output_channels: int = 24,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
|
|
|
|||
|
|
@ -316,9 +316,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "mayavoz",
|
||||
"display_name": "enhancer",
|
||||
"language": "python",
|
||||
"name": "mayavoz"
|
||||
"name": "enhancer"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
name : dns-2020
|
||||
duration : 2.0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 2
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
name : dns-2020
|
||||
duration : 2.0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
_target_: mayavoz.data.dataset.EnhancerDataset
|
||||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models import Demucs
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models.dccrn import DCCRN
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mayavoz.data.dataset import EnhancerDataset
|
||||
from mayavoz.data.dataset import MayaDataset
|
||||
from mayavoz.models import WaveUnet
|
||||
from mayavoz.utils.config import Files
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue