try: from functools import cached_property except ImportError: from backports.cached_property import cached_property from importlib import import_module from huggingface_hub import cached_download, hf_hub_url import logging import numpy as np import os from typing import Optional, Union, List, Text, Dict, Any from torch.optim import Adam import torch from torch.nn.functional import pad import pytorch_lightning as pl from pytorch_lightning.utilities.cloud_io import load as pl_load from urllib.parse import urlparse from pathlib import Path from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio from enhancer.loss import Avergeloss from enhancer.inference import Inference CACHE_DIR = "" HF_TORCH_WEIGHTS = "" DEFAULT_DEVICE = "cpu" class Model(pl.LightningModule): """ Base class for all models parameters: num_channels: int, default to 1 number of channels in input audio sampling_rate : int, default 16khz audio sampling rate lr: float, optional learning rate for model training dataset: EnhancerDataset, optional Enhancer dataset used for training/validation duration: float, optional duration used for training/inference loss : string or List of strings, default to "mse" loss functions to be used. Available ("mse","mae","Si-SDR") """ def __init__( self, num_channels: int = 1, sampling_rate: int = 16000, lr: float = 1e-3, dataset: Optional[EnhancerDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", ): super().__init__() assert ( num_channels == 1 ), "Enhancer only support for mono channel models" self.dataset = dataset self.save_hyperparameters( "num_channels", "sampling_rate", "lr", "loss", "metric", "duration" ) if self.logger: self.logger.experiment.log_dict( dict(self.hparams), "hyperparameters.json" ) self.loss = loss self.metric = metric @property def loss(self): return self._loss @loss.setter def loss(self, loss): if isinstance(loss, str): losses = [loss] self._loss = Avergeloss(losses) @property def metric(self): return self._metric @metric.setter def metric(self, metric): if isinstance(metric, str): metric = [metric] self._metric = Avergeloss(metric) @property def dataset(self): return self._dataset @dataset.setter def dataset(self, dataset): self._dataset = dataset def setup(self, stage: Optional[str] = None): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self def train_dataloader(self): return self.dataset.train_dataloader() def val_dataloader(self): return self.dataset.val_dataloader() def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.lr) def training_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) loss = self.loss(prediction, target) if self.logger: self.logger.experiment.log_metric( run_id=self.logger.run_id, key="train_loss", value=loss.item(), step=self.global_step, ) self.log("train_loss", loss.item()) return {"loss": loss} def validation_step(self, batch, batch_idx: int): mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) metric_val = self.metric(prediction, target) loss_val = self.loss(prediction, target) self.log("val_metric", metric_val.item()) self.log("val_loss", loss_val.item()) if self.logger: self.logger.experiment.log_metric( run_id=self.logger.run_id, key="val_loss", value=loss_val.item(), step=self.global_step, ) self.logger.experiment.log_metric( run_id=self.logger.run_id, key="val_metric", value=metric_val.item(), step=self.global_step, ) return {"loss": loss_val} def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = { "version": {"enhancer": __version__, "pytorch": torch.__version__}, "architecture": { "module": self.__class__.__module__, "class": self.__class__.__name__, }, } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): pass @classmethod def from_pretrained( cls, checkpoint: Union[Path, Text], map_location=None, hparams_file: Union[Path, Text] = None, strict: bool = True, use_auth_token: Union[Text, None] = None, cached_dir: Union[Path, Text] = CACHE_DIR, **kwargs, ): """ Load Pretrained model parameters: checkpoint : Path or str Path to checkpoint, or a remote URL, or a model identifier from the huggingface.co model hub. map_location: optional Same role as in torch.load(). Defaults to `lambda storage, loc: storage`. hparams_file : Path or str, optional Path to a .yaml file with hierarchical structure as in this example: drop_prob: 0.2 dataloader: batch_size: 32 You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights do not have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you would like to use. These will be converted into a dict and passed into your Model for use. strict : bool, optional Whether to strictly enforce that the keys in checkpoint match the keys returned by this module’s state dict. Defaults to True. use_auth_token : str, optional When loading a private huggingface.co model, set `use_auth_token` to True or to a string containing your hugginface.co authentication token that can be obtained by running `huggingface-cli login` cache_dir: Path or str, optional Path to model cache directory. Defaults to content of PYANNOTE_CACHE environment variable, or "~/.cache/torch/pyannote" when unset. kwargs: optional Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. Returns ------- model : Model Model See also -------- torch.load """ checkpoint = str(checkpoint) if hparams_file is not None: hparams_file = str(hparams_file) if os.path.isfile(checkpoint): model_path_pl = checkpoint elif urlparse(checkpoint).scheme in ("http", "https"): model_path_pl = checkpoint else: if "@" in checkpoint: model_id = checkpoint.split("@")[0] revision_id = checkpoint.split("@")[1] else: model_id = checkpoint revision_id = None url = hf_hub_url( model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id ) model_path_pl = cached_download( url=url, library_name="enhancer", library_version=__version__, cache_dir=cached_dir, use_auth_token=use_auth_token, ) if map_location is None: map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl, map_location) module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) try: model = Klass.load_from_checkpoint( checkpoint_path=model_path_pl, map_location=map_location, hparams_file=hparams_file, strict=strict, **kwargs, ) except Exception as e: print(e) return model def infer(self, batch: torch.Tensor, batch_size: int = 32): """ perform model inference parameters: batch : torch.Tensor input data batch_size : int, default 32 batch size for inference """ assert ( batch.ndim == 3 ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): batch_data = batch[batch_id : batch_id + batch_size, :, :].to( self.device ) prediction = self(batch_data) batch_predictions.append(prediction) return torch.vstack(batch_predictions) def enhance( self, audio: Union[Path, np.ndarray, torch.Tensor], sampling_rate: Optional[int] = None, batch_size: int = 32, save_output: bool = False, duration: Optional[int] = None, step_size: Optional[int] = None, ): """ Enhance audio using loaded pretained model. parameters: audio: Path to audio file or numpy array or torch tensor single input audio sampling_rate: int, optional incase input is path sampling rate of input batch_size: int, default 32 input audio is split into multiple chunks. Inference is done on batches of these chunks according to given batch size. save_output : bool, default False weather to save output to file duration : float, optional chunk duration in seconds, defaults to duration of loaded pretrained model. step_size: int, optional step size between consecutive durations, defaults to 50% of duration """ model_sampling_rate = self.hparams["sampling_rate"] if duration is None: duration = self.hparams["duration"] waveform = Inference.read_input( audio, sampling_rate, model_sampling_rate ) waveform.to(self.device) window_size = round(duration * model_sampling_rate) batched_waveform = Inference.batchify( waveform, window_size, step_size=step_size ) batch_prediction = self.infer(batched_waveform, batch_size=batch_size) waveform = Inference.aggreagate( batch_prediction, window_size, waveform.shape[-1], step_size, ) if save_output and isinstance(audio, (str, Path)): Inference.write_output(waveform, audio, model_sampling_rate) else: waveform = Inference.prepare_output( waveform, model_sampling_rate, audio, sampling_rate ) return waveform @property def valid_monitor(self): return "max" if self.loss.higher_better else "min"