commit
07e6aa6e24
|
|
@ -27,7 +27,25 @@ CACHE_DIR = ""
|
|||
HF_TORCH_WEIGHTS = ""
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
"""
|
||||
Base class for all models
|
||||
parameters:
|
||||
num_channels: int, default to 1
|
||||
number of channels in input audio
|
||||
sampling_rate : int, default 16khz
|
||||
audio sampling rate
|
||||
lr: float, optional
|
||||
learning rate for model training
|
||||
dataset: EnhancerDataset, optional
|
||||
Enhancer dataset used for training/validation
|
||||
duration: float, optional
|
||||
duration used for training/inference
|
||||
loss : string or List of strings, default to "mse"
|
||||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -37,14 +55,20 @@ class Model(pl.LightningModule):
|
|||
dataset: Optional[EnhancerDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric:Union[str,List] = "mse"
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
super().__init__()
|
||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||
assert (
|
||||
num_channels == 1
|
||||
), "Enhancer only support for mono channel models"
|
||||
self.dataset = dataset
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||
self.save_hyperparameters(
|
||||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||
self.logger.experiment.log_dict(
|
||||
dict(self.hparams), "hyperparameters.json"
|
||||
)
|
||||
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
|
|
@ -73,7 +97,6 @@ class Model(pl.LightningModule):
|
|||
|
||||
self._metric = Avergeloss(metric)
|
||||
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self._dataset
|
||||
|
|
@ -105,9 +128,12 @@ class Model(pl.LightningModule):
|
|||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="train_loss", value=loss.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="train_loss",
|
||||
value=loss.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.log("train_loss", loss.item())
|
||||
return {"loss": loss}
|
||||
|
||||
|
|
@ -123,33 +149,34 @@ class Model(pl.LightningModule):
|
|||
self.log("val_loss", loss_val.item())
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_loss",value=loss_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_metric",value=metric_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_loss",
|
||||
value=loss_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_metric",
|
||||
value=metric_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
return {"loss": loss_val}
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["enhancer"] = {
|
||||
"version": {
|
||||
"enhancer":__version__,
|
||||
"pytorch":torch.__version__
|
||||
},
|
||||
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||
"architecture": {
|
||||
"module": self.__class__.__module__,
|
||||
"class":self.__class__.__name__
|
||||
}
|
||||
|
||||
"class": self.__class__.__name__,
|
||||
},
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
|
|
@ -159,8 +186,51 @@ class Model(pl.LightningModule):
|
|||
strict: bool = True,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load Pretrained model
|
||||
|
||||
parameters:
|
||||
checkpoint : Path or str
|
||||
Path to checkpoint, or a remote URL, or a model identifier from
|
||||
the huggingface.co model hub.
|
||||
map_location: optional
|
||||
Same role as in torch.load().
|
||||
Defaults to `lambda storage, loc: storage`.
|
||||
hparams_file : Path or str, optional
|
||||
Path to a .yaml file with hierarchical structure as in this example:
|
||||
drop_prob: 0.2
|
||||
dataloader:
|
||||
batch_size: 32
|
||||
You most likely won’t need this since Lightning will always save the
|
||||
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||
file with the hparams you would like to use. These will be converted
|
||||
into a dict and passed into your Model for use.
|
||||
strict : bool, optional
|
||||
Whether to strictly enforce that the keys in checkpoint match
|
||||
the keys returned by this module’s state dict. Defaults to True.
|
||||
use_auth_token : str, optional
|
||||
When loading a private huggingface.co model, set `use_auth_token`
|
||||
to True or to a string containing your hugginface.co authentication
|
||||
token that can be obtained by running `huggingface-cli login`
|
||||
cache_dir: Path or str, optional
|
||||
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||
kwargs: optional
|
||||
Any extra keyword args needed to init the model.
|
||||
Can also be used to override saved hyperparameter values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Model
|
||||
Model
|
||||
|
||||
See also
|
||||
--------
|
||||
torch.load
|
||||
"""
|
||||
|
||||
checkpoint = str(checkpoint)
|
||||
if hparams_file is not None:
|
||||
|
|
@ -183,8 +253,11 @@ class Model(pl.LightningModule):
|
|||
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||
)
|
||||
model_path_pl = cached_download(
|
||||
url=url,library_name="enhancer",library_version=__version__,
|
||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
||||
url=url,
|
||||
library_name="enhancer",
|
||||
library_version=__version__,
|
||||
cache_dir=cached_dir,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
if map_location is None:
|
||||
|
|
@ -202,23 +275,34 @@ class Model(pl.LightningModule):
|
|||
map_location=map_location,
|
||||
hparams_file=hparams_file,
|
||||
strict=strict,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||
"""
|
||||
perform model inference
|
||||
parameters:
|
||||
batch : torch.Tensor
|
||||
input data
|
||||
batch_size : int, default 32
|
||||
batch size for inference
|
||||
"""
|
||||
|
||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
assert (
|
||||
batch.ndim == 3
|
||||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0, batch.shape[0], batch_size):
|
||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||
self.device
|
||||
)
|
||||
prediction = self(batch_data)
|
||||
batch_predictions.append(prediction)
|
||||
|
||||
|
|
@ -231,41 +315,56 @@ class Model(pl.LightningModule):
|
|||
batch_size: int = 32,
|
||||
save_output: bool = False,
|
||||
duration: Optional[int] = None,
|
||||
step_size:Optional[int]=None,):
|
||||
step_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Enhance audio using loaded pretained model.
|
||||
|
||||
parameters:
|
||||
audio: Path to audio file or numpy array or torch tensor
|
||||
single input audio
|
||||
sampling_rate: int, optional incase input is path
|
||||
sampling rate of input
|
||||
batch_size: int, default 32
|
||||
input audio is split into multiple chunks. Inference is done on batches
|
||||
of these chunks according to given batch size.
|
||||
save_output : bool, default False
|
||||
weather to save output to file
|
||||
duration : float, optional
|
||||
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||
step_size: int, optional
|
||||
step size between consecutive durations, defaults to 50% of duration
|
||||
"""
|
||||
|
||||
model_sampling_rate = self.hparams["sampling_rate"]
|
||||
if duration is None:
|
||||
duration = self.hparams["duration"]
|
||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||
waveform = Inference.read_input(
|
||||
audio, sampling_rate, model_sampling_rate
|
||||
)
|
||||
waveform.to(self.device)
|
||||
window_size = round(duration * model_sampling_rate)
|
||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
||||
batched_waveform = Inference.batchify(
|
||||
waveform, window_size, step_size=step_size
|
||||
)
|
||||
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
||||
waveform = Inference.aggreagate(
|
||||
batch_prediction,
|
||||
window_size,
|
||||
waveform.shape[-1],
|
||||
step_size,
|
||||
)
|
||||
|
||||
if save_output and isinstance(audio, (str, Path)):
|
||||
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||
|
||||
else:
|
||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||
audio, sampling_rate)
|
||||
waveform = Inference.prepare_output(
|
||||
waveform, model_sampling_rate, audio, sampling_rate
|
||||
)
|
||||
return waveform
|
||||
|
||||
@property
|
||||
def valid_monitor(self):
|
||||
|
||||
return "max" if self.loss.higher_better else "min"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue