math-under-llm / ui /tab_inspect.py
Alex W.
refactor(ui): translate tab_inspect.py to English
8d21309
# ui/tab_inspect.py
"""
Tab1: Model Structure Inspection
- Read all shard headers
- Display raw key structure
- Auto-build LayerProfile and display inferred results
"""
import gradio as gr
import requests
import pandas as pd
from core.fetcher import (
load_all_shard_headers,
get_file_url,
check_quantization,
http_error_msg,
)
from core.layer_profile import (
scan_model_structure,
summarize_structure,
extract_config_params,
)
SIDEBAR_MD = """
### βœ… Recommended Models
google/gemma-4-e2b
google/gemma-4-e4b-it
google/gemma-4-31b-it
Qwen/Qwen2.5-14B-Instruct
deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
meta-llama/Meta-Llama-3-8B (Need access rightοΌ‰
---
### Layer Index
- Layer index = **N** in `layers.{N}` of safetensors keys
- Raw index, **not re-numbered per component**
- Multi-modal models (e.g. Gemma-4):
- `layers.0~11` may contain audio / vision / text layers
- All components output separately, distinguished by prefix
### Example: Gemma-4-E2B
| Component | Layer Range |
|-----------|-------------|
| audio_tower | 0 ~ 11 |
| language_model | 0 ~ 34 |
| vision_tower | 0 ~ 15 |
### Example: Gemma-4-31B
| Component | Layer Range |
|-----------|-------------|
| language (local) | 0 ~ 59 |
| language (global) | 5, 11, 17 … 59 |
| vision_tower | 0 ~ 26 |
"""
def inspect_model(
model_id: str,
hf_token: str,
progress=gr.Progress()
) -> tuple[str, pd.DataFrame]:
"""
Returns (inspection log text, layer structure DataFrame)
"""
if not model_id.strip():
return "❌ Please enter a model ID.", None
token = hf_token.strip() or None
log = [f"πŸ”¬ Structure Inspection: {model_id}\n{'═'*80}\n"]
# ── Quantization check ────────────────────────────────────────────────────
progress(0.05, desc="Checking quantization...")
blocked, qmsg = check_quantization(model_id, token)
log.append(f"[Quantization Check]\n{qmsg}\n{'─'*80}\n")
if blocked:
return "".join(log), None
# ── config.json ───────────────────────────────────────────────────────────
progress(0.10, desc="Reading config...")
config_params = {}
try:
r = requests.get(
f"https://huggingface.co/{model_id}/resolve/main/config.json",
headers={"Authorization": f"Bearer {token}"} if token else {},
timeout=15
)
if r.status_code == 200:
config_params = extract_config_params(r.json())
log.append(
f"πŸ“‹ Config:\n"
f" model_type = {config_params.get('model_type')}\n"
f" hidden = {config_params.get('hidden_size')}\n"
f" n_heads = {config_params.get('num_attention_heads')}\n"
f" n_kv = {config_params.get('num_key_value_heads')}\n"
f" head_dim = {config_params.get('head_dim')}\n"
f"{'─'*80}\n"
)
except Exception as e:
log.append(f"⚠️ Could not read config.json: {e}\n")
# ── Load all shard headers ─────────────────────────────────────────────────
progress(0.20, desc="Loading shard headers...")
try:
all_headers = load_all_shard_headers(model_id, token)
except requests.exceptions.HTTPError as e:
return http_error_msg(e, model_id), None
except Exception as e:
return "".join(log) + f"❌ Failed to load headers: {e}\n", None
total_keys = sum(len(h) for h, _ in all_headers.values())
log.append(
f"πŸ“¦ Shards: {len(all_headers)} "
f"Total keys: {total_keys}\n"
f"{'─'*80}\n"
)
# ── Scan layer structure ───────────────────────────────────────────────────
progress(0.50, desc="Scanning layer structure...")
profiles = scan_model_structure(all_headers, config_params)
if not profiles:
sample = []
for h, _ in list(all_headers.values())[:1]:
sample = list(h.keys())[:30]
return (
"".join(log) +
"⚠️ No Q/K/V layers found. First 30 keys:\n" +
"\n".join(sample), None
)
# ── Generate structure text ────────────────────────────────────────────────
progress(0.80, desc="Generating report...")
struct_text = summarize_structure(profiles)
log.append(struct_text)
# ── Build overview DataFrame ───────────────────────────────────────────────
rows = []
for (prefix, layer_idx), p in sorted(profiles.items()):
rows.append({
"prefix": prefix,
"layer": layer_idx,
"d_model": p.d_model,
"head_dim": p.head_dim,
"dim_source": p.head_dim_source,
"n_q": p.n_q_heads,
"n_kv": p.n_kv_heads,
"kv_shared": p.kv_shared,
"complete": p.complete,
"q_shape": str(p.q.shape) if p.q else "",
"k_shape": str(p.k.shape) if p.k else "",
"v_shape": str(p.v.shape) if p.v else "K=V",
})
df = pd.DataFrame(rows)
progress(1.0, desc="Done")
return "".join(log), df
# ─────────────────────────────────────────────
# Tab1 UI
# ─────────────────────────────────────────────
def build_tab_inspect():
with gr.Tab("πŸ”¬ Inspect"):
gr.Markdown("""
**Step 1: Inspect model structure** β€” auto-detect components, head_dim, and K=V shared layers.
Results are used by the **Analyze** tab.
> No weights are downloaded β€” structure is inferred from safetensors headers only.
""")
with gr.Row():
with gr.Column(scale=3):
inspect_model_id = gr.Textbox(
label="HuggingFace Model ID",
placeholder="google/gemma-4-e2b",
value="google/gemma-4-e2b"
)
inspect_token = gr.Textbox(
label="HF Access Token (leave empty for public models)",
type="password"
)
inspect_btn = gr.Button("πŸ” Inspect Structure", variant="secondary")
with gr.Column(scale=1):
gr.Markdown(SIDEBAR_MD)
inspect_log = gr.Textbox(
label="Inspection Log",
lines=30, max_lines=200
)
inspect_table = gr.Dataframe(
label="Layer Structure Overview",
headers=[
"prefix", "layer", "d_model", "head_dim", "dim_source",
"n_q", "n_kv", "kv_shared", "complete",
"q_shape", "k_shape", "v_shape"
]
)
inspect_btn.click(
fn=inspect_model,
inputs=[inspect_model_id, inspect_token],
outputs=[inspect_log, inspect_table]
)
return inspect_model_id, inspect_token