From f7c09d80451d232531aa686df4bb0f2834dfc7f8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 15 Sep 2022 21:09:49 +0530 Subject: [PATCH] load from pretrained --- enhancer/models/model.py | 78 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 4 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index d697fcd..07281c9 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,12 +1,21 @@ -from typing import Optional, Union, List +from importlib import import_module +from huggingface_hub import cached_download, hf_hub_url +import os +from typing import Optional, Union, List, Path, Text from torch.optim import Adam -import pytorch_lightning as pl import torch +import pytorch_lightning as pl +from pytorch_lightning.utilities.cloud_io import load as pl_load +from urllib.parse import urlparse + from enhancer import __version__ from enhancer.data.dataset import Dataset from enhancer.utils.loss import Avergeloss +CACHE_DIR = "" +HF_TORCH_WEIGHTS = "" +DEFAULT_DEVICE = "cpu" class Model(pl.LightningModule): @@ -91,5 +100,66 @@ class Model(pl.LightningModule): } @classmethod - def from_pretrained(cls,): - pass \ No newline at end of file + def from_pretrained( + cls, + checkpoint: Union[Path, Text], + map_location = None, + hparams_file: Union[Path, Text] = None, + strict: bool = True, + use_auth_token: Union[Text, None] = None, + cached_dir: Union[Path, Text]=CACHE_DIR, + **kwargs + ): + + checkpoint = str(checkpoint) + if hparams_file is not None: + hparams_file = str(hparams_file) + + if os.path.isfile(checkpoint): + model_path_pl = checkpoint + elif urlparse(checkpoint).scheme in ("http","https"): + model_path_pl = checkpoint + else: + + if "@" in checkpoint: + model_id = checkpoint.split("@")[0] + revision_id = checkpoint.split("@")[1] + else: + model_id = checkpoint + revision_id = None + + url = hf_hub_url( + model_id,filename=HF_TORCH_WEIGHTS,revision=revision_id + ) + model_path_pl = cached_download( + url=url,library_name="enhancer",library_version=__version__, + cache_dir=cached_dir,use_auth_token=use_auth_token + ) + + if map_location is None: + map_location = torch.device(DEFAULT_DEVICE) + + loaded_checkpoint = pl_load(model_path_pl,map_location) + module_name = loaded_checkpoint["architecture"]["module"] + class_name = loaded_checkpoint["architecture"]["class"] + module = import_module(module_name) + Klass = getattr(module, class_name) + + try: + model = Klass.load_from_checkpoint( + checkpoint_path = model_path_pl, + map_location = map_location, + hparams_file = hparams_file, + strict = strict, + **kwargs + ) + except Exception as e: + print(e) + + + + + + + + \ No newline at end of file