File size: 2,406 Bytes
494c9e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """HuggingFace 模型下载与加载:下载独立,加载仅考虑本地"""
import json
import os
from typing import Callable, TypeVar
T = TypeVar("T")
# 与 transformers 的 checkpoint 命名一致
_SAFE_WEIGHTS = "model.safetensors"
_SAFE_WEIGHTS_INDEX = "model.safetensors.index.json"
_WEIGHTS = "pytorch_model.bin"
_WEIGHTS_INDEX = "pytorch_model.bin.index.json"
def _is_model_cache_complete(local_path: str) -> bool:
"""
本地检查模型权重是否完整。与 transformers 的 _get_resolved_checkpoint_files 逻辑一致。
"""
def _p(f: str) -> str:
return os.path.join(local_path, f)
if os.path.isfile(_p(_SAFE_WEIGHTS)):
return True
index_file = _p(_SAFE_WEIGHTS_INDEX)
if os.path.isfile(index_file):
with open(index_file) as f:
index = json.load(f)
shards = set(index.get("weight_map", {}).values())
return all(os.path.isfile(_p(s)) for s in shards)
if os.path.isfile(_p(_WEIGHTS)):
return True
index_file = _p(_WEIGHTS_INDEX)
if os.path.isfile(index_file):
with open(index_file) as f:
index = json.load(f)
shards = set(index.get("weight_map", {}).values())
return all(os.path.isfile(_p(s)) for s in shards)
return False
def ensure_model_local(model_path: str, *, force_download: bool = False) -> str:
"""
确保模型在本地可用,返回本地路径。
- 本地目录:直接返回
- HuggingFace ID:优先用本地缓存(不联网),缓存不完整时 force_download 可触发下载
"""
if os.path.isdir(model_path):
return model_path
if "/" in model_path and not os.path.exists(model_path):
from huggingface_hub import snapshot_download
if force_download:
return snapshot_download(model_path)
try:
path = snapshot_download(model_path, local_files_only=True)
if not _is_model_cache_complete(path):
return snapshot_download(model_path)
return path
except Exception:
return snapshot_download(model_path)
return model_path
def resolve_and_load(model_path: str, loader: Callable[[str, bool], T]) -> T:
"""
先确保模型本地可用,再加载。加载时始终使用 local_files_only=True。
"""
path = ensure_model_local(model_path)
return loader(path, True)
|