add enhance function
This commit is contained in:
		
							parent
							
								
									872303642f
								
							
						
					
					
						commit
						9ef6665b84
					
				|  | @ -1,9 +1,15 @@ | |||
| from asyncore import write | ||||
| from importlib import import_module | ||||
| from lib2to3.pgen2.token import OP | ||||
| import wave | ||||
| from xmlrpc.client import boolean | ||||
| from huggingface_hub import cached_download, hf_hub_url | ||||
| import numpy as np | ||||
| import os | ||||
| from typing import Optional, Union, List, Path, Text | ||||
| from typing import Optional, Union, List, Path, Text, Dict, Any | ||||
| from torch.optim import Adam | ||||
| import torch | ||||
| from torch.nn.functional import pad | ||||
| import pytorch_lightning as pl | ||||
| from pytorch_lightning.utilities.cloud_io import load as pl_load | ||||
| from urllib.parse import urlparse | ||||
|  | @ -11,7 +17,9 @@ from urllib.parse import urlparse | |||
| 
 | ||||
| from enhancer import __version__ | ||||
| from enhancer.data.dataset import Dataset | ||||
| from enhancer.utils.io import Audio | ||||
| from enhancer.utils.loss import Avergeloss | ||||
| from enhancer.inference import Inference | ||||
| 
 | ||||
| CACHE_DIR = "" | ||||
| HF_TORCH_WEIGHTS = "" | ||||
|  | @ -30,8 +38,8 @@ class Model(pl.LightningModule): | |||
|     ): | ||||
|         super().__init__() | ||||
|         assert num_channels ==1 , "Enhancer only support for mono channel models" | ||||
|         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") | ||||
|         self.dataset = dataset | ||||
|         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") | ||||
|          | ||||
|      | ||||
|     @property | ||||
|  | @ -40,6 +48,8 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|     @dataset.setter | ||||
|     def dataset(self,dataset): | ||||
|         if dataset is not None: | ||||
|             self.save_hyperparameters("duration",self.dataset.duration) | ||||
|         self._dataset = dataset | ||||
| 
 | ||||
|     def setup(self,stage:Optional[str]=None): | ||||
|  | @ -99,6 +109,10 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|         } | ||||
| 
 | ||||
|     def on_load_checkpoint(self, checkpoint: Dict[str, Any]): | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_pretrained( | ||||
|         cls, | ||||
|  | @ -157,6 +171,54 @@ class Model(pl.LightningModule): | |||
|             print(e) | ||||
| 
 | ||||
| 
 | ||||
|         return model  | ||||
| 
 | ||||
|     def infer_batch(self,batch,batch_size): | ||||
|          | ||||
|         assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" | ||||
|         batch_predictions = [] | ||||
|         self.eval().to(self.device) | ||||
| 
 | ||||
|         for batch_id in range(batch.shape[0],batch_size): | ||||
|             batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) | ||||
|             prediction = self(batch_data) | ||||
|             batch_predictions.append(prediction) | ||||
|          | ||||
|         return torch.vstack(batch_predictions) | ||||
| 
 | ||||
|     def enhance( | ||||
|         self, | ||||
|         audio:Union[Path,np.ndarray,torch.Tensor], | ||||
|         sampling_rate:Optional[int]=None, | ||||
|         batch_size:int=32, | ||||
|         save_output:boolean=False, | ||||
|         duration:Optional[int]=None, | ||||
|         step_size:Optional[int]=None,): | ||||
| 
 | ||||
|         model_sampling_rate = self.model.hprams("sampling_rate") | ||||
|         if duration is None: | ||||
|             duration = self.model.hparams("duration") | ||||
|         waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) | ||||
|         waveform.to(self.device) | ||||
|         window_size = round(duration * model_sampling_rate) | ||||
|         batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) | ||||
|         batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size) | ||||
|         waveform = Inference.aggreagate(batch_prediction,window_size,step_size) | ||||
|          | ||||
|         if save_output and isinstance(audio,(str,Path)): | ||||
|             Inference.write_output(waveform,audio,model_sampling_rate) | ||||
| 
 | ||||
|         else: | ||||
|             return waveform             | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|          | ||||
| 
 | ||||
| 
 | ||||
|         | ||||
| 
 | ||||
| 
 | ||||
|              | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786