diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 65946a2..c4be077 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -38,6 +38,8 @@ class Model(pl.LightningModule): assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") + if self.logger: + self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") @property @@ -79,6 +81,9 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) + if self.logger: + self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) + return {"loss":loss} def validation_step(self,batch,batch_idx:int): @@ -88,6 +93,8 @@ class Model(pl.LightningModule): prediction = self(mixed_waveform) loss = self.metric(prediction, target) + if self.logger: + self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) return {"loss":loss}