change save name

This commit is contained in:
shahules786 2022-11-14 16:30:14 +05:30
parent f8a44f823a
commit 12cde1b0ab
1 changed files with 5 additions and 4 deletions

View File

@ -24,6 +24,7 @@ CACHE_DIR = os.getenv(
) )
HF_TORCH_WEIGHTS = "pytorch_model.ckpt" HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu" DEFAULT_DEVICE = "cpu"
SAVE_NAME = "enhancer"
class Model(pl.LightningModule): class Model(pl.LightningModule):
@ -233,8 +234,8 @@ class Model(pl.LightningModule):
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["mayavoz"] = { checkpoint[SAVE_NAME] = {
"version": {"mayavoz": __version__, "pytorch": torch.__version__}, "version": {SAVE_NAME: __version__, "pytorch": torch.__version__},
"architecture": { "architecture": {
"module": self.__class__.__module__, "module": self.__class__.__module__,
"class": self.__class__.__name__, "class": self.__class__.__name__,
@ -327,8 +328,8 @@ class Model(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE) map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl, map_location) loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint["mayavoz"]["architecture"]["module"] module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"]
class_name = loaded_checkpoint["mayavoz"]["architecture"]["class"] class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"]
module = import_module(module_name) module = import_module(module_name)
Klass = getattr(module, class_name) Klass = getattr(module, class_name)