change save name

This commit is contained in:
shahules786 2022-11-14 16:30:14 +05:30
parent f8a44f823a
commit 12cde1b0ab
1 changed files with 5 additions and 4 deletions

View File

@ -24,6 +24,7 @@ CACHE_DIR = os.getenv(
)
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu"
SAVE_NAME = "enhancer"
class Model(pl.LightningModule):
@ -233,8 +234,8 @@ class Model(pl.LightningModule):
def on_save_checkpoint(self, checkpoint):
checkpoint["mayavoz"] = {
"version": {"mayavoz": __version__, "pytorch": torch.__version__},
checkpoint[SAVE_NAME] = {
"version": {SAVE_NAME: __version__, "pytorch": torch.__version__},
"architecture": {
"module": self.__class__.__module__,
"class": self.__class__.__name__,
@ -327,8 +328,8 @@ class Model(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl, map_location)
module_name = loaded_checkpoint["mayavoz"]["architecture"]["module"]
class_name = loaded_checkpoint["mayavoz"]["architecture"]["class"]
module_name = loaded_checkpoint[SAVE_NAME]["architecture"]["module"]
class_name = loaded_checkpoint[SAVE_NAME]["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)