add logging
This commit is contained in:
parent
34755f33aa
commit
b55e12d15c
|
|
@ -38,6 +38,8 @@ class Model(pl.LightningModule):
|
|||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||
self.dataset = dataset
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||
if self.logger:
|
||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||
|
||||
|
||||
@property
|
||||
|
|
@ -79,6 +81,9 @@ class Model(pl.LightningModule):
|
|||
|
||||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
|
||||
|
||||
return {"loss":loss}
|
||||
|
||||
def validation_step(self,batch,batch_idx:int):
|
||||
|
|
@ -88,6 +93,8 @@ class Model(pl.LightningModule):
|
|||
prediction = self(mixed_waveform)
|
||||
|
||||
loss = self.metric(prediction, target)
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
|
||||
|
||||
return {"loss":loss}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue