diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..abbbc73 --- /dev/null +++ b/.flake8 @@ -0,0 +1,9 @@ +[flake8] +per-file-ignores = __init__.py:F401 +ignore = E203, E266, E501, W503 +# line length is intentionally set to 80 here because black uses Bugbear +# See https://github.com/psf/black/blob/master/README.md#line-length for more details +max-line-length = 80 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 +exclude = tools/kaldi_decoder diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..807429c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ + +repos: + # # Clean Notebooks + # - repo: https://github.com/kynan/nbstripout + # rev: master + # hooks: + # - id: nbstripout + # Format Code + - repo: https://github.com/ambv/black + rev: 22.8.0 + hooks: + - id: black + + # Sort imports + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://gitlab.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: ['--ignore=E203,E501,F811,E712,W503'] + + # Formatting, Whitespace, etc + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=no'] diff --git a/README.md b/README.md index e462afa..97e0c7e 100644 --- a/README.md +++ b/README.md @@ -1 +1,6 @@ -# enhancer \ No newline at end of file +# enhancer +Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides + +* Various pretrained models nicely integrated with huggingface that users can select and use without any hastle. +* Ability to train and validation your own custom speech enhancement models with just under 10 lines of code! +* A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself! \ No newline at end of file diff --git a/cli/train.py b/cli/train.py deleted file mode 100644 index dee3d2e..0000000 --- a/cli/train.py +++ /dev/null @@ -1,67 +0,0 @@ -from genericpath import isfile -import os -from types import MethodType -import hydra -from hydra.utils import instantiate -from omegaconf import DictConfig -from torch.optim.lr_scheduler import ReduceLROnPlateau -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning.loggers import MLFlowLogger -os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID","0") - -@hydra.main(config_path="train_config",config_name="config") -def main(config: DictConfig): - - callbacks = [] - logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) - - - parameters = config.hyperparameters - - dataset = instantiate(config.dataset) - model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), - loss=parameters.get("loss"), metric = parameters.get("metric")) - - direction = model.valid_monitor - checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, - mode=direction,every_n_epochs=1 - ) - callbacks.append(checkpoint) - early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=0.0, - patience=parameters.get("EarlyStopping_patience",10), - strict=True, - verbose=False, - ) - callbacks.append(early_stopping) - - def configure_optimizer(self): - optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) - scheduler = ReduceLROnPlateau( - optimizer=optimizer, - mode=direction, - factor=parameters.get("ReduceLr_factor",0.1), - verbose=True, - min_lr=parameters.get("min_lr",1e-6), - patience=parameters.get("ReduceLr_patience",3) - ) - return {"optimizer":optimizer, "lr_scheduler":scheduler} - - model.configure_parameters = MethodType(configure_optimizer,model) - - trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) - trainer.fit(model) - - saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") - if os.path.isfile(saved_location): - logger.experiment.log_artifact(logger.run_id,saved_location) - - - -if __name__=="__main__": - main() \ No newline at end of file diff --git a/enhancer/__init__.py b/enhancer/__init__.py index b3c06d4..f102a9c 100644 --- a/enhancer/__init__.py +++ b/enhancer/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1" \ No newline at end of file +__version__ = "0.0.1" diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py new file mode 100644 index 0000000..cb3c7c1 --- /dev/null +++ b/enhancer/cli/train.py @@ -0,0 +1,85 @@ +import os +from types import MethodType + +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import MLFlowLogger +from torch.optim.lr_scheduler import ReduceLROnPlateau + +os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID", "0") + + +@hydra.main(config_path="train_config", config_name="config") +def main(config: DictConfig): + + callbacks = [] + logger = MLFlowLogger( + experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name, + tags={"JOB_ID": JOB_ID}, + ) + + parameters = config.hyperparameters + + dataset = instantiate(config.dataset) + model = instantiate( + config.model, + dataset=dataset, + lr=parameters.get("lr"), + loss=parameters.get("loss"), + metric=parameters.get("metric"), + ) + + direction = model.valid_monitor + checkpoint = ModelCheckpoint( + dirpath="./model", + filename=f"model_{JOB_ID}", + monitor="val_loss", + verbose=True, + mode=direction, + every_n_epochs=1, + ) + callbacks.append(checkpoint) + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) + + def configure_optimizer(self): + optimizer = instantiate( + config.optimizer, + lr=parameters.get("lr"), + parameters=self.parameters(), + ) + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + mode=direction, + factor=parameters.get("ReduceLr_factor", 0.1), + verbose=True, + min_lr=parameters.get("min_lr", 1e-6), + patience=parameters.get("ReduceLr_patience", 3), + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + model.configure_parameters = MethodType(configure_optimizer, model) + + trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) + trainer.fit(model) + + saved_location = os.path.join( + trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" + ) + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id, saved_location) + + +if __name__ == "__main__": + main() diff --git a/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml similarity index 83% rename from cli/train_config/config.yaml rename to enhancer/cli/train_config/config.yaml index 61551bd..c0b2cf6 100644 --- a/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -4,4 +4,4 @@ defaults: - optimizer : Adam - hyperparameters : default - trainer : default - - mlflow : experiment \ No newline at end of file + - mlflow : experiment diff --git a/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml similarity index 99% rename from cli/train_config/dataset/DNS-2020.yaml rename to enhancer/cli/train_config/dataset/DNS-2020.yaml index f59cb2b..3bd0e67 100644 --- a/cli/train_config/dataset/DNS-2020.yaml +++ b/enhancer/cli/train_config/dataset/DNS-2020.yaml @@ -10,4 +10,3 @@ files: test_clean : clean_test_wav train_noisy : clean_test_wav test_noisy : clean_test_wav - diff --git a/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml similarity index 99% rename from cli/train_config/dataset/Vctk.yaml rename to enhancer/cli/train_config/dataset/Vctk.yaml index 129d9a8..5c19320 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -10,6 +10,3 @@ files: test_clean : clean_testset_wav train_noisy : noisy_trainset_28spk_wav test_noisy : noisy_testset_wav - - - diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml new file mode 100644 index 0000000..ba44597 --- /dev/null +++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml @@ -0,0 +1,13 @@ +_target_: enhancer.data.dataset.EnhancerDataset +name : vctk +root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk +duration : 1.0 +sampling_rate: 16000 +batch_size: 64 +num_workers : 0 + +files: + train_clean : clean_testset_wav + test_clean : clean_testset_wav + train_noisy : noisy_testset_wav + test_noisy : noisy_testset_wav diff --git a/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml similarity index 99% rename from cli/train_config/hyperparameters/default.yaml rename to enhancer/cli/train_config/hyperparameters/default.yaml index 82ac3c2..7e4cda3 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -5,4 +5,3 @@ ReduceLr_patience : 5 ReduceLr_factor : 0.1 min_lr : 0.000001 EarlyStopping_factor : 10 - diff --git a/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml similarity index 64% rename from cli/train_config/mlflow/experiment.yaml rename to enhancer/cli/train_config/mlflow/experiment.yaml index 2995c60..e8893f6 100644 --- a/cli/train_config/mlflow/experiment.yaml +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ experiment_name : shahules/enhancer -run_name : baseline \ No newline at end of file +run_name : baseline diff --git a/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml similarity index 98% rename from cli/train_config/model/Demucs.yaml rename to enhancer/cli/train_config/model/Demucs.yaml index 1006e71..d91b5ff 100644 --- a/cli/train_config/model/Demucs.yaml +++ b/enhancer/cli/train_config/model/Demucs.yaml @@ -14,5 +14,3 @@ encoder_decoder: lstm: bidirectional: False num_layers: 2 - - diff --git a/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml similarity index 100% rename from cli/train_config/model/WaveUnet.yaml rename to enhancer/cli/train_config/model/WaveUnet.yaml diff --git a/cli/train_config/optimizer/Adam.yaml b/enhancer/cli/train_config/optimizer/Adam.yaml similarity index 100% rename from cli/train_config/optimizer/Adam.yaml rename to enhancer/cli/train_config/optimizer/Adam.yaml diff --git a/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml similarity index 100% rename from cli/train_config/trainer/default.yaml rename to enhancer/cli/train_config/trainer/default.yaml diff --git a/cli/train_config/trainer/fastrun_dev.yaml b/enhancer/cli/train_config/trainer/fastrun_dev.yaml similarity index 100% rename from cli/train_config/trainer/fastrun_dev.yaml rename to enhancer/cli/train_config/trainer/fastrun_dev.yaml diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py index e69de29..7efd946 100644 --- a/enhancer/data/__init__.py +++ b/enhancer/data/__init__.py @@ -0,0 +1 @@ +from enhancer.data.dataset import EnhancerDataset diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 4c485c8..95c73a1 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,20 +1,21 @@ -import multiprocessing import math +import multiprocessing import os -import pytorch_lightning as pl -from torch.utils.data import IterableDataset, DataLoader, Dataset -import torch.nn.functional as F from typing import Optional +import pytorch_lightning as pl +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + from enhancer.data.fileprocessor import Fileprocessor -from enhancer.utils.random import create_unique_rng -from enhancer.utils.io import Audio from enhancer.utils import check_files from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.random import create_unique_rng + class TrainDataset(IterableDataset): - - def __init__(self,dataset): + def __init__(self, dataset): self.dataset = dataset def __iter__(self): @@ -23,88 +24,102 @@ class TrainDataset(IterableDataset): def __len__(self): return self.dataset.train__len__() + class ValidDataset(Dataset): - - def __init__(self,dataset): + def __init__(self, dataset): self.dataset = dataset - def __getitem__(self,idx): + def __getitem__(self, idx): return self.dataset.val__getitem__(idx) def __len__(self): return self.dataset.val__len__() -class TaskDataset(pl.LightningDataModule): +class TaskDataset(pl.LightningDataModule): def __init__( self, - name:str, - root_dir:str, - files:Files, - duration:float=1.0, - sampling_rate:int=48000, - matching_function = None, + name: str, + root_dir: str, + files: Files, + duration: float = 1.0, + sampling_rate: int = 48000, + matching_function=None, batch_size=32, - num_workers:Optional[int]=None): + num_workers: Optional[int] = None, + ): super().__init__() self.name = name - self.files,self.root_dir = check_files(root_dir,files) + self.files, self.root_dir = check_files(root_dir, files) self.duration = duration self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function self._validation = [] if num_workers is None: - num_workers = multiprocessing.cpu_count()//2 + num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers def setup(self, stage: Optional[str] = None): - if stage in ("fit",None): + if stage in ("fit", None): - train_clean = os.path.join(self.root_dir,self.files.train_clean) - train_noisy = os.path.join(self.root_dir,self.files.train_noisy) - fp = Fileprocessor.from_name(self.name,train_clean, - train_noisy, self.matching_function) + train_clean = os.path.join(self.root_dir, self.files.train_clean) + train_noisy = os.path.join(self.root_dir, self.files.train_noisy) + fp = Fileprocessor.from_name( + self.name, train_clean, train_noisy, self.matching_function + ) self.train_data = fp.prepare_matching_dict() - - val_clean = os.path.join(self.root_dir,self.files.test_clean) - val_noisy = os.path.join(self.root_dir,self.files.test_noisy) - fp = Fileprocessor.from_name(self.name,val_clean, - val_noisy, self.matching_function) + + val_clean = os.path.join(self.root_dir, self.files.test_clean) + val_noisy = os.path.join(self.root_dir, self.files.test_noisy) + fp = Fileprocessor.from_name( + self.name, val_clean, val_noisy, self.matching_function + ) val_data = fp.prepare_matching_dict() for item in val_data: - clean,noisy,total_dur = item.values() + clean, noisy, total_dur = item.values() if total_dur < self.duration: continue - num_segments = round(total_dur/self.duration) + num_segments = round(total_dur / self.duration) for index in range(num_segments): start_time = index * self.duration - self._validation.append(({"clean":clean,"noisy":noisy}, - start_time)) + self._validation.append( + ({"clean": clean, "noisy": noisy}, start_time) + ) + def train_dataloader(self): - return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) + return DataLoader( + TrainDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) def val_dataloader(self): - return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) + return DataLoader( + ValidDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + class EnhancerDataset(TaskDataset): """ Dataset object for creating clean-noisy speech enhancement datasets paramters: name : str - name of the dataset + name of the dataset root_dir : str root directory of the dataset containing clean/noisy folders files : Files - dataclass containing train_clean, train_noisy, test_clean, test_noisy - folder names (refer cli/train_config/dataset) + dataclass containing train_clean, train_noisy, test_clean, test_noisy + folder names (refer enhancer.utils.Files dataclass) duration : float expected audio duration of single audio sample for training sampling_rate : int - desired sampling rate + desired sampling rate batch_size : int batch size of each batch num_workers : int @@ -114,71 +129,92 @@ class EnhancerDataset(TaskDataset): use one_to_one mapping for datasets with one noisy file for each clean file use one_to_many mapping for multiple noisy files for each clean file - + """ def __init__( self, - name:str, - root_dir:str, - files:Files, + name: str, + root_dir: str, + files: Files, duration=1.0, sampling_rate=48000, matching_function=None, batch_size=32, - num_workers:Optional[int]=None): - + num_workers: Optional[int] = None, + ): + super().__init__( name=name, root_dir=root_dir, files=files, sampling_rate=sampling_rate, duration=duration, - matching_function = matching_function, + matching_function=matching_function, batch_size=batch_size, - num_workers = num_workers, - + num_workers=num_workers, ) self.sampling_rate = sampling_rate self.files = files - self.duration = max(1.0,duration) - self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True) + self.duration = max(1.0, duration) + self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) + + def setup(self, stage: Optional[str] = None): - def setup(self, stage:Optional[str]=None): - super().setup(stage=stage) def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) - + rng = create_unique_rng(self.model.current_epoch) + while True: - file_dict,*_ = rng.choices(self.train_data,k=1, - weights=[file["duration"] for file in self.train_data]) - file_duration = file_dict['duration'] - start_time = round(rng.uniform(0,file_duration- self.duration),2) - data = self.prepare_segment(file_dict,start_time) + file_dict, *_ = rng.choices( + self.train_data, + k=1, + weights=[file["duration"] for file in self.train_data], + ) + file_duration = file_dict["duration"] + start_time = round(rng.uniform(0, file_duration - self.duration), 2) + data = self.prepare_segment(file_dict, start_time) yield data - def val__getitem__(self,idx): + def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) - - def prepare_segment(self,file_dict:dict, start_time:float): - clean_segment = self.audio(file_dict["clean"], - offset=start_time,duration=self.duration) - noisy_segment = self.audio(file_dict["noisy"], - offset=start_time,duration=self.duration) - clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1]))) - noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1]))) - return {"clean": clean_segment,"noisy":noisy_segment} - + def prepare_segment(self, file_dict: dict, start_time: float): + + clean_segment = self.audio( + file_dict["clean"], offset=start_time, duration=self.duration + ) + noisy_segment = self.audio( + file_dict["noisy"], offset=start_time, duration=self.duration + ) + clean_segment = F.pad( + clean_segment, + ( + 0, + int( + self.duration * self.sampling_rate - clean_segment.shape[-1] + ), + ), + ) + noisy_segment = F.pad( + noisy_segment, + ( + 0, + int( + self.duration * self.sampling_rate - noisy_segment.shape[-1] + ), + ), + ) + return {"clean": clean_segment, "noisy": noisy_segment} + def train__len__(self): - return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) + return math.ceil( + sum([file["duration"] for file in self.train_data]) / self.duration + ) def val__len__(self): return len(self._validation) - - diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index eab41a0..66d4d75 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,108 +1,118 @@ import glob import os -from re import S + import numpy as np from scipy.io import wavfile -MATCHING_FNS = ("one_to_one","one_to_many") +MATCHING_FNS = ("one_to_one", "one_to_many") + class ProcessorFunctions: + """ + Preprocessing methods for different types of speech enhacement datasets. + """ @staticmethod - def one_to_one(clean_path,noisy_path): + def one_to_one(clean_path, noisy_path): """ One clean audio can have only one noisy audio file """ matching_wavfiles = list() - clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] - noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))] - common_filenames = np.intersect1d(noisy_filenames,clean_filenames) + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] + noisy_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(noisy_path, "*.wav")) + ] + common_filenames = np.intersect1d(noisy_filenames, clean_filenames) for file_name in common_filenames: - sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) - sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) - if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr_noisy)): + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, file_name) + ) + sr_noisy, noisy_file = wavfile.read( + os.path.join(noisy_path, file_name) + ) + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): matching_wavfiles.append( - {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), - "duration":clean_file.shape[-1]/sr_clean} - ) + { + "clean": os.path.join(clean_path, file_name), + "noisy": os.path.join(noisy_path, file_name), + "duration": clean_file.shape[-1] / sr_clean, + } + ) return matching_wavfiles @staticmethod - def one_to_many(clean_path,noisy_path): + def one_to_many(clean_path, noisy_path): """ One clean audio have multiple noisy audio files """ - + matching_wavfiles = dict() - clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] for clean_file in clean_filenames: - noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav")) + noisy_filenames = glob.glob( + os.path.join(noisy_path, f"*_{clean_file}.wav") + ) for noisy_file in noisy_filenames: - sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, clean_file) + ) sr_noisy, noisy_file = wavfile.read(noisy_file) - if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr_noisy)): + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): matching_wavfiles.update( - {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, - "duration":clean_file.shape[-1]/sr_clean} - ) + { + "clean": os.path.join(clean_path, clean_file), + "noisy": noisy_file, + "duration": clean_file.shape[-1] / sr_clean, + } + ) return matching_wavfiles class Fileprocessor: - - def __init__( - self, - clean_dir, - noisy_dir, - matching_function = None - ): + def __init__(self, clean_dir, noisy_dir, matching_function=None): self.clean_dir = clean_dir self.noisy_dir = noisy_dir self.matching_function = matching_function @classmethod - def from_name(cls, - name:str, - clean_dir, - noisy_dir, - matching_function=None - ): + def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): if matching_function is None: if name.lower() == "vctk": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) else: if matching_function not in MATCHING_FNS: - raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") + raise ValueError( + f"Invalid matching function! Avaialble options are {MATCHING_FNS}" + ) else: - return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) - - + return cls( + clean_dir, + noisy_dir, + getattr(ProcessorFunctions, matching_function), + ) def prepare_matching_dict(self): if self.matching_function is None: raise ValueError("Not a valid matching function") - return self.matching_function(self.clean_dir,self.noisy_dir) - - - - - - - - - - - - + return self.matching_function(self.clean_dir, self.noisy_dir) diff --git a/enhancer/inference.py b/enhancer/inference.py index fd2f57a..ae399f1 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -1,119 +1,168 @@ -from json import load -import wave +from pathlib import Path +from typing import Optional, Union + import numpy as np -from scipy.signal import get_window -from scipy.io import wavfile -from typing import List, Optional, Union import torch import torch.nn.functional as F -from pathlib import Path from librosa import load as load_audio +from scipy.io import wavfile +from scipy.signal import get_window from enhancer.utils import Audio + class Inference: + """ + contains methods used for inference. + """ @staticmethod def read_input(audio, sr, model_sr): + """ + read and verify audio input regardless of the input format. + arguments: + audio : audio input + sr : sampling rate of input audio + model_sr : sampling rate used for model training. + """ - if isinstance(audio,(np.ndarray,torch.Tensor)): + if isinstance(audio, (np.ndarray, torch.Tensor)): assert sr is not None, "Invalid sampling rate!" if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) - if isinstance(audio,str): + if isinstance(audio, str): audio = Path(audio) if not audio.is_file(): raise ValueError(f"Input file {audio} does not exist") else: - audio,sr = load_audio(audio,sr=sr,) + audio, sr = load_audio( + audio, + sr=sr, + ) if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: - assert audio.shape[0] == 1, "Enhance inference only supports single waveform" + assert ( + audio.shape[0] == 1 + ), "Enhance inference only supports single waveform" - waveform = Audio.resample_audio(audio,sr=sr,target_sr=model_sr) + waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr) waveform = Audio.convert_mono(waveform) - if isinstance(waveform,np.ndarray): + if isinstance(waveform, np.ndarray): waveform = torch.from_numpy(waveform) return waveform @staticmethod - def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None): + def batchify( + waveform: torch.Tensor, + window_size: int, + step_size: Optional[int] = None, + ): """ - break input waveform into samples with duration specified. + break input waveform into samples with duration specified.(Overlap-add) + arguments: + waveform : audio waveform + window_size : window size used for splitting waveform into batches + step_size : step_size used for splitting waveform into batches """ - assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" - _,num_samples = waveform.shape + assert ( + waveform.ndim == 2 + ), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" + _, num_samples = waveform.shape waveform = waveform.unsqueeze(-1) - step_size = window_size//2 if step_size is None else step_size + step_size = window_size // 2 if step_size is None else step_size if num_samples >= window_size: - waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1), - stride=(step_size,1), padding=(window_size,0)) - waveform_batch = waveform_batch.permute(2,0,1) - - + waveform_batch = F.unfold( + waveform[None, ...], + kernel_size=(window_size, 1), + stride=(step_size, 1), + padding=(window_size, 0), + ) + waveform_batch = waveform_batch.permute(2, 0, 1) + return waveform_batch @staticmethod - def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None, - window="hanning",): + def aggreagate( + data: torch.Tensor, + window_size: int, + total_frames: int, + step_size: Optional[int] = None, + window="hanning", + ): """ - takes input as tensor outputs aggregated waveform + stitch batched waveform into single waveform. (Overlap-add) + arguments: + data: batched waveform + window_size : window_size used to batch waveform + step_size : step_size used to batch waveform + total_frames : total number of frames present in original waveform + window : type of window used for overlap-add mechanism. """ - num_chunks,n_channels,num_frames = data.shape - window = get_window(window=window,Nx=data.shape[-1]) + num_chunks, n_channels, num_frames = data.shape + window = get_window(window=window, Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window - step_size = window_size//2 if step_size is None else step_size + step_size = window_size // 2 if step_size is None else step_size + data = data.permute(1, 2, 0) + data = F.fold( + data, + (total_frames, 1), + kernel_size=(window_size, 1), + stride=(step_size, 1), + padding=(window_size, 0), + ).squeeze(-1) - data = data.permute(1,2,0) - data = F.fold(data, - (total_frames,1), - kernel_size=(window_size,1), - stride=(step_size,1), - padding=(window_size,0)).squeeze(-1) - - return data.reshape(1,n_channels,-1) + return data.reshape(1, n_channels, -1) @staticmethod - def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int): + def write_output( + waveform: torch.Tensor, filename: Union[str, Path], sr: int + ): + """ + write audio output as wav file + arguments: + waveform : audio waveform + filename : name of the wave file. Output will be written as cleaned_filename.wav + sr : sampling rate + """ - if isinstance(filename,str): + if isinstance(filename, str): filename = Path(filename) - parent, name = filename.parent, "cleaned_"+filename.name - filename = parent/Path(name) + parent, name = filename.parent, "cleaned_" + filename.name + filename = parent / Path(name) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: - if isinstance(waveform,torch.Tensor): - waveform = waveform.detach().cpu().squeeze().numpy() - wavfile.write(filename,rate=sr,data=waveform) + wavfile.write(filename, rate=sr, data=waveform.detach().cpu()) @staticmethod - def prepare_output(waveform:torch.Tensor, model_sampling_rate:int, - audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int] + def prepare_output( + waveform: torch.Tensor, + model_sampling_rate: int, + audio: Union[str, np.ndarray, torch.Tensor], + sampling_rate: Optional[int], ): - if isinstance(audio,np.ndarray): + """ + prepare output audio based on input format + arguments: + waveform : predicted audio waveform + model_sampling_rate : sampling rate used to train the model + audio : input audio + sampling_rate : input audio sampling rate + + """ + if isinstance(audio, np.ndarray): waveform = waveform.detach().cpu().numpy() - if sampling_rate!=None: - waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate) + if sampling_rate is not None: + waveform = Audio.resample_audio( + waveform, sr=model_sampling_rate, target_sr=sampling_rate + ) - return waveform - - - - - - - - - - - - \ No newline at end of file + return waveform diff --git a/enhancer/loss.py b/enhancer/loss.py index ef33161..db1d222 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,62 +3,82 @@ import torch.nn as nn class mean_squared_error(nn.Module): + """ + Mean squared error / L1 loss + """ - def __init__(self,reduction="mean"): + def __init__(self, reduction="mean"): super().__init__() self.loss_fun = nn.MSELoss(reduction=reduction) self.higher_better = False - def forward(self,prediction:torch.Tensor, target: torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) return self.loss_fun(prediction, target) -class mean_absolute_error(nn.Module): - def __init__(self,reduction="mean"): +class mean_absolute_error(nn.Module): + """ + Mean absolute error / L2 loss + """ + + def __init__(self, reduction="mean"): super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) self.higher_better = False - def forward(self, prediction:torch.Tensor, target: torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) return self.loss_fun(prediction, target) -class Si_SDR(nn.Module): - def __init__( - self, - reduction:str="mean" - ): +class Si_SDR(nn.Module): + """ + SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) + """ + + def __init__(self, reduction: str = "mean"): super().__init__() - if reduction in ["sum","mean",None]: + if reduction in ["sum", "mean", None]: self.reduction = reduction else: - raise TypeError("Invalid reduction, valid options are sum, mean, None") + raise TypeError( + "Invalid reduction, valid options are sum, mean, None" + ) self.higher_better = False - def forward(self,prediction:torch.Tensor, target:torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") - - target_energy = torch.sum(target**2,keepdim=True,dim=-1) - scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + target_energy = torch.sum(target**2, keepdim=True, dim=-1) + scaling_factor = ( + torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy + ) target_projection = target * scaling_factor noise = prediction - target_projection - ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1) - si_sdr = 10*torch.log10(ratio).mean(dim=-1) + ratio = torch.sum(target_projection**2, dim=-1) / torch.sum( + noise**2, dim=-1 + ) + si_sdr = 10 * torch.log10(ratio).mean(dim=-1) if self.reduction == "sum": si_sdr = si_sdr.sum() @@ -66,46 +86,55 @@ class Si_SDR(nn.Module): si_sdr = si_sdr.mean() else: pass - + return si_sdr - class Avergeloss(nn.Module): + """ + Combine multiple metics of same nature. + for example, ["mea","mae"] + parameters: + losses : loss function names to be combined + """ - def __init__(self,losses): + def __init__(self, losses): super().__init__() self.valid_losses = nn.ModuleList() - - direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] + + direction = [ + getattr(LOSS_MAP[loss](), "higher_better") for loss in losses + ] if len(set(direction)) > 1: - raise ValueError("all cost functions should be of same nature, maximize or minimize!") + raise ValueError( + "all cost functions should be of same nature, maximize or minimize!" + ) self.higher_better = direction[0] for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) - - def validate_loss(self,loss:str): + def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): - raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}") + raise ValueError( + f"""Invalid loss function {loss}, available loss functions are + {tuple([loss for loss in LOSS_MAP.keys()])}""" + ) else: return LOSS_MAP[loss] - def forward(self,prediction:torch.Tensor, target:torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): loss = 0.0 for loss_fun in self.valid_losses: loss += loss_fun(prediction, target) - + return loss - - - -LOSS_MAP = {"mae":mean_absolute_error, - "mse": mean_squared_error, - "SI-SDR":Si_SDR} - +LOSS_MAP = { + "mae": mean_absolute_error, + "mse": mean_squared_error, + "SI-SDR": Si_SDR, +} diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 534a608..2d97568 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,3 +1,3 @@ from enhancer.models.demucs import Demucs +from enhancer.models.model import Model from enhancer.models.waveunet import WaveUnet -from enhancer.models.model import Model \ No newline at end of file diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 7c9d8ff..65f119d 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,217 +1,264 @@ import logging -from typing import Optional, Union, List -from torch import nn -import torch.nn.functional as F import math +from typing import List, Optional, Union + +import torch.nn.functional as F +from torch import nn -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict + class DemucsLSTM(nn.Module): def __init__( self, - input_size:int, - hidden_size:int, - num_layers:int, - bidirectional:bool=True - + input_size: int, + hidden_size: int, + num_layers: int, + bidirectional: bool = True, ): super().__init__() - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + self.lstm = nn.LSTM( + input_size, hidden_size, num_layers, bidirectional=bidirectional + ) dim = 2 if bidirectional else 1 - self.linear = nn.Linear(dim*hidden_size,hidden_size) + self.linear = nn.Linear(dim * hidden_size, hidden_size) - def forward(self,x): + def forward(self, x): - output,(h,c) = self.lstm(x) + output, (h, c) = self.lstm(x) output = self.linear(output) - return output,(h,c) + return output, (h, c) class DemucsEncoder(nn.Module): - def __init__( self, - num_channels:int, - hidden_size:int, - kernel_size:int, - stride:int=1, - glu:bool=False, + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.encoder = nn.Sequential( - nn.Conv1d(num_channels,hidden_size,kernel_size,stride), + nn.Conv1d(num_channels, hidden_size, kernel_size, stride), nn.ReLU(), - nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), - activation + nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), + activation, ) - def forward(self,waveform): - + def forward(self, waveform): + return self.encoder(waveform) -class DemucsDecoder(nn.Module): +class DemucsDecoder(nn.Module): def __init__( self, - num_channels:int, - hidden_size:int, - kernel_size:int, - stride:int=1, - glu:bool=False, - layer:int=0 + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, + layer: int = 0, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.decoder = nn.Sequential( - nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), + nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), activation, - nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) + nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), ) - if layer>0: + if layer > 0: self.decoder.add_module("4", nn.ReLU()) - def forward(self,waveform,): + def forward( + self, + waveform, + ): out = self.decoder(waveform) return out class Demucs(Model): + """ + Demucs model from https://arxiv.org/pdf/1911.13254.pdf + parameters: + encoder_decoder: dict, optional + keyword arguments passsed to encoder decoder block + lstm : dict, optional + keyword arguments passsed to LSTM block + num_channels: int, defaults to 1 + number channels in input audio + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + + """ ED_DEFAULTS = { - "initial_output_channels":48, - "kernel_size":8, - "stride":1, - "depth":5, - "glu":True, - "growth_factor":2, + "initial_output_channels": 48, + "kernel_size": 8, + "stride": 1, + "depth": 5, + "glu": True, + "growth_factor": 2, } LSTM_DEFAULTS = { - "bidirectional":True, - "num_layers":2, + "bidirectional": True, + "num_layers": 2, } - + def __init__( self, - encoder_decoder:Optional[dict]=None, - lstm:Optional[dict]=None, - num_channels:int=1, - resample:int=4, - sampling_rate = 16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - loss:Union[str, List] = "mse", - metric:Union[str, List] = "mse" - - + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + num_channels: int = 1, + resample: int = 4, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + loss: Union[str, List] = "mse", + metric: Union[str, List] = "mse", ): - duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) if dataset is not None: - if sampling_rate!=dataset.sampling_rate: - logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + if sampling_rate != dataset.sampling_rate: + logging.warn( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) sampling_rate = dataset.sampling_rate - super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss, metric=metric) - - encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) - lstm = merge_dict(self.LSTM_DEFAULTS,lstm) - self.save_hyperparameters("encoder_decoder","lstm","resample") + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + self.save_hyperparameters("encoder_decoder", "lstm", "resample") hidden = encoder_decoder["initial_output_channels"] self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for layer in range(encoder_decoder["depth"]): - encoder_layer = DemucsEncoder(num_channels=num_channels, - hidden_size=hidden, - kernel_size=encoder_decoder["kernel_size"], - stride=encoder_decoder["stride"], - glu=encoder_decoder["glu"], - ) + encoder_layer = DemucsEncoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=encoder_decoder["stride"], + glu=encoder_decoder["glu"], + ) self.encoder.append(encoder_layer) - decoder_layer = DemucsDecoder(num_channels=num_channels, - hidden_size=hidden, - kernel_size=encoder_decoder["kernel_size"], - stride=1, - glu=encoder_decoder["glu"], - layer=layer - ) - self.decoder.insert(0,decoder_layer) + decoder_layer = DemucsDecoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=1, + glu=encoder_decoder["glu"], + layer=layer, + ) + self.decoder.insert(0, decoder_layer) num_channels = hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden - - self.de_lstm = DemucsLSTM(input_size=num_channels, - hidden_size=num_channels, - num_layers=lstm["num_layers"], - bidirectional=lstm["bidirectional"] - ) - def forward(self,waveform): + self.de_lstm = DemucsLSTM( + input_size=num_channels, + hidden_size=num_channels, + num_layers=lstm["num_layers"], + bidirectional=lstm["bidirectional"], + ) + + def forward(self, waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1)!=1: - raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") + if waveform.size(1) != 1: + raise TypeError( + f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + ) length = waveform.shape[-1] - x = F.pad(waveform, (0,self.get_padding_length(length) - length)) - if self.hparams.resample>1: - x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, - target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) - + x = F.pad(waveform, (0, self.get_padding_length(length) - length)) + if self.hparams.resample > 1: + x = audio.resample_audio( + audio=x, + sr=self.hparams.sampling_rate, + target_sr=int( + self.hparams.sampling_rate * self.hparams.resample + ), + ) + encoder_outputs = [] for encoder in self.encoder: x = encoder(x) encoder_outputs.append(x) - x = x.permute(0,2,1) - x,_ = self.de_lstm(x) + x = x.permute(0, 2, 1) + x, _ = self.de_lstm(x) - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) for decoder in self.decoder: skip_connection = encoder_outputs.pop(-1) - x += skip_connection[..., :x.shape[-1]] + x += skip_connection[..., : x.shape[-1]] x = decoder(x) - + if self.hparams.resample > 1: - x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), - self.hparams.sampling_rate) + x = audio.resample_audio( + x, + int(self.hparams.sampling_rate * self.hparams.resample), + self.hparams.sampling_rate, + ) return x - - def get_padding_length(self,input_length): + + def get_padding_length(self, input_length): input_length = math.ceil(input_length * self.hparams.resample) - - for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation - input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 - input_length = max(1,input_length) - for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration - input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] - input_length = math.ceil(input_length/self.hparams.resample) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # encoder operation + input_length = ( + math.ceil( + (input_length - self.hparams.encoder_decoder["kernel_size"]) + / self.hparams.encoder_decoder["stride"] + ) + + 1 + ) + input_length = max(1, input_length) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # decoder operaration + input_length = (input_length - 1) * self.hparams.encoder_decoder[ + "stride" + ] + self.hparams.encoder_decoder["kernel_size"] + input_length = math.ceil(input_length / self.hparams.resample) return int(input_length) - - - - - - - - - - - - - \ No newline at end of file diff --git a/enhancer/models/model.py b/enhancer/models/model.py index de2edab..39dbe80 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,92 +1,108 @@ -try: - from functools import cached_property -except ImportError: - from backports.cached_property import cached_property -from importlib import import_module -from huggingface_hub import cached_download, hf_hub_url -import logging -import numpy as np import os -from typing import Optional, Union, List, Text, Dict, Any -from torch.optim import Adam -import torch -from torch.nn.functional import pad -import pytorch_lightning as pl -from pytorch_lightning.utilities.cloud_io import load as pl_load -from urllib.parse import urlparse +from importlib import import_module from pathlib import Path +from typing import Any, Dict, List, Optional, Text, Union +from urllib.parse import urlparse +import numpy as np +import pytorch_lightning as pl +import torch +from huggingface_hub import cached_download, hf_hub_url +from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch.optim import Adam from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset -from enhancer.utils.io import Audio -from enhancer.loss import Avergeloss from enhancer.inference import Inference +from enhancer.loss import Avergeloss CACHE_DIR = "" HF_TORCH_WEIGHTS = "" DEFAULT_DEVICE = "cpu" + class Model(pl.LightningModule): + """ + Base class for all models + parameters: + num_channels: int, default to 1 + number of channels in input audio + sampling_rate : int, default 16khz + audio sampling rate + lr: float, optional + learning rate for model training + dataset: EnhancerDataset, optional + Enhancer dataset used for training/validation + duration: float, optional + duration used for training/inference + loss : string or List of strings, default to "mse" + loss functions to be used. Available ("mse","mae","Si-SDR") + + """ def __init__( self, - num_channels:int=1, - sampling_rate:int=16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - duration:Optional[float]=None, + num_channels: int = 1, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric:Union[str,List] = "mse" + metric: Union[str, List] = "mse", ): super().__init__() - assert num_channels ==1 , "Enhancer only support for mono channel models" + assert ( + num_channels == 1 + ), "Enhancer only support for mono channel models" self.dataset = dataset - self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") + self.save_hyperparameters( + "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" + ) if self.logger: - self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") - + self.logger.experiment.log_dict( + dict(self.hparams), "hyperparameters.json" + ) + self.loss = loss self.metric = metric @property def loss(self): return self._loss - - @loss.setter - def loss(self,loss): - if isinstance(loss,str): - losses = [loss] - + @loss.setter + def loss(self, loss): + + if isinstance(loss, str): + losses = [loss] + self._loss = Avergeloss(losses) @property def metric(self): return self._metric - + @metric.setter - def metric(self,metric): + def metric(self, metric): + + if isinstance(metric, str): + metric = [metric] - if isinstance(metric,str): - metric = [metric] - self._metric = Avergeloss(metric) - @property def dataset(self): return self._dataset @dataset.setter - def dataset(self,dataset): + def dataset(self, dataset): self._dataset = dataset - def setup(self,stage:Optional[str]=None): + def setup(self, stage: Optional[str] = None): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - + def train_dataloader(self): return self.dataset.train_dataloader() @@ -94,9 +110,9 @@ class Model(pl.LightningModule): return self.dataset.val_dataloader() def configure_optimizers(self): - return Adam(self.parameters(), lr = self.hparams.lr) + return Adam(self.parameters(), lr=self.hparams.lr) - def training_step(self,batch, batch_idx:int): + def training_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] @@ -105,13 +121,16 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="train_loss", value=loss.item(), - step=self.global_step) - self.log("train_loss",loss.item()) - return {"loss":loss} + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="train_loss", + value=loss.item(), + step=self.global_step, + ) + self.log("train_loss", loss.item()) + return {"loss": loss} - def validation_step(self,batch,batch_idx:int): + def validation_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] @@ -119,48 +138,92 @@ class Model(pl.LightningModule): metric_val = self.metric(prediction, target) loss_val = self.loss(prediction, target) - self.log("val_metric",metric_val.item()) - self.log("val_loss",loss_val.item()) + self.log("val_metric", metric_val.item()) + self.log("val_loss", loss_val.item()) if self.logger: - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_loss",value=loss_val.item(), - step=self.global_step) - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_metric",value=metric_val.item(), - step=self.global_step) + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="val_loss", + value=loss_val.item(), + step=self.global_step, + ) + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="val_metric", + value=metric_val.item(), + step=self.global_step, + ) - return {"loss":loss_val} + return {"loss": loss_val} def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = { - "version": { - "enhancer":__version__, - "pytorch":torch.__version__ + "version": {"enhancer": __version__, "pytorch": torch.__version__}, + "architecture": { + "module": self.__class__.__module__, + "class": self.__class__.__name__, }, - "architecture":{ - "module":self.__class__.__module__, - "class":self.__class__.__name__ - } - } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): pass - @classmethod def from_pretrained( cls, checkpoint: Union[Path, Text], - map_location = None, + map_location=None, hparams_file: Union[Path, Text] = None, strict: bool = True, use_auth_token: Union[Text, None] = None, - cached_dir: Union[Path, Text]=CACHE_DIR, - **kwargs + cached_dir: Union[Path, Text] = CACHE_DIR, + **kwargs, ): + """ + Load Pretrained model + + parameters: + checkpoint : Path or str + Path to checkpoint, or a remote URL, or a model identifier from + the huggingface.co model hub. + map_location: optional + Same role as in torch.load(). + Defaults to `lambda storage, loc: storage`. + hparams_file : Path or str, optional + Path to a .yaml file with hierarchical structure as in this example: + drop_prob: 0.2 + dataloader: + batch_size: 32 + You most likely won’t need this since Lightning will always save the + hyperparameters to the checkpoint. However, if your checkpoint weights + do not have the hyperparameters saved, use this method to pass in a .yaml + file with the hparams you would like to use. These will be converted + into a dict and passed into your Model for use. + strict : bool, optional + Whether to strictly enforce that the keys in checkpoint match + the keys returned by this module’s state dict. Defaults to True. + use_auth_token : str, optional + When loading a private huggingface.co model, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + cache_dir: Path or str, optional + Path to model cache directory. Defaults to content of PYANNOTE_CACHE + environment variable, or "~/.cache/torch/pyannote" when unset. + kwargs: optional + Any extra keyword args needed to init the model. + Can also be used to override saved hyperparameter values. + + Returns + ------- + model : Model + Model + + See also + -------- + torch.load + """ checkpoint = str(checkpoint) if hparams_file is not None: @@ -168,104 +231,133 @@ class Model(pl.LightningModule): if os.path.isfile(checkpoint): model_path_pl = checkpoint - elif urlparse(checkpoint).scheme in ("http","https"): + elif urlparse(checkpoint).scheme in ("http", "https"): model_path_pl = checkpoint else: - + if "@" in checkpoint: model_id = checkpoint.split("@")[0] revision_id = checkpoint.split("@")[1] else: model_id = checkpoint revision_id = None - + url = hf_hub_url( - model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id + model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id ) model_path_pl = cached_download( - url=url,library_name="enhancer",library_version=__version__, - cache_dir=cached_dir,use_auth_token=use_auth_token + url=url, + library_name="enhancer", + library_version=__version__, + cache_dir=cached_dir, + use_auth_token=use_auth_token, ) if map_location is None: map_location = torch.device(DEFAULT_DEVICE) - loaded_checkpoint = pl_load(model_path_pl,map_location) + loaded_checkpoint = pl_load(model_path_pl, map_location) module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] - class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) try: model = Klass.load_from_checkpoint( - checkpoint_path = model_path_pl, - map_location = map_location, - hparams_file = hparams_file, - strict = strict, - **kwargs + checkpoint_path=model_path_pl, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, ) except Exception as e: print(e) + return model - return model + def infer(self, batch: torch.Tensor, batch_size: int = 32): + """ + perform model inference + parameters: + batch : torch.Tensor + input data + batch_size : int, default 32 + batch size for inference + """ - def infer(self,batch:torch.Tensor,batch_size:int=32): - - assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" + assert ( + batch.ndim == 3 + ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) with torch.no_grad(): - for batch_id in range(0,batch.shape[0],batch_size): - batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + for batch_id in range(0, batch.shape[0], batch_size): + batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + self.device + ) prediction = self(batch_data) batch_predictions.append(prediction) - + return torch.vstack(batch_predictions) def enhance( self, - audio:Union[Path,np.ndarray,torch.Tensor], - sampling_rate:Optional[int]=None, - batch_size:int=32, - save_output:bool=False, - duration:Optional[int]=None, - step_size:Optional[int]=None,): + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + batch_size: int = 32, + save_output: bool = False, + duration: Optional[int] = None, + step_size: Optional[int] = None, + ): + """ + Enhance audio using loaded pretained model. + + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate: int, optional incase input is path + sampling rate of input + batch_size: int, default 32 + input audio is split into multiple chunks. Inference is done on batches + of these chunks according to given batch size. + save_output : bool, default False + weather to save output to file + duration : float, optional + chunk duration in seconds, defaults to duration of loaded pretrained model. + step_size: int, optional + step size between consecutive durations, defaults to 50% of duration + """ model_sampling_rate = self.hparams["sampling_rate"] if duration is None: duration = self.hparams["duration"] - waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) + waveform = Inference.read_input( + audio, sampling_rate, model_sampling_rate + ) waveform.to(self.device) window_size = round(duration * model_sampling_rate) - batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) - batch_prediction = self.infer(batched_waveform,batch_size=batch_size) - waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) - - if save_output and isinstance(audio,(str,Path)): - Inference.write_output(waveform,audio,model_sampling_rate) + batched_waveform = Inference.batchify( + waveform, window_size, step_size=step_size + ) + batch_prediction = self.infer(batched_waveform, batch_size=batch_size) + waveform = Inference.aggreagate( + batch_prediction, + window_size, + waveform.shape[-1], + step_size, + ) + + if save_output and isinstance(audio, (str, Path)): + Inference.write_output(waveform, audio, model_sampling_rate) else: - waveform = Inference.prepare_output(waveform, model_sampling_rate, - audio, sampling_rate) + waveform = Inference.prepare_output( + waveform, model_sampling_rate, audio, sampling_rate + ) return waveform + @property def valid_monitor(self): - return "max" if self.loss.higher_better else "min" - - - - - - - - - - - - - - - \ No newline at end of file + return "max" if self.loss.higher_better else "min" diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index f799352..ebb4b1f 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,82 +1,124 @@ import logging +from typing import List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Union, List -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model + class WavenetDecoder(nn.Module): - def __init__( self, - in_channels:int, - out_channels:int, - kernel_size:int=5, - padding:int=2, - stride:int=1, - dilation:int=1, + in_channels: int, + out_channels: int, + kernel_size: int = 5, + padding: int = 2, + stride: int = 1, + dilation: int = 1, ): - super(WavenetDecoder,self).__init__() + super(WavenetDecoder, self).__init__() self.decoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(negative_slope=0.1) + nn.LeakyReLU(negative_slope=0.1), ) - - def forward(self,waveform): - + + def forward(self, waveform): + return self.decoder(waveform) -class WavenetEncoder(nn.Module): +class WavenetEncoder(nn.Module): def __init__( self, - in_channels:int, - out_channels:int, - kernel_size:int=15, - padding:int=7, - stride:int=1, - dilation:int=1, + in_channels: int, + out_channels: int, + kernel_size: int = 15, + padding: int = 7, + stride: int = 1, + dilation: int = 1, ): - super(WavenetEncoder,self).__init__() + super(WavenetEncoder, self).__init__() self.encoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(negative_slope=0.1) + nn.LeakyReLU(negative_slope=0.1), ) - - def forward( - self, - waveform - ): + def forward(self, waveform): return self.encoder(waveform) class WaveUnet(Model): + """ + Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf + parameters: + num_channels: int, defaults to 1 + number of channels in input audio + depth : int, defaults to 12 + depth of network + initial_output_channels: int, defaults to 24 + number of output channels in intial upsampling layer + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + """ def __init__( self, - num_channels:int=1, - depth:int=12, - initial_output_channels:int=24, - sampling_rate:int=16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - duration:Optional[float]=None, + num_channels: int = 1, + depth: int = 12, + initial_output_channels: int = 24, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric:Union[str,List] = "mse" + metric: Union[str, List] = "mse", ): - duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) if dataset is not None: - if sampling_rate!=dataset.sampling_rate: - logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + if sampling_rate != dataset.sampling_rate: + logging.warn( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) sampling_rate = dataset.sampling_rate - super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss, metric=metric + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, ) self.save_hyperparameters("depth") self.encoders = nn.ModuleList() @@ -84,72 +126,76 @@ class WaveUnet(Model): out_channels = initial_output_channels for layer in range(depth): - encoder = WavenetEncoder(num_channels,out_channels) + encoder = WavenetEncoder(num_channels, out_channels) self.encoders.append(encoder) num_channels = out_channels out_channels += initial_output_channels - if layer == depth -1 : - decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) + if layer == depth - 1: + decoder = WavenetDecoder( + depth * initial_output_channels + num_channels, num_channels + ) else: - decoder = WavenetDecoder(num_channels+out_channels,num_channels) + decoder = WavenetDecoder( + num_channels + out_channels, num_channels + ) - self.decoders.insert(0,decoder) + self.decoders.insert(0, decoder) bottleneck_dim = depth * initial_output_channels self.bottleneck = nn.Sequential( - nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, - padding=7), + nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7), nn.BatchNorm1d(bottleneck_dim), - nn.LeakyReLU(negative_slope=0.1, inplace=True) + nn.LeakyReLU(negative_slope=0.1, inplace=True), ) self.final = nn.Sequential( nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), - nn.Tanh() + nn.Tanh(), ) - - def forward( - self,waveform - ): + def forward(self, waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1)!=1: - raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") + if waveform.size(1) != 1: + raise TypeError( + f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels" + ) encoder_outputs = [] out = waveform for encoder in self.encoders: out = encoder(out) - encoder_outputs.insert(0,out) - out = out[:,:,::2] - + encoder_outputs.insert(0, out) + out = out[:, :, ::2] + out = self.bottleneck(out) - for layer,decoder in enumerate(self.decoders): + for layer, decoder in enumerate(self.decoders): out = F.interpolate(out, scale_factor=2, mode="linear") - out = self.fix_last_dim(out,encoder_outputs[layer]) - out = torch.cat([out,encoder_outputs[layer]],dim=1) + out = self.fix_last_dim(out, encoder_outputs[layer]) + out = torch.cat([out, encoder_outputs[layer]], dim=1) out = decoder(out) - out = torch.cat([out, waveform],dim=1) + out = torch.cat([out, waveform], dim=1) out = self.final(out) return out - - def fix_last_dim(self,x,target): + + def fix_last_dim(self, x, target): """ - trying to do centre crop along last dimension + centre crop along last dimension """ - assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" + assert ( + x.shape[-1] >= target.shape[-1] + ), "input dimension cannot be larger than target dimension" if x.shape[-1] == target.shape[-1]: return x - + diff = x.shape[-1] - target.shape[-1] - if diff%2!=0: - x = F.pad(x,(0,1)) + if diff % 2 != 0: + x = F.pad(x, (0, 1)) diff += 1 - crop = diff//2 - return x[:,:,crop:-crop] + crop = diff // 2 + return x[:, :, crop:-crop] diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index 3da7ede..de0db9f 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,3 +1,3 @@ -from enhancer.utils.utils import check_files +from enhancer.utils.config import Files from enhancer.utils.io import Audio -from enhancer.utils.config import Files \ No newline at end of file +from enhancer.utils.utils import check_files diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index 1bbc51d..252e6c9 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -1,10 +1,9 @@ from dataclasses import dataclass + @dataclass class Files: - train_clean : str - train_noisy : str - test_clean : str - test_noisy : str - - + train_clean: str + train_noisy: str + test_clean: str + test_noisy: str diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index afc19e8..9e9ce32 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,41 +1,67 @@ import os +from pathlib import Path +from typing import Optional, Union + import librosa -from typing import Optional import numpy as np import torch import torchaudio + class Audio: + """ + Audio utils + parameters: + sampling_rate : int, defaults to 16KHz + audio sampling rate + mono: bool, defaults to True + return_tensors: bool, defaults to True + returns torch tensor type if set to True else numpy ndarray + """ def __init__( - self, - sampling_rate:int=16000, - mono:bool=True, - return_tensor=True + self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True ) -> None: - + self.sampling_rate = sampling_rate self.mono = mono self.return_tensor = return_tensor def __call__( self, - audio, - sampling_rate:Optional[int]=None, - offset:Optional[float] = None, - duration:Optional[float] = None + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + offset: Optional[float] = None, + duration: Optional[float] = None, ): - if isinstance(audio,str): + """ + read and process input audio + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate : int, optional + sampling rate of the audio input + offset: float, optional + offset from which the audio must be read, reads from beginning if unused. + duration: float (seconds), optional + read duration, reads full audio starting from offset if not used + """ + if isinstance(audio, str): if os.path.exists(audio): - audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, - offset=offset,duration=duration) + audio, sampling_rate = librosa.load( + audio, + sr=sampling_rate, + mono=False, + offset=offset, + duration=duration, + ) if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: raise FileNotFoundError(f"File {audio} deos not exist") - elif isinstance(audio,np.ndarray): + elif isinstance(audio, np.ndarray): if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: raise ValueError("audio should be either filepath or numpy ndarray") @@ -43,40 +69,60 @@ class Audio: audio = self.convert_mono(audio) if sampling_rate: - audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) + audio = self.__class__.resample_audio( + audio, self.sampling_rate, sampling_rate + ) if self.return_tensor: return torch.tensor(audio) else: return audio @staticmethod - def convert_mono( - audio + def convert_mono(audio: Union[np.ndarray, torch.Tensor]): + """ + convert input audio into mono (1) + parameters: + audio: np.ndarray or torch.Tensor + """ + if len(audio.shape) > 2: + assert ( + audio.shape[0] == 1 + ), "convert mono only accepts single waveform" + audio = audio.reshape(audio.shape[1], audio.shape[2]) - ): - if len(audio.shape)>2: - assert audio.shape[0] == 1, "convert mono only accepts single waveform" - audio = audio.reshape(audio.shape[1],audio.shape[2]) - - assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" - num_channels,num_samples = audio.shape - if num_channels>1: - return audio.mean(axis=0).reshape(1,num_samples) + assert ( + audio.shape[1] >> audio.shape[0] + ), f"expected input format (num_channels,num_samples) got {audio.shape}" + num_channels, num_samples = audio.shape + if num_channels > 1: + return audio.mean(axis=0).reshape(1, num_samples) return audio - @staticmethod def resample_audio( - audio, - sr:int, - target_sr:int + audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int ): - if sr!=target_sr: - if isinstance(audio,np.ndarray): - audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) - elif isinstance(audio,torch.Tensor): - audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) + """ + resample audio to desired sampling rate + parameters: + audio : Path to audio file or numpy array or torch tensor + audio waveform + sr : int + current sampling rate + target_sr : int + target sampling rate + + """ + if sr != target_sr: + if isinstance(audio, np.ndarray): + audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) + elif isinstance(audio, torch.Tensor): + audio = torchaudio.functional.resample( + audio, orig_freq=sr, new_freq=target_sr + ) else: - raise ValueError("Input should be either numpy array or torch tensor") + raise ValueError( + "Input should be either numpy array or torch tensor" + ) return audio diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 3b1acac..dd9395a 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -1,19 +1,19 @@ import os import random + import torch - -def create_unique_rng(epoch:int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() - global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) - global_rank = int(os.environ.get('GLOBAL_RANK',"0")) - local_rank = int(os.environ.get('LOCAL_RANK',"0")) - node_rank = int(os.environ.get('NODE_RANK',"0")) - world_size = int(os.environ.get('WORLD_SIZE',"0")) + global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) + global_rank = int(os.environ.get("GLOBAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + node_rank = int(os.environ.get("NODE_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "0")) worker_info = torch.utils.data.get_worker_info() if worker_info is not None: @@ -24,17 +24,13 @@ def create_unique_rng(epoch:int): worker_id = 0 seed = ( - global_seed - + worker_id - + local_rank * num_workers - + node_rank * num_workers * global_rank - + epoch * num_workers * world_size - ) + global_seed + + worker_id + + local_rank * num_workers + + node_rank * num_workers * global_rank + + epoch * num_workers * world_size + ) rng.seed(seed) return rng - - - - diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index be74dc2..ad45139 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -1,19 +1,26 @@ - import os from typing import Optional + from enhancer.utils.config import Files -def check_files(root_dir:str, files:Files): - path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] +def check_files(root_dir: str, files: Files): + + path_variables = [ + member_var + for member_var in dir(files) + if not member_var.startswith("__") + ] for variable in path_variables: - path = getattr(files,variable) - if not os.path.isdir(os.path.join(root_dir,path)): + path = getattr(files, variable) + if not os.path.isdir(os.path.join(root_dir, path)): raise ValueError(f"Invalid {path}, is not a directory") - - return files,root_dir -def merge_dict(default_dict:dict, custom:Optional[dict]=None): + return files, root_dir + + +def merge_dict(default_dict: dict, custom: Optional[dict] = None): + params = dict(default_dict) if custom: params.update(custom) diff --git a/environment.yml b/environment.yml index 4f211bf..8da22e1 100644 --- a/environment.yml +++ b/environment.yml @@ -5,4 +5,4 @@ dependencies: - python=3.8 - pip: - -r requirements.txt - - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html \ No newline at end of file + - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html diff --git a/hpc_entrypoint.sh b/hpc_entrypoint.sh index 7372eb9..6d6a3a0 100644 --- a/hpc_entrypoint.sh +++ b/hpc_entrypoint.sh @@ -33,7 +33,7 @@ mkdir temp pwd #python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train -#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test +#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test echo "Start Training..." python cli/train.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b3e5d7c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[tool.black] +line-length = 80 +target-version = ['py38'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.mypy_cache + | \.tox + | \.venv + )/ +) +''' diff --git a/requirements.txt b/requirements.txt index e7fcd24..afa3641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,16 @@ -joblib==1.1.0 -numpy==1.19.5 -librosa==0.9.1 -numpy==1.19.5 -hydra-core==1.2.0 -scikit-learn==0.24.2 -scipy==1.5.4 -torch==1.10.2 -tqdm==4.64.0 -mlflow==1.23.1 -protobuf==3.19.3 -boto3==1.23.9 -torchaudio==0.10.2 -huggingface-hub==0.4.0 -pytorch-lightning==1.5.10 +black>=22.8.0 +boto3>=1.24.86 +flake8>=5.0.4 +huggingface-hu>=0.10.0 +hydra-core>=1.2.0 +joblib>=1.2.0 +librosa>=0.9.2 +mlflow>=1.29.0 +numpy>=1.23.3 +protobuf>=3.19.6 +pytorch-lightning>=1.7.7 +scikit-learn>=1.1.2 +scipy>=1.9.1 +torch>=1.12.1 +torchaudio>=0.12.1 +tqdm>=4.64.1 diff --git a/setup.sh b/setup.sh index adad46c..43adc89 100644 --- a/setup.sh +++ b/setup.sh @@ -10,4 +10,4 @@ conda env create -f environment.yml || conda env update -f environment.yml source activate enhancer echo "copying files" -# cp /scratch/$USER/TIMIT/.* /deep-transcriber \ No newline at end of file +# cp /scratch/$USER/TIMIT/.* /deep-transcriber diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index fbc982c..4d14871 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -1,31 +1,32 @@ -from asyncio import base_tasks -import torch import pytest +import torch from enhancer.loss import mean_absolute_error, mean_squared_error loss_functions = [mean_absolute_error(), mean_squared_error()] + def check_loss_shapes_compatibility(loss_fun): batch_size = 4 - shape = (1,1000) - loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) + shape = (1, 1000) + loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape)) with pytest.raises(TypeError): - loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) + loss_fun(torch.rand(4, *shape), torch.rand(6, *shape)) -@pytest.mark.parametrize("loss",loss_functions) +@pytest.mark.parametrize("loss", loss_functions) def test_loss_input_shapes(loss): check_loss_shapes_compatibility(loss) -@pytest.mark.parametrize("loss",loss_functions) + +@pytest.mark.parametrize("loss", loss_functions) def test_loss_output_type(loss): batch_size = 4 - prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) + prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( + batch_size, 1, 1000 + ) loss_value = loss(prediction, target) - assert isinstance(loss_value.item(),float) - - + assert isinstance(loss_value.item(), float) diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index a59fa04..f5a0ec4 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,46 +1,43 @@ import pytest import torch -from enhancer import data -from enhancer.utils.config import Files -from enhancer.models import Demucs from enhancer.data.dataset import EnhancerDataset +from enhancer.models import Demucs +from enhancer.utils.config import Files @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = Demucs() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = Demucs(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + _ = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 43fd14d..9c4dd96 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,46 +1,43 @@ import pytest import torch -from enhancer import data -from enhancer.utils.config import Files -from enhancer.models import WaveUnet from enhancer.data.dataset import EnhancerDataset +from enhancer.models import WaveUnet +from enhancer.utils.config import Files @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = WaveUnet() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + _ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/test_inference.py b/tests/test_inference.py index 5eb7442..a6e2423 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4,22 +4,26 @@ import torch from enhancer.inference import Inference -@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) +@pytest.mark.parametrize( + "audio", + ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], +) def test_read_input(audio): - read_audio = Inference.read_input(audio,48000,16000) - assert isinstance(read_audio,torch.Tensor) + read_audio = Inference.read_input(audio, 48000, 16000) + assert isinstance(read_audio, torch.Tensor) assert read_audio.shape[0] == 1 + def test_batchify(): - rand = torch.rand(1,1000) - batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) + rand = torch.rand(1, 1000) + batched_rand = Inference.batchify(rand, window_size=100, step_size=100) assert batched_rand.shape[0] == 12 + def test_aggregate(): - rand = torch.rand(12,1,100) - agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) + rand = torch.rand(12, 1, 100) + agg_rand = Inference.aggreagate( + data=rand, window_size=100, total_frames=1000, step_size=100 + ) assert agg_rand.shape[-1] == 1000 - - - \ No newline at end of file diff --git a/tests/utils_test.py b/tests/utils_test.py index 413bfac..65c723d 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,46 +1,50 @@ -from logging import root +import numpy as np import pytest import torch -import numpy as np -from enhancer.utils.io import Audio -from enhancer.utils.config import Files from enhancer.data.fileprocessor import Fileprocessor +from enhancer.utils.io import Audio + def test_io_channel(): - input_audio = np.random.rand(2,32000) - audio = Audio(mono=True,return_tensor=False) + input_audio = np.random.rand(2, 32000) + audio = Audio(mono=True, return_tensor=False) output_audio = audio(input_audio) assert output_audio.shape[0] == 1 + def test_io_resampling(): - input_audio = np.random.rand(1,32000) - resampled_audio = Audio.resample_audio(input_audio,16000,8000) + input_audio = np.random.rand(1, 32000) + resampled_audio = Audio.resample_audio(input_audio, 16000, 8000) - input_audio = torch.rand(1,32000) - resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) + input_audio = torch.rand(1, 32000) + resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000) assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 + def test_fileprocessor_vctk(): - fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", - "tests/data/vctk/noisy_testset_wav",48000) + fp = Fileprocessor.from_name( + "vctk", + "tests/data/vctk/clean_testset_wav", + "tests/data/vctk/noisy_testset_wav", + 48000, + ) matching_dict = fp.prepare_matching_dict() - assert len(matching_dict)==2 + assert len(matching_dict) == 2 -@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"]) + +@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) def test_fileprocessor_names(dataset_name): - fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) - assert hasattr(fp.matching_function, '__call__') + fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) + assert hasattr(fp.matching_function, "__call__") + def test_fileprocessor_invaliname(): with pytest.raises(ValueError): - fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() - - - - - + _ = Fileprocessor.from_name( + "undefined", "clean_dir", "noisy_dir", 16000 + ).prepare_matching_dict()