Alex W. commited on
Commit
0ff8a89
·
1 Parent(s): c8aec40

feat(db+ui): add modality dimension & migrate legacy data

Browse files

Introduce `modality` (language/vision/audio) as a first-class dimension
alongside the existing `layer_type` (standard/global), fixing misleading
total_layers display for multi-modal models (Gemma-4 series).

---

1. `total_layers` mixed audio+vision+language into one number —
meaningless for multi-modal models like gemma-4-31b-it.
2. Layer-type dropdown showed "standard/global" (structural concept)
where users expected "language/vision/audio" (modality concept).
3. These two concepts (structure vs modality) were conflated in one field.

---

- Add `modality TEXT DEFAULT 'language'` to `layer_head_metrics`
- Add `modality TEXT DEFAULT 'language'` to `components`
- Add `_migrate_add_modality()`: idempotent ALTER TABLE migration
- Runs on every `init_db()` startup
- Detects missing column via `PRAGMA table_info()`
- Backfills legacy rows via keyword matching on `prefix`
- vision: LIKE '%vision%' OR '%visual%' OR '%image%'
- audio: LIKE '%audio%' OR '%speech%' OR '%acoustic%'
- language: DEFAULT (covers pure-text models e.g. LLaMA, Qwen)
- Add indexes: `idx_metrics_modality`, `idx_components_modality`

- Add `infer_modality(prefix: str) -> str`
- Keyword match on lowercased prefix, no model-name hard-coding
- Default → 'language' (covers "model." prefix of LLaMA/Qwen)
- `write_layer_records()`: fill `modality` column on every insert
- `upsert_component()`: fill `modality` column on every insert

- `get_analyzed_models()`:
- Remove `total_layers`
- Add `language_layers`, `vision_layers`, `audio_layers`
via CASE WHEN aggregation (auto-includes standard+global)
- Add `get_model_components()`:
- Returns raw components rows for a model (Plan B detail view)
- `get_layer_metrics()`:
- Add `modality` filter parameter (independent of `layer_type`)
- Both filters composable: e.g. modality='language' + layer_type='global'
- `get_leaderboard()`:
- Replace `prefix_filter` text param with `modality` dropdown param
- Default modality='language' (leaderboard targets text reasoning)

---

- Model list (Plan A): show language/vision/audio layer counts separately
- vision/audio show "" when 0 (cleaner display)
- Model detail (Plan B): add `components_table` showing raw prefix rows
- Expandable detail alongside summary stats
- Raw data query: split single dropdown into two independent dropdowns
- Modality: [all | language | vision | audio]
- Layer Type: [all | standard | global]
- info text explains each option in EN + 中文

- Replace `prefix_filter` Textbox with `modality` Dropdown
- Choices: [language | vision | audio | all]
- Default: 'language' (standard use case)
- Add `modality` column to leaderboard display table

---

- `total_layers` removed: language_layers = SUM(n_layers) per modality,
naturally includes all layer_types under same prefix
- `layer_type` (standard/global) retained: orthogonal structural dimension
- Future unknown layer types → default to 'language' (no schema change needed)
- `model_summary` table unchanged: leaderboard filters via components JOIN

---

- Zero breaking changes to core/ and tab_analyze.py
- Migration is idempotent: safe to deploy on existing DB
- New DB: modality column present from creation, migration is no-op

Files changed (5) hide show
  1. db/reader.py +98 -66
  2. db/schema.py +114 -83
  3. db/writer.py +168 -227
  4. ui/tab_database.py +140 -89
  5. ui/tab_leaderboard.py +56 -78
db/reader.py CHANGED
@@ -2,7 +2,8 @@
2
  """
3
  数据库查询模块
4
  - 排行榜查询
5
- - 模型详情查询
 
6
  - 断点续传状态查询
7
  """
8
 
@@ -17,13 +18,13 @@ from db.schema import get_connection, init_db
17
 
18
  def get_leaderboard(
19
  conn: sqlite3.Connection,
20
- prefix_filter: str = None, # 只看某个组件,None=全部
21
  layer_type: str = "standard",
22
- limit: int = 50,
23
  ) -> pd.DataFrame:
24
  """
25
- 排行榜查询
26
- wang_score 降序排列
27
  """
28
  sql = """
29
  SELECT
@@ -41,7 +42,7 @@ def get_leaderboard(
41
  s.n_layers,
42
  s.n_records,
43
  s.updated_at,
44
- -- 组件信息
45
  c.head_dim_min,
46
  c.head_dim_max,
47
  c.has_kv_shared,
@@ -54,9 +55,9 @@ def get_leaderboard(
54
  """
55
  params = [layer_type]
56
 
57
- if prefix_filter:
58
- sql += " AND s.prefix LIKE ?"
59
- params.append(f"%{prefix_filter}%")
60
 
61
  sql += " ORDER BY s.wang_score DESC LIMIT ?"
62
  params.append(limit)
@@ -64,18 +65,80 @@ def get_leaderboard(
64
  cur = conn.cursor()
65
  cur.execute(sql, params)
66
  rows = cur.fetchall()
67
-
68
  if not rows:
69
  return pd.DataFrame()
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  cols = [d[0] for d in cur.description]
72
  return pd.DataFrame([dict(zip(cols, row)) for row in rows])
73
 
74
 
75
  # ─────────────────────────────────────────────
76
- # 模型详情
77
  # ─────────────────────────────────────────────
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def get_model_summary(
80
  conn: sqlite3.Connection,
81
  model_id: str,
@@ -83,11 +146,9 @@ def get_model_summary(
83
  """获取某模型所有组件的汇总统计"""
84
  cur = conn.cursor()
85
  cur.execute(
86
- """
87
- SELECT * FROM model_summary
88
- WHERE model_id = ?
89
- ORDER BY prefix, layer_type
90
- """,
91
  (model_id,)
92
  )
93
  rows = cur.fetchall()
@@ -97,17 +158,22 @@ def get_model_summary(
97
  return pd.DataFrame([dict(zip(cols, row)) for row in rows])
98
 
99
 
 
 
 
 
100
  def get_layer_metrics(
101
- conn: sqlite3.Connection,
102
- model_id: str,
103
- prefix: str = None,
104
- layer_type: str = None,
105
- start_layer:int = None,
106
- end_layer: int = None,
 
107
  ) -> pd.DataFrame:
108
  """
109
- 查询逐头原始数据
110
- 支持按 prefix / layer_type / 层号范围过滤
111
  """
112
  sql = "SELECT * FROM layer_head_metrics WHERE model_id = ?"
113
  params = [model_id]
@@ -115,6 +181,9 @@ def get_layer_metrics(
115
  if prefix:
116
  sql += " AND prefix = ?"
117
  params.append(prefix)
 
 
 
118
  if layer_type:
119
  sql += " AND layer_type = ?"
120
  params.append(layer_type)
@@ -130,32 +199,6 @@ def get_layer_metrics(
130
  cur = conn.cursor()
131
  cur.execute(sql, params)
132
  rows = cur.fetchall()
133
-
134
- if not rows:
135
- return pd.DataFrame()
136
- cols = [d[0] for d in cur.description]
137
- return pd.DataFrame([dict(zip(cols, row)) for row in rows])
138
-
139
-
140
- def get_analyzed_models(conn: sqlite3.Connection) -> pd.DataFrame:
141
- """获取所有已分析模型列表"""
142
- cur = conn.cursor()
143
- cur.execute(
144
- """
145
- SELECT
146
- m.model_id,
147
- m.model_type,
148
- m.analyzed_at,
149
- m.analyze_sec,
150
- COUNT(DISTINCT c.prefix) as n_components,
151
- SUM(c.n_layers) as total_layers
152
- FROM models m
153
- LEFT JOIN components c ON m.model_id = c.model_id
154
- GROUP BY m.model_id
155
- ORDER BY m.analyzed_at DESC
156
- """
157
- )
158
- rows = cur.fetchall()
159
  if not rows:
160
  return pd.DataFrame()
161
  cols = [d[0] for d in cur.description]
@@ -171,29 +214,18 @@ def get_resume_status(
171
  model_id: str,
172
  prefix: str,
173
  ) -> dict:
174
- """
175
- 查询某 (model_id, prefix) 的断点续传状态
176
- 返回已完成的层号集合和统计信息
177
- """
178
  cur = conn.cursor()
179
-
180
- # 已完成的层
181
  cur.execute(
182
- """
183
- SELECT DISTINCT layer, COUNT(*) as n_heads
184
- FROM layer_head_metrics
185
- WHERE model_id = ? AND prefix = ?
186
- GROUP BY layer
187
- ORDER BY layer
188
- """,
189
  (model_id, prefix)
190
  )
191
  rows = cur.fetchall()
192
-
193
  done_layers = {r[0]: r[1] for r in rows}
194
-
195
  return {
196
  "done_layers": set(done_layers.keys()),
197
- "layer_detail": done_layers, # layer → n_heads
198
  "total_done": len(done_layers),
199
  }
 
2
  """
3
  数据库查询模块
4
  - 排行榜查询
5
+ - 模型详情查询(方案A:按modality聚合 + 方案B:原始components行)
6
+ - 逐头原始数据查询
7
  - 断点续传状态查询
8
  """
9
 
 
18
 
19
  def get_leaderboard(
20
  conn: sqlite3.Connection,
21
+ modality: str = "language", # language/vision/audio/all
22
  layer_type: str = "standard",
23
+ limit: int = 100,
24
  ) -> pd.DataFrame:
25
  """
26
+ 排行榜查询,按 wang_score 降序。
27
+ modality 过滤通过 components 表的 prefix 关联实现。
28
  """
29
  sql = """
30
  SELECT
 
42
  s.n_layers,
43
  s.n_records,
44
  s.updated_at,
45
+ c.modality,
46
  c.head_dim_min,
47
  c.head_dim_max,
48
  c.has_kv_shared,
 
55
  """
56
  params = [layer_type]
57
 
58
+ if modality != "all":
59
+ sql += " AND c.modality = ?"
60
+ params.append(modality)
61
 
62
  sql += " ORDER BY s.wang_score DESC LIMIT ?"
63
  params.append(limit)
 
65
  cur = conn.cursor()
66
  cur.execute(sql, params)
67
  rows = cur.fetchall()
 
68
  if not rows:
69
  return pd.DataFrame()
70
+ cols = [d[0] for d in cur.description]
71
+ return pd.DataFrame([dict(zip(cols, row)) for row in rows])
72
+
73
+
74
+ # ─────────────────────────────────────────────
75
+ # 模型列表(方案A:按modality聚合)
76
+ # ─────────────────────────────────────────────
77
 
78
+ def get_analyzed_models(conn: sqlite3.Connection) -> pd.DataFrame:
79
+ """
80
+ 模型列表,按 modality 聚合层数。
81
+ language_layers = SUM(n_layers) WHERE modality='language'
82
+ 自动包含 standard + global 层(同一 prefix 下)。
83
+ """
84
+ cur = conn.cursor()
85
+ cur.execute(
86
+ """
87
+ SELECT
88
+ m.model_id,
89
+ m.model_type,
90
+ m.analyzed_at,
91
+ m.analyze_sec,
92
+ COUNT(DISTINCT c.prefix) AS n_components,
93
+ SUM(CASE WHEN c.modality = 'language'
94
+ THEN c.n_layers ELSE 0 END) AS language_layers,
95
+ SUM(CASE WHEN c.modality = 'vision'
96
+ THEN c.n_layers ELSE 0 END) AS vision_layers,
97
+ SUM(CASE WHEN c.modality = 'audio'
98
+ THEN c.n_layers ELSE 0 END) AS audio_layers
99
+ FROM models m
100
+ LEFT JOIN components c ON m.model_id = c.model_id
101
+ GROUP BY m.model_id
102
+ ORDER BY m.analyzed_at DESC
103
+ """
104
+ )
105
+ rows = cur.fetchall()
106
+ if not rows:
107
+ return pd.DataFrame()
108
  cols = [d[0] for d in cur.description]
109
  return pd.DataFrame([dict(zip(cols, row)) for row in rows])
110
 
111
 
112
  # ─────────────────────────────────────────────
113
+ # 模型详情(方案B:原始components行)
114
  # ─────────────────────────────────────────────
115
 
116
+ def get_model_components(
117
+ conn: sqlite3.Connection,
118
+ model_id: str,
119
+ ) -> pd.DataFrame:
120
+ """
121
+ 返回某模型的原始 components 行(方案B详情展开用)。
122
+ 每行 = 一个 prefix,含 modality/n_layers/head_dim 等。
123
+ """
124
+ cur = conn.cursor()
125
+ cur.execute(
126
+ """SELECT
127
+ prefix, modality, n_layers,
128
+ head_dim_min, head_dim_max,
129
+ has_kv_shared, has_global, d_model
130
+ FROM components
131
+ WHERE model_id = ?
132
+ ORDER BY modality, prefix""",
133
+ (model_id,)
134
+ )
135
+ rows = cur.fetchall()
136
+ if not rows:
137
+ return pd.DataFrame()
138
+ cols = [d[0] for d in cur.description]
139
+ return pd.DataFrame([dict(zip(cols, row)) for row in rows])
140
+
141
+
142
  def get_model_summary(
143
  conn: sqlite3.Connection,
144
  model_id: str,
 
146
  """获取某模型所有组件的汇总统计"""
147
  cur = conn.cursor()
148
  cur.execute(
149
+ """SELECT * FROM model_summary
150
+ WHERE model_id = ?
151
+ ORDER BY prefix, layer_type""",
 
 
152
  (model_id,)
153
  )
154
  rows = cur.fetchall()
 
158
  return pd.DataFrame([dict(zip(cols, row)) for row in rows])
159
 
160
 
161
+ # ───────────────────────────���─────────────────
162
+ # 逐头原始数据
163
+ # ─────────────────────────────────────────────
164
+
165
  def get_layer_metrics(
166
+ conn: sqlite3.Connection,
167
+ model_id: str,
168
+ prefix: str = None,
169
+ modality: str = None, # language/vision/audio
170
+ layer_type: str = None, # standard/global
171
+ start_layer: int = None,
172
+ end_layer: int = None,
173
  ) -> pd.DataFrame:
174
  """
175
+ 逐头原始数据查询。
176
+ modality layer_type 是两个独立维度,可以组合过滤
177
  """
178
  sql = "SELECT * FROM layer_head_metrics WHERE model_id = ?"
179
  params = [model_id]
 
181
  if prefix:
182
  sql += " AND prefix = ?"
183
  params.append(prefix)
184
+ if modality:
185
+ sql += " AND modality = ?"
186
+ params.append(modality)
187
  if layer_type:
188
  sql += " AND layer_type = ?"
189
  params.append(layer_type)
 
199
  cur = conn.cursor()
200
  cur.execute(sql, params)
201
  rows = cur.fetchall()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if not rows:
203
  return pd.DataFrame()
204
  cols = [d[0] for d in cur.description]
 
214
  model_id: str,
215
  prefix: str,
216
  ) -> dict:
 
 
 
 
217
  cur = conn.cursor()
 
 
218
  cur.execute(
219
+ """SELECT DISTINCT layer, COUNT(*) as n_heads
220
+ FROM layer_head_metrics
221
+ WHERE model_id = ? AND prefix = ?
222
+ GROUP BY layer ORDER BY layer""",
 
 
 
223
  (model_id, prefix)
224
  )
225
  rows = cur.fetchall()
 
226
  done_layers = {r[0]: r[1] for r in rows}
 
227
  return {
228
  "done_layers": set(done_layers.keys()),
229
+ "layer_detail": done_layers,
230
  "total_done": len(done_layers),
231
  }
db/schema.py CHANGED
@@ -8,11 +8,6 @@ import sqlite3
8
  import os
9
  from datetime import datetime
10
 
11
- # ─────────────────────────────────────────────
12
- # 数据库路径
13
- # /data 是 HF Space bucket 挂载点,重启后数据不丢失
14
- # 本地开发时自动回退到当前目录
15
- # ─────────────────────────────────────────────
16
 
17
  def get_db_path() -> str:
18
  if os.path.exists("/data"):
@@ -21,9 +16,8 @@ def get_db_path() -> str:
21
 
22
 
23
  def get_connection() -> sqlite3.Connection:
24
- """获取数据库连接,启用 WAL 模式提升并发性能"""
25
  conn = sqlite3.connect(get_db_path(), check_same_thread=False)
26
- conn.row_factory = sqlite3.Row # 支持按列名访问
27
  conn.execute("PRAGMA journal_mode=WAL")
28
  conn.execute("PRAGMA foreign_keys=ON")
29
  return conn
@@ -36,10 +30,10 @@ def get_connection() -> sqlite3.Connection:
36
  SQL_CREATE_MODELS = """
37
  CREATE TABLE IF NOT EXISTS models (
38
  model_id TEXT PRIMARY KEY,
39
- model_type TEXT, -- gemma4 / llama / qwen2 等
40
  analyzed_at TIMESTAMP,
41
- analyze_sec REAL, -- 分析耗时(秒)
42
- notes TEXT -- 备注
43
  );
44
  """
45
 
@@ -47,13 +41,14 @@ SQL_CREATE_COMPONENTS = """
47
  CREATE TABLE IF NOT EXISTS components (
48
  id INTEGER PRIMARY KEY AUTOINCREMENT,
49
  model_id TEXT NOT NULL,
50
- prefix TEXT NOT NULL, -- 如 model.language_model.
51
- n_layers INTEGER, -- 该组件完整层数
52
- head_dim_min INTEGER, -- 最小 head_dim(异构层用)
53
- head_dim_max INTEGER, -- 最大 head_dim
54
- has_kv_shared INTEGER DEFAULT 0, -- 是否有 K=V 共享层(全局层)
55
- has_global INTEGER DEFAULT 0, -- 是否有 global 层
56
- d_model INTEGER, -- 输入维度
 
57
  UNIQUE(model_id, prefix),
58
  FOREIGN KEY(model_id) REFERENCES models(model_id)
59
  );
@@ -65,48 +60,32 @@ CREATE TABLE IF NOT EXISTS layer_head_metrics (
65
  model_id TEXT NOT NULL,
66
  prefix TEXT NOT NULL,
67
  layer INTEGER NOT NULL,
68
- layer_type TEXT DEFAULT 'standard', -- standard / global
 
69
  kv_head INTEGER NOT NULL,
70
  q_head INTEGER NOT NULL,
71
- kv_shared INTEGER DEFAULT 0, -- 1=K=V共享(理论值),0=独立V
72
  head_dim INTEGER,
73
  d_model INTEGER,
74
  n_q_heads INTEGER,
75
  n_kv_heads INTEGER,
76
- -- 第一定律:谱线性对齐
77
- pearson_QK REAL,
78
- spearman_QK REAL,
79
- pearson_QV REAL,
80
- pearson_KV REAL,
81
- -- 第定律:谱形状残差
82
- ssr_QK REAL,
83
- ssr_QV REAL,
84
- ssr_KV REAL,
85
- -- 第定律:条件数
86
- sigma_max_Q REAL,
87
- sigma_min_Q REAL,
88
- cond_Q REAL,
89
- sigma_max_K REAL,
90
- sigma_min_K REAL,
91
- cond_K REAL,
92
- sigma_max_V REAL,
93
- sigma_min_V REAL,
94
- cond_V REAL,
95
- -- 第四定律:左奇异向量对齐(输出子空间)
96
- cosU_QK REAL,
97
- cosU_QV REAL,
98
- cosU_KV REAL,
99
- -- 第五定律:右奇异向量对齐(输入子空间)
100
- cosV_QK REAL,
101
- cosV_QV REAL,
102
- cosV_KV REAL,
103
- -- 尺度因子 + 最小二乘残差
104
- alpha_QK REAL,
105
- alpha_res_QK REAL,
106
- alpha_QV REAL,
107
- alpha_res_QV REAL,
108
- alpha_KV REAL,
109
- alpha_res_KV REAL,
110
 
111
  UNIQUE(model_id, prefix, layer, kv_head, q_head),
112
  FOREIGN KEY(model_id) REFERENCES models(model_id)
@@ -117,29 +96,23 @@ SQL_CREATE_MODEL_SUMMARY = """
117
  CREATE TABLE IF NOT EXISTS model_summary (
118
  model_id TEXT NOT NULL,
119
  prefix TEXT NOT NULL,
120
- layer_type TEXT NOT NULL DEFAULT 'all', -- all / standard / global
121
  -- 第一定律
122
- median_pearson_QK REAL,
123
- mean_pearson_QK REAL,
124
- -- 第二定律(王氏评分核心)
125
- median_ssr_QK REAL,
126
- mean_ssr_QK REAL,
127
- median_ssr_QV REAL,
128
- mean_ssr_QV REAL,
129
  -- 第三定律
130
- median_cond_Q REAL,
131
- mean_cond_Q REAL,
132
  -- 第四定律
133
- median_cosU_QK REAL,
134
- median_cosU_QV REAL,
135
  -- 第五定律
136
- median_cosV_QK REAL,
137
- median_cosV_QV REAL,
138
- -- 王氏评分(暂时 = 1 - median_ssr_QK,基于 standard 层)
139
  wang_score REAL,
140
  -- 统计范围
141
- n_layers INTEGER, -- 参与统计的层数
142
- n_records INTEGER, -- 参与统计的记录数
143
  updated_at TIMESTAMP,
144
 
145
  PRIMARY KEY(model_id, prefix, layer_type),
@@ -147,32 +120,91 @@ CREATE TABLE IF NOT EXISTS model_summary (
147
  );
148
  """
149
 
150
- # 索引:加速常用查询
151
  SQL_CREATE_INDEXES = [
152
- # 按模型+组件查询层数据
153
  """CREATE INDEX IF NOT EXISTS idx_metrics_model_prefix
154
  ON layer_head_metrics(model_id, prefix)""",
155
- # 按层号范围查询
156
  """CREATE INDEX IF NOT EXISTS idx_metrics_layer
157
  ON layer_head_metrics(model_id, prefix, layer)""",
158
- # 排行榜查询
 
159
  """CREATE INDEX IF NOT EXISTS idx_summary_wang_score
160
  ON model_summary(wang_score DESC)""",
161
- # 断点续传:快速判断某层是否已分析
162
  """CREATE INDEX IF NOT EXISTS idx_metrics_resume
163
  ON layer_head_metrics(model_id, prefix, layer, kv_head, q_head)""",
 
 
164
  ]
165
 
166
 
167
  # ─────────────────────────────────────────────
168
- # 初始化函
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # ─────────────────────────────────────────────
170
 
171
  def init_db() -> sqlite3.Connection:
172
  """
173
- 初始化数据库:建表 + 建索引
174
  幂等操作,重复调用安全
175
- 返回数据库连接
176
  """
177
  conn = get_connection()
178
  cur = conn.cursor()
@@ -186,23 +218,22 @@ def init_db() -> sqlite3.Connection:
186
  cur.execute(sql)
187
 
188
  conn.commit()
 
 
 
 
189
  return conn
190
 
191
 
192
  def get_db_stats(conn: sqlite3.Connection) -> dict:
193
- """获取数据库统计信息"""
194
  cur = conn.cursor()
195
  stats = {}
196
-
197
  for table in ["models", "components", "layer_head_metrics", "model_summary"]:
198
  cur.execute(f"SELECT COUNT(*) FROM {table}")
199
  stats[table] = cur.fetchone()[0]
200
-
201
- # 数据库文件大小
202
  db_path = get_db_path()
203
  if os.path.exists(db_path):
204
  stats["db_size_mb"] = round(os.path.getsize(db_path) / 1024 / 1024, 2)
205
  else:
206
  stats["db_size_mb"] = 0
207
-
208
  return stats
 
8
  import os
9
  from datetime import datetime
10
 
 
 
 
 
 
11
 
12
  def get_db_path() -> str:
13
  if os.path.exists("/data"):
 
16
 
17
 
18
  def get_connection() -> sqlite3.Connection:
 
19
  conn = sqlite3.connect(get_db_path(), check_same_thread=False)
20
+ conn.row_factory = sqlite3.Row
21
  conn.execute("PRAGMA journal_mode=WAL")
22
  conn.execute("PRAGMA foreign_keys=ON")
23
  return conn
 
30
  SQL_CREATE_MODELS = """
31
  CREATE TABLE IF NOT EXISTS models (
32
  model_id TEXT PRIMARY KEY,
33
+ model_type TEXT,
34
  analyzed_at TIMESTAMP,
35
+ analyze_sec REAL,
36
+ notes TEXT
37
  );
38
  """
39
 
 
41
  CREATE TABLE IF NOT EXISTS components (
42
  id INTEGER PRIMARY KEY AUTOINCREMENT,
43
  model_id TEXT NOT NULL,
44
+ prefix TEXT NOT NULL,
45
+ modality TEXT DEFAULT 'language', -- language/vision/audio
46
+ n_layers INTEGER,
47
+ head_dim_min INTEGER,
48
+ head_dim_max INTEGER,
49
+ has_kv_shared INTEGER DEFAULT 0,
50
+ has_global INTEGER DEFAULT 0,
51
+ d_model INTEGER,
52
  UNIQUE(model_id, prefix),
53
  FOREIGN KEY(model_id) REFERENCES models(model_id)
54
  );
 
60
  model_id TEXT NOT NULL,
61
  prefix TEXT NOT NULL,
62
  layer INTEGER NOT NULL,
63
+ layer_type TEXT DEFAULT 'standard', -- standard/global
64
+ modality TEXT DEFAULT 'language', -- language/vision/audio
65
  kv_head INTEGER NOT NULL,
66
  q_head INTEGER NOT NULL,
67
+ kv_shared INTEGER DEFAULT 0,
68
  head_dim INTEGER,
69
  d_model INTEGER,
70
  n_q_heads INTEGER,
71
  n_kv_heads INTEGER,
72
+ -- 第一定律
73
+ pearson_QK REAL, spearman_QK REAL,
74
+ pearson_QV REAL, pearson_KV REAL,
75
+ -- 第二定律
76
+ ssr_QK REAL, ssr_QV REAL, ssr_KV REAL,
77
+ -- 第定律
78
+ sigma_max_Q REAL, sigma_min_Q REAL, cond_Q REAL,
79
+ sigma_max_K REAL, sigma_min_K REAL, cond_K REAL,
80
+ sigma_max_V REAL, sigma_min_V REAL, cond_V REAL,
81
+ -- 第定律
82
+ cosU_QK REAL, cosU_QV REAL, cosU_KV REAL,
83
+ -- 第五定律
84
+ cosV_QK REAL, cosV_QV REAL, cosV_KV REAL,
85
+ -- 尺度因子
86
+ alpha_QK REAL, alpha_res_QK REAL,
87
+ alpha_QV REAL, alpha_res_QV REAL,
88
+ alpha_KV REAL, alpha_res_KV REAL,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  UNIQUE(model_id, prefix, layer, kv_head, q_head),
91
  FOREIGN KEY(model_id) REFERENCES models(model_id)
 
96
  CREATE TABLE IF NOT EXISTS model_summary (
97
  model_id TEXT NOT NULL,
98
  prefix TEXT NOT NULL,
99
+ layer_type TEXT NOT NULL DEFAULT 'all',
100
  -- 第一定律
101
+ median_pearson_QK REAL, mean_pearson_QK REAL,
102
+ -- 第二定律
103
+ median_ssr_QK REAL, mean_ssr_QK REAL,
104
+ median_ssr_QV REAL, mean_ssr_QV REAL,
 
 
 
105
  -- 第三定律
106
+ median_cond_Q REAL, mean_cond_Q REAL,
 
107
  -- 第四定律
108
+ median_cosU_QK REAL, median_cosU_QV REAL,
 
109
  -- 第五定律
110
+ median_cosV_QK REAL, median_cosV_QV REAL,
111
+ -- 王氏评分
 
112
  wang_score REAL,
113
  -- 统计范围
114
+ n_layers INTEGER,
115
+ n_records INTEGER,
116
  updated_at TIMESTAMP,
117
 
118
  PRIMARY KEY(model_id, prefix, layer_type),
 
120
  );
121
  """
122
 
 
123
  SQL_CREATE_INDEXES = [
 
124
  """CREATE INDEX IF NOT EXISTS idx_metrics_model_prefix
125
  ON layer_head_metrics(model_id, prefix)""",
 
126
  """CREATE INDEX IF NOT EXISTS idx_metrics_layer
127
  ON layer_head_metrics(model_id, prefix, layer)""",
128
+ """CREATE INDEX IF NOT EXISTS idx_metrics_modality
129
+ ON layer_head_metrics(model_id, modality)""",
130
  """CREATE INDEX IF NOT EXISTS idx_summary_wang_score
131
  ON model_summary(wang_score DESC)""",
 
132
  """CREATE INDEX IF NOT EXISTS idx_metrics_resume
133
  ON layer_head_metrics(model_id, prefix, layer, kv_head, q_head)""",
134
+ """CREATE INDEX IF NOT EXISTS idx_components_modality
135
+ ON components(model_id, modality)""",
136
  ]
137
 
138
 
139
  # ─────────────────────────────────────────────
140
+ # 迁移:为旧据库加 modality 列
141
+ # ─────────────────────────────────────────────
142
+
143
+ def _migrate_add_modality(conn: sqlite3.Connection):
144
+ """
145
+ 幂等迁移:给旧表加 modality 列并回填数据。
146
+ 新建数据库时这些列已在建表SQL中,PRAGMA会检测到直接跳过。
147
+ """
148
+ cur = conn.cursor()
149
+
150
+ # ── layer_head_metrics ────────────────────
151
+ cur.execute("PRAGMA table_info(layer_head_metrics)")
152
+ lhm_cols = [row[1] for row in cur.fetchall()]
153
+
154
+ if "modality" not in lhm_cols:
155
+ cur.execute(
156
+ "ALTER TABLE layer_head_metrics "
157
+ "ADD COLUMN modality TEXT DEFAULT 'language'"
158
+ )
159
+ # 回填 vision
160
+ cur.execute(
161
+ """UPDATE layer_head_metrics SET modality = 'vision'
162
+ WHERE prefix LIKE '%vision%'
163
+ OR prefix LIKE '%visual%'
164
+ OR prefix LIKE '%image%'"""
165
+ )
166
+ # 回填 audio
167
+ cur.execute(
168
+ """UPDATE layer_head_metrics SET modality = 'audio'
169
+ WHERE prefix LIKE '%audio%'
170
+ OR prefix LIKE '%speech%'
171
+ OR prefix LIKE '%acoustic%'"""
172
+ )
173
+ # language 已由 DEFAULT 'language' 覆盖,无需额外更新
174
+
175
+ # ── components ────────────────────────────
176
+ cur.execute("PRAGMA table_info(components)")
177
+ comp_cols = [row[1] for row in cur.fetchall()]
178
+
179
+ if "modality" not in comp_cols:
180
+ cur.execute(
181
+ "ALTER TABLE components "
182
+ "ADD COLUMN modality TEXT DEFAULT 'language'"
183
+ )
184
+ cur.execute(
185
+ """UPDATE components SET modality = 'vision'
186
+ WHERE prefix LIKE '%vision%'
187
+ OR prefix LIKE '%visual%'
188
+ OR prefix LIKE '%image%'"""
189
+ )
190
+ cur.execute(
191
+ """UPDATE components SET modality = 'audio'
192
+ WHERE prefix LIKE '%audio%'
193
+ OR prefix LIKE '%speech%'
194
+ OR prefix LIKE '%acoustic%'"""
195
+ )
196
+
197
+ conn.commit()
198
+
199
+
200
+ # ─────────────────────────────────────────────
201
+ # 初始化
202
  # ─────────────────────────────────────────────
203
 
204
  def init_db() -> sqlite3.Connection:
205
  """
206
+ 初始化数据库:建表 + 建索引 + 迁移旧数据
207
  幂等操作,重复调用安全
 
208
  """
209
  conn = get_connection()
210
  cur = conn.cursor()
 
218
  cur.execute(sql)
219
 
220
  conn.commit()
221
+
222
+ # 旧数据库迁移(新库此函数为空操作)
223
+ _migrate_add_modality(conn)
224
+
225
  return conn
226
 
227
 
228
  def get_db_stats(conn: sqlite3.Connection) -> dict:
 
229
  cur = conn.cursor()
230
  stats = {}
 
231
  for table in ["models", "components", "layer_head_metrics", "model_summary"]:
232
  cur.execute(f"SELECT COUNT(*) FROM {table}")
233
  stats[table] = cur.fetchone()[0]
 
 
234
  db_path = get_db_path()
235
  if os.path.exists(db_path):
236
  stats["db_size_mb"] = round(os.path.getsize(db_path) / 1024 / 1024, 2)
237
  else:
238
  stats["db_size_mb"] = 0
 
239
  return stats
db/writer.py CHANGED
@@ -4,77 +4,98 @@
4
  - 写入分析结果到 layer_head_metrics
5
  - 计算并写入 model_summary
6
  - 支持断点续传(以 prefix+layer 为粒度)
 
7
  """
8
 
 
9
  import sqlite3
10
  import numpy as np
11
  from datetime import datetime
12
  from db.schema import get_connection, init_db
13
- import os
14
 
15
 
16
  # ─────────────────────────────────────────────
17
- # layer_type 推断
18
  # ─────────────────────────────────────────────
19
 
20
  def infer_layer_type(kv_shared: bool) -> str:
21
  """
22
- kv_shared 推断层类型
23
- kv_shared=True → 'global' (K=V共享,如 Gemma-4-31B 全局层)
24
  kv_shared=False → 'standard'
25
- 零 hard coding,纯从结构特征推断
26
  """
27
  return "global" if kv_shared else "standard"
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ─────────────────────────────────────────────
31
- # 断点续传:检查已完成的层
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # ─────────────────────────────────────────────
33
 
34
  def get_analyzed_layers(
35
  conn: sqlite3.Connection,
36
  model_id: str,
37
  prefix: str,
38
- ) -> set[int]:
39
- """
40
- 返回已完成分析的层号集合
41
- 用于断点续传:跳过已有数据的层
42
- 粒度:(model_id, prefix, layer)
43
- """
44
  cur = conn.cursor()
45
  cur.execute(
46
- """
47
- SELECT DISTINCT layer
48
- FROM layer_head_metrics
49
- WHERE model_id = ? AND prefix = ?
50
- """,
51
  (model_id, prefix)
52
  )
53
  return {row[0] for row in cur.fetchall()}
54
 
55
 
56
  def is_layer_complete(
57
- conn: sqlite3.Connection,
58
- model_id: str,
59
- prefix: str,
60
- layer: int,
61
  expected_records: int,
62
  ) -> bool:
63
- """
64
- 检查某层是否已完整写入
65
- expected_records = n_q_heads(该层应有的记录数)
66
- """
67
  cur = conn.cursor()
68
  cur.execute(
69
- """
70
- SELECT COUNT(*)
71
- FROM layer_head_metrics
72
- WHERE model_id = ? AND prefix = ? AND layer = ?
73
- """,
74
  (model_id, prefix, layer)
75
  )
76
- actual = cur.fetchone()[0]
77
- return actual >= expected_records
78
 
79
 
80
  # ─────────────────────────────────────────────
@@ -87,51 +108,47 @@ def upsert_model(
87
  model_type: str = None,
88
  notes: str = None,
89
  ):
90
- """写入或更新模型基本信息"""
91
  conn.execute(
92
- """
93
- INSERT INTO models(model_id, model_type, analyzed_at, notes)
94
- VALUES(?, ?, ?, ?)
95
- ON CONFLICT(model_id) DO UPDATE SET
96
- model_type = excluded.model_type,
97
- analyzed_at = excluded.analyzed_at,
98
- notes = excluded.notes
99
- """,
100
  (model_id, model_type, datetime.utcnow().isoformat(), notes)
101
  )
102
  conn.commit()
103
 
104
 
105
  def upsert_component(
106
- conn: sqlite3.Connection,
107
- model_id: str,
108
- prefix: str,
109
- n_layers: int,
110
- head_dim_min: int,
111
- head_dim_max: int,
112
- has_kv_shared:bool,
113
- has_global: bool,
114
- d_model: int,
115
  ):
116
- """写入或更新组件信息"""
117
  conn.execute(
118
- """
119
- INSERT INTO components(
120
- model_id, prefix, n_layers,
121
- head_dim_min, head_dim_max,
122
- has_kv_shared, has_global, d_model
123
- )
124
- VALUES(?, ?, ?, ?, ?, ?, ?, ?)
125
- ON CONFLICT(model_id, prefix) DO UPDATE SET
126
- n_layers = excluded.n_layers,
127
- head_dim_min = excluded.head_dim_min,
128
- head_dim_max = excluded.head_dim_max,
129
- has_kv_shared = excluded.has_kv_shared,
130
- has_global = excluded.has_global,
131
- d_model = excluded.d_model
132
- """,
133
  (
134
- model_id, prefix, n_layers,
135
  head_dim_min, head_dim_max,
136
  int(has_kv_shared), int(has_global), d_model
137
  )
@@ -148,21 +165,20 @@ def write_layer_records(
148
  model_id: str,
149
  records: list[dict],
150
  ):
151
- """
152
- 批量写入一层的逐头指标
153
- 使用 INSERT OR REPLACE 实现幂等写入
154
- """
155
  if not records:
156
  return
157
 
158
  rows = []
159
  for r in records:
160
  layer_type = infer_layer_type(bool(r.get("kv_shared", False)))
 
161
  rows.append((
162
  model_id,
163
  r["prefix"],
164
  r["layer"],
165
  layer_type,
 
166
  r["kv_head"],
167
  r["q_head"],
168
  int(r.get("kv_shared", False)),
@@ -170,66 +186,41 @@ def write_layer_records(
170
  r.get("d_model"),
171
  r.get("n_q_heads"),
172
  r.get("n_kv_heads"),
173
- # 第一定律
174
- r.get("pearson_QK"),
175
- r.get("spearman_QK"),
176
- r.get("pearson_QV"),
177
- r.get("pearson_KV"),
178
- # 第二定律
179
- r.get("ssr_QK"),
180
- r.get("ssr_QV"),
181
- r.get("ssr_KV"),
182
- # 第三定律
183
- r.get("sigma_max_Q"),
184
- r.get("sigma_min_Q"),
185
- r.get("cond_Q"),
186
- r.get("sigma_max_K"),
187
- r.get("sigma_min_K"),
188
- r.get("cond_K"),
189
- r.get("sigma_max_V"),
190
- r.get("sigma_min_V"),
191
- r.get("cond_V"),
192
- # 第四定律
193
- r.get("cosU_QK"),
194
- r.get("cosU_QV"),
195
- r.get("cosU_KV"),
196
- # 第五定律
197
- r.get("cosV_QK"),
198
- r.get("cosV_QV"),
199
- r.get("cosV_KV"),
200
- # 尺度因子
201
- r.get("alpha_QK"),
202
- r.get("alpha_res_QK"),
203
- r.get("alpha_QV"),
204
- r.get("alpha_res_QV"),
205
- r.get("alpha_KV"),
206
- r.get("alpha_res_KV"),
207
  ))
208
 
209
  conn.executemany(
210
- """
211
- INSERT OR REPLACE INTO layer_head_metrics(
212
- model_id, prefix, layer, layer_type,
213
- kv_head, q_head, kv_shared,
214
- head_dim, d_model, n_q_heads, n_kv_heads,
215
- pearson_QK, spearman_QK, pearson_QV, pearson_KV,
216
- ssr_QK, ssr_QV, ssr_KV,
217
- sigma_max_Q, sigma_min_Q, cond_Q,
218
- sigma_max_K, sigma_min_K, cond_K,
219
- sigma_max_V, sigma_min_V, cond_V,
220
- cosU_QK, cosU_QV, cosU_KV,
221
- cosV_QK, cosV_QV, cosV_KV,
222
- alpha_QK, alpha_res_QK,
223
- alpha_QV, alpha_res_QV,
224
- alpha_KV, alpha_res_KV
225
- ) VALUES (
226
- ?,?,?,?,?,?,?,?,?,?,?,
227
- ?,?,?,?,?,?,?,
228
- ?,?,?,?,?,?,?,?,?,
229
- ?,?,?,?,?,?,
230
- ?,?,?,?,?,?
231
- )
232
- """,
233
  rows
234
  )
235
  conn.commit()
@@ -240,34 +231,25 @@ def write_layer_records(
240
  # ─────────────────────────────────────────────
241
 
242
  def _calc_summary_row(
243
- rows: list[sqlite3.Row],
244
- model_id: str,
245
- prefix: str,
246
  layer_type: str,
247
  ) -> dict | None:
248
- """
249
- 从一组 layer_head_metrics 行计算汇总统计
250
- 返回 model_summary 的一行
251
- """
252
  if not rows:
253
  return None
254
 
255
  def col(name):
256
  vals = [r[name] for r in rows if r[name] is not None]
257
- return np.array(vals) if vals else np.array([])
258
-
259
- def med(arr):
260
- return float(np.median(arr)) if len(arr) > 0 else None
261
 
262
- def avg(arr):
263
- return float(np.mean(arr)) if len(arr) > 0 else None
264
 
265
- ssr_qk = col("ssr_QK")
266
  wang_score = float(1 - np.median(ssr_qk)) if len(ssr_qk) > 0 else None
267
-
268
- # 统计层数(去重)
269
- n_layers = len(set(r["layer"] for r in rows))
270
- n_records = len(rows)
271
 
272
  return {
273
  "model_id": model_id,
@@ -298,106 +280,65 @@ def update_model_summary(
298
  prefix: str,
299
  ):
300
  """
301
- 重新计算并写入 model_summary
302
- 对每个 (model_id, prefix) 生成三行:
303
- - layer_type='all'
304
- - layer_type='standard'
305
- - layer_type='global'
306
- 王氏评分固定用 standard 层计算
307
  """
308
  cur = conn.cursor()
309
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  for layer_type in ["all", "standard", "global"]:
311
- # 查询对应数据
312
  if layer_type == "all":
313
  cur.execute(
314
- """
315
- SELECT * FROM layer_head_metrics
316
- WHERE model_id = ? AND prefix = ?
317
- """,
318
  (model_id, prefix)
319
  )
320
  else:
321
  cur.execute(
322
- """
323
- SELECT * FROM layer_head_metrics
324
- WHERE model_id = ? AND prefix = ? AND layer_type = ?
325
- """,
326
  (model_id, prefix, layer_type)
327
  )
328
 
329
- rows = cur.fetchall()
330
  summary = _calc_summary_row(rows, model_id, prefix, layer_type)
331
-
332
  if summary is None:
333
  continue
334
 
335
- # 王氏评分统一用 standard 层(如果当前是 all/global,重新取 standard 的 ssr)
336
- if layer_type != "standard":
337
- cur.execute(
338
- """
339
- SELECT ssr_QK FROM layer_head_metrics
340
- WHERE model_id = ? AND prefix = ? AND layer_type = 'standard'
341
- """,
342
- (model_id, prefix)
343
- )
344
- std_rows = cur.fetchall()
345
- if std_rows:
346
- std_ssr = np.array([r[0] for r in std_rows if r[0] is not None])
347
- summary["wang_score"] = float(1 - np.median(std_ssr)) if len(std_ssr) > 0 else None
348
 
349
  conn.execute(
350
- """
351
- INSERT OR REPLACE INTO model_summary(
352
- model_id, prefix, layer_type,
353
- median_pearson_QK, mean_pearson_QK,
354
- median_ssr_QK, mean_ssr_QK,
355
- median_ssr_QV, mean_ssr_QV,
356
- median_cond_Q, mean_cond_Q,
357
- median_cosU_QK, median_cosU_QV,
358
- median_cosV_QK, median_cosV_QV,
359
- wang_score,
360
- n_layers, n_records, updated_at
361
- ) VALUES (
362
- :model_id, :prefix, :layer_type,
363
- :median_pearson_QK, :mean_pearson_QK,
364
- :median_ssr_QK, :mean_ssr_QK,
365
- :median_ssr_QV, :mean_ssr_QV,
366
- :median_cond_Q, :mean_cond_Q,
367
- :median_cosU_QK, :median_cosU_QV,
368
- :median_cosV_QK, :median_cosV_QV,
369
- :wang_score,
370
- :n_layers, :n_records, :updated_at
371
- )
372
- """,
373
  summary
374
  )
375
 
376
- conn.commit()
377
-
378
- # 在 db/writer.py 末尾追加
379
-
380
-
381
-
382
- # ─────────────────────────────────────────────
383
- # 写入权限验证
384
- # ─────────────────────────────────────────────
385
-
386
- def check_write_permission(admin_token: str) -> bool:
387
- """
388
- 验证管理员写入权限。
389
-
390
- 原理:
391
- - WRITE_TOKEN 存储在 HF Space Secrets(加密,不进入 git repo)
392
- - 运行时由 HF 注入为环境变量
393
- - 只在服务端比对,不返回给前端
394
-
395
- 返回:
396
- - True = 有写入权限
397
- - False = 只读模式(分析可以跑,结果不写库)
398
- """
399
- server_token = os.environ.get("WRITE_TOKEN", "")
400
- if not server_token:
401
- # 服务端未配置 WRITE_TOKEN → 拒绝所有写入
402
- return False
403
- return admin_token.strip() == server_token
 
4
  - 写入分析结果到 layer_head_metrics
5
  - 计算并写入 model_summary
6
  - 支持断点续传(以 prefix+layer 为粒度)
7
+ - 写入权限验证
8
  """
9
 
10
+ import os
11
  import sqlite3
12
  import numpy as np
13
  from datetime import datetime
14
  from db.schema import get_connection, init_db
 
15
 
16
 
17
  # ─────────────────────────────────────────────
18
+ # 推断函数:layer_type 和 modality
19
  # ─────────────────────────────────────────────
20
 
21
  def infer_layer_type(kv_shared: bool) -> str:
22
  """
23
+ 结构特征推断层类型
24
+ kv_shared=True → 'global' (K=V共享,如 Gemma 全局层)
25
  kv_shared=False → 'standard'
 
26
  """
27
  return "global" if kv_shared else "standard"
28
 
29
 
30
+ def infer_modality(prefix: str) -> str:
31
+ """
32
+ 从组件前缀推断模态
33
+ 纯关键词匹配,不 hard coding 模型名
34
+ 未匹配到任何关键词 → 默认 'language'
35
+ (覆盖纯语言模型,如 "model." 前缀的 LLaMA/Qwen)
36
+ """
37
+ p = prefix.lower()
38
+ if "vision" in p or "visual" in p or "image" in p:
39
+ return "vision"
40
+ if "audio" in p or "speech" in p or "acoustic" in p:
41
+ return "audio"
42
+ return "language"
43
+
44
+
45
  # ─────────────────────────────────────────────
46
+ # 写入权限验证
47
+ # ─────────────────────────────────────────────
48
+
49
+ def check_write_permission(admin_token: str) -> bool:
50
+ """
51
+ 验证管理员写入权限。
52
+ WRITE_TOKEN 存储在 HF Space Secrets(加密,不进入 git repo)。
53
+ 运行时由 HF 注入为环境变量,只在服务端比对,不返回给前端。
54
+
55
+ 返回:
56
+ True = 有写入权限
57
+ False = 只读模式(分析可以跑,结果不写库)
58
+ """
59
+ server_token = os.environ.get("WRITE_TOKEN", "")
60
+ if not server_token:
61
+ return False
62
+ return admin_token.strip() == server_token
63
+
64
+
65
+ # ─────────────────────────────────────────────
66
+ # 断点续传
67
  # ─────────────────────────────────────────────
68
 
69
  def get_analyzed_layers(
70
  conn: sqlite3.Connection,
71
  model_id: str,
72
  prefix: str,
73
+ ) -> set:
74
+ """返回已完成分析的层号集合"""
 
 
 
 
75
  cur = conn.cursor()
76
  cur.execute(
77
+ """SELECT DISTINCT layer FROM layer_head_metrics
78
+ WHERE model_id = ? AND prefix = ?""",
 
 
 
79
  (model_id, prefix)
80
  )
81
  return {row[0] for row in cur.fetchall()}
82
 
83
 
84
  def is_layer_complete(
85
+ conn: sqlite3.Connection,
86
+ model_id: str,
87
+ prefix: str,
88
+ layer: int,
89
  expected_records: int,
90
  ) -> bool:
91
+ """检查某层是否已完整写入"""
 
 
 
92
  cur = conn.cursor()
93
  cur.execute(
94
+ """SELECT COUNT(*) FROM layer_head_metrics
95
+ WHERE model_id = ? AND prefix = ? AND layer = ?""",
 
 
 
96
  (model_id, prefix, layer)
97
  )
98
+ return cur.fetchone()[0] >= expected_records
 
99
 
100
 
101
  # ─────────────────────────────────────────────
 
108
  model_type: str = None,
109
  notes: str = None,
110
  ):
 
111
  conn.execute(
112
+ """INSERT INTO models(model_id, model_type, analyzed_at, notes)
113
+ VALUES(?, ?, ?, ?)
114
+ ON CONFLICT(model_id) DO UPDATE SET
115
+ model_type = excluded.model_type,
116
+ analyzed_at = excluded.analyzed_at,
117
+ notes = excluded.notes""",
 
 
118
  (model_id, model_type, datetime.utcnow().isoformat(), notes)
119
  )
120
  conn.commit()
121
 
122
 
123
  def upsert_component(
124
+ conn: sqlite3.Connection,
125
+ model_id: str,
126
+ prefix: str,
127
+ n_layers: int,
128
+ head_dim_min: int,
129
+ head_dim_max: int,
130
+ has_kv_shared: bool,
131
+ has_global: bool,
132
+ d_model: int,
133
  ):
134
+ modality = infer_modality(prefix)
135
  conn.execute(
136
+ """INSERT INTO components(
137
+ model_id, prefix, modality, n_layers,
138
+ head_dim_min, head_dim_max,
139
+ has_kv_shared, has_global, d_model
140
+ )
141
+ VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)
142
+ ON CONFLICT(model_id, prefix) DO UPDATE SET
143
+ modality = excluded.modality,
144
+ n_layers = excluded.n_layers,
145
+ head_dim_min = excluded.head_dim_min,
146
+ head_dim_max = excluded.head_dim_max,
147
+ has_kv_shared = excluded.has_kv_shared,
148
+ has_global = excluded.has_global,
149
+ d_model = excluded.d_model""",
 
150
  (
151
+ model_id, prefix, modality, n_layers,
152
  head_dim_min, head_dim_max,
153
  int(has_kv_shared), int(has_global), d_model
154
  )
 
165
  model_id: str,
166
  records: list[dict],
167
  ):
168
+ """批量写入一层的逐头指标,幂等"""
 
 
 
169
  if not records:
170
  return
171
 
172
  rows = []
173
  for r in records:
174
  layer_type = infer_layer_type(bool(r.get("kv_shared", False)))
175
+ modality = infer_modality(r["prefix"])
176
  rows.append((
177
  model_id,
178
  r["prefix"],
179
  r["layer"],
180
  layer_type,
181
+ modality,
182
  r["kv_head"],
183
  r["q_head"],
184
  int(r.get("kv_shared", False)),
 
186
  r.get("d_model"),
187
  r.get("n_q_heads"),
188
  r.get("n_kv_heads"),
189
+ r.get("pearson_QK"), r.get("spearman_QK"),
190
+ r.get("pearson_QV"), r.get("pearson_KV"),
191
+ r.get("ssr_QK"), r.get("ssr_QV"), r.get("ssr_KV"),
192
+ r.get("sigma_max_Q"), r.get("sigma_min_Q"), r.get("cond_Q"),
193
+ r.get("sigma_max_K"), r.get("sigma_min_K"), r.get("cond_K"),
194
+ r.get("sigma_max_V"), r.get("sigma_min_V"), r.get("cond_V"),
195
+ r.get("cosU_QK"), r.get("cosU_QV"), r.get("cosU_KV"),
196
+ r.get("cosV_QK"), r.get("cosV_QV"), r.get("cosV_KV"),
197
+ r.get("alpha_QK"), r.get("alpha_res_QK"),
198
+ r.get("alpha_QV"), r.get("alpha_res_QV"),
199
+ r.get("alpha_KV"), r.get("alpha_res_KV"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  ))
201
 
202
  conn.executemany(
203
+ """INSERT OR REPLACE INTO layer_head_metrics(
204
+ model_id, prefix, layer, layer_type, modality,
205
+ kv_head, q_head, kv_shared,
206
+ head_dim, d_model, n_q_heads, n_kv_heads,
207
+ pearson_QK, spearman_QK, pearson_QV, pearson_KV,
208
+ ssr_QK, ssr_QV, ssr_KV,
209
+ sigma_max_Q, sigma_min_Q, cond_Q,
210
+ sigma_max_K, sigma_min_K, cond_K,
211
+ sigma_max_V, sigma_min_V, cond_V,
212
+ cosU_QK, cosU_QV, cosU_KV,
213
+ cosV_QK, cosV_QV, cosV_KV,
214
+ alpha_QK, alpha_res_QK,
215
+ alpha_QV, alpha_res_QV,
216
+ alpha_KV, alpha_res_KV
217
+ ) VALUES (
218
+ ?,?,?,?,?,?,?,?,?,?,?,?,
219
+ ?,?,?,?,?,?,?,
220
+ ?,?,?,?,?,?,?,?,?,
221
+ ?,?,?,?,?,?,
222
+ ?,?,?,?,?,?
223
+ )""",
 
 
224
  rows
225
  )
226
  conn.commit()
 
231
  # ─────────────────────────────────────────────
232
 
233
  def _calc_summary_row(
234
+ rows: list,
235
+ model_id: str,
236
+ prefix: str,
237
  layer_type: str,
238
  ) -> dict | None:
 
 
 
 
239
  if not rows:
240
  return None
241
 
242
  def col(name):
243
  vals = [r[name] for r in rows if r[name] is not None]
244
+ return np.array(vals, dtype=float) if vals else np.array([])
 
 
 
245
 
246
+ def med(arr): return float(np.median(arr)) if len(arr) > 0 else None
247
+ def avg(arr): return float(np.mean(arr)) if len(arr) > 0 else None
248
 
249
+ ssr_qk = col("ssr_QK")
250
  wang_score = float(1 - np.median(ssr_qk)) if len(ssr_qk) > 0 else None
251
+ n_layers = len(set(r["layer"] for r in rows))
252
+ n_records = len(rows)
 
 
253
 
254
  return {
255
  "model_id": model_id,
 
280
  prefix: str,
281
  ):
282
  """
283
+ 重新计算并写入 model_summary(all / standard / global 三行)
284
+ wang_score 统一用 standard 层计算
 
 
 
 
285
  """
286
  cur = conn.cursor()
287
 
288
+ # 预取 standard 层的 ssr_QK(wang_score 统一用这个)
289
+ cur.execute(
290
+ """SELECT ssr_QK FROM layer_head_metrics
291
+ WHERE model_id = ? AND prefix = ? AND layer_type = 'standard'""",
292
+ (model_id, prefix)
293
+ )
294
+ std_ssr_rows = cur.fetchall()
295
+ std_ssr = np.array(
296
+ [r[0] for r in std_ssr_rows if r[0] is not None], dtype=float
297
+ )
298
+ std_wang_score = float(1 - np.median(std_ssr)) if len(std_ssr) > 0 else None
299
+
300
  for layer_type in ["all", "standard", "global"]:
 
301
  if layer_type == "all":
302
  cur.execute(
303
+ "SELECT * FROM layer_head_metrics WHERE model_id=? AND prefix=?",
 
 
 
304
  (model_id, prefix)
305
  )
306
  else:
307
  cur.execute(
308
+ """SELECT * FROM layer_head_metrics
309
+ WHERE model_id=? AND prefix=? AND layer_type=?""",
 
 
310
  (model_id, prefix, layer_type)
311
  )
312
 
313
+ rows = cur.fetchall()
314
  summary = _calc_summary_row(rows, model_id, prefix, layer_type)
 
315
  if summary is None:
316
  continue
317
 
318
+ # wang_score 统一用 standard 层
319
+ summary["wang_score"] = std_wang_score
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  conn.execute(
322
+ """INSERT OR REPLACE INTO model_summary(
323
+ model_id, prefix, layer_type,
324
+ median_pearson_QK, mean_pearson_QK,
325
+ median_ssr_QK, mean_ssr_QK,
326
+ median_ssr_QV, mean_ssr_QV,
327
+ median_cond_Q, mean_cond_Q,
328
+ median_cosU_QK, median_cosU_QV,
329
+ median_cosV_QK, median_cosV_QV,
330
+ wang_score, n_layers, n_records, updated_at
331
+ ) VALUES (
332
+ :model_id, :prefix, :layer_type,
333
+ :median_pearson_QK, :mean_pearson_QK,
334
+ :median_ssr_QK, :mean_ssr_QK,
335
+ :median_ssr_QV, :mean_ssr_QV,
336
+ :median_cond_Q, :mean_cond_Q,
337
+ :median_cosU_QK, :median_cosU_QV,
338
+ :median_cosV_QK, :median_cosV_QV,
339
+ :wang_score, :n_layers, :n_records, :updated_at
340
+ )""",
 
 
 
 
341
  summary
342
  )
343
 
344
+ conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/tab_database.py CHANGED
@@ -1,9 +1,10 @@
1
  # ui/tab_database.py
2
  """
3
- Tab4:数据库浏览
4
- - 查看已分析模型列表
5
- - 查看某模型的逐层原始数据
6
- - 数据库统计信息
 
7
  """
8
 
9
  import gradio as gr
@@ -12,6 +13,7 @@ import pandas as pd
12
  from db.schema import init_db, get_db_stats
13
  from db.reader import (
14
  get_analyzed_models,
 
15
  get_model_summary,
16
  get_layer_metrics,
17
  get_resume_status,
@@ -19,92 +21,112 @@ from db.reader import (
19
 
20
 
21
  def load_db_stats() -> str:
22
- """获取数据库统计信息"""
23
  conn = init_db()
24
  stats = get_db_stats(conn)
25
  return (
26
- f"📊 数据库统计\n"
27
  f"{'─'*40}\n"
28
- f" 模型数: {stats.get('models', 0)}\n"
29
- f" 组件数: {stats.get('components', 0)}\n"
30
- f" 层头记录数: {stats.get('layer_head_metrics', 0)}\n"
31
- f" 汇总行数: {stats.get('model_summary', 0)}\n"
32
- f" 数据库大小: {stats.get('db_size_mb', 0)} MB\n"
33
  )
34
 
35
 
36
  def load_model_list() -> pd.DataFrame:
37
- """加载已分析模型列表"""
 
 
 
 
38
  conn = init_db()
39
  df = get_analyzed_models(conn)
40
  if df.empty:
41
- return pd.DataFrame(
42
- columns=["model_id", "model_type", "analyzed_at",
43
- "analyze_sec", "n_components", "total_layers"]
44
- )
 
 
 
45
  return df
46
 
47
 
48
- def load_model_detail(model_id: str) -> tuple[pd.DataFrame, pd.DataFrame, str]:
 
 
49
  """
50
- 加载模型详情
51
- 返回 (summary_df, 断点续传状态文本)
 
 
52
  """
53
  if not model_id.strip():
54
- return pd.DataFrame(), pd.DataFrame(), "请输入模型 ID"
55
 
56
  conn = init_db()
 
 
 
 
57
 
58
  # 汇总统计
59
- summary_df = get_model_summary(conn, model_id.strip())
60
 
61
- # 断点续传状态(按前缀)
62
- status_lines = [f"📍 断点续传状态:{model_id}\n{'─'*50}\n"]
63
- if not summary_df.empty:
64
- for pfx in summary_df["prefix"].unique():
65
- rs = get_resume_status(conn, model_id.strip(), pfx)
66
  status_lines.append(
67
  f" [{pfx}]\n"
68
- f" 已完成层数:{rs['total_done']}\n"
69
- f" 层号:{sorted(rs['done_layers'])}\n"
70
  )
71
  else:
72
- status_lines.append(" 暂无数据\n")
73
 
74
- return summary_df, "".join(status_lines)
75
 
76
 
77
  def load_layer_data(
78
- model_id: str,
79
- prefix: str,
80
- layer_type: str,
81
- start_layer: int,
82
- end_layer: int,
83
  ) -> tuple[pd.DataFrame, str]:
84
- """加载逐头原始数据"""
 
 
 
85
  if not model_id.strip():
86
- return pd.DataFrame(), "请输入模型 ID"
87
 
88
  conn = init_db()
 
89
  lt = layer_type if layer_type != "all" else None
90
- pfx = prefix.strip() or None
91
 
92
  df = get_layer_metrics(
93
  conn,
94
  model_id = model_id.strip(),
95
- prefix = pfx,
96
  layer_type = lt,
97
  start_layer = int(start_layer),
98
  end_layer = int(end_layer),
99
  )
100
 
101
  if df.empty:
102
- return pd.DataFrame(), f"⚠️ 无数据:model={model_id} prefix={pfx} layer_type={lt}"
 
 
 
103
 
104
  status = (
105
- f"✅ {len(df)} 条记录 "
106
- f"| {df['layer'].min()}~{df['layer'].max()} "
107
- f"| prefix={pfx or '全部'}"
108
  )
109
  return df, status
110
 
@@ -114,120 +136,149 @@ def load_layer_data(
114
  # ─────────────────────────────────────────────
115
 
116
  def build_tab_database():
117
- with gr.Tab("🗄️ 数据库"):
118
- gr.Markdown("## 数据库浏览 \n查看已分析模型的原始数据和汇总统计。")
 
 
 
 
119
 
120
- # ── 数据库统计 ──────────────────────────
121
  with gr.Row():
122
  stats_text = gr.Textbox(
123
- label="数据库统计",
124
- value="点击刷新",
125
  lines=7,
126
  interactive=False,
127
  scale=2,
128
  )
129
  refresh_stats_btn = gr.Button(
130
- "🔄 刷新统计", scale=1, variant="secondary"
131
  )
132
-
133
- refresh_stats_btn.click(
134
- fn=load_db_stats,
135
- outputs=stats_text,
136
- )
137
 
138
  gr.Markdown("---")
139
 
140
- # ── 已分析模型列表 ──────────────────────
141
- gr.Markdown("### 已分析模型")
142
- with gr.Row():
143
- refresh_models_btn = gr.Button(
144
- "🔄 刷新模型列表", variant="secondary"
145
- )
146
-
 
 
 
147
  models_table = gr.Dataframe(
148
- label="已分析模型",
 
 
 
 
149
  interactive=False,
150
  )
151
-
152
- refresh_models_btn.click(
153
- fn=load_model_list,
154
- outputs=models_table,
155
- )
156
 
157
  gr.Markdown("---")
158
 
159
- # ── 模型详情 ────────────────────────────
160
- gr.Markdown("### 模型详情 & 断点续传状态")
 
 
 
 
161
  with gr.Row():
162
  detail_model_id = gr.Textbox(
163
- label="模型 ID",
164
  placeholder="google/gemma-4-e2b",
165
  scale=3,
166
  )
167
  load_detail_btn = gr.Button(
168
- "📋 查看详情", variant="secondary", scale=1
169
  )
170
 
171
  resume_status_text = gr.Textbox(
172
- label="断点续传状态",
173
  lines=8,
174
  interactive=False,
175
  )
 
 
 
 
 
 
 
 
 
 
176
  summary_table = gr.Dataframe(
177
- label="模型汇总统计(all/standard/global 三行)",
178
  interactive=False,
179
  )
180
 
181
  load_detail_btn.click(
182
  fn=load_model_detail,
183
  inputs=[detail_model_id],
184
- outputs=[summary_table, resume_status_text],
185
  )
186
 
187
  gr.Markdown("---")
188
 
189
- # ── 逐头原始数据 ────────────────────────
190
- gr.Markdown("### 逐头原始数据查询")
 
 
 
 
191
  with gr.Row():
192
  raw_model_id = gr.Textbox(
193
- label="模型 ID",
194
  placeholder="google/gemma-4-e2b",
195
  scale=2,
196
  )
197
- raw_prefix = gr.Textbox(
198
- label="组件前缀(留空=全部)",
199
- placeholder="model.language_model.",
200
- scale=2,
 
 
201
  )
202
  raw_layer_type = gr.Dropdown(
203
- label="层类型",
204
  choices=["all", "standard", "global"],
205
  value="all",
206
  scale=1,
 
 
 
 
207
  )
208
  with gr.Row():
209
  raw_start = gr.Number(
210
- label="起始层号", value=0, precision=0, scale=1
211
  )
212
  raw_end = gr.Number(
213
- label="结束层号", value=10, precision=0, scale=1
214
  )
215
  load_raw_btn = gr.Button(
216
- "🔍 查询数据", variant="secondary", scale=1
217
  )
218
 
219
  raw_status = gr.Textbox(
220
- label="查询状态", lines=1, interactive=False
221
  )
222
  raw_table = gr.Dataframe(
223
- label="逐头原始数据",
224
  interactive=False,
225
  wrap=False,
226
  )
227
 
228
  load_raw_btn.click(
229
  fn=load_layer_data,
230
- inputs=[raw_model_id, raw_prefix, raw_layer_type,
231
- raw_start, raw_end],
 
 
232
  outputs=[raw_table, raw_status],
233
  )
 
1
  # ui/tab_database.py
2
  """
3
+ Tab4: Database Browser
4
+ - Model list (Plan A: aggregated by modality)
5
+ - Model detail (Plan B: raw components rows, expandable)
6
+ - Per-head raw data query (modality + layer_type as two independent filters)
7
+ - DB stats
8
  """
9
 
10
  import gradio as gr
 
13
  from db.schema import init_db, get_db_stats
14
  from db.reader import (
15
  get_analyzed_models,
16
+ get_model_components,
17
  get_model_summary,
18
  get_layer_metrics,
19
  get_resume_status,
 
21
 
22
 
23
  def load_db_stats() -> str:
 
24
  conn = init_db()
25
  stats = get_db_stats(conn)
26
  return (
27
+ f"Database Statistics\n"
28
  f"{'─'*40}\n"
29
+ f" Models: {stats.get('models', 0)}\n"
30
+ f" Components: {stats.get('components', 0)}\n"
31
+ f" Layer-head records:{stats.get('layer_head_metrics', 0)}\n"
32
+ f" Summary rows: {stats.get('model_summary', 0)}\n"
33
+ f" DB size: {stats.get('db_size_mb', 0)} MB\n"
34
  )
35
 
36
 
37
  def load_model_list() -> pd.DataFrame:
38
+ """
39
+ 方案A:按 modality 聚合层数
40
+ language_layers 含 standard + global(同一prefix下全部层)
41
+ vision/audio 为 0 时显示 0
42
+ """
43
  conn = init_db()
44
  df = get_analyzed_models(conn)
45
  if df.empty:
46
+ return pd.DataFrame(columns=[
47
+ "model_id", "model_type", "analyzed_at", "analyze_sec",
48
+ "n_components", "language_layers", "vision_layers", "audio_layers"
49
+ ])
50
+ # vision/audio 为 0 时替换为空字符串,更美观
51
+ for col in ["vision_layers", "audio_layers"]:
52
+ df[col] = df[col].apply(lambda x: "" if x == 0 else x)
53
  return df
54
 
55
 
56
+ def load_model_detail(
57
+ model_id: str
58
+ ) -> tuple[pd.DataFrame, pd.DataFrame, str]:
59
  """
60
+ 返回:
61
+ 1. 方案B:原始 components 行(prefix/modality/n_layers/head_dim等)
62
+ 2. model_summary 汇总统计
63
+ 3. 断点续传状态文本
64
  """
65
  if not model_id.strip():
66
+ return pd.DataFrame(), pd.DataFrame(), "Please enter a model ID."
67
 
68
  conn = init_db()
69
+ mid = model_id.strip()
70
+
71
+ # 方案B:原始 components
72
+ comp_df = get_model_components(conn, mid)
73
 
74
  # 汇总统计
75
+ summary_df = get_model_summary(conn, mid)
76
 
77
+ # 断点续传状态
78
+ status_lines = [f"Resume Status: {mid}\n{'─'*50}\n"]
79
+ if not comp_df.empty:
80
+ for pfx in comp_df["prefix"].tolist():
81
+ rs = get_resume_status(conn, mid, pfx)
82
  status_lines.append(
83
  f" [{pfx}]\n"
84
+ f" Done layers : {rs['total_done']}\n"
85
+ f" Layer index : {sorted(rs['done_layers'])}\n"
86
  )
87
  else:
88
+ status_lines.append(" No data yet.\n")
89
 
90
+ return comp_df, summary_df, "".join(status_lines)
91
 
92
 
93
  def load_layer_data(
94
+ model_id: str,
95
+ modality: str,
96
+ layer_type: str,
97
+ start_layer:int,
98
+ end_layer: int,
99
  ) -> tuple[pd.DataFrame, str]:
100
+ """
101
+ 逐头原始数据查询
102
+ modality 和 layer_type 两个维度独立过滤
103
+ """
104
  if not model_id.strip():
105
+ return pd.DataFrame(), "Please enter a model ID."
106
 
107
  conn = init_db()
108
+ mod = modality if modality != "all" else None
109
  lt = layer_type if layer_type != "all" else None
 
110
 
111
  df = get_layer_metrics(
112
  conn,
113
  model_id = model_id.strip(),
114
+ modality = mod,
115
  layer_type = lt,
116
  start_layer = int(start_layer),
117
  end_layer = int(end_layer),
118
  )
119
 
120
  if df.empty:
121
+ return pd.DataFrame(), (
122
+ f"No data found: model={model_id} "
123
+ f"modality={mod or 'all'} layer_type={lt or 'all'}"
124
+ )
125
 
126
  status = (
127
+ f"✅ {len(df)} records "
128
+ f"| layers {df['layer'].min()}~{df['layer'].max()} "
129
+ f"| modality={mod or 'all'} layer_type={lt or 'all'}"
130
  )
131
  return df, status
132
 
 
136
  # ─────────────────────────────────────────────
137
 
138
  def build_tab_database():
139
+ with gr.Tab("🗄️ Database"):
140
+ gr.Markdown(
141
+ "## Database Browser\n"
142
+ "View analyzed models, raw per-head data, and resume status.\n\n"
143
+ "> 查看已分析模型、逐头原始数据及断点续传状态。"
144
+ )
145
 
146
+ # ── DB Stats ────────────────────────────────────────
147
  with gr.Row():
148
  stats_text = gr.Textbox(
149
+ label="Database Statistics",
150
+ value="Click Refresh to load.",
151
  lines=7,
152
  interactive=False,
153
  scale=2,
154
  )
155
  refresh_stats_btn = gr.Button(
156
+ "🔄 Refresh Stats", scale=1, variant="secondary"
157
  )
158
+ refresh_stats_btn.click(fn=load_db_stats, outputs=stats_text)
 
 
 
 
159
 
160
  gr.Markdown("---")
161
 
162
+ # ── Model List(方案A)──────────────────────────────
163
+ gr.Markdown(
164
+ "### Analyzed Models\n"
165
+ "Layers are split by modality. "
166
+ "`language_layers` includes both standard and global layers.\n\n"
167
+ "> 层数按模态拆分。`language_layers` 含 standard 和 global 层。"
168
+ )
169
+ refresh_models_btn = gr.Button(
170
+ "🔄 Refresh Model List", variant="secondary"
171
+ )
172
  models_table = gr.Dataframe(
173
+ label="Analyzed Models",
174
+ headers=[
175
+ "model_id", "model_type", "analyzed_at", "analyze_sec",
176
+ "n_components", "language_layers", "vision_layers", "audio_layers"
177
+ ],
178
  interactive=False,
179
  )
180
+ refresh_models_btn.click(fn=load_model_list, outputs=models_table)
 
 
 
 
181
 
182
  gr.Markdown("---")
183
 
184
+ # ── Model Detail(方案B展开)────────────────────────
185
+ gr.Markdown(
186
+ "### Model Detail & Resume Status\n"
187
+ "Expand raw component rows and check which layers are done.\n\n"
188
+ "> 查看原始组件信息及断点续传进度。"
189
+ )
190
  with gr.Row():
191
  detail_model_id = gr.Textbox(
192
+ label="Model ID",
193
  placeholder="google/gemma-4-e2b",
194
  scale=3,
195
  )
196
  load_detail_btn = gr.Button(
197
+ "📋 Load Detail", variant="secondary", scale=1
198
  )
199
 
200
  resume_status_text = gr.Textbox(
201
+ label="Resume Status",
202
  lines=8,
203
  interactive=False,
204
  )
205
+ # 方案B:原始 components 行
206
+ components_table = gr.Dataframe(
207
+ label="Components (raw) — prefix / modality / n_layers / head_dim",
208
+ headers=[
209
+ "prefix", "modality", "n_layers",
210
+ "head_dim_min", "head_dim_max",
211
+ "has_kv_shared", "has_global", "d_model"
212
+ ],
213
+ interactive=False,
214
+ )
215
  summary_table = gr.Dataframe(
216
+ label="Model Summary (all / standard / global)",
217
  interactive=False,
218
  )
219
 
220
  load_detail_btn.click(
221
  fn=load_model_detail,
222
  inputs=[detail_model_id],
223
+ outputs=[components_table, summary_table, resume_status_text],
224
  )
225
 
226
  gr.Markdown("---")
227
 
228
+ # ── Raw Data Query ──────────────────────────────────
229
+ gr.Markdown(
230
+ "### Per-head Raw Data Query\n"
231
+ "`Modality` and `Layer Type` are two independent filter dimensions.\n\n"
232
+ "> Modality(模态)和 Layer Type(层结构类型)是两个独立过滤维度,可组合使用。"
233
+ )
234
  with gr.Row():
235
  raw_model_id = gr.Textbox(
236
+ label="Model ID",
237
  placeholder="google/gemma-4-e2b",
238
  scale=2,
239
  )
240
+ raw_modality = gr.Dropdown(
241
+ label="Modality",
242
+ choices=["all", "language", "vision", "audio"],
243
+ value="language",
244
+ scale=1,
245
+ info="Filter by component modality | 按模态过滤",
246
  )
247
  raw_layer_type = gr.Dropdown(
248
+ label="Layer Type",
249
  choices=["all", "standard", "global"],
250
  value="all",
251
  scale=1,
252
+ info=(
253
+ "standard = normal layers | "
254
+ "global = K=V shared layers (e.g. Gemma global)"
255
+ ),
256
  )
257
  with gr.Row():
258
  raw_start = gr.Number(
259
+ label="Start Layer", value=0, precision=0, scale=1
260
  )
261
  raw_end = gr.Number(
262
+ label="End Layer", value=10, precision=0, scale=1
263
  )
264
  load_raw_btn = gr.Button(
265
+ "🔍 Query Data", variant="secondary", scale=1
266
  )
267
 
268
  raw_status = gr.Textbox(
269
+ label="Query Status", lines=1, interactive=False
270
  )
271
  raw_table = gr.Dataframe(
272
+ label="Per-head Raw Data",
273
  interactive=False,
274
  wrap=False,
275
  )
276
 
277
  load_raw_btn.click(
278
  fn=load_layer_data,
279
+ inputs=[
280
+ raw_model_id, raw_modality, raw_layer_type,
281
+ raw_start, raw_end
282
+ ],
283
  outputs=[raw_table, raw_status],
284
  )
ui/tab_leaderboard.py CHANGED
@@ -1,9 +1,9 @@
1
  # ui/tab_leaderboard.py
2
  """
3
- Tab3:王氏评分排行榜
4
- - model_summary 读取,按 wang_score 降序
5
- - 支持按组件过滤(language_model / vision_tower / all)
6
- - 支持按 layer_type 过滤(standard / global / all)
7
  """
8
 
9
  import gradio as gr
@@ -14,123 +14,104 @@ from db.schema import init_db
14
  from db.reader import get_leaderboard
15
 
16
 
17
- # ─────────────────────────────────────────────
18
- # 排行榜列格式化
19
- # ─────────────────────────────────────────────
20
-
21
  def _format_leaderboard(df: pd.DataFrame) -> pd.DataFrame:
22
- """格式化排行榜显示列"""
23
  if df.empty:
24
  return df
25
 
26
- # 提取可读的模型名(去掉 org 前缀)
27
  df = df.copy()
28
  df["model_name"] = df["model_id"].apply(
29
  lambda x: x.split("/")[-1] if "/" in x else x
30
  )
31
-
32
- # 王氏评分百分制(便于直觉理解)
33
  df["wang_score_pct"] = df["wang_score"].apply(
34
  lambda x: f"{x*100:.3f}" if pd.notna(x) else "N/A"
35
  )
36
-
37
- # 格式化关键指标
38
  for col in ["median_pearson_QK", "median_ssr_QK", "mean_ssr_QK"]:
39
  if col in df.columns:
40
  df[col] = df[col].apply(
41
  lambda x: f"{x:.6f}" if pd.notna(x) else "N/A"
42
  )
43
 
44
- # 选择展示列
45
  display_cols = [
46
- "model_name",
47
- "prefix",
48
- "layer_type",
49
  "wang_score_pct",
50
- "median_pearson_QK",
51
- "median_ssr_QK",
52
- "mean_ssr_QK",
53
- "median_cosU_QK",
54
- "median_cosU_QV",
55
- "median_cosV_QK",
56
- "n_layers",
57
- "n_records",
58
- "model_id", # 完整 ID 放最后
59
  ]
60
  existing = [c for c in display_cols if c in df.columns]
61
  return df[existing]
62
 
63
 
64
  def load_leaderboard(
65
- prefix_filter: str,
66
- layer_type: str,
67
  ) -> tuple[pd.DataFrame, str]:
68
- """
69
- 加载排行榜数据
70
- 返回 (DataFrame, 状态文本)
71
- """
72
  conn = init_db()
 
 
73
 
74
- # prefix_filter 空字符串 None(不过滤)
75
- pfx = prefix_filter.strip() or None
76
- lt = layer_type if layer_type != "all" else "standard"
77
-
78
- df = get_leaderboard(conn, prefix_filter=pfx, layer_type=lt, limit=100)
79
 
80
  if df.empty:
81
  return pd.DataFrame(), (
82
- "📭 排行榜暂无数据\n"
83
- "请先在「分析」Tab 分析至少一个模型的完整层。\n"
84
- f"(当前过滤:prefix='{pfx}', layer_type='{lt}')"
85
  )
86
 
87
  formatted = _format_leaderboard(df)
88
  status = (
89
- f"✅ {len(formatted)} 条记录 "
90
- f"| layer_type={lt} "
91
- f"| prefix_filter='{pfx or '全部'}'"
92
  )
93
  return formatted, status
94
 
95
 
96
- # ─────────────────────────────────────────────
97
- # Tab3 UI
98
- # ─────────────────────────────────────────────
99
-
100
  def build_tab_leaderboard():
101
- with gr.Tab("🏆 排行榜"):
102
  gr.Markdown("""
103
- ## 王氏评分排行榜
104
- **Wang Score = 1 − median(SSR_QK)**,越高越好(理论极值 = 1)
105
- 基于 `standard` 层计算(排除 K=V 共享的全局层干扰)。
 
 
 
 
106
  """)
107
 
108
  with gr.Row():
109
- prefix_input = gr.Textbox(
110
- label="组件过滤(含关键词即匹配,留空=全部)",
111
- placeholder="language_model",
112
- value="",
113
- scale=3,
 
114
  )
115
  layer_type_input = gr.Dropdown(
116
- label="层类型",
117
  choices=["standard", "global", "all"],
118
  value="standard",
119
  scale=1,
 
 
 
 
 
 
 
120
  )
121
- refresh_btn = gr.Button("🔄 刷新排行榜", variant="primary", scale=1)
122
 
123
  status_text = gr.Textbox(
124
- label="状态",
125
- value="点击「刷新排行榜」加载数据",
126
  lines=1,
127
  interactive=False,
128
  )
129
 
130
  leaderboard_table = gr.Dataframe(
131
- label="王氏评分排行榜(按 Wang Score 降序)",
132
  headers=[
133
- "model_name", "prefix", "layer_type",
134
  "wang_score_pct",
135
  "median_pearson_QK", "median_ssr_QK", "mean_ssr_QK",
136
  "median_cosU_QK", "median_cosU_QV", "median_cosV_QK",
@@ -141,23 +122,20 @@ def build_tab_leaderboard():
141
  )
142
 
143
  gr.Markdown("""
144
- ### 指标说明
145
- | 指标 | 含义 | 越好 |
146
- |------|------|------|
147
- | Wang Score | 1 − median(SSR_QK),综合推理能力评分 | ↑ 高 |
148
- | median_pearson_QK | Q/K 奇异值谱 Pearson 相关中位数(第一定律) | ↑ |
149
- | median_ssr_QK | Q/K 归一化谱失配中位数(第二定律) | |
150
- | median_cosU_QK | Q/K 输出子空间对齐(第四定律,≈随机正交) | 1/√d |
151
- | median_cosU_QV | Q/V 输出子空间(第四定律,超正交) | |
152
- | median_cosV_QK | Q/K 输入子空间(第五定律,≈随机正交) | 1/√D |
 
153
  """)
154
 
155
- # 事件绑定
156
  refresh_btn.click(
157
  fn=load_leaderboard,
158
- inputs=[prefix_input, layer_type_input],
159
  outputs=[leaderboard_table, status_text],
160
- )
161
-
162
- # 启动时自动加载
163
- leaderboard_table.change(fn=None)
 
1
  # ui/tab_leaderboard.py
2
  """
3
+ Tab3: Wang's Five Laws Leaderboard
4
+ - Ranked by wang_score (= 1 − median SSR_QK, standard layers only)
5
+ - Filter by modality (default: language)
6
+ - Filter by layer_type (default: standard)
7
  """
8
 
9
  import gradio as gr
 
14
  from db.reader import get_leaderboard
15
 
16
 
 
 
 
 
17
  def _format_leaderboard(df: pd.DataFrame) -> pd.DataFrame:
 
18
  if df.empty:
19
  return df
20
 
 
21
  df = df.copy()
22
  df["model_name"] = df["model_id"].apply(
23
  lambda x: x.split("/")[-1] if "/" in x else x
24
  )
 
 
25
  df["wang_score_pct"] = df["wang_score"].apply(
26
  lambda x: f"{x*100:.3f}" if pd.notna(x) else "N/A"
27
  )
 
 
28
  for col in ["median_pearson_QK", "median_ssr_QK", "mean_ssr_QK"]:
29
  if col in df.columns:
30
  df[col] = df[col].apply(
31
  lambda x: f"{x:.6f}" if pd.notna(x) else "N/A"
32
  )
33
 
 
34
  display_cols = [
35
+ "model_name", "modality", "layer_type",
 
 
36
  "wang_score_pct",
37
+ "median_pearson_QK", "median_ssr_QK", "mean_ssr_QK",
38
+ "median_cosU_QK", "median_cosU_QV", "median_cosV_QK",
39
+ "n_layers", "n_records", "model_id",
 
 
 
 
 
 
40
  ]
41
  existing = [c for c in display_cols if c in df.columns]
42
  return df[existing]
43
 
44
 
45
  def load_leaderboard(
46
+ modality: str,
47
+ layer_type: str,
48
  ) -> tuple[pd.DataFrame, str]:
 
 
 
 
49
  conn = init_db()
50
+ lt = layer_type if layer_type != "all" else "standard"
51
+ mod = modality
52
 
53
+ df = get_leaderboard(conn, modality=mod, layer_type=lt, limit=100)
 
 
 
 
54
 
55
  if df.empty:
56
  return pd.DataFrame(), (
57
+ f"No data yet. Please analyze at least one model first.\n"
58
+ f"(modality='{mod}', layer_type='{lt}')\n\n"
59
+ f"暂无数据,请先在「Analyze」Tab 分析至少一个模型。"
60
  )
61
 
62
  formatted = _format_leaderboard(df)
63
  status = (
64
+ f"✅ {len(formatted)} entries "
65
+ f"| modality={mod} layer_type={lt}"
 
66
  )
67
  return formatted, status
68
 
69
 
 
 
 
 
70
  def build_tab_leaderboard():
71
+ with gr.Tab("🏆 Leaderboard"):
72
  gr.Markdown("""
73
+ ## Wang's Five Laws — Model Leaderboard
74
+
75
+ **Wang Score = 1 − median(SSR\_QK)** Higher is better. Theoretical max = 1.
76
+ Computed from `standard` layers only (global/KV-shared layers excluded).
77
+
78
+ > 王氏评分 = 1 − median(SSR_QK),越高越好,理论极值=1。
79
+ > 仅基于 standard 层计算(排除 K=V 共享的全局层干扰)。
80
  """)
81
 
82
  with gr.Row():
83
+ modality_input = gr.Dropdown(
84
+ label="Modality",
85
+ choices=["language", "vision", "audio", "all"],
86
+ value="language",
87
+ scale=1,
88
+ info="language = text LLM components | 通常选 language",
89
  )
90
  layer_type_input = gr.Dropdown(
91
+ label="Layer Type",
92
  choices=["standard", "global", "all"],
93
  value="standard",
94
  scale=1,
95
+ info=(
96
+ "standard = normal layers | "
97
+ "global = K=V shared (Gemma global layers)"
98
+ ),
99
+ )
100
+ refresh_btn = gr.Button(
101
+ "🔄 Refresh Leaderboard", variant="primary", scale=1
102
  )
 
103
 
104
  status_text = gr.Textbox(
105
+ label="Status",
106
+ value="Click Refresh to load leaderboard.",
107
  lines=1,
108
  interactive=False,
109
  )
110
 
111
  leaderboard_table = gr.Dataframe(
112
+ label="Wang Score Leaderboard (sorted by Wang Score ↓)",
113
  headers=[
114
+ "model_name", "modality", "layer_type",
115
  "wang_score_pct",
116
  "median_pearson_QK", "median_ssr_QK", "mean_ssr_QK",
117
  "median_cosU_QK", "median_cosU_QV", "median_cosV_QK",
 
122
  )
123
 
124
  gr.Markdown("""
125
+ ### Metric Reference | 指标说明
126
+
127
+ | Metric | Description | Better |
128
+ |--------|-------------|--------|
129
+ | Wang Score | 1 median(SSR\_QK),综合推理能力评分 | ↑ Higher |
130
+ | median\_pearson\_QK | Q/K spectral Pearson correlation (Law 1) | Higher |
131
+ | median\_ssr\_QK | Q/K normalized spectral mismatch (Law 2) | Lower |
132
+ | median\_cosU\_QK | Q/K output subspace alignment (Law 4, ≈ random orthogonal) | 1/√d |
133
+ | median\_cosU\_QV | Q/V output subspace (Law 4, super-orthogonal) | Lower |
134
+ | median\_cosV\_QK | Q/K input subspace (Law 5, ≈ random orthogonal) | ≈ 1/√D |
135
  """)
136
 
 
137
  refresh_btn.click(
138
  fn=load_leaderboard,
139
+ inputs=[modality_input, layer_type_input],
140
  outputs=[leaderboard_table, status_text],
141
+ )