refactor model.py
This commit is contained in:
		
							parent
							
								
									96c6108ec6
								
							
						
					
					
						commit
						2cf9803ed1
					
				|  | @ -27,24 +27,48 @@ CACHE_DIR = "" | ||||||
| HF_TORCH_WEIGHTS = "" | HF_TORCH_WEIGHTS = "" | ||||||
| DEFAULT_DEVICE = "cpu" | DEFAULT_DEVICE = "cpu" | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class Model(pl.LightningModule): | class Model(pl.LightningModule): | ||||||
|  |     """ | ||||||
|  |     Base class for all models | ||||||
|  |     parameters: | ||||||
|  |         num_channels: int, default to 1 | ||||||
|  |             number of channels in input audio | ||||||
|  |         sampling_rate : int, default 16khz | ||||||
|  |             audio sampling rate | ||||||
|  |         lr: float, optional | ||||||
|  |             learning rate for model training | ||||||
|  |         dataset: EnhancerDataset, optional | ||||||
|  |             Enhancer dataset used for training/validation | ||||||
|  |         duration: float, optional | ||||||
|  |             duration used for training/inference | ||||||
|  |         loss : string or List of strings, default to "mse" | ||||||
|  |             loss functions to be used. Available ("mse","mae","Si-SDR") | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         num_channels:int=1, |         num_channels: int = 1, | ||||||
|         sampling_rate:int=16000, |         sampling_rate: int = 16000, | ||||||
|         lr:float=1e-3, |         lr: float = 1e-3, | ||||||
|         dataset:Optional[EnhancerDataset]=None, |         dataset: Optional[EnhancerDataset] = None, | ||||||
|         duration:Optional[float]=None, |         duration: Optional[float] = None, | ||||||
|         loss: Union[str, List] = "mse", |         loss: Union[str, List] = "mse", | ||||||
|         metric:Union[str,List] = "mse" |         metric: Union[str, List] = "mse", | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         assert num_channels ==1 , "Enhancer only support for mono channel models" |         assert ( | ||||||
|  |             num_channels == 1 | ||||||
|  |         ), "Enhancer only support for mono channel models" | ||||||
|         self.dataset = dataset |         self.dataset = dataset | ||||||
|         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") |         self.save_hyperparameters( | ||||||
|  |             "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" | ||||||
|  |         ) | ||||||
|         if self.logger: |         if self.logger: | ||||||
|             self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") |             self.logger.experiment.log_dict( | ||||||
|  |                 dict(self.hparams), "hyperparameters.json" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         self.loss = loss |         self.loss = loss | ||||||
|         self.metric = metric |         self.metric = metric | ||||||
|  | @ -54,10 +78,10 @@ class Model(pl.LightningModule): | ||||||
|         return self._loss |         return self._loss | ||||||
| 
 | 
 | ||||||
|     @loss.setter |     @loss.setter | ||||||
|     def loss(self,loss): |     def loss(self, loss): | ||||||
| 
 | 
 | ||||||
|         if isinstance(loss,str): |         if isinstance(loss, str): | ||||||
|                 losses = [loss] |             losses = [loss] | ||||||
| 
 | 
 | ||||||
|         self._loss = Avergeloss(losses) |         self._loss = Avergeloss(losses) | ||||||
| 
 | 
 | ||||||
|  | @ -66,23 +90,22 @@ class Model(pl.LightningModule): | ||||||
|         return self._metric |         return self._metric | ||||||
| 
 | 
 | ||||||
|     @metric.setter |     @metric.setter | ||||||
|     def metric(self,metric): |     def metric(self, metric): | ||||||
| 
 | 
 | ||||||
|         if isinstance(metric,str): |         if isinstance(metric, str): | ||||||
|                 metric = [metric] |             metric = [metric] | ||||||
| 
 | 
 | ||||||
|         self._metric = Avergeloss(metric) |         self._metric = Avergeloss(metric) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     @property |     @property | ||||||
|     def dataset(self): |     def dataset(self): | ||||||
|         return self._dataset |         return self._dataset | ||||||
| 
 | 
 | ||||||
|     @dataset.setter |     @dataset.setter | ||||||
|     def dataset(self,dataset): |     def dataset(self, dataset): | ||||||
|         self._dataset = dataset |         self._dataset = dataset | ||||||
| 
 | 
 | ||||||
|     def setup(self,stage:Optional[str]=None): |     def setup(self, stage: Optional[str] = None): | ||||||
|         if stage == "fit": |         if stage == "fit": | ||||||
|             self.dataset.setup(stage) |             self.dataset.setup(stage) | ||||||
|             self.dataset.model = self |             self.dataset.model = self | ||||||
|  | @ -94,9 +117,9 @@ class Model(pl.LightningModule): | ||||||
|         return self.dataset.val_dataloader() |         return self.dataset.val_dataloader() | ||||||
| 
 | 
 | ||||||
|     def configure_optimizers(self): |     def configure_optimizers(self): | ||||||
|         return Adam(self.parameters(), lr = self.hparams.lr) |         return Adam(self.parameters(), lr=self.hparams.lr) | ||||||
| 
 | 
 | ||||||
|     def training_step(self,batch, batch_idx:int): |     def training_step(self, batch, batch_idx: int): | ||||||
| 
 | 
 | ||||||
|         mixed_waveform = batch["noisy"] |         mixed_waveform = batch["noisy"] | ||||||
|         target = batch["clean"] |         target = batch["clean"] | ||||||
|  | @ -105,13 +128,16 @@ class Model(pl.LightningModule): | ||||||
|         loss = self.loss(prediction, target) |         loss = self.loss(prediction, target) | ||||||
| 
 | 
 | ||||||
|         if self.logger: |         if self.logger: | ||||||
|             self.logger.experiment.log_metric(run_id=self.logger.run_id,  |             self.logger.experiment.log_metric( | ||||||
|                     key="train_loss", value=loss.item(),  |                 run_id=self.logger.run_id, | ||||||
|                     step=self.global_step) |                 key="train_loss", | ||||||
|         self.log("train_loss",loss.item()) |                 value=loss.item(), | ||||||
|         return {"loss":loss} |                 step=self.global_step, | ||||||
|  |             ) | ||||||
|  |         self.log("train_loss", loss.item()) | ||||||
|  |         return {"loss": loss} | ||||||
| 
 | 
 | ||||||
|     def validation_step(self,batch,batch_idx:int): |     def validation_step(self, batch, batch_idx: int): | ||||||
| 
 | 
 | ||||||
|         mixed_waveform = batch["noisy"] |         mixed_waveform = batch["noisy"] | ||||||
|         target = batch["clean"] |         target = batch["clean"] | ||||||
|  | @ -119,48 +145,92 @@ class Model(pl.LightningModule): | ||||||
| 
 | 
 | ||||||
|         metric_val = self.metric(prediction, target) |         metric_val = self.metric(prediction, target) | ||||||
|         loss_val = self.loss(prediction, target) |         loss_val = self.loss(prediction, target) | ||||||
|         self.log("val_metric",metric_val.item()) |         self.log("val_metric", metric_val.item()) | ||||||
|         self.log("val_loss",loss_val.item()) |         self.log("val_loss", loss_val.item()) | ||||||
| 
 | 
 | ||||||
|         if self.logger: |         if self.logger: | ||||||
|             self.logger.experiment.log_metric(run_id=self.logger.run_id,  |             self.logger.experiment.log_metric( | ||||||
|                         key="val_loss",value=loss_val.item(), |                 run_id=self.logger.run_id, | ||||||
|                         step=self.global_step) |                 key="val_loss", | ||||||
|             self.logger.experiment.log_metric(run_id=self.logger.run_id,  |                 value=loss_val.item(), | ||||||
|                         key="val_metric",value=metric_val.item(), |                 step=self.global_step, | ||||||
|                         step=self.global_step) |             ) | ||||||
|  |             self.logger.experiment.log_metric( | ||||||
|  |                 run_id=self.logger.run_id, | ||||||
|  |                 key="val_metric", | ||||||
|  |                 value=metric_val.item(), | ||||||
|  |                 step=self.global_step, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         return {"loss":loss_val} |         return {"loss": loss_val} | ||||||
| 
 | 
 | ||||||
|     def on_save_checkpoint(self, checkpoint): |     def on_save_checkpoint(self, checkpoint): | ||||||
| 
 | 
 | ||||||
|         checkpoint["enhancer"] = { |         checkpoint["enhancer"] = { | ||||||
|             "version": { |             "version": {"enhancer": __version__, "pytorch": torch.__version__}, | ||||||
|                 "enhancer":__version__, |             "architecture": { | ||||||
|                 "pytorch":torch.__version__ |                 "module": self.__class__.__module__, | ||||||
|  |                 "class": self.__class__.__name__, | ||||||
|             }, |             }, | ||||||
|             "architecture":{ |  | ||||||
|                 "module":self.__class__.__module__, |  | ||||||
|                 "class":self.__class__.__name__ |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     def on_load_checkpoint(self, checkpoint: Dict[str, Any]): |     def on_load_checkpoint(self, checkpoint: Dict[str, Any]): | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def from_pretrained( |     def from_pretrained( | ||||||
|         cls, |         cls, | ||||||
|         checkpoint: Union[Path, Text], |         checkpoint: Union[Path, Text], | ||||||
|         map_location = None, |         map_location=None, | ||||||
|         hparams_file: Union[Path, Text] = None, |         hparams_file: Union[Path, Text] = None, | ||||||
|         strict: bool = True, |         strict: bool = True, | ||||||
|         use_auth_token: Union[Text, None] = None, |         use_auth_token: Union[Text, None] = None, | ||||||
|         cached_dir: Union[Path, Text]=CACHE_DIR, |         cached_dir: Union[Path, Text] = CACHE_DIR, | ||||||
|         **kwargs |         **kwargs, | ||||||
|     ): |     ): | ||||||
|  |         """ | ||||||
|  |         Load Pretrained model | ||||||
|  | 
 | ||||||
|  |         parameters: | ||||||
|  |         checkpoint : Path or str | ||||||
|  |             Path to checkpoint, or a remote URL, or a model identifier from | ||||||
|  |             the huggingface.co model hub. | ||||||
|  |         map_location: optional | ||||||
|  |             Same role as in torch.load(). | ||||||
|  |             Defaults to `lambda storage, loc: storage`. | ||||||
|  |         hparams_file : Path or str, optional | ||||||
|  |             Path to a .yaml file with hierarchical structure as in this example: | ||||||
|  |                 drop_prob: 0.2 | ||||||
|  |                 dataloader: | ||||||
|  |                     batch_size: 32 | ||||||
|  |             You most likely won’t need this since Lightning will always save the | ||||||
|  |             hyperparameters to the checkpoint. However, if your checkpoint weights | ||||||
|  |             do not have the hyperparameters saved, use this method to pass in a .yaml | ||||||
|  |             file with the hparams you would like to use. These will be converted | ||||||
|  |             into a dict and passed into your Model for use. | ||||||
|  |         strict : bool, optional | ||||||
|  |             Whether to strictly enforce that the keys in checkpoint match | ||||||
|  |             the keys returned by this module’s state dict. Defaults to True. | ||||||
|  |         use_auth_token : str, optional | ||||||
|  |             When loading a private huggingface.co model, set `use_auth_token` | ||||||
|  |             to True or to a string containing your hugginface.co authentication | ||||||
|  |             token that can be obtained by running `huggingface-cli login` | ||||||
|  |         cache_dir: Path or str, optional | ||||||
|  |             Path to model cache directory. Defaults to content of PYANNOTE_CACHE | ||||||
|  |             environment variable, or "~/.cache/torch/pyannote" when unset. | ||||||
|  |         kwargs: optional | ||||||
|  |             Any extra keyword args needed to init the model. | ||||||
|  |             Can also be used to override saved hyperparameter values. | ||||||
|  | 
 | ||||||
|  |         Returns | ||||||
|  |         ------- | ||||||
|  |         model : Model | ||||||
|  |             Model | ||||||
|  | 
 | ||||||
|  |         See also | ||||||
|  |         -------- | ||||||
|  |         torch.load | ||||||
|  |         """ | ||||||
| 
 | 
 | ||||||
|         checkpoint = str(checkpoint) |         checkpoint = str(checkpoint) | ||||||
|         if hparams_file is not None: |         if hparams_file is not None: | ||||||
|  | @ -168,7 +238,7 @@ class Model(pl.LightningModule): | ||||||
| 
 | 
 | ||||||
|         if os.path.isfile(checkpoint): |         if os.path.isfile(checkpoint): | ||||||
|             model_path_pl = checkpoint |             model_path_pl = checkpoint | ||||||
|         elif urlparse(checkpoint).scheme in ("http","https"): |         elif urlparse(checkpoint).scheme in ("http", "https"): | ||||||
|             model_path_pl = checkpoint |             model_path_pl = checkpoint | ||||||
|         else: |         else: | ||||||
| 
 | 
 | ||||||
|  | @ -180,45 +250,59 @@ class Model(pl.LightningModule): | ||||||
|                 revision_id = None |                 revision_id = None | ||||||
| 
 | 
 | ||||||
|             url = hf_hub_url( |             url = hf_hub_url( | ||||||
|                 model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id |                 model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id | ||||||
|             ) |             ) | ||||||
|             model_path_pl = cached_download( |             model_path_pl = cached_download( | ||||||
|                 url=url,library_name="enhancer",library_version=__version__, |                 url=url, | ||||||
|                 cache_dir=cached_dir,use_auth_token=use_auth_token |                 library_name="enhancer", | ||||||
|  |                 library_version=__version__, | ||||||
|  |                 cache_dir=cached_dir, | ||||||
|  |                 use_auth_token=use_auth_token, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         if map_location is None: |         if map_location is None: | ||||||
|             map_location = torch.device(DEFAULT_DEVICE) |             map_location = torch.device(DEFAULT_DEVICE) | ||||||
| 
 | 
 | ||||||
|         loaded_checkpoint = pl_load(model_path_pl,map_location) |         loaded_checkpoint = pl_load(model_path_pl, map_location) | ||||||
|         module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] |         module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] | ||||||
|         class_name =  loaded_checkpoint["enhancer"]["architecture"]["class"] |         class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] | ||||||
|         module = import_module(module_name) |         module = import_module(module_name) | ||||||
|         Klass = getattr(module, class_name) |         Klass = getattr(module, class_name) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             model = Klass.load_from_checkpoint( |             model = Klass.load_from_checkpoint( | ||||||
|                 checkpoint_path = model_path_pl, |                 checkpoint_path=model_path_pl, | ||||||
|                 map_location = map_location, |                 map_location=map_location, | ||||||
|                 hparams_file = hparams_file, |                 hparams_file=hparams_file, | ||||||
|                 strict = strict, |                 strict=strict, | ||||||
|                 **kwargs |                 **kwargs, | ||||||
|             ) |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(e) |             print(e) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|         return model |         return model | ||||||
| 
 | 
 | ||||||
|     def infer(self,batch:torch.Tensor,batch_size:int=32): |     def infer(self, batch: torch.Tensor, batch_size: int = 32): | ||||||
|  |         """ | ||||||
|  |         perform model inference | ||||||
|  |         parameters: | ||||||
|  |             batch : torch.Tensor | ||||||
|  |                 input data | ||||||
|  |             batch_size : int, default 32 | ||||||
|  |                 batch size for inference | ||||||
|  |         """ | ||||||
| 
 | 
 | ||||||
|         assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" |         assert ( | ||||||
|  |             batch.ndim == 3 | ||||||
|  |         ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" | ||||||
|         batch_predictions = [] |         batch_predictions = [] | ||||||
|         self.eval().to(self.device) |         self.eval().to(self.device) | ||||||
| 
 | 
 | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             for batch_id in range(0,batch.shape[0],batch_size): |             for batch_id in range(0, batch.shape[0], batch_size): | ||||||
|                 batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) |                 batch_data = batch[batch_id : batch_id + batch_size, :, :].to( | ||||||
|  |                     self.device | ||||||
|  |                 ) | ||||||
|                 prediction = self(batch_data) |                 prediction = self(batch_data) | ||||||
|                 batch_predictions.append(prediction) |                 batch_predictions.append(prediction) | ||||||
| 
 | 
 | ||||||
|  | @ -226,46 +310,61 @@ class Model(pl.LightningModule): | ||||||
| 
 | 
 | ||||||
|     def enhance( |     def enhance( | ||||||
|         self, |         self, | ||||||
|         audio:Union[Path,np.ndarray,torch.Tensor], |         audio: Union[Path, np.ndarray, torch.Tensor], | ||||||
|         sampling_rate:Optional[int]=None, |         sampling_rate: Optional[int] = None, | ||||||
|         batch_size:int=32, |         batch_size: int = 32, | ||||||
|         save_output:bool=False, |         save_output: bool = False, | ||||||
|         duration:Optional[int]=None, |         duration: Optional[int] = None, | ||||||
|         step_size:Optional[int]=None,): |         step_size: Optional[int] = None, | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Enhance audio using loaded pretained model. | ||||||
|  | 
 | ||||||
|  |         parameters: | ||||||
|  |             audio: Path to audio file or numpy array or torch tensor | ||||||
|  |                 single input audio | ||||||
|  |             sampling_rate: int, optional incase input is path | ||||||
|  |                 sampling rate of input | ||||||
|  |             batch_size: int, default 32 | ||||||
|  |                 input audio is split into multiple chunks. Inference is done on batches | ||||||
|  |                 of these chunks according to given batch size. | ||||||
|  |             save_output : bool, default False | ||||||
|  |                 weather to save output to file | ||||||
|  |             duration : float, optional | ||||||
|  |                 chunk duration in seconds, defaults to duration of loaded pretrained model. | ||||||
|  |             step_size: int, optional | ||||||
|  |                 step size between consecutive durations, defaults to 50% of duration | ||||||
|  |         """ | ||||||
| 
 | 
 | ||||||
|         model_sampling_rate = self.hparams["sampling_rate"] |         model_sampling_rate = self.hparams["sampling_rate"] | ||||||
|         if duration is None: |         if duration is None: | ||||||
|             duration = self.hparams["duration"] |             duration = self.hparams["duration"] | ||||||
|         waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) |         waveform = Inference.read_input( | ||||||
|  |             audio, sampling_rate, model_sampling_rate | ||||||
|  |         ) | ||||||
|         waveform.to(self.device) |         waveform.to(self.device) | ||||||
|         window_size = round(duration * model_sampling_rate) |         window_size = round(duration * model_sampling_rate) | ||||||
|         batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) |         batched_waveform = Inference.batchify( | ||||||
|         batch_prediction = self.infer(batched_waveform,batch_size=batch_size) |             waveform, window_size, step_size=step_size | ||||||
|         waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) |         ) | ||||||
|  |         batch_prediction = self.infer(batched_waveform, batch_size=batch_size) | ||||||
|  |         waveform = Inference.aggreagate( | ||||||
|  |             batch_prediction, | ||||||
|  |             window_size, | ||||||
|  |             waveform.shape[-1], | ||||||
|  |             step_size, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if save_output and isinstance(audio,(str,Path)): |         if save_output and isinstance(audio, (str, Path)): | ||||||
|             Inference.write_output(waveform,audio,model_sampling_rate) |             Inference.write_output(waveform, audio, model_sampling_rate) | ||||||
| 
 | 
 | ||||||
|         else: |         else: | ||||||
|             waveform = Inference.prepare_output(waveform, model_sampling_rate, |             waveform = Inference.prepare_output( | ||||||
|                                 audio, sampling_rate)     |                 waveform, model_sampling_rate, audio, sampling_rate | ||||||
|  |             ) | ||||||
|             return waveform |             return waveform | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def valid_monitor(self): |     def valid_monitor(self): | ||||||
| 
 | 
 | ||||||
|         return "max" if self.loss.higher_better else "min" |         return "max" if self.loss.higher_better else "min" | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|          |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|              |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|          |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786