commit
07e6aa6e24
|
|
@ -27,66 +27,89 @@ CACHE_DIR = ""
|
|||
HF_TORCH_WEIGHTS = ""
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
"""
|
||||
Base class for all models
|
||||
parameters:
|
||||
num_channels: int, default to 1
|
||||
number of channels in input audio
|
||||
sampling_rate : int, default 16khz
|
||||
audio sampling rate
|
||||
lr: float, optional
|
||||
learning rate for model training
|
||||
dataset: EnhancerDataset, optional
|
||||
Enhancer dataset used for training/validation
|
||||
duration: float, optional
|
||||
duration used for training/inference
|
||||
loss : string or List of strings, default to "mse"
|
||||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels:int=1,
|
||||
sampling_rate:int=16000,
|
||||
lr:float=1e-3,
|
||||
dataset:Optional[EnhancerDataset]=None,
|
||||
duration:Optional[float]=None,
|
||||
num_channels: int = 1,
|
||||
sampling_rate: int = 16000,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric:Union[str,List] = "mse"
|
||||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
super().__init__()
|
||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||
assert (
|
||||
num_channels == 1
|
||||
), "Enhancer only support for mono channel models"
|
||||
self.dataset = dataset
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||
self.save_hyperparameters(
|
||||
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||
|
||||
self.logger.experiment.log_dict(
|
||||
dict(self.hparams), "hyperparameters.json"
|
||||
)
|
||||
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
return self._loss
|
||||
|
||||
@loss.setter
|
||||
def loss(self,loss):
|
||||
|
||||
if isinstance(loss,str):
|
||||
losses = [loss]
|
||||
|
||||
@loss.setter
|
||||
def loss(self, loss):
|
||||
|
||||
if isinstance(loss, str):
|
||||
losses = [loss]
|
||||
|
||||
self._loss = Avergeloss(losses)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self._metric
|
||||
|
||||
|
||||
@metric.setter
|
||||
def metric(self,metric):
|
||||
def metric(self, metric):
|
||||
|
||||
if isinstance(metric, str):
|
||||
metric = [metric]
|
||||
|
||||
if isinstance(metric,str):
|
||||
metric = [metric]
|
||||
|
||||
self._metric = Avergeloss(metric)
|
||||
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self,dataset):
|
||||
def dataset(self, dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def setup(self,stage:Optional[str]=None):
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
if stage == "fit":
|
||||
self.dataset.setup(stage)
|
||||
self.dataset.model = self
|
||||
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataset.train_dataloader()
|
||||
|
||||
|
|
@ -94,9 +117,9 @@ class Model(pl.LightningModule):
|
|||
return self.dataset.val_dataloader()
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr = self.hparams.lr)
|
||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||
|
||||
def training_step(self,batch, batch_idx:int):
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
|
|
@ -105,13 +128,16 @@ class Model(pl.LightningModule):
|
|||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="train_loss", value=loss.item(),
|
||||
step=self.global_step)
|
||||
self.log("train_loss",loss.item())
|
||||
return {"loss":loss}
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="train_loss",
|
||||
value=loss.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.log("train_loss", loss.item())
|
||||
return {"loss": loss}
|
||||
|
||||
def validation_step(self,batch,batch_idx:int):
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
|
|
@ -119,48 +145,92 @@ class Model(pl.LightningModule):
|
|||
|
||||
metric_val = self.metric(prediction, target)
|
||||
loss_val = self.loss(prediction, target)
|
||||
self.log("val_metric",metric_val.item())
|
||||
self.log("val_loss",loss_val.item())
|
||||
self.log("val_metric", metric_val.item())
|
||||
self.log("val_loss", loss_val.item())
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_loss",value=loss_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_metric",value=metric_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_loss",
|
||||
value=loss_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_metric",
|
||||
value=metric_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
return {"loss":loss_val}
|
||||
return {"loss": loss_val}
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["enhancer"] = {
|
||||
"version": {
|
||||
"enhancer":__version__,
|
||||
"pytorch":torch.__version__
|
||||
"version": {"enhancer": __version__, "pytorch": torch.__version__},
|
||||
"architecture": {
|
||||
"module": self.__class__.__module__,
|
||||
"class": self.__class__.__name__,
|
||||
},
|
||||
"architecture":{
|
||||
"module":self.__class__.__module__,
|
||||
"class":self.__class__.__name__
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
checkpoint: Union[Path, Text],
|
||||
map_location = None,
|
||||
map_location=None,
|
||||
hparams_file: Union[Path, Text] = None,
|
||||
strict: bool = True,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
cached_dir: Union[Path, Text]=CACHE_DIR,
|
||||
**kwargs
|
||||
cached_dir: Union[Path, Text] = CACHE_DIR,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load Pretrained model
|
||||
|
||||
parameters:
|
||||
checkpoint : Path or str
|
||||
Path to checkpoint, or a remote URL, or a model identifier from
|
||||
the huggingface.co model hub.
|
||||
map_location: optional
|
||||
Same role as in torch.load().
|
||||
Defaults to `lambda storage, loc: storage`.
|
||||
hparams_file : Path or str, optional
|
||||
Path to a .yaml file with hierarchical structure as in this example:
|
||||
drop_prob: 0.2
|
||||
dataloader:
|
||||
batch_size: 32
|
||||
You most likely won’t need this since Lightning will always save the
|
||||
hyperparameters to the checkpoint. However, if your checkpoint weights
|
||||
do not have the hyperparameters saved, use this method to pass in a .yaml
|
||||
file with the hparams you would like to use. These will be converted
|
||||
into a dict and passed into your Model for use.
|
||||
strict : bool, optional
|
||||
Whether to strictly enforce that the keys in checkpoint match
|
||||
the keys returned by this module’s state dict. Defaults to True.
|
||||
use_auth_token : str, optional
|
||||
When loading a private huggingface.co model, set `use_auth_token`
|
||||
to True or to a string containing your hugginface.co authentication
|
||||
token that can be obtained by running `huggingface-cli login`
|
||||
cache_dir: Path or str, optional
|
||||
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||
kwargs: optional
|
||||
Any extra keyword args needed to init the model.
|
||||
Can also be used to override saved hyperparameter values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Model
|
||||
Model
|
||||
|
||||
See also
|
||||
--------
|
||||
torch.load
|
||||
"""
|
||||
|
||||
checkpoint = str(checkpoint)
|
||||
if hparams_file is not None:
|
||||
|
|
@ -168,104 +238,133 @@ class Model(pl.LightningModule):
|
|||
|
||||
if os.path.isfile(checkpoint):
|
||||
model_path_pl = checkpoint
|
||||
elif urlparse(checkpoint).scheme in ("http","https"):
|
||||
elif urlparse(checkpoint).scheme in ("http", "https"):
|
||||
model_path_pl = checkpoint
|
||||
else:
|
||||
|
||||
|
||||
if "@" in checkpoint:
|
||||
model_id = checkpoint.split("@")[0]
|
||||
revision_id = checkpoint.split("@")[1]
|
||||
else:
|
||||
model_id = checkpoint
|
||||
revision_id = None
|
||||
|
||||
|
||||
url = hf_hub_url(
|
||||
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
||||
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
|
||||
)
|
||||
model_path_pl = cached_download(
|
||||
url=url,library_name="enhancer",library_version=__version__,
|
||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
||||
url=url,
|
||||
library_name="enhancer",
|
||||
library_version=__version__,
|
||||
cache_dir=cached_dir,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
if map_location is None:
|
||||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
try:
|
||||
model = Klass.load_from_checkpoint(
|
||||
checkpoint_path = model_path_pl,
|
||||
map_location = map_location,
|
||||
hparams_file = hparams_file,
|
||||
strict = strict,
|
||||
**kwargs
|
||||
checkpoint_path=model_path_pl,
|
||||
map_location=map_location,
|
||||
hparams_file=hparams_file,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return model
|
||||
|
||||
return model
|
||||
def infer(self, batch: torch.Tensor, batch_size: int = 32):
|
||||
"""
|
||||
perform model inference
|
||||
parameters:
|
||||
batch : torch.Tensor
|
||||
input data
|
||||
batch_size : int, default 32
|
||||
batch size for inference
|
||||
"""
|
||||
|
||||
def infer(self,batch:torch.Tensor,batch_size:int=32):
|
||||
|
||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
assert (
|
||||
batch.ndim == 3
|
||||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0,batch.shape[0],batch_size):
|
||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||
for batch_id in range(0, batch.shape[0], batch_size):
|
||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||
self.device
|
||||
)
|
||||
prediction = self(batch_data)
|
||||
batch_predictions.append(prediction)
|
||||
|
||||
|
||||
return torch.vstack(batch_predictions)
|
||||
|
||||
def enhance(
|
||||
self,
|
||||
audio:Union[Path,np.ndarray,torch.Tensor],
|
||||
sampling_rate:Optional[int]=None,
|
||||
batch_size:int=32,
|
||||
save_output:bool=False,
|
||||
duration:Optional[int]=None,
|
||||
step_size:Optional[int]=None,):
|
||||
audio: Union[Path, np.ndarray, torch.Tensor],
|
||||
sampling_rate: Optional[int] = None,
|
||||
batch_size: int = 32,
|
||||
save_output: bool = False,
|
||||
duration: Optional[int] = None,
|
||||
step_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Enhance audio using loaded pretained model.
|
||||
|
||||
parameters:
|
||||
audio: Path to audio file or numpy array or torch tensor
|
||||
single input audio
|
||||
sampling_rate: int, optional incase input is path
|
||||
sampling rate of input
|
||||
batch_size: int, default 32
|
||||
input audio is split into multiple chunks. Inference is done on batches
|
||||
of these chunks according to given batch size.
|
||||
save_output : bool, default False
|
||||
weather to save output to file
|
||||
duration : float, optional
|
||||
chunk duration in seconds, defaults to duration of loaded pretrained model.
|
||||
step_size: int, optional
|
||||
step size between consecutive durations, defaults to 50% of duration
|
||||
"""
|
||||
|
||||
model_sampling_rate = self.hparams["sampling_rate"]
|
||||
if duration is None:
|
||||
duration = self.hparams["duration"]
|
||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||
waveform = Inference.read_input(
|
||||
audio, sampling_rate, model_sampling_rate
|
||||
)
|
||||
waveform.to(self.device)
|
||||
window_size = round(duration * model_sampling_rate)
|
||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
||||
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
||||
|
||||
if save_output and isinstance(audio,(str,Path)):
|
||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||
batched_waveform = Inference.batchify(
|
||||
waveform, window_size, step_size=step_size
|
||||
)
|
||||
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(
|
||||
batch_prediction,
|
||||
window_size,
|
||||
waveform.shape[-1],
|
||||
step_size,
|
||||
)
|
||||
|
||||
if save_output and isinstance(audio, (str, Path)):
|
||||
Inference.write_output(waveform, audio, model_sampling_rate)
|
||||
|
||||
else:
|
||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||
audio, sampling_rate)
|
||||
waveform = Inference.prepare_output(
|
||||
waveform, model_sampling_rate, audio, sampling_rate
|
||||
)
|
||||
return waveform
|
||||
|
||||
@property
|
||||
def valid_monitor(self):
|
||||
|
||||
return "max" if self.loss.higher_better else "min"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return "max" if self.loss.higher_better else "min"
|
||||
|
|
|
|||
Loading…
Reference in New Issue