rename dataset
This commit is contained in:
parent
bfd53937c2
commit
8bc63becce
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
name : dns-2020
|
name : dns-2020
|
||||||
duration : 2.0
|
duration : 2.0
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
|
|
|
||||||
|
|
@ -248,7 +248,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class MayaDataset(TaskDataset):
|
||||||
"""
|
"""
|
||||||
Dataset object for creating clean-noisy speech enhancement datasets
|
Dataset object for creating clean-noisy speech enhancement datasets
|
||||||
paramters:
|
paramters:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mayavoz.data import EnhancerDataset
|
from mayavoz.data import MayaDataset
|
||||||
from mayavoz.models import Mayamodel
|
from mayavoz.models import Mayamodel
|
||||||
from mayavoz.models.complexnn import (
|
from mayavoz.models.complexnn import (
|
||||||
ComplexBatchNorm2D,
|
ComplexBatchNorm2D,
|
||||||
|
|
@ -134,13 +134,13 @@ class DCCRN(Mayamodel):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List, Any] = "mse",
|
loss: Union[str, List, Any] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import List, Optional, Union
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from mayavoz.models.model import Mayamodel
|
from mayavoz.models.model import Mayamodel
|
||||||
from mayavoz.utils.io import Audio as audio
|
from mayavoz.utils.io import Audio as audio
|
||||||
from mayavoz.utils.utils import merge_dict
|
from mayavoz.utils.utils import merge_dict
|
||||||
|
|
@ -102,8 +102,8 @@ class Demucs(Mayamodel):
|
||||||
sampling rate of input audio
|
sampling rate of input audio
|
||||||
lr : float, defaults to 1e-3
|
lr : float, defaults to 1e-3
|
||||||
learning rate used for training
|
learning rate used for training
|
||||||
dataset: EnhancerDataset, optional
|
dataset: MayaDataset, optional
|
||||||
EnhancerDataset object containing train/validation data for training
|
MayaDataset object containing train/validation data for training
|
||||||
duration : float, optional
|
duration : float, optional
|
||||||
chunk duration in seconds
|
chunk duration in seconds
|
||||||
loss : string or List of strings
|
loss : string or List of strings
|
||||||
|
|
@ -135,13 +135,13 @@ class Demucs(Mayamodel):
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
floor=1e-3,
|
floor=1e-3,
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from mayavoz.inference import Inference
|
from mayavoz.inference import Inference
|
||||||
from mayavoz.loss import LOSS_MAP, LossWrapper
|
from mayavoz.loss import LOSS_MAP, LossWrapper
|
||||||
from mayavoz.version import __version__
|
from mayavoz.version import __version__
|
||||||
|
|
@ -37,7 +37,7 @@ class Mayamodel(pl.LightningModule):
|
||||||
audio sampling rate
|
audio sampling rate
|
||||||
lr: float, optional
|
lr: float, optional
|
||||||
learning rate for model training
|
learning rate for model training
|
||||||
dataset: EnhancerDataset, optional
|
dataset: MayaDataset, optional
|
||||||
mayavoz dataset used for training/validation
|
mayavoz dataset used for training/validation
|
||||||
duration: float, optional
|
duration: float, optional
|
||||||
duration used for training/inference
|
duration used for training/inference
|
||||||
|
|
@ -51,7 +51,7 @@ class Mayamodel(pl.LightningModule):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
sampling_rate: int = 16000,
|
sampling_rate: int = 16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List, Any] = "mse",
|
metric: Union[str, List, Any] = "mse",
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from mayavoz.models.model import Mayamodel
|
from mayavoz.models.model import Mayamodel
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -80,8 +80,8 @@ class WaveUnet(Mayamodel):
|
||||||
sampling rate of input audio
|
sampling rate of input audio
|
||||||
lr : float, defaults to 1e-3
|
lr : float, defaults to 1e-3
|
||||||
learning rate used for training
|
learning rate used for training
|
||||||
dataset: EnhancerDataset, optional
|
dataset: MayaDataset, optional
|
||||||
EnhancerDataset object containing train/validation data for training
|
MayaDataset object containing train/validation data for training
|
||||||
duration : float, optional
|
duration : float, optional
|
||||||
chunk duration in seconds
|
chunk duration in seconds
|
||||||
loss : string or List of strings
|
loss : string or List of strings
|
||||||
|
|
@ -97,13 +97,13 @@ class WaveUnet(Mayamodel):
|
||||||
initial_output_channels: int = 24,
|
initial_output_channels: int = 24,
|
||||||
sampling_rate: int = 16000,
|
sampling_rate: int = 16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
|
|
||||||
|
|
@ -316,9 +316,9 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "mayavoz",
|
"display_name": "enhancer",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "mayavoz"
|
"name": "enhancer"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
name : dns-2020
|
name : dns-2020
|
||||||
duration : 2.0
|
duration : 2.0
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 2
|
duration : 2
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
name : dns-2020
|
name : dns-2020
|
||||||
duration : 2.0
|
duration : 2.0
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: mayavoz.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from mayavoz.models import Demucs
|
from mayavoz.models import Demucs
|
||||||
from mayavoz.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from mayavoz.models.dccrn import DCCRN
|
from mayavoz.models.dccrn import DCCRN
|
||||||
from mayavoz.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mayavoz.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from mayavoz.models import WaveUnet
|
from mayavoz.models import WaveUnet
|
||||||
from mayavoz.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue