math-under-llm / core /layer_profile.py
Alex W.
feat: restructure into modular architecture with auto-inference engine
81d60af
# core/layer_profile.py
"""
ไปŽ safetensors headers ่‡ชๅŠจๆŽจๆ–ญๆฏไธ€ๅฑ‚็š„็ป“ๆž„๏ผš
- head_dim๏ผˆไผ˜ๅ…ˆ k_norm/q_norm shape๏ผŒๅ…ถๆฌก config๏ผŒๆœ€ๅŽๆžšไธพ๏ผ‰
- K=V ๅ…ฑไบซๆฃ€ๆต‹๏ผˆv_key ๆ˜ฏๅฆๅญ˜ๅœจ๏ผ‰
- ็ป„ไปถๅ‰็ผ€่‡ชๅŠจๅˆ†็ฆป
- ้›ถ hard coding
"""
import re
from dataclasses import dataclass, field
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# QKV ๅŽ็ผ€ๅˆ†็ฑป
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ็ฒพ็กฎๆŽ’้™คๅˆ—่กจ๏ผˆไธๆ˜ฏ Q/K/V ไธปๆƒ้‡๏ผ‰
_EXCLUDE_PATTERNS = [
"norm", # layernorm, k_norm, q_norm ็ญ‰
"rope", # rotary embedding
"lm_head",
"o_proj", # ่พ“ๅ‡บๆŠ•ๅฝฑ
"out_proj",
"post", # audio tower ็š„ post linear
"relative", # audio tower relative_k_proj
"per_dim", # audio tower per_dim_scale
"scalar",
"gate_proj", # FFN
"up_proj",
"down_proj",
"ffw_layer", # audio FFN
"depthwise",
"conv",
"linear_start",
"linear_end",
"per_layer",
"embed",
"input_max", # audio ้‡ๅŒ–็ปŸ่ฎก้‡
"input_min",
"output_max",
"output_min",
]
_Q_PATTERNS = ["q_proj", "wq", "query", "q_a", "q_b"]
_K_PATTERNS = ["k_proj", "wk", "key", "k_a", "k_b"]
_V_PATTERNS = ["v_proj", "wv", "value", "v_a", "v_b"]
# k_norm / q_norm๏ผš็”จไบŽๆŽจๆ–ญ head_dim๏ผŒไธๆ˜ฏ QKV
_NORM_KEYS = ["k_norm", "q_norm"]
def classify_qkv_suffix(suffix: str) -> str | None:
"""
layers.{N}. ไน‹ๅŽ็š„ๅŽ็ผ€ โ†’ 'q' / 'k' / 'v' / None
ๆ”ฏๆŒ๏ผš
ๆ ‡ๅ‡†: self_attn.q_proj.weight
ๅตŒๅฅ—: self_attn.q_proj.linear.weight (audio/vision tower)
"""
if not suffix.endswith(".weight"):
return None
s = suffix.lower()
# ๆŽ’้™ค้ž QKV
if any(e in s for e in _EXCLUDE_PATTERNS):
return None
if any(p in s for p in _Q_PATTERNS):
return "q"
if any(p in s for p in _K_PATTERNS):
return "k"
if any(p in s for p in _V_PATTERNS):
return "v"
return None
def is_norm_key(suffix: str) -> bool:
"""ๅˆคๆ–ญๆ˜ฏๅฆไธบ norm key๏ผˆ็”จไบŽๆŽจๆ–ญ head_dim๏ผ‰"""
s = suffix.lower()
return any(n in s for n in _NORM_KEYS) and suffix.endswith(".weight")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# LayerProfile ๆ•ฐๆฎ็ป“ๆž„
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@dataclass
class QKVKey:
"""ๅ•ไธช Q/K/V weight ็š„ไฝ็ฝฎไฟกๆฏ"""
shard: str # ๆ‰€ๅœจ shard ๆ–‡ไปถๅ
key: str # ๅฎŒๆ•ด key ๅ
shape: list # weight shape
@dataclass
class LayerProfile:
"""
ไธ€ไธช (prefix, layer_idx) ๆงฝ็š„ๅฎŒๆ•ด็ป“ๆž„ไฟกๆฏ
ๆ‰€ๆœ‰ๅญ—ๆฎตๅ‡ไปŽๆƒ้‡ๆ–‡ไปถ่‡ชๅŠจๆŽจๆ–ญ๏ผŒ้›ถ hard coding
"""
prefix: str
layer_idx: int
# QKV ไฝ็ฝฎ
q: QKVKey | None = None
k: QKVKey | None = None
v: QKVKey | None = None # None = K=V ๅ…ฑไบซ
# ่‡ชๅŠจๆŽจๆ–ญ็š„็ปดๅบฆ
head_dim: int = 0
n_q_heads: int = 0
n_kv_heads: int = 0
d_model: int = 0 # = q_shape[1]
# ๆ ‡ๅฟ—
kv_shared: bool = False # V ๆ˜ฏๅฆๅค็”จ K
complete: bool = False # Q/K ้ƒฝๅญ˜ๅœจๆ‰็ฎ— complete
infer_ok: bool = False # head_dim ๆŽจๆ–ญๆˆๅŠŸ
# ๆŽจๆ–ญๆฅๆบ๏ผˆ่ฐƒ่ฏ•็”จ๏ผ‰
head_dim_source: str = "" # "k_norm" / "q_norm" / "config" / "enum"
# ๅŽŸๅง‹ norm shape๏ผˆ็”จไบŽๆŽจๆ–ญ head_dim๏ผ‰
k_norm_shape: list = field(default_factory=list)
q_norm_shape: list = field(default_factory=list)
def summary(self) -> str:
kv_tag = "[K=Vๅ…ฑไบซ]" if self.kv_shared else ""
return (
f"Layer {self.layer_idx:3d} | "
f"d_model={self.d_model:5d} | "
f"head_dim={self.head_dim:4d}({self.head_dim_source}) | "
f"n_q={self.n_q_heads:3d} n_kv={self.n_kv_heads:3d} | "
f"{kv_tag}"
)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆ ธๅฟƒ๏ผš่‡ชๅŠจๆŽจๆ–ญ head_dim
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _infer_head_dim(
q_shape: list,
k_shape: list,
k_norm_shape: list,
q_norm_shape: list,
config_params: dict,
) -> tuple[int, str]:
"""
ๆŽจๆ–ญ head_dim๏ผŒ่ฟ”ๅ›ž (head_dim, source)
ไผ˜ๅ…ˆ็บง๏ผš
1. k_norm.shape[0] โ†’ ๆœ€ๅฏ้ ๏ผˆGemma ็ณปๅˆ—๏ผ‰
2. q_norm.shape[0] โ†’ ๅค‡็”จ
3. config head_dim
4. config hidden_size / num_attention_heads
5. ๆžšไธพๅ€™้€‰ๅ€ผ
"""
q_rows = q_shape[0] if q_shape else 0
k_rows = k_shape[0] if k_shape else 0
# 1. k_norm
if k_norm_shape and len(k_norm_shape) == 1:
d = k_norm_shape[0]
if d > 0 and (q_rows == 0 or q_rows % d == 0):
return d, "k_norm"
# 2. q_norm
if q_norm_shape and len(q_norm_shape) == 1:
d = q_norm_shape[0]
if d > 0 and (q_rows == 0 or q_rows % d == 0):
return d, "q_norm"
# 3. config head_dim
if config_params:
d = config_params.get("head_dim")
if d and q_rows % d == 0 and k_rows % d == 0:
return d, "config"
# 4. config hidden_size / num_heads
hs = config_params.get("hidden_size") or 0
nh = config_params.get("num_attention_heads") or 0
if hs and nh:
d = hs // nh
if d > 0 and q_rows % d == 0 and k_rows % d == 0:
return d, "config_calc"
# 5. ๆžšไธพ
for d in [512, 256, 128, 96, 80, 64, 48, 40, 32, 16]:
if q_rows % d == 0 and k_rows % d == 0:
return d, "enum"
return 0, "failed"
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ไธปๆ‰ซๆๅ‡ฝๆ•ฐ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def scan_model_structure(
all_shard_headers: dict[str, tuple[dict, int]],
config_params: dict = None,
) -> dict[tuple[str, int], LayerProfile]:
"""
ๆ‰ซๆๆ‰€ๆœ‰ shard headers๏ผŒๆž„ๅปบๅฎŒๆ•ด็š„ LayerProfile ๅญ—ๅ…ธใ€‚
่ฟ”ๅ›ž๏ผš
{
(prefix, layer_idx): LayerProfile,
...
}
็‰นๆ€ง๏ผš
- ้›ถ hard coding
- ่‡ชๅŠจๆฃ€ๆต‹ K=V ๅ…ฑไบซ
- ่‡ชๅŠจๆŽจๆ–ญ head_dim
- ไธๅŒ็ป„ไปถ็š„ๅŒ็ผ–ๅทๅฑ‚ๅฎŒๅ…จ็‹ฌ็ซ‹
"""
config_params = config_params or {}
# โ”€โ”€ ็ฌฌไธ€้๏ผšๆ”ถ้›†ๆ‰€ๆœ‰ๅŽŸๅง‹ไฟกๆฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# slot โ†’ { "q/k/v/k_norm/q_norm": QKVKey }
raw: dict[tuple[str, int], dict] = {}
for shard_name, (header, _) in all_shard_headers.items():
for key, info in header.items():
m = re.search(r'layers\.(\d+)\.', key)
if not m:
continue
layer_idx = int(m.group(1))
prefix = key[:m.start()] # ็ฒพ็กฎๆˆชๆ–ญ
suffix = key[m.end():]
slot = (prefix, layer_idx)
if slot not in raw:
raw[slot] = {}
shape = info.get("shape", [])
# ๅˆ†็ฑป
role = classify_qkv_suffix(suffix)
if role and role not in raw[slot]:
raw[slot][role] = QKVKey(
shard=shard_name,
key=key,
shape=shape
)
continue
# ๆ”ถ้›† norm shape๏ผˆ็”จไบŽ head_dim ๆŽจๆ–ญ๏ผ‰
if is_norm_key(suffix):
s = suffix.lower()
if "k_norm" in s and "k_norm_shape" not in raw[slot]:
raw[slot]["k_norm_shape"] = shape
elif "q_norm" in s and "q_norm_shape" not in raw[slot]:
raw[slot]["q_norm_shape"] = shape
# โ”€โ”€ ็ฌฌไบŒ้๏ผšๆž„ๅปบ LayerProfile โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
profiles: dict[tuple[str, int], LayerProfile] = {}
for slot, data in raw.items():
prefix, layer_idx = slot
q = data.get("q")
k = data.get("k")
v = data.get("v")
# Q/K ๅฟ…้กปๅญ˜ๅœจๆ‰ๆœ‰ๆ„ไน‰
if q is None or k is None:
continue
# K=V ๅ…ฑไบซๆฃ€ๆต‹๏ผšv_key ไธๅญ˜ๅœจ
kv_shared = (v is None)
k_norm_shape = data.get("k_norm_shape", [])
q_norm_shape = data.get("q_norm_shape", [])
# ๆŽจๆ–ญ head_dim
head_dim, source = _infer_head_dim(
q_shape = q.shape,
k_shape = k.shape,
k_norm_shape = k_norm_shape,
q_norm_shape = q_norm_shape,
config_params= config_params,
)
infer_ok = head_dim > 0
n_q_heads = q.shape[0] // head_dim if infer_ok and q.shape else 0
n_kv_heads= k.shape[0] // head_dim if infer_ok and k.shape else 0
d_model = q.shape[1] if q.shape and len(q.shape) > 1 else 0
# ้ชŒ่ฏๆ•ด้™คๆ€ง
if infer_ok and q.shape and q.shape[0] % head_dim != 0:
infer_ok = False
profiles[slot] = LayerProfile(
prefix = prefix,
layer_idx = layer_idx,
q = q,
k = k,
v = v,
head_dim = head_dim,
n_q_heads = n_q_heads,
n_kv_heads = n_kv_heads,
d_model = d_model,
kv_shared = kv_shared,
complete = infer_ok and n_q_heads > 0 and n_kv_heads > 0,
infer_ok = infer_ok,
head_dim_source = source,
k_norm_shape = k_norm_shape,
q_norm_shape = q_norm_shape,
)
return profiles
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ็ป“ๆž„ๆฆ‚่งˆ๏ผˆไพ› Tab1 ๅฑ•็คบ๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def summarize_structure(
profiles: dict[tuple[str, int], LayerProfile]
) -> str:
"""็”Ÿๆˆไบบ็ฑปๅฏ่ฏป็š„็ป“ๆž„ๆฆ‚่งˆๆ–‡ๆœฌ"""
if not profiles:
return "โš ๏ธ ๆœชๅ‘็Žฐไปปไฝ•ๆœ‰ๆ•ˆๅฑ‚\n"
# ๆŒ‰ prefix ๅˆ†็ป„
by_prefix: dict[str, list[LayerProfile]] = {}
for (prefix, _), prof in profiles.items():
by_prefix.setdefault(prefix, []).append(prof)
lines = []
for prefix in sorted(by_prefix):
profs = sorted(by_prefix[prefix], key=lambda p: p.layer_idx)
layer_idxs = [p.layer_idx for p in profs]
complete = [p for p in profs if p.complete]
kv_shared = [p for p in profs if p.kv_shared]
# ๆฃ€ๆต‹ๅผ‚ๆž„ head_dim
head_dims = sorted(set(p.head_dim for p in complete))
lines.append(f"\n{'โ”€'*70}")
lines.append(f"็ป„ไปถ๏ผš'{prefix}'")
lines.append(
f" ๅฑ‚ๆ•ฐ๏ผš{len(profs)} "
f"่Œƒๅ›ด๏ผš{layer_idxs[0]}~{layer_idxs[-1]} "
f"ๅฎŒๆ•ดๅฑ‚๏ผš{len(complete)}"
)
lines.append(f" head_dim๏ผš{head_dims}")
if kv_shared:
lines.append(
f" K=Vๅ…ฑไบซๅฑ‚๏ผš{[p.layer_idx for p in kv_shared]}"
)
# ๅผ‚ๆž„ๅฑ‚่ฏฆๆƒ…
if len(head_dims) > 1:
lines.append(" โš ๏ธ ๅผ‚ๆž„ head_dim ๆฃ€ๆต‹ๅˆฐ๏ผš")
for d in head_dims:
idxs = [p.layer_idx for p in complete if p.head_dim == d]
lines.append(f" head_dim={d:4d} โ†’ ๅฑ‚ {idxs}")
# ๆฏๅฑ‚ไธ€่กŒ็ฎ€่ฆไฟกๆฏ
lines.append("")
for p in profs:
if p.complete:
lines.append(f" {p.summary()}")
else:
lines.append(
f" Layer {p.layer_idx:3d} | "
f"โš ๏ธ ไธๅฎŒๆ•ด "
f"(head_dimๆŽจๆ–ญ:{p.head_dim_source})"
)
lines.append(f"\n{'โ”€'*70}")
return "\n".join(lines)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# config ่งฃๆž๏ผˆๅ…ผๅฎน Gemma4 text_config๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def extract_config_params(config: dict) -> dict:
"""
ๅ…ผๅฎนไธๅŒๆจกๅž‹็š„ config.json ๅญ—ๆฎต๏ผš
- ๆ ‡ๅ‡†๏ผš้กถๅฑ‚ๅญ—ๆฎต
- Gemma4๏ผštext_config ๅญๅญ—ๆฎต
"""
if not config:
return {}
text_cfg = config.get("text_config", {}) or {}
def get(*keys):
for k in keys:
v = config.get(k)
if v is not None:
return v
v = text_cfg.get(k)
if v is not None:
return v
return None
return {
"model_type": get("model_type"),
"hidden_size": get("hidden_size"),
"num_attention_heads": get("num_attention_heads"),
"num_key_value_heads": get("num_key_value_heads"),
"head_dim": get("head_dim"),
}