math-under-llm / core /fetcher.py
Alex W.
feat: add global debug switch and unified debug logging system
0105df7
# core/fetcher.py
"""
HTTP Range Request ่ฏปๅ– safetensors ๆƒ้‡
้›ถไธ‹่ฝฝ๏ผŒ็›ดๆŽฅไปŽ HuggingFace ่ฟœ็จ‹่ฏปๅ–
"""
import struct
import json
import requests
import torch
from huggingface_hub import list_repo_files
from core.debug import dprint
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# dtype ๆ˜ ๅฐ„
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
DTYPE_MAP = {
"F32": (torch.float32, 4),
"F16": (torch.float16, 2),
"BF16": (torch.bfloat16, 2),
"F64": (torch.float64, 8),
"I32": (torch.int32, 4),
"I64": (torch.int64, 8),
"I8": (torch.int8, 1),
"U8": (torch.uint8, 1),
}
try:
DTYPE_MAP["F8_E4M3"] = (torch.float8_e4m3fn, 1)
DTYPE_MAP["F8_E5M2"] = (torch.float8_e5m2, 1)
except AttributeError:
pass
UNSUPPORTED_SVD_DTYPES = {"I8", "U8", "I32", "I64", "F8_E4M3", "F8_E5M2"}
QUANTIZED_KEY_SIGNATURES = ["qweight", "qzeros", "scales", "g_idx", "packed_weight"]
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# URL ๅทฅๅ…ท
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def get_file_url(model_id: str, filename: str) -> str:
return f"https://huggingface.co/{model_id}/resolve/main/{filename}"
def http_error_msg(e: requests.exceptions.HTTPError, model_id: str) -> str:
code = e.response.status_code
if code == 401: return "โŒ 401 ๆœชๆŽˆๆƒ๏ผš่ฏทๅกซๅ†™ๆœ‰ๆ•ˆ็š„ HF Access Token"
if code == 403: return f"โŒ 403 ็ฆๆญข่ฎฟ้—ฎ๏ผš่ฏทๅ…ˆๆŽฅๅ— {model_id} ็š„ไฝฟ็”จๅ่ฎฎ"
if code == 404: return f"โŒ 404 ๆœชๆ‰พๅˆฐ๏ผšๆจกๅž‹ {model_id} ไธๅญ˜ๅœจ"
return f"โŒ HTTP {code}๏ผš{e}"
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# safetensors header ่ฏปๅ–
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def read_safetensors_header(url: str, token: str = None) -> tuple[dict, int]:
"""่ฏปๅ– safetensors ๆ–‡ไปถๅคด๏ผŒ่ฟ”ๅ›ž (header_dict, header_size)"""
hdrs = {"Authorization": f"Bearer {token}"} if token else {}
r = requests.get(url, headers={**hdrs, "Range": "bytes=0-7"}, timeout=30)
r.raise_for_status()
header_size = struct.unpack("<Q", r.content)[0]
r = requests.get(
url,
headers={**hdrs, "Range": f"bytes=8-{8 + header_size - 1}"},
timeout=30
)
r.raise_for_status()
raw = json.loads(r.content)
raw.pop("__metadata__", None)
return raw, header_size
def load_tensor_remote(
url: str,
tensor_name: str,
header: dict,
header_size: int,
token: str = None
) -> torch.Tensor | None:
if tensor_name not in header:
return None
info = header[tensor_name]
dtype_str = info["dtype"]
shape = info["shape"]
offsets = info["data_offsets"]
if dtype_str not in DTYPE_MAP:
raise ValueError(f"ๆœช็Ÿฅ dtype: {dtype_str}")
if dtype_str in UNSUPPORTED_SVD_DTYPES:
raise ValueError(f"dtype={dtype_str} ไธบ้‡ๅŒ–ๆ ผๅผ๏ผŒๆ— ๆณ• SVD")
torch_dtype, bytes_per_elem = DTYPE_MAP[dtype_str]
abs_start = 8 + header_size + offsets[0]
abs_end = 8 + header_size + offsets[1] - 1
# โ”€โ”€ ่ฐƒ่ฏ•๏ผšๆ‰“ๅฐๅ็งปไฟกๆฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
expected_bytes = offsets[1] - offsets[0]
expected_elems = 1
for d in shape:
expected_elems *= d
dprint(
f"[FETCH] {tensor_name}\n"
f" shape={shape} dtype={dtype_str}\n"
f" data_offsets={offsets}\n"
f" abs_start={abs_start} abs_end={abs_end}\n"
f" expected_bytes={expected_bytes} "
f"expected_elems={expected_elems} "
f"bytes_per_elem={bytes_per_elem}\n"
f" check: {expected_elems * bytes_per_elem} == {expected_bytes} "
f"{'โœ…' if expected_elems * bytes_per_elem == expected_bytes else 'โŒ ไธๅŒน้…!'}\n"
)
req_headers = {"Range": f"bytes={abs_start}-{abs_end}"}
if token:
req_headers["Authorization"] = f"Bearer {token}"
r = requests.get(url, headers=req_headers, timeout=120)
r.raise_for_status()
# โ”€โ”€ ่ฐƒ่ฏ•๏ผšๆ‰“ๅฐๅฎž้™…ๆ”ถๅˆฐ็š„ๅญ—่Š‚ๆ•ฐ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
actual_bytes = len(r.content)
dprint(
f" actual_bytes={actual_bytes} "
f"{'โœ…' if actual_bytes == expected_bytes else 'โŒ ๅญ—่Š‚ๆ•ฐไธๅŒน้…!'}\n"
f" ๅ‰8ๅญ—่Š‚(hex)={r.content[:8].hex()}\n"
)
if torch_dtype == torch.bfloat16:
tensor = torch.frombuffer(
bytearray(r.content), dtype=torch.int16
).view(torch.bfloat16)
else:
tensor = torch.frombuffer(bytearray(r.content), dtype=torch_dtype)
result = tensor.reshape(shape).float()
# โ”€โ”€ ่ฐƒ่ฏ•๏ผšๆ‰“ๅฐ็ป“ๆžœ้ฆ–่กŒ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
dprint(f" result[0,:5]={result[0,:5].tolist()}\n")
return result
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆ–‡ไปถๅˆ—่กจ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def get_safetensor_files(model_id: str, token: str = None) -> list[str]:
kwargs = {"token": token} if token else {}
return sorted(
f for f in list_repo_files(model_id, **kwargs)
if f.endswith(".safetensors")
)
def find_index_file(model_id: str, token: str = None) -> dict | None:
url = f"https://huggingface.co/{model_id}/resolve/main/model.safetensors.index.json"
hdrs = {"Authorization": f"Bearer {token}"} if token else {}
r = requests.get(url, headers=hdrs, timeout=15)
return r.json() if r.status_code == 200 else None
def get_all_shard_files(model_id: str, token: str = None) -> list[str]:
"""่Žทๅ–ๆ‰€ๆœ‰ shard ๆ–‡ไปถๅๅˆ—่กจ"""
index = find_index_file(model_id, token)
if index:
return sorted(set(index["weight_map"].values()))
return get_safetensor_files(model_id, token)
def load_all_shard_headers(
model_id: str,
token: str = None
) -> dict[str, tuple[dict, int]]:
"""
่ฏปๅ–ๆ‰€ๆœ‰ shard ็š„ header
่ฟ”ๅ›ž๏ผš{ shard_filename: (header_dict, header_size) }
"""
shard_files = get_all_shard_files(model_id, token)
result = {}
for sf in shard_files:
url = get_file_url(model_id, sf)
h, hs = read_safetensors_header(url, token)
result[sf] = (h, hs)
return result
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ้‡ๅŒ–ๆฃ€ๆต‹
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def check_quantization(model_id: str, token: str = None) -> tuple[bool, str]:
"""
ไธ‰้‡้‡ๅŒ–ๆฃ€ๆต‹
่ฟ”ๅ›ž (is_blocked, message)
"""
hdrs = {"Authorization": f"Bearer {token}"} if token else {}
warnings = []
# ๆฃ€ๆต‹1๏ผšconfig.json
try:
r = requests.get(
f"https://huggingface.co/{model_id}/resolve/main/config.json",
headers=hdrs, timeout=15
)
if r.status_code == 200:
cfg = r.json()
qcfg = cfg.get("quantization_config", {}) or {}
qt = (
qcfg.get("quant_type", "") or
qcfg.get("quant_method", "") or
cfg.get("quantization", "")
).lower()
if "gptq" in qt:
return True, f"โŒ GPTQ {qcfg.get('bits','?')}bit๏ผŒ่ฏท็”จๅŽŸๅง‹ BF16 ็‰ˆๆœฌใ€‚"
if "awq" in qt:
return True, "โŒ AWQ ้‡ๅŒ–๏ผŒ่ฏท็”จๅŽŸๅง‹ BF16 ็‰ˆๆœฌใ€‚"
if "bitsandbytes" in qt or "bnb" in qt:
warnings.append("โš ๏ธ bitsandbytes ้‡ๅŒ–๏ผŒ็ป“ๆžœๅฏ่ƒฝๅคฑ็œŸ")
except Exception:
warnings.append("โš ๏ธ ๆ— ๆณ•่ฏปๅ– config.json")
# ๆฃ€ๆต‹2๏ผšๆจกๅž‹ๅๅ…ณ้”ฎ่ฏ
for kw in ["gptq", "awq", "gguf"]:
if kw in model_id.lower():
return True, f"โŒ ๆจกๅž‹ๅๅซ '{kw.upper()}'๏ผŒ่ฏทไฝฟ็”จๅŽŸๅง‹ BF16 ็‰ˆๆœฌใ€‚"
# ๆฃ€ๆต‹3๏ผšๆ–‡ไปถ็บงๅˆซ
try:
all_files = list(list_repo_files(model_id, token=token))
if any(f.endswith(".gguf") for f in all_files):
return True, "โŒ ๆฃ€ๆต‹ๅˆฐ .gguf ๆ–‡ไปถ๏ผŒไธๆ”ฏๆŒ่ฏฅๆ ผๅผใ€‚"
if not any(f.endswith(".safetensors") for f in all_files):
return True, "โŒ ๆœชๆ‰พๅˆฐ .safetensors ๆ–‡ไปถใ€‚"
except Exception as e:
warnings.append(f"โš ๏ธ ๆ–‡ไปถๅˆ—่กจๆฃ€ๆต‹ๅคฑ่ดฅ๏ผš{e}")
# ๆฃ€ๆต‹4๏ผšheader ๅ†…ๅฎน
try:
shard_files = get_all_shard_files(model_id, token)
hdr, _ = read_safetensors_header(
get_file_url(model_id, shard_files[0]), token
)
bad = [k for k in hdr if any(s in k for s in QUANTIZED_KEY_SIGNATURES)]
if bad:
return True, f"โŒ ้‡ๅŒ– key๏ผš{bad[:3]}"
good = {hdr[k].get("dtype", "") for k in list(hdr)[:20]} - UNSUPPORTED_SVD_DTYPES
if good:
warnings.append(f"โœ… ๆƒ้‡ๆ ผๅผ๏ผš{good}")
except Exception as e:
warnings.append(f"โš ๏ธ header ๆฃ€ๆต‹ๅคฑ่ดฅ๏ผš{e}")
return False, "\n".join(warnings) if warnings else "โœ… ๆœชๆฃ€ๆต‹ๅˆฐ้‡ๅŒ–๏ผŒๅฏไปฅๆญฃๅธธๅˆ†ๆž"