change save name
This commit is contained in:
parent
f8a44f823a
commit
12cde1b0ab
|
|
@ -24,6 +24,7 @@ CACHE_DIR = os.getenv(
|
|||
)
|
||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
SAVE_NAME = "enhancer"
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
|
|
@ -233,8 +234,8 @@ class Model(pl.LightningModule):
|
|||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["mayavoz"] = {
|
||||
"version": {"mayavoz": __version__, "pytorch": torch.__version__},
|
||||
checkpoint[SAVE_NAME] = {
|
||||
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__},
|
||||
"architecture": {
|
||||
"module": self.__class__.__module__,
|
||||
"class": self.__class__.__name__,
|
||||
|
|
@ -327,8 +328,8 @@ class Model(pl.LightningModule):
|
|||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl, map_location)
|
||||
module_name = loaded_checkpoint["mayavoz"]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["mayavoz"]["architecture"]["class"]
|
||||
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue