load from pretrained
This commit is contained in:
parent
c341b23ff8
commit
f7c09d8045
|
|
@ -1,12 +1,21 @@
|
||||||
from typing import Optional, Union, List
|
from importlib import import_module
|
||||||
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
|
import os
|
||||||
|
from typing import Optional, Union, List, Path, Text
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
from enhancer import __version__
|
from enhancer import __version__
|
||||||
from enhancer.data.dataset import Dataset
|
from enhancer.data.dataset import Dataset
|
||||||
from enhancer.utils.loss import Avergeloss
|
from enhancer.utils.loss import Avergeloss
|
||||||
|
|
||||||
|
CACHE_DIR = ""
|
||||||
|
HF_TORCH_WEIGHTS = ""
|
||||||
|
DEFAULT_DEVICE = "cpu"
|
||||||
|
|
||||||
class Model(pl.LightningModule):
|
class Model(pl.LightningModule):
|
||||||
|
|
||||||
|
|
@ -91,5 +100,66 @@ class Model(pl.LightningModule):
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,):
|
def from_pretrained(
|
||||||
pass
|
cls,
|
||||||
|
checkpoint: Union[Path, Text],
|
||||||
|
map_location = None,
|
||||||
|
hparams_file: Union[Path, Text] = None,
|
||||||
|
strict: bool = True,
|
||||||
|
use_auth_token: Union[Text, None] = None,
|
||||||
|
cached_dir: Union[Path, Text]=CACHE_DIR,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
checkpoint = str(checkpoint)
|
||||||
|
if hparams_file is not None:
|
||||||
|
hparams_file = str(hparams_file)
|
||||||
|
|
||||||
|
if os.path.isfile(checkpoint):
|
||||||
|
model_path_pl = checkpoint
|
||||||
|
elif urlparse(checkpoint).scheme in ("http","https"):
|
||||||
|
model_path_pl = checkpoint
|
||||||
|
else:
|
||||||
|
|
||||||
|
if "@" in checkpoint:
|
||||||
|
model_id = checkpoint.split("@")[0]
|
||||||
|
revision_id = checkpoint.split("@")[1]
|
||||||
|
else:
|
||||||
|
model_id = checkpoint
|
||||||
|
revision_id = None
|
||||||
|
|
||||||
|
url = hf_hub_url(
|
||||||
|
model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id
|
||||||
|
)
|
||||||
|
model_path_pl = cached_download(
|
||||||
|
url=url,library_name="enhancer",library_version=__version__,
|
||||||
|
cache_dir=cached_dir,use_auth_token=use_auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
if map_location is None:
|
||||||
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
|
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||||
|
module_name = loaded_checkpoint["architecture"]["module"]
|
||||||
|
class_name = loaded_checkpoint["architecture"]["class"]
|
||||||
|
module = import_module(module_name)
|
||||||
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = Klass.load_from_checkpoint(
|
||||||
|
checkpoint_path = model_path_pl,
|
||||||
|
map_location = map_location,
|
||||||
|
hparams_file = hparams_file,
|
||||||
|
strict = strict,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue