Spaces:
Running
Running
Alex W. commited on
Commit ·
38fc6ed
1
Parent(s): 9319cc8
feat:write 5 laws's data into sqlite.
Browse files- db/__init__.py +0 -0
- db/reader.py +199 -0
- db/schema.py +208 -0
- db/writer.py +375 -0
db/__init__.py
ADDED
|
File without changes
|
db/reader.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# db/reader.py
|
| 2 |
+
"""
|
| 3 |
+
数据库查询模块
|
| 4 |
+
- 排行榜查询
|
| 5 |
+
- 模型详情查询
|
| 6 |
+
- 断点续传状态查询
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sqlite3
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from db.schema import get_connection, init_db
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ─────────────────────────────────────────────
|
| 15 |
+
# 排行榜
|
| 16 |
+
# ─────────────────────────────────────────────
|
| 17 |
+
|
| 18 |
+
def get_leaderboard(
|
| 19 |
+
conn: sqlite3.Connection,
|
| 20 |
+
prefix_filter: str = None, # 只看某个组件,None=全部
|
| 21 |
+
layer_type: str = "standard",
|
| 22 |
+
limit: int = 50,
|
| 23 |
+
) -> pd.DataFrame:
|
| 24 |
+
"""
|
| 25 |
+
排行榜查询
|
| 26 |
+
按 wang_score 降序排列
|
| 27 |
+
"""
|
| 28 |
+
sql = """
|
| 29 |
+
SELECT
|
| 30 |
+
s.model_id,
|
| 31 |
+
s.prefix,
|
| 32 |
+
s.layer_type,
|
| 33 |
+
s.wang_score,
|
| 34 |
+
s.median_pearson_QK,
|
| 35 |
+
s.median_ssr_QK,
|
| 36 |
+
s.mean_ssr_QK,
|
| 37 |
+
s.median_cosU_QK,
|
| 38 |
+
s.median_cosU_QV,
|
| 39 |
+
s.median_cosV_QK,
|
| 40 |
+
s.median_cond_Q,
|
| 41 |
+
s.n_layers,
|
| 42 |
+
s.n_records,
|
| 43 |
+
s.updated_at,
|
| 44 |
+
-- 组件信息
|
| 45 |
+
c.head_dim_min,
|
| 46 |
+
c.head_dim_max,
|
| 47 |
+
c.has_kv_shared,
|
| 48 |
+
c.has_global,
|
| 49 |
+
c.d_model
|
| 50 |
+
FROM model_summary s
|
| 51 |
+
LEFT JOIN components c
|
| 52 |
+
ON s.model_id = c.model_id AND s.prefix = c.prefix
|
| 53 |
+
WHERE s.layer_type = ?
|
| 54 |
+
"""
|
| 55 |
+
params = [layer_type]
|
| 56 |
+
|
| 57 |
+
if prefix_filter:
|
| 58 |
+
sql += " AND s.prefix LIKE ?"
|
| 59 |
+
params.append(f"%{prefix_filter}%")
|
| 60 |
+
|
| 61 |
+
sql += " ORDER BY s.wang_score DESC LIMIT ?"
|
| 62 |
+
params.append(limit)
|
| 63 |
+
|
| 64 |
+
cur = conn.cursor()
|
| 65 |
+
cur.execute(sql, params)
|
| 66 |
+
rows = cur.fetchall()
|
| 67 |
+
|
| 68 |
+
if not rows:
|
| 69 |
+
return pd.DataFrame()
|
| 70 |
+
|
| 71 |
+
cols = [d[0] for d in cur.description]
|
| 72 |
+
return pd.DataFrame([dict(zip(cols, row)) for row in rows])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ─────────────────────────────────────────────
|
| 76 |
+
# 模型详情
|
| 77 |
+
# ─────────────────────────────────────────────
|
| 78 |
+
|
| 79 |
+
def get_model_summary(
|
| 80 |
+
conn: sqlite3.Connection,
|
| 81 |
+
model_id: str,
|
| 82 |
+
) -> pd.DataFrame:
|
| 83 |
+
"""获取某模型所有组件的汇总统计"""
|
| 84 |
+
cur = conn.cursor()
|
| 85 |
+
cur.execute(
|
| 86 |
+
"""
|
| 87 |
+
SELECT * FROM model_summary
|
| 88 |
+
WHERE model_id = ?
|
| 89 |
+
ORDER BY prefix, layer_type
|
| 90 |
+
""",
|
| 91 |
+
(model_id,)
|
| 92 |
+
)
|
| 93 |
+
rows = cur.fetchall()
|
| 94 |
+
if not rows:
|
| 95 |
+
return pd.DataFrame()
|
| 96 |
+
cols = [d[0] for d in cur.description]
|
| 97 |
+
return pd.DataFrame([dict(zip(cols, row)) for row in rows])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_layer_metrics(
|
| 101 |
+
conn: sqlite3.Connection,
|
| 102 |
+
model_id: str,
|
| 103 |
+
prefix: str = None,
|
| 104 |
+
layer_type: str = None,
|
| 105 |
+
start_layer:int = None,
|
| 106 |
+
end_layer: int = None,
|
| 107 |
+
) -> pd.DataFrame:
|
| 108 |
+
"""
|
| 109 |
+
查询逐头原始数据
|
| 110 |
+
支持按 prefix / layer_type / 层号范围过滤
|
| 111 |
+
"""
|
| 112 |
+
sql = "SELECT * FROM layer_head_metrics WHERE model_id = ?"
|
| 113 |
+
params = [model_id]
|
| 114 |
+
|
| 115 |
+
if prefix:
|
| 116 |
+
sql += " AND prefix = ?"
|
| 117 |
+
params.append(prefix)
|
| 118 |
+
if layer_type:
|
| 119 |
+
sql += " AND layer_type = ?"
|
| 120 |
+
params.append(layer_type)
|
| 121 |
+
if start_layer is not None:
|
| 122 |
+
sql += " AND layer >= ?"
|
| 123 |
+
params.append(start_layer)
|
| 124 |
+
if end_layer is not None:
|
| 125 |
+
sql += " AND layer <= ?"
|
| 126 |
+
params.append(end_layer)
|
| 127 |
+
|
| 128 |
+
sql += " ORDER BY prefix, layer, kv_head, q_head"
|
| 129 |
+
|
| 130 |
+
cur = conn.cursor()
|
| 131 |
+
cur.execute(sql, params)
|
| 132 |
+
rows = cur.fetchall()
|
| 133 |
+
|
| 134 |
+
if not rows:
|
| 135 |
+
return pd.DataFrame()
|
| 136 |
+
cols = [d[0] for d in cur.description]
|
| 137 |
+
return pd.DataFrame([dict(zip(cols, row)) for row in rows])
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_analyzed_models(conn: sqlite3.Connection) -> pd.DataFrame:
|
| 141 |
+
"""获取所有已分析模型列表"""
|
| 142 |
+
cur = conn.cursor()
|
| 143 |
+
cur.execute(
|
| 144 |
+
"""
|
| 145 |
+
SELECT
|
| 146 |
+
m.model_id,
|
| 147 |
+
m.model_type,
|
| 148 |
+
m.analyzed_at,
|
| 149 |
+
m.analyze_sec,
|
| 150 |
+
COUNT(DISTINCT c.prefix) as n_components,
|
| 151 |
+
SUM(c.n_layers) as total_layers
|
| 152 |
+
FROM models m
|
| 153 |
+
LEFT JOIN components c ON m.model_id = c.model_id
|
| 154 |
+
GROUP BY m.model_id
|
| 155 |
+
ORDER BY m.analyzed_at DESC
|
| 156 |
+
"""
|
| 157 |
+
)
|
| 158 |
+
rows = cur.fetchall()
|
| 159 |
+
if not rows:
|
| 160 |
+
return pd.DataFrame()
|
| 161 |
+
cols = [d[0] for d in cur.description]
|
| 162 |
+
return pd.DataFrame([dict(zip(cols, row)) for row in rows])
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ─────────────────────────────────────────────
|
| 166 |
+
# 断点续传状态
|
| 167 |
+
# ─────────────────────────────────────────────
|
| 168 |
+
|
| 169 |
+
def get_resume_status(
|
| 170 |
+
conn: sqlite3.Connection,
|
| 171 |
+
model_id: str,
|
| 172 |
+
prefix: str,
|
| 173 |
+
) -> dict:
|
| 174 |
+
"""
|
| 175 |
+
查询某 (model_id, prefix) 的断点续传状态
|
| 176 |
+
返回已完成的层号集合和统计信息
|
| 177 |
+
"""
|
| 178 |
+
cur = conn.cursor()
|
| 179 |
+
|
| 180 |
+
# 已完成的层
|
| 181 |
+
cur.execute(
|
| 182 |
+
"""
|
| 183 |
+
SELECT DISTINCT layer, COUNT(*) as n_heads
|
| 184 |
+
FROM layer_head_metrics
|
| 185 |
+
WHERE model_id = ? AND prefix = ?
|
| 186 |
+
GROUP BY layer
|
| 187 |
+
ORDER BY layer
|
| 188 |
+
""",
|
| 189 |
+
(model_id, prefix)
|
| 190 |
+
)
|
| 191 |
+
rows = cur.fetchall()
|
| 192 |
+
|
| 193 |
+
done_layers = {r[0]: r[1] for r in rows}
|
| 194 |
+
|
| 195 |
+
return {
|
| 196 |
+
"done_layers": set(done_layers.keys()),
|
| 197 |
+
"layer_detail": done_layers, # layer → n_heads
|
| 198 |
+
"total_done": len(done_layers),
|
| 199 |
+
}
|
db/schema.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# db/schema.py
|
| 2 |
+
"""
|
| 3 |
+
数据库表结构定义与初始化
|
| 4 |
+
SQLite 存储在 /data/wang_laws.db(HF Space bucket 持久化)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sqlite3
|
| 8 |
+
import os
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
# ─────────────────────────────────────────────
|
| 12 |
+
# 数据库路径
|
| 13 |
+
# /data 是 HF Space bucket 挂载点,重启后数据不丢失
|
| 14 |
+
# 本地开发时自动回退到当前目录
|
| 15 |
+
# ─────────────────────────────────────────────
|
| 16 |
+
|
| 17 |
+
def get_db_path() -> str:
|
| 18 |
+
if os.path.exists("/data"):
|
| 19 |
+
return "/data/wang_laws.db"
|
| 20 |
+
return "wang_laws.db"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_connection() -> sqlite3.Connection:
|
| 24 |
+
"""获取数据库连接,启用 WAL 模式提升并发性能"""
|
| 25 |
+
conn = sqlite3.connect(get_db_path(), check_same_thread=False)
|
| 26 |
+
conn.row_factory = sqlite3.Row # 支持按列名访问
|
| 27 |
+
conn.execute("PRAGMA journal_mode=WAL")
|
| 28 |
+
conn.execute("PRAGMA foreign_keys=ON")
|
| 29 |
+
return conn
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ─────────────────────────────────────────────
|
| 33 |
+
# 建表 SQL
|
| 34 |
+
# ─────────────────────────────────────────────
|
| 35 |
+
|
| 36 |
+
SQL_CREATE_MODELS = """
|
| 37 |
+
CREATE TABLE IF NOT EXISTS models (
|
| 38 |
+
model_id TEXT PRIMARY KEY,
|
| 39 |
+
model_type TEXT, -- gemma4 / llama / qwen2 等
|
| 40 |
+
analyzed_at TIMESTAMP,
|
| 41 |
+
analyze_sec REAL, -- 分析耗时(秒)
|
| 42 |
+
notes TEXT -- 备注
|
| 43 |
+
);
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
SQL_CREATE_COMPONENTS = """
|
| 47 |
+
CREATE TABLE IF NOT EXISTS components (
|
| 48 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 49 |
+
model_id TEXT NOT NULL,
|
| 50 |
+
prefix TEXT NOT NULL, -- 如 model.language_model.
|
| 51 |
+
n_layers INTEGER, -- 该组件完整层数
|
| 52 |
+
head_dim_min INTEGER, -- 最小 head_dim(异构层用)
|
| 53 |
+
head_dim_max INTEGER, -- 最大 head_dim
|
| 54 |
+
has_kv_shared INTEGER DEFAULT 0, -- 是否有 K=V 共享层(全局层)
|
| 55 |
+
has_global INTEGER DEFAULT 0, -- 是否有 global 层
|
| 56 |
+
d_model INTEGER, -- 输入维度
|
| 57 |
+
UNIQUE(model_id, prefix),
|
| 58 |
+
FOREIGN KEY(model_id) REFERENCES models(model_id)
|
| 59 |
+
);
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
SQL_CREATE_LAYER_HEAD_METRICS = """
|
| 63 |
+
CREATE TABLE IF NOT EXISTS layer_head_metrics (
|
| 64 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 65 |
+
model_id TEXT NOT NULL,
|
| 66 |
+
prefix TEXT NOT NULL,
|
| 67 |
+
layer INTEGER NOT NULL,
|
| 68 |
+
layer_type TEXT DEFAULT 'standard', -- standard / global
|
| 69 |
+
kv_head INTEGER NOT NULL,
|
| 70 |
+
q_head INTEGER NOT NULL,
|
| 71 |
+
kv_shared INTEGER DEFAULT 0, -- 1=K=V共享(理论值),0=独立V
|
| 72 |
+
head_dim INTEGER,
|
| 73 |
+
d_model INTEGER,
|
| 74 |
+
n_q_heads INTEGER,
|
| 75 |
+
n_kv_heads INTEGER,
|
| 76 |
+
-- 第一定律:谱线性对齐
|
| 77 |
+
pearson_QK REAL,
|
| 78 |
+
spearman_QK REAL,
|
| 79 |
+
pearson_QV REAL,
|
| 80 |
+
pearson_KV REAL,
|
| 81 |
+
-- 第二定律:谱形状残差
|
| 82 |
+
ssr_QK REAL,
|
| 83 |
+
ssr_QV REAL,
|
| 84 |
+
ssr_KV REAL,
|
| 85 |
+
-- 第三定律:条件数
|
| 86 |
+
sigma_max_Q REAL,
|
| 87 |
+
sigma_min_Q REAL,
|
| 88 |
+
cond_Q REAL,
|
| 89 |
+
sigma_max_K REAL,
|
| 90 |
+
sigma_min_K REAL,
|
| 91 |
+
cond_K REAL,
|
| 92 |
+
sigma_max_V REAL,
|
| 93 |
+
sigma_min_V REAL,
|
| 94 |
+
cond_V REAL,
|
| 95 |
+
-- 第四定律:左奇异向量对齐(输出子空间)
|
| 96 |
+
cosU_QK REAL,
|
| 97 |
+
cosU_QV REAL,
|
| 98 |
+
cosU_KV REAL,
|
| 99 |
+
-- 第五定律:右奇异向量对齐(输入子空间)
|
| 100 |
+
cosV_QK REAL,
|
| 101 |
+
cosV_QV REAL,
|
| 102 |
+
cosV_KV REAL,
|
| 103 |
+
-- 尺度因子 + 最小二乘残差
|
| 104 |
+
alpha_QK REAL,
|
| 105 |
+
alpha_res_QK REAL,
|
| 106 |
+
alpha_QV REAL,
|
| 107 |
+
alpha_res_QV REAL,
|
| 108 |
+
alpha_KV REAL,
|
| 109 |
+
alpha_res_KV REAL,
|
| 110 |
+
|
| 111 |
+
UNIQUE(model_id, prefix, layer, kv_head, q_head),
|
| 112 |
+
FOREIGN KEY(model_id) REFERENCES models(model_id)
|
| 113 |
+
);
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
SQL_CREATE_MODEL_SUMMARY = """
|
| 117 |
+
CREATE TABLE IF NOT EXISTS model_summary (
|
| 118 |
+
model_id TEXT NOT NULL,
|
| 119 |
+
prefix TEXT NOT NULL,
|
| 120 |
+
layer_type TEXT NOT NULL DEFAULT 'all', -- all / standard / global
|
| 121 |
+
-- 第一定律
|
| 122 |
+
median_pearson_QK REAL,
|
| 123 |
+
mean_pearson_QK REAL,
|
| 124 |
+
-- 第二定律(王氏评分核心)
|
| 125 |
+
median_ssr_QK REAL,
|
| 126 |
+
mean_ssr_QK REAL,
|
| 127 |
+
median_ssr_QV REAL,
|
| 128 |
+
mean_ssr_QV REAL,
|
| 129 |
+
-- 第三定律
|
| 130 |
+
median_cond_Q REAL,
|
| 131 |
+
mean_cond_Q REAL,
|
| 132 |
+
-- 第四定律
|
| 133 |
+
median_cosU_QK REAL,
|
| 134 |
+
median_cosU_QV REAL,
|
| 135 |
+
-- 第五定律
|
| 136 |
+
median_cosV_QK REAL,
|
| 137 |
+
median_cosV_QV REAL,
|
| 138 |
+
-- 王氏评分(暂时 = 1 - median_ssr_QK,基于 standard 层)
|
| 139 |
+
wang_score REAL,
|
| 140 |
+
-- 统计范围
|
| 141 |
+
n_layers INTEGER, -- 参与统计的层数
|
| 142 |
+
n_records INTEGER, -- 参与统计的记录数
|
| 143 |
+
updated_at TIMESTAMP,
|
| 144 |
+
|
| 145 |
+
PRIMARY KEY(model_id, prefix, layer_type),
|
| 146 |
+
FOREIGN KEY(model_id) REFERENCES models(model_id)
|
| 147 |
+
);
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
# 索引:加速常用查询
|
| 151 |
+
SQL_CREATE_INDEXES = [
|
| 152 |
+
# 按模型+组件查询层数据
|
| 153 |
+
"""CREATE INDEX IF NOT EXISTS idx_metrics_model_prefix
|
| 154 |
+
ON layer_head_metrics(model_id, prefix)""",
|
| 155 |
+
# 按层号范围查询
|
| 156 |
+
"""CREATE INDEX IF NOT EXISTS idx_metrics_layer
|
| 157 |
+
ON layer_head_metrics(model_id, prefix, layer)""",
|
| 158 |
+
# 排行榜查询
|
| 159 |
+
"""CREATE INDEX IF NOT EXISTS idx_summary_wang_score
|
| 160 |
+
ON model_summary(wang_score DESC)""",
|
| 161 |
+
# 断点续传:快速判断某层是否已分析
|
| 162 |
+
"""CREATE INDEX IF NOT EXISTS idx_metrics_resume
|
| 163 |
+
ON layer_head_metrics(model_id, prefix, layer, kv_head, q_head)""",
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ─────────────────────────────────────────────
|
| 168 |
+
# 初始化函数
|
| 169 |
+
# ─────────────────────────────────────────────
|
| 170 |
+
|
| 171 |
+
def init_db() -> sqlite3.Connection:
|
| 172 |
+
"""
|
| 173 |
+
初始化数据库:建表 + 建索引
|
| 174 |
+
幂等操作,重复调用安全
|
| 175 |
+
返回数据库连接
|
| 176 |
+
"""
|
| 177 |
+
conn = get_connection()
|
| 178 |
+
cur = conn.cursor()
|
| 179 |
+
|
| 180 |
+
cur.execute(SQL_CREATE_MODELS)
|
| 181 |
+
cur.execute(SQL_CREATE_COMPONENTS)
|
| 182 |
+
cur.execute(SQL_CREATE_LAYER_HEAD_METRICS)
|
| 183 |
+
cur.execute(SQL_CREATE_MODEL_SUMMARY)
|
| 184 |
+
|
| 185 |
+
for sql in SQL_CREATE_INDEXES:
|
| 186 |
+
cur.execute(sql)
|
| 187 |
+
|
| 188 |
+
conn.commit()
|
| 189 |
+
return conn
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_db_stats(conn: sqlite3.Connection) -> dict:
|
| 193 |
+
"""获取数据库统计信息"""
|
| 194 |
+
cur = conn.cursor()
|
| 195 |
+
stats = {}
|
| 196 |
+
|
| 197 |
+
for table in ["models", "components", "layer_head_metrics", "model_summary"]:
|
| 198 |
+
cur.execute(f"SELECT COUNT(*) FROM {table}")
|
| 199 |
+
stats[table] = cur.fetchone()[0]
|
| 200 |
+
|
| 201 |
+
# 数据库文件大小
|
| 202 |
+
db_path = get_db_path()
|
| 203 |
+
if os.path.exists(db_path):
|
| 204 |
+
stats["db_size_mb"] = round(os.path.getsize(db_path) / 1024 / 1024, 2)
|
| 205 |
+
else:
|
| 206 |
+
stats["db_size_mb"] = 0
|
| 207 |
+
|
| 208 |
+
return stats
|
db/writer.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# db/writer.py
|
| 2 |
+
"""
|
| 3 |
+
数据库写入模块
|
| 4 |
+
- 写入分析结果到 layer_head_metrics
|
| 5 |
+
- 计算并写入 model_summary
|
| 6 |
+
- 支持断点续传(以 prefix+layer 为粒度)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sqlite3
|
| 10 |
+
import numpy as np
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from db.schema import get_connection, init_db
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ─────────────────────────────────────────────
|
| 16 |
+
# layer_type 推断
|
| 17 |
+
# ─────────────────────────────────────────────
|
| 18 |
+
|
| 19 |
+
def infer_layer_type(kv_shared: bool) -> str:
|
| 20 |
+
"""
|
| 21 |
+
从 kv_shared 推断层类型
|
| 22 |
+
kv_shared=True → 'global' (K=V共享,如 Gemma-4-31B 全局层)
|
| 23 |
+
kv_shared=False → 'standard'
|
| 24 |
+
零 hard coding,纯从结构特征推断
|
| 25 |
+
"""
|
| 26 |
+
return "global" if kv_shared else "standard"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ─────────────────────────────────────────────
|
| 30 |
+
# 断点续传:检查已完成的层
|
| 31 |
+
# ─────────────────────────────────────────────
|
| 32 |
+
|
| 33 |
+
def get_analyzed_layers(
|
| 34 |
+
conn: sqlite3.Connection,
|
| 35 |
+
model_id: str,
|
| 36 |
+
prefix: str,
|
| 37 |
+
) -> set[int]:
|
| 38 |
+
"""
|
| 39 |
+
返回已完成分析的层号集合
|
| 40 |
+
用于断点续传:跳过已有数据的层
|
| 41 |
+
粒度:(model_id, prefix, layer)
|
| 42 |
+
"""
|
| 43 |
+
cur = conn.cursor()
|
| 44 |
+
cur.execute(
|
| 45 |
+
"""
|
| 46 |
+
SELECT DISTINCT layer
|
| 47 |
+
FROM layer_head_metrics
|
| 48 |
+
WHERE model_id = ? AND prefix = ?
|
| 49 |
+
""",
|
| 50 |
+
(model_id, prefix)
|
| 51 |
+
)
|
| 52 |
+
return {row[0] for row in cur.fetchall()}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def is_layer_complete(
|
| 56 |
+
conn: sqlite3.Connection,
|
| 57 |
+
model_id: str,
|
| 58 |
+
prefix: str,
|
| 59 |
+
layer: int,
|
| 60 |
+
expected_records: int,
|
| 61 |
+
) -> bool:
|
| 62 |
+
"""
|
| 63 |
+
检查某层是否已完整写入
|
| 64 |
+
expected_records = n_q_heads(该层应有的记录数)
|
| 65 |
+
"""
|
| 66 |
+
cur = conn.cursor()
|
| 67 |
+
cur.execute(
|
| 68 |
+
"""
|
| 69 |
+
SELECT COUNT(*)
|
| 70 |
+
FROM layer_head_metrics
|
| 71 |
+
WHERE model_id = ? AND prefix = ? AND layer = ?
|
| 72 |
+
""",
|
| 73 |
+
(model_id, prefix, layer)
|
| 74 |
+
)
|
| 75 |
+
actual = cur.fetchone()[0]
|
| 76 |
+
return actual >= expected_records
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ─────────────────────────────────────────────
|
| 80 |
+
# 写入模型元数据
|
| 81 |
+
# ─────────────────────────────────────────────
|
| 82 |
+
|
| 83 |
+
def upsert_model(
|
| 84 |
+
conn: sqlite3.Connection,
|
| 85 |
+
model_id: str,
|
| 86 |
+
model_type: str = None,
|
| 87 |
+
notes: str = None,
|
| 88 |
+
):
|
| 89 |
+
"""写入或更新模型基本信息"""
|
| 90 |
+
conn.execute(
|
| 91 |
+
"""
|
| 92 |
+
INSERT INTO models(model_id, model_type, analyzed_at, notes)
|
| 93 |
+
VALUES(?, ?, ?, ?)
|
| 94 |
+
ON CONFLICT(model_id) DO UPDATE SET
|
| 95 |
+
model_type = excluded.model_type,
|
| 96 |
+
analyzed_at = excluded.analyzed_at,
|
| 97 |
+
notes = excluded.notes
|
| 98 |
+
""",
|
| 99 |
+
(model_id, model_type, datetime.utcnow().isoformat(), notes)
|
| 100 |
+
)
|
| 101 |
+
conn.commit()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def upsert_component(
|
| 105 |
+
conn: sqlite3.Connection,
|
| 106 |
+
model_id: str,
|
| 107 |
+
prefix: str,
|
| 108 |
+
n_layers: int,
|
| 109 |
+
head_dim_min: int,
|
| 110 |
+
head_dim_max: int,
|
| 111 |
+
has_kv_shared:bool,
|
| 112 |
+
has_global: bool,
|
| 113 |
+
d_model: int,
|
| 114 |
+
):
|
| 115 |
+
"""写入或更新组件信息"""
|
| 116 |
+
conn.execute(
|
| 117 |
+
"""
|
| 118 |
+
INSERT INTO components(
|
| 119 |
+
model_id, prefix, n_layers,
|
| 120 |
+
head_dim_min, head_dim_max,
|
| 121 |
+
has_kv_shared, has_global, d_model
|
| 122 |
+
)
|
| 123 |
+
VALUES(?, ?, ?, ?, ?, ?, ?, ?)
|
| 124 |
+
ON CONFLICT(model_id, prefix) DO UPDATE SET
|
| 125 |
+
n_layers = excluded.n_layers,
|
| 126 |
+
head_dim_min = excluded.head_dim_min,
|
| 127 |
+
head_dim_max = excluded.head_dim_max,
|
| 128 |
+
has_kv_shared = excluded.has_kv_shared,
|
| 129 |
+
has_global = excluded.has_global,
|
| 130 |
+
d_model = excluded.d_model
|
| 131 |
+
""",
|
| 132 |
+
(
|
| 133 |
+
model_id, prefix, n_layers,
|
| 134 |
+
head_dim_min, head_dim_max,
|
| 135 |
+
int(has_kv_shared), int(has_global), d_model
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
conn.commit()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ─────────────────────────────────────────────
|
| 142 |
+
# 写入逐头指标
|
| 143 |
+
# ─────────────────────────────────────────────
|
| 144 |
+
|
| 145 |
+
def write_layer_records(
|
| 146 |
+
conn: sqlite3.Connection,
|
| 147 |
+
model_id: str,
|
| 148 |
+
records: list[dict],
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
批量写入一层的逐头指标
|
| 152 |
+
使用 INSERT OR REPLACE 实现幂等写入
|
| 153 |
+
"""
|
| 154 |
+
if not records:
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
rows = []
|
| 158 |
+
for r in records:
|
| 159 |
+
layer_type = infer_layer_type(bool(r.get("kv_shared", False)))
|
| 160 |
+
rows.append((
|
| 161 |
+
model_id,
|
| 162 |
+
r["prefix"],
|
| 163 |
+
r["layer"],
|
| 164 |
+
layer_type,
|
| 165 |
+
r["kv_head"],
|
| 166 |
+
r["q_head"],
|
| 167 |
+
int(r.get("kv_shared", False)),
|
| 168 |
+
r.get("head_dim"),
|
| 169 |
+
r.get("d_model"),
|
| 170 |
+
r.get("n_q_heads"),
|
| 171 |
+
r.get("n_kv_heads"),
|
| 172 |
+
# 第一定律
|
| 173 |
+
r.get("pearson_QK"),
|
| 174 |
+
r.get("spearman_QK"),
|
| 175 |
+
r.get("pearson_QV"),
|
| 176 |
+
r.get("pearson_KV"),
|
| 177 |
+
# 第二定律
|
| 178 |
+
r.get("ssr_QK"),
|
| 179 |
+
r.get("ssr_QV"),
|
| 180 |
+
r.get("ssr_KV"),
|
| 181 |
+
# 第三定律
|
| 182 |
+
r.get("sigma_max_Q"),
|
| 183 |
+
r.get("sigma_min_Q"),
|
| 184 |
+
r.get("cond_Q"),
|
| 185 |
+
r.get("sigma_max_K"),
|
| 186 |
+
r.get("sigma_min_K"),
|
| 187 |
+
r.get("cond_K"),
|
| 188 |
+
r.get("sigma_max_V"),
|
| 189 |
+
r.get("sigma_min_V"),
|
| 190 |
+
r.get("cond_V"),
|
| 191 |
+
# 第四定律
|
| 192 |
+
r.get("cosU_QK"),
|
| 193 |
+
r.get("cosU_QV"),
|
| 194 |
+
r.get("cosU_KV"),
|
| 195 |
+
# 第五定律
|
| 196 |
+
r.get("cosV_QK"),
|
| 197 |
+
r.get("cosV_QV"),
|
| 198 |
+
r.get("cosV_KV"),
|
| 199 |
+
# 尺度因子
|
| 200 |
+
r.get("alpha_QK"),
|
| 201 |
+
r.get("alpha_res_QK"),
|
| 202 |
+
r.get("alpha_QV"),
|
| 203 |
+
r.get("alpha_res_QV"),
|
| 204 |
+
r.get("alpha_KV"),
|
| 205 |
+
r.get("alpha_res_KV"),
|
| 206 |
+
))
|
| 207 |
+
|
| 208 |
+
conn.executemany(
|
| 209 |
+
"""
|
| 210 |
+
INSERT OR REPLACE INTO layer_head_metrics(
|
| 211 |
+
model_id, prefix, layer, layer_type,
|
| 212 |
+
kv_head, q_head, kv_shared,
|
| 213 |
+
head_dim, d_model, n_q_heads, n_kv_heads,
|
| 214 |
+
pearson_QK, spearman_QK, pearson_QV, pearson_KV,
|
| 215 |
+
ssr_QK, ssr_QV, ssr_KV,
|
| 216 |
+
sigma_max_Q, sigma_min_Q, cond_Q,
|
| 217 |
+
sigma_max_K, sigma_min_K, cond_K,
|
| 218 |
+
sigma_max_V, sigma_min_V, cond_V,
|
| 219 |
+
cosU_QK, cosU_QV, cosU_KV,
|
| 220 |
+
cosV_QK, cosV_QV, cosV_KV,
|
| 221 |
+
alpha_QK, alpha_res_QK,
|
| 222 |
+
alpha_QV, alpha_res_QV,
|
| 223 |
+
alpha_KV, alpha_res_KV
|
| 224 |
+
) VALUES (
|
| 225 |
+
?,?,?,?,?,?,?,?,?,?,?,
|
| 226 |
+
?,?,?,?,?,?,?,
|
| 227 |
+
?,?,?,?,?,?,?,?,?,
|
| 228 |
+
?,?,?,?,?,?,
|
| 229 |
+
?,?,?,?,?,?
|
| 230 |
+
)
|
| 231 |
+
""",
|
| 232 |
+
rows
|
| 233 |
+
)
|
| 234 |
+
conn.commit()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ─────────────────────────────────────────────
|
| 238 |
+
# 计算并写入 model_summary
|
| 239 |
+
# ─────────────────────────────────────────────
|
| 240 |
+
|
| 241 |
+
def _calc_summary_row(
|
| 242 |
+
rows: list[sqlite3.Row],
|
| 243 |
+
model_id: str,
|
| 244 |
+
prefix: str,
|
| 245 |
+
layer_type: str,
|
| 246 |
+
) -> dict | None:
|
| 247 |
+
"""
|
| 248 |
+
从一组 layer_head_metrics 行计算汇总统计
|
| 249 |
+
返回 model_summary 的一行
|
| 250 |
+
"""
|
| 251 |
+
if not rows:
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
def col(name):
|
| 255 |
+
vals = [r[name] for r in rows if r[name] is not None]
|
| 256 |
+
return np.array(vals) if vals else np.array([])
|
| 257 |
+
|
| 258 |
+
def med(arr):
|
| 259 |
+
return float(np.median(arr)) if len(arr) > 0 else None
|
| 260 |
+
|
| 261 |
+
def avg(arr):
|
| 262 |
+
return float(np.mean(arr)) if len(arr) > 0 else None
|
| 263 |
+
|
| 264 |
+
ssr_qk = col("ssr_QK")
|
| 265 |
+
wang_score = float(1 - np.median(ssr_qk)) if len(ssr_qk) > 0 else None
|
| 266 |
+
|
| 267 |
+
# 统计层数(去重)
|
| 268 |
+
n_layers = len(set(r["layer"] for r in rows))
|
| 269 |
+
n_records = len(rows)
|
| 270 |
+
|
| 271 |
+
return {
|
| 272 |
+
"model_id": model_id,
|
| 273 |
+
"prefix": prefix,
|
| 274 |
+
"layer_type": layer_type,
|
| 275 |
+
"median_pearson_QK": med(col("pearson_QK")),
|
| 276 |
+
"mean_pearson_QK": avg(col("pearson_QK")),
|
| 277 |
+
"median_ssr_QK": med(ssr_qk),
|
| 278 |
+
"mean_ssr_QK": avg(ssr_qk),
|
| 279 |
+
"median_ssr_QV": med(col("ssr_QV")),
|
| 280 |
+
"mean_ssr_QV": avg(col("ssr_QV")),
|
| 281 |
+
"median_cond_Q": med(col("cond_Q")),
|
| 282 |
+
"mean_cond_Q": avg(col("cond_Q")),
|
| 283 |
+
"median_cosU_QK": med(col("cosU_QK")),
|
| 284 |
+
"median_cosU_QV": med(col("cosU_QV")),
|
| 285 |
+
"median_cosV_QK": med(col("cosV_QK")),
|
| 286 |
+
"median_cosV_QV": med(col("cosV_QV")),
|
| 287 |
+
"wang_score": wang_score,
|
| 288 |
+
"n_layers": n_layers,
|
| 289 |
+
"n_records": n_records,
|
| 290 |
+
"updated_at": datetime.utcnow().isoformat(),
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def update_model_summary(
|
| 295 |
+
conn: sqlite3.Connection,
|
| 296 |
+
model_id: str,
|
| 297 |
+
prefix: str,
|
| 298 |
+
):
|
| 299 |
+
"""
|
| 300 |
+
重新计算并写入 model_summary
|
| 301 |
+
对每个 (model_id, prefix) 生成三行:
|
| 302 |
+
- layer_type='all'
|
| 303 |
+
- layer_type='standard'
|
| 304 |
+
- layer_type='global'
|
| 305 |
+
王氏评分固定用 standard 层计算
|
| 306 |
+
"""
|
| 307 |
+
cur = conn.cursor()
|
| 308 |
+
|
| 309 |
+
for layer_type in ["all", "standard", "global"]:
|
| 310 |
+
# 查询对应数据
|
| 311 |
+
if layer_type == "all":
|
| 312 |
+
cur.execute(
|
| 313 |
+
"""
|
| 314 |
+
SELECT * FROM layer_head_metrics
|
| 315 |
+
WHERE model_id = ? AND prefix = ?
|
| 316 |
+
""",
|
| 317 |
+
(model_id, prefix)
|
| 318 |
+
)
|
| 319 |
+
else:
|
| 320 |
+
cur.execute(
|
| 321 |
+
"""
|
| 322 |
+
SELECT * FROM layer_head_metrics
|
| 323 |
+
WHERE model_id = ? AND prefix = ? AND layer_type = ?
|
| 324 |
+
""",
|
| 325 |
+
(model_id, prefix, layer_type)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
rows = cur.fetchall()
|
| 329 |
+
summary = _calc_summary_row(rows, model_id, prefix, layer_type)
|
| 330 |
+
|
| 331 |
+
if summary is None:
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
# 王氏评分统一用 standard 层(如果当前是 all/global,重新取 standard 的 ssr)
|
| 335 |
+
if layer_type != "standard":
|
| 336 |
+
cur.execute(
|
| 337 |
+
"""
|
| 338 |
+
SELECT ssr_QK FROM layer_head_metrics
|
| 339 |
+
WHERE model_id = ? AND prefix = ? AND layer_type = 'standard'
|
| 340 |
+
""",
|
| 341 |
+
(model_id, prefix)
|
| 342 |
+
)
|
| 343 |
+
std_rows = cur.fetchall()
|
| 344 |
+
if std_rows:
|
| 345 |
+
std_ssr = np.array([r[0] for r in std_rows if r[0] is not None])
|
| 346 |
+
summary["wang_score"] = float(1 - np.median(std_ssr)) if len(std_ssr) > 0 else None
|
| 347 |
+
|
| 348 |
+
conn.execute(
|
| 349 |
+
"""
|
| 350 |
+
INSERT OR REPLACE INTO model_summary(
|
| 351 |
+
model_id, prefix, layer_type,
|
| 352 |
+
median_pearson_QK, mean_pearson_QK,
|
| 353 |
+
median_ssr_QK, mean_ssr_QK,
|
| 354 |
+
median_ssr_QV, mean_ssr_QV,
|
| 355 |
+
median_cond_Q, mean_cond_Q,
|
| 356 |
+
median_cosU_QK, median_cosU_QV,
|
| 357 |
+
median_cosV_QK, median_cosV_QV,
|
| 358 |
+
wang_score,
|
| 359 |
+
n_layers, n_records, updated_at
|
| 360 |
+
) VALUES (
|
| 361 |
+
:model_id, :prefix, :layer_type,
|
| 362 |
+
:median_pearson_QK, :mean_pearson_QK,
|
| 363 |
+
:median_ssr_QK, :mean_ssr_QK,
|
| 364 |
+
:median_ssr_QV, :mean_ssr_QV,
|
| 365 |
+
:median_cond_Q, :mean_cond_Q,
|
| 366 |
+
:median_cosU_QK, :median_cosU_QV,
|
| 367 |
+
:median_cosV_QK, :median_cosV_QV,
|
| 368 |
+
:wang_score,
|
| 369 |
+
:n_layers, :n_records, :updated_at
|
| 370 |
+
)
|
| 371 |
+
""",
|
| 372 |
+
summary
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
conn.commit()
|