refactor model.py
This commit is contained in:
		
							parent
							
								
									96c6108ec6
								
							
						
					
					
						commit
						2cf9803ed1
					
				|  | @ -27,24 +27,48 @@ CACHE_DIR = "" | |||
| HF_TORCH_WEIGHTS = "" | ||||
| DEFAULT_DEVICE = "cpu" | ||||
| 
 | ||||
| 
 | ||||
| class Model(pl.LightningModule): | ||||
|     """ | ||||
|     Base class for all models | ||||
|     parameters: | ||||
|         num_channels: int, default to 1 | ||||
|             number of channels in input audio | ||||
|         sampling_rate : int, default 16khz | ||||
|             audio sampling rate | ||||
|         lr: float, optional | ||||
|             learning rate for model training | ||||
|         dataset: EnhancerDataset, optional | ||||
|             Enhancer dataset used for training/validation | ||||
|         duration: float, optional | ||||
|             duration used for training/inference | ||||
|         loss : string or List of strings, default to "mse" | ||||
|             loss functions to be used. Available ("mse","mae","Si-SDR") | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_channels:int=1, | ||||
|         sampling_rate:int=16000, | ||||
|         lr:float=1e-3, | ||||
|         dataset:Optional[EnhancerDataset]=None, | ||||
|         duration:Optional[float]=None, | ||||
|         num_channels: int = 1, | ||||
|         sampling_rate: int = 16000, | ||||
|         lr: float = 1e-3, | ||||
|         dataset: Optional[EnhancerDataset] = None, | ||||
|         duration: Optional[float] = None, | ||||
|         loss: Union[str, List] = "mse", | ||||
|         metric:Union[str,List] = "mse" | ||||
|         metric: Union[str, List] = "mse", | ||||
|     ): | ||||
|         super().__init__() | ||||
|         assert num_channels ==1 , "Enhancer only support for mono channel models" | ||||
|         assert ( | ||||
|             num_channels == 1 | ||||
|         ), "Enhancer only support for mono channel models" | ||||
|         self.dataset = dataset | ||||
|         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") | ||||
|         self.save_hyperparameters( | ||||
|             "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" | ||||
|         ) | ||||
|         if self.logger: | ||||
|             self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") | ||||
|             self.logger.experiment.log_dict( | ||||
|                 dict(self.hparams), "hyperparameters.json" | ||||
|             ) | ||||
| 
 | ||||
|         self.loss = loss | ||||
|         self.metric = metric | ||||
|  | @ -54,9 +78,9 @@ class Model(pl.LightningModule): | |||
|         return self._loss | ||||
| 
 | ||||
|     @loss.setter | ||||
|     def loss(self,loss): | ||||
|     def loss(self, loss): | ||||
| 
 | ||||
|         if isinstance(loss,str): | ||||
|         if isinstance(loss, str): | ||||
|             losses = [loss] | ||||
| 
 | ||||
|         self._loss = Avergeloss(losses) | ||||
|  | @ -66,23 +90,22 @@ class Model(pl.LightningModule): | |||
|         return self._metric | ||||
| 
 | ||||
|     @metric.setter | ||||
|     def metric(self,metric): | ||||
|     def metric(self, metric): | ||||
| 
 | ||||
|         if isinstance(metric,str): | ||||
|         if isinstance(metric, str): | ||||
|             metric = [metric] | ||||
| 
 | ||||
|         self._metric = Avergeloss(metric) | ||||
| 
 | ||||
| 
 | ||||
|     @property | ||||
|     def dataset(self): | ||||
|         return self._dataset | ||||
| 
 | ||||
|     @dataset.setter | ||||
|     def dataset(self,dataset): | ||||
|     def dataset(self, dataset): | ||||
|         self._dataset = dataset | ||||
| 
 | ||||
|     def setup(self,stage:Optional[str]=None): | ||||
|     def setup(self, stage: Optional[str] = None): | ||||
|         if stage == "fit": | ||||
|             self.dataset.setup(stage) | ||||
|             self.dataset.model = self | ||||
|  | @ -94,9 +117,9 @@ class Model(pl.LightningModule): | |||
|         return self.dataset.val_dataloader() | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         return Adam(self.parameters(), lr = self.hparams.lr) | ||||
|         return Adam(self.parameters(), lr=self.hparams.lr) | ||||
| 
 | ||||
|     def training_step(self,batch, batch_idx:int): | ||||
|     def training_step(self, batch, batch_idx: int): | ||||
| 
 | ||||
|         mixed_waveform = batch["noisy"] | ||||
|         target = batch["clean"] | ||||
|  | @ -105,13 +128,16 @@ class Model(pl.LightningModule): | |||
|         loss = self.loss(prediction, target) | ||||
| 
 | ||||
|         if self.logger: | ||||
|             self.logger.experiment.log_metric(run_id=self.logger.run_id,  | ||||
|                     key="train_loss", value=loss.item(),  | ||||
|                     step=self.global_step) | ||||
|         self.log("train_loss",loss.item()) | ||||
|         return {"loss":loss} | ||||
|             self.logger.experiment.log_metric( | ||||
|                 run_id=self.logger.run_id, | ||||
|                 key="train_loss", | ||||
|                 value=loss.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
|         self.log("train_loss", loss.item()) | ||||
|         return {"loss": loss} | ||||
| 
 | ||||
|     def validation_step(self,batch,batch_idx:int): | ||||
|     def validation_step(self, batch, batch_idx: int): | ||||
| 
 | ||||
|         mixed_waveform = batch["noisy"] | ||||
|         target = batch["clean"] | ||||
|  | @ -119,48 +145,92 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|         metric_val = self.metric(prediction, target) | ||||
|         loss_val = self.loss(prediction, target) | ||||
|         self.log("val_metric",metric_val.item()) | ||||
|         self.log("val_loss",loss_val.item()) | ||||
|         self.log("val_metric", metric_val.item()) | ||||
|         self.log("val_loss", loss_val.item()) | ||||
| 
 | ||||
|         if self.logger: | ||||
|             self.logger.experiment.log_metric(run_id=self.logger.run_id,  | ||||
|                         key="val_loss",value=loss_val.item(), | ||||
|                         step=self.global_step) | ||||
|             self.logger.experiment.log_metric(run_id=self.logger.run_id,  | ||||
|                         key="val_metric",value=metric_val.item(), | ||||
|                         step=self.global_step) | ||||
|             self.logger.experiment.log_metric( | ||||
|                 run_id=self.logger.run_id, | ||||
|                 key="val_loss", | ||||
|                 value=loss_val.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
|             self.logger.experiment.log_metric( | ||||
|                 run_id=self.logger.run_id, | ||||
|                 key="val_metric", | ||||
|                 value=metric_val.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
| 
 | ||||
|         return {"loss":loss_val} | ||||
|         return {"loss": loss_val} | ||||
| 
 | ||||
|     def on_save_checkpoint(self, checkpoint): | ||||
| 
 | ||||
|         checkpoint["enhancer"] = { | ||||
|             "version": { | ||||
|                 "enhancer":__version__, | ||||
|                 "pytorch":torch.__version__ | ||||
|             "version": {"enhancer": __version__, "pytorch": torch.__version__}, | ||||
|             "architecture": { | ||||
|                 "module": self.__class__.__module__, | ||||
|                 "class": self.__class__.__name__, | ||||
|             }, | ||||
|             "architecture":{ | ||||
|                 "module":self.__class__.__module__, | ||||
|                 "class":self.__class__.__name__ | ||||
|             } | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|     def on_load_checkpoint(self, checkpoint: Dict[str, Any]): | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_pretrained( | ||||
|         cls, | ||||
|         checkpoint: Union[Path, Text], | ||||
|         map_location = None, | ||||
|         map_location=None, | ||||
|         hparams_file: Union[Path, Text] = None, | ||||
|         strict: bool = True, | ||||
|         use_auth_token: Union[Text, None] = None, | ||||
|         cached_dir: Union[Path, Text]=CACHE_DIR, | ||||
|         **kwargs | ||||
|         cached_dir: Union[Path, Text] = CACHE_DIR, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """ | ||||
|         Load Pretrained model | ||||
| 
 | ||||
|         parameters: | ||||
|         checkpoint : Path or str | ||||
|             Path to checkpoint, or a remote URL, or a model identifier from | ||||
|             the huggingface.co model hub. | ||||
|         map_location: optional | ||||
|             Same role as in torch.load(). | ||||
|             Defaults to `lambda storage, loc: storage`. | ||||
|         hparams_file : Path or str, optional | ||||
|             Path to a .yaml file with hierarchical structure as in this example: | ||||
|                 drop_prob: 0.2 | ||||
|                 dataloader: | ||||
|                     batch_size: 32 | ||||
|             You most likely won’t need this since Lightning will always save the | ||||
|             hyperparameters to the checkpoint. However, if your checkpoint weights | ||||
|             do not have the hyperparameters saved, use this method to pass in a .yaml | ||||
|             file with the hparams you would like to use. These will be converted | ||||
|             into a dict and passed into your Model for use. | ||||
|         strict : bool, optional | ||||
|             Whether to strictly enforce that the keys in checkpoint match | ||||
|             the keys returned by this module’s state dict. Defaults to True. | ||||
|         use_auth_token : str, optional | ||||
|             When loading a private huggingface.co model, set `use_auth_token` | ||||
|             to True or to a string containing your hugginface.co authentication | ||||
|             token that can be obtained by running `huggingface-cli login` | ||||
|         cache_dir: Path or str, optional | ||||
|             Path to model cache directory. Defaults to content of PYANNOTE_CACHE | ||||
|             environment variable, or "~/.cache/torch/pyannote" when unset. | ||||
|         kwargs: optional | ||||
|             Any extra keyword args needed to init the model. | ||||
|             Can also be used to override saved hyperparameter values. | ||||
| 
 | ||||
|         Returns | ||||
|         ------- | ||||
|         model : Model | ||||
|             Model | ||||
| 
 | ||||
|         See also | ||||
|         -------- | ||||
|         torch.load | ||||
|         """ | ||||
| 
 | ||||
|         checkpoint = str(checkpoint) | ||||
|         if hparams_file is not None: | ||||
|  | @ -168,7 +238,7 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|         if os.path.isfile(checkpoint): | ||||
|             model_path_pl = checkpoint | ||||
|         elif urlparse(checkpoint).scheme in ("http","https"): | ||||
|         elif urlparse(checkpoint).scheme in ("http", "https"): | ||||
|             model_path_pl = checkpoint | ||||
|         else: | ||||
| 
 | ||||
|  | @ -180,17 +250,20 @@ class Model(pl.LightningModule): | |||
|                 revision_id = None | ||||
| 
 | ||||
|             url = hf_hub_url( | ||||
|                 model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id | ||||
|                 model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id | ||||
|             ) | ||||
|             model_path_pl = cached_download( | ||||
|                 url=url,library_name="enhancer",library_version=__version__, | ||||
|                 cache_dir=cached_dir,use_auth_token=use_auth_token | ||||
|                 url=url, | ||||
|                 library_name="enhancer", | ||||
|                 library_version=__version__, | ||||
|                 cache_dir=cached_dir, | ||||
|                 use_auth_token=use_auth_token, | ||||
|             ) | ||||
| 
 | ||||
|         if map_location is None: | ||||
|             map_location = torch.device(DEFAULT_DEVICE) | ||||
| 
 | ||||
|         loaded_checkpoint = pl_load(model_path_pl,map_location) | ||||
|         loaded_checkpoint = pl_load(model_path_pl, map_location) | ||||
|         module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] | ||||
|         class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] | ||||
|         module = import_module(module_name) | ||||
|  | @ -198,27 +271,38 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|         try: | ||||
|             model = Klass.load_from_checkpoint( | ||||
|                 checkpoint_path = model_path_pl, | ||||
|                 map_location = map_location, | ||||
|                 hparams_file = hparams_file, | ||||
|                 strict = strict, | ||||
|                 **kwargs | ||||
|                 checkpoint_path=model_path_pl, | ||||
|                 map_location=map_location, | ||||
|                 hparams_file=hparams_file, | ||||
|                 strict=strict, | ||||
|                 **kwargs, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
| 
 | ||||
| 
 | ||||
|         return model | ||||
| 
 | ||||
|     def infer(self,batch:torch.Tensor,batch_size:int=32): | ||||
|     def infer(self, batch: torch.Tensor, batch_size: int = 32): | ||||
|         """ | ||||
|         perform model inference | ||||
|         parameters: | ||||
|             batch : torch.Tensor | ||||
|                 input data | ||||
|             batch_size : int, default 32 | ||||
|                 batch size for inference | ||||
|         """ | ||||
| 
 | ||||
|         assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" | ||||
|         assert ( | ||||
|             batch.ndim == 3 | ||||
|         ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" | ||||
|         batch_predictions = [] | ||||
|         self.eval().to(self.device) | ||||
| 
 | ||||
|         with torch.no_grad(): | ||||
|             for batch_id in range(0,batch.shape[0],batch_size): | ||||
|                 batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) | ||||
|             for batch_id in range(0, batch.shape[0], batch_size): | ||||
|                 batch_data = batch[batch_id : batch_id + batch_size, :, :].to( | ||||
|                     self.device | ||||
|                 ) | ||||
|                 prediction = self(batch_data) | ||||
|                 batch_predictions.append(prediction) | ||||
| 
 | ||||
|  | @ -226,46 +310,61 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|     def enhance( | ||||
|         self, | ||||
|         audio:Union[Path,np.ndarray,torch.Tensor], | ||||
|         sampling_rate:Optional[int]=None, | ||||
|         batch_size:int=32, | ||||
|         save_output:bool=False, | ||||
|         duration:Optional[int]=None, | ||||
|         step_size:Optional[int]=None,): | ||||
|         audio: Union[Path, np.ndarray, torch.Tensor], | ||||
|         sampling_rate: Optional[int] = None, | ||||
|         batch_size: int = 32, | ||||
|         save_output: bool = False, | ||||
|         duration: Optional[int] = None, | ||||
|         step_size: Optional[int] = None, | ||||
|     ): | ||||
|         """ | ||||
|         Enhance audio using loaded pretained model. | ||||
| 
 | ||||
|         parameters: | ||||
|             audio: Path to audio file or numpy array or torch tensor | ||||
|                 single input audio | ||||
|             sampling_rate: int, optional incase input is path | ||||
|                 sampling rate of input | ||||
|             batch_size: int, default 32 | ||||
|                 input audio is split into multiple chunks. Inference is done on batches | ||||
|                 of these chunks according to given batch size. | ||||
|             save_output : bool, default False | ||||
|                 weather to save output to file | ||||
|             duration : float, optional | ||||
|                 chunk duration in seconds, defaults to duration of loaded pretrained model. | ||||
|             step_size: int, optional | ||||
|                 step size between consecutive durations, defaults to 50% of duration | ||||
|         """ | ||||
| 
 | ||||
|         model_sampling_rate = self.hparams["sampling_rate"] | ||||
|         if duration is None: | ||||
|             duration = self.hparams["duration"] | ||||
|         waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) | ||||
|         waveform = Inference.read_input( | ||||
|             audio, sampling_rate, model_sampling_rate | ||||
|         ) | ||||
|         waveform.to(self.device) | ||||
|         window_size = round(duration * model_sampling_rate) | ||||
|         batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) | ||||
|         batch_prediction = self.infer(batched_waveform,batch_size=batch_size) | ||||
|         waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) | ||||
|         batched_waveform = Inference.batchify( | ||||
|             waveform, window_size, step_size=step_size | ||||
|         ) | ||||
|         batch_prediction = self.infer(batched_waveform, batch_size=batch_size) | ||||
|         waveform = Inference.aggreagate( | ||||
|             batch_prediction, | ||||
|             window_size, | ||||
|             waveform.shape[-1], | ||||
|             step_size, | ||||
|         ) | ||||
| 
 | ||||
|         if save_output and isinstance(audio,(str,Path)): | ||||
|             Inference.write_output(waveform,audio,model_sampling_rate) | ||||
|         if save_output and isinstance(audio, (str, Path)): | ||||
|             Inference.write_output(waveform, audio, model_sampling_rate) | ||||
| 
 | ||||
|         else: | ||||
|             waveform = Inference.prepare_output(waveform, model_sampling_rate, | ||||
|                                 audio, sampling_rate)     | ||||
|             waveform = Inference.prepare_output( | ||||
|                 waveform, model_sampling_rate, audio, sampling_rate | ||||
|             ) | ||||
|             return waveform | ||||
| 
 | ||||
|     @property | ||||
|     def valid_monitor(self): | ||||
| 
 | ||||
|         return "max" if self.loss.higher_better else "min" | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|          | ||||
| 
 | ||||
| 
 | ||||
|         | ||||
| 
 | ||||
| 
 | ||||
|              | ||||
| 
 | ||||
| 
 | ||||
|          | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786