Spaces:
Running
Running
| # core/fetcher.py | |
| """ | |
| HTTP Range Request ่ฏปๅ safetensors ๆ้ | |
| ้ถไธ่ฝฝ๏ผ็ดๆฅไป HuggingFace ่ฟ็จ่ฏปๅ | |
| """ | |
| import struct | |
| import json | |
| import requests | |
| import torch | |
| from huggingface_hub import list_repo_files | |
| from core.debug import dprint | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # dtype ๆ ๅฐ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| DTYPE_MAP = { | |
| "F32": (torch.float32, 4), | |
| "F16": (torch.float16, 2), | |
| "BF16": (torch.bfloat16, 2), | |
| "F64": (torch.float64, 8), | |
| "I32": (torch.int32, 4), | |
| "I64": (torch.int64, 8), | |
| "I8": (torch.int8, 1), | |
| "U8": (torch.uint8, 1), | |
| } | |
| try: | |
| DTYPE_MAP["F8_E4M3"] = (torch.float8_e4m3fn, 1) | |
| DTYPE_MAP["F8_E5M2"] = (torch.float8_e5m2, 1) | |
| except AttributeError: | |
| pass | |
| UNSUPPORTED_SVD_DTYPES = {"I8", "U8", "I32", "I64", "F8_E4M3", "F8_E5M2"} | |
| QUANTIZED_KEY_SIGNATURES = ["qweight", "qzeros", "scales", "g_idx", "packed_weight"] | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # URL ๅทฅๅ ท | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def get_file_url(model_id: str, filename: str) -> str: | |
| return f"https://huggingface.co/{model_id}/resolve/main/{filename}" | |
| def http_error_msg(e: requests.exceptions.HTTPError, model_id: str) -> str: | |
| code = e.response.status_code | |
| if code == 401: return "โ 401 ๆชๆๆ๏ผ่ฏทๅกซๅๆๆ็ HF Access Token" | |
| if code == 403: return f"โ 403 ็ฆๆญข่ฎฟ้ฎ๏ผ่ฏทๅ ๆฅๅ {model_id} ็ไฝฟ็จๅ่ฎฎ" | |
| if code == 404: return f"โ 404 ๆชๆพๅฐ๏ผๆจกๅ {model_id} ไธๅญๅจ" | |
| return f"โ HTTP {code}๏ผ{e}" | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # safetensors header ่ฏปๅ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def read_safetensors_header(url: str, token: str = None) -> tuple[dict, int]: | |
| """่ฏปๅ safetensors ๆไปถๅคด๏ผ่ฟๅ (header_dict, header_size)""" | |
| hdrs = {"Authorization": f"Bearer {token}"} if token else {} | |
| r = requests.get(url, headers={**hdrs, "Range": "bytes=0-7"}, timeout=30) | |
| r.raise_for_status() | |
| header_size = struct.unpack("<Q", r.content)[0] | |
| r = requests.get( | |
| url, | |
| headers={**hdrs, "Range": f"bytes=8-{8 + header_size - 1}"}, | |
| timeout=30 | |
| ) | |
| r.raise_for_status() | |
| raw = json.loads(r.content) | |
| raw.pop("__metadata__", None) | |
| return raw, header_size | |
| def load_tensor_remote( | |
| url: str, | |
| tensor_name: str, | |
| header: dict, | |
| header_size: int, | |
| token: str = None | |
| ) -> torch.Tensor | None: | |
| if tensor_name not in header: | |
| return None | |
| info = header[tensor_name] | |
| dtype_str = info["dtype"] | |
| shape = info["shape"] | |
| offsets = info["data_offsets"] | |
| if dtype_str not in DTYPE_MAP: | |
| raise ValueError(f"ๆช็ฅ dtype: {dtype_str}") | |
| if dtype_str in UNSUPPORTED_SVD_DTYPES: | |
| raise ValueError(f"dtype={dtype_str} ไธบ้ๅๆ ผๅผ๏ผๆ ๆณ SVD") | |
| torch_dtype, bytes_per_elem = DTYPE_MAP[dtype_str] | |
| abs_start = 8 + header_size + offsets[0] | |
| abs_end = 8 + header_size + offsets[1] - 1 | |
| # โโ ่ฐ่ฏ๏ผๆๅฐๅ็งปไฟกๆฏ โโโโโโโโโโโโโโโโโโโโโโโโ | |
| expected_bytes = offsets[1] - offsets[0] | |
| expected_elems = 1 | |
| for d in shape: | |
| expected_elems *= d | |
| dprint( | |
| f"[FETCH] {tensor_name}\n" | |
| f" shape={shape} dtype={dtype_str}\n" | |
| f" data_offsets={offsets}\n" | |
| f" abs_start={abs_start} abs_end={abs_end}\n" | |
| f" expected_bytes={expected_bytes} " | |
| f"expected_elems={expected_elems} " | |
| f"bytes_per_elem={bytes_per_elem}\n" | |
| f" check: {expected_elems * bytes_per_elem} == {expected_bytes} " | |
| f"{'โ ' if expected_elems * bytes_per_elem == expected_bytes else 'โ ไธๅน้ !'}\n" | |
| ) | |
| req_headers = {"Range": f"bytes={abs_start}-{abs_end}"} | |
| if token: | |
| req_headers["Authorization"] = f"Bearer {token}" | |
| r = requests.get(url, headers=req_headers, timeout=120) | |
| r.raise_for_status() | |
| # โโ ่ฐ่ฏ๏ผๆๅฐๅฎ้ ๆถๅฐ็ๅญ่ๆฐ โโโโโโโโโโโโโโโโ | |
| actual_bytes = len(r.content) | |
| dprint( | |
| f" actual_bytes={actual_bytes} " | |
| f"{'โ ' if actual_bytes == expected_bytes else 'โ ๅญ่ๆฐไธๅน้ !'}\n" | |
| f" ๅ8ๅญ่(hex)={r.content[:8].hex()}\n" | |
| ) | |
| if torch_dtype == torch.bfloat16: | |
| tensor = torch.frombuffer( | |
| bytearray(r.content), dtype=torch.int16 | |
| ).view(torch.bfloat16) | |
| else: | |
| tensor = torch.frombuffer(bytearray(r.content), dtype=torch_dtype) | |
| result = tensor.reshape(shape).float() | |
| # โโ ่ฐ่ฏ๏ผๆๅฐ็ปๆ้ฆ่ก โโโโโโโโโโโโโโโโโโโโโโโโ | |
| dprint(f" result[0,:5]={result[0,:5].tolist()}\n") | |
| return result | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆไปถๅ่กจ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def get_safetensor_files(model_id: str, token: str = None) -> list[str]: | |
| kwargs = {"token": token} if token else {} | |
| return sorted( | |
| f for f in list_repo_files(model_id, **kwargs) | |
| if f.endswith(".safetensors") | |
| ) | |
| def find_index_file(model_id: str, token: str = None) -> dict | None: | |
| url = f"https://huggingface.co/{model_id}/resolve/main/model.safetensors.index.json" | |
| hdrs = {"Authorization": f"Bearer {token}"} if token else {} | |
| r = requests.get(url, headers=hdrs, timeout=15) | |
| return r.json() if r.status_code == 200 else None | |
| def get_all_shard_files(model_id: str, token: str = None) -> list[str]: | |
| """่ทๅๆๆ shard ๆไปถๅๅ่กจ""" | |
| index = find_index_file(model_id, token) | |
| if index: | |
| return sorted(set(index["weight_map"].values())) | |
| return get_safetensor_files(model_id, token) | |
| def load_all_shard_headers( | |
| model_id: str, | |
| token: str = None | |
| ) -> dict[str, tuple[dict, int]]: | |
| """ | |
| ่ฏปๅๆๆ shard ็ header | |
| ่ฟๅ๏ผ{ shard_filename: (header_dict, header_size) } | |
| """ | |
| shard_files = get_all_shard_files(model_id, token) | |
| result = {} | |
| for sf in shard_files: | |
| url = get_file_url(model_id, sf) | |
| h, hs = read_safetensors_header(url, token) | |
| result[sf] = (h, hs) | |
| return result | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ้ๅๆฃๆต | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def check_quantization(model_id: str, token: str = None) -> tuple[bool, str]: | |
| """ | |
| ไธ้้ๅๆฃๆต | |
| ่ฟๅ (is_blocked, message) | |
| """ | |
| hdrs = {"Authorization": f"Bearer {token}"} if token else {} | |
| warnings = [] | |
| # ๆฃๆต1๏ผconfig.json | |
| try: | |
| r = requests.get( | |
| f"https://huggingface.co/{model_id}/resolve/main/config.json", | |
| headers=hdrs, timeout=15 | |
| ) | |
| if r.status_code == 200: | |
| cfg = r.json() | |
| qcfg = cfg.get("quantization_config", {}) or {} | |
| qt = ( | |
| qcfg.get("quant_type", "") or | |
| qcfg.get("quant_method", "") or | |
| cfg.get("quantization", "") | |
| ).lower() | |
| if "gptq" in qt: | |
| return True, f"โ GPTQ {qcfg.get('bits','?')}bit๏ผ่ฏท็จๅๅง BF16 ็ๆฌใ" | |
| if "awq" in qt: | |
| return True, "โ AWQ ้ๅ๏ผ่ฏท็จๅๅง BF16 ็ๆฌใ" | |
| if "bitsandbytes" in qt or "bnb" in qt: | |
| warnings.append("โ ๏ธ bitsandbytes ้ๅ๏ผ็ปๆๅฏ่ฝๅคฑ็") | |
| except Exception: | |
| warnings.append("โ ๏ธ ๆ ๆณ่ฏปๅ config.json") | |
| # ๆฃๆต2๏ผๆจกๅๅๅ ณ้ฎ่ฏ | |
| for kw in ["gptq", "awq", "gguf"]: | |
| if kw in model_id.lower(): | |
| return True, f"โ ๆจกๅๅๅซ '{kw.upper()}'๏ผ่ฏทไฝฟ็จๅๅง BF16 ็ๆฌใ" | |
| # ๆฃๆต3๏ผๆไปถ็บงๅซ | |
| try: | |
| all_files = list(list_repo_files(model_id, token=token)) | |
| if any(f.endswith(".gguf") for f in all_files): | |
| return True, "โ ๆฃๆตๅฐ .gguf ๆไปถ๏ผไธๆฏๆ่ฏฅๆ ผๅผใ" | |
| if not any(f.endswith(".safetensors") for f in all_files): | |
| return True, "โ ๆชๆพๅฐ .safetensors ๆไปถใ" | |
| except Exception as e: | |
| warnings.append(f"โ ๏ธ ๆไปถๅ่กจๆฃๆตๅคฑ่ดฅ๏ผ{e}") | |
| # ๆฃๆต4๏ผheader ๅ ๅฎน | |
| try: | |
| shard_files = get_all_shard_files(model_id, token) | |
| hdr, _ = read_safetensors_header( | |
| get_file_url(model_id, shard_files[0]), token | |
| ) | |
| bad = [k for k in hdr if any(s in k for s in QUANTIZED_KEY_SIGNATURES)] | |
| if bad: | |
| return True, f"โ ้ๅ key๏ผ{bad[:3]}" | |
| good = {hdr[k].get("dtype", "") for k in list(hdr)[:20]} - UNSUPPORTED_SVD_DTYPES | |
| if good: | |
| warnings.append(f"โ ๆ้ๆ ผๅผ๏ผ{good}") | |
| except Exception as e: | |
| warnings.append(f"โ ๏ธ header ๆฃๆตๅคฑ่ดฅ๏ผ{e}") | |
| return False, "\n".join(warnings) if warnings else "โ ๆชๆฃๆตๅฐ้ๅ๏ผๅฏไปฅๆญฃๅธธๅๆ" |