Alex W. commited on
Commit
38fc6ed
·
1 Parent(s): 9319cc8

feat:write 5 laws's data into sqlite.

Browse files
Files changed (4) hide show
  1. db/__init__.py +0 -0
  2. db/reader.py +199 -0
  3. db/schema.py +208 -0
  4. db/writer.py +375 -0
db/__init__.py ADDED
File without changes
db/reader.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # db/reader.py
2
+ """
3
+ 数据库查询模块
4
+ - 排行榜查询
5
+ - 模型详情查询
6
+ - 断点续传状态查询
7
+ """
8
+
9
+ import sqlite3
10
+ import pandas as pd
11
+ from db.schema import get_connection, init_db
12
+
13
+
14
+ # ─────────────────────────────────────────────
15
+ # 排行榜
16
+ # ─────────────────────────────────────────────
17
+
18
+ def get_leaderboard(
19
+ conn: sqlite3.Connection,
20
+ prefix_filter: str = None, # 只看某个组件,None=全部
21
+ layer_type: str = "standard",
22
+ limit: int = 50,
23
+ ) -> pd.DataFrame:
24
+ """
25
+ 排行榜查询
26
+ 按 wang_score 降序排列
27
+ """
28
+ sql = """
29
+ SELECT
30
+ s.model_id,
31
+ s.prefix,
32
+ s.layer_type,
33
+ s.wang_score,
34
+ s.median_pearson_QK,
35
+ s.median_ssr_QK,
36
+ s.mean_ssr_QK,
37
+ s.median_cosU_QK,
38
+ s.median_cosU_QV,
39
+ s.median_cosV_QK,
40
+ s.median_cond_Q,
41
+ s.n_layers,
42
+ s.n_records,
43
+ s.updated_at,
44
+ -- 组件信息
45
+ c.head_dim_min,
46
+ c.head_dim_max,
47
+ c.has_kv_shared,
48
+ c.has_global,
49
+ c.d_model
50
+ FROM model_summary s
51
+ LEFT JOIN components c
52
+ ON s.model_id = c.model_id AND s.prefix = c.prefix
53
+ WHERE s.layer_type = ?
54
+ """
55
+ params = [layer_type]
56
+
57
+ if prefix_filter:
58
+ sql += " AND s.prefix LIKE ?"
59
+ params.append(f"%{prefix_filter}%")
60
+
61
+ sql += " ORDER BY s.wang_score DESC LIMIT ?"
62
+ params.append(limit)
63
+
64
+ cur = conn.cursor()
65
+ cur.execute(sql, params)
66
+ rows = cur.fetchall()
67
+
68
+ if not rows:
69
+ return pd.DataFrame()
70
+
71
+ cols = [d[0] for d in cur.description]
72
+ return pd.DataFrame([dict(zip(cols, row)) for row in rows])
73
+
74
+
75
+ # ─────────────────────────────────────────────
76
+ # 模型详情
77
+ # ─────────────────────────────────────────────
78
+
79
+ def get_model_summary(
80
+ conn: sqlite3.Connection,
81
+ model_id: str,
82
+ ) -> pd.DataFrame:
83
+ """获取某模型所有组件的汇总统计"""
84
+ cur = conn.cursor()
85
+ cur.execute(
86
+ """
87
+ SELECT * FROM model_summary
88
+ WHERE model_id = ?
89
+ ORDER BY prefix, layer_type
90
+ """,
91
+ (model_id,)
92
+ )
93
+ rows = cur.fetchall()
94
+ if not rows:
95
+ return pd.DataFrame()
96
+ cols = [d[0] for d in cur.description]
97
+ return pd.DataFrame([dict(zip(cols, row)) for row in rows])
98
+
99
+
100
+ def get_layer_metrics(
101
+ conn: sqlite3.Connection,
102
+ model_id: str,
103
+ prefix: str = None,
104
+ layer_type: str = None,
105
+ start_layer:int = None,
106
+ end_layer: int = None,
107
+ ) -> pd.DataFrame:
108
+ """
109
+ 查询逐头原始数据
110
+ 支持按 prefix / layer_type / 层号范围过滤
111
+ """
112
+ sql = "SELECT * FROM layer_head_metrics WHERE model_id = ?"
113
+ params = [model_id]
114
+
115
+ if prefix:
116
+ sql += " AND prefix = ?"
117
+ params.append(prefix)
118
+ if layer_type:
119
+ sql += " AND layer_type = ?"
120
+ params.append(layer_type)
121
+ if start_layer is not None:
122
+ sql += " AND layer >= ?"
123
+ params.append(start_layer)
124
+ if end_layer is not None:
125
+ sql += " AND layer <= ?"
126
+ params.append(end_layer)
127
+
128
+ sql += " ORDER BY prefix, layer, kv_head, q_head"
129
+
130
+ cur = conn.cursor()
131
+ cur.execute(sql, params)
132
+ rows = cur.fetchall()
133
+
134
+ if not rows:
135
+ return pd.DataFrame()
136
+ cols = [d[0] for d in cur.description]
137
+ return pd.DataFrame([dict(zip(cols, row)) for row in rows])
138
+
139
+
140
+ def get_analyzed_models(conn: sqlite3.Connection) -> pd.DataFrame:
141
+ """获取所有已分析模型列表"""
142
+ cur = conn.cursor()
143
+ cur.execute(
144
+ """
145
+ SELECT
146
+ m.model_id,
147
+ m.model_type,
148
+ m.analyzed_at,
149
+ m.analyze_sec,
150
+ COUNT(DISTINCT c.prefix) as n_components,
151
+ SUM(c.n_layers) as total_layers
152
+ FROM models m
153
+ LEFT JOIN components c ON m.model_id = c.model_id
154
+ GROUP BY m.model_id
155
+ ORDER BY m.analyzed_at DESC
156
+ """
157
+ )
158
+ rows = cur.fetchall()
159
+ if not rows:
160
+ return pd.DataFrame()
161
+ cols = [d[0] for d in cur.description]
162
+ return pd.DataFrame([dict(zip(cols, row)) for row in rows])
163
+
164
+
165
+ # ─────────────────────────────────────────────
166
+ # 断点续传状态
167
+ # ─────────────────────────────────────────────
168
+
169
+ def get_resume_status(
170
+ conn: sqlite3.Connection,
171
+ model_id: str,
172
+ prefix: str,
173
+ ) -> dict:
174
+ """
175
+ 查询某 (model_id, prefix) 的断点续传状态
176
+ 返回已完成的层号集合和统计信息
177
+ """
178
+ cur = conn.cursor()
179
+
180
+ # 已完成的层
181
+ cur.execute(
182
+ """
183
+ SELECT DISTINCT layer, COUNT(*) as n_heads
184
+ FROM layer_head_metrics
185
+ WHERE model_id = ? AND prefix = ?
186
+ GROUP BY layer
187
+ ORDER BY layer
188
+ """,
189
+ (model_id, prefix)
190
+ )
191
+ rows = cur.fetchall()
192
+
193
+ done_layers = {r[0]: r[1] for r in rows}
194
+
195
+ return {
196
+ "done_layers": set(done_layers.keys()),
197
+ "layer_detail": done_layers, # layer → n_heads
198
+ "total_done": len(done_layers),
199
+ }
db/schema.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # db/schema.py
2
+ """
3
+ 数据库表结构定义与初始化
4
+ SQLite 存储在 /data/wang_laws.db(HF Space bucket 持久化)
5
+ """
6
+
7
+ import sqlite3
8
+ import os
9
+ from datetime import datetime
10
+
11
+ # ─────────────────────────────────────────────
12
+ # 数据库路径
13
+ # /data 是 HF Space bucket 挂载点,重启后数据不丢失
14
+ # 本地开发时自动回退到当前目录
15
+ # ─────────────────────────────────────────────
16
+
17
+ def get_db_path() -> str:
18
+ if os.path.exists("/data"):
19
+ return "/data/wang_laws.db"
20
+ return "wang_laws.db"
21
+
22
+
23
+ def get_connection() -> sqlite3.Connection:
24
+ """获取数据库连接,启用 WAL 模式提升并发性能"""
25
+ conn = sqlite3.connect(get_db_path(), check_same_thread=False)
26
+ conn.row_factory = sqlite3.Row # 支持按列名访问
27
+ conn.execute("PRAGMA journal_mode=WAL")
28
+ conn.execute("PRAGMA foreign_keys=ON")
29
+ return conn
30
+
31
+
32
+ # ─────────────────────────────────────────────
33
+ # 建表 SQL
34
+ # ─────────────────────────────────────────────
35
+
36
+ SQL_CREATE_MODELS = """
37
+ CREATE TABLE IF NOT EXISTS models (
38
+ model_id TEXT PRIMARY KEY,
39
+ model_type TEXT, -- gemma4 / llama / qwen2 等
40
+ analyzed_at TIMESTAMP,
41
+ analyze_sec REAL, -- 分析耗时(秒)
42
+ notes TEXT -- 备注
43
+ );
44
+ """
45
+
46
+ SQL_CREATE_COMPONENTS = """
47
+ CREATE TABLE IF NOT EXISTS components (
48
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
49
+ model_id TEXT NOT NULL,
50
+ prefix TEXT NOT NULL, -- 如 model.language_model.
51
+ n_layers INTEGER, -- 该组件完整层数
52
+ head_dim_min INTEGER, -- 最小 head_dim(异构层用)
53
+ head_dim_max INTEGER, -- 最大 head_dim
54
+ has_kv_shared INTEGER DEFAULT 0, -- 是否有 K=V 共享层(全局层)
55
+ has_global INTEGER DEFAULT 0, -- 是否有 global 层
56
+ d_model INTEGER, -- 输入维度
57
+ UNIQUE(model_id, prefix),
58
+ FOREIGN KEY(model_id) REFERENCES models(model_id)
59
+ );
60
+ """
61
+
62
+ SQL_CREATE_LAYER_HEAD_METRICS = """
63
+ CREATE TABLE IF NOT EXISTS layer_head_metrics (
64
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
65
+ model_id TEXT NOT NULL,
66
+ prefix TEXT NOT NULL,
67
+ layer INTEGER NOT NULL,
68
+ layer_type TEXT DEFAULT 'standard', -- standard / global
69
+ kv_head INTEGER NOT NULL,
70
+ q_head INTEGER NOT NULL,
71
+ kv_shared INTEGER DEFAULT 0, -- 1=K=V共享(理论值),0=独立V
72
+ head_dim INTEGER,
73
+ d_model INTEGER,
74
+ n_q_heads INTEGER,
75
+ n_kv_heads INTEGER,
76
+ -- 第一定律:谱线性对齐
77
+ pearson_QK REAL,
78
+ spearman_QK REAL,
79
+ pearson_QV REAL,
80
+ pearson_KV REAL,
81
+ -- 第二定律:谱形状残差
82
+ ssr_QK REAL,
83
+ ssr_QV REAL,
84
+ ssr_KV REAL,
85
+ -- 第三定律:条件数
86
+ sigma_max_Q REAL,
87
+ sigma_min_Q REAL,
88
+ cond_Q REAL,
89
+ sigma_max_K REAL,
90
+ sigma_min_K REAL,
91
+ cond_K REAL,
92
+ sigma_max_V REAL,
93
+ sigma_min_V REAL,
94
+ cond_V REAL,
95
+ -- 第四定律:左奇异向量对齐(输出子空间)
96
+ cosU_QK REAL,
97
+ cosU_QV REAL,
98
+ cosU_KV REAL,
99
+ -- 第五定律:右奇异向量对齐(输入子空间)
100
+ cosV_QK REAL,
101
+ cosV_QV REAL,
102
+ cosV_KV REAL,
103
+ -- 尺度因子 + 最小二乘残差
104
+ alpha_QK REAL,
105
+ alpha_res_QK REAL,
106
+ alpha_QV REAL,
107
+ alpha_res_QV REAL,
108
+ alpha_KV REAL,
109
+ alpha_res_KV REAL,
110
+
111
+ UNIQUE(model_id, prefix, layer, kv_head, q_head),
112
+ FOREIGN KEY(model_id) REFERENCES models(model_id)
113
+ );
114
+ """
115
+
116
+ SQL_CREATE_MODEL_SUMMARY = """
117
+ CREATE TABLE IF NOT EXISTS model_summary (
118
+ model_id TEXT NOT NULL,
119
+ prefix TEXT NOT NULL,
120
+ layer_type TEXT NOT NULL DEFAULT 'all', -- all / standard / global
121
+ -- 第一定律
122
+ median_pearson_QK REAL,
123
+ mean_pearson_QK REAL,
124
+ -- 第二定律(王氏评分核心)
125
+ median_ssr_QK REAL,
126
+ mean_ssr_QK REAL,
127
+ median_ssr_QV REAL,
128
+ mean_ssr_QV REAL,
129
+ -- 第三定律
130
+ median_cond_Q REAL,
131
+ mean_cond_Q REAL,
132
+ -- 第四定律
133
+ median_cosU_QK REAL,
134
+ median_cosU_QV REAL,
135
+ -- 第五定律
136
+ median_cosV_QK REAL,
137
+ median_cosV_QV REAL,
138
+ -- 王氏评分(暂时 = 1 - median_ssr_QK,基于 standard 层)
139
+ wang_score REAL,
140
+ -- 统计范围
141
+ n_layers INTEGER, -- 参与统计的层数
142
+ n_records INTEGER, -- 参与统计的记录数
143
+ updated_at TIMESTAMP,
144
+
145
+ PRIMARY KEY(model_id, prefix, layer_type),
146
+ FOREIGN KEY(model_id) REFERENCES models(model_id)
147
+ );
148
+ """
149
+
150
+ # 索引:加速常用查询
151
+ SQL_CREATE_INDEXES = [
152
+ # 按模型+组件查询层数据
153
+ """CREATE INDEX IF NOT EXISTS idx_metrics_model_prefix
154
+ ON layer_head_metrics(model_id, prefix)""",
155
+ # 按层号范围查询
156
+ """CREATE INDEX IF NOT EXISTS idx_metrics_layer
157
+ ON layer_head_metrics(model_id, prefix, layer)""",
158
+ # 排行榜查询
159
+ """CREATE INDEX IF NOT EXISTS idx_summary_wang_score
160
+ ON model_summary(wang_score DESC)""",
161
+ # 断点续传:快速判断某层是否已分析
162
+ """CREATE INDEX IF NOT EXISTS idx_metrics_resume
163
+ ON layer_head_metrics(model_id, prefix, layer, kv_head, q_head)""",
164
+ ]
165
+
166
+
167
+ # ─────────────────────────────────────────────
168
+ # 初始化函数
169
+ # ─────────────────────────────────────────────
170
+
171
+ def init_db() -> sqlite3.Connection:
172
+ """
173
+ 初始化数据库:建表 + 建索引
174
+ 幂等操作,重复调用安全
175
+ 返回数据库连接
176
+ """
177
+ conn = get_connection()
178
+ cur = conn.cursor()
179
+
180
+ cur.execute(SQL_CREATE_MODELS)
181
+ cur.execute(SQL_CREATE_COMPONENTS)
182
+ cur.execute(SQL_CREATE_LAYER_HEAD_METRICS)
183
+ cur.execute(SQL_CREATE_MODEL_SUMMARY)
184
+
185
+ for sql in SQL_CREATE_INDEXES:
186
+ cur.execute(sql)
187
+
188
+ conn.commit()
189
+ return conn
190
+
191
+
192
+ def get_db_stats(conn: sqlite3.Connection) -> dict:
193
+ """获取数据库统计信息"""
194
+ cur = conn.cursor()
195
+ stats = {}
196
+
197
+ for table in ["models", "components", "layer_head_metrics", "model_summary"]:
198
+ cur.execute(f"SELECT COUNT(*) FROM {table}")
199
+ stats[table] = cur.fetchone()[0]
200
+
201
+ # 数据库文件大小
202
+ db_path = get_db_path()
203
+ if os.path.exists(db_path):
204
+ stats["db_size_mb"] = round(os.path.getsize(db_path) / 1024 / 1024, 2)
205
+ else:
206
+ stats["db_size_mb"] = 0
207
+
208
+ return stats
db/writer.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # db/writer.py
2
+ """
3
+ 数据库写入模块
4
+ - 写入分析结果到 layer_head_metrics
5
+ - 计算并写入 model_summary
6
+ - 支持断点续传(以 prefix+layer 为粒度)
7
+ """
8
+
9
+ import sqlite3
10
+ import numpy as np
11
+ from datetime import datetime
12
+ from db.schema import get_connection, init_db
13
+
14
+
15
+ # ─────────────────────────────────────────────
16
+ # layer_type 推断
17
+ # ─────────────────────────────────────────────
18
+
19
+ def infer_layer_type(kv_shared: bool) -> str:
20
+ """
21
+ 从 kv_shared 推断层类型
22
+ kv_shared=True → 'global' (K=V共享,如 Gemma-4-31B 全局层)
23
+ kv_shared=False → 'standard'
24
+ 零 hard coding,纯从结构特征推断
25
+ """
26
+ return "global" if kv_shared else "standard"
27
+
28
+
29
+ # ─────────────────────────────────────────────
30
+ # 断点续传:检查已完成的层
31
+ # ─────────────────────────────────────────────
32
+
33
+ def get_analyzed_layers(
34
+ conn: sqlite3.Connection,
35
+ model_id: str,
36
+ prefix: str,
37
+ ) -> set[int]:
38
+ """
39
+ 返回已完成分析的层号集合
40
+ 用于断点续传:跳过已有数据的层
41
+ 粒度:(model_id, prefix, layer)
42
+ """
43
+ cur = conn.cursor()
44
+ cur.execute(
45
+ """
46
+ SELECT DISTINCT layer
47
+ FROM layer_head_metrics
48
+ WHERE model_id = ? AND prefix = ?
49
+ """,
50
+ (model_id, prefix)
51
+ )
52
+ return {row[0] for row in cur.fetchall()}
53
+
54
+
55
+ def is_layer_complete(
56
+ conn: sqlite3.Connection,
57
+ model_id: str,
58
+ prefix: str,
59
+ layer: int,
60
+ expected_records: int,
61
+ ) -> bool:
62
+ """
63
+ 检查某层是否已完整写入
64
+ expected_records = n_q_heads(该层应有的记录数)
65
+ """
66
+ cur = conn.cursor()
67
+ cur.execute(
68
+ """
69
+ SELECT COUNT(*)
70
+ FROM layer_head_metrics
71
+ WHERE model_id = ? AND prefix = ? AND layer = ?
72
+ """,
73
+ (model_id, prefix, layer)
74
+ )
75
+ actual = cur.fetchone()[0]
76
+ return actual >= expected_records
77
+
78
+
79
+ # ─────────────────────────────────────────────
80
+ # 写入模型元数据
81
+ # ─────────────────────────────────────────────
82
+
83
+ def upsert_model(
84
+ conn: sqlite3.Connection,
85
+ model_id: str,
86
+ model_type: str = None,
87
+ notes: str = None,
88
+ ):
89
+ """写入或更新模型基本信息"""
90
+ conn.execute(
91
+ """
92
+ INSERT INTO models(model_id, model_type, analyzed_at, notes)
93
+ VALUES(?, ?, ?, ?)
94
+ ON CONFLICT(model_id) DO UPDATE SET
95
+ model_type = excluded.model_type,
96
+ analyzed_at = excluded.analyzed_at,
97
+ notes = excluded.notes
98
+ """,
99
+ (model_id, model_type, datetime.utcnow().isoformat(), notes)
100
+ )
101
+ conn.commit()
102
+
103
+
104
+ def upsert_component(
105
+ conn: sqlite3.Connection,
106
+ model_id: str,
107
+ prefix: str,
108
+ n_layers: int,
109
+ head_dim_min: int,
110
+ head_dim_max: int,
111
+ has_kv_shared:bool,
112
+ has_global: bool,
113
+ d_model: int,
114
+ ):
115
+ """写入或更新组件信息"""
116
+ conn.execute(
117
+ """
118
+ INSERT INTO components(
119
+ model_id, prefix, n_layers,
120
+ head_dim_min, head_dim_max,
121
+ has_kv_shared, has_global, d_model
122
+ )
123
+ VALUES(?, ?, ?, ?, ?, ?, ?, ?)
124
+ ON CONFLICT(model_id, prefix) DO UPDATE SET
125
+ n_layers = excluded.n_layers,
126
+ head_dim_min = excluded.head_dim_min,
127
+ head_dim_max = excluded.head_dim_max,
128
+ has_kv_shared = excluded.has_kv_shared,
129
+ has_global = excluded.has_global,
130
+ d_model = excluded.d_model
131
+ """,
132
+ (
133
+ model_id, prefix, n_layers,
134
+ head_dim_min, head_dim_max,
135
+ int(has_kv_shared), int(has_global), d_model
136
+ )
137
+ )
138
+ conn.commit()
139
+
140
+
141
+ # ─────────────────────────────────────────────
142
+ # 写入逐头指标
143
+ # ─────────────────────────────────────────────
144
+
145
+ def write_layer_records(
146
+ conn: sqlite3.Connection,
147
+ model_id: str,
148
+ records: list[dict],
149
+ ):
150
+ """
151
+ 批量写入一层的逐头指标
152
+ 使用 INSERT OR REPLACE 实现幂等写入
153
+ """
154
+ if not records:
155
+ return
156
+
157
+ rows = []
158
+ for r in records:
159
+ layer_type = infer_layer_type(bool(r.get("kv_shared", False)))
160
+ rows.append((
161
+ model_id,
162
+ r["prefix"],
163
+ r["layer"],
164
+ layer_type,
165
+ r["kv_head"],
166
+ r["q_head"],
167
+ int(r.get("kv_shared", False)),
168
+ r.get("head_dim"),
169
+ r.get("d_model"),
170
+ r.get("n_q_heads"),
171
+ r.get("n_kv_heads"),
172
+ # 第一定律
173
+ r.get("pearson_QK"),
174
+ r.get("spearman_QK"),
175
+ r.get("pearson_QV"),
176
+ r.get("pearson_KV"),
177
+ # 第二定律
178
+ r.get("ssr_QK"),
179
+ r.get("ssr_QV"),
180
+ r.get("ssr_KV"),
181
+ # 第三定律
182
+ r.get("sigma_max_Q"),
183
+ r.get("sigma_min_Q"),
184
+ r.get("cond_Q"),
185
+ r.get("sigma_max_K"),
186
+ r.get("sigma_min_K"),
187
+ r.get("cond_K"),
188
+ r.get("sigma_max_V"),
189
+ r.get("sigma_min_V"),
190
+ r.get("cond_V"),
191
+ # 第四定律
192
+ r.get("cosU_QK"),
193
+ r.get("cosU_QV"),
194
+ r.get("cosU_KV"),
195
+ # 第五定律
196
+ r.get("cosV_QK"),
197
+ r.get("cosV_QV"),
198
+ r.get("cosV_KV"),
199
+ # 尺度因子
200
+ r.get("alpha_QK"),
201
+ r.get("alpha_res_QK"),
202
+ r.get("alpha_QV"),
203
+ r.get("alpha_res_QV"),
204
+ r.get("alpha_KV"),
205
+ r.get("alpha_res_KV"),
206
+ ))
207
+
208
+ conn.executemany(
209
+ """
210
+ INSERT OR REPLACE INTO layer_head_metrics(
211
+ model_id, prefix, layer, layer_type,
212
+ kv_head, q_head, kv_shared,
213
+ head_dim, d_model, n_q_heads, n_kv_heads,
214
+ pearson_QK, spearman_QK, pearson_QV, pearson_KV,
215
+ ssr_QK, ssr_QV, ssr_KV,
216
+ sigma_max_Q, sigma_min_Q, cond_Q,
217
+ sigma_max_K, sigma_min_K, cond_K,
218
+ sigma_max_V, sigma_min_V, cond_V,
219
+ cosU_QK, cosU_QV, cosU_KV,
220
+ cosV_QK, cosV_QV, cosV_KV,
221
+ alpha_QK, alpha_res_QK,
222
+ alpha_QV, alpha_res_QV,
223
+ alpha_KV, alpha_res_KV
224
+ ) VALUES (
225
+ ?,?,?,?,?,?,?,?,?,?,?,
226
+ ?,?,?,?,?,?,?,
227
+ ?,?,?,?,?,?,?,?,?,
228
+ ?,?,?,?,?,?,
229
+ ?,?,?,?,?,?
230
+ )
231
+ """,
232
+ rows
233
+ )
234
+ conn.commit()
235
+
236
+
237
+ # ─────────────────────────────────────────────
238
+ # 计算并写入 model_summary
239
+ # ─────────────────────────────────────────────
240
+
241
+ def _calc_summary_row(
242
+ rows: list[sqlite3.Row],
243
+ model_id: str,
244
+ prefix: str,
245
+ layer_type: str,
246
+ ) -> dict | None:
247
+ """
248
+ 从一组 layer_head_metrics 行计算汇总统计
249
+ 返回 model_summary 的一行
250
+ """
251
+ if not rows:
252
+ return None
253
+
254
+ def col(name):
255
+ vals = [r[name] for r in rows if r[name] is not None]
256
+ return np.array(vals) if vals else np.array([])
257
+
258
+ def med(arr):
259
+ return float(np.median(arr)) if len(arr) > 0 else None
260
+
261
+ def avg(arr):
262
+ return float(np.mean(arr)) if len(arr) > 0 else None
263
+
264
+ ssr_qk = col("ssr_QK")
265
+ wang_score = float(1 - np.median(ssr_qk)) if len(ssr_qk) > 0 else None
266
+
267
+ # 统计层数(去重)
268
+ n_layers = len(set(r["layer"] for r in rows))
269
+ n_records = len(rows)
270
+
271
+ return {
272
+ "model_id": model_id,
273
+ "prefix": prefix,
274
+ "layer_type": layer_type,
275
+ "median_pearson_QK": med(col("pearson_QK")),
276
+ "mean_pearson_QK": avg(col("pearson_QK")),
277
+ "median_ssr_QK": med(ssr_qk),
278
+ "mean_ssr_QK": avg(ssr_qk),
279
+ "median_ssr_QV": med(col("ssr_QV")),
280
+ "mean_ssr_QV": avg(col("ssr_QV")),
281
+ "median_cond_Q": med(col("cond_Q")),
282
+ "mean_cond_Q": avg(col("cond_Q")),
283
+ "median_cosU_QK": med(col("cosU_QK")),
284
+ "median_cosU_QV": med(col("cosU_QV")),
285
+ "median_cosV_QK": med(col("cosV_QK")),
286
+ "median_cosV_QV": med(col("cosV_QV")),
287
+ "wang_score": wang_score,
288
+ "n_layers": n_layers,
289
+ "n_records": n_records,
290
+ "updated_at": datetime.utcnow().isoformat(),
291
+ }
292
+
293
+
294
+ def update_model_summary(
295
+ conn: sqlite3.Connection,
296
+ model_id: str,
297
+ prefix: str,
298
+ ):
299
+ """
300
+ 重新计算并写入 model_summary
301
+ 对每个 (model_id, prefix) 生成三行:
302
+ - layer_type='all'
303
+ - layer_type='standard'
304
+ - layer_type='global'
305
+ 王氏评分固定用 standard 层计算
306
+ """
307
+ cur = conn.cursor()
308
+
309
+ for layer_type in ["all", "standard", "global"]:
310
+ # 查询对应数据
311
+ if layer_type == "all":
312
+ cur.execute(
313
+ """
314
+ SELECT * FROM layer_head_metrics
315
+ WHERE model_id = ? AND prefix = ?
316
+ """,
317
+ (model_id, prefix)
318
+ )
319
+ else:
320
+ cur.execute(
321
+ """
322
+ SELECT * FROM layer_head_metrics
323
+ WHERE model_id = ? AND prefix = ? AND layer_type = ?
324
+ """,
325
+ (model_id, prefix, layer_type)
326
+ )
327
+
328
+ rows = cur.fetchall()
329
+ summary = _calc_summary_row(rows, model_id, prefix, layer_type)
330
+
331
+ if summary is None:
332
+ continue
333
+
334
+ # 王氏评分统一用 standard 层(如果当前是 all/global,重新取 standard 的 ssr)
335
+ if layer_type != "standard":
336
+ cur.execute(
337
+ """
338
+ SELECT ssr_QK FROM layer_head_metrics
339
+ WHERE model_id = ? AND prefix = ? AND layer_type = 'standard'
340
+ """,
341
+ (model_id, prefix)
342
+ )
343
+ std_rows = cur.fetchall()
344
+ if std_rows:
345
+ std_ssr = np.array([r[0] for r in std_rows if r[0] is not None])
346
+ summary["wang_score"] = float(1 - np.median(std_ssr)) if len(std_ssr) > 0 else None
347
+
348
+ conn.execute(
349
+ """
350
+ INSERT OR REPLACE INTO model_summary(
351
+ model_id, prefix, layer_type,
352
+ median_pearson_QK, mean_pearson_QK,
353
+ median_ssr_QK, mean_ssr_QK,
354
+ median_ssr_QV, mean_ssr_QV,
355
+ median_cond_Q, mean_cond_Q,
356
+ median_cosU_QK, median_cosU_QV,
357
+ median_cosV_QK, median_cosV_QV,
358
+ wang_score,
359
+ n_layers, n_records, updated_at
360
+ ) VALUES (
361
+ :model_id, :prefix, :layer_type,
362
+ :median_pearson_QK, :mean_pearson_QK,
363
+ :median_ssr_QK, :mean_ssr_QK,
364
+ :median_ssr_QV, :mean_ssr_QV,
365
+ :median_cond_Q, :mean_cond_Q,
366
+ :median_cosU_QK, :median_cosU_QV,
367
+ :median_cosV_QK, :median_cosV_QV,
368
+ :wang_score,
369
+ :n_layers, :n_records, :updated_at
370
+ )
371
+ """,
372
+ summary
373
+ )
374
+
375
+ conn.commit()