371 lines
12 KiB
Python
371 lines
12 KiB
Python
try:
|
||
from functools import cached_property
|
||
except ImportError:
|
||
from backports.cached_property import cached_property
|
||
from importlib import import_module
|
||
from huggingface_hub import cached_download, hf_hub_url
|
||
import logging
|
||
import numpy as np
|
||
import os
|
||
from typing import Optional, Union, List, Text, Dict, Any
|
||
from torch.optim import Adam
|
||
import torch
|
||
from torch.nn.functional import pad
|
||
import pytorch_lightning as pl
|
||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||
from urllib.parse import urlparse
|
||
from pathlib import Path
|
||
|
||
|
||
from enhancer import __version__
|
||
from enhancer.data.dataset import EnhancerDataset
|
||
from enhancer.utils.io import Audio
|
||
from enhancer.loss import Avergeloss
|
||
from enhancer.inference import Inference
|
||
|
||
CACHE_DIR = ""
|
||
HF_TORCH_WEIGHTS = ""
|
||
DEFAULT_DEVICE = "cpu"
|
||
|
||
|
||
class Model(pl.LightningModule):
|
||
"""
|
||
Base class for all models
|
||
parameters:
|
||
num_channels: int, default to 1
|
||
number of channels in input audio
|
||
sampling_rate : int, default 16khz
|
||
audio sampling rate
|
||
lr: float, optional
|
||
learning rate for model training
|
||
dataset: EnhancerDataset, optional
|
||
Enhancer dataset used for training/validation
|
||
duration: float, optional
|
||
duration used for training/inference
|
||
loss : string or List of strings, default to "mse"
|
||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
num_channels: int = 1,
|
||
sampling_rate: int = 16000,
|
||
lr: float = 1e-3,
|
||
dataset: Optional[EnhancerDataset] = None,
|
||
duration: Optional[float] = None,
|
||
loss: Union[str, List] = "mse",
|
||
metric: Union[str, List] = "mse",
|
||
):
|
||
super().__init__()
|
||
assert (
|
||
num_channels == 1
|
||
), "Enhancer only support for mono channel models"
|
||
self.dataset = dataset
|
||
self.save_hyperparameters(
|
||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||
)
|
||
if self.logger:
|
||
self.logger.experiment.log_dict(
|
||
dict(self.hparams), "hyperparameters.json"
|
||
)
|
||
|
||
self.loss = loss
|
||
self.metric = metric
|
||
|
||
@property
|
||
def loss(self):
|
||
return self._loss
|
||
|
||
@loss.setter
|
||
def loss(self, loss):
|
||
|
||
if isinstance(loss, str):
|
||
losses = [loss]
|
||
|
||
self._loss = Avergeloss(losses)
|
||
|
||
@property
|
||
def metric(self):
|
||
return self._metric
|
||
|
||
@metric.setter
|
||
def metric(self, metric):
|
||
|
||
if isinstance(metric, str):
|
||
metric = [metric]
|
||
|
||
self._metric = Avergeloss(metric)
|
||
|
||
@property
|
||
def dataset(self):
|
||
return self._dataset
|
||
|
||
@dataset.setter
|
||
def dataset(self, dataset):
|
||
self._dataset = dataset
|
||
|
||
def setup(self, stage: Optional[str] = None):
|
||
if stage == "fit":
|
||
self.dataset.setup(stage)
|
||
self.dataset.model = self
|
||
|
||
def train_dataloader(self):
|
||
return self.dataset.train_dataloader()
|
||
|
||
def val_dataloader(self):
|
||
return self.dataset.val_dataloader()
|
||
|
||
def configure_optimizers(self):
|
||
return Adam(self.parameters(), lr=self.hparams.lr)
|
||
|
||
def training_step(self, batch, batch_idx: int):
|
||
|
||
mixed_waveform = batch["noisy"]
|
||
target = batch["clean"]
|
||
prediction = self(mixed_waveform)
|
||
|
||
loss = self.loss(prediction, target)
|
||
|
||
if self.logger:
|
||
self.logger.experiment.log_metric(
|
||
run_id=self.logger.run_id,
|
||
key="train_loss",
|
||
value=loss.item(),
|
||
step=self.global_step,
|
||
)
|
||
self.log("train_loss", loss.item())
|
||
return {"loss": loss}
|
||
|
||
def validation_step(self, batch, batch_idx: int):
|
||
|
||
mixed_waveform = batch["noisy"]
|
||
target = batch["clean"]
|
||
prediction = self(mixed_waveform)
|
||
|
||
metric_val = self.metric(prediction, target)
|
||
loss_val = self.loss(prediction, target)
|
||
self.log("val_metric", metric_val.item())
|
||
self.log("val_loss", loss_val.item())
|
||
|
||
if self.logger:
|
||
self.logger.experiment.log_metric(
|
||
run_id=self.logger.run_id,
|
||
key="val_loss",
|
||
value=loss_val.item(),
|
||
step=self.global_step,
|
||
)
|
||
self.logger.experiment.log_metric(
|
||
run_id=self.logger.run_id,
|
||
key="val_metric",
|
||
value=metric_val.item(),
|
||
step=self.global_step,
|
||
)
|
||
|
||
return {"loss": loss_val}
|
||
|
||
def on_save_checkpoint(self, checkpoint):
|
||
|
||
checkpoint["enhancer"] = {
|
||
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||
"architecture": {
|
||
"module": self.__class__.__module__,
|
||
"class": self.__class__.__name__,
|
||
},
|
||
}
|
||
|
||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||
pass
|
||
|
||
@classmethod
|
||
def from_pretrained(
|
||
cls,
|
||
checkpoint: Union[Path, Text],
|
||
map_location=None,
|
||
hparams_file: Union[Path, Text] = None,
|
||
strict: bool = True,
|
||
use_auth_token: Union[Text, None] = None,
|
||
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Load Pretrained model
|
||
|
||
parameters:
|
||
checkpoint : Path or str
|
||
Path to checkpoint, or a remote URL, or a model identifier from
|
||
the huggingface.co model hub.
|
||
map_location: optional
|
||
Same role as in torch.load().
|
||
Defaults to `lambda storage, loc: storage`.
|
||
hparams_file : Path or str, optional
|
||
Path to a .yaml file with hierarchical structure as in this example:
|
||
drop_prob: 0.2
|
||
dataloader:
|
||
batch_size: 32
|
||
You most likely won’t need this since Lightning will always save the
|
||
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||
file with the hparams you would like to use. These will be converted
|
||
into a dict and passed into your Model for use.
|
||
strict : bool, optional
|
||
Whether to strictly enforce that the keys in checkpoint match
|
||
the keys returned by this module’s state dict. Defaults to True.
|
||
use_auth_token : str, optional
|
||
When loading a private huggingface.co model, set `use_auth_token`
|
||
to True or to a string containing your hugginface.co authentication
|
||
token that can be obtained by running `huggingface-cli login`
|
||
cache_dir: Path or str, optional
|
||
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||
kwargs: optional
|
||
Any extra keyword args needed to init the model.
|
||
Can also be used to override saved hyperparameter values.
|
||
|
||
Returns
|
||
-------
|
||
model : Model
|
||
Model
|
||
|
||
See also
|
||
--------
|
||
torch.load
|
||
"""
|
||
|
||
checkpoint = str(checkpoint)
|
||
if hparams_file is not None:
|
||
hparams_file = str(hparams_file)
|
||
|
||
if os.path.isfile(checkpoint):
|
||
model_path_pl = checkpoint
|
||
elif urlparse(checkpoint).scheme in ("http", "https"):
|
||
model_path_pl = checkpoint
|
||
else:
|
||
|
||
if "@" in checkpoint:
|
||
model_id = checkpoint.split("@")[0]
|
||
revision_id = checkpoint.split("@")[1]
|
||
else:
|
||
model_id = checkpoint
|
||
revision_id = None
|
||
|
||
url = hf_hub_url(
|
||
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||
)
|
||
model_path_pl = cached_download(
|
||
url=url,
|
||
library_name="enhancer",
|
||
library_version=__version__,
|
||
cache_dir=cached_dir,
|
||
use_auth_token=use_auth_token,
|
||
)
|
||
|
||
if map_location is None:
|
||
map_location = torch.device(DEFAULT_DEVICE)
|
||
|
||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||
module = import_module(module_name)
|
||
Klass = getattr(module, class_name)
|
||
|
||
try:
|
||
model = Klass.load_from_checkpoint(
|
||
checkpoint_path=model_path_pl,
|
||
map_location=map_location,
|
||
hparams_file=hparams_file,
|
||
strict=strict,
|
||
**kwargs,
|
||
)
|
||
except Exception as e:
|
||
print(e)
|
||
|
||
return model
|
||
|
||
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||
"""
|
||
perform model inference
|
||
parameters:
|
||
batch : torch.Tensor
|
||
input data
|
||
batch_size : int, default 32
|
||
batch size for inference
|
||
"""
|
||
|
||
assert (
|
||
batch.ndim == 3
|
||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||
batch_predictions = []
|
||
self.eval().to(self.device)
|
||
|
||
with torch.no_grad():
|
||
for batch_id in range(0, batch.shape[0], batch_size):
|
||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||
self.device
|
||
)
|
||
prediction = self(batch_data)
|
||
batch_predictions.append(prediction)
|
||
|
||
return torch.vstack(batch_predictions)
|
||
|
||
def enhance(
|
||
self,
|
||
audio: Union[Path, np.ndarray, torch.Tensor],
|
||
sampling_rate: Optional[int] = None,
|
||
batch_size: int = 32,
|
||
save_output: bool = False,
|
||
duration: Optional[int] = None,
|
||
step_size: Optional[int] = None,
|
||
):
|
||
"""
|
||
Enhance audio using loaded pretained model.
|
||
|
||
parameters:
|
||
audio: Path to audio file or numpy array or torch tensor
|
||
single input audio
|
||
sampling_rate: int, optional incase input is path
|
||
sampling rate of input
|
||
batch_size: int, default 32
|
||
input audio is split into multiple chunks. Inference is done on batches
|
||
of these chunks according to given batch size.
|
||
save_output : bool, default False
|
||
weather to save output to file
|
||
duration : float, optional
|
||
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||
step_size: int, optional
|
||
step size between consecutive durations, defaults to 50% of duration
|
||
"""
|
||
|
||
model_sampling_rate = self.hparams["sampling_rate"]
|
||
if duration is None:
|
||
duration = self.hparams["duration"]
|
||
waveform = Inference.read_input(
|
||
audio, sampling_rate, model_sampling_rate
|
||
)
|
||
waveform.to(self.device)
|
||
window_size = round(duration * model_sampling_rate)
|
||
batched_waveform = Inference.batchify(
|
||
waveform, window_size, step_size=step_size
|
||
)
|
||
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||
waveform = Inference.aggreagate(
|
||
batch_prediction,
|
||
window_size,
|
||
waveform.shape[-1],
|
||
step_size,
|
||
)
|
||
|
||
if save_output and isinstance(audio, (str, Path)):
|
||
Inference.write_output(waveform, audio, model_sampling_rate)
|
||
|
||
else:
|
||
waveform = Inference.prepare_output(
|
||
waveform, model_sampling_rate, audio, sampling_rate
|
||
)
|
||
return waveform
|
||
|
||
@property
|
||
def valid_monitor(self):
|
||
|
||
return "max" if self.loss.higher_better else "min"
|