set cost property

This commit is contained in:
shahules786 2022-09-30 15:19:43 +05:30
parent 1f4947103f
commit e2f570a8d1
1 changed files with 39 additions and 12 deletions

View File

@ -1,3 +1,7 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import logging
@ -42,7 +46,34 @@ class Model(pl.LightningModule):
if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
self.loss = loss
self.metric = metric
@property
def loss(self):
return self._loss
@loss.setter
def loss(self,loss):
if isinstance(loss,str):
losses = [loss]
self._loss = Avergeloss(losses)
@property
def metric(self):
return self._metric
@metric.setter
def metric(self,metric):
if isinstance(metric,str):
metric = [metric]
self._metric = Avergeloss(metric)
@property
def dataset(self):
return self._dataset
@ -55,16 +86,7 @@ class Model(pl.LightningModule):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
self.loss = self.setup_loss(self.hparams.loss)
self.metric = self.setup_loss(self.hparams.metric)
def setup_loss(self,loss):
if isinstance(loss,str):
losses = [loss]
return Avergeloss(losses)
def train_dataloader(self):
return self.dataset.train_dataloader()
@ -224,7 +246,12 @@ class Model(pl.LightningModule):
Inference.write_output(waveform,audio,model_sampling_rate)
else:
return waveform
return waveform
@property
def valid_monitor(self):
return "max" if self.loss.higher_better else "min"