log metric
This commit is contained in:
parent
c1b67c1e3a
commit
4e033d2ab5
|
|
@ -38,9 +38,6 @@ class Model(pl.LightningModule):
|
|||
super().__init__()
|
||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||
self.dataset = dataset
|
||||
if self.dataset is not None:
|
||||
sampling_rate = self.dataset.sampling_rate
|
||||
logging.warn("Setting model sampling rate same as dataset sampling rate")
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||
if self.logger:
|
||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||
|
|
@ -86,7 +83,7 @@ class Model(pl.LightningModule):
|
|||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
|
||||
self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step)
|
||||
|
||||
return {"loss":loss}
|
||||
|
||||
|
|
@ -98,7 +95,7 @@ class Model(pl.LightningModule):
|
|||
|
||||
loss = self.metric(prediction, target)
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
|
||||
self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step)
|
||||
|
||||
return {"loss":loss}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue