From 7f00707733ffa9115c3556ec194c18fbea026fe6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:11:30 +0530 Subject: [PATCH 01/38] add doc/refactor black --- enhancer/inference.py | 150 +++++++++++++++++++++++++++++------------- 1 file changed, 106 insertions(+), 44 deletions(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 6e9cff7..27a9385 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -1,9 +1,7 @@ -from json import load -import wave import numpy as np from scipy.signal import get_window from scipy.io import wavfile -from typing import List, Optional, Union +from typing import Optional, Union import torch import torch.nn.functional as F from pathlib import Path @@ -11,89 +9,153 @@ from librosa import load as load_audio from enhancer.utils import Audio + class Inference: + """ + contains methods used for inference. + """ @staticmethod def read_input(audio, sr, model_sr): + """ + read and verify audio input regardless of the input format. + arguments: + audio : audio input + sr : sampling rate of input audio + model_sr : sampling rate used for model training. + """ - if isinstance(audio,(np.ndarray,torch.Tensor)): + if isinstance(audio, (np.ndarray, torch.Tensor)): assert sr is not None, "Invalid sampling rate!" - if isinstance(audio,str): + if isinstance(audio, str): audio = Path(audio) if not audio.is_file(): raise ValueError(f"Input file {audio} does not exist") else: - audio,sr = load_audio(audio,sr=sr,) + audio, sr = load_audio( + audio, + sr=sr, + ) if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: - assert audio.shape[0] == 1, "Enhance inference only supports single waveform" + assert ( + audio.shape[0] == 1 + ), "Enhance inference only supports single waveform" - waveform = Audio.resample_audio(audio,sr=sr,target_sr=model_sr) + waveform = Audio.resample_audio(audio, sr=sr, target_sr=model_sr) waveform = Audio.convert_mono(waveform) - if isinstance(waveform,np.ndarray): + if isinstance(waveform, np.ndarray): waveform = torch.from_numpy(waveform) return waveform @staticmethod - def batchify(waveform: torch.Tensor, window_size:int, step_size:Optional[int]=None): + def batchify( + waveform: torch.Tensor, + window_size: int, + step_size: Optional[int] = None, + ): """ - break input waveform into samples with duration specified. + break input waveform into samples with duration specified.(Overlap-add) + arguments: + waveform : audio waveform + window_size : window size used for splitting waveform into batches + step_size : step_size used for splitting waveform into batches """ - assert waveform.ndim == 2, f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" - _,num_samples = waveform.shape + assert ( + waveform.ndim == 2 + ), f"Expcted input waveform with 2 dimensions (channels,samples), got {waveform.ndim}" + _, num_samples = waveform.shape waveform = waveform.unsqueeze(-1) - step_size = window_size//2 if step_size is None else step_size + step_size = window_size // 2 if step_size is None else step_size if num_samples >= window_size: - waveform_batch = F.unfold(waveform[None,...], kernel_size=(window_size,1), - stride=(step_size,1), padding=(window_size,0)) - waveform_batch = waveform_batch.permute(2,0,1) - - + waveform_batch = F.unfold( + waveform[None, ...], + kernel_size=(window_size, 1), + stride=(step_size, 1), + padding=(window_size, 0), + ) + waveform_batch = waveform_batch.permute(2, 0, 1) + return waveform_batch @staticmethod - def aggreagate(data:torch.Tensor,window_size:int,total_frames:int,step_size:Optional[int]=None, - window="hanning",): + def aggreagate( + data: torch.Tensor, + window_size: int, + total_frames: int, + step_size: Optional[int] = None, + window="hanning", + ): """ - takes input as tensor outputs aggregated waveform + stitch batched waveform into single waveform. (Overlap-add) + arguments: + data: batched waveform + window_size : window_size used to batch waveform + step_size : step_size used to batch waveform + total_frames : total number of frames present in original waveform + window : type of window used for overlap-add mechanism. """ - num_chunks,n_channels,num_frames = data.shape - window = get_window(window=window,Nx=data.shape[-1]) + num_chunks, n_channels, num_frames = data.shape + window = get_window(window=window, Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window - data = data.permute(1,2,0) - data = F.fold(data, - (total_frames,1), - kernel_size=(window_size,1), - stride=(step_size,1), - padding=(window_size,0)).squeeze(-1) + data = data.permute(1, 2, 0) + data = F.fold( + data, + (total_frames, 1), + kernel_size=(window_size, 1), + stride=(step_size, 1), + padding=(window_size, 0), + ).squeeze(-1) - return data.reshape(1,n_channels,-1) + return data.reshape(1, n_channels, -1) @staticmethod - def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int): + def write_output( + waveform: torch.Tensor, filename: Union[str, Path], sr: int + ): + """ + write audio output as wav file + arguments: + waveform : audio waveform + filename : name of the wave file. Output will be written as cleaned_filename.wav + sr : sampling rate + """ - if isinstance(filename,str): + if isinstance(filename, str): filename = Path(filename) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: - wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) - + wavfile.write(filename, rate=sr, data=waveform.detach().cpu()) + @staticmethod + def prepare_output( + waveform: torch.Tensor, + model_sampling_rate: int, + audio: Union[str, np.ndarray, torch.Tensor], + sampling_rate: Optional[int], + ): + """ + prepare output audio based on input format + arguments: + waveform : predicted audio waveform + model_sampling_rate : sampling rate used to train the model + audio : input audio + sampling_rate : input audio sampling rate + """ + if isinstance(audio, np.ndarray): + waveform = waveform.detach().cpu().numpy() + if sampling_rate is not None: + waveform = Audio.resample_audio( + waveform, sr=model_sampling_rate, target_sr=sampling_rate + ) - - - - - - - - \ No newline at end of file + return waveform From e2c8afdfb9c1048ae9a212198537d51e9d9a477e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:11:57 +0530 Subject: [PATCH 02/38] flake8 --- .flake8 | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..ed37421 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +ignore = E203, E266, E501, W503 +# line length is intentionally set to 80 here because black uses Bugbear +# See https://github.com/psf/black/blob/master/README.md#line-length for more details +max-line-length = 80 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 +exclude = tools/kaldi_decoder \ No newline at end of file From 1062eb3541bb311499a0fff1e0d61fe98c562cb6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:12:05 +0530 Subject: [PATCH 03/38] toml --- pyproject.toml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8f12f30 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[tool.black] +line-length = 80 +target-version = ['py38'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.mypy_cache + | \.tox + | \.venv + )/ +) +''' \ No newline at end of file From e8b5e343c7729c7ca52c091045e26101cb525ad4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:21:30 +0530 Subject: [PATCH 04/38] black reformat --- enhancer/inference.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 838ef5f..1abd8bb 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -28,7 +28,7 @@ class Inference: if isinstance(audio, (np.ndarray, torch.Tensor)): assert sr is not None, "Invalid sampling rate!" if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) if isinstance(audio, str): audio = Path(audio) @@ -105,8 +105,7 @@ class Inference: window = get_window(window=window, Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window - step_size = window_size//2 if step_size is None else step_size - + step_size = window_size // 2 if step_size is None else step_size data = data.permute(1, 2, 0) data = F.fold( @@ -134,8 +133,8 @@ class Inference: if isinstance(filename, str): filename = Path(filename) - parent, name = filename.parent, "cleaned_"+filename.name - filename = parent/Path(name) + parent, name = filename.parent, "cleaned_" + filename.name + filename = parent / Path(name) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: From 756465c2bfdf6eee06a55526765b10f95eba2e9f Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:26:34 +0530 Subject: [PATCH 05/38] format loss.py --- enhancer/loss.py | 110 +++++++++++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 42 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index ef33161..1e156a0 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,62 +3,82 @@ import torch.nn as nn class mean_squared_error(nn.Module): + """ + Mean squared error / L1 loss + """ - def __init__(self,reduction="mean"): + def __init__(self, reduction="mean"): super().__init__() self.loss_fun = nn.MSELoss(reduction=reduction) self.higher_better = False - def forward(self,prediction:torch.Tensor, target: torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) return self.loss_fun(prediction, target) -class mean_absolute_error(nn.Module): - def __init__(self,reduction="mean"): +class mean_absolute_error(nn.Module): + """ + Mean absolute error / L2 loss + """ + + def __init__(self, reduction="mean"): super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) self.higher_better = False - def forward(self, prediction:torch.Tensor, target: torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) return self.loss_fun(prediction, target) -class Si_SDR(nn.Module): - def __init__( - self, - reduction:str="mean" - ): +class Si_SDR(nn.Module): + """ + SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) + """ + + def __init__(self, reduction: str = "mean"): super().__init__() - if reduction in ["sum","mean",None]: + if reduction in ["sum", "mean", None]: self.reduction = reduction else: - raise TypeError("Invalid reduction, valid options are sum, mean, None") + raise TypeError( + "Invalid reduction, valid options are sum, mean, None" + ) self.higher_better = False - def forward(self,prediction:torch.Tensor, target:torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: - raise TypeError(f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""") - - target_energy = torch.sum(target**2,keepdim=True,dim=-1) - scaling_factor = torch.sum(prediction*target,keepdim=True,dim=-1) / target_energy + raise TypeError( + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" + ) + + target_energy = torch.sum(target**2, keepdim=True, dim=-1) + scaling_factor = ( + torch.sum(prediction * target, keepdim=True, dim=-1) / target_energy + ) target_projection = target * scaling_factor noise = prediction - target_projection - ratio = torch.sum(target_projection**2,dim=-1) / torch.sum(noise**2,dim=-1) - si_sdr = 10*torch.log10(ratio).mean(dim=-1) + ratio = torch.sum(target_projection**2, dim=-1) / torch.sum( + noise**2, dim=-1 + ) + si_sdr = 10 * torch.log10(ratio).mean(dim=-1) if self.reduction == "sum": si_sdr = si_sdr.sum() @@ -66,46 +86,52 @@ class Si_SDR(nn.Module): si_sdr = si_sdr.mean() else: pass - + return si_sdr - class Avergeloss(nn.Module): + """ + Combine multiple metics of same nature. + for example, ["mea","mae"] + """ - def __init__(self,losses): + def __init__(self, losses): super().__init__() self.valid_losses = nn.ModuleList() - - direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] + + direction = [ + getattr(LOSS_MAP[loss](), "higher_better") for loss in losses + ] if len(set(direction)) > 1: - raise ValueError("all cost functions should be of same nature, maximize or minimize!") + raise ValueError( + "all cost functions should be of same nature, maximize or minimize!" + ) self.higher_better = direction[0] for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) - - def validate_loss(self,loss:str): + def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): - raise ValueError(f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}") + raise ValueError( + f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}" + ) else: return LOSS_MAP[loss] - def forward(self,prediction:torch.Tensor, target:torch.Tensor): + def forward(self, prediction: torch.Tensor, target: torch.Tensor): loss = 0.0 for loss_fun in self.valid_losses: loss += loss_fun(prediction, target) - + return loss - - - -LOSS_MAP = {"mae":mean_absolute_error, - "mse": mean_squared_error, - "SI-SDR":Si_SDR} - +LOSS_MAP = { + "mae": mean_absolute_error, + "mse": mean_squared_error, + "SI-SDR": Si_SDR, +} From 96c6108ec6994998405b51e57dccca8c6cbb3bf5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 11:49:27 +0530 Subject: [PATCH 06/38] document average loss --- enhancer/loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enhancer/loss.py b/enhancer/loss.py index 1e156a0..f2f62d3 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -94,6 +94,8 @@ class Avergeloss(nn.Module): """ Combine multiple metics of same nature. for example, ["mea","mae"] + parameters: + losses : loss function names to be combined """ def __init__(self, losses): From 2cf9803ed1bcd840608264d2e8640045bbb691b6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 12:15:10 +0530 Subject: [PATCH 07/38] refactor model.py --- enhancer/models/model.py | 309 ++++++++++++++++++++++++++------------- 1 file changed, 204 insertions(+), 105 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index de2edab..071bbb6 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -27,66 +27,89 @@ CACHE_DIR = "" HF_TORCH_WEIGHTS = "" DEFAULT_DEVICE = "cpu" + class Model(pl.LightningModule): + """ + Base class for all models + parameters: + num_channels: int, default to 1 + number of channels in input audio + sampling_rate : int, default 16khz + audio sampling rate + lr: float, optional + learning rate for model training + dataset: EnhancerDataset, optional + Enhancer dataset used for training/validation + duration: float, optional + duration used for training/inference + loss : string or List of strings, default to "mse" + loss functions to be used. Available ("mse","mae","Si-SDR") + + """ def __init__( self, - num_channels:int=1, - sampling_rate:int=16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - duration:Optional[float]=None, + num_channels: int = 1, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric:Union[str,List] = "mse" + metric: Union[str, List] = "mse", ): super().__init__() - assert num_channels ==1 , "Enhancer only support for mono channel models" + assert ( + num_channels == 1 + ), "Enhancer only support for mono channel models" self.dataset = dataset - self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") + self.save_hyperparameters( + "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" + ) if self.logger: - self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") - + self.logger.experiment.log_dict( + dict(self.hparams), "hyperparameters.json" + ) + self.loss = loss self.metric = metric @property def loss(self): return self._loss - - @loss.setter - def loss(self,loss): - if isinstance(loss,str): - losses = [loss] - + @loss.setter + def loss(self, loss): + + if isinstance(loss, str): + losses = [loss] + self._loss = Avergeloss(losses) @property def metric(self): return self._metric - + @metric.setter - def metric(self,metric): + def metric(self, metric): + + if isinstance(metric, str): + metric = [metric] - if isinstance(metric,str): - metric = [metric] - self._metric = Avergeloss(metric) - @property def dataset(self): return self._dataset @dataset.setter - def dataset(self,dataset): + def dataset(self, dataset): self._dataset = dataset - def setup(self,stage:Optional[str]=None): + def setup(self, stage: Optional[str] = None): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - + def train_dataloader(self): return self.dataset.train_dataloader() @@ -94,9 +117,9 @@ class Model(pl.LightningModule): return self.dataset.val_dataloader() def configure_optimizers(self): - return Adam(self.parameters(), lr = self.hparams.lr) + return Adam(self.parameters(), lr=self.hparams.lr) - def training_step(self,batch, batch_idx:int): + def training_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] @@ -105,13 +128,16 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="train_loss", value=loss.item(), - step=self.global_step) - self.log("train_loss",loss.item()) - return {"loss":loss} + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="train_loss", + value=loss.item(), + step=self.global_step, + ) + self.log("train_loss", loss.item()) + return {"loss": loss} - def validation_step(self,batch,batch_idx:int): + def validation_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] @@ -119,48 +145,92 @@ class Model(pl.LightningModule): metric_val = self.metric(prediction, target) loss_val = self.loss(prediction, target) - self.log("val_metric",metric_val.item()) - self.log("val_loss",loss_val.item()) + self.log("val_metric", metric_val.item()) + self.log("val_loss", loss_val.item()) if self.logger: - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_loss",value=loss_val.item(), - step=self.global_step) - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_metric",value=metric_val.item(), - step=self.global_step) + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="val_loss", + value=loss_val.item(), + step=self.global_step, + ) + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="val_metric", + value=metric_val.item(), + step=self.global_step, + ) - return {"loss":loss_val} + return {"loss": loss_val} def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = { - "version": { - "enhancer":__version__, - "pytorch":torch.__version__ + "version": {"enhancer": __version__, "pytorch": torch.__version__}, + "architecture": { + "module": self.__class__.__module__, + "class": self.__class__.__name__, }, - "architecture":{ - "module":self.__class__.__module__, - "class":self.__class__.__name__ - } - } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): pass - @classmethod def from_pretrained( cls, checkpoint: Union[Path, Text], - map_location = None, + map_location=None, hparams_file: Union[Path, Text] = None, strict: bool = True, use_auth_token: Union[Text, None] = None, - cached_dir: Union[Path, Text]=CACHE_DIR, - **kwargs + cached_dir: Union[Path, Text] = CACHE_DIR, + **kwargs, ): + """ + Load Pretrained model + + parameters: + checkpoint : Path or str + Path to checkpoint, or a remote URL, or a model identifier from + the huggingface.co model hub. + map_location: optional + Same role as in torch.load(). + Defaults to `lambda storage, loc: storage`. + hparams_file : Path or str, optional + Path to a .yaml file with hierarchical structure as in this example: + drop_prob: 0.2 + dataloader: + batch_size: 32 + You most likely won’t need this since Lightning will always save the + hyperparameters to the checkpoint. However, if your checkpoint weights + do not have the hyperparameters saved, use this method to pass in a .yaml + file with the hparams you would like to use. These will be converted + into a dict and passed into your Model for use. + strict : bool, optional + Whether to strictly enforce that the keys in checkpoint match + the keys returned by this module’s state dict. Defaults to True. + use_auth_token : str, optional + When loading a private huggingface.co model, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + cache_dir: Path or str, optional + Path to model cache directory. Defaults to content of PYANNOTE_CACHE + environment variable, or "~/.cache/torch/pyannote" when unset. + kwargs: optional + Any extra keyword args needed to init the model. + Can also be used to override saved hyperparameter values. + + Returns + ------- + model : Model + Model + + See also + -------- + torch.load + """ checkpoint = str(checkpoint) if hparams_file is not None: @@ -168,104 +238,133 @@ class Model(pl.LightningModule): if os.path.isfile(checkpoint): model_path_pl = checkpoint - elif urlparse(checkpoint).scheme in ("http","https"): + elif urlparse(checkpoint).scheme in ("http", "https"): model_path_pl = checkpoint else: - + if "@" in checkpoint: model_id = checkpoint.split("@")[0] revision_id = checkpoint.split("@")[1] else: model_id = checkpoint revision_id = None - + url = hf_hub_url( - model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id + model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id ) model_path_pl = cached_download( - url=url,library_name="enhancer",library_version=__version__, - cache_dir=cached_dir,use_auth_token=use_auth_token + url=url, + library_name="enhancer", + library_version=__version__, + cache_dir=cached_dir, + use_auth_token=use_auth_token, ) if map_location is None: map_location = torch.device(DEFAULT_DEVICE) - loaded_checkpoint = pl_load(model_path_pl,map_location) + loaded_checkpoint = pl_load(model_path_pl, map_location) module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] - class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) try: model = Klass.load_from_checkpoint( - checkpoint_path = model_path_pl, - map_location = map_location, - hparams_file = hparams_file, - strict = strict, - **kwargs + checkpoint_path=model_path_pl, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, ) except Exception as e: print(e) + return model - return model + def infer(self, batch: torch.Tensor, batch_size: int = 32): + """ + perform model inference + parameters: + batch : torch.Tensor + input data + batch_size : int, default 32 + batch size for inference + """ - def infer(self,batch:torch.Tensor,batch_size:int=32): - - assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" + assert ( + batch.ndim == 3 + ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) with torch.no_grad(): - for batch_id in range(0,batch.shape[0],batch_size): - batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + for batch_id in range(0, batch.shape[0], batch_size): + batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + self.device + ) prediction = self(batch_data) batch_predictions.append(prediction) - + return torch.vstack(batch_predictions) def enhance( self, - audio:Union[Path,np.ndarray,torch.Tensor], - sampling_rate:Optional[int]=None, - batch_size:int=32, - save_output:bool=False, - duration:Optional[int]=None, - step_size:Optional[int]=None,): + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + batch_size: int = 32, + save_output: bool = False, + duration: Optional[int] = None, + step_size: Optional[int] = None, + ): + """ + Enhance audio using loaded pretained model. + + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate: int, optional incase input is path + sampling rate of input + batch_size: int, default 32 + input audio is split into multiple chunks. Inference is done on batches + of these chunks according to given batch size. + save_output : bool, default False + weather to save output to file + duration : float, optional + chunk duration in seconds, defaults to duration of loaded pretrained model. + step_size: int, optional + step size between consecutive durations, defaults to 50% of duration + """ model_sampling_rate = self.hparams["sampling_rate"] if duration is None: duration = self.hparams["duration"] - waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) + waveform = Inference.read_input( + audio, sampling_rate, model_sampling_rate + ) waveform.to(self.device) window_size = round(duration * model_sampling_rate) - batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) - batch_prediction = self.infer(batched_waveform,batch_size=batch_size) - waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) - - if save_output and isinstance(audio,(str,Path)): - Inference.write_output(waveform,audio,model_sampling_rate) + batched_waveform = Inference.batchify( + waveform, window_size, step_size=step_size + ) + batch_prediction = self.infer(batched_waveform, batch_size=batch_size) + waveform = Inference.aggreagate( + batch_prediction, + window_size, + waveform.shape[-1], + step_size, + ) + + if save_output and isinstance(audio, (str, Path)): + Inference.write_output(waveform, audio, model_sampling_rate) else: - waveform = Inference.prepare_output(waveform, model_sampling_rate, - audio, sampling_rate) + waveform = Inference.prepare_output( + waveform, model_sampling_rate, audio, sampling_rate + ) return waveform + @property def valid_monitor(self): - return "max" if self.loss.higher_better else "min" - - - - - - - - - - - - - - - \ No newline at end of file + return "max" if self.loss.higher_better else "min" From 451058c29dad413053536e6dc589fc7812a5821e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 12:50:26 +0530 Subject: [PATCH 08/38] refactor models --- enhancer/models/demucs.py | 284 +++++++++++++++++++++--------------- enhancer/models/waveunet.py | 191 +++++++++++++++--------- 2 files changed, 283 insertions(+), 192 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 7c9d8ff..76a0bf7 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -9,209 +9,255 @@ from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict + class DemucsLSTM(nn.Module): def __init__( self, - input_size:int, - hidden_size:int, - num_layers:int, - bidirectional:bool=True - + input_size: int, + hidden_size: int, + num_layers: int, + bidirectional: bool = True, ): super().__init__() - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + self.lstm = nn.LSTM( + input_size, hidden_size, num_layers, bidirectional=bidirectional + ) dim = 2 if bidirectional else 1 - self.linear = nn.Linear(dim*hidden_size,hidden_size) + self.linear = nn.Linear(dim * hidden_size, hidden_size) - def forward(self,x): + def forward(self, x): - output,(h,c) = self.lstm(x) + output, (h, c) = self.lstm(x) output = self.linear(output) - return output,(h,c) + return output, (h, c) class DemucsEncoder(nn.Module): - def __init__( self, - num_channels:int, - hidden_size:int, - kernel_size:int, - stride:int=1, - glu:bool=False, + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.encoder = nn.Sequential( - nn.Conv1d(num_channels,hidden_size,kernel_size,stride), + nn.Conv1d(num_channels, hidden_size, kernel_size, stride), nn.ReLU(), - nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), - activation + nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), + activation, ) - def forward(self,waveform): - + def forward(self, waveform): + return self.encoder(waveform) -class DemucsDecoder(nn.Module): +class DemucsDecoder(nn.Module): def __init__( self, - num_channels:int, - hidden_size:int, - kernel_size:int, - stride:int=1, - glu:bool=False, - layer:int=0 + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, + layer: int = 0, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.decoder = nn.Sequential( - nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), + nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), activation, - nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) + nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), ) - if layer>0: + if layer > 0: self.decoder.add_module("4", nn.ReLU()) - def forward(self,waveform,): + def forward( + self, + waveform, + ): out = self.decoder(waveform) return out class Demucs(Model): + """ + Demucs model from https://arxiv.org/pdf/1911.13254.pdf + parameters: + encoder_decoder: dict, optional + keyword arguments passsed to encoder decoder block + lstm : dict, optional + keyword arguments passsed to LSTM block + num_channels: int, defaults to 1 + number channels in input audio + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + + """ ED_DEFAULTS = { - "initial_output_channels":48, - "kernel_size":8, - "stride":1, - "depth":5, - "glu":True, - "growth_factor":2, + "initial_output_channels": 48, + "kernel_size": 8, + "stride": 1, + "depth": 5, + "glu": True, + "growth_factor": 2, } LSTM_DEFAULTS = { - "bidirectional":True, - "num_layers":2, + "bidirectional": True, + "num_layers": 2, } - + def __init__( self, - encoder_decoder:Optional[dict]=None, - lstm:Optional[dict]=None, - num_channels:int=1, - resample:int=4, - sampling_rate = 16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - loss:Union[str, List] = "mse", - metric:Union[str, List] = "mse" - - + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + num_channels: int = 1, + resample: int = 4, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + loss: Union[str, List] = "mse", + metric: Union[str, List] = "mse", ): - duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) if dataset is not None: - if sampling_rate!=dataset.sampling_rate: - logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + if sampling_rate != dataset.sampling_rate: + logging.warn( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) sampling_rate = dataset.sampling_rate - super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss, metric=metric) - - encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) - lstm = merge_dict(self.LSTM_DEFAULTS,lstm) - self.save_hyperparameters("encoder_decoder","lstm","resample") + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + self.save_hyperparameters("encoder_decoder", "lstm", "resample") hidden = encoder_decoder["initial_output_channels"] self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for layer in range(encoder_decoder["depth"]): - encoder_layer = DemucsEncoder(num_channels=num_channels, - hidden_size=hidden, - kernel_size=encoder_decoder["kernel_size"], - stride=encoder_decoder["stride"], - glu=encoder_decoder["glu"], - ) + encoder_layer = DemucsEncoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=encoder_decoder["stride"], + glu=encoder_decoder["glu"], + ) self.encoder.append(encoder_layer) - decoder_layer = DemucsDecoder(num_channels=num_channels, - hidden_size=hidden, - kernel_size=encoder_decoder["kernel_size"], - stride=1, - glu=encoder_decoder["glu"], - layer=layer - ) - self.decoder.insert(0,decoder_layer) + decoder_layer = DemucsDecoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=1, + glu=encoder_decoder["glu"], + layer=layer, + ) + self.decoder.insert(0, decoder_layer) num_channels = hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden - - self.de_lstm = DemucsLSTM(input_size=num_channels, - hidden_size=num_channels, - num_layers=lstm["num_layers"], - bidirectional=lstm["bidirectional"] - ) - def forward(self,waveform): + self.de_lstm = DemucsLSTM( + input_size=num_channels, + hidden_size=num_channels, + num_layers=lstm["num_layers"], + bidirectional=lstm["bidirectional"], + ) + + def forward(self, waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1)!=1: - raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") + if waveform.size(1) != 1: + raise TypeError( + f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + ) length = waveform.shape[-1] - x = F.pad(waveform, (0,self.get_padding_length(length) - length)) - if self.hparams.resample>1: - x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, - target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) - + x = F.pad(waveform, (0, self.get_padding_length(length) - length)) + if self.hparams.resample > 1: + x = audio.resample_audio( + audio=x, + sr=self.hparams.sampling_rate, + target_sr=int( + self.hparams.sampling_rate * self.hparams.resample + ), + ) + encoder_outputs = [] for encoder in self.encoder: x = encoder(x) encoder_outputs.append(x) - x = x.permute(0,2,1) - x,_ = self.de_lstm(x) + x = x.permute(0, 2, 1) + x, _ = self.de_lstm(x) - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) for decoder in self.decoder: skip_connection = encoder_outputs.pop(-1) - x += skip_connection[..., :x.shape[-1]] + x += skip_connection[..., : x.shape[-1]] x = decoder(x) - + if self.hparams.resample > 1: - x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), - self.hparams.sampling_rate) + x = audio.resample_audio( + x, + int(self.hparams.sampling_rate * self.hparams.resample), + self.hparams.sampling_rate, + ) return x - - def get_padding_length(self,input_length): + + def get_padding_length(self, input_length): input_length = math.ceil(input_length * self.hparams.resample) - - for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation - input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 - input_length = max(1,input_length) - for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration - input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] - input_length = math.ceil(input_length/self.hparams.resample) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # encoder operation + input_length = ( + math.ceil( + (input_length - self.hparams.encoder_decoder["kernel_size"]) + / self.hparams.encoder_decoder["stride"] + ) + + 1 + ) + input_length = max(1, input_length) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # decoder operaration + input_length = (input_length - 1) * self.hparams.encoder_decoder[ + "stride" + ] + self.hparams.encoder_decoder["kernel_size"] + input_length = math.ceil(input_length / self.hparams.resample) return int(input_length) - - - - - - - - - - - - - \ No newline at end of file diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index f799352..4d5cc0a 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -7,76 +7,117 @@ from typing import Optional, Union, List from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset -class WavenetDecoder(nn.Module): +class WavenetDecoder(nn.Module): def __init__( self, - in_channels:int, - out_channels:int, - kernel_size:int=5, - padding:int=2, - stride:int=1, - dilation:int=1, + in_channels: int, + out_channels: int, + kernel_size: int = 5, + padding: int = 2, + stride: int = 1, + dilation: int = 1, ): - super(WavenetDecoder,self).__init__() + super(WavenetDecoder, self).__init__() self.decoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(negative_slope=0.1) + nn.LeakyReLU(negative_slope=0.1), ) - - def forward(self,waveform): - + + def forward(self, waveform): + return self.decoder(waveform) -class WavenetEncoder(nn.Module): +class WavenetEncoder(nn.Module): def __init__( self, - in_channels:int, - out_channels:int, - kernel_size:int=15, - padding:int=7, - stride:int=1, - dilation:int=1, + in_channels: int, + out_channels: int, + kernel_size: int = 15, + padding: int = 7, + stride: int = 1, + dilation: int = 1, ): - super(WavenetEncoder,self).__init__() + super(WavenetEncoder, self).__init__() self.encoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(negative_slope=0.1) + nn.LeakyReLU(negative_slope=0.1), ) - - def forward( - self, - waveform - ): + def forward(self, waveform): return self.encoder(waveform) class WaveUnet(Model): + """ + Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf + parameters: + num_channels: int, defaults to 1 + number of channels in input audio + depth : int, defaults to 12 + depth of network + initial_output_channels: int, defaults to 24 + number of output channels in intial upsampling layer + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + """ def __init__( self, - num_channels:int=1, - depth:int=12, - initial_output_channels:int=24, - sampling_rate:int=16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - duration:Optional[float]=None, + num_channels: int = 1, + depth: int = 12, + initial_output_channels: int = 24, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric:Union[str,List] = "mse" + metric: Union[str, List] = "mse", ): - duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) if dataset is not None: - if sampling_rate!=dataset.sampling_rate: - logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + if sampling_rate != dataset.sampling_rate: + logging.warn( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) sampling_rate = dataset.sampling_rate - super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss, metric=metric + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, ) self.save_hyperparameters("depth") self.encoders = nn.ModuleList() @@ -84,72 +125,76 @@ class WaveUnet(Model): out_channels = initial_output_channels for layer in range(depth): - encoder = WavenetEncoder(num_channels,out_channels) + encoder = WavenetEncoder(num_channels, out_channels) self.encoders.append(encoder) num_channels = out_channels out_channels += initial_output_channels - if layer == depth -1 : - decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) + if layer == depth - 1: + decoder = WavenetDecoder( + depth * initial_output_channels + num_channels, num_channels + ) else: - decoder = WavenetDecoder(num_channels+out_channels,num_channels) + decoder = WavenetDecoder( + num_channels + out_channels, num_channels + ) - self.decoders.insert(0,decoder) + self.decoders.insert(0, decoder) bottleneck_dim = depth * initial_output_channels self.bottleneck = nn.Sequential( - nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, - padding=7), + nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7), nn.BatchNorm1d(bottleneck_dim), - nn.LeakyReLU(negative_slope=0.1, inplace=True) + nn.LeakyReLU(negative_slope=0.1, inplace=True), ) self.final = nn.Sequential( nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), - nn.Tanh() + nn.Tanh(), ) - - def forward( - self,waveform - ): + def forward(self, waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1)!=1: - raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") + if waveform.size(1) != 1: + raise TypeError( + f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels" + ) encoder_outputs = [] out = waveform for encoder in self.encoders: out = encoder(out) - encoder_outputs.insert(0,out) - out = out[:,:,::2] - + encoder_outputs.insert(0, out) + out = out[:, :, ::2] + out = self.bottleneck(out) - for layer,decoder in enumerate(self.decoders): + for layer, decoder in enumerate(self.decoders): out = F.interpolate(out, scale_factor=2, mode="linear") - out = self.fix_last_dim(out,encoder_outputs[layer]) - out = torch.cat([out,encoder_outputs[layer]],dim=1) + out = self.fix_last_dim(out, encoder_outputs[layer]) + out = torch.cat([out, encoder_outputs[layer]], dim=1) out = decoder(out) - out = torch.cat([out, waveform],dim=1) + out = torch.cat([out, waveform], dim=1) out = self.final(out) return out - - def fix_last_dim(self,x,target): + + def fix_last_dim(self, x, target): """ - trying to do centre crop along last dimension + centre crop along last dimension """ - assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" + assert ( + x.shape[-1] >= target.shape[-1] + ), "input dimension cannot be larger than target dimension" if x.shape[-1] == target.shape[-1]: return x - + diff = x.shape[-1] - target.shape[-1] - if diff%2!=0: - x = F.pad(x,(0,1)) + if diff % 2 != 0: + x = F.pad(x, (0, 1)) diff += 1 - crop = diff//2 - return x[:,:,crop:-crop] + crop = diff // 2 + return x[:, :, crop:-crop] From b92310c93d33b5daf495cf3ee34bfbca5b2d1b19 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:05:40 +0530 Subject: [PATCH 09/38] refactor data modules --- enhancer/data/__init__.py | 1 + enhancer/data/dataset.py | 173 ++++++++++++++++++++------------- enhancer/data/fileprocessor.py | 118 ++++++++++++---------- 3 files changed, 169 insertions(+), 123 deletions(-) diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py index e69de29..3ec018e 100644 --- a/enhancer/data/__init__.py +++ b/enhancer/data/__init__.py @@ -0,0 +1 @@ +from enhancer.data.dataset import EnhancerDataset \ No newline at end of file diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 4c485c8..d194167 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -12,9 +12,9 @@ from enhancer.utils.io import Audio from enhancer.utils import check_files from enhancer.utils.config import Files + class TrainDataset(IterableDataset): - - def __init__(self,dataset): + def __init__(self, dataset): self.dataset = dataset def __iter__(self): @@ -23,88 +23,102 @@ class TrainDataset(IterableDataset): def __len__(self): return self.dataset.train__len__() + class ValidDataset(Dataset): - - def __init__(self,dataset): + def __init__(self, dataset): self.dataset = dataset - def __getitem__(self,idx): + def __getitem__(self, idx): return self.dataset.val__getitem__(idx) def __len__(self): return self.dataset.val__len__() -class TaskDataset(pl.LightningDataModule): +class TaskDataset(pl.LightningDataModule): def __init__( self, - name:str, - root_dir:str, - files:Files, - duration:float=1.0, - sampling_rate:int=48000, - matching_function = None, + name: str, + root_dir: str, + files: Files, + duration: float = 1.0, + sampling_rate: int = 48000, + matching_function=None, batch_size=32, - num_workers:Optional[int]=None): + num_workers: Optional[int] = None, + ): super().__init__() self.name = name - self.files,self.root_dir = check_files(root_dir,files) + self.files, self.root_dir = check_files(root_dir, files) self.duration = duration self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function self._validation = [] if num_workers is None: - num_workers = multiprocessing.cpu_count()//2 + num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers def setup(self, stage: Optional[str] = None): - if stage in ("fit",None): + if stage in ("fit", None): - train_clean = os.path.join(self.root_dir,self.files.train_clean) - train_noisy = os.path.join(self.root_dir,self.files.train_noisy) - fp = Fileprocessor.from_name(self.name,train_clean, - train_noisy, self.matching_function) + train_clean = os.path.join(self.root_dir, self.files.train_clean) + train_noisy = os.path.join(self.root_dir, self.files.train_noisy) + fp = Fileprocessor.from_name( + self.name, train_clean, train_noisy, self.matching_function + ) self.train_data = fp.prepare_matching_dict() - - val_clean = os.path.join(self.root_dir,self.files.test_clean) - val_noisy = os.path.join(self.root_dir,self.files.test_noisy) - fp = Fileprocessor.from_name(self.name,val_clean, - val_noisy, self.matching_function) + + val_clean = os.path.join(self.root_dir, self.files.test_clean) + val_noisy = os.path.join(self.root_dir, self.files.test_noisy) + fp = Fileprocessor.from_name( + self.name, val_clean, val_noisy, self.matching_function + ) val_data = fp.prepare_matching_dict() for item in val_data: - clean,noisy,total_dur = item.values() + clean, noisy, total_dur = item.values() if total_dur < self.duration: continue - num_segments = round(total_dur/self.duration) + num_segments = round(total_dur / self.duration) for index in range(num_segments): start_time = index * self.duration - self._validation.append(({"clean":clean,"noisy":noisy}, - start_time)) + self._validation.append( + ({"clean": clean, "noisy": noisy}, start_time) + ) + def train_dataloader(self): - return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) + return DataLoader( + TrainDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) def val_dataloader(self): - return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) + return DataLoader( + ValidDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + class EnhancerDataset(TaskDataset): """ Dataset object for creating clean-noisy speech enhancement datasets paramters: name : str - name of the dataset + name of the dataset root_dir : str root directory of the dataset containing clean/noisy folders files : Files - dataclass containing train_clean, train_noisy, test_clean, test_noisy - folder names (refer cli/train_config/dataset) + dataclass containing train_clean, train_noisy, test_clean, test_noisy + folder names (refer enhancer.utils.Files dataclass) duration : float expected audio duration of single audio sample for training sampling_rate : int - desired sampling rate + desired sampling rate batch_size : int batch size of each batch num_workers : int @@ -114,71 +128,92 @@ class EnhancerDataset(TaskDataset): use one_to_one mapping for datasets with one noisy file for each clean file use one_to_many mapping for multiple noisy files for each clean file - + """ def __init__( self, - name:str, - root_dir:str, - files:Files, + name: str, + root_dir: str, + files: Files, duration=1.0, sampling_rate=48000, matching_function=None, batch_size=32, - num_workers:Optional[int]=None): - + num_workers: Optional[int] = None, + ): + super().__init__( name=name, root_dir=root_dir, files=files, sampling_rate=sampling_rate, duration=duration, - matching_function = matching_function, + matching_function=matching_function, batch_size=batch_size, - num_workers = num_workers, - + num_workers=num_workers, ) self.sampling_rate = sampling_rate self.files = files - self.duration = max(1.0,duration) - self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True) + self.duration = max(1.0, duration) + self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) + + def setup(self, stage: Optional[str] = None): - def setup(self, stage:Optional[str]=None): - super().setup(stage=stage) def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) - + rng = create_unique_rng(self.model.current_epoch) + while True: - file_dict,*_ = rng.choices(self.train_data,k=1, - weights=[file["duration"] for file in self.train_data]) - file_duration = file_dict['duration'] - start_time = round(rng.uniform(0,file_duration- self.duration),2) - data = self.prepare_segment(file_dict,start_time) + file_dict, *_ = rng.choices( + self.train_data, + k=1, + weights=[file["duration"] for file in self.train_data], + ) + file_duration = file_dict["duration"] + start_time = round(rng.uniform(0, file_duration - self.duration), 2) + data = self.prepare_segment(file_dict, start_time) yield data - def val__getitem__(self,idx): + def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) - - def prepare_segment(self,file_dict:dict, start_time:float): - clean_segment = self.audio(file_dict["clean"], - offset=start_time,duration=self.duration) - noisy_segment = self.audio(file_dict["noisy"], - offset=start_time,duration=self.duration) - clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1]))) - noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1]))) - return {"clean": clean_segment,"noisy":noisy_segment} - + def prepare_segment(self, file_dict: dict, start_time: float): + + clean_segment = self.audio( + file_dict["clean"], offset=start_time, duration=self.duration + ) + noisy_segment = self.audio( + file_dict["noisy"], offset=start_time, duration=self.duration + ) + clean_segment = F.pad( + clean_segment, + ( + 0, + int( + self.duration * self.sampling_rate - clean_segment.shape[-1] + ), + ), + ) + noisy_segment = F.pad( + noisy_segment, + ( + 0, + int( + self.duration * self.sampling_rate - noisy_segment.shape[-1] + ), + ), + ) + return {"clean": clean_segment, "noisy": noisy_segment} + def train__len__(self): - return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) + return math.ceil( + sum([file["duration"] for file in self.train_data]) / self.duration + ) def val__len__(self): return len(self._validation) - - diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index eab41a0..106f649 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -4,105 +4,115 @@ from re import S import numpy as np from scipy.io import wavfile -MATCHING_FNS = ("one_to_one","one_to_many") +MATCHING_FNS = ("one_to_one", "one_to_many") + class ProcessorFunctions: + """ + Preprocessing methods for different types of speech enhacement datasets. + """ @staticmethod - def one_to_one(clean_path,noisy_path): + def one_to_one(clean_path, noisy_path): """ One clean audio can have only one noisy audio file """ matching_wavfiles = list() - clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] - noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))] - common_filenames = np.intersect1d(noisy_filenames,clean_filenames) + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] + noisy_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(noisy_path, "*.wav")) + ] + common_filenames = np.intersect1d(noisy_filenames, clean_filenames) for file_name in common_filenames: - sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) - sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) - if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr_noisy)): + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, file_name) + ) + sr_noisy, noisy_file = wavfile.read( + os.path.join(noisy_path, file_name) + ) + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): matching_wavfiles.append( - {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), - "duration":clean_file.shape[-1]/sr_clean} - ) + { + "clean": os.path.join(clean_path, file_name), + "noisy": os.path.join(noisy_path, file_name), + "duration": clean_file.shape[-1] / sr_clean, + } + ) return matching_wavfiles @staticmethod - def one_to_many(clean_path,noisy_path): + def one_to_many(clean_path, noisy_path): """ One clean audio have multiple noisy audio files """ - + matching_wavfiles = dict() - clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] for clean_file in clean_filenames: - noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav")) + noisy_filenames = glob.glob( + os.path.join(noisy_path, f"*_{clean_file}.wav") + ) for noisy_file in noisy_filenames: - sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, clean_file) + ) sr_noisy, noisy_file = wavfile.read(noisy_file) - if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr_noisy)): + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): matching_wavfiles.update( - {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, - "duration":clean_file.shape[-1]/sr_clean} - ) + { + "clean": os.path.join(clean_path, clean_file), + "noisy": noisy_file, + "duration": clean_file.shape[-1] / sr_clean, + } + ) return matching_wavfiles class Fileprocessor: - - def __init__( - self, - clean_dir, - noisy_dir, - matching_function = None - ): + def __init__(self, clean_dir, noisy_dir, matching_function=None): self.clean_dir = clean_dir self.noisy_dir = noisy_dir self.matching_function = matching_function @classmethod - def from_name(cls, - name:str, - clean_dir, - noisy_dir, - matching_function=None - ): + def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): if matching_function is None: if name.lower() == "vctk": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) else: if matching_function not in MATCHING_FNS: - raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") + raise ValueError( + f"Invalid matching function! Avaialble options are {MATCHING_FNS}" + ) else: - return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) - - + return cls( + clean_dir, + noisy_dir, + getattr(ProcessorFunctions, matching_function), + ) def prepare_matching_dict(self): if self.matching_function is None: raise ValueError("Not a valid matching function") - return self.matching_function(self.clean_dir,self.noisy_dir) - - - - - - - - - - - - + return self.matching_function(self.clean_dir, self.noisy_dir) From aca4521ef27725df97ac055297ef8301d7c04d11 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:19:17 +0530 Subject: [PATCH 10/38] refactor Audio --- enhancer/utils/io.py | 121 +++++++++++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index afc19e8..3703285 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,41 +1,66 @@ import os import librosa -from typing import Optional +from pathlib import Path +from typing import Optional, Union import numpy as np import torch import torchaudio + class Audio: + """ + Audio utils + parameters: + sampling_rate : int, defaults to 16KHz + audio sampling rate + mono: bool, defaults to True + return_tensors: bool, defaults to True + returns torch tensor type if set to True else numpy ndarray + """ def __init__( - self, - sampling_rate:int=16000, - mono:bool=True, - return_tensor=True + self, sampling_rate: int = 16000, mono: bool = True, return_tensor=True ) -> None: - + self.sampling_rate = sampling_rate self.mono = mono self.return_tensor = return_tensor def __call__( self, - audio, - sampling_rate:Optional[int]=None, - offset:Optional[float] = None, - duration:Optional[float] = None + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + offset: Optional[float] = None, + duration: Optional[float] = None, ): - if isinstance(audio,str): + """ + read and process input audio + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate : int, optional + sampling rate of the audio input + offset: float, optional + offset from which the audio must be read, reads from beginning if unused. + duration: float (seconds), optional + read duration, reads full audio starting from offset if not used + """ + if isinstance(audio, str): if os.path.exists(audio): - audio,sampling_rate = librosa.load(audio,sr=sampling_rate,mono=False, - offset=offset,duration=duration) + audio, sampling_rate = librosa.load( + audio, + sr=sampling_rate, + mono=False, + offset=offset, + duration=duration, + ) if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: raise FileNotFoundError(f"File {audio} deos not exist") - elif isinstance(audio,np.ndarray): + elif isinstance(audio, np.ndarray): if len(audio.shape) == 1: - audio = audio.reshape(1,-1) + audio = audio.reshape(1, -1) else: raise ValueError("audio should be either filepath or numpy ndarray") @@ -43,40 +68,60 @@ class Audio: audio = self.convert_mono(audio) if sampling_rate: - audio = self.__class__.resample_audio(audio,self.sampling_rate,sampling_rate) + audio = self.__class__.resample_audio( + audio, self.sampling_rate, sampling_rate + ) if self.return_tensor: return torch.tensor(audio) else: return audio @staticmethod - def convert_mono( - audio + def convert_mono(audio: Union[np.ndarray, torch.Tensor]): + """ + convert input audio into mono (1) + parameters: + audio: np.ndarray or torch.Tensor + """ + if len(audio.shape) > 2: + assert ( + audio.shape[0] == 1 + ), "convert mono only accepts single waveform" + audio = audio.reshape(audio.shape[1], audio.shape[2]) - ): - if len(audio.shape)>2: - assert audio.shape[0] == 1, "convert mono only accepts single waveform" - audio = audio.reshape(audio.shape[1],audio.shape[2]) - - assert audio.shape[1] >> audio.shape[0], f"expected input format (num_channels,num_samples) got {audio.shape}" - num_channels,num_samples = audio.shape - if num_channels>1: - return audio.mean(axis=0).reshape(1,num_samples) + assert ( + audio.shape[1] >> audio.shape[0] + ), f"expected input format (num_channels,num_samples) got {audio.shape}" + num_channels, num_samples = audio.shape + if num_channels > 1: + return audio.mean(axis=0).reshape(1, num_samples) return audio - @staticmethod def resample_audio( - audio, - sr:int, - target_sr:int + audio: Union[np.ndarray, torch.Tensor], sr: int, target_sr: int ): - if sr!=target_sr: - if isinstance(audio,np.ndarray): - audio = librosa.resample(audio,orig_sr=sr,target_sr=target_sr) - elif isinstance(audio,torch.Tensor): - audio = torchaudio.functional.resample(audio,orig_freq=sr,new_freq=target_sr) + """ + resample audio to desired sampling rate + parameters: + audio : Path to audio file or numpy array or torch tensor + audio waveform + sr : int + current sampling rate + target_sr : int + target sampling rate + + """ + if sr != target_sr: + if isinstance(audio, np.ndarray): + audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) + elif isinstance(audio, torch.Tensor): + audio = torchaudio.functional.resample( + audio, orig_freq=sr, new_freq=target_sr + ) else: - raise ValueError("Input should be either numpy array or torch tensor") + raise ValueError( + "Input should be either numpy array or torch tensor" + ) return audio From 6c4bced3607f23db76d911f9a3471de22db53fa1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:20:43 +0530 Subject: [PATCH 11/38] reformat utils --- enhancer/utils/__init__.py | 2 +- enhancer/utils/config.py | 11 +++++------ enhancer/utils/random.py | 29 ++++++++++++----------------- enhancer/utils/utils.py | 21 +++++++++++++-------- 4 files changed, 31 insertions(+), 32 deletions(-) diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index 3da7ede..c9f5438 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,3 +1,3 @@ from enhancer.utils.utils import check_files from enhancer.utils.io import Audio -from enhancer.utils.config import Files \ No newline at end of file +from enhancer.utils.config import Files diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index 1bbc51d..252e6c9 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -1,10 +1,9 @@ from dataclasses import dataclass + @dataclass class Files: - train_clean : str - train_noisy : str - test_clean : str - test_noisy : str - - + train_clean: str + train_noisy: str + test_clean: str + test_noisy: str diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 3b1acac..51e09c0 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -3,17 +3,16 @@ import random import torch - -def create_unique_rng(epoch:int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() - global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) - global_rank = int(os.environ.get('GLOBAL_RANK',"0")) - local_rank = int(os.environ.get('LOCAL_RANK',"0")) - node_rank = int(os.environ.get('NODE_RANK',"0")) - world_size = int(os.environ.get('WORLD_SIZE',"0")) + global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) + global_rank = int(os.environ.get("GLOBAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + node_rank = int(os.environ.get("NODE_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "0")) worker_info = torch.utils.data.get_worker_info() if worker_info is not None: @@ -24,17 +23,13 @@ def create_unique_rng(epoch:int): worker_id = 0 seed = ( - global_seed - + worker_id - + local_rank * num_workers - + node_rank * num_workers * global_rank - + epoch * num_workers * world_size - ) + global_seed + + worker_id + + local_rank * num_workers + + node_rank * num_workers * global_rank + + epoch * num_workers * world_size + ) rng.seed(seed) return rng - - - - diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index be74dc2..73673ed 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -1,19 +1,24 @@ - import os from typing import Optional from enhancer.utils.config import Files -def check_files(root_dir:str, files:Files): - path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] +def check_files(root_dir: str, files: Files): + + path_variables = [ + member_var + for member_var in dir(files) + if not member_var.startswith("__") + ] for variable in path_variables: - path = getattr(files,variable) - if not os.path.isdir(os.path.join(root_dir,path)): + path = getattr(files, variable) + if not os.path.isdir(os.path.join(root_dir, path)): raise ValueError(f"Invalid {path}, is not a directory") - - return files,root_dir -def merge_dict(default_dict:dict, custom:Optional[dict]=None): + return files, root_dir + + +def merge_dict(default_dict: dict, custom: Optional[dict] = None): params = dict(default_dict) if custom: params.update(custom) From bf37570d4a70a5951e47921edcc1af7139616a5e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:27:20 +0530 Subject: [PATCH 12/38] relative imports --- enhancer/__init__.py | 2 +- enhancer/data/__init__.py | 2 +- enhancer/models/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/enhancer/__init__.py b/enhancer/__init__.py index b3c06d4..f102a9c 100644 --- a/enhancer/__init__.py +++ b/enhancer/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1" \ No newline at end of file +__version__ = "0.0.1" diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py index 3ec018e..7efd946 100644 --- a/enhancer/data/__init__.py +++ b/enhancer/data/__init__.py @@ -1 +1 @@ -from enhancer.data.dataset import EnhancerDataset \ No newline at end of file +from enhancer.data.dataset import EnhancerDataset diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 534a608..368a9d7 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,3 +1,3 @@ from enhancer.models.demucs import Demucs from enhancer.models.waveunet import WaveUnet -from enhancer.models.model import Model \ No newline at end of file +from enhancer.models.model import Model From 670b70938acd5a357c473389a9798d7e8d3db362 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:47:13 +0530 Subject: [PATCH 13/38] ignore __init__.py --- .flake8 | 1 + 1 file changed, 1 insertion(+) diff --git a/.flake8 b/.flake8 index ed37421..861f69a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] +per-file-ignores = __init__.py:F401 ignore = E203, E266, E501, W503 # line length is intentionally set to 80 here because black uses Bugbear # See https://github.com/psf/black/blob/master/README.md#line-length for more details From babe5776ce25eb9d532fa6add28fc4ffce3f6b22 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:47:25 +0530 Subject: [PATCH 14/38] flake8 changes --- enhancer/data/fileprocessor.py | 1 - enhancer/loss.py | 11 ++++++----- enhancer/models/model.py | 9 +-------- enhancer/utils/utils.py | 1 + 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 106f649..5cc9b31 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,6 +1,5 @@ import glob import os -from re import S import numpy as np from scipy.io import wavfile diff --git a/enhancer/loss.py b/enhancer/loss.py index f2f62d3..db1d222 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -17,8 +17,8 @@ class mean_squared_error(nn.Module): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( - f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""" + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" ) return self.loss_fun(prediction, target) @@ -39,7 +39,7 @@ class mean_absolute_error(nn.Module): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( - f"""Inputs must be of the same shape (batch_size,channels,samples) + f"""Inputs must be of the same shape (batch_size,channels,samples) got {prediction.size()} and {target.size()} instead""" ) @@ -65,7 +65,7 @@ class Si_SDR(nn.Module): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( - f"""Inputs must be of the same shape (batch_size,channels,samples) + f"""Inputs must be of the same shape (batch_size,channels,samples) got {prediction.size()} and {target.size()} instead""" ) @@ -119,7 +119,8 @@ class Avergeloss(nn.Module): def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): raise ValueError( - f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}" + f"""Invalid loss function {loss}, available loss functions are + {tuple([loss for loss in LOSS_MAP.keys()])}""" ) else: return LOSS_MAP[loss] diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 071bbb6..56f24db 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,16 +1,10 @@ -try: - from functools import cached_property -except ImportError: - from backports.cached_property import cached_property from importlib import import_module from huggingface_hub import cached_download, hf_hub_url -import logging import numpy as np import os from typing import Optional, Union, List, Text, Dict, Any from torch.optim import Adam import torch -from torch.nn.functional import pad import pytorch_lightning as pl from pytorch_lightning.utilities.cloud_io import load as pl_load from urllib.parse import urlparse @@ -19,7 +13,6 @@ from pathlib import Path from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset -from enhancer.utils.io import Audio from enhancer.loss import Avergeloss from enhancer.inference import Inference @@ -300,7 +293,7 @@ class Model(pl.LightningModule): with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + batch_data = batch[batch_id: batch_id + batch_size, :, :].to( self.device ) prediction = self(batch_data) diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index 73673ed..ebb41b4 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -19,6 +19,7 @@ def check_files(root_dir: str, files: Files): def merge_dict(default_dict: dict, custom: Optional[dict] = None): + params = dict(default_dict) if custom: params.update(custom) From 80d6795b61868c1d08a66f3b3f969fadff193483 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:52:50 +0530 Subject: [PATCH 15/38] move cli to enhancer --- {cli => enhancer/cli}/train.py | 0 {cli => enhancer/cli}/train_config/config.yaml | 0 .../cli}/train_config/dataset/DNS-2020.yaml | 0 .../cli}/train_config/dataset/Vctk.yaml | 0 enhancer/cli/train_config/dataset/Vctk_local.yaml | 13 +++++++++++++ .../cli}/train_config/hyperparameters/default.yaml | 0 .../cli}/train_config/mlflow/experiment.yaml | 0 .../cli}/train_config/model/Demucs.yaml | 0 .../cli}/train_config/model/WaveUnet.yaml | 0 .../cli}/train_config/optimizer/Adam.yaml | 0 .../cli}/train_config/trainer/default.yaml | 0 .../cli}/train_config/trainer/fastrun_dev.yaml | 0 12 files changed, 13 insertions(+) rename {cli => enhancer/cli}/train.py (100%) rename {cli => enhancer/cli}/train_config/config.yaml (100%) rename {cli => enhancer/cli}/train_config/dataset/DNS-2020.yaml (100%) rename {cli => enhancer/cli}/train_config/dataset/Vctk.yaml (100%) create mode 100644 enhancer/cli/train_config/dataset/Vctk_local.yaml rename {cli => enhancer/cli}/train_config/hyperparameters/default.yaml (100%) rename {cli => enhancer/cli}/train_config/mlflow/experiment.yaml (100%) rename {cli => enhancer/cli}/train_config/model/Demucs.yaml (100%) rename {cli => enhancer/cli}/train_config/model/WaveUnet.yaml (100%) rename {cli => enhancer/cli}/train_config/optimizer/Adam.yaml (100%) rename {cli => enhancer/cli}/train_config/trainer/default.yaml (100%) rename {cli => enhancer/cli}/train_config/trainer/fastrun_dev.yaml (100%) diff --git a/cli/train.py b/enhancer/cli/train.py similarity index 100% rename from cli/train.py rename to enhancer/cli/train.py diff --git a/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml similarity index 100% rename from cli/train_config/config.yaml rename to enhancer/cli/train_config/config.yaml diff --git a/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml similarity index 100% rename from cli/train_config/dataset/DNS-2020.yaml rename to enhancer/cli/train_config/dataset/DNS-2020.yaml diff --git a/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml similarity index 100% rename from cli/train_config/dataset/Vctk.yaml rename to enhancer/cli/train_config/dataset/Vctk.yaml diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml new file mode 100644 index 0000000..b792b71 --- /dev/null +++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml @@ -0,0 +1,13 @@ +_target_: enhancer.data.dataset.EnhancerDataset +name : vctk +root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk +duration : 1.0 +sampling_rate: 16000 +batch_size: 64 +num_workers : 0 + +files: + train_clean : clean_testset_wav + test_clean : clean_testset_wav + train_noisy : noisy_testset_wav + test_noisy : noisy_testset_wav \ No newline at end of file diff --git a/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml similarity index 100% rename from cli/train_config/hyperparameters/default.yaml rename to enhancer/cli/train_config/hyperparameters/default.yaml diff --git a/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml similarity index 100% rename from cli/train_config/mlflow/experiment.yaml rename to enhancer/cli/train_config/mlflow/experiment.yaml diff --git a/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml similarity index 100% rename from cli/train_config/model/Demucs.yaml rename to enhancer/cli/train_config/model/Demucs.yaml diff --git a/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml similarity index 100% rename from cli/train_config/model/WaveUnet.yaml rename to enhancer/cli/train_config/model/WaveUnet.yaml diff --git a/cli/train_config/optimizer/Adam.yaml b/enhancer/cli/train_config/optimizer/Adam.yaml similarity index 100% rename from cli/train_config/optimizer/Adam.yaml rename to enhancer/cli/train_config/optimizer/Adam.yaml diff --git a/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml similarity index 100% rename from cli/train_config/trainer/default.yaml rename to enhancer/cli/train_config/trainer/default.yaml diff --git a/cli/train_config/trainer/fastrun_dev.yaml b/enhancer/cli/train_config/trainer/fastrun_dev.yaml similarity index 100% rename from cli/train_config/trainer/fastrun_dev.yaml rename to enhancer/cli/train_config/trainer/fastrun_dev.yaml From 8ac01b846de688d4e060a3c52ed5106e8a8e63bd Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:54:10 +0530 Subject: [PATCH 16/38] black --- enhancer/cli/train.py | 79 ++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index dee3d2e..814fa0f 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -1,4 +1,3 @@ -from genericpath import isfile import os from types import MethodType import hydra @@ -7,61 +6,79 @@ from omegaconf import DictConfig from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger -os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID","0") -@hydra.main(config_path="train_config",config_name="config") +os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID", "0") + + +@hydra.main(config_path="train_config", config_name="config") def main(config: DictConfig): callbacks = [] - logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) - + logger = MLFlowLogger( + experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name, + tags={"JOB_ID": JOB_ID}, + ) parameters = config.hyperparameters dataset = instantiate(config.dataset) - model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), - loss=parameters.get("loss"), metric = parameters.get("metric")) + model = instantiate( + config.model, + dataset=dataset, + lr=parameters.get("lr"), + loss=parameters.get("loss"), + metric=parameters.get("metric"), + ) direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, - mode=direction,every_n_epochs=1 + dirpath="./model", + filename=f"model_{JOB_ID}", + monitor="val_loss", + verbose=True, + mode=direction, + every_n_epochs=1, ) callbacks.append(checkpoint) early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=0.0, - patience=parameters.get("EarlyStopping_patience",10), - strict=True, - verbose=False, - ) + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) callbacks.append(early_stopping) - + def configure_optimizer(self): - optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) + optimizer = instantiate( + config.optimizer, + lr=parameters.get("lr"), + parameters=self.parameters(), + ) scheduler = ReduceLROnPlateau( optimizer=optimizer, mode=direction, - factor=parameters.get("ReduceLr_factor",0.1), + factor=parameters.get("ReduceLr_factor", 0.1), verbose=True, - min_lr=parameters.get("min_lr",1e-6), - patience=parameters.get("ReduceLr_patience",3) + min_lr=parameters.get("min_lr", 1e-6), + patience=parameters.get("ReduceLr_patience", 3), ) - return {"optimizer":optimizer, "lr_scheduler":scheduler} + return {"optimizer": optimizer, "lr_scheduler": scheduler} - model.configure_parameters = MethodType(configure_optimizer,model) + model.configure_parameters = MethodType(configure_optimizer, model) - trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) + trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) - saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") + saved_location = os.path.join( + trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" + ) if os.path.isfile(saved_location): - logger.experiment.log_artifact(logger.run_id,saved_location) + logger.experiment.log_artifact(logger.run_id, saved_location) - -if __name__=="__main__": - main() \ No newline at end of file +if __name__ == "__main__": + main() From 8a07cb8712a437faeff8bd7475d33e7c445fbb28 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:54:54 +0530 Subject: [PATCH 17/38] tests --- tests/loss_function_test.py | 20 +++++++++-------- tests/models/demucs_test.py | 36 ++++++++++++++---------------- tests/models/test_waveunet.py | 36 ++++++++++++++---------------- tests/test_inference.py | 24 +++++++++++--------- tests/utils_test.py | 42 ++++++++++++++++++++--------------- 5 files changed, 83 insertions(+), 75 deletions(-) diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index fbc982c..a4fdc62 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error loss_functions = [mean_absolute_error(), mean_squared_error()] + def check_loss_shapes_compatibility(loss_fun): batch_size = 4 - shape = (1,1000) - loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) + shape = (1, 1000) + loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape)) with pytest.raises(TypeError): - loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) + loss_fun(torch.rand(4, *shape), torch.rand(6, *shape)) -@pytest.mark.parametrize("loss",loss_functions) +@pytest.mark.parametrize("loss", loss_functions) def test_loss_input_shapes(loss): check_loss_shapes_compatibility(loss) -@pytest.mark.parametrize("loss",loss_functions) + +@pytest.mark.parametrize("loss", loss_functions) def test_loss_output_type(loss): batch_size = 4 - prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) + prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( + batch_size, 1, 1000 + ) loss_value = loss(prediction, target) - assert isinstance(loss_value.item(),float) - - + assert isinstance(loss_value.item(), float) diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index a59fa04..6660888 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = Demucs() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = Demucs(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + model = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 43fd14d..c83966b 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = WaveUnet() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/test_inference.py b/tests/test_inference.py index 5eb7442..a6e2423 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4,22 +4,26 @@ import torch from enhancer.inference import Inference -@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) +@pytest.mark.parametrize( + "audio", + ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], +) def test_read_input(audio): - read_audio = Inference.read_input(audio,48000,16000) - assert isinstance(read_audio,torch.Tensor) + read_audio = Inference.read_input(audio, 48000, 16000) + assert isinstance(read_audio, torch.Tensor) assert read_audio.shape[0] == 1 + def test_batchify(): - rand = torch.rand(1,1000) - batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) + rand = torch.rand(1, 1000) + batched_rand = Inference.batchify(rand, window_size=100, step_size=100) assert batched_rand.shape[0] == 12 + def test_aggregate(): - rand = torch.rand(12,1,100) - agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) + rand = torch.rand(12, 1, 100) + agg_rand = Inference.aggreagate( + data=rand, window_size=100, total_frames=1000, step_size=100 + ) assert agg_rand.shape[-1] == 1000 - - - \ No newline at end of file diff --git a/tests/utils_test.py b/tests/utils_test.py index 413bfac..93a9094 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -7,40 +7,46 @@ from enhancer.utils.io import Audio from enhancer.utils.config import Files from enhancer.data.fileprocessor import Fileprocessor + def test_io_channel(): - input_audio = np.random.rand(2,32000) - audio = Audio(mono=True,return_tensor=False) + input_audio = np.random.rand(2, 32000) + audio = Audio(mono=True, return_tensor=False) output_audio = audio(input_audio) assert output_audio.shape[0] == 1 + def test_io_resampling(): - input_audio = np.random.rand(1,32000) - resampled_audio = Audio.resample_audio(input_audio,16000,8000) + input_audio = np.random.rand(1, 32000) + resampled_audio = Audio.resample_audio(input_audio, 16000, 8000) - input_audio = torch.rand(1,32000) - resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) + input_audio = torch.rand(1, 32000) + resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000) assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 + def test_fileprocessor_vctk(): - fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", - "tests/data/vctk/noisy_testset_wav",48000) + fp = Fileprocessor.from_name( + "vctk", + "tests/data/vctk/clean_testset_wav", + "tests/data/vctk/noisy_testset_wav", + 48000, + ) matching_dict = fp.prepare_matching_dict() - assert len(matching_dict)==2 + assert len(matching_dict) == 2 -@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"]) + +@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) def test_fileprocessor_names(dataset_name): - fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) - assert hasattr(fp.matching_function, '__call__') + fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) + assert hasattr(fp.matching_function, "__call__") + def test_fileprocessor_invaliname(): with pytest.raises(ValueError): - fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() - - - - - + fp = Fileprocessor.from_name( + "undefined", "clean_dir", "noisy_dir", 16000 + ).prepare_matching_dict() From 24f4c25a1b5cbc9fa92b0f532449cb4ff28613bb Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 17:00:41 +0530 Subject: [PATCH 18/38] requirements --- requirements.txt | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index e7fcd24..8373578 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,14 @@ -joblib==1.1.0 -numpy==1.19.5 -librosa==0.9.1 -numpy==1.19.5 +joblib==1.2.0 +librosa==0.9.2 +numpy==1.23.3 hydra-core==1.2.0 -scikit-learn==0.24.2 -scipy==1.5.4 -torch==1.10.2 -tqdm==4.64.0 -mlflow==1.23.1 -protobuf==3.19.3 -boto3==1.23.9 -torchaudio==0.10.2 -huggingface-hub==0.4.0 -pytorch-lightning==1.5.10 +scikit-learn==1.1.2 +scipy==1.9.1 +torch==1.12.1 +tqdm==4.64.1 +mlflow==1.29.0 +protobuf==3.19.6 +boto3==1.24.86 +torchaudio==0.12.1 +huggingface-hu==0.10.0 +pytorch-lightning==1.7.7 \ No newline at end of file From 83fe8a29c008dad0ef2295e83996b286a7c2aa4b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 17:16:26 +0530 Subject: [PATCH 19/38] add flake --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8373578..b2f3638 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ protobuf==3.19.6 boto3==1.24.86 torchaudio==0.12.1 huggingface-hu==0.10.0 -pytorch-lightning==1.7.7 \ No newline at end of file +pytorch-lightning==1.7.7 +flake8==5.0.4 +black==22.8.0 \ No newline at end of file From b53f9d5f9ef31010f909f58df9c76f8a421d6ce7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 17:40:49 +0530 Subject: [PATCH 20/38] pre-commit --- .pre-commit-config.yaml | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5721482 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ + +repos: + # # Clean Notebooks + # - repo: https://github.com/kynan/nbstripout + # rev: master + # hooks: + # - id: nbstripout + # Format Code + - repo: https://github.com/ambv/black + rev: 22.3.0 + hooks: + - id: black + + # Sort imports + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] + + # Formatting, Whitespace, etc + - repo: git://github.com/pre-commit/pre-commit-hooks + rev: v2.2.3 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=no'] + - id: flake8 + args: ['--ignore=E203,E501,F811,E712,W503'] From 9adb915447d7467fb79cef9273268a5cd7bd3a05 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 17:47:40 +0530 Subject: [PATCH 21/38] pre-commit conf --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5721482..8eac0a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: # - id: nbstripout # Format Code - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 22.8.0 hooks: - id: black @@ -20,7 +20,7 @@ repos: # Formatting, Whitespace, etc - repo: git://github.com/pre-commit/pre-commit-hooks - rev: v2.2.3 + rev: v2.20.0 hooks: - id: trailing-whitespace - id: check-added-large-files From d20b7a166fdfcaee6ab0fcbd6c58c5d66401f8a4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:35:17 +0530 Subject: [PATCH 22/38] tests --- tests/loss_function_test.py | 3 ++- tests/models/demucs_test.py | 6 +++--- tests/models/test_waveunet.py | 6 +++--- tests/utils_test.py | 7 ++++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index a4fdc62..cd60177 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -1,6 +1,7 @@ from asyncio import base_tasks -import torch + import pytest +import torch from enhancer.loss import mean_absolute_error, mean_squared_error diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index 6660888..1ea50c5 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,10 +1,10 @@ import pytest import torch -from enhancer import data -from enhancer.utils.config import Files -from enhancer.models import Demucs +from enhancer import data from enhancer.data.dataset import EnhancerDataset +from enhancer.models import Demucs +from enhancer.utils.config import Files @pytest.fixture diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index c83966b..798ed5d 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,10 +1,10 @@ import pytest import torch -from enhancer import data -from enhancer.utils.config import Files -from enhancer.models import WaveUnet +from enhancer import data from enhancer.data.dataset import EnhancerDataset +from enhancer.models import WaveUnet +from enhancer.utils.config import Files @pytest.fixture diff --git a/tests/utils_test.py b/tests/utils_test.py index 93a9094..1cc171a 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,11 +1,12 @@ from logging import root + +import numpy as np import pytest import torch -import numpy as np -from enhancer.utils.io import Audio -from enhancer.utils.config import Files from enhancer.data.fileprocessor import Fileprocessor +from enhancer.utils.config import Files +from enhancer.utils.io import Audio def test_io_channel(): From e5d9eb7e95737565066ddc3ee3ff7c73c6e36d88 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:35:33 +0530 Subject: [PATCH 23/38] models --- enhancer/models/__init__.py | 2 +- enhancer/models/demucs.py | 9 +++++---- enhancer/models/model.py | 22 +++++++++++----------- enhancer/models/waveunet.py | 5 +++-- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 368a9d7..2d97568 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,3 +1,3 @@ from enhancer.models.demucs import Demucs -from enhancer.models.waveunet import WaveUnet from enhancer.models.model import Model +from enhancer.models.waveunet import WaveUnet diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 76a0bf7..65f119d 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,11 +1,12 @@ import logging -from typing import Optional, Union, List -from torch import nn -import torch.nn.functional as F import math +from typing import List, Optional, Union + +import torch.nn.functional as F +from torch import nn -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 56f24db..39dbe80 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,20 +1,20 @@ -from importlib import import_module -from huggingface_hub import cached_download, hf_hub_url -import numpy as np import os -from typing import Optional, Union, List, Text, Dict, Any -from torch.optim import Adam -import torch -import pytorch_lightning as pl -from pytorch_lightning.utilities.cloud_io import load as pl_load -from urllib.parse import urlparse +from importlib import import_module from pathlib import Path +from typing import Any, Dict, List, Optional, Text, Union +from urllib.parse import urlparse +import numpy as np +import pytorch_lightning as pl +import torch +from huggingface_hub import cached_download, hf_hub_url +from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch.optim import Adam from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset -from enhancer.loss import Avergeloss from enhancer.inference import Inference +from enhancer.loss import Avergeloss CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -293,7 +293,7 @@ class Model(pl.LightningModule): with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id: batch_id + batch_size, :, :].to( + batch_data = batch[batch_id : batch_id + batch_size, :, :].to( self.device ) prediction = self(batch_data) diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 4d5cc0a..ebb4b1f 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,11 +1,12 @@ import logging +from typing import List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Union, List -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model class WavenetDecoder(nn.Module): From 64d61e25a422d0a81028e9a397cd4559c3154a89 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:35:53 +0530 Subject: [PATCH 24/38] cli --- enhancer/cli/train.py | 5 +++-- enhancer/cli/train_config/config.yaml | 2 +- enhancer/cli/train_config/dataset/DNS-2020.yaml | 1 - enhancer/cli/train_config/dataset/Vctk.yaml | 3 --- enhancer/cli/train_config/dataset/Vctk_local.yaml | 2 +- enhancer/cli/train_config/hyperparameters/default.yaml | 1 - enhancer/cli/train_config/mlflow/experiment.yaml | 2 +- enhancer/cli/train_config/model/Demucs.yaml | 2 -- 8 files changed, 6 insertions(+), 12 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 814fa0f..cb3c7c1 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -1,11 +1,12 @@ import os from types import MethodType + import hydra from hydra.utils import instantiate from omegaconf import DictConfig -from torch.optim.lr_scheduler import ReduceLROnPlateau -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger +from torch.optim.lr_scheduler import ReduceLROnPlateau os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml index 61551bd..c0b2cf6 100644 --- a/enhancer/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -4,4 +4,4 @@ defaults: - optimizer : Adam - hyperparameters : default - trainer : default - - mlflow : experiment \ No newline at end of file + - mlflow : experiment diff --git a/enhancer/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml index f59cb2b..3bd0e67 100644 --- a/enhancer/cli/train_config/dataset/DNS-2020.yaml +++ b/enhancer/cli/train_config/dataset/DNS-2020.yaml @@ -10,4 +10,3 @@ files: test_clean : clean_test_wav train_noisy : clean_test_wav test_noisy : clean_test_wav - diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 129d9a8..5c19320 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -10,6 +10,3 @@ files: test_clean : clean_testset_wav train_noisy : noisy_trainset_28spk_wav test_noisy : noisy_testset_wav - - - diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml index b792b71..ba44597 100644 --- a/enhancer/cli/train_config/dataset/Vctk_local.yaml +++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml @@ -10,4 +10,4 @@ files: train_clean : clean_testset_wav test_clean : clean_testset_wav train_noisy : noisy_testset_wav - test_noisy : noisy_testset_wav \ No newline at end of file + test_noisy : noisy_testset_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 82ac3c2..7e4cda3 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -5,4 +5,3 @@ ReduceLr_patience : 5 ReduceLr_factor : 0.1 min_lr : 0.000001 EarlyStopping_factor : 10 - diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml index 2995c60..e8893f6 100644 --- a/enhancer/cli/train_config/mlflow/experiment.yaml +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ experiment_name : shahules/enhancer -run_name : baseline \ No newline at end of file +run_name : baseline diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml index 1006e71..d91b5ff 100644 --- a/enhancer/cli/train_config/model/Demucs.yaml +++ b/enhancer/cli/train_config/model/Demucs.yaml @@ -14,5 +14,3 @@ encoder_decoder: lstm: bidirectional: False num_layers: 2 - - From 761e6492bbcccb1e8ec5586a913abbf9bf68a348 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:36:12 +0530 Subject: [PATCH 25/38] utils --- enhancer/utils/__init__.py | 4 ++-- enhancer/utils/io.py | 3 ++- enhancer/utils/random.py | 1 + enhancer/utils/utils.py | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index c9f5438..de0db9f 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,3 +1,3 @@ -from enhancer.utils.utils import check_files -from enhancer.utils.io import Audio from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.utils import check_files diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 3703285..9e9ce32 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,7 +1,8 @@ import os -import librosa from pathlib import Path from typing import Optional, Union + +import librosa import numpy as np import torch import torchaudio diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 51e09c0..dd9395a 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -1,5 +1,6 @@ import os import random + import torch diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index ebb41b4..ad45139 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -1,5 +1,6 @@ import os from typing import Optional + from enhancer.utils.config import Files From 459c927f0b2ca57d5778b5e80cee0cf42ff7cf55 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:36:24 +0530 Subject: [PATCH 26/38] inference --- enhancer/inference.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 1abd8bb..ae399f1 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -1,11 +1,12 @@ -import numpy as np -from scipy.signal import get_window -from scipy.io import wavfile +from pathlib import Path from typing import Optional, Union + +import numpy as np import torch import torch.nn.functional as F -from pathlib import Path from librosa import load as load_audio +from scipy.io import wavfile +from scipy.signal import get_window from enhancer.utils import Audio From d0f8ad2f37bacaac294729bfa619a24a07456d8e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:36:50 +0530 Subject: [PATCH 27/38] dataset --- enhancer/data/dataset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d194167..95c73a1 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,16 +1,17 @@ -import multiprocessing import math +import multiprocessing import os -import pytorch_lightning as pl -from torch.utils.data import IterableDataset, DataLoader, Dataset -import torch.nn.functional as F from typing import Optional +import pytorch_lightning as pl +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + from enhancer.data.fileprocessor import Fileprocessor -from enhancer.utils.random import create_unique_rng -from enhancer.utils.io import Audio from enhancer.utils import check_files from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.random import create_unique_rng class TrainDataset(IterableDataset): From 09e800392e557e49c2a482a280bddba23cfdb6e7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:37:30 +0530 Subject: [PATCH 28/38] fileprocessor --- enhancer/data/fileprocessor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 5cc9b31..66d4d75 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,5 +1,6 @@ import glob import os + import numpy as np from scipy.io import wavfile From ec76466b00045bc2d86fd8c434a6dc10a08791f9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:37:39 +0530 Subject: [PATCH 29/38] readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e462afa..743a823 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# enhancer \ No newline at end of file +# enhancer From ab0805e1ac4ad77a320fcb9cd58c397fd06e1035 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:42:09 +0530 Subject: [PATCH 30/38] flake8 tests --- tests/loss_function_test.py | 2 -- tests/models/demucs_test.py | 3 +-- tests/models/test_waveunet.py | 3 +-- tests/utils_test.py | 5 +---- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index cd60177..4d14871 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -1,5 +1,3 @@ -from asyncio import base_tasks - import pytest import torch diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index 1ea50c5..f5a0ec4 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,7 +1,6 @@ import pytest import torch -from enhancer import data from enhancer.data.dataset import EnhancerDataset from enhancer.models import Demucs from enhancer.utils.config import Files @@ -41,4 +40,4 @@ def test_forward(batch_size, samples): ) def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = Demucs(num_channels=channels, dataset=dataset, loss=loss) + _ = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 798ed5d..9c4dd96 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,7 +1,6 @@ import pytest import torch -from enhancer import data from enhancer.data.dataset import EnhancerDataset from enhancer.models import WaveUnet from enhancer.utils.config import Files @@ -41,4 +40,4 @@ def test_forward(batch_size, samples): ) def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) + _ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/utils_test.py b/tests/utils_test.py index 1cc171a..65c723d 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,11 +1,8 @@ -from logging import root - import numpy as np import pytest import torch from enhancer.data.fileprocessor import Fileprocessor -from enhancer.utils.config import Files from enhancer.utils.io import Audio @@ -48,6 +45,6 @@ def test_fileprocessor_names(dataset_name): def test_fileprocessor_invaliname(): with pytest.raises(ValueError): - fp = Fileprocessor.from_name( + _ = Fileprocessor.from_name( "undefined", "clean_dir", "noisy_dir", 16000 ).prepare_matching_dict() From c0623db75d5961339de06b0992f49171bf372d87 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:42:18 +0530 Subject: [PATCH 31/38] setup --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index adad46c..43adc89 100644 --- a/setup.sh +++ b/setup.sh @@ -10,4 +10,4 @@ conda env create -f environment.yml || conda env update -f environment.yml source activate enhancer echo "copying files" -# cp /scratch/$USER/TIMIT/.* /deep-transcriber \ No newline at end of file +# cp /scratch/$USER/TIMIT/.* /deep-transcriber From 71603d78dbf8d0a0e2588e13974598cc0f81a13c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:42:26 +0530 Subject: [PATCH 32/38] requirements --- requirements.txt | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/requirements.txt b/requirements.txt index b2f3638..afa3641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ -joblib==1.2.0 -librosa==0.9.2 -numpy==1.23.3 -hydra-core==1.2.0 -scikit-learn==1.1.2 -scipy==1.9.1 -torch==1.12.1 -tqdm==4.64.1 -mlflow==1.29.0 -protobuf==3.19.6 -boto3==1.24.86 -torchaudio==0.12.1 -huggingface-hu==0.10.0 -pytorch-lightning==1.7.7 -flake8==5.0.4 -black==22.8.0 \ No newline at end of file +black>=22.8.0 +boto3>=1.24.86 +flake8>=5.0.4 +huggingface-hu>=0.10.0 +hydra-core>=1.2.0 +joblib>=1.2.0 +librosa>=0.9.2 +mlflow>=1.29.0 +numpy>=1.23.3 +protobuf>=3.19.6 +pytorch-lightning>=1.7.7 +scikit-learn>=1.1.2 +scipy>=1.9.1 +torch>=1.12.1 +torchaudio>=0.12.1 +tqdm>=4.64.1 From e127987e3a24c66ef104f3c8d445efa88ef40c23 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:42:37 +0530 Subject: [PATCH 33/38] pre commit hooks --- .pre-commit-config.yaml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8eac0a2..807429c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,9 +18,15 @@ repos: - id: isort args: ["--profile", "black"] + - repo: https://gitlab.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: ['--ignore=E203,E501,F811,E712,W503'] + # Formatting, Whitespace, etc - - repo: git://github.com/pre-commit/pre-commit-hooks - rev: v2.20.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 hooks: - id: trailing-whitespace - id: check-added-large-files @@ -35,5 +41,3 @@ repos: - id: requirements-txt-fixer - id: mixed-line-ending args: ['--fix=no'] - - id: flake8 - args: ['--ignore=E203,E501,F811,E712,W503'] From d4f1087c4565fd593dbfc6b9d20b31ffe0e52b02 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:42:51 +0530 Subject: [PATCH 34/38] toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8f12f30..b3e5d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,4 +12,4 @@ exclude = ''' | \.venv )/ ) -''' \ No newline at end of file +''' From 61741c528fce6081cf1ac7f794661d8ee279da39 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:43:04 +0530 Subject: [PATCH 35/38] flake8 --- .flake8 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 861f69a..abbbc73 100644 --- a/.flake8 +++ b/.flake8 @@ -6,4 +6,4 @@ ignore = E203, E266, E501, W503 max-line-length = 80 max-complexity = 18 select = B,C,E,F,W,T4,B9 -exclude = tools/kaldi_decoder \ No newline at end of file +exclude = tools/kaldi_decoder From 95a998d824c0cd476f7603622a3af43409743367 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:43:25 +0530 Subject: [PATCH 36/38] environment --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 4f211bf..8da22e1 100644 --- a/environment.yml +++ b/environment.yml @@ -5,4 +5,4 @@ dependencies: - python=3.8 - pip: - -r requirements.txt - - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html \ No newline at end of file + - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html From b82ba4e1bba9b92acb9903941b4290dc94763721 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:43:33 +0530 Subject: [PATCH 37/38] hawk --- hpc_entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hpc_entrypoint.sh b/hpc_entrypoint.sh index 7372eb9..6d6a3a0 100644 --- a/hpc_entrypoint.sh +++ b/hpc_entrypoint.sh @@ -33,7 +33,7 @@ mkdir temp pwd #python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train -#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test +#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test echo "Start Training..." python cli/train.py From d8a4d664a0d7e77dca2a2f3e4f0238f9b21ae666 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 21:01:46 +0530 Subject: [PATCH 38/38] update readme --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 743a823..97e0c7e 100644 --- a/README.md +++ b/README.md @@ -1 +1,6 @@ # enhancer +Enhancer is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . Enhancer provides + +* Various pretrained models nicely integrated with huggingface that users can select and use without any hastle. +* Ability to train and validation your own custom speech enhancement models with just under 10 lines of code! +* A command line tool that facilitates training of highly customisable speech enhacement models from the terminal itself! \ No newline at end of file