commit
07e6aa6e24
|
|
@ -27,7 +27,25 @@ CACHE_DIR = ""
|
||||||
HF_TORCH_WEIGHTS = ""
|
HF_TORCH_WEIGHTS = ""
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
|
|
||||||
|
|
||||||
class Model(pl.LightningModule):
|
class Model(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
Base class for all models
|
||||||
|
parameters:
|
||||||
|
num_channels: int, default to 1
|
||||||
|
number of channels in input audio
|
||||||
|
sampling_rate : int, default 16khz
|
||||||
|
audio sampling rate
|
||||||
|
lr: float, optional
|
||||||
|
learning rate for model training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
Enhancer dataset used for training/validation
|
||||||
|
duration: float, optional
|
||||||
|
duration used for training/inference
|
||||||
|
loss : string or List of strings, default to "mse"
|
||||||
|
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -37,14 +55,20 @@ class Model(pl.LightningModule):
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str,List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert (
|
||||||
|
num_channels == 1
|
||||||
|
), "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
self.save_hyperparameters(
|
||||||
|
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||||
|
)
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(
|
||||||
|
dict(self.hparams), "hyperparameters.json"
|
||||||
|
)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
|
@ -73,7 +97,6 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
self._metric = Avergeloss(metric)
|
self._metric = Avergeloss(metric)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
return self._dataset
|
return self._dataset
|
||||||
|
|
@ -105,9 +128,12 @@ class Model(pl.LightningModule):
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(
|
||||||
key="train_loss", value=loss.item(),
|
run_id=self.logger.run_id,
|
||||||
step=self.global_step)
|
key="train_loss",
|
||||||
|
value=loss.item(),
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
self.log("train_loss", loss.item())
|
self.log("train_loss", loss.item())
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
@ -123,33 +149,34 @@ class Model(pl.LightningModule):
|
||||||
self.log("val_loss", loss_val.item())
|
self.log("val_loss", loss_val.item())
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
self.logger.experiment.log_metric(
|
||||||
key="val_loss",value=loss_val.item(),
|
run_id=self.logger.run_id,
|
||||||
step=self.global_step)
|
key="val_loss",
|
||||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
value=loss_val.item(),
|
||||||
key="val_metric",value=metric_val.item(),
|
step=self.global_step,
|
||||||
step=self.global_step)
|
)
|
||||||
|
self.logger.experiment.log_metric(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
key="val_metric",
|
||||||
|
value=metric_val.item(),
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
return {"loss": loss_val}
|
return {"loss": loss_val}
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint["enhancer"] = {
|
checkpoint["enhancer"] = {
|
||||||
"version": {
|
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||||
"enhancer":__version__,
|
|
||||||
"pytorch":torch.__version__
|
|
||||||
},
|
|
||||||
"architecture": {
|
"architecture": {
|
||||||
"module": self.__class__.__module__,
|
"module": self.__class__.__module__,
|
||||||
"class":self.__class__.__name__
|
"class": self.__class__.__name__,
|
||||||
}
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -159,8 +186,51 @@ class Model(pl.LightningModule):
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
use_auth_token: Union[Text, None] = None,
|
use_auth_token: Union[Text, None] = None,
|
||||||
cached_dir: Union[Path, Text] = CACHE_DIR,
|
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Load Pretrained model
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
checkpoint : Path or str
|
||||||
|
Path to checkpoint, or a remote URL, or a model identifier from
|
||||||
|
the huggingface.co model hub.
|
||||||
|
map_location: optional
|
||||||
|
Same role as in torch.load().
|
||||||
|
Defaults to `lambda storage, loc: storage`.
|
||||||
|
hparams_file : Path or str, optional
|
||||||
|
Path to a .yaml file with hierarchical structure as in this example:
|
||||||
|
drop_prob: 0.2
|
||||||
|
dataloader:
|
||||||
|
batch_size: 32
|
||||||
|
You most likely won’t need this since Lightning will always save the
|
||||||
|
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||||
|
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||||
|
file with the hparams you would like to use. These will be converted
|
||||||
|
into a dict and passed into your Model for use.
|
||||||
|
strict : bool, optional
|
||||||
|
Whether to strictly enforce that the keys in checkpoint match
|
||||||
|
the keys returned by this module’s state dict. Defaults to True.
|
||||||
|
use_auth_token : str, optional
|
||||||
|
When loading a private huggingface.co model, set `use_auth_token`
|
||||||
|
to True or to a string containing your hugginface.co authentication
|
||||||
|
token that can be obtained by running `huggingface-cli login`
|
||||||
|
cache_dir: Path or str, optional
|
||||||
|
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||||
|
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||||
|
kwargs: optional
|
||||||
|
Any extra keyword args needed to init the model.
|
||||||
|
Can also be used to override saved hyperparameter values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Model
|
||||||
|
Model
|
||||||
|
|
||||||
|
See also
|
||||||
|
--------
|
||||||
|
torch.load
|
||||||
|
"""
|
||||||
|
|
||||||
checkpoint = str(checkpoint)
|
checkpoint = str(checkpoint)
|
||||||
if hparams_file is not None:
|
if hparams_file is not None:
|
||||||
|
|
@ -183,8 +253,11 @@ class Model(pl.LightningModule):
|
||||||
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||||
)
|
)
|
||||||
model_path_pl = cached_download(
|
model_path_pl = cached_download(
|
||||||
url=url,library_name="enhancer",library_version=__version__,
|
url=url,
|
||||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
library_name="enhancer",
|
||||||
|
library_version=__version__,
|
||||||
|
cache_dir=cached_dir,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if map_location is None:
|
if map_location is None:
|
||||||
|
|
@ -202,23 +275,34 @@ class Model(pl.LightningModule):
|
||||||
map_location=map_location,
|
map_location=map_location,
|
||||||
hparams_file=hparams_file,
|
hparams_file=hparams_file,
|
||||||
strict=strict,
|
strict=strict,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||||
|
"""
|
||||||
|
perform model inference
|
||||||
|
parameters:
|
||||||
|
batch : torch.Tensor
|
||||||
|
input data
|
||||||
|
batch_size : int, default 32
|
||||||
|
batch size for inference
|
||||||
|
"""
|
||||||
|
|
||||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
assert (
|
||||||
|
batch.ndim == 3
|
||||||
|
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||||
batch_predictions = []
|
batch_predictions = []
|
||||||
self.eval().to(self.device)
|
self.eval().to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_id in range(0, batch.shape[0], batch_size):
|
for batch_id in range(0, batch.shape[0], batch_size):
|
||||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
prediction = self(batch_data)
|
prediction = self(batch_data)
|
||||||
batch_predictions.append(prediction)
|
batch_predictions.append(prediction)
|
||||||
|
|
||||||
|
|
@ -231,41 +315,56 @@ class Model(pl.LightningModule):
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
save_output: bool = False,
|
save_output: bool = False,
|
||||||
duration: Optional[int] = None,
|
duration: Optional[int] = None,
|
||||||
step_size:Optional[int]=None,):
|
step_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enhance audio using loaded pretained model.
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
audio: Path to audio file or numpy array or torch tensor
|
||||||
|
single input audio
|
||||||
|
sampling_rate: int, optional incase input is path
|
||||||
|
sampling rate of input
|
||||||
|
batch_size: int, default 32
|
||||||
|
input audio is split into multiple chunks. Inference is done on batches
|
||||||
|
of these chunks according to given batch size.
|
||||||
|
save_output : bool, default False
|
||||||
|
weather to save output to file
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||||
|
step_size: int, optional
|
||||||
|
step size between consecutive durations, defaults to 50% of duration
|
||||||
|
"""
|
||||||
|
|
||||||
model_sampling_rate = self.hparams["sampling_rate"]
|
model_sampling_rate = self.hparams["sampling_rate"]
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = self.hparams["duration"]
|
duration = self.hparams["duration"]
|
||||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
waveform = Inference.read_input(
|
||||||
|
audio, sampling_rate, model_sampling_rate
|
||||||
|
)
|
||||||
waveform.to(self.device)
|
waveform.to(self.device)
|
||||||
window_size = round(duration * model_sampling_rate)
|
window_size = round(duration * model_sampling_rate)
|
||||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
batched_waveform = Inference.batchify(
|
||||||
|
waveform, window_size, step_size=step_size
|
||||||
|
)
|
||||||
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
waveform = Inference.aggreagate(
|
||||||
|
batch_prediction,
|
||||||
|
window_size,
|
||||||
|
waveform.shape[-1],
|
||||||
|
step_size,
|
||||||
|
)
|
||||||
|
|
||||||
if save_output and isinstance(audio, (str, Path)):
|
if save_output and isinstance(audio, (str, Path)):
|
||||||
Inference.write_output(waveform, audio, model_sampling_rate)
|
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
waveform = Inference.prepare_output(
|
||||||
audio, sampling_rate)
|
waveform, model_sampling_rate, audio, sampling_rate
|
||||||
|
)
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_monitor(self):
|
def valid_monitor(self):
|
||||||
|
|
||||||
return "max" if self.loss.higher_better else "min"
|
return "max" if self.loss.higher_better else "min"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue