commit
07e6aa6e24
|
|
@ -27,24 +27,48 @@ CACHE_DIR = ""
|
||||||
HF_TORCH_WEIGHTS = ""
|
HF_TORCH_WEIGHTS = ""
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
|
|
||||||
|
|
||||||
class Model(pl.LightningModule):
|
class Model(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
Base class for all models
|
||||||
|
parameters:
|
||||||
|
num_channels: int, default to 1
|
||||||
|
number of channels in input audio
|
||||||
|
sampling_rate : int, default 16khz
|
||||||
|
audio sampling rate
|
||||||
|
lr: float, optional
|
||||||
|
learning rate for model training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
Enhancer dataset used for training/validation
|
||||||
|
duration: float, optional
|
||||||
|
duration used for training/inference
|
||||||
|
loss : string or List of strings, default to "mse"
|
||||||
|
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int=1,
|
num_channels: int = 1,
|
||||||
sampling_rate:int=16000,
|
sampling_rate: int = 16000,
|
||||||
lr:float=1e-3,
|
lr: float = 1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration:Optional[float]=None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str,List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert (
|
||||||
|
num_channels == 1
|
||||||
|
), "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
self.save_hyperparameters(
|
||||||
|
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||||
|
)
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(
|
||||||
|
dict(self.hparams), "hyperparameters.json"
|
||||||
|
)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
|
@ -54,9 +78,9 @@ class Model(pl.LightningModule):
|
||||||
return self._loss
|
return self._loss
|
||||||
|
|
||||||
@loss.setter
|
@loss.setter
|
||||||
def loss(self,loss):
|
def loss(self, loss):
|
||||||
|
|
||||||
if isinstance(loss,str):
|
if isinstance(loss, str):
|
||||||
losses = [loss]
|
losses = [loss]
|
||||||
|
|
||||||
self._loss = Avergeloss(losses)
|
self._loss = Avergeloss(losses)
|
||||||
|
|
@ -66,23 +90,22 @@ class Model(pl.LightningModule):
|
||||||
return self._metric
|
return self._metric
|
||||||
|
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self,metric):
|
def metric(self, metric):
|
||||||
|
|
||||||
if isinstance(metric,str):
|
if isinstance(metric, str):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
|
|
||||||
self._metric = Avergeloss(metric)
|
self._metric = Avergeloss(metric)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
return self._dataset
|
return self._dataset
|
||||||
|
|
||||||
@dataset.setter
|
@dataset.setter
|
||||||
def dataset(self,dataset):
|
def dataset(self, dataset):
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
|
|
||||||
def setup(self,stage:Optional[str]=None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
|
|
@ -94,9 +117,9 @@ class Model(pl.LightningModule):
|
||||||
return self.dataset.val_dataloader()
|
return self.dataset.val_dataloader()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return Adam(self.parameters(), lr = self.hparams.lr)
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
|
|
||||||
def training_step(self,batch, batch_idx:int):
|
def training_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
mixed_waveform = batch["noisy"]
|
mixed_waveform = batch["noisy"]
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
|
|
@ -105,13 +128,16 @@ class Model(pl.LightningModule):
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(
|
||||||
key="train_loss", value=loss.item(),
|
run_id=self.logger.run_id,
|
||||||
step=self.global_step)
|
key="train_loss",
|
||||||
self.log("train_loss",loss.item())
|
value=loss.item(),
|
||||||
return {"loss":loss}
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
self.log("train_loss", loss.item())
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
def validation_step(self,batch,batch_idx:int):
|
def validation_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
mixed_waveform = batch["noisy"]
|
mixed_waveform = batch["noisy"]
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
|
|
@ -119,48 +145,92 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
metric_val = self.metric(prediction, target)
|
metric_val = self.metric(prediction, target)
|
||||||
loss_val = self.loss(prediction, target)
|
loss_val = self.loss(prediction, target)
|
||||||
self.log("val_metric",metric_val.item())
|
self.log("val_metric", metric_val.item())
|
||||||
self.log("val_loss",loss_val.item())
|
self.log("val_loss", loss_val.item())
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(
|
||||||
key="val_loss",value=loss_val.item(),
|
run_id=self.logger.run_id,
|
||||||
step=self.global_step)
|
key="val_loss",
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
value=loss_val.item(),
|
||||||
key="val_metric",value=metric_val.item(),
|
step=self.global_step,
|
||||||
step=self.global_step)
|
)
|
||||||
|
self.logger.experiment.log_metric(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
key="val_metric",
|
||||||
|
value=metric_val.item(),
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
return {"loss":loss_val}
|
return {"loss": loss_val}
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint["enhancer"] = {
|
checkpoint["enhancer"] = {
|
||||||
"version": {
|
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||||
"enhancer":__version__,
|
"architecture": {
|
||||||
"pytorch":torch.__version__
|
"module": self.__class__.__module__,
|
||||||
|
"class": self.__class__.__name__,
|
||||||
},
|
},
|
||||||
"architecture":{
|
|
||||||
"module":self.__class__.__module__,
|
|
||||||
"class":self.__class__.__name__
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
checkpoint: Union[Path, Text],
|
checkpoint: Union[Path, Text],
|
||||||
map_location = None,
|
map_location=None,
|
||||||
hparams_file: Union[Path, Text] = None,
|
hparams_file: Union[Path, Text] = None,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
use_auth_token: Union[Text, None] = None,
|
use_auth_token: Union[Text, None] = None,
|
||||||
cached_dir: Union[Path, Text]=CACHE_DIR,
|
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Load Pretrained model
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
checkpoint : Path or str
|
||||||
|
Path to checkpoint, or a remote URL, or a model identifier from
|
||||||
|
the huggingface.co model hub.
|
||||||
|
map_location: optional
|
||||||
|
Same role as in torch.load().
|
||||||
|
Defaults to `lambda storage, loc: storage`.
|
||||||
|
hparams_file : Path or str, optional
|
||||||
|
Path to a .yaml file with hierarchical structure as in this example:
|
||||||
|
drop_prob: 0.2
|
||||||
|
dataloader:
|
||||||
|
batch_size: 32
|
||||||
|
You most likely won’t need this since Lightning will always save the
|
||||||
|
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||||
|
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||||
|
file with the hparams you would like to use. These will be converted
|
||||||
|
into a dict and passed into your Model for use.
|
||||||
|
strict : bool, optional
|
||||||
|
Whether to strictly enforce that the keys in checkpoint match
|
||||||
|
the keys returned by this module’s state dict. Defaults to True.
|
||||||
|
use_auth_token : str, optional
|
||||||
|
When loading a private huggingface.co model, set `use_auth_token`
|
||||||
|
to True or to a string containing your hugginface.co authentication
|
||||||
|
token that can be obtained by running `huggingface-cli login`
|
||||||
|
cache_dir: Path or str, optional
|
||||||
|
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||||
|
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||||
|
kwargs: optional
|
||||||
|
Any extra keyword args needed to init the model.
|
||||||
|
Can also be used to override saved hyperparameter values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Model
|
||||||
|
Model
|
||||||
|
|
||||||
|
See also
|
||||||
|
--------
|
||||||
|
torch.load
|
||||||
|
"""
|
||||||
|
|
||||||
checkpoint = str(checkpoint)
|
checkpoint = str(checkpoint)
|
||||||
if hparams_file is not None:
|
if hparams_file is not None:
|
||||||
|
|
@ -168,7 +238,7 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
model_path_pl = checkpoint
|
model_path_pl = checkpoint
|
||||||
elif urlparse(checkpoint).scheme in ("http","https"):
|
elif urlparse(checkpoint).scheme in ("http", "https"):
|
||||||
model_path_pl = checkpoint
|
model_path_pl = checkpoint
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|
@ -180,17 +250,20 @@ class Model(pl.LightningModule):
|
||||||
revision_id = None
|
revision_id = None
|
||||||
|
|
||||||
url = hf_hub_url(
|
url = hf_hub_url(
|
||||||
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||||
)
|
)
|
||||||
model_path_pl = cached_download(
|
model_path_pl = cached_download(
|
||||||
url=url,library_name="enhancer",library_version=__version__,
|
url=url,
|
||||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
library_name="enhancer",
|
||||||
|
library_version=__version__,
|
||||||
|
cache_dir=cached_dir,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if map_location is None:
|
if map_location is None:
|
||||||
map_location = torch.device(DEFAULT_DEVICE)
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||||
module = import_module(module_name)
|
module = import_module(module_name)
|
||||||
|
|
@ -198,27 +271,38 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = Klass.load_from_checkpoint(
|
model = Klass.load_from_checkpoint(
|
||||||
checkpoint_path = model_path_pl,
|
checkpoint_path=model_path_pl,
|
||||||
map_location = map_location,
|
map_location=map_location,
|
||||||
hparams_file = hparams_file,
|
hparams_file=hparams_file,
|
||||||
strict = strict,
|
strict=strict,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def infer(self,batch:torch.Tensor,batch_size:int=32):
|
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||||
|
"""
|
||||||
|
perform model inference
|
||||||
|
parameters:
|
||||||
|
batch : torch.Tensor
|
||||||
|
input data
|
||||||
|
batch_size : int, default 32
|
||||||
|
batch size for inference
|
||||||
|
"""
|
||||||
|
|
||||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
assert (
|
||||||
|
batch.ndim == 3
|
||||||
|
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||||
batch_predictions = []
|
batch_predictions = []
|
||||||
self.eval().to(self.device)
|
self.eval().to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_id in range(0,batch.shape[0],batch_size):
|
for batch_id in range(0, batch.shape[0], batch_size):
|
||||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
prediction = self(batch_data)
|
prediction = self(batch_data)
|
||||||
batch_predictions.append(prediction)
|
batch_predictions.append(prediction)
|
||||||
|
|
||||||
|
|
@ -226,46 +310,61 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
def enhance(
|
def enhance(
|
||||||
self,
|
self,
|
||||||
audio:Union[Path,np.ndarray,torch.Tensor],
|
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||||
sampling_rate:Optional[int]=None,
|
sampling_rate: Optional[int] = None,
|
||||||
batch_size:int=32,
|
batch_size: int = 32,
|
||||||
save_output:bool=False,
|
save_output: bool = False,
|
||||||
duration:Optional[int]=None,
|
duration: Optional[int] = None,
|
||||||
step_size:Optional[int]=None,):
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enhance audio using loaded pretained model.
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate: int, optional incase input is path
|
||||||
|
sampling rate of input
|
||||||
|
batch_size: int, default 32
|
||||||
|
input audio is split into multiple chunks. Inference is done on batches
|
||||||
|
of these chunks according to given batch size.
|
||||||
|
save_output : bool, default False
|
||||||
|
weather to save output to file
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||||
|
step_size: int, optional
|
||||||
|
step size between consecutive durations, defaults to 50% of duration
|
||||||
|
"""
|
||||||
|
|
||||||
model_sampling_rate = self.hparams["sampling_rate"]
|
model_sampling_rate = self.hparams["sampling_rate"]
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = self.hparams["duration"]
|
duration = self.hparams["duration"]
|
||||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
waveform = Inference.read_input(
|
||||||
|
audio, sampling_rate, model_sampling_rate
|
||||||
|
)
|
||||||
waveform.to(self.device)
|
waveform.to(self.device)
|
||||||
window_size = round(duration * model_sampling_rate)
|
window_size = round(duration * model_sampling_rate)
|
||||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
batched_waveform = Inference.batchify(
|
||||||
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
waveform, window_size, step_size=step_size
|
||||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
)
|
||||||
|
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||||
|
waveform = Inference.aggreagate(
|
||||||
|
batch_prediction,
|
||||||
|
window_size,
|
||||||
|
waveform.shape[-1],
|
||||||
|
step_size,
|
||||||
|
)
|
||||||
|
|
||||||
if save_output and isinstance(audio,(str,Path)):
|
if save_output and isinstance(audio, (str, Path)):
|
||||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
waveform = Inference.prepare_output(
|
||||||
audio, sampling_rate)
|
waveform, model_sampling_rate, audio, sampling_rate
|
||||||
|
)
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_monitor(self):
|
def valid_monitor(self):
|
||||||
|
|
||||||
return "max" if self.loss.higher_better else "min"
|
return "max" if self.loss.higher_better else "min"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue