load from pretrained

This commit is contained in:
shahules786 2022-09-15 21:09:49 +05:30
parent c341b23ff8
commit f7c09d8045
1 changed files with 74 additions and 4 deletions

View File

@ -1,12 +1,21 @@
from typing import Optional, Union, List from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import os
from typing import Optional, Union, List, Path, Text
from torch.optim import Adam from torch.optim import Adam
import pytorch_lightning as pl
import torch import torch
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse
from enhancer import __version__ from enhancer import __version__
from enhancer.data.dataset import Dataset from enhancer.data.dataset import Dataset
from enhancer.utils.loss import Avergeloss from enhancer.utils.loss import Avergeloss
CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
DEFAULT_DEVICE = "cpu"
class Model(pl.LightningModule): class Model(pl.LightningModule):
@ -91,5 +100,66 @@ class Model(pl.LightningModule):
} }
@classmethod @classmethod
def from_pretrained(cls,): def from_pretrained(
pass cls,
checkpoint: Union[Path, Text],
map_location = None,
hparams_file: Union[Path, Text] = None,
strict: bool = True,
use_auth_token: Union[Text, None] = None,
cached_dir: Union[Path, Text]=CACHE_DIR,
**kwargs
):
checkpoint = str(checkpoint)
if hparams_file is not None:
hparams_file = str(hparams_file)
if os.path.isfile(checkpoint):
model_path_pl = checkpoint
elif urlparse(checkpoint).scheme in ("http","https"):
model_path_pl = checkpoint
else:
if "@" in checkpoint:
model_id = checkpoint.split("@")[0]
revision_id = checkpoint.split("@")[1]
else:
model_id = checkpoint
revision_id = None
url = hf_hub_url(
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
)
model_path_pl = cached_download(
url=url,library_name="enhancer",library_version=__version__,
cache_dir=cached_dir,use_auth_token=use_auth_token
)
if map_location is None:
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location)
module_name = loaded_checkpoint["architecture"]["module"]
class_name = loaded_checkpoint["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)
try:
model = Klass.load_from_checkpoint(
checkpoint_path = model_path_pl,
map_location = map_location,
hparams_file = hparams_file,
strict = strict,
**kwargs
)
except Exception as e:
print(e)