set cost property
This commit is contained in:
parent
1f4947103f
commit
e2f570a8d1
|
|
@ -1,3 +1,7 @@
|
||||||
|
try:
|
||||||
|
from functools import cached_property
|
||||||
|
except ImportError:
|
||||||
|
from backports.cached_property import cached_property
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -42,7 +46,34 @@ class Model(pl.LightningModule):
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||||
|
|
||||||
|
self.loss = loss
|
||||||
|
self.metric = metric
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss(self):
|
||||||
|
return self._loss
|
||||||
|
|
||||||
|
@loss.setter
|
||||||
|
def loss(self,loss):
|
||||||
|
|
||||||
|
if isinstance(loss,str):
|
||||||
|
losses = [loss]
|
||||||
|
|
||||||
|
self._loss = Avergeloss(losses)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric(self):
|
||||||
|
return self._metric
|
||||||
|
|
||||||
|
@metric.setter
|
||||||
|
def metric(self,metric):
|
||||||
|
|
||||||
|
if isinstance(metric,str):
|
||||||
|
metric = [metric]
|
||||||
|
|
||||||
|
self._metric = Avergeloss(metric)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
return self._dataset
|
return self._dataset
|
||||||
|
|
@ -55,16 +86,7 @@ class Model(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
self.loss = self.setup_loss(self.hparams.loss)
|
|
||||||
self.metric = self.setup_loss(self.hparams.metric)
|
|
||||||
|
|
||||||
def setup_loss(self,loss):
|
|
||||||
|
|
||||||
if isinstance(loss,str):
|
|
||||||
losses = [loss]
|
|
||||||
|
|
||||||
return Avergeloss(losses)
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
||||||
|
|
@ -224,7 +246,12 @@ class Model(pl.LightningModule):
|
||||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid_monitor(self):
|
||||||
|
|
||||||
|
return "max" if self.loss.higher_better else "min"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue