Merge pull request #7 from shahules786/dev-reformat

refactor model.py
This commit is contained in:
Shahul ES 2022-10-05 12:37:57 +05:30 committed by GitHub
commit 07e6aa6e24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 204 additions and 105 deletions

View File

@ -27,66 +27,89 @@ CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
DEFAULT_DEVICE = "cpu"
class Model(pl.LightningModule):
"""
Base class for all models
parameters:
num_channels: int, default to 1
number of channels in input audio
sampling_rate : int, default 16khz
audio sampling rate
lr: float, optional
learning rate for model training
dataset: EnhancerDataset, optional
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings, default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
def __init__(
self,
num_channels:int=1,
sampling_rate:int=16000,
lr:float=1e-3,
dataset:Optional[EnhancerDataset]=None,
duration:Optional[float]=None,
num_channels: int = 1,
sampling_rate: int = 16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse"
metric: Union[str, List] = "mse",
):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
assert (
num_channels == 1
), "Enhancer only support for mono channel models"
self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
self.save_hyperparameters(
"num_channels", "sampling_rate", "lr", "loss", "metric", "duration"
)
if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
self.logger.experiment.log_dict(
dict(self.hparams), "hyperparameters.json"
)
self.loss = loss
self.metric = metric
@property
def loss(self):
return self._loss
@loss.setter
def loss(self,loss):
if isinstance(loss,str):
losses = [loss]
@loss.setter
def loss(self, loss):
if isinstance(loss, str):
losses = [loss]
self._loss = Avergeloss(losses)
@property
def metric(self):
return self._metric
@metric.setter
def metric(self,metric):
def metric(self, metric):
if isinstance(metric, str):
metric = [metric]
if isinstance(metric,str):
metric = [metric]
self._metric = Avergeloss(metric)
@property
def dataset(self):
return self._dataset
@dataset.setter
def dataset(self,dataset):
def dataset(self, dataset):
self._dataset = dataset
def setup(self,stage:Optional[str]=None):
def setup(self, stage: Optional[str] = None):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
def train_dataloader(self):
return self.dataset.train_dataloader()
@ -94,9 +117,9 @@ class Model(pl.LightningModule):
return self.dataset.val_dataloader()
def configure_optimizers(self):
return Adam(self.parameters(), lr = self.hparams.lr)
return Adam(self.parameters(), lr=self.hparams.lr)
def training_step(self,batch, batch_idx:int):
def training_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"]
target = batch["clean"]
@ -105,13 +128,16 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target)
if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="train_loss", value=loss.item(),
step=self.global_step)
self.log("train_loss",loss.item())
return {"loss":loss}
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss",
value=loss.item(),
step=self.global_step,
)
self.log("train_loss", loss.item())
return {"loss": loss}
def validation_step(self,batch,batch_idx:int):
def validation_step(self, batch, batch_idx: int):
mixed_waveform = batch["noisy"]
target = batch["clean"]
@ -119,48 +145,92 @@ class Model(pl.LightningModule):
metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target)
self.log("val_metric",metric_val.item())
self.log("val_loss",loss_val.item())
self.log("val_metric", metric_val.item())
self.log("val_loss", loss_val.item())
if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_loss",value=loss_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_metric",value=metric_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_loss",
value=loss_val.item(),
step=self.global_step,
)
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
step=self.global_step,
)
return {"loss":loss_val}
return {"loss": loss_val}
def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = {
"version": {
"enhancer":__version__,
"pytorch":torch.__version__
"version": {"enhancer": __version__, "pytorch": torch.__version__},
"architecture": {
"module": self.__class__.__module__,
"class": self.__class__.__name__,
},
"architecture":{
"module":self.__class__.__module__,
"class":self.__class__.__name__
}
}
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod
def from_pretrained(
cls,
checkpoint: Union[Path, Text],
map_location = None,
map_location=None,
hparams_file: Union[Path, Text] = None,
strict: bool = True,
use_auth_token: Union[Text, None] = None,
cached_dir: Union[Path, Text]=CACHE_DIR,
**kwargs
cached_dir: Union[Path, Text] = CACHE_DIR,
**kwargs,
):
"""
Load Pretrained model
parameters:
checkpoint : Path or str
Path to checkpoint, or a remote URL, or a model identifier from
the huggingface.co model hub.
map_location: optional
Same role as in torch.load().
Defaults to `lambda storage, loc: storage`.
hparams_file : Path or str, optional
Path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely wont need this since Lightning will always save the
hyperparameters to the checkpoint. However, if your checkpoint weights
do not have the hyperparameters saved, use this method to pass in a .yaml
file with the hparams you would like to use. These will be converted
into a dict and passed into your Model for use.
strict : bool, optional
Whether to strictly enforce that the keys in checkpoint match
the keys returned by this modules state dict. Defaults to True.
use_auth_token : str, optional
When loading a private huggingface.co model, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Returns
-------
model : Model
Model
See also
--------
torch.load
"""
checkpoint = str(checkpoint)
if hparams_file is not None:
@ -168,104 +238,133 @@ class Model(pl.LightningModule):
if os.path.isfile(checkpoint):
model_path_pl = checkpoint
elif urlparse(checkpoint).scheme in ("http","https"):
elif urlparse(checkpoint).scheme in ("http", "https"):
model_path_pl = checkpoint
else:
if "@" in checkpoint:
model_id = checkpoint.split("@")[0]
revision_id = checkpoint.split("@")[1]
else:
model_id = checkpoint
revision_id = None
url = hf_hub_url(
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
model_id, filename=HF_TORCH_WEIGHTS, revision=revision_id
)
model_path_pl = cached_download(
url=url,library_name="enhancer",library_version=__version__,
cache_dir=cached_dir,use_auth_token=use_auth_token
url=url,
library_name="enhancer",
library_version=__version__,
cache_dir=cached_dir,
use_auth_token=use_auth_token,
)
if map_location is None:
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location)
loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)
try:
model = Klass.load_from_checkpoint(
checkpoint_path = model_path_pl,
map_location = map_location,
hparams_file = hparams_file,
strict = strict,
**kwargs
checkpoint_path=model_path_pl,
map_location=map_location,
hparams_file=hparams_file,
strict=strict,
**kwargs,
)
except Exception as e:
print(e)
return model
return model
def infer(self, batch: torch.Tensor, batch_size: int = 32):
"""
perform model inference
parameters:
batch : torch.Tensor
input data
batch_size : int, default 32
batch size for inference
"""
def infer(self,batch:torch.Tensor,batch_size:int=32):
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
assert (
batch.ndim == 3
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = []
self.eval().to(self.device)
with torch.no_grad():
for batch_id in range(0,batch.shape[0],batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
self.device
)
prediction = self(batch_data)
batch_predictions.append(prediction)
return torch.vstack(batch_predictions)
def enhance(
self,
audio:Union[Path,np.ndarray,torch.Tensor],
sampling_rate:Optional[int]=None,
batch_size:int=32,
save_output:bool=False,
duration:Optional[int]=None,
step_size:Optional[int]=None,):
audio: Union[Path, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
batch_size: int = 32,
save_output: bool = False,
duration: Optional[int] = None,
step_size: Optional[int] = None,
):
"""
Enhance audio using loaded pretained model.
parameters:
audio: Path to audio file or numpy array or torch tensor
single input audio
sampling_rate: int, optional incase input is path
sampling rate of input
batch_size: int, default 32
input audio is split into multiple chunks. Inference is done on batches
of these chunks according to given batch size.
save_output : bool, default False
weather to save output to file
duration : float, optional
chunk duration in seconds, defaults to duration of loaded pretrained model.
step_size: int, optional
step size between consecutive durations, defaults to 50% of duration
"""
model_sampling_rate = self.hparams["sampling_rate"]
if duration is None:
duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform = Inference.read_input(
audio, sampling_rate, model_sampling_rate
)
waveform.to(self.device)
window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
if save_output and isinstance(audio,(str,Path)):
Inference.write_output(waveform,audio,model_sampling_rate)
batched_waveform = Inference.batchify(
waveform, window_size, step_size=step_size
)
batch_prediction = self.infer(batched_waveform, batch_size=batch_size)
waveform = Inference.aggreagate(
batch_prediction,
window_size,
waveform.shape[-1],
step_size,
)
if save_output and isinstance(audio, (str, Path)):
Inference.write_output(waveform, audio, model_sampling_rate)
else:
waveform = Inference.prepare_output(waveform, model_sampling_rate,
audio, sampling_rate)
waveform = Inference.prepare_output(
waveform, model_sampling_rate, audio, sampling_rate
)
return waveform
@property
def valid_monitor(self):
return "max" if self.loss.higher_better else "min"
return "max" if self.loss.higher_better else "min"