log average metrics

This commit is contained in:
shahules786 2022-10-11 16:48:49 +05:30
parent fdce8bb601
commit abcdc29309
1 changed files with 47 additions and 4 deletions

View File

@ -1,7 +1,8 @@
import os import os
from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Union from typing import List, Optional, Text, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -192,6 +193,51 @@ class Model(pl.LightningModule):
return metric_dict return metric_dict
def training_epoch_end(self, outputs):
train_mean_loss = 0.0
for output in outputs:
train_mean_loss += output["loss"]
train_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss_epoch",
value=train_mean_loss,
step=self.current_epoch,
)
def validation_epoch_end(self, outputs):
valid_mean_loss = 0.0
for output in outputs:
valid_mean_loss += output["loss"]
valid_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="valid_loss_epoch",
value=valid_mean_loss,
step=self.current_epoch,
)
def test_epoch_end(self, outputs):
test_mean_metrics = defaultdict(int)
for output in outputs:
for metric, value in output.items():
test_mean_metrics[metric] += value.item()
for metric in test_mean_metrics.keys():
test_mean_metrics[metric] /= len(outputs)
for k, v in test_mean_metrics.items():
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key=k,
value=v,
step=self.current_epoch,
)
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = { checkpoint["enhancer"] = {
@ -202,9 +248,6 @@ class Model(pl.LightningModule):
}, },
} }
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,