Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
a064151e2e
|
|
@ -0,0 +1,9 @@
|
|||
[flake8]
|
||||
per-file-ignores = __init__.py:F401
|
||||
ignore = E203, E266, E501, W503
|
||||
# line length is intentionally set to 80 here because black uses Bugbear
|
||||
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
||||
max-line-length = 80
|
||||
max-complexity = 18
|
||||
select = B,C,E,F,W,T4,B9
|
||||
exclude = tools/kaldi_decoder
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
repos:
|
||||
# # Clean Notebooks
|
||||
# - repo: https://github.com/kynan/nbstripout
|
||||
# rev: master
|
||||
# hooks:
|
||||
# - id: nbstripout
|
||||
# Format Code
|
||||
- repo: https://github.com/ambv/black
|
||||
rev: 22.8.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
# Sort imports
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile", "black"]
|
||||
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ['--ignore=E203,E501,F811,E712,W503']
|
||||
|
||||
# Formatting, Whitespace, etc
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: check-ast
|
||||
- id: check-json
|
||||
- id: check-merge-conflict
|
||||
- id: check-xml
|
||||
- id: check-yaml
|
||||
- id: debug-statements
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: mixed-line-ending
|
||||
args: ['--fix=no']
|
||||
|
|
@ -1 +1,6 @@
|
|||
# enhancer
|
||||
Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides
|
||||
|
||||
* Various pretrained models nicely integrated with huggingface that users can select and use without any hastle.
|
||||
* Ability to train and validation your own custom speech enhancement models with just under 10 lines of code!
|
||||
* A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself!
|
||||
67
cli/train.py
67
cli/train.py
|
|
@ -1,67 +0,0 @@
|
|||
from genericpath import isfile
|
||||
import os
|
||||
from types import MethodType
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID","0")
|
||||
|
||||
@hydra.main(config_path="train_config",config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
|
||||
|
||||
|
||||
parameters = config.hyperparameters
|
||||
|
||||
dataset = instantiate(config.dataset)
|
||||
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True,
|
||||
mode=direction,every_n_epochs=1
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience",10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizer(self):
|
||||
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor",0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr",1e-6),
|
||||
patience=parameters.get("ReduceLr_patience",3)
|
||||
)
|
||||
return {"optimizer":optimizer, "lr_scheduler":scheduler}
|
||||
|
||||
model.configure_parameters = MethodType(configure_optimizer,model)
|
||||
|
||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
|
||||
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id,saved_location)
|
||||
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
from types import MethodType
|
||||
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name,
|
||||
tags={"JOB_ID": JOB_ID},
|
||||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
|
||||
dataset = instantiate(config.dataset)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"),
|
||||
metric=parameters.get("metric"),
|
||||
)
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="./model",
|
||||
filename=f"model_{JOB_ID}",
|
||||
monitor="val_loss",
|
||||
verbose=True,
|
||||
mode=direction,
|
||||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=parameters.get("EarlyStopping_patience", 10),
|
||||
strict=True,
|
||||
verbose=False,
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizer(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
parameters=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
mode=direction,
|
||||
factor=parameters.get("ReduceLr_factor", 0.1),
|
||||
verbose=True,
|
||||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
||||
|
||||
model.configure_parameters = MethodType(configure_optimizer, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
|
||||
saved_location = os.path.join(
|
||||
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
|
||||
)
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id, saved_location)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -10,4 +10,3 @@ files:
|
|||
test_clean : clean_test_wav
|
||||
train_noisy : clean_test_wav
|
||||
test_noisy : clean_test_wav
|
||||
|
||||
|
|
@ -10,6 +10,3 @@ files:
|
|||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
_target_: enhancer.data.dataset.EnhancerDataset
|
||||
name : vctk
|
||||
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
|
||||
duration : 1.0
|
||||
sampling_rate: 16000
|
||||
batch_size: 64
|
||||
num_workers : 0
|
||||
|
||||
files:
|
||||
train_clean : clean_testset_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_testset_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
|
@ -5,4 +5,3 @@ ReduceLr_patience : 5
|
|||
ReduceLr_factor : 0.1
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
||||
|
|
@ -14,5 +14,3 @@ encoder_decoder:
|
|||
lstm:
|
||||
bidirectional: False
|
||||
num_layers: 2
|
||||
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
from enhancer.data.dataset import EnhancerDataset
|
||||
|
|
@ -1,20 +1,21 @@
|
|||
import multiprocessing
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import IterableDataset, DataLoader, Dataset
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils import check_files
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
|
||||
|
||||
class TrainDataset(IterableDataset):
|
||||
|
||||
def __init__(self,dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
|
|
@ -23,72 +24,86 @@ class TrainDataset(IterableDataset):
|
|||
def __len__(self):
|
||||
return self.dataset.train__len__()
|
||||
|
||||
class ValidDataset(Dataset):
|
||||
|
||||
def __init__(self,dataset):
|
||||
class ValidDataset(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self,idx):
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset.val__getitem__(idx)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.val__len__()
|
||||
|
||||
class TaskDataset(pl.LightningDataModule):
|
||||
|
||||
class TaskDataset(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
name:str,
|
||||
root_dir:str,
|
||||
files:Files,
|
||||
duration:float=1.0,
|
||||
sampling_rate:int=48000,
|
||||
matching_function = None,
|
||||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
duration: float = 1.0,
|
||||
sampling_rate: int = 48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
self.files,self.root_dir = check_files(root_dir,files)
|
||||
self.files, self.root_dir = check_files(root_dir, files)
|
||||
self.duration = duration
|
||||
self.sampling_rate = sampling_rate
|
||||
self.batch_size = batch_size
|
||||
self.matching_function = matching_function
|
||||
self._validation = []
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count()//2
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
self.num_workers = num_workers
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
if stage in ("fit",None):
|
||||
if stage in ("fit", None):
|
||||
|
||||
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
||||
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,train_clean,
|
||||
train_noisy, self.matching_function)
|
||||
train_clean = os.path.join(self.root_dir, self.files.train_clean)
|
||||
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
|
||||
fp = Fileprocessor.from_name(
|
||||
self.name, train_clean, train_noisy, self.matching_function
|
||||
)
|
||||
self.train_data = fp.prepare_matching_dict()
|
||||
|
||||
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,val_clean,
|
||||
val_noisy, self.matching_function)
|
||||
val_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(
|
||||
self.name, val_clean, val_noisy, self.matching_function
|
||||
)
|
||||
val_data = fp.prepare_matching_dict()
|
||||
|
||||
for item in val_data:
|
||||
clean,noisy,total_dur = item.values()
|
||||
clean, noisy, total_dur = item.values()
|
||||
if total_dur < self.duration:
|
||||
continue
|
||||
num_segments = round(total_dur/self.duration)
|
||||
num_segments = round(total_dur / self.duration)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
self._validation.append(({"clean":clean,"noisy":noisy},
|
||||
start_time))
|
||||
self._validation.append(
|
||||
({"clean": clean, "noisy": noisy}, start_time)
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
return DataLoader(
|
||||
TrainDataset(self),
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
return DataLoader(
|
||||
ValidDataset(self),
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
"""
|
||||
|
|
@ -100,7 +115,7 @@ class EnhancerDataset(TaskDataset):
|
|||
root directory of the dataset containing clean/noisy folders
|
||||
files : Files
|
||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||
folder names (refer cli/train_config/dataset)
|
||||
folder names (refer enhancer.utils.Files dataclass)
|
||||
duration : float
|
||||
expected audio duration of single audio sample for training
|
||||
sampling_rate : int
|
||||
|
|
@ -119,14 +134,15 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
name:str,
|
||||
root_dir:str,
|
||||
files:Files,
|
||||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
duration=1.0,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
name=name,
|
||||
|
|
@ -134,18 +150,17 @@ class EnhancerDataset(TaskDataset):
|
|||
files=files,
|
||||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function = matching_function,
|
||||
matching_function=matching_function,
|
||||
batch_size=batch_size,
|
||||
num_workers = num_workers,
|
||||
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
self.files = files
|
||||
self.duration = max(1.0,duration)
|
||||
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
|
||||
self.duration = max(1.0, duration)
|
||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||
|
||||
def setup(self, stage:Optional[str]=None):
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
super().setup(stage=stage)
|
||||
|
||||
|
|
@ -155,30 +170,51 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
while True:
|
||||
|
||||
file_dict,*_ = rng.choices(self.train_data,k=1,
|
||||
weights=[file["duration"] for file in self.train_data])
|
||||
file_duration = file_dict['duration']
|
||||
start_time = round(rng.uniform(0,file_duration- self.duration),2)
|
||||
data = self.prepare_segment(file_dict,start_time)
|
||||
file_dict, *_ = rng.choices(
|
||||
self.train_data,
|
||||
k=1,
|
||||
weights=[file["duration"] for file in self.train_data],
|
||||
)
|
||||
file_duration = file_dict["duration"]
|
||||
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
|
||||
data = self.prepare_segment(file_dict, start_time)
|
||||
yield data
|
||||
|
||||
def val__getitem__(self,idx):
|
||||
def val__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._validation[idx])
|
||||
|
||||
def prepare_segment(self,file_dict:dict, start_time:float):
|
||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||
|
||||
clean_segment = self.audio(file_dict["clean"],
|
||||
offset=start_time,duration=self.duration)
|
||||
noisy_segment = self.audio(file_dict["noisy"],
|
||||
offset=start_time,duration=self.duration)
|
||||
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
|
||||
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
|
||||
return {"clean": clean_segment,"noisy":noisy_segment}
|
||||
clean_segment = self.audio(
|
||||
file_dict["clean"], offset=start_time, duration=self.duration
|
||||
)
|
||||
noisy_segment = self.audio(
|
||||
file_dict["noisy"], offset=start_time, duration=self.duration
|
||||
)
|
||||
clean_segment = F.pad(
|
||||
clean_segment,
|
||||
(
|
||||
0,
|
||||
int(
|
||||
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
)
|
||||
noisy_segment = F.pad(
|
||||
noisy_segment,
|
||||
(
|
||||
0,
|
||||
int(
|
||||
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
)
|
||||
return {"clean": clean_segment, "noisy": noisy_segment}
|
||||
|
||||
def train__len__(self):
|
||||
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
|
||||
return math.ceil(
|
||||
sum([file["duration"] for file in self.train_data]) / self.duration
|
||||
)
|
||||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,108 +1,118 @@
|
|||
import glob
|
||||
import os
|
||||
from re import S
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
MATCHING_FNS = ("one_to_one","one_to_many")
|
||||
MATCHING_FNS = ("one_to_one", "one_to_many")
|
||||
|
||||
|
||||
class ProcessorFunctions:
|
||||
"""
|
||||
Preprocessing methods for different types of speech enhacement datasets.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def one_to_one(clean_path,noisy_path):
|
||||
def one_to_one(clean_path, noisy_path):
|
||||
"""
|
||||
One clean audio can have only one noisy audio file
|
||||
"""
|
||||
|
||||
matching_wavfiles = list()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))]
|
||||
common_filenames = np.intersect1d(noisy_filenames,clean_filenames)
|
||||
clean_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||
]
|
||||
noisy_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
|
||||
]
|
||||
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
|
||||
|
||||
for file_name in common_filenames:
|
||||
|
||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
|
||||
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
|
||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||
(sr_clean==sr_noisy)):
|
||||
sr_clean, clean_file = wavfile.read(
|
||||
os.path.join(clean_path, file_name)
|
||||
)
|
||||
sr_noisy, noisy_file = wavfile.read(
|
||||
os.path.join(noisy_path, file_name)
|
||||
)
|
||||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||
sr_clean == sr_noisy
|
||||
):
|
||||
matching_wavfiles.append(
|
||||
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
|
||||
"duration":clean_file.shape[-1]/sr_clean}
|
||||
)
|
||||
{
|
||||
"clean": os.path.join(clean_path, file_name),
|
||||
"noisy": os.path.join(noisy_path, file_name),
|
||||
"duration": clean_file.shape[-1] / sr_clean,
|
||||
}
|
||||
)
|
||||
return matching_wavfiles
|
||||
|
||||
@staticmethod
|
||||
def one_to_many(clean_path,noisy_path):
|
||||
def one_to_many(clean_path, noisy_path):
|
||||
"""
|
||||
One clean audio have multiple noisy audio files
|
||||
"""
|
||||
|
||||
matching_wavfiles = dict()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
clean_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||
]
|
||||
for clean_file in clean_filenames:
|
||||
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav"))
|
||||
noisy_filenames = glob.glob(
|
||||
os.path.join(noisy_path, f"*_{clean_file}.wav")
|
||||
)
|
||||
for noisy_file in noisy_filenames:
|
||||
|
||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
|
||||
sr_clean, clean_file = wavfile.read(
|
||||
os.path.join(clean_path, clean_file)
|
||||
)
|
||||
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||
(sr_clean==sr_noisy)):
|
||||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||
sr_clean == sr_noisy
|
||||
):
|
||||
matching_wavfiles.update(
|
||||
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
|
||||
"duration":clean_file.shape[-1]/sr_clean}
|
||||
)
|
||||
{
|
||||
"clean": os.path.join(clean_path, clean_file),
|
||||
"noisy": noisy_file,
|
||||
"duration": clean_file.shape[-1] / sr_clean,
|
||||
}
|
||||
)
|
||||
|
||||
return matching_wavfiles
|
||||
|
||||
|
||||
class Fileprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
matching_function = None
|
||||
):
|
||||
def __init__(self, clean_dir, noisy_dir, matching_function=None):
|
||||
self.clean_dir = clean_dir
|
||||
self.noisy_dir = noisy_dir
|
||||
self.matching_function = matching_function
|
||||
|
||||
@classmethod
|
||||
def from_name(cls,
|
||||
name:str,
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
matching_function=None
|
||||
):
|
||||
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||
|
||||
if matching_function is None:
|
||||
if name.lower() == "vctk":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||
else:
|
||||
if matching_function not in MATCHING_FNS:
|
||||
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
|
||||
raise ValueError(
|
||||
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
|
||||
)
|
||||
else:
|
||||
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
|
||||
|
||||
|
||||
return cls(
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
getattr(ProcessorFunctions, matching_function),
|
||||
)
|
||||
|
||||
def prepare_matching_dict(self):
|
||||
|
||||
if self.matching_function is None:
|
||||
raise ValueError("Not a valid matching function")
|
||||
|
||||
return self.matching_function(self.clean_dir,self.noisy_dir)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return self.matching_function(self.clean_dir, self.noisy_dir)
|
||||
|
|
|
|||
|
|
@ -1,119 +1,168 @@
|
|||
from json import load
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
from scipy.io import wavfile
|
||||
from typing import List, Optional, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
from librosa import load as load_audio
|
||||
from scipy.io import wavfile
|
||||
from scipy.signal import get_window
|
||||
|
||||
from enhancer.utils import Audio
|
||||
|
||||
|
||||
class Inference:
|
||||
"""
|
||||
contains methods used for inference.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def read_input(audio, sr, model_sr):
|
||||
"""
|
||||
read and verify audio input regardless of the input format.
|
||||
arguments:
|
||||
audio : audio input
|
||||
sr : sampling rate of input audio
|
||||
model_sr : sampling rate used for model training.
|
||||
"""
|
||||
|
||||
if isinstance(audio,(np.ndarray,torch.Tensor)):
|
||||
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
||||
assert sr is not None, "Invalid sampling rate!"
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.reshape(1,-1)
|
||||
audio = audio.reshape(1, -1)
|
||||
|
||||
if isinstance(audio,str):
|
||||
if isinstance(audio, str):
|
||||
audio = Path(audio)
|
||||
if not audio.is_file():
|
||||
raise ValueError(f"Input file {audio} does not exist")
|
||||
else:
|
||||
audio,sr = load_audio(audio,sr=sr,)
|
||||
audio, sr = load_audio(
|
||||
audio,
|
||||
sr=sr,
|
||||
)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.reshape(1,-1)
|
||||
audio = audio.reshape(1, -1)
|
||||
else:
|
||||
assert audio.shape[0] == 1, "Enhance inference only supports single waveform"
|
||||
assert (
|
||||
audio.shape[0] == 1
|
||||
), "Enhance inference only supports single waveform"
|
||||
|
||||
waveform = Audio.resample_audio(audio,sr=sr,target_sr=model_sr)
|
||||
waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr)
|
||||
waveform = Audio.convert_mono(waveform)
|
||||
if isinstance(waveform,np.ndarray):
|
||||
if isinstance(waveform, np.ndarray):
|
||||
waveform = torch.from_numpy(waveform)
|
||||
|
||||
return waveform
|
||||
|
||||
@staticmethod
|
||||
def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None):
|
||||
def batchify(
|
||||
waveform: torch.Tensor,
|
||||
window_size: int,
|
||||
step_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
break input waveform into samples with duration specified.
|
||||
break input waveform into samples with duration specified.(Overlap-add)
|
||||
arguments:
|
||||
waveform : audio waveform
|
||||
window_size : window size used for splitting waveform into batches
|
||||
step_size : step_size used for splitting waveform into batches
|
||||
"""
|
||||
assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
||||
_,num_samples = waveform.shape
|
||||
assert (
|
||||
waveform.ndim == 2
|
||||
), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}"
|
||||
_, num_samples = waveform.shape
|
||||
waveform = waveform.unsqueeze(-1)
|
||||
step_size = window_size//2 if step_size is None else step_size
|
||||
step_size = window_size // 2 if step_size is None else step_size
|
||||
|
||||
if num_samples >= window_size:
|
||||
waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1),
|
||||
stride=(step_size,1), padding=(window_size,0))
|
||||
waveform_batch = waveform_batch.permute(2,0,1)
|
||||
|
||||
waveform_batch = F.unfold(
|
||||
waveform[None, ...],
|
||||
kernel_size=(window_size, 1),
|
||||
stride=(step_size, 1),
|
||||
padding=(window_size, 0),
|
||||
)
|
||||
waveform_batch = waveform_batch.permute(2, 0, 1)
|
||||
|
||||
return waveform_batch
|
||||
|
||||
@staticmethod
|
||||
def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None,
|
||||
window="hanning",):
|
||||
def aggreagate(
|
||||
data: torch.Tensor,
|
||||
window_size: int,
|
||||
total_frames: int,
|
||||
step_size: Optional[int] = None,
|
||||
window="hanning",
|
||||
):
|
||||
"""
|
||||
takes input as tensor outputs aggregated waveform
|
||||
stitch batched waveform into single waveform. (Overlap-add)
|
||||
arguments:
|
||||
data: batched waveform
|
||||
window_size : window_size used to batch waveform
|
||||
step_size : step_size used to batch waveform
|
||||
total_frames : total number of frames present in original waveform
|
||||
window : type of window used for overlap-add mechanism.
|
||||
"""
|
||||
num_chunks,n_channels,num_frames = data.shape
|
||||
window = get_window(window=window,Nx=data.shape[-1])
|
||||
num_chunks, n_channels, num_frames = data.shape
|
||||
window = get_window(window=window, Nx=data.shape[-1])
|
||||
window = torch.from_numpy(window).to(data.device)
|
||||
data *= window
|
||||
step_size = window_size//2 if step_size is None else step_size
|
||||
step_size = window_size // 2 if step_size is None else step_size
|
||||
|
||||
data = data.permute(1, 2, 0)
|
||||
data = F.fold(
|
||||
data,
|
||||
(total_frames, 1),
|
||||
kernel_size=(window_size, 1),
|
||||
stride=(step_size, 1),
|
||||
padding=(window_size, 0),
|
||||
).squeeze(-1)
|
||||
|
||||
data = data.permute(1,2,0)
|
||||
data = F.fold(data,
|
||||
(total_frames,1),
|
||||
kernel_size=(window_size,1),
|
||||
stride=(step_size,1),
|
||||
padding=(window_size,0)).squeeze(-1)
|
||||
|
||||
return data.reshape(1,n_channels,-1)
|
||||
return data.reshape(1, n_channels, -1)
|
||||
|
||||
@staticmethod
|
||||
def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int):
|
||||
def write_output(
|
||||
waveform: torch.Tensor, filename: Union[str, Path], sr: int
|
||||
):
|
||||
"""
|
||||
write audio output as wav file
|
||||
arguments:
|
||||
waveform : audio waveform
|
||||
filename : name of the wave file. Output will be written as cleaned_filename.wav
|
||||
sr : sampling rate
|
||||
"""
|
||||
|
||||
if isinstance(filename,str):
|
||||
if isinstance(filename, str):
|
||||
filename = Path(filename)
|
||||
|
||||
parent, name = filename.parent, "cleaned_"+filename.name
|
||||
filename = parent/Path(name)
|
||||
parent, name = filename.parent, "cleaned_" + filename.name
|
||||
filename = parent / Path(name)
|
||||
if filename.is_file():
|
||||
raise FileExistsError(f"file {filename} already exists")
|
||||
else:
|
||||
if isinstance(waveform,torch.Tensor):
|
||||
waveform = waveform.detach().cpu().squeeze().numpy()
|
||||
wavfile.write(filename,rate=sr,data=waveform)
|
||||
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
|
||||
|
||||
@staticmethod
|
||||
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
|
||||
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
|
||||
def prepare_output(
|
||||
waveform: torch.Tensor,
|
||||
model_sampling_rate: int,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
sampling_rate: Optional[int],
|
||||
):
|
||||
if isinstance(audio,np.ndarray):
|
||||
"""
|
||||
prepare output audio based on input format
|
||||
arguments:
|
||||
waveform : predicted audio waveform
|
||||
model_sampling_rate : sampling rate used to train the model
|
||||
audio : input audio
|
||||
sampling_rate : input audio sampling rate
|
||||
|
||||
"""
|
||||
if isinstance(audio, np.ndarray):
|
||||
waveform = waveform.detach().cpu().numpy()
|
||||
|
||||
if sampling_rate!=None:
|
||||
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
|
||||
if sampling_rate is not None:
|
||||
waveform = Audio.resample_audio(
|
||||
waveform, sr=model_sampling_rate, target_sr=sampling_rate
|
||||
)
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
105
enhancer/loss.py
105
enhancer/loss.py
|
|
@ -3,62 +3,82 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class mean_squared_error(nn.Module):
|
||||
"""
|
||||
Mean squared error / L1 loss
|
||||
"""
|
||||
|
||||
def __init__(self,reduction="mean"):
|
||||
def __init__(self, reduction="mean"):
|
||||
super().__init__()
|
||||
|
||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||
self.higher_better = False
|
||||
|
||||
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
if prediction.size() != target.size() or target.ndim < 3:
|
||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead""")
|
||||
raise TypeError(
|
||||
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead"""
|
||||
)
|
||||
|
||||
return self.loss_fun(prediction, target)
|
||||
|
||||
class mean_absolute_error(nn.Module):
|
||||
|
||||
def __init__(self,reduction="mean"):
|
||||
class mean_absolute_error(nn.Module):
|
||||
"""
|
||||
Mean absolute error / L2 loss
|
||||
"""
|
||||
|
||||
def __init__(self, reduction="mean"):
|
||||
super().__init__()
|
||||
|
||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||
self.higher_better = False
|
||||
|
||||
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
if prediction.size() != target.size() or target.ndim < 3:
|
||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead""")
|
||||
raise TypeError(
|
||||
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead"""
|
||||
)
|
||||
|
||||
return self.loss_fun(prediction, target)
|
||||
|
||||
class Si_SDR(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduction:str="mean"
|
||||
):
|
||||
class Si_SDR(nn.Module):
|
||||
"""
|
||||
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
||||
"""
|
||||
|
||||
def __init__(self, reduction: str = "mean"):
|
||||
super().__init__()
|
||||
if reduction in ["sum","mean",None]:
|
||||
if reduction in ["sum", "mean", None]:
|
||||
self.reduction = reduction
|
||||
else:
|
||||
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
||||
raise TypeError(
|
||||
"Invalid reduction, valid options are sum, mean, None"
|
||||
)
|
||||
self.higher_better = False
|
||||
|
||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
if prediction.size() != target.size() or target.ndim < 3:
|
||||
raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead""")
|
||||
raise TypeError(
|
||||
f"""Inputs must be of the same shape (batch_size,channels,samples)
|
||||
got {prediction.size()} and {target.size()} instead"""
|
||||
)
|
||||
|
||||
target_energy = torch.sum(target**2,keepdim=True,dim=-1)
|
||||
scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy
|
||||
target_energy = torch.sum(target**2, keepdim=True, dim=-1)
|
||||
scaling_factor = (
|
||||
torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy
|
||||
)
|
||||
target_projection = target * scaling_factor
|
||||
noise = prediction - target_projection
|
||||
ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1)
|
||||
si_sdr = 10*torch.log10(ratio).mean(dim=-1)
|
||||
ratio = torch.sum(target_projection**2, dim=-1) / torch.sum(
|
||||
noise**2, dim=-1
|
||||
)
|
||||
si_sdr = 10 * torch.log10(ratio).mean(dim=-1)
|
||||
|
||||
if self.reduction == "sum":
|
||||
si_sdr = si_sdr.sum()
|
||||
|
|
@ -70,31 +90,42 @@ class Si_SDR(nn.Module):
|
|||
return si_sdr
|
||||
|
||||
|
||||
|
||||
class Avergeloss(nn.Module):
|
||||
"""
|
||||
Combine multiple metics of same nature.
|
||||
for example, ["mea","mae"]
|
||||
parameters:
|
||||
losses : loss function names to be combined
|
||||
"""
|
||||
|
||||
def __init__(self,losses):
|
||||
def __init__(self, losses):
|
||||
super().__init__()
|
||||
|
||||
self.valid_losses = nn.ModuleList()
|
||||
|
||||
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
|
||||
direction = [
|
||||
getattr(LOSS_MAP[loss](), "higher_better") for loss in losses
|
||||
]
|
||||
if len(set(direction)) > 1:
|
||||
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
|
||||
raise ValueError(
|
||||
"all cost functions should be of same nature, maximize or minimize!"
|
||||
)
|
||||
|
||||
self.higher_better = direction[0]
|
||||
for loss in losses:
|
||||
loss = self.validate_loss(loss)
|
||||
self.valid_losses.append(loss())
|
||||
|
||||
|
||||
def validate_loss(self,loss:str):
|
||||
def validate_loss(self, loss: str):
|
||||
if loss not in LOSS_MAP.keys():
|
||||
raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}")
|
||||
raise ValueError(
|
||||
f"""Invalid loss function {loss}, available loss functions are
|
||||
{tuple([loss for loss in LOSS_MAP.keys()])}"""
|
||||
)
|
||||
else:
|
||||
return LOSS_MAP[loss]
|
||||
|
||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
loss = 0.0
|
||||
for loss_fun in self.valid_losses:
|
||||
loss += loss_fun(prediction, target)
|
||||
|
|
@ -102,10 +133,8 @@ class Avergeloss(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
|
||||
|
||||
LOSS_MAP = {"mae":mean_absolute_error,
|
||||
"mse": mean_squared_error,
|
||||
"SI-SDR":Si_SDR}
|
||||
|
||||
|
||||
LOSS_MAP = {
|
||||
"mae": mean_absolute_error,
|
||||
"mse": mean_squared_error,
|
||||
"SI-SDR": Si_SDR,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from enhancer.models.demucs import Demucs
|
||||
from enhancer.models.waveunet import WaveUnet
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.models.waveunet import WaveUnet
|
||||
|
|
|
|||
|
|
@ -1,217 +1,264 @@
|
|||
import logging
|
||||
from typing import Optional, Union, List
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.utils.io import Audio as audio
|
||||
from enhancer.utils.utils import merge_dict
|
||||
|
||||
|
||||
class DemucsLSTM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size:int,
|
||||
hidden_size:int,
|
||||
num_layers:int,
|
||||
bidirectional:bool=True
|
||||
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int,
|
||||
bidirectional: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
||||
self.lstm = nn.LSTM(
|
||||
input_size, hidden_size, num_layers, bidirectional=bidirectional
|
||||
)
|
||||
dim = 2 if bidirectional else 1
|
||||
self.linear = nn.Linear(dim*hidden_size,hidden_size)
|
||||
self.linear = nn.Linear(dim * hidden_size, hidden_size)
|
||||
|
||||
def forward(self,x):
|
||||
def forward(self, x):
|
||||
|
||||
output,(h,c) = self.lstm(x)
|
||||
output, (h, c) = self.lstm(x)
|
||||
output = self.linear(output)
|
||||
|
||||
return output,(h,c)
|
||||
return output, (h, c)
|
||||
|
||||
|
||||
class DemucsEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels:int,
|
||||
hidden_size:int,
|
||||
kernel_size:int,
|
||||
stride:int=1,
|
||||
glu:bool=False,
|
||||
num_channels: int,
|
||||
hidden_size: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
glu: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||
multi_factor = 2 if glu else 1
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(num_channels,hidden_size,kernel_size,stride),
|
||||
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1),
|
||||
activation
|
||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||
activation,
|
||||
)
|
||||
|
||||
def forward(self,waveform):
|
||||
def forward(self, waveform):
|
||||
|
||||
return self.encoder(waveform)
|
||||
|
||||
class DemucsDecoder(nn.Module):
|
||||
|
||||
class DemucsDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_channels:int,
|
||||
hidden_size:int,
|
||||
kernel_size:int,
|
||||
stride:int=1,
|
||||
glu:bool=False,
|
||||
layer:int=0
|
||||
num_channels: int,
|
||||
hidden_size: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
glu: bool = False,
|
||||
layer: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||
multi_factor = 2 if glu else 1
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1),
|
||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||
activation,
|
||||
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride)
|
||||
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
||||
)
|
||||
if layer>0:
|
||||
if layer > 0:
|
||||
self.decoder.add_module("4", nn.ReLU())
|
||||
|
||||
def forward(self,waveform,):
|
||||
def forward(
|
||||
self,
|
||||
waveform,
|
||||
):
|
||||
|
||||
out = self.decoder(waveform)
|
||||
return out
|
||||
|
||||
|
||||
class Demucs(Model):
|
||||
"""
|
||||
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||
parameters:
|
||||
encoder_decoder: dict, optional
|
||||
keyword arguments passsed to encoder decoder block
|
||||
lstm : dict, optional
|
||||
keyword arguments passsed to LSTM block
|
||||
num_channels: int, defaults to 1
|
||||
number channels in input audio
|
||||
sampling_rate: int, defaults to 16KHz
|
||||
sampling rate of input audio
|
||||
lr : float, defaults to 1e-3
|
||||
learning rate used for training
|
||||
dataset: EnhancerDataset, optional
|
||||
EnhancerDataset object containing train/validation data for training
|
||||
duration : float, optional
|
||||
chunk duration in seconds
|
||||
loss : string or List of strings
|
||||
loss function to be used, available ("mse","mae","SI-SDR")
|
||||
metric : string or List of strings
|
||||
metric function to be used, available ("mse","mae","SI-SDR")
|
||||
|
||||
"""
|
||||
|
||||
ED_DEFAULTS = {
|
||||
"initial_output_channels":48,
|
||||
"kernel_size":8,
|
||||
"stride":1,
|
||||
"depth":5,
|
||||
"glu":True,
|
||||
"growth_factor":2,
|
||||
"initial_output_channels": 48,
|
||||
"kernel_size": 8,
|
||||
"stride": 1,
|
||||
"depth": 5,
|
||||
"glu": True,
|
||||
"growth_factor": 2,
|
||||
}
|
||||
LSTM_DEFAULTS = {
|
||||
"bidirectional":True,
|
||||
"num_layers":2,
|
||||
"bidirectional": True,
|
||||
"num_layers": 2,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_decoder:Optional[dict]=None,
|
||||
lstm:Optional[dict]=None,
|
||||
num_channels:int=1,
|
||||
resample:int=4,
|
||||
sampling_rate = 16000,
|
||||
lr:float=1e-3,
|
||||
dataset:Optional[EnhancerDataset]=None,
|
||||
loss:Union[str, List] = "mse",
|
||||
metric:Union[str, List] = "mse"
|
||||
|
||||
|
||||
encoder_decoder: Optional[dict] = None,
|
||||
lstm: Optional[dict] = None,
|
||||
num_channels: int = 1,
|
||||
resample: int = 4,
|
||||
sampling_rate=16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate!=dataset.sampling_rate:
|
||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
super().__init__(num_channels=num_channels,
|
||||
sampling_rate=sampling_rate,lr=lr,
|
||||
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
||||
super().__init__(
|
||||
num_channels=num_channels,
|
||||
sampling_rate=sampling_rate,
|
||||
lr=lr,
|
||||
dataset=dataset,
|
||||
duration=duration,
|
||||
loss=loss,
|
||||
metric=metric,
|
||||
)
|
||||
|
||||
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
||||
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
||||
self.save_hyperparameters("encoder_decoder","lstm","resample")
|
||||
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
||||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||
hidden = encoder_decoder["initial_output_channels"]
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for layer in range(encoder_decoder["depth"]):
|
||||
|
||||
encoder_layer = DemucsEncoder(num_channels=num_channels,
|
||||
hidden_size=hidden,
|
||||
kernel_size=encoder_decoder["kernel_size"],
|
||||
stride=encoder_decoder["stride"],
|
||||
glu=encoder_decoder["glu"],
|
||||
)
|
||||
encoder_layer = DemucsEncoder(
|
||||
num_channels=num_channels,
|
||||
hidden_size=hidden,
|
||||
kernel_size=encoder_decoder["kernel_size"],
|
||||
stride=encoder_decoder["stride"],
|
||||
glu=encoder_decoder["glu"],
|
||||
)
|
||||
self.encoder.append(encoder_layer)
|
||||
|
||||
decoder_layer = DemucsDecoder(num_channels=num_channels,
|
||||
hidden_size=hidden,
|
||||
kernel_size=encoder_decoder["kernel_size"],
|
||||
stride=1,
|
||||
glu=encoder_decoder["glu"],
|
||||
layer=layer
|
||||
)
|
||||
self.decoder.insert(0,decoder_layer)
|
||||
decoder_layer = DemucsDecoder(
|
||||
num_channels=num_channels,
|
||||
hidden_size=hidden,
|
||||
kernel_size=encoder_decoder["kernel_size"],
|
||||
stride=1,
|
||||
glu=encoder_decoder["glu"],
|
||||
layer=layer,
|
||||
)
|
||||
self.decoder.insert(0, decoder_layer)
|
||||
|
||||
num_channels = hidden
|
||||
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
||||
|
||||
self.de_lstm = DemucsLSTM(input_size=num_channels,
|
||||
hidden_size=num_channels,
|
||||
num_layers=lstm["num_layers"],
|
||||
bidirectional=lstm["bidirectional"]
|
||||
)
|
||||
self.de_lstm = DemucsLSTM(
|
||||
input_size=num_channels,
|
||||
hidden_size=num_channels,
|
||||
num_layers=lstm["num_layers"],
|
||||
bidirectional=lstm["bidirectional"],
|
||||
)
|
||||
|
||||
def forward(self,waveform):
|
||||
def forward(self, waveform):
|
||||
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(1)
|
||||
|
||||
if waveform.size(1)!=1:
|
||||
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels")
|
||||
if waveform.size(1) != 1:
|
||||
raise TypeError(
|
||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||
)
|
||||
|
||||
length = waveform.shape[-1]
|
||||
x = F.pad(waveform, (0,self.get_padding_length(length) - length))
|
||||
if self.hparams.resample>1:
|
||||
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
|
||||
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
|
||||
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||
if self.hparams.resample > 1:
|
||||
x = audio.resample_audio(
|
||||
audio=x,
|
||||
sr=self.hparams.sampling_rate,
|
||||
target_sr=int(
|
||||
self.hparams.sampling_rate * self.hparams.resample
|
||||
),
|
||||
)
|
||||
|
||||
encoder_outputs = []
|
||||
for encoder in self.encoder:
|
||||
x = encoder(x)
|
||||
encoder_outputs.append(x)
|
||||
x = x.permute(0,2,1)
|
||||
x,_ = self.de_lstm(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x, _ = self.de_lstm(x)
|
||||
|
||||
x = x.permute(0,2,1)
|
||||
x = x.permute(0, 2, 1)
|
||||
for decoder in self.decoder:
|
||||
skip_connection = encoder_outputs.pop(-1)
|
||||
x += skip_connection[..., :x.shape[-1]]
|
||||
x += skip_connection[..., : x.shape[-1]]
|
||||
x = decoder(x)
|
||||
|
||||
if self.hparams.resample > 1:
|
||||
x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
|
||||
self.hparams.sampling_rate)
|
||||
x = audio.resample_audio(
|
||||
x,
|
||||
int(self.hparams.sampling_rate * self.hparams.resample),
|
||||
self.hparams.sampling_rate,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def get_padding_length(self,input_length):
|
||||
def get_padding_length(self, input_length):
|
||||
|
||||
input_length = math.ceil(input_length * self.hparams.resample)
|
||||
|
||||
|
||||
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation
|
||||
input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1
|
||||
input_length = max(1,input_length)
|
||||
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration
|
||||
input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"]
|
||||
input_length = math.ceil(input_length/self.hparams.resample)
|
||||
for layer in range(
|
||||
self.hparams.encoder_decoder["depth"]
|
||||
): # encoder operation
|
||||
input_length = (
|
||||
math.ceil(
|
||||
(input_length - self.hparams.encoder_decoder["kernel_size"])
|
||||
/ self.hparams.encoder_decoder["stride"]
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
input_length = max(1, input_length)
|
||||
for layer in range(
|
||||
self.hparams.encoder_decoder["depth"]
|
||||
): # decoder operaration
|
||||
input_length = (input_length - 1) * self.hparams.encoder_decoder[
|
||||
"stride"
|
||||
] + self.hparams.encoder_decoder["kernel_size"]
|
||||
input_length = math.ceil(input_length / self.hparams.resample)
|
||||
|
||||
return int(input_length)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,50 +1,67 @@
|
|||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
from backports.cached_property import cached_property
|
||||
from importlib import import_module
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Optional, Union, List, Text, Dict, Any
|
||||
from torch.optim import Adam
|
||||
import torch
|
||||
from torch.nn.functional import pad
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from urllib.parse import urlparse
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Text, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from torch.optim import Adam
|
||||
|
||||
from enhancer import __version__
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.loss import Avergeloss
|
||||
from enhancer.inference import Inference
|
||||
from enhancer.loss import Avergeloss
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
"""
|
||||
Base class for all models
|
||||
parameters:
|
||||
num_channels: int, default to 1
|
||||
number of channels in input audio
|
||||
sampling_rate : int, default 16khz
|
||||
audio sampling rate
|
||||
lr: float, optional
|
||||
learning rate for model training
|
||||
dataset: EnhancerDataset, optional
|
||||
Enhancer dataset used for training/validation
|
||||
duration: float, optional
|
||||
duration used for training/inference
|
||||
loss : string or List of strings, default to "mse"
|
||||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels:int=1,
|
||||
sampling_rate:int=16000,
|
||||
lr:float=1e-3,
|
||||
dataset:Optional[EnhancerDataset]=None,
|
||||
duration:Optional[float]=None,
|
||||
num_channels: int = 1,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric:Union[str,List] = "mse"
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
super().__init__()
|
||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||
assert (
|
||||
num_channels == 1
|
||||
), "Enhancer only support for mono channel models"
|
||||
self.dataset = dataset
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||
self.save_hyperparameters(
|
||||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||
self.logger.experiment.log_dict(
|
||||
dict(self.hparams), "hyperparameters.json"
|
||||
)
|
||||
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
|
|
@ -54,10 +71,10 @@ class Model(pl.LightningModule):
|
|||
return self._loss
|
||||
|
||||
@loss.setter
|
||||
def loss(self,loss):
|
||||
def loss(self, loss):
|
||||
|
||||
if isinstance(loss,str):
|
||||
losses = [loss]
|
||||
if isinstance(loss, str):
|
||||
losses = [loss]
|
||||
|
||||
self._loss = Avergeloss(losses)
|
||||
|
||||
|
|
@ -66,23 +83,22 @@ class Model(pl.LightningModule):
|
|||
return self._metric
|
||||
|
||||
@metric.setter
|
||||
def metric(self,metric):
|
||||
def metric(self, metric):
|
||||
|
||||
if isinstance(metric,str):
|
||||
metric = [metric]
|
||||
if isinstance(metric, str):
|
||||
metric = [metric]
|
||||
|
||||
self._metric = Avergeloss(metric)
|
||||
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self,dataset):
|
||||
def dataset(self, dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def setup(self,stage:Optional[str]=None):
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
if stage == "fit":
|
||||
self.dataset.setup(stage)
|
||||
self.dataset.model = self
|
||||
|
|
@ -94,9 +110,9 @@ class Model(pl.LightningModule):
|
|||
return self.dataset.val_dataloader()
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr = self.hparams.lr)
|
||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||
|
||||
def training_step(self,batch, batch_idx:int):
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
|
|
@ -105,13 +121,16 @@ class Model(pl.LightningModule):
|
|||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="train_loss", value=loss.item(),
|
||||
step=self.global_step)
|
||||
self.log("train_loss",loss.item())
|
||||
return {"loss":loss}
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="train_loss",
|
||||
value=loss.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.log("train_loss", loss.item())
|
||||
return {"loss": loss}
|
||||
|
||||
def validation_step(self,batch,batch_idx:int):
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
|
|
@ -119,48 +138,92 @@ class Model(pl.LightningModule):
|
|||
|
||||
metric_val = self.metric(prediction, target)
|
||||
loss_val = self.loss(prediction, target)
|
||||
self.log("val_metric",metric_val.item())
|
||||
self.log("val_loss",loss_val.item())
|
||||
self.log("val_metric", metric_val.item())
|
||||
self.log("val_loss", loss_val.item())
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_loss",value=loss_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_metric",value=metric_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_loss",
|
||||
value=loss_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_metric",
|
||||
value=metric_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
return {"loss":loss_val}
|
||||
return {"loss": loss_val}
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["enhancer"] = {
|
||||
"version": {
|
||||
"enhancer":__version__,
|
||||
"pytorch":torch.__version__
|
||||
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||
"architecture": {
|
||||
"module": self.__class__.__module__,
|
||||
"class": self.__class__.__name__,
|
||||
},
|
||||
"architecture":{
|
||||
"module":self.__class__.__module__,
|
||||
"class":self.__class__.__name__
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
checkpoint: Union[Path, Text],
|
||||
map_location = None,
|
||||
map_location=None,
|
||||
hparams_file: Union[Path, Text] = None,
|
||||
strict: bool = True,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
cached_dir: Union[Path, Text]=CACHE_DIR,
|
||||
**kwargs
|
||||
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load Pretrained model
|
||||
|
||||
parameters:
|
||||
checkpoint : Path or str
|
||||
Path to checkpoint, or a remote URL, or a model identifier from
|
||||
the huggingface.co model hub.
|
||||
map_location: optional
|
||||
Same role as in torch.load().
|
||||
Defaults to `lambda storage, loc: storage`.
|
||||
hparams_file : Path or str, optional
|
||||
Path to a .yaml file with hierarchical structure as in this example:
|
||||
drop_prob: 0.2
|
||||
dataloader:
|
||||
batch_size: 32
|
||||
You most likely won’t need this since Lightning will always save the
|
||||
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||
file with the hparams you would like to use. These will be converted
|
||||
into a dict and passed into your Model for use.
|
||||
strict : bool, optional
|
||||
Whether to strictly enforce that the keys in checkpoint match
|
||||
the keys returned by this module’s state dict. Defaults to True.
|
||||
use_auth_token : str, optional
|
||||
When loading a private huggingface.co model, set `use_auth_token`
|
||||
to True or to a string containing your hugginface.co authentication
|
||||
token that can be obtained by running `huggingface-cli login`
|
||||
cache_dir: Path or str, optional
|
||||
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||
kwargs: optional
|
||||
Any extra keyword args needed to init the model.
|
||||
Can also be used to override saved hyperparameter values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Model
|
||||
Model
|
||||
|
||||
See also
|
||||
--------
|
||||
torch.load
|
||||
"""
|
||||
|
||||
checkpoint = str(checkpoint)
|
||||
if hparams_file is not None:
|
||||
|
|
@ -168,7 +231,7 @@ class Model(pl.LightningModule):
|
|||
|
||||
if os.path.isfile(checkpoint):
|
||||
model_path_pl = checkpoint
|
||||
elif urlparse(checkpoint).scheme in ("http","https"):
|
||||
elif urlparse(checkpoint).scheme in ("http", "https"):
|
||||
model_path_pl = checkpoint
|
||||
else:
|
||||
|
||||
|
|
@ -180,45 +243,59 @@ class Model(pl.LightningModule):
|
|||
revision_id = None
|
||||
|
||||
url = hf_hub_url(
|
||||
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
||||
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||
)
|
||||
model_path_pl = cached_download(
|
||||
url=url,library_name="enhancer",library_version=__version__,
|
||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
||||
url=url,
|
||||
library_name="enhancer",
|
||||
library_version=__version__,
|
||||
cache_dir=cached_dir,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
if map_location is None:
|
||||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
try:
|
||||
model = Klass.load_from_checkpoint(
|
||||
checkpoint_path = model_path_pl,
|
||||
map_location = map_location,
|
||||
hparams_file = hparams_file,
|
||||
strict = strict,
|
||||
**kwargs
|
||||
checkpoint_path=model_path_pl,
|
||||
map_location=map_location,
|
||||
hparams_file=hparams_file,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
def infer(self,batch:torch.Tensor,batch_size:int=32):
|
||||
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||
"""
|
||||
perform model inference
|
||||
parameters:
|
||||
batch : torch.Tensor
|
||||
input data
|
||||
batch_size : int, default 32
|
||||
batch size for inference
|
||||
"""
|
||||
|
||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
assert (
|
||||
batch.ndim == 3
|
||||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0,batch.shape[0],batch_size):
|
||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||
for batch_id in range(0, batch.shape[0], batch_size):
|
||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||
self.device
|
||||
)
|
||||
prediction = self(batch_data)
|
||||
batch_predictions.append(prediction)
|
||||
|
||||
|
|
@ -226,46 +303,61 @@ class Model(pl.LightningModule):
|
|||
|
||||
def enhance(
|
||||
self,
|
||||
audio:Union[Path,np.ndarray,torch.Tensor],
|
||||
sampling_rate:Optional[int]=None,
|
||||
batch_size:int=32,
|
||||
save_output:bool=False,
|
||||
duration:Optional[int]=None,
|
||||
step_size:Optional[int]=None,):
|
||||
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||
sampling_rate: Optional[int] = None,
|
||||
batch_size: int = 32,
|
||||
save_output: bool = False,
|
||||
duration: Optional[int] = None,
|
||||
step_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Enhance audio using loaded pretained model.
|
||||
|
||||
parameters:
|
||||
audio: Path to audio file or numpy array or torch tensor
|
||||
single input audio
|
||||
sampling_rate: int, optional incase input is path
|
||||
sampling rate of input
|
||||
batch_size: int, default 32
|
||||
input audio is split into multiple chunks. Inference is done on batches
|
||||
of these chunks according to given batch size.
|
||||
save_output : bool, default False
|
||||
weather to save output to file
|
||||
duration : float, optional
|
||||
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||
step_size: int, optional
|
||||
step size between consecutive durations, defaults to 50% of duration
|
||||
"""
|
||||
|
||||
model_sampling_rate = self.hparams["sampling_rate"]
|
||||
if duration is None:
|
||||
duration = self.hparams["duration"]
|
||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||
waveform = Inference.read_input(
|
||||
audio, sampling_rate, model_sampling_rate
|
||||
)
|
||||
waveform.to(self.device)
|
||||
window_size = round(duration * model_sampling_rate)
|
||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
||||
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
||||
batched_waveform = Inference.batchify(
|
||||
waveform, window_size, step_size=step_size
|
||||
)
|
||||
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(
|
||||
batch_prediction,
|
||||
window_size,
|
||||
waveform.shape[-1],
|
||||
step_size,
|
||||
)
|
||||
|
||||
if save_output and isinstance(audio,(str,Path)):
|
||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||
if save_output and isinstance(audio, (str, Path)):
|
||||
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||
|
||||
else:
|
||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||
audio, sampling_rate)
|
||||
waveform = Inference.prepare_output(
|
||||
waveform, model_sampling_rate, audio, sampling_rate
|
||||
)
|
||||
return waveform
|
||||
|
||||
@property
|
||||
def valid_monitor(self):
|
||||
|
||||
return "max" if self.loss.higher_better else "min"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,82 +1,124 @@
|
|||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.model import Model
|
||||
|
||||
|
||||
class WavenetDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels:int,
|
||||
out_channels:int,
|
||||
kernel_size:int=5,
|
||||
padding:int=2,
|
||||
stride:int=1,
|
||||
dilation:int=1,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 5,
|
||||
padding: int = 2,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super(WavenetDecoder,self).__init__()
|
||||
super(WavenetDecoder, self).__init__()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
nn.LeakyReLU(negative_slope=0.1)
|
||||
nn.LeakyReLU(negative_slope=0.1),
|
||||
)
|
||||
|
||||
def forward(self,waveform):
|
||||
def forward(self, waveform):
|
||||
|
||||
return self.decoder(waveform)
|
||||
|
||||
class WavenetEncoder(nn.Module):
|
||||
|
||||
class WavenetEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels:int,
|
||||
out_channels:int,
|
||||
kernel_size:int=15,
|
||||
padding:int=7,
|
||||
stride:int=1,
|
||||
dilation:int=1,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 15,
|
||||
padding: int = 7,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super(WavenetEncoder,self).__init__()
|
||||
super(WavenetEncoder, self).__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
nn.LeakyReLU(negative_slope=0.1)
|
||||
nn.LeakyReLU(negative_slope=0.1),
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform
|
||||
):
|
||||
def forward(self, waveform):
|
||||
return self.encoder(waveform)
|
||||
|
||||
|
||||
class WaveUnet(Model):
|
||||
"""
|
||||
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||
parameters:
|
||||
num_channels: int, defaults to 1
|
||||
number of channels in input audio
|
||||
depth : int, defaults to 12
|
||||
depth of network
|
||||
initial_output_channels: int, defaults to 24
|
||||
number of output channels in intial upsampling layer
|
||||
sampling_rate: int, defaults to 16KHz
|
||||
sampling rate of input audio
|
||||
lr : float, defaults to 1e-3
|
||||
learning rate used for training
|
||||
dataset: EnhancerDataset, optional
|
||||
EnhancerDataset object containing train/validation data for training
|
||||
duration : float, optional
|
||||
chunk duration in seconds
|
||||
loss : string or List of strings
|
||||
loss function to be used, available ("mse","mae","SI-SDR")
|
||||
metric : string or List of strings
|
||||
metric function to be used, available ("mse","mae","SI-SDR")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels:int=1,
|
||||
depth:int=12,
|
||||
initial_output_channels:int=24,
|
||||
sampling_rate:int=16000,
|
||||
lr:float=1e-3,
|
||||
dataset:Optional[EnhancerDataset]=None,
|
||||
duration:Optional[float]=None,
|
||||
num_channels: int = 1,
|
||||
depth: int = 12,
|
||||
initial_output_channels: int = 24,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric:Union[str,List] = "mse"
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate!=dataset.sampling_rate:
|
||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
super().__init__(num_channels=num_channels,
|
||||
sampling_rate=sampling_rate,lr=lr,
|
||||
dataset=dataset,duration=duration,loss=loss, metric=metric
|
||||
super().__init__(
|
||||
num_channels=num_channels,
|
||||
sampling_rate=sampling_rate,
|
||||
lr=lr,
|
||||
dataset=dataset,
|
||||
duration=duration,
|
||||
loss=loss,
|
||||
metric=metric,
|
||||
)
|
||||
self.save_hyperparameters("depth")
|
||||
self.encoders = nn.ModuleList()
|
||||
|
|
@ -84,72 +126,76 @@ class WaveUnet(Model):
|
|||
out_channels = initial_output_channels
|
||||
for layer in range(depth):
|
||||
|
||||
encoder = WavenetEncoder(num_channels,out_channels)
|
||||
encoder = WavenetEncoder(num_channels, out_channels)
|
||||
self.encoders.append(encoder)
|
||||
|
||||
num_channels = out_channels
|
||||
out_channels += initial_output_channels
|
||||
if layer == depth -1 :
|
||||
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels)
|
||||
if layer == depth - 1:
|
||||
decoder = WavenetDecoder(
|
||||
depth * initial_output_channels + num_channels, num_channels
|
||||
)
|
||||
else:
|
||||
decoder = WavenetDecoder(num_channels+out_channels,num_channels)
|
||||
decoder = WavenetDecoder(
|
||||
num_channels + out_channels, num_channels
|
||||
)
|
||||
|
||||
self.decoders.insert(0,decoder)
|
||||
self.decoders.insert(0, decoder)
|
||||
|
||||
bottleneck_dim = depth * initial_output_channels
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
|
||||
padding=7),
|
||||
nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
|
||||
nn.BatchNorm1d(bottleneck_dim),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
)
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
||||
nn.Tanh()
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,waveform
|
||||
):
|
||||
def forward(self, waveform):
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(1)
|
||||
|
||||
if waveform.size(1)!=1:
|
||||
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels")
|
||||
if waveform.size(1) != 1:
|
||||
raise TypeError(
|
||||
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||
)
|
||||
|
||||
encoder_outputs = []
|
||||
out = waveform
|
||||
for encoder in self.encoders:
|
||||
out = encoder(out)
|
||||
encoder_outputs.insert(0,out)
|
||||
out = out[:,:,::2]
|
||||
encoder_outputs.insert(0, out)
|
||||
out = out[:, :, ::2]
|
||||
|
||||
out = self.bottleneck(out)
|
||||
|
||||
for layer,decoder in enumerate(self.decoders):
|
||||
for layer, decoder in enumerate(self.decoders):
|
||||
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||
out = self.fix_last_dim(out,encoder_outputs[layer])
|
||||
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
||||
out = self.fix_last_dim(out, encoder_outputs[layer])
|
||||
out = torch.cat([out, encoder_outputs[layer]], dim=1)
|
||||
out = decoder(out)
|
||||
|
||||
out = torch.cat([out, waveform],dim=1)
|
||||
out = torch.cat([out, waveform], dim=1)
|
||||
out = self.final(out)
|
||||
return out
|
||||
|
||||
def fix_last_dim(self,x,target):
|
||||
def fix_last_dim(self, x, target):
|
||||
"""
|
||||
trying to do centre crop along last dimension
|
||||
centre crop along last dimension
|
||||
"""
|
||||
|
||||
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension"
|
||||
assert (
|
||||
x.shape[-1] >= target.shape[-1]
|
||||
), "input dimension cannot be larger than target dimension"
|
||||
if x.shape[-1] == target.shape[-1]:
|
||||
return x
|
||||
|
||||
diff = x.shape[-1] - target.shape[-1]
|
||||
if diff%2!=0:
|
||||
x = F.pad(x,(0,1))
|
||||
if diff % 2 != 0:
|
||||
x = F.pad(x, (0, 1))
|
||||
diff += 1
|
||||
|
||||
crop = diff//2
|
||||
return x[:,:,crop:-crop]
|
||||
crop = diff // 2
|
||||
return x[:, :, crop:-crop]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from enhancer.utils.utils import check_files
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.utils import check_files
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Files:
|
||||
train_clean : str
|
||||
train_noisy : str
|
||||
test_clean : str
|
||||
test_noisy : str
|
||||
|
||||
|
||||
train_clean: str
|
||||
train_noisy: str
|
||||
test_clean: str
|
||||
test_noisy: str
|
||||
|
|
|
|||
|
|
@ -1,17 +1,26 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import librosa
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
class Audio:
|
||||
"""
|
||||
Audio utils
|
||||
parameters:
|
||||
sampling_rate : int, defaults to 16KHz
|
||||
audio sampling rate
|
||||
mono: bool, defaults to True
|
||||
return_tensors: bool, defaults to True
|
||||
returns torch tensor type if set to True else numpy ndarray
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate:int=16000,
|
||||
mono:bool=True,
|
||||
return_tensor=True
|
||||
self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True
|
||||
) -> None:
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
|
|
@ -20,22 +29,39 @@ class Audio:
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
audio,
|
||||
sampling_rate:Optional[int]=None,
|
||||
offset:Optional[float] = None,
|
||||
duration:Optional[float] = None
|
||||
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||
sampling_rate: Optional[int] = None,
|
||||
offset: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
):
|
||||
if isinstance(audio,str):
|
||||
"""
|
||||
read and process input audio
|
||||
parameters:
|
||||
audio: Path to audio file or numpy array or torch tensor
|
||||
single input audio
|
||||
sampling_rate : int, optional
|
||||
sampling rate of the audio input
|
||||
offset: float, optional
|
||||
offset from which the audio must be read, reads from beginning if unused.
|
||||
duration: float (seconds), optional
|
||||
read duration, reads full audio starting from offset if not used
|
||||
"""
|
||||
if isinstance(audio, str):
|
||||
if os.path.exists(audio):
|
||||
audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False,
|
||||
offset=offset,duration=duration)
|
||||
audio, sampling_rate = librosa.load(
|
||||
audio,
|
||||
sr=sampling_rate,
|
||||
mono=False,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.reshape(1,-1)
|
||||
audio = audio.reshape(1, -1)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {audio} deos not exist")
|
||||
elif isinstance(audio,np.ndarray):
|
||||
elif isinstance(audio, np.ndarray):
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.reshape(1,-1)
|
||||
audio = audio.reshape(1, -1)
|
||||
else:
|
||||
raise ValueError("audio should be either filepath or numpy ndarray")
|
||||
|
||||
|
|
@ -43,40 +69,60 @@ class Audio:
|
|||
audio = self.convert_mono(audio)
|
||||
|
||||
if sampling_rate:
|
||||
audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate)
|
||||
audio = self.__class__.resample_audio(
|
||||
audio, self.sampling_rate, sampling_rate
|
||||
)
|
||||
if self.return_tensor:
|
||||
return torch.tensor(audio)
|
||||
else:
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def convert_mono(
|
||||
audio
|
||||
def convert_mono(audio: Union[np.ndarray, torch.Tensor]):
|
||||
"""
|
||||
convert input audio into mono (1)
|
||||
parameters:
|
||||
audio: np.ndarray or torch.Tensor
|
||||
"""
|
||||
if len(audio.shape) > 2:
|
||||
assert (
|
||||
audio.shape[0] == 1
|
||||
), "convert mono only accepts single waveform"
|
||||
audio = audio.reshape(audio.shape[1], audio.shape[2])
|
||||
|
||||
):
|
||||
if len(audio.shape)>2:
|
||||
assert audio.shape[0] == 1, "convert mono only accepts single waveform"
|
||||
audio = audio.reshape(audio.shape[1],audio.shape[2])
|
||||
|
||||
assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}"
|
||||
num_channels,num_samples = audio.shape
|
||||
if num_channels>1:
|
||||
return audio.mean(axis=0).reshape(1,num_samples)
|
||||
assert (
|
||||
audio.shape[1] >> audio.shape[0]
|
||||
), f"expected input format (num_channels,num_samples) got {audio.shape}"
|
||||
num_channels, num_samples = audio.shape
|
||||
if num_channels > 1:
|
||||
return audio.mean(axis=0).reshape(1, num_samples)
|
||||
return audio
|
||||
|
||||
|
||||
@staticmethod
|
||||
def resample_audio(
|
||||
audio,
|
||||
sr:int,
|
||||
target_sr:int
|
||||
audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int
|
||||
):
|
||||
if sr!=target_sr:
|
||||
if isinstance(audio,np.ndarray):
|
||||
audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr)
|
||||
elif isinstance(audio,torch.Tensor):
|
||||
audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr)
|
||||
"""
|
||||
resample audio to desired sampling rate
|
||||
parameters:
|
||||
audio : Path to audio file or numpy array or torch tensor
|
||||
audio waveform
|
||||
sr : int
|
||||
current sampling rate
|
||||
target_sr : int
|
||||
target sampling rate
|
||||
|
||||
"""
|
||||
if sr != target_sr:
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
||||
elif isinstance(audio, torch.Tensor):
|
||||
audio = torchaudio.functional.resample(
|
||||
audio, orig_freq=sr, new_freq=target_sr
|
||||
)
|
||||
else:
|
||||
raise ValueError("Input should be either numpy array or torch tensor")
|
||||
raise ValueError(
|
||||
"Input should be either numpy array or torch tensor"
|
||||
)
|
||||
|
||||
return audio
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def create_unique_rng(epoch:int):
|
||||
def create_unique_rng(epoch: int):
|
||||
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||
|
||||
rng = random.Random()
|
||||
|
||||
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0"))
|
||||
global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK',"0"))
|
||||
node_rank = int(os.environ.get('NODE_RANK',"0"))
|
||||
world_size = int(os.environ.get('WORLD_SIZE',"0"))
|
||||
global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
|
||||
global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
node_rank = int(os.environ.get("NODE_RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "0"))
|
||||
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
|
|
@ -24,17 +24,13 @@ def create_unique_rng(epoch:int):
|
|||
worker_id = 0
|
||||
|
||||
seed = (
|
||||
global_seed
|
||||
+ worker_id
|
||||
+ local_rank * num_workers
|
||||
+ node_rank * num_workers * global_rank
|
||||
+ epoch * num_workers * world_size
|
||||
)
|
||||
global_seed
|
||||
+ worker_id
|
||||
+ local_rank * num_workers
|
||||
+ node_rank * num_workers * global_rank
|
||||
+ epoch * num_workers * world_size
|
||||
)
|
||||
|
||||
rng.seed(seed)
|
||||
|
||||
return rng
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,26 @@
|
|||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from enhancer.utils.config import Files
|
||||
|
||||
def check_files(root_dir:str, files:Files):
|
||||
|
||||
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')]
|
||||
def check_files(root_dir: str, files: Files):
|
||||
|
||||
path_variables = [
|
||||
member_var
|
||||
for member_var in dir(files)
|
||||
if not member_var.startswith("__")
|
||||
]
|
||||
for variable in path_variables:
|
||||
path = getattr(files,variable)
|
||||
if not os.path.isdir(os.path.join(root_dir,path)):
|
||||
path = getattr(files, variable)
|
||||
if not os.path.isdir(os.path.join(root_dir, path)):
|
||||
raise ValueError(f"Invalid {path}, is not a directory")
|
||||
|
||||
return files,root_dir
|
||||
return files, root_dir
|
||||
|
||||
|
||||
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
|
||||
|
||||
def merge_dict(default_dict:dict, custom:Optional[dict]=None):
|
||||
params = dict(default_dict)
|
||||
if custom:
|
||||
params.update(custom)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
[tool.black]
|
||||
line-length = 80
|
||||
target-version = ['py38']
|
||||
exclude = '''
|
||||
|
||||
(
|
||||
/(
|
||||
\.eggs # exclude a few common directories in the
|
||||
| \.git # root of the project
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
)/
|
||||
)
|
||||
'''
|
||||
|
|
@ -1,15 +1,16 @@
|
|||
joblib==1.1.0
|
||||
numpy==1.19.5
|
||||
librosa==0.9.1
|
||||
numpy==1.19.5
|
||||
hydra-core==1.2.0
|
||||
scikit-learn==0.24.2
|
||||
scipy==1.5.4
|
||||
torch==1.10.2
|
||||
tqdm==4.64.0
|
||||
mlflow==1.23.1
|
||||
protobuf==3.19.3
|
||||
boto3==1.23.9
|
||||
torchaudio==0.10.2
|
||||
huggingface-hub==0.4.0
|
||||
pytorch-lightning==1.5.10
|
||||
black>=22.8.0
|
||||
boto3>=1.24.86
|
||||
flake8>=5.0.4
|
||||
huggingface-hu>=0.10.0
|
||||
hydra-core>=1.2.0
|
||||
joblib>=1.2.0
|
||||
librosa>=0.9.2
|
||||
mlflow>=1.29.0
|
||||
numpy>=1.23.3
|
||||
protobuf>=3.19.6
|
||||
pytorch-lightning>=1.7.7
|
||||
scikit-learn>=1.1.2
|
||||
scipy>=1.9.1
|
||||
torch>=1.12.1
|
||||
torchaudio>=0.12.1
|
||||
tqdm>=4.64.1
|
||||
|
|
|
|||
|
|
@ -1,31 +1,32 @@
|
|||
from asyncio import base_tasks
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from enhancer.loss import mean_absolute_error, mean_squared_error
|
||||
|
||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||
|
||||
|
||||
def check_loss_shapes_compatibility(loss_fun):
|
||||
|
||||
batch_size = 4
|
||||
shape = (1,1000)
|
||||
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
|
||||
shape = (1, 1000)
|
||||
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
|
||||
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("loss",loss_functions)
|
||||
@pytest.mark.parametrize("loss", loss_functions)
|
||||
def test_loss_input_shapes(loss):
|
||||
check_loss_shapes_compatibility(loss)
|
||||
|
||||
@pytest.mark.parametrize("loss",loss_functions)
|
||||
|
||||
@pytest.mark.parametrize("loss", loss_functions)
|
||||
def test_loss_output_type(loss):
|
||||
|
||||
batch_size = 4
|
||||
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
||||
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
|
||||
batch_size, 1, 1000
|
||||
)
|
||||
loss_value = loss(prediction, target)
|
||||
assert isinstance(loss_value.item(),float)
|
||||
|
||||
|
||||
assert isinstance(loss_value.item(), float)
|
||||
|
|
|
|||
|
|
@ -1,46 +1,43 @@
|
|||
import pytest
|
||||
import torch
|
||||
from enhancer import data
|
||||
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.models import Demucs
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models import Demucs
|
||||
from enhancer.utils.config import Files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vctk_dataset():
|
||||
root_dir = "tests/data/vctk"
|
||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
||||
files = Files(
|
||||
train_clean="clean_testset_wav",
|
||||
train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
||||
def test_forward(batch_size,samples):
|
||||
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||
def test_forward(batch_size, samples):
|
||||
model = Demucs()
|
||||
model.eval()
|
||||
|
||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
||||
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||
with torch.no_grad():
|
||||
_ = model(data)
|
||||
|
||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
||||
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||
with torch.no_grad():
|
||||
with pytest.raises(TypeError):
|
||||
_ = model(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
||||
def test_demucs_init(dataset,channels,loss):
|
||||
@pytest.mark.parametrize(
|
||||
"dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||
)
|
||||
def test_demucs_init(dataset, channels, loss):
|
||||
with torch.no_grad():
|
||||
model = Demucs(num_channels=channels,dataset=dataset,loss=loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
_ = Demucs(num_channels=channels, dataset=dataset, loss=loss)
|
||||
|
|
|
|||
|
|
@ -1,46 +1,43 @@
|
|||
import pytest
|
||||
import torch
|
||||
from enhancer import data
|
||||
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.models import WaveUnet
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models import WaveUnet
|
||||
from enhancer.utils.config import Files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vctk_dataset():
|
||||
root_dir = "tests/data/vctk"
|
||||
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav")
|
||||
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files)
|
||||
files = Files(
|
||||
train_clean="clean_testset_wav",
|
||||
train_noisy="noisy_testset_wav",
|
||||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,samples",[(1,1000)])
|
||||
def test_forward(batch_size,samples):
|
||||
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||
def test_forward(batch_size, samples):
|
||||
model = WaveUnet()
|
||||
model.eval()
|
||||
|
||||
data = torch.rand(batch_size,1,samples,requires_grad=False)
|
||||
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||
with torch.no_grad():
|
||||
_ = model(data)
|
||||
|
||||
data = torch.rand(batch_size,2,samples,requires_grad=False)
|
||||
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||
with torch.no_grad():
|
||||
with pytest.raises(TypeError):
|
||||
_ = model(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])])
|
||||
def test_demucs_init(dataset,channels,loss):
|
||||
@pytest.mark.parametrize(
|
||||
"dataset,channels,loss",
|
||||
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||
)
|
||||
def test_demucs_init(dataset, channels, loss):
|
||||
with torch.no_grad():
|
||||
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
_ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)
|
||||
|
|
|
|||
|
|
@ -4,22 +4,26 @@ import torch
|
|||
from enhancer.inference import Inference
|
||||
|
||||
|
||||
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)])
|
||||
@pytest.mark.parametrize(
|
||||
"audio",
|
||||
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
|
||||
)
|
||||
def test_read_input(audio):
|
||||
|
||||
read_audio = Inference.read_input(audio,48000,16000)
|
||||
assert isinstance(read_audio,torch.Tensor)
|
||||
read_audio = Inference.read_input(audio, 48000, 16000)
|
||||
assert isinstance(read_audio, torch.Tensor)
|
||||
assert read_audio.shape[0] == 1
|
||||
|
||||
|
||||
def test_batchify():
|
||||
rand = torch.rand(1,1000)
|
||||
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100)
|
||||
rand = torch.rand(1, 1000)
|
||||
batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
|
||||
assert batched_rand.shape[0] == 12
|
||||
|
||||
|
||||
def test_aggregate():
|
||||
rand = torch.rand(12,1,100)
|
||||
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100)
|
||||
rand = torch.rand(12, 1, 100)
|
||||
agg_rand = Inference.aggreagate(
|
||||
data=rand, window_size=100, total_frames=1000, step_size=100
|
||||
)
|
||||
assert agg_rand.shape[-1] == 1000
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,46 +1,50 @@
|
|||
from logging import root
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.config import Files
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
from enhancer.utils.io import Audio
|
||||
|
||||
|
||||
def test_io_channel():
|
||||
|
||||
input_audio = np.random.rand(2,32000)
|
||||
audio = Audio(mono=True,return_tensor=False)
|
||||
input_audio = np.random.rand(2, 32000)
|
||||
audio = Audio(mono=True, return_tensor=False)
|
||||
output_audio = audio(input_audio)
|
||||
assert output_audio.shape[0] == 1
|
||||
|
||||
|
||||
def test_io_resampling():
|
||||
|
||||
input_audio = np.random.rand(1,32000)
|
||||
resampled_audio = Audio.resample_audio(input_audio,16000,8000)
|
||||
input_audio = np.random.rand(1, 32000)
|
||||
resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
|
||||
|
||||
input_audio = torch.rand(1,32000)
|
||||
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000)
|
||||
input_audio = torch.rand(1, 32000)
|
||||
resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
|
||||
|
||||
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
|
||||
|
||||
|
||||
def test_fileprocessor_vctk():
|
||||
|
||||
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav",
|
||||
"tests/data/vctk/noisy_testset_wav",48000)
|
||||
fp = Fileprocessor.from_name(
|
||||
"vctk",
|
||||
"tests/data/vctk/clean_testset_wav",
|
||||
"tests/data/vctk/noisy_testset_wav",
|
||||
48000,
|
||||
)
|
||||
matching_dict = fp.prepare_matching_dict()
|
||||
assert len(matching_dict)==2
|
||||
assert len(matching_dict) == 2
|
||||
|
||||
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
|
||||
|
||||
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||
def test_fileprocessor_names(dataset_name):
|
||||
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000)
|
||||
assert hasattr(fp.matching_function, '__call__')
|
||||
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
|
||||
assert hasattr(fp.matching_function, "__call__")
|
||||
|
||||
|
||||
def test_fileprocessor_invaliname():
|
||||
with pytest.raises(ValueError):
|
||||
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
_ = Fileprocessor.from_name(
|
||||
"undefined", "clean_dir", "noisy_dir", 16000
|
||||
).prepare_matching_dict()
|
||||
|
|
|
|||
Loading…
Reference in New Issue