save checkpoint attributes

This commit is contained in:
shahules786 2022-09-15 19:08:08 +05:30
parent 381067e4ef
commit c341b23ff8
1 changed files with 17 additions and 1 deletions

View File

@ -1,9 +1,11 @@
from typing import Optional, Union, List from typing import Optional, Union, List
from torch.optim import Adam from torch.optim import Adam
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from enhancer import __version__
from enhancer.data.dataset import Dataset from enhancer.data.dataset import Dataset
from enhancer.utils.loss import LOSS_MAP, Avergeloss from enhancer.utils.loss import Avergeloss
class Model(pl.LightningModule): class Model(pl.LightningModule):
@ -74,6 +76,20 @@ class Model(pl.LightningModule):
return {"loss":loss} return {"loss":loss}
def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = {
"version": {
"enhancer":__version__,
"pytorch":torch.__version__
},
"architecture":{
"module":self.__class__.__module__,
"class":self.__class__.__name__
}
}
@classmethod @classmethod
def from_pretrained(cls,): def from_pretrained(cls,):
pass pass