From 12cde1b0abbd8ef8bf7b99b5393c843d8579232c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 14 Nov 2022 16:30:14 +0530 Subject: [PATCH] change save name --- mayavoz/models/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index 2957e5b..d82c5c5 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -24,6 +24,7 @@ CACHE_DIR = os.getenv( ) HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu" +SAVE_NAME = "enhancer" class Model(pl.LightningModule): @@ -233,8 +234,8 @@ class Model(pl.LightningModule): def on_save_checkpoint(self, checkpoint): - checkpoint["mayavoz"] = { - "version": {"mayavoz": __version__, "pytorch": torch.__version__}, + checkpoint[SAVE_NAME] = { + "version": {SAVE_NAME: __version__, "pytorch": torch.__version__}, "architecture": { "module": self.__class__.__module__, "class": self.__class__.__name__, @@ -327,8 +328,8 @@ class Model(pl.LightningModule): map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl, map_location) - module_name = loaded_checkpoint["mayavoz"]["architecture"]["module"] - class_name = loaded_checkpoint["mayavoz"]["architecture"]["class"] + module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"] + class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name)