Spaces:
Running
Running
| # core/layer_profile.py | |
| """ | |
| ไป safetensors headers ่ชๅจๆจๆญๆฏไธๅฑ็็ปๆ๏ผ | |
| - head_dim๏ผไผๅ k_norm/q_norm shape๏ผๅ ถๆฌก config๏ผๆๅๆไธพ๏ผ | |
| - K=V ๅ ฑไบซๆฃๆต๏ผv_key ๆฏๅฆๅญๅจ๏ผ | |
| - ็ปไปถๅ็ผ่ชๅจๅ็ฆป | |
| - ้ถ hard coding | |
| """ | |
| import re | |
| from dataclasses import dataclass, field | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # QKV ๅ็ผๅ็ฑป | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ็ฒพ็กฎๆ้คๅ่กจ๏ผไธๆฏ Q/K/V ไธปๆ้๏ผ | |
| _EXCLUDE_PATTERNS = [ | |
| "norm", # layernorm, k_norm, q_norm ็ญ | |
| "rope", # rotary embedding | |
| "lm_head", | |
| "o_proj", # ่พๅบๆๅฝฑ | |
| "out_proj", | |
| "post", # audio tower ็ post linear | |
| "relative", # audio tower relative_k_proj | |
| "per_dim", # audio tower per_dim_scale | |
| "scalar", | |
| "gate_proj", # FFN | |
| "up_proj", | |
| "down_proj", | |
| "ffw_layer", # audio FFN | |
| "depthwise", | |
| "conv", | |
| "linear_start", | |
| "linear_end", | |
| "per_layer", | |
| "embed", | |
| "input_max", # audio ้ๅ็ป่ฎก้ | |
| "input_min", | |
| "output_max", | |
| "output_min", | |
| ] | |
| _Q_PATTERNS = ["q_proj", "wq", "query", "q_a", "q_b"] | |
| _K_PATTERNS = ["k_proj", "wk", "key", "k_a", "k_b"] | |
| _V_PATTERNS = ["v_proj", "wv", "value", "v_a", "v_b"] | |
| # k_norm / q_norm๏ผ็จไบๆจๆญ head_dim๏ผไธๆฏ QKV | |
| _NORM_KEYS = ["k_norm", "q_norm"] | |
| def classify_qkv_suffix(suffix: str) -> str | None: | |
| """ | |
| layers.{N}. ไนๅ็ๅ็ผ โ 'q' / 'k' / 'v' / None | |
| ๆฏๆ๏ผ | |
| ๆ ๅ: self_attn.q_proj.weight | |
| ๅตๅฅ: self_attn.q_proj.linear.weight (audio/vision tower) | |
| """ | |
| if not suffix.endswith(".weight"): | |
| return None | |
| s = suffix.lower() | |
| # ๆ้ค้ QKV | |
| if any(e in s for e in _EXCLUDE_PATTERNS): | |
| return None | |
| if any(p in s for p in _Q_PATTERNS): | |
| return "q" | |
| if any(p in s for p in _K_PATTERNS): | |
| return "k" | |
| if any(p in s for p in _V_PATTERNS): | |
| return "v" | |
| return None | |
| def is_norm_key(suffix: str) -> bool: | |
| """ๅคๆญๆฏๅฆไธบ norm key๏ผ็จไบๆจๆญ head_dim๏ผ""" | |
| s = suffix.lower() | |
| return any(n in s for n in _NORM_KEYS) and suffix.endswith(".weight") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # LayerProfile ๆฐๆฎ็ปๆ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class QKVKey: | |
| """ๅไธช Q/K/V weight ็ไฝ็ฝฎไฟกๆฏ""" | |
| shard: str # ๆๅจ shard ๆไปถๅ | |
| key: str # ๅฎๆด key ๅ | |
| shape: list # weight shape | |
| class LayerProfile: | |
| """ | |
| ไธไธช (prefix, layer_idx) ๆงฝ็ๅฎๆด็ปๆไฟกๆฏ | |
| ๆๆๅญๆฎตๅไปๆ้ๆไปถ่ชๅจๆจๆญ๏ผ้ถ hard coding | |
| """ | |
| prefix: str | |
| layer_idx: int | |
| # QKV ไฝ็ฝฎ | |
| q: QKVKey | None = None | |
| k: QKVKey | None = None | |
| v: QKVKey | None = None # None = K=V ๅ ฑไบซ | |
| # ่ชๅจๆจๆญ็็ปดๅบฆ | |
| head_dim: int = 0 | |
| n_q_heads: int = 0 | |
| n_kv_heads: int = 0 | |
| d_model: int = 0 # = q_shape[1] | |
| # ๆ ๅฟ | |
| kv_shared: bool = False # V ๆฏๅฆๅค็จ K | |
| complete: bool = False # Q/K ้ฝๅญๅจๆ็ฎ complete | |
| infer_ok: bool = False # head_dim ๆจๆญๆๅ | |
| # ๆจๆญๆฅๆบ๏ผ่ฐ่ฏ็จ๏ผ | |
| head_dim_source: str = "" # "k_norm" / "q_norm" / "config" / "enum" | |
| # ๅๅง norm shape๏ผ็จไบๆจๆญ head_dim๏ผ | |
| k_norm_shape: list = field(default_factory=list) | |
| q_norm_shape: list = field(default_factory=list) | |
| def summary(self) -> str: | |
| kv_tag = "[K=Vๅ ฑไบซ]" if self.kv_shared else "" | |
| return ( | |
| f"Layer {self.layer_idx:3d} | " | |
| f"d_model={self.d_model:5d} | " | |
| f"head_dim={self.head_dim:4d}({self.head_dim_source}) | " | |
| f"n_q={self.n_q_heads:3d} n_kv={self.n_kv_heads:3d} | " | |
| f"{kv_tag}" | |
| ) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆ ธๅฟ๏ผ่ชๅจๆจๆญ head_dim | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _infer_head_dim( | |
| q_shape: list, | |
| k_shape: list, | |
| k_norm_shape: list, | |
| q_norm_shape: list, | |
| config_params: dict, | |
| ) -> tuple[int, str]: | |
| """ | |
| ๆจๆญ head_dim๏ผ่ฟๅ (head_dim, source) | |
| ไผๅ ็บง๏ผ | |
| 1. k_norm.shape[0] โ ๆๅฏ้ ๏ผGemma ็ณปๅ๏ผ | |
| 2. q_norm.shape[0] โ ๅค็จ | |
| 3. config head_dim | |
| 4. config hidden_size / num_attention_heads | |
| 5. ๆไธพๅ้ๅผ | |
| """ | |
| q_rows = q_shape[0] if q_shape else 0 | |
| k_rows = k_shape[0] if k_shape else 0 | |
| # 1. k_norm | |
| if k_norm_shape and len(k_norm_shape) == 1: | |
| d = k_norm_shape[0] | |
| if d > 0 and (q_rows == 0 or q_rows % d == 0): | |
| return d, "k_norm" | |
| # 2. q_norm | |
| if q_norm_shape and len(q_norm_shape) == 1: | |
| d = q_norm_shape[0] | |
| if d > 0 and (q_rows == 0 or q_rows % d == 0): | |
| return d, "q_norm" | |
| # 3. config head_dim | |
| if config_params: | |
| d = config_params.get("head_dim") | |
| if d and q_rows % d == 0 and k_rows % d == 0: | |
| return d, "config" | |
| # 4. config hidden_size / num_heads | |
| hs = config_params.get("hidden_size") or 0 | |
| nh = config_params.get("num_attention_heads") or 0 | |
| if hs and nh: | |
| d = hs // nh | |
| if d > 0 and q_rows % d == 0 and k_rows % d == 0: | |
| return d, "config_calc" | |
| # 5. ๆไธพ | |
| for d in [512, 256, 128, 96, 80, 64, 48, 40, 32, 16]: | |
| if q_rows % d == 0 and k_rows % d == 0: | |
| return d, "enum" | |
| return 0, "failed" | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ไธปๆซๆๅฝๆฐ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def scan_model_structure( | |
| all_shard_headers: dict[str, tuple[dict, int]], | |
| config_params: dict = None, | |
| ) -> dict[tuple[str, int], LayerProfile]: | |
| """ | |
| ๆซๆๆๆ shard headers๏ผๆๅปบๅฎๆด็ LayerProfile ๅญๅ ธใ | |
| ่ฟๅ๏ผ | |
| { | |
| (prefix, layer_idx): LayerProfile, | |
| ... | |
| } | |
| ็นๆง๏ผ | |
| - ้ถ hard coding | |
| - ่ชๅจๆฃๆต K=V ๅ ฑไบซ | |
| - ่ชๅจๆจๆญ head_dim | |
| - ไธๅ็ปไปถ็ๅ็ผๅทๅฑๅฎๅ จ็ฌ็ซ | |
| """ | |
| config_params = config_params or {} | |
| # โโ ็ฌฌไธ้๏ผๆถ้ๆๆๅๅงไฟกๆฏ โโโโโโโโโโโโโโโโโ | |
| # slot โ { "q/k/v/k_norm/q_norm": QKVKey } | |
| raw: dict[tuple[str, int], dict] = {} | |
| for shard_name, (header, _) in all_shard_headers.items(): | |
| for key, info in header.items(): | |
| m = re.search(r'layers\.(\d+)\.', key) | |
| if not m: | |
| continue | |
| layer_idx = int(m.group(1)) | |
| prefix = key[:m.start()] # ็ฒพ็กฎๆชๆญ | |
| suffix = key[m.end():] | |
| slot = (prefix, layer_idx) | |
| if slot not in raw: | |
| raw[slot] = {} | |
| shape = info.get("shape", []) | |
| # ๅ็ฑป | |
| role = classify_qkv_suffix(suffix) | |
| if role and role not in raw[slot]: | |
| raw[slot][role] = QKVKey( | |
| shard=shard_name, | |
| key=key, | |
| shape=shape | |
| ) | |
| continue | |
| # ๆถ้ norm shape๏ผ็จไบ head_dim ๆจๆญ๏ผ | |
| if is_norm_key(suffix): | |
| s = suffix.lower() | |
| if "k_norm" in s and "k_norm_shape" not in raw[slot]: | |
| raw[slot]["k_norm_shape"] = shape | |
| elif "q_norm" in s and "q_norm_shape" not in raw[slot]: | |
| raw[slot]["q_norm_shape"] = shape | |
| # โโ ็ฌฌไบ้๏ผๆๅปบ LayerProfile โโโโโโโโโโโโโโโโโ | |
| profiles: dict[tuple[str, int], LayerProfile] = {} | |
| for slot, data in raw.items(): | |
| prefix, layer_idx = slot | |
| q = data.get("q") | |
| k = data.get("k") | |
| v = data.get("v") | |
| # Q/K ๅฟ ้กปๅญๅจๆๆๆไน | |
| if q is None or k is None: | |
| continue | |
| # K=V ๅ ฑไบซๆฃๆต๏ผv_key ไธๅญๅจ | |
| kv_shared = (v is None) | |
| k_norm_shape = data.get("k_norm_shape", []) | |
| q_norm_shape = data.get("q_norm_shape", []) | |
| # ๆจๆญ head_dim | |
| head_dim, source = _infer_head_dim( | |
| q_shape = q.shape, | |
| k_shape = k.shape, | |
| k_norm_shape = k_norm_shape, | |
| q_norm_shape = q_norm_shape, | |
| config_params= config_params, | |
| ) | |
| infer_ok = head_dim > 0 | |
| n_q_heads = q.shape[0] // head_dim if infer_ok and q.shape else 0 | |
| n_kv_heads= k.shape[0] // head_dim if infer_ok and k.shape else 0 | |
| d_model = q.shape[1] if q.shape and len(q.shape) > 1 else 0 | |
| # ้ช่ฏๆด้คๆง | |
| if infer_ok and q.shape and q.shape[0] % head_dim != 0: | |
| infer_ok = False | |
| profiles[slot] = LayerProfile( | |
| prefix = prefix, | |
| layer_idx = layer_idx, | |
| q = q, | |
| k = k, | |
| v = v, | |
| head_dim = head_dim, | |
| n_q_heads = n_q_heads, | |
| n_kv_heads = n_kv_heads, | |
| d_model = d_model, | |
| kv_shared = kv_shared, | |
| complete = infer_ok and n_q_heads > 0 and n_kv_heads > 0, | |
| infer_ok = infer_ok, | |
| head_dim_source = source, | |
| k_norm_shape = k_norm_shape, | |
| q_norm_shape = q_norm_shape, | |
| ) | |
| return profiles | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ็ปๆๆฆ่ง๏ผไพ Tab1 ๅฑ็คบ๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def summarize_structure( | |
| profiles: dict[tuple[str, int], LayerProfile] | |
| ) -> str: | |
| """็ๆไบบ็ฑปๅฏ่ฏป็็ปๆๆฆ่งๆๆฌ""" | |
| if not profiles: | |
| return "โ ๏ธ ๆชๅ็ฐไปปไฝๆๆๅฑ\n" | |
| # ๆ prefix ๅ็ป | |
| by_prefix: dict[str, list[LayerProfile]] = {} | |
| for (prefix, _), prof in profiles.items(): | |
| by_prefix.setdefault(prefix, []).append(prof) | |
| lines = [] | |
| for prefix in sorted(by_prefix): | |
| profs = sorted(by_prefix[prefix], key=lambda p: p.layer_idx) | |
| layer_idxs = [p.layer_idx for p in profs] | |
| complete = [p for p in profs if p.complete] | |
| kv_shared = [p for p in profs if p.kv_shared] | |
| # ๆฃๆตๅผๆ head_dim | |
| head_dims = sorted(set(p.head_dim for p in complete)) | |
| lines.append(f"\n{'โ'*70}") | |
| lines.append(f"็ปไปถ๏ผ'{prefix}'") | |
| lines.append( | |
| f" ๅฑๆฐ๏ผ{len(profs)} " | |
| f"่ๅด๏ผ{layer_idxs[0]}~{layer_idxs[-1]} " | |
| f"ๅฎๆดๅฑ๏ผ{len(complete)}" | |
| ) | |
| lines.append(f" head_dim๏ผ{head_dims}") | |
| if kv_shared: | |
| lines.append( | |
| f" K=Vๅ ฑไบซๅฑ๏ผ{[p.layer_idx for p in kv_shared]}" | |
| ) | |
| # ๅผๆๅฑ่ฏฆๆ | |
| if len(head_dims) > 1: | |
| lines.append(" โ ๏ธ ๅผๆ head_dim ๆฃๆตๅฐ๏ผ") | |
| for d in head_dims: | |
| idxs = [p.layer_idx for p in complete if p.head_dim == d] | |
| lines.append(f" head_dim={d:4d} โ ๅฑ {idxs}") | |
| # ๆฏๅฑไธ่ก็ฎ่ฆไฟกๆฏ | |
| lines.append("") | |
| for p in profs: | |
| if p.complete: | |
| lines.append(f" {p.summary()}") | |
| else: | |
| lines.append( | |
| f" Layer {p.layer_idx:3d} | " | |
| f"โ ๏ธ ไธๅฎๆด " | |
| f"(head_dimๆจๆญ:{p.head_dim_source})" | |
| ) | |
| lines.append(f"\n{'โ'*70}") | |
| return "\n".join(lines) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # config ่งฃๆ๏ผๅ ผๅฎน Gemma4 text_config๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def extract_config_params(config: dict) -> dict: | |
| """ | |
| ๅ ผๅฎนไธๅๆจกๅ็ config.json ๅญๆฎต๏ผ | |
| - ๆ ๅ๏ผ้กถๅฑๅญๆฎต | |
| - Gemma4๏ผtext_config ๅญๅญๆฎต | |
| """ | |
| if not config: | |
| return {} | |
| text_cfg = config.get("text_config", {}) or {} | |
| def get(*keys): | |
| for k in keys: | |
| v = config.get(k) | |
| if v is not None: | |
| return v | |
| v = text_cfg.get(k) | |
| if v is not None: | |
| return v | |
| return None | |
| return { | |
| "model_type": get("model_type"), | |
| "hidden_size": get("hidden_size"), | |
| "num_attention_heads": get("num_attention_heads"), | |
| "num_key_value_heads": get("num_key_value_heads"), | |
| "head_dim": get("head_dim"), | |
| } |