From b55e12d15c7e8e054d383b168682c15bb08cee3b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 27 Sep 2022 11:31:40 +0530 Subject: [PATCH] add logging --- enhancer/models/model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 65946a2..c4be077 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -38,6 +38,8 @@ class Model(pl.LightningModule): assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") + if self.logger: + self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") @property @@ -79,6 +81,9 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) + if self.logger: + self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) + return {"loss":loss} def validation_step(self,batch,batch_idx:int): @@ -88,6 +93,8 @@ class Model(pl.LightningModule): prediction = self(mixed_waveform) loss = self.metric(prediction, target) + if self.logger: + self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) return {"loss":loss}