set cost property

This commit is contained in:
shahules786 2022-09-30 15:19:43 +05:30
parent 1f4947103f
commit e2f570a8d1
1 changed files with 39 additions and 12 deletions

View File

@ -1,3 +1,7 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from importlib import import_module from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
import logging import logging
@ -42,6 +46,33 @@ class Model(pl.LightningModule):
if self.logger: if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
self.loss = loss
self.metric = metric
@property
def loss(self):
return self._loss
@loss.setter
def loss(self,loss):
if isinstance(loss,str):
losses = [loss]
self._loss = Avergeloss(losses)
@property
def metric(self):
return self._metric
@metric.setter
def metric(self,metric):
if isinstance(metric,str):
metric = [metric]
self._metric = Avergeloss(metric)
@property @property
def dataset(self): def dataset(self):
@ -55,15 +86,6 @@ class Model(pl.LightningModule):
if stage == "fit": if stage == "fit":
self.dataset.setup(stage) self.dataset.setup(stage)
self.dataset.model = self self.dataset.model = self
self.loss = self.setup_loss(self.hparams.loss)
self.metric = self.setup_loss(self.hparams.metric)
def setup_loss(self,loss):
if isinstance(loss,str):
losses = [loss]
return Avergeloss(losses)
def train_dataloader(self): def train_dataloader(self):
return self.dataset.train_dataloader() return self.dataset.train_dataloader()
@ -226,6 +248,11 @@ class Model(pl.LightningModule):
else: else:
return waveform return waveform
@property
def valid_monitor(self):
return "max" if self.loss.higher_better else "min"