load from pretrained
This commit is contained in:
parent
c341b23ff8
commit
f7c09d8045
|
|
@ -1,12 +1,21 @@
|
|||
from typing import Optional, Union, List
|
||||
from importlib import import_module
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
import os
|
||||
from typing import Optional, Union, List, Path, Text
|
||||
from torch.optim import Adam
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
from enhancer import __version__
|
||||
from enhancer.data.dataset import Dataset
|
||||
from enhancer.utils.loss import Avergeloss
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
|
||||
|
|
@ -91,5 +100,66 @@ class Model(pl.LightningModule):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,):
|
||||
pass
|
||||
def from_pretrained(
|
||||
cls,
|
||||
checkpoint: Union[Path, Text],
|
||||
map_location = None,
|
||||
hparams_file: Union[Path, Text] = None,
|
||||
strict: bool = True,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
cached_dir: Union[Path, Text]=CACHE_DIR,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
checkpoint = str(checkpoint)
|
||||
if hparams_file is not None:
|
||||
hparams_file = str(hparams_file)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
model_path_pl = checkpoint
|
||||
elif urlparse(checkpoint).scheme in ("http","https"):
|
||||
model_path_pl = checkpoint
|
||||
else:
|
||||
|
||||
if "@" in checkpoint:
|
||||
model_id = checkpoint.split("@")[0]
|
||||
revision_id = checkpoint.split("@")[1]
|
||||
else:
|
||||
model_id = checkpoint
|
||||
revision_id = None
|
||||
|
||||
url = hf_hub_url(
|
||||
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
||||
)
|
||||
model_path_pl = cached_download(
|
||||
url=url,library_name="enhancer",library_version=__version__,
|
||||
cache_dir=cached_dir,use_auth_token=use_auth_token
|
||||
)
|
||||
|
||||
if map_location is None:
|
||||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||
module_name = loaded_checkpoint["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
try:
|
||||
model = Klass.load_from_checkpoint(
|
||||
checkpoint_path = model_path_pl,
|
||||
map_location = map_location,
|
||||
hparams_file = hparams_file,
|
||||
strict = strict,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue