add logging

This commit is contained in:
shahules786 2022-09-27 11:31:40 +05:30
parent 34755f33aa
commit b55e12d15c
1 changed files with 7 additions and 0 deletions

View File

@ -38,6 +38,8 @@ class Model(pl.LightningModule):
assert num_channels ==1 , "Enhancer only support for mono channel models"
self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
@property
@ -79,6 +81,9 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target)
if self.logger:
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
return {"loss":loss}
def validation_step(self,batch,batch_idx:int):
@ -88,6 +93,8 @@ class Model(pl.LightningModule):
prediction = self(mixed_waveform)
loss = self.metric(prediction, target)
if self.logger:
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
return {"loss":loss}