From 4e033d2ab5cc3dd1ccdb331f942a765aa8959a5a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 17:08:49 +0530 Subject: [PATCH] log metric --- enhancer/models/model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 980c583..64bf201 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -38,9 +38,6 @@ class Model(pl.LightningModule): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset - if self.dataset is not None: - sampling_rate = self.dataset.sampling_rate - logging.warn("Setting model sampling rate same as dataset sampling rate") self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") @@ -86,7 +83,7 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) + self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step) return {"loss":loss} @@ -98,7 +95,7 @@ class Model(pl.LightningModule): loss = self.metric(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) + self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step) return {"loss":loss}