Merge pull request #7 from shahules786/dev-reformat

refactor model.py
This commit is contained in:
Shahul ES 2022-10-05 12:37:57 +05:30 committed by GitHub
commit 07e6aa6e24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 204 additions and 105 deletions

View File

@ -27,7 +27,25 @@ CACHE_DIR = ""
HF_TORCH_WEIGHTS = "" HF_TORCH_WEIGHTS = ""
DEFAULT_DEVICE = "cpu" DEFAULT_DEVICE = "cpu"
class Model(pl.LightningModule): class Model(pl.LightningModule):
"""
Base class for all models
parameters:
num_channels: int, default to 1
number of channels in input audio
sampling_rate : int, default 16khz
audio sampling rate
lr: float, optional
learning rate for model training
dataset: EnhancerDataset, optional
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings, default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
def __init__( def __init__(
self, self,
@ -37,14 +55,20 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse" metric: Union[str, List] = "mse",
): ):
super().__init__() super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models" assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
)
if self.logger: if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") self.logger.experiment.log_dict(
dict(self.hparams), "hyperparameters.json"
)
self.loss = loss self.loss = loss
self.metric = metric self.metric = metric
@ -73,7 +97,6 @@ class Model(pl.LightningModule):
self._metric = Avergeloss(metric) self._metric = Avergeloss(metric)
@property @property
def dataset(self): def dataset(self):
return self._dataset return self._dataset
@ -105,9 +128,12 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target) loss = self.loss(prediction, target)
if self.logger: if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id, self.logger.experiment.log_metric(
key="train_loss", value=loss.item(), run_id=self.logger.run_id,
step=self.global_step) key="train_loss",
value=loss.item(),
step=self.global_step,
)
self.log("train_loss", loss.item()) self.log("train_loss", loss.item())
return {"loss": loss} return {"loss": loss}
@ -123,33 +149,34 @@ class Model(pl.LightningModule):
self.log("val_loss", loss_val.item()) self.log("val_loss", loss_val.item())
if self.logger: if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id, self.logger.experiment.log_metric(
key="val_loss",value=loss_val.item(), run_id=self.logger.run_id,
step=self.global_step) key="val_loss",
self.logger.experiment.log_metric(run_id=self.logger.run_id, value=loss_val.item(),
key="val_metric",value=metric_val.item(), step=self.global_step,
step=self.global_step) )
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
step=self.global_step,
)
return {"loss": loss_val} return {"loss": loss_val}
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = { checkpoint["enhancer"] = {
"version": { "version": {"enhancer": __version__, "pytorch": torch.__version__},
"enhancer":__version__,
"pytorch":torch.__version__
},
"architecture": { "architecture": {
"module": self.__class__.__module__, "module": self.__class__.__module__,
"class":self.__class__.__name__ "class": self.__class__.__name__,
} },
} }
def on_load_checkpoint(self, checkpoint: Dict[str, Any]): def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass pass
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
@ -159,8 +186,51 @@ class Model(pl.LightningModule):
strict: bool = True, strict: bool = True,
use_auth_token: Union[Text, None] = None, use_auth_token: Union[Text, None] = None,
cached_dir: Union[Path, Text] = CACHE_DIR, cached_dir: Union[Path, Text] = CACHE_DIR,
**kwargs **kwargs,
): ):
"""
Load Pretrained model
parameters:
checkpoint : Path or str
Path to checkpoint, or a remote URL, or a model identifier from
the huggingface.co model hub.
map_location: optional
Same role as in torch.load().
Defaults to `lambda storage, loc: storage`.
hparams_file : Path or str, optional
Path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely wont need this since Lightning will always save the
hyperparameters to the checkpoint. However, if your checkpoint weights
do not have the hyperparameters saved, use this method to pass in a .yaml
file with the hparams you would like to use. These will be converted
into a dict and passed into your Model for use.
strict : bool, optional
Whether to strictly enforce that the keys in checkpoint match
the keys returned by this modules state dict. Defaults to True.
use_auth_token : str, optional
When loading a private huggingface.co model, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Returns
-------
model : Model
Model
See also
--------
torch.load
"""
checkpoint = str(checkpoint) checkpoint = str(checkpoint)
if hparams_file is not None: if hparams_file is not None:
@ -183,8 +253,11 @@ class Model(pl.LightningModule):
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
) )
model_path_pl = cached_download( model_path_pl = cached_download(
url=url,library_name="enhancer",library_version=__version__, url=url,
cache_dir=cached_dir,use_auth_token=use_auth_token library_name="enhancer",
library_version=__version__,
cache_dir=cached_dir,
use_auth_token=use_auth_token,
) )
if map_location is None: if map_location is None:
@ -202,23 +275,34 @@ class Model(pl.LightningModule):
map_location=map_location, map_location=map_location,
hparams_file=hparams_file, hparams_file=hparams_file,
strict=strict, strict=strict,
**kwargs **kwargs,
) )
except Exception as e: except Exception as e:
print(e) print(e)
return model return model
def infer(self, batch: torch.Tensor, batch_size: int = 32): def infer(self, batch: torch.Tensor, batch_size: int = 32):
"""
perform model inference
parameters:
batch : torch.Tensor
input data
batch_size : int, default 32
batch size for inference
"""
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" assert (
batch.ndim == 3
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = [] batch_predictions = []
self.eval().to(self.device) self.eval().to(self.device)
with torch.no_grad(): with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size): for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
self.device
)
prediction = self(batch_data) prediction = self(batch_data)
batch_predictions.append(prediction) batch_predictions.append(prediction)
@ -231,41 +315,56 @@ class Model(pl.LightningModule):
batch_size: int = 32, batch_size: int = 32,
save_output: bool = False, save_output: bool = False,
duration: Optional[int] = None, duration: Optional[int] = None,
step_size:Optional[int]=None,): step_size: Optional[int] = None,
):
"""
Enhance audio using loaded pretained model.
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate: int, optional incase input is path
sampling rate of input
batch_size: int, default 32
input audio is split into multiple chunks. Inference is done on batches
of these chunks according to given batch size.
save_output : bool, default False
weather to save output to file
duration : float, optional
chunk duration in seconds, defaults to duration of loaded pretrained model.
step_size: int, optional
step size between consecutive durations, defaults to 50% of duration
"""
model_sampling_rate = self.hparams["sampling_rate"] model_sampling_rate = self.hparams["sampling_rate"]
if duration is None: if duration is None:
duration = self.hparams["duration"] duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform = Inference.read_input(
audio, sampling_rate, model_sampling_rate
)
waveform.to(self.device) waveform.to(self.device)
window_size = round(duration * model_sampling_rate) window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) batched_waveform = Inference.batchify(
waveform, window_size, step_size=step_size
)
batch_prediction = self.infer(batched_waveform, batch_size=batch_size) batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,) waveform = Inference.aggreagate(
batch_prediction,
window_size,
waveform.shape[-1],
step_size,
)
if save_output and isinstance(audio, (str, Path)): if save_output and isinstance(audio, (str, Path)):
Inference.write_output(waveform, audio, model_sampling_rate) Inference.write_output(waveform, audio, model_sampling_rate)
else: else:
waveform = Inference.prepare_output(waveform, model_sampling_rate, waveform = Inference.prepare_output(
audio, sampling_rate) waveform, model_sampling_rate, audio, sampling_rate
)
return waveform return waveform
@property @property
def valid_monitor(self): def valid_monitor(self):
return "max" if self.loss.higher_better else "min" return "max" if self.loss.higher_better else "min"