| import os |
| from pathlib import Path |
| from typing import Dict, Optional, Union |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
| from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, |
| SAFETENSORS_SINGLE_FILE) |
| from huggingface_hub.file_download import hf_hub_download |
| from huggingface_hub.utils import EntryNotFoundError, is_torch_available |
|
|
|
|
| if is_torch_available(): |
| import torch |
|
|
|
|
| class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin): |
| """Mixin class to load Pytorch models from the Hub.""" |
| def _save_pretrained(self, save_directory: Path) -> None: |
| """Save weights from a Pytorch model to a local directory.""" |
| |
| model_to_save = self.module if hasattr(self, "module") else self |
| torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) |
|
|
| @classmethod |
| def _from_pretrained( |
| cls, |
| *, |
| model_id: str, |
| revision: Optional[str], |
| cache_dir: Optional[Union[str, Path]], |
| force_download: bool, |
| proxies: Optional[Dict], |
| resume_download: Optional[bool], |
| local_files_only: bool, |
| token: Union[str, bool, None], |
| map_location: str = "cpu", |
| strict: bool = False, |
| **model_kwargs, |
| ): |
| """Load Pytorch pretrained weights and return the loaded model.""" |
| model = cls(**model_kwargs) |
| if os.path.isdir(model_id): |
| print("Loading weights from local directory") |
| try: |
| model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) |
| return cls._load_as_safetensor(model, model_file, map_location, strict) |
| except FileNotFoundError: |
| model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) |
| return cls._load_as_pickle(model, model_file, map_location, strict) |
| else: |
| try: |
| model_file = hf_hub_download( |
| repo_id=model_id, |
| filename=SAFETENSORS_SINGLE_FILE, |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| return cls._load_as_safetensor(model, model_file, map_location, strict) |
| except EntryNotFoundError: |
| model_file = hf_hub_download( |
| repo_id=model_id, |
| filename=PYTORCH_WEIGHTS_NAME, |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| return cls._load_as_pickle(model, model_file, map_location, strict) |