add enhance function

This commit is contained in:
shahules786 2022-09-19 22:35:21 +05:30
parent 872303642f
commit 9ef6665b84
1 changed files with 64 additions and 2 deletions

View File

@ -1,9 +1,15 @@
from asyncore import write
from importlib import import_module
from lib2to3.pgen2.token import OP
import wave
from xmlrpc.client import boolean
from huggingface_hub import cached_download, hf_hub_url
import numpy as np
import os
from typing import Optional, Union, List, Path, Text
from typing import Optional, Union, List, Path, Text, Dict, Any
from torch.optim import Adam
import torch
from torch.nn.functional import pad
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse
@ -11,7 +17,9 @@ from urllib.parse import urlparse
from enhancer import __version__
from enhancer.data.dataset import Dataset
from enhancer.utils.io import Audio
from enhancer.utils.loss import Avergeloss
from enhancer.inference import Inference
CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
@ -30,8 +38,8 @@ class Model(pl.LightningModule):
):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric")
self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric")
@property
@ -40,6 +48,8 @@ class Model(pl.LightningModule):
@dataset.setter
def dataset(self,dataset):
if dataset is not None:
self.save_hyperparameters("duration",self.dataset.duration)
self._dataset = dataset
def setup(self,stage:Optional[str]=None):
@ -99,6 +109,10 @@ class Model(pl.LightningModule):
}
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod
def from_pretrained(
cls,
@ -157,7 +171,55 @@ class Model(pl.LightningModule):
print(e)
return model
def infer_batch(self,batch,batch_size):
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = []
self.eval().to(self.device)
for batch_id in range(batch.shape[0],batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
prediction = self(batch_data)
batch_predictions.append(prediction)
return torch.vstack(batch_predictions)
def enhance(
self,
audio:Union[Path,np.ndarray,torch.Tensor],
sampling_rate:Optional[int]=None,
batch_size:int=32,
save_output:boolean=False,
duration:Optional[int]=None,
step_size:Optional[int]=None,):
model_sampling_rate = self.model.hprams("sampling_rate")
if duration is None:
duration = self.model.hparams("duration")
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform.to(self.device)
window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size)
waveform = Inference.aggreagate(batch_prediction,window_size,step_size)
if save_output and isinstance(audio,(str,Path)):
Inference.write_output(waveform,audio,model_sampling_rate)
else:
return waveform