diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 8e607ed..b1bdd86 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,3 +1,7 @@ +try: + from functools import cached_property +except ImportError: + from backports.cached_property import cached_property from importlib import import_module from huggingface_hub import cached_download, hf_hub_url import logging @@ -42,7 +46,34 @@ class Model(pl.LightningModule): if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") - + self.loss = loss + self.metric = metric + + @property + def loss(self): + return self._loss + + @loss.setter + def loss(self,loss): + + if isinstance(loss,str): + losses = [loss] + + self._loss = Avergeloss(losses) + + @property + def metric(self): + return self._metric + + @metric.setter + def metric(self,metric): + + if isinstance(metric,str): + metric = [metric] + + self._metric = Avergeloss(metric) + + @property def dataset(self): return self._dataset @@ -55,16 +86,7 @@ class Model(pl.LightningModule): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - self.loss = self.setup_loss(self.hparams.loss) - self.metric = self.setup_loss(self.hparams.metric) - - def setup_loss(self,loss): - - if isinstance(loss,str): - losses = [loss] - - return Avergeloss(losses) - + def train_dataloader(self): return self.dataset.train_dataloader() @@ -224,7 +246,12 @@ class Model(pl.LightningModule): Inference.write_output(waveform,audio,model_sampling_rate) else: - return waveform + return waveform + + @property + def valid_monitor(self): + + return "max" if self.loss.higher_better else "min"