diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 275027b..d697fcd 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,9 +1,11 @@ from typing import Optional, Union, List from torch.optim import Adam import pytorch_lightning as pl +import torch +from enhancer import __version__ from enhancer.data.dataset import Dataset -from enhancer.utils.loss import LOSS_MAP, Avergeloss +from enhancer.utils.loss import Avergeloss class Model(pl.LightningModule): @@ -74,6 +76,20 @@ class Model(pl.LightningModule): return {"loss":loss} + def on_save_checkpoint(self, checkpoint): + + checkpoint["enhancer"] = { + "version": { + "enhancer":__version__, + "pytorch":torch.__version__ + }, + "architecture":{ + "module":self.__class__.__module__, + "class":self.__class__.__name__ + } + + } + @classmethod def from_pretrained(cls,): pass \ No newline at end of file