log metric

This commit is contained in:
shahules786 2022-09-29 17:08:49 +05:30
parent c1b67c1e3a
commit 4e033d2ab5
1 changed files with 2 additions and 5 deletions

View File

@ -38,9 +38,6 @@ class Model(pl.LightningModule):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
self.dataset = dataset
if self.dataset is not None:
sampling_rate = self.dataset.sampling_rate
logging.warn("Setting model sampling rate same as dataset sampling rate")
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
@ -86,7 +83,7 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target)
if self.logger:
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step)
return {"loss":loss}
@ -98,7 +95,7 @@ class Model(pl.LightningModule):
loss = self.metric(prediction, target)
if self.logger:
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step)
return {"loss":loss}