log average metrics
This commit is contained in:
parent
fdce8bb601
commit
abcdc29309
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Text, Union
|
from typing import List, Optional, Text, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -192,6 +193,51 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
return metric_dict
|
return metric_dict
|
||||||
|
|
||||||
|
def training_epoch_end(self, outputs):
|
||||||
|
train_mean_loss = 0.0
|
||||||
|
for output in outputs:
|
||||||
|
train_mean_loss += output["loss"]
|
||||||
|
train_mean_loss /= len(outputs)
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.experiment.log_metric(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
key="train_loss_epoch",
|
||||||
|
value=train_mean_loss,
|
||||||
|
step=self.current_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs):
|
||||||
|
valid_mean_loss = 0.0
|
||||||
|
for output in outputs:
|
||||||
|
valid_mean_loss += output["loss"]
|
||||||
|
valid_mean_loss /= len(outputs)
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.experiment.log_metric(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
key="valid_loss_epoch",
|
||||||
|
value=valid_mean_loss,
|
||||||
|
step=self.current_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
|
||||||
|
test_mean_metrics = defaultdict(int)
|
||||||
|
for output in outputs:
|
||||||
|
for metric, value in output.items():
|
||||||
|
test_mean_metrics[metric] += value.item()
|
||||||
|
for metric in test_mean_metrics.keys():
|
||||||
|
test_mean_metrics[metric] /= len(outputs)
|
||||||
|
|
||||||
|
for k, v in test_mean_metrics.items():
|
||||||
|
self.logger.experiment.log_metric(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
step=self.current_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint["enhancer"] = {
|
checkpoint["enhancer"] = {
|
||||||
|
|
@ -202,9 +248,6 @@ class Model(pl.LightningModule):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue