diff --git a/enhancer/models/model.py b/enhancer/models/model.py index de2edab..071bbb6 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -27,66 +27,89 @@ CACHE_DIR = "" HF_TORCH_WEIGHTS = "" DEFAULT_DEVICE = "cpu" + class Model(pl.LightningModule): + """ + Base class for all models + parameters: + num_channels: int, default to 1 + number of channels in input audio + sampling_rate : int, default 16khz + audio sampling rate + lr: float, optional + learning rate for model training + dataset: EnhancerDataset, optional + Enhancer dataset used for training/validation + duration: float, optional + duration used for training/inference + loss : string or List of strings, default to "mse" + loss functions to be used. Available ("mse","mae","Si-SDR") + + """ def __init__( self, - num_channels:int=1, - sampling_rate:int=16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - duration:Optional[float]=None, + num_channels: int = 1, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric:Union[str,List] = "mse" + metric: Union[str, List] = "mse", ): super().__init__() - assert num_channels ==1 , "Enhancer only support for mono channel models" + assert ( + num_channels == 1 + ), "Enhancer only support for mono channel models" self.dataset = dataset - self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") + self.save_hyperparameters( + "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" + ) if self.logger: - self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") - + self.logger.experiment.log_dict( + dict(self.hparams), "hyperparameters.json" + ) + self.loss = loss self.metric = metric @property def loss(self): return self._loss - - @loss.setter - def loss(self,loss): - if isinstance(loss,str): - losses = [loss] - + @loss.setter + def loss(self, loss): + + if isinstance(loss, str): + losses = [loss] + self._loss = Avergeloss(losses) @property def metric(self): return self._metric - + @metric.setter - def metric(self,metric): + def metric(self, metric): + + if isinstance(metric, str): + metric = [metric] - if isinstance(metric,str): - metric = [metric] - self._metric = Avergeloss(metric) - @property def dataset(self): return self._dataset @dataset.setter - def dataset(self,dataset): + def dataset(self, dataset): self._dataset = dataset - def setup(self,stage:Optional[str]=None): + def setup(self, stage: Optional[str] = None): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - + def train_dataloader(self): return self.dataset.train_dataloader() @@ -94,9 +117,9 @@ class Model(pl.LightningModule): return self.dataset.val_dataloader() def configure_optimizers(self): - return Adam(self.parameters(), lr = self.hparams.lr) + return Adam(self.parameters(), lr=self.hparams.lr) - def training_step(self,batch, batch_idx:int): + def training_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] @@ -105,13 +128,16 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="train_loss", value=loss.item(), - step=self.global_step) - self.log("train_loss",loss.item()) - return {"loss":loss} + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="train_loss", + value=loss.item(), + step=self.global_step, + ) + self.log("train_loss", loss.item()) + return {"loss": loss} - def validation_step(self,batch,batch_idx:int): + def validation_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] @@ -119,48 +145,92 @@ class Model(pl.LightningModule): metric_val = self.metric(prediction, target) loss_val = self.loss(prediction, target) - self.log("val_metric",metric_val.item()) - self.log("val_loss",loss_val.item()) + self.log("val_metric", metric_val.item()) + self.log("val_loss", loss_val.item()) if self.logger: - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_loss",value=loss_val.item(), - step=self.global_step) - self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_metric",value=metric_val.item(), - step=self.global_step) + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="val_loss", + value=loss_val.item(), + step=self.global_step, + ) + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="val_metric", + value=metric_val.item(), + step=self.global_step, + ) - return {"loss":loss_val} + return {"loss": loss_val} def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = { - "version": { - "enhancer":__version__, - "pytorch":torch.__version__ + "version": {"enhancer": __version__, "pytorch": torch.__version__}, + "architecture": { + "module": self.__class__.__module__, + "class": self.__class__.__name__, }, - "architecture":{ - "module":self.__class__.__module__, - "class":self.__class__.__name__ - } - } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): pass - @classmethod def from_pretrained( cls, checkpoint: Union[Path, Text], - map_location = None, + map_location=None, hparams_file: Union[Path, Text] = None, strict: bool = True, use_auth_token: Union[Text, None] = None, - cached_dir: Union[Path, Text]=CACHE_DIR, - **kwargs + cached_dir: Union[Path, Text] = CACHE_DIR, + **kwargs, ): + """ + Load Pretrained model + + parameters: + checkpoint : Path or str + Path to checkpoint, or a remote URL, or a model identifier from + the huggingface.co model hub. + map_location: optional + Same role as in torch.load(). + Defaults to `lambda storage, loc: storage`. + hparams_file : Path or str, optional + Path to a .yaml file with hierarchical structure as in this example: + drop_prob: 0.2 + dataloader: + batch_size: 32 + You most likely won’t need this since Lightning will always save the + hyperparameters to the checkpoint. However, if your checkpoint weights + do not have the hyperparameters saved, use this method to pass in a .yaml + file with the hparams you would like to use. These will be converted + into a dict and passed into your Model for use. + strict : bool, optional + Whether to strictly enforce that the keys in checkpoint match + the keys returned by this module’s state dict. Defaults to True. + use_auth_token : str, optional + When loading a private huggingface.co model, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + cache_dir: Path or str, optional + Path to model cache directory. Defaults to content of PYANNOTE_CACHE + environment variable, or "~/.cache/torch/pyannote" when unset. + kwargs: optional + Any extra keyword args needed to init the model. + Can also be used to override saved hyperparameter values. + + Returns + ------- + model : Model + Model + + See also + -------- + torch.load + """ checkpoint = str(checkpoint) if hparams_file is not None: @@ -168,104 +238,133 @@ class Model(pl.LightningModule): if os.path.isfile(checkpoint): model_path_pl = checkpoint - elif urlparse(checkpoint).scheme in ("http","https"): + elif urlparse(checkpoint).scheme in ("http", "https"): model_path_pl = checkpoint else: - + if "@" in checkpoint: model_id = checkpoint.split("@")[0] revision_id = checkpoint.split("@")[1] else: model_id = checkpoint revision_id = None - + url = hf_hub_url( - model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id + model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id ) model_path_pl = cached_download( - url=url,library_name="enhancer",library_version=__version__, - cache_dir=cached_dir,use_auth_token=use_auth_token + url=url, + library_name="enhancer", + library_version=__version__, + cache_dir=cached_dir, + use_auth_token=use_auth_token, ) if map_location is None: map_location = torch.device(DEFAULT_DEVICE) - loaded_checkpoint = pl_load(model_path_pl,map_location) + loaded_checkpoint = pl_load(model_path_pl, map_location) module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] - class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) try: model = Klass.load_from_checkpoint( - checkpoint_path = model_path_pl, - map_location = map_location, - hparams_file = hparams_file, - strict = strict, - **kwargs + checkpoint_path=model_path_pl, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, ) except Exception as e: print(e) + return model - return model + def infer(self, batch: torch.Tensor, batch_size: int = 32): + """ + perform model inference + parameters: + batch : torch.Tensor + input data + batch_size : int, default 32 + batch size for inference + """ - def infer(self,batch:torch.Tensor,batch_size:int=32): - - assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" + assert ( + batch.ndim == 3 + ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) with torch.no_grad(): - for batch_id in range(0,batch.shape[0],batch_size): - batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + for batch_id in range(0, batch.shape[0], batch_size): + batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + self.device + ) prediction = self(batch_data) batch_predictions.append(prediction) - + return torch.vstack(batch_predictions) def enhance( self, - audio:Union[Path,np.ndarray,torch.Tensor], - sampling_rate:Optional[int]=None, - batch_size:int=32, - save_output:bool=False, - duration:Optional[int]=None, - step_size:Optional[int]=None,): + audio: Union[Path, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + batch_size: int = 32, + save_output: bool = False, + duration: Optional[int] = None, + step_size: Optional[int] = None, + ): + """ + Enhance audio using loaded pretained model. + + parameters: + audio: Path to audio file or numpy array or torch tensor + single input audio + sampling_rate: int, optional incase input is path + sampling rate of input + batch_size: int, default 32 + input audio is split into multiple chunks. Inference is done on batches + of these chunks according to given batch size. + save_output : bool, default False + weather to save output to file + duration : float, optional + chunk duration in seconds, defaults to duration of loaded pretrained model. + step_size: int, optional + step size between consecutive durations, defaults to 50% of duration + """ model_sampling_rate = self.hparams["sampling_rate"] if duration is None: duration = self.hparams["duration"] - waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) + waveform = Inference.read_input( + audio, sampling_rate, model_sampling_rate + ) waveform.to(self.device) window_size = round(duration * model_sampling_rate) - batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) - batch_prediction = self.infer(batched_waveform,batch_size=batch_size) - waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) - - if save_output and isinstance(audio,(str,Path)): - Inference.write_output(waveform,audio,model_sampling_rate) + batched_waveform = Inference.batchify( + waveform, window_size, step_size=step_size + ) + batch_prediction = self.infer(batched_waveform, batch_size=batch_size) + waveform = Inference.aggreagate( + batch_prediction, + window_size, + waveform.shape[-1], + step_size, + ) + + if save_output and isinstance(audio, (str, Path)): + Inference.write_output(waveform, audio, model_sampling_rate) else: - waveform = Inference.prepare_output(waveform, model_sampling_rate, - audio, sampling_rate) + waveform = Inference.prepare_output( + waveform, model_sampling_rate, audio, sampling_rate + ) return waveform + @property def valid_monitor(self): - return "max" if self.loss.higher_better else "min" - - - - - - - - - - - - - - - \ No newline at end of file + return "max" if self.loss.higher_better else "min"