save checkpoint attributes
This commit is contained in:
parent
381067e4ef
commit
c341b23ff8
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Optional, Union, List
|
||||
from torch.optim import Adam
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from enhancer import __version__
|
||||
from enhancer.data.dataset import Dataset
|
||||
from enhancer.utils.loss import LOSS_MAP, Avergeloss
|
||||
from enhancer.utils.loss import Avergeloss
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
|
|
@ -74,6 +76,20 @@ class Model(pl.LightningModule):
|
|||
|
||||
return {"loss":loss}
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["enhancer"] = {
|
||||
"version": {
|
||||
"enhancer":__version__,
|
||||
"pytorch":torch.__version__
|
||||
},
|
||||
"architecture":{
|
||||
"module":self.__class__.__module__,
|
||||
"class":self.__class__.__name__
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,):
|
||||
pass
|
||||
Loading…
Reference in New Issue