commit
a4f0fda6a5
2
.flake8
2
.flake8
|
|
@ -1,5 +1,5 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
per-file-ignores = __init__.py:F401
|
per-file-ignores = "mayavoz/model/__init__.py:F401"
|
||||||
ignore = E203, E266, E501, W503
|
ignore = E203, E266, E501, W503
|
||||||
# line length is intentionally set to 80 here because black uses Bugbear
|
# line length is intentionally set to 80 here because black uses Bugbear
|
||||||
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
name: Enhancer
|
name: mayavoz
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ dev ]
|
branches: [ dev ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ dev ]
|
branches: [ main ]
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
@ -40,12 +40,12 @@ jobs:
|
||||||
sudo apt-get install libsndfile1
|
sudo apt-get install libsndfile1
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
pip install black pytest-cov
|
pip install black pytest-cov
|
||||||
- name: Install enhancer
|
- name: Install mayavoz
|
||||||
run: |
|
run: |
|
||||||
pip install -e .[dev,testing]
|
pip install -e .[dev,testing]
|
||||||
- name: Run black
|
- name: Run black
|
||||||
run:
|
run:
|
||||||
black --check . --exclude enhancer/version.py
|
black --check . --exclude mayavoz/version.py
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run:
|
run:
|
||||||
pytest tests --cov=enhancer/
|
pytest tests --cov=mayavoz/
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
args: ['--ignore=E203,E501,F811,E712,W503']
|
args: ['--ignore=E203,E501,F811,E712,W503']
|
||||||
|
exclude: __init__.py
|
||||||
|
|
||||||
# Formatting, Whitespace, etc
|
# Formatting, Whitespace, etc
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
|
||||||
name : vctk
|
|
||||||
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
|
|
||||||
duration : 1.0
|
|
||||||
sampling_rate: 16000
|
|
||||||
batch_size: 64
|
|
||||||
num_workers : 0
|
|
||||||
|
|
||||||
files:
|
|
||||||
train_clean : clean_testset_wav
|
|
||||||
test_clean : clean_testset_wav
|
|
||||||
train_noisy : noisy_testset_wav
|
|
||||||
test_noisy : noisy_testset_wav
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
from enhancer.models.demucs import Demucs
|
|
||||||
from enhancer.models.model import Model
|
|
||||||
from enhancer.models.waveunet import WaveUnet
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
|
|
||||||
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
|
||||||
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
|
|
||||||
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
|
||||||
from enhancer.models.complexnn.utils import ComplexRelu # noqa
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
from enhancer.utils.config import Files
|
|
||||||
from enhancer.utils.io import Audio
|
|
||||||
from enhancer.utils.utils import check_files
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
name: enhancer
|
name: mayavoz
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
- pip=21.0.1
|
- pip=21.0.1
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
name : dns-2020
|
name : dns-2020
|
||||||
duration : 2.0
|
duration : 2.0
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : shahules/enhancer
|
experiment_name : shahules/mayavoz
|
||||||
run_name : Demucs + Vtck with stride + augmentations
|
run_name : Demucs + Vtck with stride + augmentations
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.dccrn.DCCRN
|
_target_: mayavoz.models.dccrn.DCCRN
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
complex_lstm : True
|
complex_lstm : True
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.demucs.Demucs
|
_target_: mayavoz.models.demucs.Demucs
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
resample: 4
|
resample: 4
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.waveunet.WaveUnet
|
_target_: mayavoz.models.waveunet.WaveUnet
|
||||||
num_channels : 1
|
num_channels : 1
|
||||||
depth : 9
|
depth : 9
|
||||||
initial_output_channels: 24
|
initial_output_channels: 24
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from mayavoz.data.dataset import MayaDataset
|
||||||
|
|
@ -11,11 +11,11 @@ import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch_audiomentations import Compose
|
from torch_audiomentations import Compose
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from mayavoz.data.fileprocessor import Fileprocessor
|
||||||
from enhancer.utils import check_files
|
from mayavoz.utils import check_files
|
||||||
from enhancer.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
from enhancer.utils.io import Audio
|
from mayavoz.utils.io import Audio
|
||||||
from enhancer.utils.random import create_unique_rng
|
from mayavoz.utils.random import create_unique_rng
|
||||||
|
|
||||||
LARGE_NUM = 2147483647
|
LARGE_NUM = 2147483647
|
||||||
|
|
||||||
|
|
@ -248,7 +248,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class MayaDataset(TaskDataset):
|
||||||
"""
|
"""
|
||||||
Dataset object for creating clean-noisy speech enhancement datasets
|
Dataset object for creating clean-noisy speech enhancement datasets
|
||||||
paramters:
|
paramters:
|
||||||
|
|
@ -258,7 +258,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
root directory of the dataset containing clean/noisy folders
|
root directory of the dataset containing clean/noisy folders
|
||||||
files : Files
|
files : Files
|
||||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
folder names (refer enhancer.utils.Files dataclass)
|
folder names (refer mayavoz.utils.Files dataclass)
|
||||||
min_valid_minutes: float
|
min_valid_minutes: float
|
||||||
minimum validation split size time in minutes
|
minimum validation split size time in minutes
|
||||||
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
||||||
|
|
@ -8,7 +8,7 @@ from librosa import load as load_audio
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
from scipy.signal import get_window
|
from scipy.signal import get_window
|
||||||
|
|
||||||
from enhancer.utils import Audio
|
from mayavoz.utils import Audio
|
||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
|
|
@ -192,7 +192,7 @@ class Si_snr(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
self.loss_fun = ScaleInvariantSignalNoiseRatio(**kwargs)
|
||||||
self.higher_better = True
|
self.higher_better = False
|
||||||
self.name = "si_snr"
|
self.name = "si_snr"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
@ -203,7 +203,7 @@ class Si_snr(nn.Module):
|
||||||
got {prediction.size()} and {target.size()} instead"""
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.loss_fun(prediction, target)
|
return -1 * self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
LOSS_MAP = {
|
LOSS_MAP = {
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from mayavoz.models.demucs import Demucs
|
||||||
|
from mayavoz.models.model import Mayamodel
|
||||||
|
from mayavoz.models.waveunet import WaveUnet
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
from mayavoz.models.complexnn.conv import ComplexConv2d # noqa
|
||||||
|
from mayavoz.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
||||||
|
from mayavoz.models.complexnn.rnn import ComplexLSTM # noqa
|
||||||
|
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
||||||
|
from mayavoz.models.complexnn.utils import ComplexRelu # noqa
|
||||||
|
|
@ -5,18 +5,18 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from enhancer.data import EnhancerDataset
|
from mayavoz.data import MayaDataset
|
||||||
from enhancer.models import Model
|
from mayavoz.models import Mayamodel
|
||||||
from enhancer.models.complexnn import (
|
from mayavoz.models.complexnn import (
|
||||||
ComplexBatchNorm2D,
|
ComplexBatchNorm2D,
|
||||||
ComplexConv2d,
|
ComplexConv2d,
|
||||||
ComplexConvTranspose2d,
|
ComplexConvTranspose2d,
|
||||||
ComplexLSTM,
|
ComplexLSTM,
|
||||||
ComplexRelu,
|
ComplexRelu,
|
||||||
)
|
)
|
||||||
from enhancer.models.complexnn.utils import complex_cat
|
from mayavoz.models.complexnn.utils import complex_cat
|
||||||
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT
|
||||||
from enhancer.utils.utils import merge_dict
|
from mayavoz.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
class DCCRN_ENCODER(nn.Module):
|
class DCCRN_ENCODER(nn.Module):
|
||||||
|
|
@ -98,7 +98,7 @@ class DCCRN_DECODER(nn.Module):
|
||||||
return self.decoder(waveform)
|
return self.decoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
class DCCRN(Model):
|
class DCCRN(Mayamodel):
|
||||||
|
|
||||||
STFT_DEFAULTS = {
|
STFT_DEFAULTS = {
|
||||||
"window_len": 400,
|
"window_len": 400,
|
||||||
|
|
@ -134,13 +134,13 @@ class DCCRN(Model):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List, Any] = "mse",
|
loss: Union[str, List, Any] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
|
@ -5,10 +5,10 @@ from typing import List, Optional, Union
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from enhancer.models.model import Model
|
from mayavoz.models.model import Mayamodel
|
||||||
from enhancer.utils.io import Audio as audio
|
from mayavoz.utils.io import Audio as audio
|
||||||
from enhancer.utils.utils import merge_dict
|
from mayavoz.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
class DemucsLSTM(nn.Module):
|
class DemucsLSTM(nn.Module):
|
||||||
|
|
@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Demucs(Model):
|
class Demucs(Mayamodel):
|
||||||
"""
|
"""
|
||||||
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -102,8 +102,8 @@ class Demucs(Model):
|
||||||
sampling rate of input audio
|
sampling rate of input audio
|
||||||
lr : float, defaults to 1e-3
|
lr : float, defaults to 1e-3
|
||||||
learning rate used for training
|
learning rate used for training
|
||||||
dataset: EnhancerDataset, optional
|
dataset: MayaDataset, optional
|
||||||
EnhancerDataset object containing train/validation data for training
|
MayaDataset object containing train/validation data for training
|
||||||
duration : float, optional
|
duration : float, optional
|
||||||
chunk duration in seconds
|
chunk duration in seconds
|
||||||
loss : string or List of strings
|
loss : string or List of strings
|
||||||
|
|
@ -135,13 +135,13 @@ class Demucs(Model):
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
floor=1e-3,
|
floor=1e-3,
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
|
@ -13,20 +13,21 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from enhancer.inference import Inference
|
from mayavoz.inference import Inference
|
||||||
from enhancer.loss import LOSS_MAP, LossWrapper
|
from mayavoz.loss import LOSS_MAP, LossWrapper
|
||||||
from enhancer.version import __version__
|
from mayavoz.version import __version__
|
||||||
|
|
||||||
CACHE_DIR = os.getenv(
|
CACHE_DIR = os.getenv(
|
||||||
"ENHANCER_CACHE",
|
"ENHANCER_CACHE",
|
||||||
os.path.expanduser("~/.cache/torch/enhancer"),
|
os.path.expanduser("~/.cache/torch/mayavoz"),
|
||||||
)
|
)
|
||||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
|
SAVE_NAME = "enhancer"
|
||||||
|
|
||||||
|
|
||||||
class Model(pl.LightningModule):
|
class Mayamodel(pl.LightningModule):
|
||||||
"""
|
"""
|
||||||
Base class for all models
|
Base class for all models
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -36,8 +37,8 @@ class Model(pl.LightningModule):
|
||||||
audio sampling rate
|
audio sampling rate
|
||||||
lr: float, optional
|
lr: float, optional
|
||||||
learning rate for model training
|
learning rate for model training
|
||||||
dataset: EnhancerDataset, optional
|
dataset: MayaDataset, optional
|
||||||
Enhancer dataset used for training/validation
|
mayavoz dataset used for training/validation
|
||||||
duration: float, optional
|
duration: float, optional
|
||||||
duration used for training/inference
|
duration used for training/inference
|
||||||
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
||||||
|
|
@ -50,15 +51,13 @@ class Model(pl.LightningModule):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
sampling_rate: int = 16000,
|
sampling_rate: int = 16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List, Any] = "mse",
|
metric: Union[str, List, Any] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
assert num_channels == 1, "mayavoz only support for mono channel models"
|
||||||
num_channels == 1
|
|
||||||
), "Enhancer only support for mono channel models"
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.save_hyperparameters(
|
self.save_hyperparameters(
|
||||||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||||
|
|
@ -235,8 +234,8 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint["enhancer"] = {
|
checkpoint[SAVE_NAME] = {
|
||||||
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__},
|
||||||
"architecture": {
|
"architecture": {
|
||||||
"module": self.__class__.__module__,
|
"module": self.__class__.__module__,
|
||||||
"class": self.__class__.__name__,
|
"class": self.__class__.__name__,
|
||||||
|
|
@ -289,8 +288,8 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
model : Model
|
model : Mayamodel
|
||||||
Model
|
Mayamodel
|
||||||
|
|
||||||
See also
|
See also
|
||||||
--------
|
--------
|
||||||
|
|
@ -319,7 +318,7 @@ class Model(pl.LightningModule):
|
||||||
)
|
)
|
||||||
model_path_pl = cached_download(
|
model_path_pl = cached_download(
|
||||||
url=url,
|
url=url,
|
||||||
library_name="enhancer",
|
library_name="mayavoz",
|
||||||
library_version=__version__,
|
library_version=__version__,
|
||||||
cache_dir=cached_dir,
|
cache_dir=cached_dir,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
|
|
@ -329,8 +328,8 @@ class Model(pl.LightningModule):
|
||||||
map_location = torch.device(DEFAULT_DEVICE)
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"]
|
||||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"]
|
||||||
module = import_module(module_name)
|
module = import_module(module_name)
|
||||||
Klass = getattr(module, class_name)
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
|
@ -5,8 +5,8 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from enhancer.models.model import Model
|
from mayavoz.models.model import Mayamodel
|
||||||
|
|
||||||
|
|
||||||
class WavenetDecoder(nn.Module):
|
class WavenetDecoder(nn.Module):
|
||||||
|
|
@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module):
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
class WaveUnet(Model):
|
class WaveUnet(Mayamodel):
|
||||||
"""
|
"""
|
||||||
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -80,8 +80,8 @@ class WaveUnet(Model):
|
||||||
sampling rate of input audio
|
sampling rate of input audio
|
||||||
lr : float, defaults to 1e-3
|
lr : float, defaults to 1e-3
|
||||||
learning rate used for training
|
learning rate used for training
|
||||||
dataset: EnhancerDataset, optional
|
dataset: MayaDataset, optional
|
||||||
EnhancerDataset object containing train/validation data for training
|
MayaDataset object containing train/validation data for training
|
||||||
duration : float, optional
|
duration : float, optional
|
||||||
chunk duration in seconds
|
chunk duration in seconds
|
||||||
loss : string or List of strings
|
loss : string or List of strings
|
||||||
|
|
@ -97,13 +97,13 @@ class WaveUnet(Model):
|
||||||
initial_output_channels: int = 24,
|
initial_output_channels: int = 24,
|
||||||
sampling_rate: int = 16000,
|
sampling_rate: int = 16000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from mayavoz.utils.config import Files
|
||||||
|
from mayavoz.utils.io import Audio
|
||||||
|
from mayavoz.utils.utils import check_files
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from enhancer.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
def check_files(root_dir: str, files: Files):
|
def check_files(root_dir: str, files: Files):
|
||||||
|
|
@ -374,7 +374,7 @@
|
||||||
"```\n",
|
"```\n",
|
||||||
"\n",
|
"\n",
|
||||||
"```yaml\n",
|
"```yaml\n",
|
||||||
"_target_: enhancer.models.demucs.Demucs\n",
|
"_target_: mayavoz.models.demucs.Demucs\n",
|
||||||
"num_channels: 1\n",
|
"num_channels: 1\n",
|
||||||
"resample: 4\n",
|
"resample: 4\n",
|
||||||
"sampling_rate : 16000\n",
|
"sampling_rate : 16000\n",
|
||||||
|
|
@ -405,9 +405,9 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "enhancer",
|
"display_name": "mayavoz",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "enhancer"
|
"name": "mayavoz"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
import os
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from pytorch_lightning.callbacks import (
|
||||||
|
EarlyStopping,
|
||||||
|
LearningRateMonitor,
|
||||||
|
ModelCheckpoint,
|
||||||
|
)
|
||||||
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
# from torch_audiomentations import Compose, Shift
|
||||||
|
|
||||||
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="train_config", config_name="config")
|
||||||
|
def main(config: DictConfig):
|
||||||
|
|
||||||
|
OmegaConf.save(config, "config_log.yaml")
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
logger = MLFlowLogger(
|
||||||
|
experiment_name=config.mlflow.experiment_name,
|
||||||
|
run_name=config.mlflow.run_name,
|
||||||
|
tags={"JOB_ID": JOB_ID},
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters = config.hyperparameters
|
||||||
|
# apply_augmentations = Compose(
|
||||||
|
# [
|
||||||
|
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
|
||||||
|
dataset = instantiate(config.dataset, augmentations=None)
|
||||||
|
model = instantiate(
|
||||||
|
config.model,
|
||||||
|
dataset=dataset,
|
||||||
|
lr=parameters.get("lr"),
|
||||||
|
loss=parameters.get("loss"),
|
||||||
|
metric=parameters.get("metric"),
|
||||||
|
)
|
||||||
|
|
||||||
|
direction = model.valid_monitor
|
||||||
|
checkpoint = ModelCheckpoint(
|
||||||
|
dirpath="./model",
|
||||||
|
filename=f"model_{JOB_ID}",
|
||||||
|
monitor="valid_loss",
|
||||||
|
verbose=False,
|
||||||
|
mode=direction,
|
||||||
|
every_n_epochs=1,
|
||||||
|
)
|
||||||
|
callbacks.append(checkpoint)
|
||||||
|
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||||
|
|
||||||
|
if parameters.get("Early_stop", False):
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor="val_loss",
|
||||||
|
mode=direction,
|
||||||
|
min_delta=0.0,
|
||||||
|
patience=parameters.get("EarlyStopping_patience", 10),
|
||||||
|
strict=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = instantiate(
|
||||||
|
config.optimizer,
|
||||||
|
lr=parameters.get("lr"),
|
||||||
|
params=self.parameters(),
|
||||||
|
)
|
||||||
|
scheduler = ReduceLROnPlateau(
|
||||||
|
optimizer=optimizer,
|
||||||
|
mode=direction,
|
||||||
|
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||||
|
verbose=True,
|
||||||
|
min_lr=parameters.get("min_lr", 1e-6),
|
||||||
|
patience=parameters.get("ReduceLr_patience", 3),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": scheduler,
|
||||||
|
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||||
|
}
|
||||||
|
|
||||||
|
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||||
|
|
||||||
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
|
trainer.fit(model)
|
||||||
|
trainer.test(model)
|
||||||
|
|
||||||
|
logger.experiment.log_artifact(
|
||||||
|
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_location = os.path.join(
|
||||||
|
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||||
|
)
|
||||||
|
if os.path.isfile(saved_location):
|
||||||
|
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||||
|
logger.experiment.log_param(
|
||||||
|
logger.run_id,
|
||||||
|
"num_train_steps_per_epoch",
|
||||||
|
dataset.train__len__() / dataset.batch_size,
|
||||||
|
)
|
||||||
|
logger.experiment.log_param(
|
||||||
|
logger.run_id,
|
||||||
|
"num_valid_steps_per_epoch",
|
||||||
|
dataset.val__len__() / dataset.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
defaults:
|
||||||
|
- model : Demucs
|
||||||
|
- dataset : Vctk
|
||||||
|
- optimizer : Adam
|
||||||
|
- hyperparameters : default
|
||||||
|
- trainer : default
|
||||||
|
- mlflow : experiment
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
|
name : dns-2020
|
||||||
|
duration : 2.0
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 32
|
||||||
|
valid_size: 0.05
|
||||||
|
files:
|
||||||
|
train_clean : CleanSpeech_training
|
||||||
|
test_clean : CleanSpeech_training
|
||||||
|
train_noisy : NoisySpeech_training
|
||||||
|
test_noisy : NoisySpeech_training
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
|
name : vctk
|
||||||
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
|
duration : 4.5
|
||||||
|
stride : 2
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 32
|
||||||
|
valid_minutes : 15
|
||||||
|
files:
|
||||||
|
train_clean : clean_trainset_28spk_wav
|
||||||
|
test_clean : clean_testset_wav
|
||||||
|
train_noisy : noisy_trainset_28spk_wav
|
||||||
|
test_noisy : noisy_testset_wav
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
loss : mae
|
||||||
|
metric : [stoi,pesq,si-sdr]
|
||||||
|
lr : 0.0003
|
||||||
|
ReduceLr_patience : 5
|
||||||
|
ReduceLr_factor : 0.2
|
||||||
|
min_lr : 0.000001
|
||||||
|
EarlyStopping_factor : 10
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
experiment_name : shahules/mayavoz
|
||||||
|
run_name : Demucs + Vtck with stride + augmentations
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
_target_: mayavoz.models.dccrn.DCCRN
|
||||||
|
num_channels: 1
|
||||||
|
sampling_rate : 16000
|
||||||
|
complex_lstm : True
|
||||||
|
complex_norm : True
|
||||||
|
complex_relu : True
|
||||||
|
masking_mode : True
|
||||||
|
|
||||||
|
encoder_decoder:
|
||||||
|
initial_output_channels : 32
|
||||||
|
depth : 6
|
||||||
|
kernel_size : 5
|
||||||
|
growth_factor : 2
|
||||||
|
stride : 2
|
||||||
|
padding : 2
|
||||||
|
output_padding : 1
|
||||||
|
|
||||||
|
lstm:
|
||||||
|
num_layers : 2
|
||||||
|
hidden_size : 256
|
||||||
|
|
||||||
|
stft:
|
||||||
|
window_len : 400
|
||||||
|
hop_size : 100
|
||||||
|
nfft : 512
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
_target_: mayavoz.models.demucs.Demucs
|
||||||
|
num_channels: 1
|
||||||
|
resample: 4
|
||||||
|
sampling_rate : 16000
|
||||||
|
|
||||||
|
encoder_decoder:
|
||||||
|
depth: 4
|
||||||
|
initial_output_channels: 64
|
||||||
|
kernel_size: 8
|
||||||
|
stride: 4
|
||||||
|
growth_factor: 2
|
||||||
|
glu: True
|
||||||
|
|
||||||
|
lstm:
|
||||||
|
bidirectional: False
|
||||||
|
num_layers: 2
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
_target_: mayavoz.models.waveunet.WaveUnet
|
||||||
|
num_channels : 1
|
||||||
|
depth : 9
|
||||||
|
initial_output_channels: 24
|
||||||
|
sampling_rate : 16000
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
_target_: torch.optim.Adam
|
||||||
|
lr: 1e-3
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
eps: 1e-08
|
||||||
|
weight_decay: 0
|
||||||
|
amsgrad: False
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
_target_: pytorch_lightning.Trainer
|
||||||
|
accelerator: gpu
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
amp_backend: native
|
||||||
|
auto_lr_find: True
|
||||||
|
auto_scale_batch_size: False
|
||||||
|
auto_select_gpus: True
|
||||||
|
benchmark: False
|
||||||
|
check_val_every_n_epoch: 1
|
||||||
|
detect_anomaly: False
|
||||||
|
deterministic: False
|
||||||
|
devices: 2
|
||||||
|
enable_checkpointing: True
|
||||||
|
enable_model_summary: True
|
||||||
|
enable_progress_bar: True
|
||||||
|
fast_dev_run: False
|
||||||
|
gpus: null
|
||||||
|
gradient_clip_val: 0
|
||||||
|
gradient_clip_algorithm: norm
|
||||||
|
ipus: null
|
||||||
|
limit_predict_batches: 1.0
|
||||||
|
limit_test_batches: 1.0
|
||||||
|
limit_train_batches: 1.0
|
||||||
|
limit_val_batches: 1.0
|
||||||
|
log_every_n_steps: 50
|
||||||
|
max_epochs: 200
|
||||||
|
max_steps: -1
|
||||||
|
max_time: null
|
||||||
|
min_epochs: 1
|
||||||
|
min_steps: null
|
||||||
|
move_metrics_to_cpu: False
|
||||||
|
multiple_trainloader_mode: max_size_cycle
|
||||||
|
num_nodes: 1
|
||||||
|
num_processes: 1
|
||||||
|
num_sanity_val_steps: 2
|
||||||
|
overfit_batches: 0.0
|
||||||
|
precision: 32
|
||||||
|
profiler: null
|
||||||
|
reload_dataloaders_every_n_epochs: 0
|
||||||
|
replace_sampler_ddp: True
|
||||||
|
strategy: ddp
|
||||||
|
sync_batchnorm: False
|
||||||
|
tpu_cores: null
|
||||||
|
track_grad_norm: -1
|
||||||
|
val_check_interval: 1.0
|
||||||
|
weights_save_path: null
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
_target_: pytorch_lightning.Trainer
|
||||||
|
fast_dev_run: True
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : shahules/enhancer
|
experiment_name : shahules/mayavoz
|
||||||
run_name : baseline
|
run_name : baseline
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.demucs.Demucs
|
_target_: mayavoz.models.demucs.Demucs
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
resample: 4
|
resample: 4
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 2
|
duration : 2
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : shahules/enhancer
|
experiment_name : shahules/mayavoz
|
||||||
run_name : baseline
|
run_name : baseline
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.waveunet.WaveUnet
|
_target_: mayavoz.models.waveunet.WaveUnet
|
||||||
num_channels : 1
|
num_channels : 1
|
||||||
depth : 9
|
depth : 9
|
||||||
initial_output_channels: 24
|
initial_output_channels: 24
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
name : dns-2020
|
name : dns-2020
|
||||||
duration : 2.0
|
duration : 2.0
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.data.dataset.EnhancerDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : shahules/enhancer
|
experiment_name : shahules/mayavoz
|
||||||
run_name : Demucs + Vtck with stride + augmentations
|
run_name : Demucs + Vtck with stride + augmentations
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.dccrn.DCCRN
|
_target_: mayavoz.models.dccrn.DCCRN
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
complex_lstm : True
|
complex_lstm : True
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.demucs.Demucs
|
_target_: mayavoz.models.demucs.Demucs
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
resample: 4
|
resample: 4
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
_target_: enhancer.models.waveunet.WaveUnet
|
_target_: mayavoz.models.waveunet.WaveUnet
|
||||||
num_channels : 1
|
num_channels : 1
|
||||||
depth : 9
|
depth : 9
|
||||||
initial_output_channels: 24
|
initial_output_channels: 24
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
name = enhancer
|
name = mayavoz
|
||||||
description = Deep learning for speech enhacement
|
description = Deep learning for speech enhacement
|
||||||
author = Shahul Ess
|
author = Shahul Ess
|
||||||
author-email = shahules786@gmail.com
|
author-email = shahules786@gmail.com
|
||||||
|
|
@ -53,7 +53,7 @@ cli =
|
||||||
[options.entry_points]
|
[options.entry_points]
|
||||||
|
|
||||||
console_scripts =
|
console_scripts =
|
||||||
enhancer-train=enhancer.cli.train:train
|
mayavoz-train=mayavoz.cli.train:train
|
||||||
|
|
||||||
[test]
|
[test]
|
||||||
# py.test options when running `python setup.py test`
|
# py.test options when running `python setup.py test`
|
||||||
|
|
@ -66,7 +66,7 @@ extras = True
|
||||||
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
||||||
# in order to write a coverage file that can be read by Jenkins.
|
# in order to write a coverage file that can be read by Jenkins.
|
||||||
addopts =
|
addopts =
|
||||||
--cov enhancer --cov-report term-missing
|
--cov mayavoz --cov-report term-missing
|
||||||
--verbose
|
--verbose
|
||||||
norecursedirs =
|
norecursedirs =
|
||||||
dist
|
dist
|
||||||
|
|
|
||||||
6
setup.py
6
setup.py
|
|
@ -33,15 +33,15 @@ elif sha != "Unknown":
|
||||||
version += "+" + sha[:7]
|
version += "+" + sha[:7]
|
||||||
print("-- Building version " + version)
|
print("-- Building version " + version)
|
||||||
|
|
||||||
version_path = ROOT_DIR / "enhancer" / "version.py"
|
version_path = ROOT_DIR / "mayavoz" / "version.py"
|
||||||
|
|
||||||
with open(version_path, "w") as f:
|
with open(version_path, "w") as f:
|
||||||
f.write("__version__ = '{}'\n".format(version))
|
f.write("__version__ = '{}'\n".format(version))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
setup(
|
setup(
|
||||||
name="enhancer",
|
name="mayavoz",
|
||||||
namespace_packages=["enhancer"],
|
namespace_packages=["mayavoz"],
|
||||||
version=version,
|
version=version,
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.loss import mean_absolute_error, mean_squared_error
|
from mayavoz.loss import mean_absolute_error, mean_squared_error
|
||||||
|
|
||||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
from mayavoz.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
from enhancer.models.complexnn.rnn import ComplexLSTM
|
from mayavoz.models.complexnn.rnn import ComplexLSTM
|
||||||
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
|
from mayavoz.models.complexnn.utils import ComplexBatchNorm2D
|
||||||
|
|
||||||
|
|
||||||
def test_complexconv2d():
|
def test_complexconv2d():
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from enhancer.models import Demucs
|
from mayavoz.models import Demucs
|
||||||
from enhancer.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from enhancer.models.dccrn import DCCRN
|
from mayavoz.models.dccrn import DCCRN
|
||||||
from enhancer.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from mayavoz.data.dataset import MayaDataset
|
||||||
from enhancer.models import WaveUnet
|
from mayavoz.models import WaveUnet
|
||||||
from enhancer.utils.config import Files
|
from mayavoz.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -15,7 +15,7 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.inference import Inference
|
from mayavoz.inference import Inference
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
from mayavoz.utils.transforms import ConviSTFT, ConvSTFT
|
||||||
|
|
||||||
|
|
||||||
def test_stft_istft():
|
def test_stft_istft():
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from mayavoz.data.fileprocessor import Fileprocessor
|
||||||
from enhancer.utils.io import Audio
|
from mayavoz.utils.io import Audio
|
||||||
|
|
||||||
|
|
||||||
def test_io_channel():
|
def test_io_channel():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue