Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
a064151e2e
|
|
@ -0,0 +1,9 @@
|
||||||
|
[flake8]
|
||||||
|
per-file-ignores = __init__.py:F401
|
||||||
|
ignore = E203, E266, E501, W503
|
||||||
|
# line length is intentionally set to 80 here because black uses Bugbear
|
||||||
|
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
||||||
|
max-line-length = 80
|
||||||
|
max-complexity = 18
|
||||||
|
select = B,C,E,F,W,T4,B9
|
||||||
|
exclude = tools/kaldi_decoder
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
|
||||||
|
repos:
|
||||||
|
# # Clean Notebooks
|
||||||
|
# - repo: https://github.com/kynan/nbstripout
|
||||||
|
# rev: master
|
||||||
|
# hooks:
|
||||||
|
# - id: nbstripout
|
||||||
|
# Format Code
|
||||||
|
- repo: https://github.com/ambv/black
|
||||||
|
rev: 22.8.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
|
||||||
|
# Sort imports
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 5.10.1
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args: ["--profile", "black"]
|
||||||
|
|
||||||
|
- repo: https://gitlab.com/pycqa/flake8
|
||||||
|
rev: 5.0.4
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args: ['--ignore=E203,E501,F811,E712,W503']
|
||||||
|
|
||||||
|
# Formatting, Whitespace, etc
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v3.2.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb=1000']
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-json
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-xml
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
- id: mixed-line-ending
|
||||||
|
args: ['--fix=no']
|
||||||
|
|
@ -1 +1,6 @@
|
||||||
# enhancer
|
# enhancer
|
||||||
|
Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides
|
||||||
|
|
||||||
|
* Various pretrained models nicely integrated with huggingface that users can select and use without any hastle.
|
||||||
|
* Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
|
||||||
|
* A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
|
||||||
67
cli/train.py
67
cli/train.py
|
|
@ -1,67 +0,0 @@
|
||||||
from genericpath import isfile
|
|
||||||
import os
|
|
||||||
from types import MethodType
|
|
||||||
import hydra
|
|
||||||
from hydra.utils import instantiate
|
|
||||||
from omegaconf import DictConfig
|
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID","0")
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
|
||||||
def main(config: DictConfig):
|
|
||||||
|
|
||||||
callbacks = []
|
|
||||||
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
|
||||||
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
|
|
||||||
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
|
||||||
|
|
||||||
dataset = instantiate(config.dataset)
|
|
||||||
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
|
||||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
|
||||||
|
|
||||||
direction = model.valid_monitor
|
|
||||||
checkpoint = ModelCheckpoint(
|
|
||||||
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True,
|
|
||||||
mode=direction,every_n_epochs=1
|
|
||||||
)
|
|
||||||
callbacks.append(checkpoint)
|
|
||||||
early_stopping = EarlyStopping(
|
|
||||||
monitor="val_loss",
|
|
||||||
mode=direction,
|
|
||||||
min_delta=0.0,
|
|
||||||
patience=parameters.get("EarlyStopping_patience",10),
|
|
||||||
strict=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stopping)
|
|
||||||
|
|
||||||
def configure_optimizer(self):
|
|
||||||
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
|
|
||||||
scheduler = ReduceLROnPlateau(
|
|
||||||
optimizer=optimizer,
|
|
||||||
mode=direction,
|
|
||||||
factor=parameters.get("ReduceLr_factor",0.1),
|
|
||||||
verbose=True,
|
|
||||||
min_lr=parameters.get("min_lr",1e-6),
|
|
||||||
patience=parameters.get("ReduceLr_patience",3)
|
|
||||||
)
|
|
||||||
return {"optimizer":optimizer, "lr_scheduler":scheduler}
|
|
||||||
|
|
||||||
model.configure_parameters = MethodType(configure_optimizer,model)
|
|
||||||
|
|
||||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
|
||||||
trainer.fit(model)
|
|
||||||
|
|
||||||
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
|
|
||||||
if os.path.isfile(saved_location):
|
|
||||||
logger.experiment.log_artifact(logger.run_id,saved_location)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__version__ = "0.0.1"
|
__version__ = "0.0.1"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
import os
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="train_config", config_name="config")
|
||||||
|
def main(config: DictConfig):
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
logger = MLFlowLogger(
|
||||||
|
experiment_name=config.mlflow.experiment_name,
|
||||||
|
run_name=config.mlflow.run_name,
|
||||||
|
tags={"JOB_ID": JOB_ID},
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters = config.hyperparameters
|
||||||
|
|
||||||
|
dataset = instantiate(config.dataset)
|
||||||
|
model = instantiate(
|
||||||
|
config.model,
|
||||||
|
dataset=dataset,
|
||||||
|
lr=parameters.get("lr"),
|
||||||
|
loss=parameters.get("loss"),
|
||||||
|
metric=parameters.get("metric"),
|
||||||
|
)
|
||||||
|
|
||||||
|
direction = model.valid_monitor
|
||||||
|
checkpoint = ModelCheckpoint(
|
||||||
|
dirpath="./model",
|
||||||
|
filename=f"model_{JOB_ID}",
|
||||||
|
monitor="val_loss",
|
||||||
|
verbose=True,
|
||||||
|
mode=direction,
|
||||||
|
every_n_epochs=1,
|
||||||
|
)
|
||||||
|
callbacks.append(checkpoint)
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor="val_loss",
|
||||||
|
mode=direction,
|
||||||
|
min_delta=0.0,
|
||||||
|
patience=parameters.get("EarlyStopping_patience", 10),
|
||||||
|
strict=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
def configure_optimizer(self):
|
||||||
|
optimizer = instantiate(
|
||||||
|
config.optimizer,
|
||||||
|
lr=parameters.get("lr"),
|
||||||
|
parameters=self.parameters(),
|
||||||
|
)
|
||||||
|
scheduler = ReduceLROnPlateau(
|
||||||
|
optimizer=optimizer,
|
||||||
|
mode=direction,
|
||||||
|
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||||
|
verbose=True,
|
||||||
|
min_lr=parameters.get("min_lr", 1e-6),
|
||||||
|
patience=parameters.get("ReduceLr_patience", 3),
|
||||||
|
)
|
||||||
|
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
||||||
|
|
||||||
|
model.configure_parameters = MethodType(configure_optimizer, model)
|
||||||
|
|
||||||
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
saved_location = os.path.join(
|
||||||
|
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||||
|
)
|
||||||
|
if os.path.isfile(saved_location):
|
||||||
|
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -4,4 +4,4 @@ defaults:
|
||||||
- optimizer : Adam
|
- optimizer : Adam
|
||||||
- hyperparameters : default
|
- hyperparameters : default
|
||||||
- trainer : default
|
- trainer : default
|
||||||
- mlflow : experiment
|
- mlflow : experiment
|
||||||
|
|
@ -10,4 +10,3 @@ files:
|
||||||
test_clean : clean_test_wav
|
test_clean : clean_test_wav
|
||||||
train_noisy : clean_test_wav
|
train_noisy : clean_test_wav
|
||||||
test_noisy : clean_test_wav
|
test_noisy : clean_test_wav
|
||||||
|
|
||||||
|
|
@ -10,6 +10,3 @@ files:
|
||||||
test_clean : clean_testset_wav
|
test_clean : clean_testset_wav
|
||||||
train_noisy : noisy_trainset_28spk_wav
|
train_noisy : noisy_trainset_28spk_wav
|
||||||
test_noisy : noisy_testset_wav
|
test_noisy : noisy_testset_wav
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
name : vctk
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
|
||||||
|
duration : 1.0
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 64
|
||||||
|
num_workers : 0
|
||||||
|
|
||||||
|
files:
|
||||||
|
train_clean : clean_testset_wav
|
||||||
|
test_clean : clean_testset_wav
|
||||||
|
train_noisy : noisy_testset_wav
|
||||||
|
test_noisy : noisy_testset_wav
|
||||||
|
|
@ -5,4 +5,3 @@ ReduceLr_patience : 5
|
||||||
ReduceLr_factor : 0.1
|
ReduceLr_factor : 0.1
|
||||||
min_lr : 0.000001
|
min_lr : 0.000001
|
||||||
EarlyStopping_factor : 10
|
EarlyStopping_factor : 10
|
||||||
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : shahules/enhancer
|
experiment_name : shahules/enhancer
|
||||||
run_name : baseline
|
run_name : baseline
|
||||||
|
|
@ -14,5 +14,3 @@ encoder_decoder:
|
||||||
lstm:
|
lstm:
|
||||||
bidirectional: False
|
bidirectional: False
|
||||||
num_layers: 2
|
num_layers: 2
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
@ -1,20 +1,21 @@
|
||||||
import multiprocessing
|
|
||||||
import math
|
import math
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pytorch_lightning as pl
|
|
||||||
from torch.utils.data import IterableDataset, DataLoader, Dataset
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
from enhancer.utils.random import create_unique_rng
|
|
||||||
from enhancer.utils.io import Audio
|
|
||||||
from enhancer.utils import check_files
|
from enhancer.utils import check_files
|
||||||
from enhancer.utils.config import Files
|
from enhancer.utils.config import Files
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
|
|
||||||
class TrainDataset(IterableDataset):
|
class TrainDataset(IterableDataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
def __init__(self,dataset):
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
@ -23,88 +24,102 @@ class TrainDataset(IterableDataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.dataset.train__len__()
|
return self.dataset.train__len__()
|
||||||
|
|
||||||
|
|
||||||
class ValidDataset(Dataset):
|
class ValidDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
def __init__(self,dataset):
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
def __getitem__(self,idx):
|
def __getitem__(self, idx):
|
||||||
return self.dataset.val__getitem__(idx)
|
return self.dataset.val__getitem__(idx)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.dataset.val__len__()
|
return self.dataset.val__len__()
|
||||||
|
|
||||||
class TaskDataset(pl.LightningDataModule):
|
|
||||||
|
|
||||||
|
class TaskDataset(pl.LightningDataModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name:str,
|
name: str,
|
||||||
root_dir:str,
|
root_dir: str,
|
||||||
files:Files,
|
files: Files,
|
||||||
duration:float=1.0,
|
duration: float = 1.0,
|
||||||
sampling_rate:int=48000,
|
sampling_rate: int = 48000,
|
||||||
matching_function = None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers:Optional[int]=None):
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.files,self.root_dir = check_files(root_dir,files)
|
self.files, self.root_dir = check_files(root_dir, files)
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
self._validation = []
|
self._validation = []
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count()//2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
if stage in ("fit",None):
|
if stage in ("fit", None):
|
||||||
|
|
||||||
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
train_clean = os.path.join(self.root_dir, self.files.train_clean)
|
||||||
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
|
||||||
fp = Fileprocessor.from_name(self.name,train_clean,
|
fp = Fileprocessor.from_name(
|
||||||
train_noisy, self.matching_function)
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
|
)
|
||||||
self.train_data = fp.prepare_matching_dict()
|
self.train_data = fp.prepare_matching_dict()
|
||||||
|
|
||||||
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
val_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||||
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||||
fp = Fileprocessor.from_name(self.name,val_clean,
|
fp = Fileprocessor.from_name(
|
||||||
val_noisy, self.matching_function)
|
self.name, val_clean, val_noisy, self.matching_function
|
||||||
|
)
|
||||||
val_data = fp.prepare_matching_dict()
|
val_data = fp.prepare_matching_dict()
|
||||||
|
|
||||||
for item in val_data:
|
for item in val_data:
|
||||||
clean,noisy,total_dur = item.values()
|
clean, noisy, total_dur = item.values()
|
||||||
if total_dur < self.duration:
|
if total_dur < self.duration:
|
||||||
continue
|
continue
|
||||||
num_segments = round(total_dur/self.duration)
|
num_segments = round(total_dur / self.duration)
|
||||||
for index in range(num_segments):
|
for index in range(num_segments):
|
||||||
start_time = index * self.duration
|
start_time = index * self.duration
|
||||||
self._validation.append(({"clean":clean,"noisy":noisy},
|
self._validation.append(
|
||||||
start_time))
|
({"clean": clean, "noisy": noisy}, start_time)
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
return DataLoader(
|
||||||
|
TrainDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
return DataLoader(
|
||||||
|
ValidDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class EnhancerDataset(TaskDataset):
|
||||||
"""
|
"""
|
||||||
Dataset object for creating clean-noisy speech enhancement datasets
|
Dataset object for creating clean-noisy speech enhancement datasets
|
||||||
paramters:
|
paramters:
|
||||||
name : str
|
name : str
|
||||||
name of the dataset
|
name of the dataset
|
||||||
root_dir : str
|
root_dir : str
|
||||||
root directory of the dataset containing clean/noisy folders
|
root directory of the dataset containing clean/noisy folders
|
||||||
files : Files
|
files : Files
|
||||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
folder names (refer cli/train_config/dataset)
|
folder names (refer enhancer.utils.Files dataclass)
|
||||||
duration : float
|
duration : float
|
||||||
expected audio duration of single audio sample for training
|
expected audio duration of single audio sample for training
|
||||||
sampling_rate : int
|
sampling_rate : int
|
||||||
desired sampling rate
|
desired sampling rate
|
||||||
batch_size : int
|
batch_size : int
|
||||||
batch size of each batch
|
batch size of each batch
|
||||||
num_workers : int
|
num_workers : int
|
||||||
|
|
@ -114,71 +129,92 @@ class EnhancerDataset(TaskDataset):
|
||||||
use one_to_one mapping for datasets with one noisy file for each clean file
|
use one_to_one mapping for datasets with one noisy file for each clean file
|
||||||
use one_to_many mapping for multiple noisy files for each clean file
|
use one_to_many mapping for multiple noisy files for each clean file
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name:str,
|
name: str,
|
||||||
root_dir:str,
|
root_dir: str,
|
||||||
files:Files,
|
files: Files,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers:Optional[int]=None):
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=name,
|
name=name,
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
files=files,
|
files=files,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function = matching_function,
|
matching_function=matching_function,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers = num_workers,
|
num_workers=num_workers,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
self.files = files
|
self.files = files
|
||||||
self.duration = max(1.0,duration)
|
self.duration = max(1.0, duration)
|
||||||
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
|
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
def setup(self, stage:Optional[str]=None):
|
|
||||||
|
|
||||||
super().setup(stage=stage)
|
super().setup(stage=stage)
|
||||||
|
|
||||||
def train__iter__(self):
|
def train__iter__(self):
|
||||||
|
|
||||||
rng = create_unique_rng(self.model.current_epoch)
|
rng = create_unique_rng(self.model.current_epoch)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
file_dict,*_ = rng.choices(self.train_data,k=1,
|
file_dict, *_ = rng.choices(
|
||||||
weights=[file["duration"] for file in self.train_data])
|
self.train_data,
|
||||||
file_duration = file_dict['duration']
|
k=1,
|
||||||
start_time = round(rng.uniform(0,file_duration- self.duration),2)
|
weights=[file["duration"] for file in self.train_data],
|
||||||
data = self.prepare_segment(file_dict,start_time)
|
)
|
||||||
|
file_duration = file_dict["duration"]
|
||||||
|
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
|
||||||
|
data = self.prepare_segment(file_dict, start_time)
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
def val__getitem__(self,idx):
|
def val__getitem__(self, idx):
|
||||||
return self.prepare_segment(*self._validation[idx])
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
||||||
def prepare_segment(self,file_dict:dict, start_time:float):
|
|
||||||
|
|
||||||
clean_segment = self.audio(file_dict["clean"],
|
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||||
offset=start_time,duration=self.duration)
|
|
||||||
noisy_segment = self.audio(file_dict["noisy"],
|
clean_segment = self.audio(
|
||||||
offset=start_time,duration=self.duration)
|
file_dict["clean"], offset=start_time, duration=self.duration
|
||||||
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
|
)
|
||||||
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
|
noisy_segment = self.audio(
|
||||||
return {"clean": clean_segment,"noisy":noisy_segment}
|
file_dict["noisy"], offset=start_time, duration=self.duration
|
||||||
|
)
|
||||||
|
clean_segment = F.pad(
|
||||||
|
clean_segment,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
int(
|
||||||
|
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
noisy_segment = F.pad(
|
||||||
|
noisy_segment,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
int(
|
||||||
|
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return {"clean": clean_segment, "noisy": noisy_segment}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
|
return math.ceil(
|
||||||
|
sum([file["duration"] for file in self.train_data]) / self.duration
|
||||||
|
)
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,108 +1,118 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from re import S
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
MATCHING_FNS = ("one_to_one","one_to_many")
|
MATCHING_FNS = ("one_to_one", "one_to_many")
|
||||||
|
|
||||||
|
|
||||||
class ProcessorFunctions:
|
class ProcessorFunctions:
|
||||||
|
"""
|
||||||
|
Preprocessing methods for different types of speech enhacement datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def one_to_one(clean_path,noisy_path):
|
def one_to_one(clean_path, noisy_path):
|
||||||
"""
|
"""
|
||||||
One clean audio can have only one noisy audio file
|
One clean audio can have only one noisy audio file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
matching_wavfiles = list()
|
matching_wavfiles = list()
|
||||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
clean_filenames = [
|
||||||
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))]
|
file.split("/")[-1]
|
||||||
common_filenames = np.intersect1d(noisy_filenames,clean_filenames)
|
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||||
|
]
|
||||||
|
noisy_filenames = [
|
||||||
|
file.split("/")[-1]
|
||||||
|
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
|
||||||
|
]
|
||||||
|
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
|
||||||
|
|
||||||
for file_name in common_filenames:
|
for file_name in common_filenames:
|
||||||
|
|
||||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
|
sr_clean, clean_file = wavfile.read(
|
||||||
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
|
os.path.join(clean_path, file_name)
|
||||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
)
|
||||||
(sr_clean==sr_noisy)):
|
sr_noisy, noisy_file = wavfile.read(
|
||||||
|
os.path.join(noisy_path, file_name)
|
||||||
|
)
|
||||||
|
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||||
|
sr_clean == sr_noisy
|
||||||
|
):
|
||||||
matching_wavfiles.append(
|
matching_wavfiles.append(
|
||||||
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
|
{
|
||||||
"duration":clean_file.shape[-1]/sr_clean}
|
"clean": os.path.join(clean_path, file_name),
|
||||||
)
|
"noisy": os.path.join(noisy_path, file_name),
|
||||||
|
"duration": clean_file.shape[-1] / sr_clean,
|
||||||
|
}
|
||||||
|
)
|
||||||
return matching_wavfiles
|
return matching_wavfiles
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def one_to_many(clean_path,noisy_path):
|
def one_to_many(clean_path, noisy_path):
|
||||||
"""
|
"""
|
||||||
One clean audio have multiple noisy audio files
|
One clean audio have multiple noisy audio files
|
||||||
"""
|
"""
|
||||||
|
|
||||||
matching_wavfiles = dict()
|
matching_wavfiles = dict()
|
||||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
clean_filenames = [
|
||||||
|
file.split("/")[-1]
|
||||||
|
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||||
|
]
|
||||||
for clean_file in clean_filenames:
|
for clean_file in clean_filenames:
|
||||||
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav"))
|
noisy_filenames = glob.glob(
|
||||||
|
os.path.join(noisy_path, f"*_{clean_file}.wav")
|
||||||
|
)
|
||||||
for noisy_file in noisy_filenames:
|
for noisy_file in noisy_filenames:
|
||||||
|
|
||||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
|
sr_clean, clean_file = wavfile.read(
|
||||||
|
os.path.join(clean_path, clean_file)
|
||||||
|
)
|
||||||
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
||||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||||
(sr_clean==sr_noisy)):
|
sr_clean == sr_noisy
|
||||||
|
):
|
||||||
matching_wavfiles.update(
|
matching_wavfiles.update(
|
||||||
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
|
{
|
||||||
"duration":clean_file.shape[-1]/sr_clean}
|
"clean": os.path.join(clean_path, clean_file),
|
||||||
)
|
"noisy": noisy_file,
|
||||||
|
"duration": clean_file.shape[-1] / sr_clean,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return matching_wavfiles
|
return matching_wavfiles
|
||||||
|
|
||||||
|
|
||||||
class Fileprocessor:
|
class Fileprocessor:
|
||||||
|
def __init__(self, clean_dir, noisy_dir, matching_function=None):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clean_dir,
|
|
||||||
noisy_dir,
|
|
||||||
matching_function = None
|
|
||||||
):
|
|
||||||
self.clean_dir = clean_dir
|
self.clean_dir = clean_dir
|
||||||
self.noisy_dir = noisy_dir
|
self.noisy_dir = noisy_dir
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_name(cls,
|
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||||
name:str,
|
|
||||||
clean_dir,
|
|
||||||
noisy_dir,
|
|
||||||
matching_function=None
|
|
||||||
):
|
|
||||||
|
|
||||||
if matching_function is None:
|
if matching_function is None:
|
||||||
if name.lower() == "vctk":
|
if name.lower() == "vctk":
|
||||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||||
elif name.lower() == "dns-2020":
|
elif name.lower() == "dns-2020":
|
||||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||||
else:
|
else:
|
||||||
if matching_function not in MATCHING_FNS:
|
if matching_function not in MATCHING_FNS:
|
||||||
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
|
raise ValueError(
|
||||||
|
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
|
return cls(
|
||||||
|
clean_dir,
|
||||||
|
noisy_dir,
|
||||||
|
getattr(ProcessorFunctions, matching_function),
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_matching_dict(self):
|
def prepare_matching_dict(self):
|
||||||
|
|
||||||
if self.matching_function is None:
|
if self.matching_function is None:
|
||||||
raise ValueError("Not a valid matching function")
|
raise ValueError("Not a valid matching function")
|
||||||
|
|
||||||
return self.matching_function(self.clean_dir,self.noisy_dir)
|
return self.matching_function(self.clean_dir, self.noisy_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,119 +1,168 @@
|
||||||
from json import load
|
from pathlib import Path
|
||||||
import wave
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.signal import get_window
|
|
||||||
from scipy.io import wavfile
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pathlib import Path
|
|
||||||
from librosa import load as load_audio
|
from librosa import load as load_audio
|
||||||
|
from scipy.io import wavfile
|
||||||
|
from scipy.signal import get_window
|
||||||
|
|
||||||
from enhancer.utils import Audio
|
from enhancer.utils import Audio
|
||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
|
"""
|
||||||
|
contains methods used for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_input(audio, sr, model_sr):
|
def read_input(audio, sr, model_sr):
|
||||||
|
"""
|
||||||
|
read and verify audio input regardless of the input format.
|
||||||
|
arguments:
|
||||||
|
audio : audio input
|
||||||
|
sr : sampling rate of input audio
|
||||||
|
model_sr : sampling rate used for model training.
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(audio,(np.ndarray,torch.Tensor)):
|
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
||||||
assert sr is not None, "Invalid sampling rate!"
|
assert sr is not None, "Invalid sampling rate!"
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.reshape(1,-1)
|
audio = audio.reshape(1, -1)
|
||||||
|
|
||||||
if isinstance(audio,str):
|
if isinstance(audio, str):
|
||||||
audio = Path(audio)
|
audio = Path(audio)
|
||||||
if not audio.is_file():
|
if not audio.is_file():
|
||||||
raise ValueError(f"Input file {audio} does not exist")
|
raise ValueError(f"Input file {audio} does not exist")
|
||||||
else:
|
else:
|
||||||
audio,sr = load_audio(audio,sr=sr,)
|
audio, sr = load_audio(
|
||||||
|
audio,
|
||||||
|
sr=sr,
|
||||||
|
)
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.reshape(1,-1)
|
audio = audio.reshape(1, -1)
|
||||||
else:
|
else:
|
||||||
assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
|
assert (
|
||||||
|
audio.shape[0] == 1
|
||||||
|
), "Enhance inference only supports single waveform"
|
||||||
|
|
||||||
waveform = Audio.resample_audio(audio,sr=sr,target_sr=model_sr)
|
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
|
||||||
waveform = Audio.convert_mono(waveform)
|
waveform = Audio.convert_mono(waveform)
|
||||||
if isinstance(waveform,np.ndarray):
|
if isinstance(waveform, np.ndarray):
|
||||||
waveform = torch.from_numpy(waveform)
|
waveform = torch.from_numpy(waveform)
|
||||||
|
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
|
def batchify(
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
break input waveform into samples with duration specified.
|
break input waveform into samples with duration specified.(Overlap-add)
|
||||||
|
arguments:
|
||||||
|
waveform : audio waveform
|
||||||
|
window_size : window size used for splitting waveform into batches
|
||||||
|
step_size : step_size used for splitting waveform into batches
|
||||||
"""
|
"""
|
||||||
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
assert (
|
||||||
_,num_samples = waveform.shape
|
waveform.ndim == 2
|
||||||
|
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
||||||
|
_, num_samples = waveform.shape
|
||||||
waveform = waveform.unsqueeze(-1)
|
waveform = waveform.unsqueeze(-1)
|
||||||
step_size = window_size//2 if step_size is None else step_size
|
step_size = window_size // 2 if step_size is None else step_size
|
||||||
|
|
||||||
if num_samples >= window_size:
|
if num_samples >= window_size:
|
||||||
waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1),
|
waveform_batch = F.unfold(
|
||||||
stride=(step_size,1), padding=(window_size,0))
|
waveform[None, ...],
|
||||||
waveform_batch = waveform_batch.permute(2,0,1)
|
kernel_size=(window_size, 1),
|
||||||
|
stride=(step_size, 1),
|
||||||
|
padding=(window_size, 0),
|
||||||
|
)
|
||||||
|
waveform_batch = waveform_batch.permute(2, 0, 1)
|
||||||
|
|
||||||
return waveform_batch
|
return waveform_batch
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None,
|
def aggreagate(
|
||||||
window="hanning",):
|
data: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
total_frames: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
|
window="hanning",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
takes input as tensor outputs aggregated waveform
|
stitch batched waveform into single waveform. (Overlap-add)
|
||||||
|
arguments:
|
||||||
|
data: batched waveform
|
||||||
|
window_size : window_size used to batch waveform
|
||||||
|
step_size : step_size used to batch waveform
|
||||||
|
total_frames : total number of frames present in original waveform
|
||||||
|
window : type of window used for overlap-add mechanism.
|
||||||
"""
|
"""
|
||||||
num_chunks,n_channels,num_frames = data.shape
|
num_chunks, n_channels, num_frames = data.shape
|
||||||
window = get_window(window=window,Nx=data.shape[-1])
|
window = get_window(window=window, Nx=data.shape[-1])
|
||||||
window = torch.from_numpy(window).to(data.device)
|
window = torch.from_numpy(window).to(data.device)
|
||||||
data *= window
|
data *= window
|
||||||
step_size = window_size//2 if step_size is None else step_size
|
step_size = window_size // 2 if step_size is None else step_size
|
||||||
|
|
||||||
|
data = data.permute(1, 2, 0)
|
||||||
|
data = F.fold(
|
||||||
|
data,
|
||||||
|
(total_frames, 1),
|
||||||
|
kernel_size=(window_size, 1),
|
||||||
|
stride=(step_size, 1),
|
||||||
|
padding=(window_size, 0),
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
data = data.permute(1,2,0)
|
return data.reshape(1, n_channels, -1)
|
||||||
data = F.fold(data,
|
|
||||||
(total_frames,1),
|
|
||||||
kernel_size=(window_size,1),
|
|
||||||
stride=(step_size,1),
|
|
||||||
padding=(window_size,0)).squeeze(-1)
|
|
||||||
|
|
||||||
return data.reshape(1,n_channels,-1)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):
|
def write_output(
|
||||||
|
waveform: torch.Tensor, filename: Union[str, Path], sr: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
write audio output as wav file
|
||||||
|
arguments:
|
||||||
|
waveform : audio waveform
|
||||||
|
filename : name of the wave file. Output will be written as cleaned_filename.wav
|
||||||
|
sr : sampling rate
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(filename,str):
|
if isinstance(filename, str):
|
||||||
filename = Path(filename)
|
filename = Path(filename)
|
||||||
|
|
||||||
parent, name = filename.parent, "cleaned_"+filename.name
|
parent, name = filename.parent, "cleaned_" + filename.name
|
||||||
filename = parent/Path(name)
|
filename = parent / Path(name)
|
||||||
if filename.is_file():
|
if filename.is_file():
|
||||||
raise FileExistsError(f"file {filename} already exists")
|
raise FileExistsError(f"file {filename} already exists")
|
||||||
else:
|
else:
|
||||||
if isinstance(waveform,torch.Tensor):
|
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
|
||||||
waveform = waveform.detach().cpu().squeeze().numpy()
|
|
||||||
wavfile.write(filename,rate=sr,data=waveform)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
|
def prepare_output(
|
||||||
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
|
waveform: torch.Tensor,
|
||||||
|
model_sampling_rate: int,
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: Optional[int],
|
||||||
):
|
):
|
||||||
if isinstance(audio,np.ndarray):
|
"""
|
||||||
|
prepare output audio based on input format
|
||||||
|
arguments:
|
||||||
|
waveform : predicted audio waveform
|
||||||
|
model_sampling_rate : sampling rate used to train the model
|
||||||
|
audio : input audio
|
||||||
|
sampling_rate : input audio sampling rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(audio, np.ndarray):
|
||||||
waveform = waveform.detach().cpu().numpy()
|
waveform = waveform.detach().cpu().numpy()
|
||||||
|
|
||||||
if sampling_rate!=None:
|
if sampling_rate is not None:
|
||||||
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
|
waveform = Audio.resample_audio(
|
||||||
|
waveform, sr=model_sampling_rate, target_sr=sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
113
enhancer/loss.py
113
enhancer/loss.py
|
|
@ -3,62 +3,82 @@ import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class mean_squared_error(nn.Module):
|
class mean_squared_error(nn.Module):
|
||||||
|
"""
|
||||||
|
Mean squared error / L1 loss
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,reduction="mean"):
|
def __init__(self, reduction="mean"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
|
||||||
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
raise TypeError(
|
||||||
got {prediction.size()} and {target.size()} instead""")
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
|
)
|
||||||
|
|
||||||
return self.loss_fun(prediction, target)
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
class mean_absolute_error(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,reduction="mean"):
|
class mean_absolute_error(nn.Module):
|
||||||
|
"""
|
||||||
|
Mean absolute error / L2 loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reduction="mean"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
|
||||||
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
raise TypeError(
|
||||||
got {prediction.size()} and {target.size()} instead""")
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
|
)
|
||||||
|
|
||||||
return self.loss_fun(prediction, target)
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
class Si_SDR(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
class Si_SDR(nn.Module):
|
||||||
self,
|
"""
|
||||||
reduction:str="mean"
|
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
||||||
):
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reduction: str = "mean"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if reduction in ["sum","mean",None]:
|
if reduction in ["sum", "mean", None]:
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
raise TypeError(
|
||||||
|
"Invalid reduction, valid options are sum, mean, None"
|
||||||
|
)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
|
||||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
raise TypeError(
|
||||||
got {prediction.size()} and {target.size()} instead""")
|
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||||
|
got {prediction.size()} and {target.size()} instead"""
|
||||||
target_energy = torch.sum(target**2,keepdim=True,dim=-1)
|
)
|
||||||
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy
|
|
||||||
|
target_energy = torch.sum(target**2, keepdim=True, dim=-1)
|
||||||
|
scaling_factor = (
|
||||||
|
torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
|
||||||
|
)
|
||||||
target_projection = target * scaling_factor
|
target_projection = target * scaling_factor
|
||||||
noise = prediction - target_projection
|
noise = prediction - target_projection
|
||||||
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1)
|
ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
|
||||||
si_sdr = 10*torch.log10(ratio).mean(dim=-1)
|
noise**2, dim=-1
|
||||||
|
)
|
||||||
|
si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
|
||||||
|
|
||||||
if self.reduction == "sum":
|
if self.reduction == "sum":
|
||||||
si_sdr = si_sdr.sum()
|
si_sdr = si_sdr.sum()
|
||||||
|
|
@ -66,46 +86,55 @@ class Si_SDR(nn.Module):
|
||||||
si_sdr = si_sdr.mean()
|
si_sdr = si_sdr.mean()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return si_sdr
|
return si_sdr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Avergeloss(nn.Module):
|
class Avergeloss(nn.Module):
|
||||||
|
"""
|
||||||
|
Combine multiple metics of same nature.
|
||||||
|
for example, ["mea","mae"]
|
||||||
|
parameters:
|
||||||
|
losses : loss function names to be combined
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,losses):
|
def __init__(self, losses):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.valid_losses = nn.ModuleList()
|
self.valid_losses = nn.ModuleList()
|
||||||
|
|
||||||
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
|
direction = [
|
||||||
|
getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
|
||||||
|
]
|
||||||
if len(set(direction)) > 1:
|
if len(set(direction)) > 1:
|
||||||
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
|
raise ValueError(
|
||||||
|
"all cost functions should be of same nature, maximize or minimize!"
|
||||||
|
)
|
||||||
|
|
||||||
self.higher_better = direction[0]
|
self.higher_better = direction[0]
|
||||||
for loss in losses:
|
for loss in losses:
|
||||||
loss = self.validate_loss(loss)
|
loss = self.validate_loss(loss)
|
||||||
self.valid_losses.append(loss())
|
self.valid_losses.append(loss())
|
||||||
|
|
||||||
|
def validate_loss(self, loss: str):
|
||||||
def validate_loss(self,loss:str):
|
|
||||||
if loss not in LOSS_MAP.keys():
|
if loss not in LOSS_MAP.keys():
|
||||||
raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}")
|
raise ValueError(
|
||||||
|
f"""Invalid loss function {loss}, available loss functions are
|
||||||
|
{tuple([loss for loss in LOSS_MAP.keys()])}"""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return LOSS_MAP[loss]
|
return LOSS_MAP[loss]
|
||||||
|
|
||||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
for loss_fun in self.valid_losses:
|
for loss_fun in self.valid_losses:
|
||||||
loss += loss_fun(prediction, target)
|
loss += loss_fun(prediction, target)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LOSS_MAP = {"mae":mean_absolute_error,
|
|
||||||
"mse": mean_squared_error,
|
|
||||||
"SI-SDR":Si_SDR}
|
|
||||||
|
|
||||||
|
|
||||||
|
LOSS_MAP = {
|
||||||
|
"mae": mean_absolute_error,
|
||||||
|
"mse": mean_squared_error,
|
||||||
|
"SI-SDR": Si_SDR,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from enhancer.models.demucs import Demucs
|
from enhancer.models.demucs import Demucs
|
||||||
|
from enhancer.models.model import Model
|
||||||
from enhancer.models.waveunet import WaveUnet
|
from enhancer.models.waveunet import WaveUnet
|
||||||
from enhancer.models.model import Model
|
|
||||||
|
|
@ -1,217 +1,264 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union, List
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
import math
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from enhancer.models.model import Model
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models.model import Model
|
||||||
from enhancer.utils.io import Audio as audio
|
from enhancer.utils.io import Audio as audio
|
||||||
from enhancer.utils.utils import merge_dict
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
class DemucsLSTM(nn.Module):
|
class DemucsLSTM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size:int,
|
input_size: int,
|
||||||
hidden_size:int,
|
hidden_size: int,
|
||||||
num_layers:int,
|
num_layers: int,
|
||||||
bidirectional:bool=True
|
bidirectional: bool = True,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
self.lstm = nn.LSTM(
|
||||||
|
input_size, hidden_size, num_layers, bidirectional=bidirectional
|
||||||
|
)
|
||||||
dim = 2 if bidirectional else 1
|
dim = 2 if bidirectional else 1
|
||||||
self.linear = nn.Linear(dim*hidden_size,hidden_size)
|
self.linear = nn.Linear(dim * hidden_size, hidden_size)
|
||||||
|
|
||||||
def forward(self,x):
|
def forward(self, x):
|
||||||
|
|
||||||
output,(h,c) = self.lstm(x)
|
output, (h, c) = self.lstm(x)
|
||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
|
|
||||||
return output,(h,c)
|
return output, (h, c)
|
||||||
|
|
||||||
|
|
||||||
class DemucsEncoder(nn.Module):
|
class DemucsEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int,
|
num_channels: int,
|
||||||
hidden_size:int,
|
hidden_size: int,
|
||||||
kernel_size:int,
|
kernel_size: int,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
glu:bool=False,
|
glu: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
multi_factor = 2 if glu else 1
|
multi_factor = 2 if glu else 1
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
nn.Conv1d(num_channels,hidden_size,kernel_size,stride),
|
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1),
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||||
activation
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,waveform):
|
def forward(self, waveform):
|
||||||
|
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
class DemucsDecoder(nn.Module):
|
|
||||||
|
|
||||||
|
class DemucsDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int,
|
num_channels: int,
|
||||||
hidden_size:int,
|
hidden_size: int,
|
||||||
kernel_size:int,
|
kernel_size: int,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
glu:bool=False,
|
glu: bool = False,
|
||||||
layer:int=0
|
layer: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
multi_factor = 2 if glu else 1
|
multi_factor = 2 if glu else 1
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1),
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||||
activation,
|
activation,
|
||||||
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride)
|
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
||||||
)
|
)
|
||||||
if layer>0:
|
if layer > 0:
|
||||||
self.decoder.add_module("4", nn.ReLU())
|
self.decoder.add_module("4", nn.ReLU())
|
||||||
|
|
||||||
def forward(self,waveform,):
|
def forward(
|
||||||
|
self,
|
||||||
|
waveform,
|
||||||
|
):
|
||||||
|
|
||||||
out = self.decoder(waveform)
|
out = self.decoder(waveform)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Demucs(Model):
|
class Demucs(Model):
|
||||||
|
"""
|
||||||
|
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||||
|
parameters:
|
||||||
|
encoder_decoder: dict, optional
|
||||||
|
keyword arguments passsed to encoder decoder block
|
||||||
|
lstm : dict, optional
|
||||||
|
keyword arguments passsed to LSTM block
|
||||||
|
num_channels: int, defaults to 1
|
||||||
|
number channels in input audio
|
||||||
|
sampling_rate: int, defaults to 16KHz
|
||||||
|
sampling rate of input audio
|
||||||
|
lr : float, defaults to 1e-3
|
||||||
|
learning rate used for training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
EnhancerDataset object containing train/validation data for training
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds
|
||||||
|
loss : string or List of strings
|
||||||
|
loss function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
metric : string or List of strings
|
||||||
|
metric function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
ED_DEFAULTS = {
|
ED_DEFAULTS = {
|
||||||
"initial_output_channels":48,
|
"initial_output_channels": 48,
|
||||||
"kernel_size":8,
|
"kernel_size": 8,
|
||||||
"stride":1,
|
"stride": 1,
|
||||||
"depth":5,
|
"depth": 5,
|
||||||
"glu":True,
|
"glu": True,
|
||||||
"growth_factor":2,
|
"growth_factor": 2,
|
||||||
}
|
}
|
||||||
LSTM_DEFAULTS = {
|
LSTM_DEFAULTS = {
|
||||||
"bidirectional":True,
|
"bidirectional": True,
|
||||||
"num_layers":2,
|
"num_layers": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_decoder:Optional[dict]=None,
|
encoder_decoder: Optional[dict] = None,
|
||||||
lstm:Optional[dict]=None,
|
lstm: Optional[dict] = None,
|
||||||
num_channels:int=1,
|
num_channels: int = 1,
|
||||||
resample:int=4,
|
resample: int = 4,
|
||||||
sampling_rate = 16000,
|
sampling_rate=16000,
|
||||||
lr:float=1e-3,
|
lr: float = 1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
loss:Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str, List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
|
|
||||||
|
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate!=dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
logging.warn(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
num_channels=num_channels,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
dataset=dataset,
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
duration=duration,
|
||||||
self.save_hyperparameters("encoder_decoder","lstm","resample")
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
||||||
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
|
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||||
hidden = encoder_decoder["initial_output_channels"]
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
for layer in range(encoder_decoder["depth"]):
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
encoder_layer = DemucsEncoder(num_channels=num_channels,
|
encoder_layer = DemucsEncoder(
|
||||||
hidden_size=hidden,
|
num_channels=num_channels,
|
||||||
kernel_size=encoder_decoder["kernel_size"],
|
hidden_size=hidden,
|
||||||
stride=encoder_decoder["stride"],
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
glu=encoder_decoder["glu"],
|
stride=encoder_decoder["stride"],
|
||||||
)
|
glu=encoder_decoder["glu"],
|
||||||
|
)
|
||||||
self.encoder.append(encoder_layer)
|
self.encoder.append(encoder_layer)
|
||||||
|
|
||||||
decoder_layer = DemucsDecoder(num_channels=num_channels,
|
decoder_layer = DemucsDecoder(
|
||||||
hidden_size=hidden,
|
num_channels=num_channels,
|
||||||
kernel_size=encoder_decoder["kernel_size"],
|
hidden_size=hidden,
|
||||||
stride=1,
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
glu=encoder_decoder["glu"],
|
stride=1,
|
||||||
layer=layer
|
glu=encoder_decoder["glu"],
|
||||||
)
|
layer=layer,
|
||||||
self.decoder.insert(0,decoder_layer)
|
)
|
||||||
|
self.decoder.insert(0, decoder_layer)
|
||||||
|
|
||||||
num_channels = hidden
|
num_channels = hidden
|
||||||
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
||||||
|
|
||||||
self.de_lstm = DemucsLSTM(input_size=num_channels,
|
|
||||||
hidden_size=num_channels,
|
|
||||||
num_layers=lstm["num_layers"],
|
|
||||||
bidirectional=lstm["bidirectional"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self,waveform):
|
self.de_lstm = DemucsLSTM(
|
||||||
|
input_size=num_channels,
|
||||||
|
hidden_size=num_channels,
|
||||||
|
num_layers=lstm["num_layers"],
|
||||||
|
bidirectional=lstm["bidirectional"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1)!=1:
|
if waveform.size(1) != 1:
|
||||||
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels")
|
raise TypeError(
|
||||||
|
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
length = waveform.shape[-1]
|
length = waveform.shape[-1]
|
||||||
x = F.pad(waveform, (0,self.get_padding_length(length) - length))
|
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||||
if self.hparams.resample>1:
|
if self.hparams.resample > 1:
|
||||||
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
|
x = audio.resample_audio(
|
||||||
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
|
audio=x,
|
||||||
|
sr=self.hparams.sampling_rate,
|
||||||
|
target_sr=int(
|
||||||
|
self.hparams.sampling_rate * self.hparams.resample
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for encoder in self.encoder:
|
for encoder in self.encoder:
|
||||||
x = encoder(x)
|
x = encoder(x)
|
||||||
encoder_outputs.append(x)
|
encoder_outputs.append(x)
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0, 2, 1)
|
||||||
x,_ = self.de_lstm(x)
|
x, _ = self.de_lstm(x)
|
||||||
|
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0, 2, 1)
|
||||||
for decoder in self.decoder:
|
for decoder in self.decoder:
|
||||||
skip_connection = encoder_outputs.pop(-1)
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
x += skip_connection[..., :x.shape[-1]]
|
x += skip_connection[..., : x.shape[-1]]
|
||||||
x = decoder(x)
|
x = decoder(x)
|
||||||
|
|
||||||
if self.hparams.resample > 1:
|
if self.hparams.resample > 1:
|
||||||
x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
|
x = audio.resample_audio(
|
||||||
self.hparams.sampling_rate)
|
x,
|
||||||
|
int(self.hparams.sampling_rate * self.hparams.resample),
|
||||||
|
self.hparams.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_padding_length(self,input_length):
|
def get_padding_length(self, input_length):
|
||||||
|
|
||||||
input_length = math.ceil(input_length * self.hparams.resample)
|
input_length = math.ceil(input_length * self.hparams.resample)
|
||||||
|
|
||||||
|
for layer in range(
|
||||||
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation
|
self.hparams.encoder_decoder["depth"]
|
||||||
input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1
|
): # encoder operation
|
||||||
input_length = max(1,input_length)
|
input_length = (
|
||||||
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration
|
math.ceil(
|
||||||
input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"]
|
(input_length - self.hparams.encoder_decoder["kernel_size"])
|
||||||
input_length = math.ceil(input_length/self.hparams.resample)
|
/ self.hparams.encoder_decoder["stride"]
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
input_length = max(1, input_length)
|
||||||
|
for layer in range(
|
||||||
|
self.hparams.encoder_decoder["depth"]
|
||||||
|
): # decoder operaration
|
||||||
|
input_length = (input_length - 1) * self.hparams.encoder_decoder[
|
||||||
|
"stride"
|
||||||
|
] + self.hparams.encoder_decoder["kernel_size"]
|
||||||
|
input_length = math.ceil(input_length / self.hparams.resample)
|
||||||
|
|
||||||
return int(input_length)
|
return int(input_length)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,92 +1,108 @@
|
||||||
try:
|
|
||||||
from functools import cached_property
|
|
||||||
except ImportError:
|
|
||||||
from backports.cached_property import cached_property
|
|
||||||
from importlib import import_module
|
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Union, List, Text, Dict, Any
|
from importlib import import_module
|
||||||
from torch.optim import Adam
|
|
||||||
import torch
|
|
||||||
from torch.nn.functional import pad
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Text, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer import __version__
|
from enhancer import __version__
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.utils.io import Audio
|
|
||||||
from enhancer.loss import Avergeloss
|
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
|
from enhancer.loss import Avergeloss
|
||||||
|
|
||||||
CACHE_DIR = ""
|
CACHE_DIR = ""
|
||||||
HF_TORCH_WEIGHTS = ""
|
HF_TORCH_WEIGHTS = ""
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
|
|
||||||
|
|
||||||
class Model(pl.LightningModule):
|
class Model(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
Base class for all models
|
||||||
|
parameters:
|
||||||
|
num_channels: int, default to 1
|
||||||
|
number of channels in input audio
|
||||||
|
sampling_rate : int, default 16khz
|
||||||
|
audio sampling rate
|
||||||
|
lr: float, optional
|
||||||
|
learning rate for model training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
Enhancer dataset used for training/validation
|
||||||
|
duration: float, optional
|
||||||
|
duration used for training/inference
|
||||||
|
loss : string or List of strings, default to "mse"
|
||||||
|
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int=1,
|
num_channels: int = 1,
|
||||||
sampling_rate:int=16000,
|
sampling_rate: int = 16000,
|
||||||
lr:float=1e-3,
|
lr: float = 1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration:Optional[float]=None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str,List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert (
|
||||||
|
num_channels == 1
|
||||||
|
), "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
self.save_hyperparameters(
|
||||||
|
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||||
|
)
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(
|
||||||
|
dict(self.hparams), "hyperparameters.json"
|
||||||
|
)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss(self):
|
def loss(self):
|
||||||
return self._loss
|
return self._loss
|
||||||
|
|
||||||
@loss.setter
|
|
||||||
def loss(self,loss):
|
|
||||||
|
|
||||||
if isinstance(loss,str):
|
@loss.setter
|
||||||
losses = [loss]
|
def loss(self, loss):
|
||||||
|
|
||||||
|
if isinstance(loss, str):
|
||||||
|
losses = [loss]
|
||||||
|
|
||||||
self._loss = Avergeloss(losses)
|
self._loss = Avergeloss(losses)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric(self):
|
def metric(self):
|
||||||
return self._metric
|
return self._metric
|
||||||
|
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self,metric):
|
def metric(self, metric):
|
||||||
|
|
||||||
|
if isinstance(metric, str):
|
||||||
|
metric = [metric]
|
||||||
|
|
||||||
if isinstance(metric,str):
|
|
||||||
metric = [metric]
|
|
||||||
|
|
||||||
self._metric = Avergeloss(metric)
|
self._metric = Avergeloss(metric)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
return self._dataset
|
return self._dataset
|
||||||
|
|
||||||
@dataset.setter
|
@dataset.setter
|
||||||
def dataset(self,dataset):
|
def dataset(self, dataset):
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
|
|
||||||
def setup(self,stage:Optional[str]=None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
||||||
|
|
@ -94,9 +110,9 @@ class Model(pl.LightningModule):
|
||||||
return self.dataset.val_dataloader()
|
return self.dataset.val_dataloader()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return Adam(self.parameters(), lr = self.hparams.lr)
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
|
|
||||||
def training_step(self,batch, batch_idx:int):
|
def training_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
mixed_waveform = batch["noisy"]
|
mixed_waveform = batch["noisy"]
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
|
|
@ -105,13 +121,16 @@ class Model(pl.LightningModule):
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(
|
||||||
key="train_loss", value=loss.item(),
|
run_id=self.logger.run_id,
|
||||||
step=self.global_step)
|
key="train_loss",
|
||||||
self.log("train_loss",loss.item())
|
value=loss.item(),
|
||||||
return {"loss":loss}
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
self.log("train_loss", loss.item())
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
def validation_step(self,batch,batch_idx:int):
|
def validation_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
mixed_waveform = batch["noisy"]
|
mixed_waveform = batch["noisy"]
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
|
|
@ -119,48 +138,92 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
metric_val = self.metric(prediction, target)
|
metric_val = self.metric(prediction, target)
|
||||||
loss_val = self.loss(prediction, target)
|
loss_val = self.loss(prediction, target)
|
||||||
self.log("val_metric",metric_val.item())
|
self.log("val_metric", metric_val.item())
|
||||||
self.log("val_loss",loss_val.item())
|
self.log("val_loss", loss_val.item())
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(
|
||||||
key="val_loss",value=loss_val.item(),
|
run_id=self.logger.run_id,
|
||||||
step=self.global_step)
|
key="val_loss",
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
value=loss_val.item(),
|
||||||
key="val_metric",value=metric_val.item(),
|
step=self.global_step,
|
||||||
step=self.global_step)
|
)
|
||||||
|
self.logger.experiment.log_metric(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
key="val_metric",
|
||||||
|
value=metric_val.item(),
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
return {"loss":loss_val}
|
return {"loss": loss_val}
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint["enhancer"] = {
|
checkpoint["enhancer"] = {
|
||||||
"version": {
|
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||||
"enhancer":__version__,
|
"architecture": {
|
||||||
"pytorch":torch.__version__
|
"module": self.__class__.__module__,
|
||||||
|
"class": self.__class__.__name__,
|
||||||
},
|
},
|
||||||
"architecture":{
|
|
||||||
"module":self.__class__.__module__,
|
|
||||||
"class":self.__class__.__name__
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
checkpoint: Union[Path, Text],
|
checkpoint: Union[Path, Text],
|
||||||
map_location = None,
|
map_location=None,
|
||||||
hparams_file: Union[Path, Text] = None,
|
hparams_file: Union[Path, Text] = None,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
use_auth_token: Union[Text, None] = None,
|
use_auth_token: Union[Text, None] = None,
|
||||||
cached_dir: Union[Path, Text]=CACHE_DIR,
|
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Load Pretrained model
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
checkpoint : Path or str
|
||||||
|
Path to checkpoint, or a remote URL, or a model identifier from
|
||||||
|
the huggingface.co model hub.
|
||||||
|
map_location: optional
|
||||||
|
Same role as in torch.load().
|
||||||
|
Defaults to `lambda storage, loc: storage`.
|
||||||
|
hparams_file : Path or str, optional
|
||||||
|
Path to a .yaml file with hierarchical structure as in this example:
|
||||||
|
drop_prob: 0.2
|
||||||
|
dataloader:
|
||||||
|
batch_size: 32
|
||||||
|
You most likely won’t need this since Lightning will always save the
|
||||||
|
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||||
|
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||||
|
file with the hparams you would like to use. These will be converted
|
||||||
|
into a dict and passed into your Model for use.
|
||||||
|
strict : bool, optional
|
||||||
|
Whether to strictly enforce that the keys in checkpoint match
|
||||||
|
the keys returned by this module’s state dict. Defaults to True.
|
||||||
|
use_auth_token : str, optional
|
||||||
|
When loading a private huggingface.co model, set `use_auth_token`
|
||||||
|
to True or to a string containing your hugginface.co authentication
|
||||||
|
token that can be obtained by running `huggingface-cli login`
|
||||||
|
cache_dir: Path or str, optional
|
||||||
|
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||||
|
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||||
|
kwargs: optional
|
||||||
|
Any extra keyword args needed to init the model.
|
||||||
|
Can also be used to override saved hyperparameter values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Model
|
||||||
|
Model
|
||||||
|
|
||||||
|
See also
|
||||||
|
--------
|
||||||
|
torch.load
|
||||||
|
"""
|
||||||
|
|
||||||
checkpoint = str(checkpoint)
|
checkpoint = str(checkpoint)
|
||||||
if hparams_file is not None:
|
if hparams_file is not None:
|
||||||
|
|
@ -168,104 +231,133 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
model_path_pl = checkpoint
|
model_path_pl = checkpoint
|
||||||
elif urlparse(checkpoint).scheme in ("http","https"):
|
elif urlparse(checkpoint).scheme in ("http", "https"):
|
||||||
model_path_pl = checkpoint
|
model_path_pl = checkpoint
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if "@" in checkpoint:
|
if "@" in checkpoint:
|
||||||
model_id = checkpoint.split("@")[0]
|
model_id = checkpoint.split("@")[0]
|
||||||
revision_id = checkpoint.split("@")[1]
|
revision_id = checkpoint.split("@")[1]
|
||||||
else:
|
else:
|
||||||
model_id = checkpoint
|
model_id = checkpoint
|
||||||
revision_id = None
|
revision_id = None
|
||||||
|
|
||||||
url = hf_hub_url(
|
url = hf_hub_url(
|
||||||
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||||
)
|
)
|
||||||
model_path_pl = cached_download(
|
model_path_pl = cached_download(
|
||||||
url=url,library_name="enhancer",library_version=__version__,
|
url=url,
|
||||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
library_name="enhancer",
|
||||||
|
library_version=__version__,
|
||||||
|
cache_dir=cached_dir,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if map_location is None:
|
if map_location is None:
|
||||||
map_location = torch.device(DEFAULT_DEVICE)
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||||
module = import_module(module_name)
|
module = import_module(module_name)
|
||||||
Klass = getattr(module, class_name)
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = Klass.load_from_checkpoint(
|
model = Klass.load_from_checkpoint(
|
||||||
checkpoint_path = model_path_pl,
|
checkpoint_path=model_path_pl,
|
||||||
map_location = map_location,
|
map_location=map_location,
|
||||||
hparams_file = hparams_file,
|
hparams_file=hparams_file,
|
||||||
strict = strict,
|
strict=strict,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
return model
|
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||||
|
"""
|
||||||
|
perform model inference
|
||||||
|
parameters:
|
||||||
|
batch : torch.Tensor
|
||||||
|
input data
|
||||||
|
batch_size : int, default 32
|
||||||
|
batch size for inference
|
||||||
|
"""
|
||||||
|
|
||||||
def infer(self,batch:torch.Tensor,batch_size:int=32):
|
assert (
|
||||||
|
batch.ndim == 3
|
||||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||||
batch_predictions = []
|
batch_predictions = []
|
||||||
self.eval().to(self.device)
|
self.eval().to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_id in range(0,batch.shape[0],batch_size):
|
for batch_id in range(0, batch.shape[0], batch_size):
|
||||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
prediction = self(batch_data)
|
prediction = self(batch_data)
|
||||||
batch_predictions.append(prediction)
|
batch_predictions.append(prediction)
|
||||||
|
|
||||||
return torch.vstack(batch_predictions)
|
return torch.vstack(batch_predictions)
|
||||||
|
|
||||||
def enhance(
|
def enhance(
|
||||||
self,
|
self,
|
||||||
audio:Union[Path,np.ndarray,torch.Tensor],
|
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||||
sampling_rate:Optional[int]=None,
|
sampling_rate: Optional[int] = None,
|
||||||
batch_size:int=32,
|
batch_size: int = 32,
|
||||||
save_output:bool=False,
|
save_output: bool = False,
|
||||||
duration:Optional[int]=None,
|
duration: Optional[int] = None,
|
||||||
step_size:Optional[int]=None,):
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enhance audio using loaded pretained model.
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate: int, optional incase input is path
|
||||||
|
sampling rate of input
|
||||||
|
batch_size: int, default 32
|
||||||
|
input audio is split into multiple chunks. Inference is done on batches
|
||||||
|
of these chunks according to given batch size.
|
||||||
|
save_output : bool, default False
|
||||||
|
weather to save output to file
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||||
|
step_size: int, optional
|
||||||
|
step size between consecutive durations, defaults to 50% of duration
|
||||||
|
"""
|
||||||
|
|
||||||
model_sampling_rate = self.hparams["sampling_rate"]
|
model_sampling_rate = self.hparams["sampling_rate"]
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = self.hparams["duration"]
|
duration = self.hparams["duration"]
|
||||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
waveform = Inference.read_input(
|
||||||
|
audio, sampling_rate, model_sampling_rate
|
||||||
|
)
|
||||||
waveform.to(self.device)
|
waveform.to(self.device)
|
||||||
window_size = round(duration * model_sampling_rate)
|
window_size = round(duration * model_sampling_rate)
|
||||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
batched_waveform = Inference.batchify(
|
||||||
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
waveform, window_size, step_size=step_size
|
||||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
)
|
||||||
|
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||||
if save_output and isinstance(audio,(str,Path)):
|
waveform = Inference.aggreagate(
|
||||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
batch_prediction,
|
||||||
|
window_size,
|
||||||
|
waveform.shape[-1],
|
||||||
|
step_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if save_output and isinstance(audio, (str, Path)):
|
||||||
|
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
waveform = Inference.prepare_output(
|
||||||
audio, sampling_rate)
|
waveform, model_sampling_rate, audio, sampling_rate
|
||||||
|
)
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_monitor(self):
|
def valid_monitor(self):
|
||||||
|
|
||||||
return "max" if self.loss.higher_better else "min"
|
return "max" if self.loss.higher_better else "min"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,82 +1,124 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Union, List
|
|
||||||
|
|
||||||
from enhancer.models.model import Model
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
class WavenetDecoder(nn.Module):
|
class WavenetDecoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels:int,
|
in_channels: int,
|
||||||
out_channels:int,
|
out_channels: int,
|
||||||
kernel_size:int=5,
|
kernel_size: int = 5,
|
||||||
padding:int=2,
|
padding: int = 2,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
dilation:int=1,
|
dilation: int = 1,
|
||||||
):
|
):
|
||||||
super(WavenetDecoder,self).__init__()
|
super(WavenetDecoder, self).__init__()
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
nn.LeakyReLU(negative_slope=0.1)
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,waveform):
|
def forward(self, waveform):
|
||||||
|
|
||||||
return self.decoder(waveform)
|
return self.decoder(waveform)
|
||||||
|
|
||||||
class WavenetEncoder(nn.Module):
|
|
||||||
|
|
||||||
|
class WavenetEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels:int,
|
in_channels: int,
|
||||||
out_channels:int,
|
out_channels: int,
|
||||||
kernel_size:int=15,
|
kernel_size: int = 15,
|
||||||
padding:int=7,
|
padding: int = 7,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
dilation:int=1,
|
dilation: int = 1,
|
||||||
):
|
):
|
||||||
super(WavenetEncoder,self).__init__()
|
super(WavenetEncoder, self).__init__()
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
nn.LeakyReLU(negative_slope=0.1)
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(self, waveform):
|
||||||
self,
|
|
||||||
waveform
|
|
||||||
):
|
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
class WaveUnet(Model):
|
class WaveUnet(Model):
|
||||||
|
"""
|
||||||
|
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||||
|
parameters:
|
||||||
|
num_channels: int, defaults to 1
|
||||||
|
number of channels in input audio
|
||||||
|
depth : int, defaults to 12
|
||||||
|
depth of network
|
||||||
|
initial_output_channels: int, defaults to 24
|
||||||
|
number of output channels in intial upsampling layer
|
||||||
|
sampling_rate: int, defaults to 16KHz
|
||||||
|
sampling rate of input audio
|
||||||
|
lr : float, defaults to 1e-3
|
||||||
|
learning rate used for training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
EnhancerDataset object containing train/validation data for training
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds
|
||||||
|
loss : string or List of strings
|
||||||
|
loss function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
metric : string or List of strings
|
||||||
|
metric function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int=1,
|
num_channels: int = 1,
|
||||||
depth:int=12,
|
depth: int = 12,
|
||||||
initial_output_channels:int=24,
|
initial_output_channels: int = 24,
|
||||||
sampling_rate:int=16000,
|
sampling_rate: int = 16000,
|
||||||
lr:float=1e-3,
|
lr: float = 1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration:Optional[float]=None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str,List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate!=dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
logging.warn(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
num_channels=num_channels,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
)
|
)
|
||||||
self.save_hyperparameters("depth")
|
self.save_hyperparameters("depth")
|
||||||
self.encoders = nn.ModuleList()
|
self.encoders = nn.ModuleList()
|
||||||
|
|
@ -84,72 +126,76 @@ class WaveUnet(Model):
|
||||||
out_channels = initial_output_channels
|
out_channels = initial_output_channels
|
||||||
for layer in range(depth):
|
for layer in range(depth):
|
||||||
|
|
||||||
encoder = WavenetEncoder(num_channels,out_channels)
|
encoder = WavenetEncoder(num_channels, out_channels)
|
||||||
self.encoders.append(encoder)
|
self.encoders.append(encoder)
|
||||||
|
|
||||||
num_channels = out_channels
|
num_channels = out_channels
|
||||||
out_channels += initial_output_channels
|
out_channels += initial_output_channels
|
||||||
if layer == depth -1 :
|
if layer == depth - 1:
|
||||||
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels)
|
decoder = WavenetDecoder(
|
||||||
|
depth * initial_output_channels + num_channels, num_channels
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoder = WavenetDecoder(num_channels+out_channels,num_channels)
|
decoder = WavenetDecoder(
|
||||||
|
num_channels + out_channels, num_channels
|
||||||
|
)
|
||||||
|
|
||||||
self.decoders.insert(0,decoder)
|
self.decoders.insert(0, decoder)
|
||||||
|
|
||||||
bottleneck_dim = depth * initial_output_channels
|
bottleneck_dim = depth * initial_output_channels
|
||||||
self.bottleneck = nn.Sequential(
|
self.bottleneck = nn.Sequential(
|
||||||
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
|
nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
|
||||||
padding=7),
|
|
||||||
nn.BatchNorm1d(bottleneck_dim),
|
nn.BatchNorm1d(bottleneck_dim),
|
||||||
nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
)
|
)
|
||||||
self.final = nn.Sequential(
|
self.final = nn.Sequential(
|
||||||
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
||||||
nn.Tanh()
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(self, waveform):
|
||||||
self,waveform
|
|
||||||
):
|
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1)!=1:
|
if waveform.size(1) != 1:
|
||||||
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels")
|
raise TypeError(
|
||||||
|
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
out = waveform
|
out = waveform
|
||||||
for encoder in self.encoders:
|
for encoder in self.encoders:
|
||||||
out = encoder(out)
|
out = encoder(out)
|
||||||
encoder_outputs.insert(0,out)
|
encoder_outputs.insert(0, out)
|
||||||
out = out[:,:,::2]
|
out = out[:, :, ::2]
|
||||||
|
|
||||||
out = self.bottleneck(out)
|
out = self.bottleneck(out)
|
||||||
|
|
||||||
for layer,decoder in enumerate(self.decoders):
|
for layer, decoder in enumerate(self.decoders):
|
||||||
out = F.interpolate(out, scale_factor=2, mode="linear")
|
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||||
out = self.fix_last_dim(out,encoder_outputs[layer])
|
out = self.fix_last_dim(out, encoder_outputs[layer])
|
||||||
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
out = torch.cat([out, encoder_outputs[layer]], dim=1)
|
||||||
out = decoder(out)
|
out = decoder(out)
|
||||||
|
|
||||||
out = torch.cat([out, waveform],dim=1)
|
out = torch.cat([out, waveform], dim=1)
|
||||||
out = self.final(out)
|
out = self.final(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def fix_last_dim(self,x,target):
|
def fix_last_dim(self, x, target):
|
||||||
"""
|
"""
|
||||||
trying to do centre crop along last dimension
|
centre crop along last dimension
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension"
|
assert (
|
||||||
|
x.shape[-1] >= target.shape[-1]
|
||||||
|
), "input dimension cannot be larger than target dimension"
|
||||||
if x.shape[-1] == target.shape[-1]:
|
if x.shape[-1] == target.shape[-1]:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
diff = x.shape[-1] - target.shape[-1]
|
diff = x.shape[-1] - target.shape[-1]
|
||||||
if diff%2!=0:
|
if diff % 2 != 0:
|
||||||
x = F.pad(x,(0,1))
|
x = F.pad(x, (0, 1))
|
||||||
diff += 1
|
diff += 1
|
||||||
|
|
||||||
crop = diff//2
|
crop = diff // 2
|
||||||
return x[:,:,crop:-crop]
|
return x[:, :, crop:-crop]
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from enhancer.utils.utils import check_files
|
from enhancer.utils.config import Files
|
||||||
from enhancer.utils.io import Audio
|
from enhancer.utils.io import Audio
|
||||||
from enhancer.utils.config import Files
|
from enhancer.utils.utils import check_files
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Files:
|
class Files:
|
||||||
train_clean : str
|
train_clean: str
|
||||||
train_noisy : str
|
train_noisy: str
|
||||||
test_clean : str
|
test_clean: str
|
||||||
test_noisy : str
|
test_noisy: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,67 @@
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
from typing import Optional
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
class Audio:
|
class Audio:
|
||||||
|
"""
|
||||||
|
Audio utils
|
||||||
|
parameters:
|
||||||
|
sampling_rate : int, defaults to 16KHz
|
||||||
|
audio sampling rate
|
||||||
|
mono: bool, defaults to True
|
||||||
|
return_tensors: bool, defaults to True
|
||||||
|
returns torch tensor type if set to True else numpy ndarray
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
|
||||||
sampling_rate:int=16000,
|
|
||||||
mono:bool=True,
|
|
||||||
return_tensor=True
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
self.mono = mono
|
self.mono = mono
|
||||||
self.return_tensor = return_tensor
|
self.return_tensor = return_tensor
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
audio,
|
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||||
sampling_rate:Optional[int]=None,
|
sampling_rate: Optional[int] = None,
|
||||||
offset:Optional[float] = None,
|
offset: Optional[float] = None,
|
||||||
duration:Optional[float] = None
|
duration: Optional[float] = None,
|
||||||
):
|
):
|
||||||
if isinstance(audio,str):
|
"""
|
||||||
|
read and process input audio
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate : int, optional
|
||||||
|
sampling rate of the audio input
|
||||||
|
offset: float, optional
|
||||||
|
offset from which the audio must be read, reads from beginning if unused.
|
||||||
|
duration: float (seconds), optional
|
||||||
|
read duration, reads full audio starting from offset if not used
|
||||||
|
"""
|
||||||
|
if isinstance(audio, str):
|
||||||
if os.path.exists(audio):
|
if os.path.exists(audio):
|
||||||
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False,
|
audio, sampling_rate = librosa.load(
|
||||||
offset=offset,duration=duration)
|
audio,
|
||||||
|
sr=sampling_rate,
|
||||||
|
mono=False,
|
||||||
|
offset=offset,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.reshape(1,-1)
|
audio = audio.reshape(1, -1)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"File {audio} deos not exist")
|
raise FileNotFoundError(f"File {audio} deos not exist")
|
||||||
elif isinstance(audio,np.ndarray):
|
elif isinstance(audio, np.ndarray):
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.reshape(1,-1)
|
audio = audio.reshape(1, -1)
|
||||||
else:
|
else:
|
||||||
raise ValueError("audio should be either filepath or numpy ndarray")
|
raise ValueError("audio should be either filepath or numpy ndarray")
|
||||||
|
|
||||||
|
|
@ -43,40 +69,60 @@ class Audio:
|
||||||
audio = self.convert_mono(audio)
|
audio = self.convert_mono(audio)
|
||||||
|
|
||||||
if sampling_rate:
|
if sampling_rate:
|
||||||
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate)
|
audio = self.__class__.resample_audio(
|
||||||
|
audio, self.sampling_rate, sampling_rate
|
||||||
|
)
|
||||||
if self.return_tensor:
|
if self.return_tensor:
|
||||||
return torch.tensor(audio)
|
return torch.tensor(audio)
|
||||||
else:
|
else:
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_mono(
|
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
|
||||||
audio
|
"""
|
||||||
|
convert input audio into mono (1)
|
||||||
|
parameters:
|
||||||
|
audio: np.ndarray or torch.Tensor
|
||||||
|
"""
|
||||||
|
if len(audio.shape) > 2:
|
||||||
|
assert (
|
||||||
|
audio.shape[0] == 1
|
||||||
|
), "convert mono only accepts single waveform"
|
||||||
|
audio = audio.reshape(audio.shape[1], audio.shape[2])
|
||||||
|
|
||||||
):
|
assert (
|
||||||
if len(audio.shape)>2:
|
audio.shape[1] >> audio.shape[0]
|
||||||
assert audio.shape[0] == 1, "convert mono only accepts single waveform"
|
), f"expected input format (num_channels,num_samples) got {audio.shape}"
|
||||||
audio = audio.reshape(audio.shape[1],audio.shape[2])
|
num_channels, num_samples = audio.shape
|
||||||
|
if num_channels > 1:
|
||||||
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}"
|
return audio.mean(axis=0).reshape(1, num_samples)
|
||||||
num_channels,num_samples = audio.shape
|
|
||||||
if num_channels>1:
|
|
||||||
return audio.mean(axis=0).reshape(1,num_samples)
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resample_audio(
|
def resample_audio(
|
||||||
audio,
|
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
|
||||||
sr:int,
|
|
||||||
target_sr:int
|
|
||||||
):
|
):
|
||||||
if sr!=target_sr:
|
"""
|
||||||
if isinstance(audio,np.ndarray):
|
resample audio to desired sampling rate
|
||||||
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr)
|
parameters:
|
||||||
elif isinstance(audio,torch.Tensor):
|
audio : Path to audio file or numpy array or torch tensor
|
||||||
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
|
audio waveform
|
||||||
|
sr : int
|
||||||
|
current sampling rate
|
||||||
|
target_sr : int
|
||||||
|
target sampling rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
if sr != target_sr:
|
||||||
|
if isinstance(audio, np.ndarray):
|
||||||
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
||||||
|
elif isinstance(audio, torch.Tensor):
|
||||||
|
audio = torchaudio.functional.resample(
|
||||||
|
audio, orig_freq=sr, new_freq=target_sr
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input should be either numpy array or torch tensor")
|
raise ValueError(
|
||||||
|
"Input should be either numpy array or torch tensor"
|
||||||
|
)
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_unique_rng(epoch: int):
|
||||||
def create_unique_rng(epoch:int):
|
|
||||||
"""create unique random number generator for each (worker_id,epoch) combination"""
|
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||||
|
|
||||||
rng = random.Random()
|
rng = random.Random()
|
||||||
|
|
||||||
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0"))
|
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
|
||||||
global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
|
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
|
||||||
local_rank = int(os.environ.get('LOCAL_RANK',"0"))
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
node_rank = int(os.environ.get('NODE_RANK',"0"))
|
node_rank = int(os.environ.get("NODE_RANK", "0"))
|
||||||
world_size = int(os.environ.get('WORLD_SIZE',"0"))
|
world_size = int(os.environ.get("WORLD_SIZE", "0"))
|
||||||
|
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
if worker_info is not None:
|
if worker_info is not None:
|
||||||
|
|
@ -24,17 +24,13 @@ def create_unique_rng(epoch:int):
|
||||||
worker_id = 0
|
worker_id = 0
|
||||||
|
|
||||||
seed = (
|
seed = (
|
||||||
global_seed
|
global_seed
|
||||||
+ worker_id
|
+ worker_id
|
||||||
+ local_rank * num_workers
|
+ local_rank * num_workers
|
||||||
+ node_rank * num_workers * global_rank
|
+ node_rank * num_workers * global_rank
|
||||||
+ epoch * num_workers * world_size
|
+ epoch * num_workers * world_size
|
||||||
)
|
)
|
||||||
|
|
||||||
rng.seed(seed)
|
rng.seed(seed)
|
||||||
|
|
||||||
return rng
|
return rng
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,26 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from enhancer.utils.config import Files
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
def check_files(root_dir:str, files:Files):
|
|
||||||
|
|
||||||
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')]
|
def check_files(root_dir: str, files: Files):
|
||||||
|
|
||||||
|
path_variables = [
|
||||||
|
member_var
|
||||||
|
for member_var in dir(files)
|
||||||
|
if not member_var.startswith("__")
|
||||||
|
]
|
||||||
for variable in path_variables:
|
for variable in path_variables:
|
||||||
path = getattr(files,variable)
|
path = getattr(files, variable)
|
||||||
if not os.path.isdir(os.path.join(root_dir,path)):
|
if not os.path.isdir(os.path.join(root_dir, path)):
|
||||||
raise ValueError(f"Invalid {path}, is not a directory")
|
raise ValueError(f"Invalid {path}, is not a directory")
|
||||||
|
|
||||||
return files,root_dir
|
|
||||||
|
|
||||||
def merge_dict(default_dict:dict, custom:Optional[dict]=None):
|
return files, root_dir
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
|
||||||
|
|
||||||
params = dict(default_dict)
|
params = dict(default_dict)
|
||||||
if custom:
|
if custom:
|
||||||
params.update(custom)
|
params.update(custom)
|
||||||
|
|
|
||||||
|
|
@ -5,4 +5,4 @@ dependencies:
|
||||||
- python=3.8
|
- python=3.8
|
||||||
- pip:
|
- pip:
|
||||||
- -r requirements.txt
|
- -r requirements.txt
|
||||||
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html
|
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ mkdir temp
|
||||||
pwd
|
pwd
|
||||||
|
|
||||||
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train
|
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train
|
||||||
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
|
#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test
|
||||||
|
|
||||||
echo "Start Training..."
|
echo "Start Training..."
|
||||||
python cli/train.py
|
python cli/train.py
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
[tool.black]
|
||||||
|
line-length = 80
|
||||||
|
target-version = ['py38']
|
||||||
|
exclude = '''
|
||||||
|
|
||||||
|
(
|
||||||
|
/(
|
||||||
|
\.eggs # exclude a few common directories in the
|
||||||
|
| \.git # root of the project
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.tox
|
||||||
|
| \.venv
|
||||||
|
)/
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
|
@ -1,15 +1,16 @@
|
||||||
joblib==1.1.0
|
black>=22.8.0
|
||||||
numpy==1.19.5
|
boto3>=1.24.86
|
||||||
librosa==0.9.1
|
flake8>=5.0.4
|
||||||
numpy==1.19.5
|
huggingface-hu>=0.10.0
|
||||||
hydra-core==1.2.0
|
hydra-core>=1.2.0
|
||||||
scikit-learn==0.24.2
|
joblib>=1.2.0
|
||||||
scipy==1.5.4
|
librosa>=0.9.2
|
||||||
torch==1.10.2
|
mlflow>=1.29.0
|
||||||
tqdm==4.64.0
|
numpy>=1.23.3
|
||||||
mlflow==1.23.1
|
protobuf>=3.19.6
|
||||||
protobuf==3.19.3
|
pytorch-lightning>=1.7.7
|
||||||
boto3==1.23.9
|
scikit-learn>=1.1.2
|
||||||
torchaudio==0.10.2
|
scipy>=1.9.1
|
||||||
huggingface-hub==0.4.0
|
torch>=1.12.1
|
||||||
pytorch-lightning==1.5.10
|
torchaudio>=0.12.1
|
||||||
|
tqdm>=4.64.1
|
||||||
|
|
|
||||||
2
setup.sh
2
setup.sh
|
|
@ -10,4 +10,4 @@ conda env create -f environment.yml || conda env update -f environment.yml
|
||||||
source activate enhancer
|
source activate enhancer
|
||||||
|
|
||||||
echo "copying files"
|
echo "copying files"
|
||||||
# cp /scratch/$USER/TIMIT/.* /deep-transcriber
|
# cp /scratch/$USER/TIMIT/.* /deep-transcriber
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,32 @@
|
||||||
from asyncio import base_tasks
|
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from enhancer.loss import mean_absolute_error, mean_squared_error
|
from enhancer.loss import mean_absolute_error, mean_squared_error
|
||||||
|
|
||||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||||
|
|
||||||
|
|
||||||
def check_loss_shapes_compatibility(loss_fun):
|
def check_loss_shapes_compatibility(loss_fun):
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
shape = (1,1000)
|
shape = (1, 1000)
|
||||||
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
|
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
|
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("loss",loss_functions)
|
@pytest.mark.parametrize("loss", loss_functions)
|
||||||
def test_loss_input_shapes(loss):
|
def test_loss_input_shapes(loss):
|
||||||
check_loss_shapes_compatibility(loss)
|
check_loss_shapes_compatibility(loss)
|
||||||
|
|
||||||
@pytest.mark.parametrize("loss",loss_functions)
|
|
||||||
|
@pytest.mark.parametrize("loss", loss_functions)
|
||||||
def test_loss_output_type(loss):
|
def test_loss_output_type(loss):
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
|
||||||
|
batch_size, 1, 1000
|
||||||
|
)
|
||||||
loss_value = loss(prediction, target)
|
loss_value = loss(prediction, target)
|
||||||
assert isinstance(loss_value.item(),float)
|
assert isinstance(loss_value.item(), float)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,46 +1,43 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from enhancer import data
|
|
||||||
|
|
||||||
from enhancer.utils.config import Files
|
|
||||||
from enhancer.models import Demucs
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models import Demucs
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vctk_dataset():
|
def vctk_dataset():
|
||||||
root_dir = "tests/data/vctk"
|
root_dir = "tests/data/vctk"
|
||||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
files = Files(
|
||||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
train_clean="clean_testset_wav",
|
||||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
def test_forward(batch_size,samples):
|
def test_forward(batch_size, samples):
|
||||||
model = Demucs()
|
model = Demucs()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset,channels,loss",
|
@pytest.mark.parametrize(
|
||||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
"dataset,channels,loss",
|
||||||
def test_demucs_init(dataset,channels,loss):
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = Demucs(num_channels=channels,dataset=dataset,loss=loss)
|
_ = Demucs(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,46 +1,43 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from enhancer import data
|
|
||||||
|
|
||||||
from enhancer.utils.config import Files
|
|
||||||
from enhancer.models import WaveUnet
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models import WaveUnet
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vctk_dataset():
|
def vctk_dataset():
|
||||||
root_dir = "tests/data/vctk"
|
root_dir = "tests/data/vctk"
|
||||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
files = Files(
|
||||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
train_clean="clean_testset_wav",
|
||||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
def test_forward(batch_size,samples):
|
def test_forward(batch_size, samples):
|
||||||
model = WaveUnet()
|
model = WaveUnet()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset,channels,loss",
|
@pytest.mark.parametrize(
|
||||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
"dataset,channels,loss",
|
||||||
def test_demucs_init(dataset,channels,loss):
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
|
_ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,26 @@ import torch
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)])
|
@pytest.mark.parametrize(
|
||||||
|
"audio",
|
||||||
|
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
|
||||||
|
)
|
||||||
def test_read_input(audio):
|
def test_read_input(audio):
|
||||||
|
|
||||||
read_audio = Inference.read_input(audio,48000,16000)
|
read_audio = Inference.read_input(audio, 48000, 16000)
|
||||||
assert isinstance(read_audio,torch.Tensor)
|
assert isinstance(read_audio, torch.Tensor)
|
||||||
assert read_audio.shape[0] == 1
|
assert read_audio.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_batchify():
|
def test_batchify():
|
||||||
rand = torch.rand(1,1000)
|
rand = torch.rand(1, 1000)
|
||||||
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100)
|
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
|
||||||
assert batched_rand.shape[0] == 12
|
assert batched_rand.shape[0] == 12
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate():
|
def test_aggregate():
|
||||||
rand = torch.rand(12,1,100)
|
rand = torch.rand(12, 1, 100)
|
||||||
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100)
|
agg_rand = Inference.aggreagate(
|
||||||
|
data=rand, window_size=100, total_frames=1000, step_size=100
|
||||||
|
)
|
||||||
assert agg_rand.shape[-1] == 1000
|
assert agg_rand.shape[-1] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,46 +1,50 @@
|
||||||
from logging import root
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from enhancer.utils.io import Audio
|
|
||||||
from enhancer.utils.config import Files
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
from enhancer.utils.io import Audio
|
||||||
|
|
||||||
|
|
||||||
def test_io_channel():
|
def test_io_channel():
|
||||||
|
|
||||||
input_audio = np.random.rand(2,32000)
|
input_audio = np.random.rand(2, 32000)
|
||||||
audio = Audio(mono=True,return_tensor=False)
|
audio = Audio(mono=True, return_tensor=False)
|
||||||
output_audio = audio(input_audio)
|
output_audio = audio(input_audio)
|
||||||
assert output_audio.shape[0] == 1
|
assert output_audio.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_io_resampling():
|
def test_io_resampling():
|
||||||
|
|
||||||
input_audio = np.random.rand(1,32000)
|
input_audio = np.random.rand(1, 32000)
|
||||||
resampled_audio = Audio.resample_audio(input_audio,16000,8000)
|
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
|
||||||
|
|
||||||
input_audio = torch.rand(1,32000)
|
input_audio = torch.rand(1, 32000)
|
||||||
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000)
|
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
|
||||||
|
|
||||||
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
||||||
|
|
||||||
|
|
||||||
def test_fileprocessor_vctk():
|
def test_fileprocessor_vctk():
|
||||||
|
|
||||||
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav",
|
fp = Fileprocessor.from_name(
|
||||||
"tests/data/vctk/noisy_testset_wav",48000)
|
"vctk",
|
||||||
|
"tests/data/vctk/clean_testset_wav",
|
||||||
|
"tests/data/vctk/noisy_testset_wav",
|
||||||
|
48000,
|
||||||
|
)
|
||||||
matching_dict = fp.prepare_matching_dict()
|
matching_dict = fp.prepare_matching_dict()
|
||||||
assert len(matching_dict)==2
|
assert len(matching_dict) == 2
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
|
|
||||||
|
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||||
def test_fileprocessor_names(dataset_name):
|
def test_fileprocessor_names(dataset_name):
|
||||||
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000)
|
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
|
||||||
assert hasattr(fp.matching_function, '__call__')
|
assert hasattr(fp.matching_function, "__call__")
|
||||||
|
|
||||||
|
|
||||||
def test_fileprocessor_invaliname():
|
def test_fileprocessor_invaliname():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict()
|
_ = Fileprocessor.from_name(
|
||||||
|
"undefined", "clean_dir", "noisy_dir", 16000
|
||||||
|
).prepare_matching_dict()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue