From 9ef6665b84463f596c0336ba690d77b8e617bd6b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 19 Sep 2022 22:35:21 +0530 Subject: [PATCH] add enhance function --- enhancer/models/model.py | 66 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 07281c9..f2339a2 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,9 +1,15 @@ +from asyncore import write from importlib import import_module +from lib2to3.pgen2.token import OP +import wave +from xmlrpc.client import boolean from huggingface_hub import cached_download, hf_hub_url +import numpy as np import os -from typing import Optional, Union, List, Path, Text +from typing import Optional, Union, List, Path, Text, Dict, Any from torch.optim import Adam import torch +from torch.nn.functional import pad import pytorch_lightning as pl from pytorch_lightning.utilities.cloud_io import load as pl_load from urllib.parse import urlparse @@ -11,7 +17,9 @@ from urllib.parse import urlparse from enhancer import __version__ from enhancer.data.dataset import Dataset +from enhancer.utils.io import Audio from enhancer.utils.loss import Avergeloss +from enhancer.inference import Inference CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -30,8 +38,8 @@ class Model(pl.LightningModule): ): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" - self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") self.dataset = dataset + self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") @property @@ -40,6 +48,8 @@ class Model(pl.LightningModule): @dataset.setter def dataset(self,dataset): + if dataset is not None: + self.save_hyperparameters("duration",self.dataset.duration) self._dataset = dataset def setup(self,stage:Optional[str]=None): @@ -99,6 +109,10 @@ class Model(pl.LightningModule): } + def on_load_checkpoint(self, checkpoint: Dict[str, Any]): + pass + + @classmethod def from_pretrained( cls, @@ -157,7 +171,55 @@ class Model(pl.LightningModule): print(e) + return model + + def infer_batch(self,batch,batch_size): + assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" + batch_predictions = [] + self.eval().to(self.device) + + for batch_id in range(batch.shape[0],batch_size): + batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + prediction = self(batch_data) + batch_predictions.append(prediction) + + return torch.vstack(batch_predictions) + + def enhance( + self, + audio:Union[Path,np.ndarray,torch.Tensor], + sampling_rate:Optional[int]=None, + batch_size:int=32, + save_output:boolean=False, + duration:Optional[int]=None, + step_size:Optional[int]=None,): + + model_sampling_rate = self.model.hprams("sampling_rate") + if duration is None: + duration = self.model.hparams("duration") + waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) + waveform.to(self.device) + window_size = round(duration * model_sampling_rate) + batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) + batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size) + waveform = Inference.aggreagate(batch_prediction,window_size,step_size) + + if save_output and isinstance(audio,(str,Path)): + Inference.write_output(waveform,audio,model_sampling_rate) + + else: + return waveform + + + + + + + + + +