add logging
This commit is contained in:
parent
34755f33aa
commit
b55e12d15c
|
|
@ -38,6 +38,8 @@ class Model(pl.LightningModule):
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||||
|
if self.logger:
|
||||||
|
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -79,6 +81,9 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
def validation_step(self,batch,batch_idx:int):
|
def validation_step(self,batch,batch_idx:int):
|
||||||
|
|
@ -88,6 +93,8 @@ class Model(pl.LightningModule):
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
loss = self.metric(prediction, target)
|
loss = self.metric(prediction, target)
|
||||||
|
if self.logger:
|
||||||
|
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue