save checkpoint attributes
This commit is contained in:
parent
381067e4ef
commit
c341b23ff8
|
|
@ -1,9 +1,11 @@
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer import __version__
|
||||||
from enhancer.data.dataset import Dataset
|
from enhancer.data.dataset import Dataset
|
||||||
from enhancer.utils.loss import LOSS_MAP, Avergeloss
|
from enhancer.utils.loss import Avergeloss
|
||||||
|
|
||||||
|
|
||||||
class Model(pl.LightningModule):
|
class Model(pl.LightningModule):
|
||||||
|
|
@ -74,6 +76,20 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
|
checkpoint["enhancer"] = {
|
||||||
|
"version": {
|
||||||
|
"enhancer":__version__,
|
||||||
|
"pytorch":torch.__version__
|
||||||
|
},
|
||||||
|
"architecture":{
|
||||||
|
"module":self.__class__.__module__,
|
||||||
|
"class":self.__class__.__name__
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,):
|
def from_pretrained(cls,):
|
||||||
pass
|
pass
|
||||||
Loading…
Reference in New Issue