269 lines
7.8 KiB
Python
269 lines
7.8 KiB
Python
try:
|
|
from functools import cached_property
|
|
except ImportError:
|
|
from backports.cached_property import cached_property
|
|
from importlib import import_module
|
|
from huggingface_hub import cached_download, hf_hub_url
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
from typing import Optional, Union, List, Text, Dict, Any
|
|
from torch.optim import Adam
|
|
import torch
|
|
from torch.nn.functional import pad
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
|
from urllib.parse import urlparse
|
|
from pathlib import Path
|
|
|
|
|
|
from enhancer import __version__
|
|
from enhancer.data.dataset import EnhancerDataset
|
|
from enhancer.utils.io import Audio
|
|
from enhancer.loss import Avergeloss
|
|
from enhancer.inference import Inference
|
|
|
|
CACHE_DIR = ""
|
|
HF_TORCH_WEIGHTS = ""
|
|
DEFAULT_DEVICE = "cpu"
|
|
|
|
class Model(pl.LightningModule):
|
|
|
|
def __init__(
|
|
self,
|
|
num_channels:int=1,
|
|
sampling_rate:int=16000,
|
|
lr:float=1e-3,
|
|
dataset:Optional[EnhancerDataset]=None,
|
|
duration:Optional[float]=None,
|
|
loss: Union[str, List] = "mse",
|
|
metric:Union[str,List] = "mse"
|
|
):
|
|
super().__init__()
|
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
|
self.dataset = dataset
|
|
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
|
if self.logger:
|
|
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
|
|
|
self.loss = loss
|
|
self.metric = metric
|
|
|
|
@property
|
|
def loss(self):
|
|
return self._loss
|
|
|
|
@loss.setter
|
|
def loss(self,loss):
|
|
|
|
if isinstance(loss,str):
|
|
losses = [loss]
|
|
|
|
self._loss = Avergeloss(losses)
|
|
|
|
@property
|
|
def metric(self):
|
|
return self._metric
|
|
|
|
@metric.setter
|
|
def metric(self,metric):
|
|
|
|
if isinstance(metric,str):
|
|
metric = [metric]
|
|
|
|
self._metric = Avergeloss(metric)
|
|
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self._dataset
|
|
|
|
@dataset.setter
|
|
def dataset(self,dataset):
|
|
self._dataset = dataset
|
|
|
|
def setup(self,stage:Optional[str]=None):
|
|
if stage == "fit":
|
|
self.dataset.setup(stage)
|
|
self.dataset.model = self
|
|
|
|
def train_dataloader(self):
|
|
return self.dataset.train_dataloader()
|
|
|
|
def val_dataloader(self):
|
|
return self.dataset.val_dataloader()
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr = self.hparams.lr)
|
|
|
|
def training_step(self,batch, batch_idx:int):
|
|
|
|
mixed_waveform = batch["noisy"]
|
|
target = batch["clean"]
|
|
prediction = self(mixed_waveform)
|
|
|
|
loss = self.loss(prediction, target)
|
|
|
|
if self.logger:
|
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
|
key="train_loss", value=loss.item(),
|
|
step=self.global_step)
|
|
self.log("train_loss",loss.item())
|
|
return {"loss":loss}
|
|
|
|
def validation_step(self,batch,batch_idx:int):
|
|
|
|
mixed_waveform = batch["noisy"]
|
|
target = batch["clean"]
|
|
prediction = self(mixed_waveform)
|
|
|
|
metric_val = self.metric(prediction, target)
|
|
loss_val = self.loss(prediction, target)
|
|
self.log("val_metric",metric_val.item())
|
|
self.log("val_loss",loss_val.item())
|
|
|
|
if self.logger:
|
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
|
key="val_loss",value=loss_val.item(),
|
|
step=self.global_step)
|
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
|
key="val_metric",value=metric_val.item(),
|
|
step=self.global_step)
|
|
|
|
return {"loss":loss_val}
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
|
checkpoint["enhancer"] = {
|
|
"version": {
|
|
"enhancer":__version__,
|
|
"pytorch":torch.__version__
|
|
},
|
|
"architecture":{
|
|
"module":self.__class__.__module__,
|
|
"class":self.__class__.__name__
|
|
}
|
|
|
|
}
|
|
|
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
|
pass
|
|
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
checkpoint: Union[Path, Text],
|
|
map_location = None,
|
|
hparams_file: Union[Path, Text] = None,
|
|
strict: bool = True,
|
|
use_auth_token: Union[Text, None] = None,
|
|
cached_dir: Union[Path, Text]=CACHE_DIR,
|
|
**kwargs
|
|
):
|
|
|
|
checkpoint = str(checkpoint)
|
|
if hparams_file is not None:
|
|
hparams_file = str(hparams_file)
|
|
|
|
if os.path.isfile(checkpoint):
|
|
model_path_pl = checkpoint
|
|
elif urlparse(checkpoint).scheme in ("http","https"):
|
|
model_path_pl = checkpoint
|
|
else:
|
|
|
|
if "@" in checkpoint:
|
|
model_id = checkpoint.split("@")[0]
|
|
revision_id = checkpoint.split("@")[1]
|
|
else:
|
|
model_id = checkpoint
|
|
revision_id = None
|
|
|
|
url = hf_hub_url(
|
|
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
|
)
|
|
model_path_pl = cached_download(
|
|
url=url,library_name="enhancer",library_version=__version__,
|
|
cache_dir=cached_dir,use_auth_token=use_auth_token
|
|
)
|
|
|
|
if map_location is None:
|
|
map_location = torch.device(DEFAULT_DEVICE)
|
|
|
|
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
|
module_name = loaded_checkpoint["architecture"]["module"]
|
|
class_name = loaded_checkpoint["architecture"]["class"]
|
|
module = import_module(module_name)
|
|
Klass = getattr(module, class_name)
|
|
|
|
try:
|
|
model = Klass.load_from_checkpoint(
|
|
checkpoint_path = model_path_pl,
|
|
map_location = map_location,
|
|
hparams_file = hparams_file,
|
|
strict = strict,
|
|
**kwargs
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
|
|
return model
|
|
|
|
def infer(self,batch:torch.Tensor,batch_size:int=32):
|
|
|
|
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
|
batch_predictions = []
|
|
self.eval().to(self.device)
|
|
|
|
for batch_id in range(batch.shape[0],batch_size):
|
|
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
|
prediction = self(batch_data)
|
|
batch_predictions.append(prediction)
|
|
|
|
return torch.vstack(batch_predictions)
|
|
|
|
def enhance(
|
|
self,
|
|
audio:Union[Path,np.ndarray,torch.Tensor],
|
|
sampling_rate:Optional[int]=None,
|
|
batch_size:int=32,
|
|
save_output:bool=False,
|
|
duration:Optional[int]=None,
|
|
step_size:Optional[int]=None,):
|
|
|
|
model_sampling_rate = self.model.hprams("sampling_rate")
|
|
if duration is None:
|
|
duration = self.model.hparams("duration")
|
|
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
|
waveform.to(self.device)
|
|
window_size = round(duration * model_sampling_rate)
|
|
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
|
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
|
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
|
|
|
if save_output and isinstance(audio,(str,Path)):
|
|
Inference.write_output(waveform,audio,model_sampling_rate)
|
|
|
|
else:
|
|
return waveform
|
|
|
|
@property
|
|
def valid_monitor(self):
|
|
|
|
return "max" if self.loss.higher_better else "min"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|