log metric
This commit is contained in:
parent
c1b67c1e3a
commit
4e033d2ab5
|
|
@ -38,9 +38,6 @@ class Model(pl.LightningModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
if self.dataset is not None:
|
|
||||||
sampling_rate = self.dataset.sampling_rate
|
|
||||||
logging.warn("Setting model sampling rate same as dataset sampling rate")
|
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||||
|
|
@ -86,7 +83,7 @@ class Model(pl.LightningModule):
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
|
self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step)
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
|
|
@ -98,7 +95,7 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
loss = self.metric(prediction, target)
|
loss = self.metric(prediction, target)
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
|
self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step)
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue