Spaces:
Running
Running
| # db/writer.py | |
| """ | |
| ๆฐๆฎๅบๅๅ ฅๆจกๅ | |
| - ๅๅ ฅๅๆ็ปๆๅฐ layer_head_metrics | |
| - ่ฎก็ฎๅนถๅๅ ฅ model_summary๏ผpseudo-bulk ไธคๆญฅ่ๅ๏ผ้ฟๅ GQA ไผช้ๅค๏ผ | |
| - ๆฏๆๆญ็น็ปญไผ ๏ผไปฅ prefix+layer ไธบ็ฒๅบฆ๏ผ | |
| - ๅๅ ฅๆ้้ช่ฏ | |
| - ็บง่ๅ ้คๆจกๅ | |
| """ | |
| import os | |
| import sqlite3 | |
| import numpy as np | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from db.schema import get_connection, init_db | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆจๆญๅฝๆฐ๏ผlayer_type ๅ modality | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def infer_layer_type(kv_shared: bool) -> str: | |
| return "global" if kv_shared else "standard" | |
| def infer_modality(prefix: str) -> str: | |
| p = prefix.lower() | |
| if "vision" in p or "visual" in p or "image" in p: | |
| return "vision" | |
| if "audio" in p or "speech" in p or "acoustic" in p: | |
| return "audio" | |
| return "language" | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅๅ ฅๆ้้ช่ฏ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def check_write_permission(admin_token: str) -> bool: | |
| server_token = os.environ.get("WRITE_TOKEN", "") | |
| if not server_token: | |
| return False | |
| return admin_token.strip() == server_token | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆญ็น็ปญไผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def get_analyzed_layers( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| prefix: str, | |
| ) -> set: | |
| cur = conn.cursor() | |
| cur.execute( | |
| """SELECT DISTINCT layer FROM layer_head_metrics | |
| WHERE model_id = ? AND prefix = ?""", | |
| (model_id, prefix) | |
| ) | |
| return {row[0] for row in cur.fetchall()} | |
| def is_layer_complete( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| prefix: str, | |
| layer: int, | |
| expected_records: int, | |
| ) -> bool: | |
| cur = conn.cursor() | |
| cur.execute( | |
| """SELECT COUNT(*) FROM layer_head_metrics | |
| WHERE model_id = ? AND prefix = ? AND layer = ?""", | |
| (model_id, prefix, layer) | |
| ) | |
| return cur.fetchone()[0] >= expected_records | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅๅ ฅๆจกๅๅ ๆฐๆฎ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def upsert_model( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| model_type: str = None, | |
| notes: str = None, | |
| ): | |
| conn.execute( | |
| """INSERT INTO models(model_id, model_type, analyzed_at, notes) | |
| VALUES(?, ?, ?, ?) | |
| ON CONFLICT(model_id) DO UPDATE SET | |
| model_type = excluded.model_type, | |
| analyzed_at = excluded.analyzed_at, | |
| notes = excluded.notes""", | |
| (model_id, model_type, datetime.utcnow().isoformat(), notes) | |
| ) | |
| conn.commit() | |
| def upsert_component( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| prefix: str, | |
| n_layers: int, | |
| head_dim_min: int, | |
| head_dim_max: int, | |
| has_kv_shared: bool, | |
| has_global: bool, | |
| d_model: int, | |
| ): | |
| modality = infer_modality(prefix) | |
| conn.execute( | |
| """INSERT INTO components( | |
| model_id, prefix, modality, n_layers, | |
| head_dim_min, head_dim_max, | |
| has_kv_shared, has_global, d_model | |
| ) | |
| VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| ON CONFLICT(model_id, prefix) DO UPDATE SET | |
| modality = excluded.modality, | |
| n_layers = excluded.n_layers, | |
| head_dim_min = excluded.head_dim_min, | |
| head_dim_max = excluded.head_dim_max, | |
| has_kv_shared = excluded.has_kv_shared, | |
| has_global = excluded.has_global, | |
| d_model = excluded.d_model""", | |
| ( | |
| model_id, prefix, modality, n_layers, | |
| head_dim_min, head_dim_max, | |
| int(has_kv_shared), int(has_global), d_model | |
| ) | |
| ) | |
| conn.commit() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅๅ ฅ้ๅคดๆๆ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def write_layer_records( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| records: list[dict], | |
| ): | |
| """ๆน้ๅๅ ฅไธๅฑ็้ๅคดๆๆ ๏ผๅน็ญ""" | |
| if not records: | |
| return | |
| rows = [] | |
| for r in records: | |
| layer_type = infer_layer_type(bool(r.get("kv_shared", False))) | |
| modality = infer_modality(r["prefix"]) | |
| rows.append(( | |
| model_id, | |
| r["prefix"], | |
| r["layer"], | |
| layer_type, | |
| modality, | |
| r["kv_head"], | |
| r["q_head"], | |
| int(r.get("kv_shared", False)), | |
| r.get("head_dim"), | |
| r.get("d_model"), | |
| r.get("n_q_heads"), | |
| r.get("n_kv_heads"), | |
| r.get("pearson_QK"), r.get("spearman_QK"), | |
| r.get("pearson_QV"), r.get("pearson_KV"), | |
| r.get("ssr_QK"), r.get("ssr_QV"), r.get("ssr_KV"), | |
| r.get("sigma_max_Q"), r.get("sigma_min_Q"), r.get("cond_Q"), | |
| r.get("sigma_max_K"), r.get("sigma_min_K"), r.get("cond_K"), | |
| r.get("sigma_max_V"), r.get("sigma_min_V"), r.get("cond_V"), | |
| r.get("cosU_QK"), r.get("cosU_QV"), r.get("cosU_KV"), | |
| r.get("cosV_QK"), r.get("cosV_QV"), r.get("cosV_KV"), | |
| r.get("alpha_QK"), r.get("alpha_res_QK"), | |
| r.get("alpha_QV"), r.get("alpha_res_QV"), | |
| r.get("alpha_KV"), r.get("alpha_res_KV"), | |
| )) | |
| conn.executemany( | |
| """INSERT OR REPLACE INTO layer_head_metrics( | |
| model_id, prefix, layer, layer_type, modality, | |
| kv_head, q_head, kv_shared, | |
| head_dim, d_model, n_q_heads, n_kv_heads, | |
| pearson_QK, spearman_QK, pearson_QV, pearson_KV, | |
| ssr_QK, ssr_QV, ssr_KV, | |
| sigma_max_Q, sigma_min_Q, cond_Q, | |
| sigma_max_K, sigma_min_K, cond_K, | |
| sigma_max_V, sigma_min_V, cond_V, | |
| cosU_QK, cosU_QV, cosU_KV, | |
| cosV_QK, cosV_QV, cosV_KV, | |
| alpha_QK, alpha_res_QK, | |
| alpha_QV, alpha_res_QV, | |
| alpha_KV, alpha_res_KV | |
| ) VALUES ( | |
| ?,?,?,?,?,?,?,?,?,?,?,?, | |
| ?,?,?,?,?,?,?, | |
| ?,?,?,?,?,?,?,?,?, | |
| ?,?,?,?,?,?, | |
| ?,?,?,?,?,? | |
| )""", | |
| rows | |
| ) | |
| conn.commit() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Pseudo-bulk ่ๅๆ ธๅฟๅฝๆฐ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _pseudobulk(rows, col_name: str) -> np.ndarray: | |
| """ | |
| Pseudo-bulk two-step aggregation (Nature Comms 2021). | |
| Avoids GQA pseudoreplication (e.g. 4Q:1K โ 4 correlated records per KV head). | |
| Step 1: median within each (layer, kv_head) group | |
| โ one value per KV-head per layer | |
| Step 2: return flat array of Step-1 values | |
| โ caller computes final median / mean / quantile | |
| Works with both sqlite3.Row objects and plain dicts. | |
| """ | |
| groups: dict[tuple, list] = defaultdict(list) | |
| for r in rows: | |
| try: | |
| v = r["ssr_QK"] if col_name == "ssr_QK" else r[col_name] | |
| layer = int(r["layer"]) | |
| kv_head = int(r["kv_head"]) if r["kv_head"] is not None else 0 | |
| except (KeyError, TypeError, IndexError): | |
| continue | |
| if v is None: | |
| continue | |
| groups[(layer, kv_head)].append(float(v)) | |
| if not groups: | |
| return np.array([]) | |
| # Step 1: median within each (layer, kv_head) group | |
| return np.array([float(np.median(vals)) for vals in groups.values()]) | |
| def _pseudobulk_col(rows, col_name: str) -> np.ndarray: | |
| """Generic version of _pseudobulk for any column name.""" | |
| groups: dict[tuple, list] = defaultdict(list) | |
| for r in rows: | |
| try: | |
| v = r[col_name] | |
| layer = int(r["layer"]) | |
| kv_head = int(r["kv_head"]) if r["kv_head"] is not None else 0 | |
| except (KeyError, TypeError, IndexError): | |
| continue | |
| if v is None: | |
| continue | |
| groups[(layer, kv_head)].append(float(v)) | |
| if not groups: | |
| return np.array([]) | |
| return np.array([float(np.median(vals)) for vals in groups.values()]) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ่ฎก็ฎๅนถๅๅ ฅ model_summary | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _calc_summary_row( | |
| rows, | |
| model_id: str, | |
| prefix: str, | |
| layer_type: str, | |
| ) -> dict | None: | |
| if not rows: | |
| return None | |
| def pb(col): | |
| return _pseudobulk_col(rows, col) | |
| def med(arr): return float(np.median(arr)) if len(arr) > 0 else None | |
| def avg(arr): return float(np.mean(arr)) if len(arr) > 0 else None | |
| ssr_qk = pb("ssr_QK") | |
| wang_score = float(1 - np.median(ssr_qk)) if len(ssr_qk) > 0 else None | |
| n_layers = len(set(r["layer"] for r in rows)) | |
| n_records = len(rows) | |
| return { | |
| "model_id": model_id, | |
| "prefix": prefix, | |
| "layer_type": layer_type, | |
| "median_pearson_QK": med(pb("pearson_QK")), | |
| "mean_pearson_QK": avg(pb("pearson_QK")), | |
| "median_ssr_QK": med(ssr_qk), | |
| "mean_ssr_QK": avg(ssr_qk), | |
| "median_ssr_QV": med(pb("ssr_QV")), | |
| "mean_ssr_QV": avg(pb("ssr_QV")), | |
| "median_cond_Q": med(pb("cond_Q")), | |
| "mean_cond_Q": avg(pb("cond_Q")), | |
| "median_cosU_QK": med(pb("cosU_QK")), | |
| "median_cosU_QV": med(pb("cosU_QV")), | |
| "median_cosV_QK": med(pb("cosV_QK")), | |
| "median_cosV_QV": med(pb("cosV_QV")), | |
| "wang_score": wang_score, | |
| "n_layers": n_layers, | |
| "n_records": n_records, | |
| "updated_at": datetime.utcnow().isoformat(), | |
| } | |
| def update_model_summary( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| prefix: str, | |
| ): | |
| """ | |
| ้ๆฐ่ฎก็ฎๅนถๅๅ ฅ model_summary๏ผall / standard / global ไธ่ก๏ผใ | |
| wang_score ็ปไธ็จ standard ๅฑ pseudo-bulk median(SSR_QK) ่ฎก็ฎใ | |
| """ | |
| cur = conn.cursor() | |
| cur.row_factory = sqlite3.Row | |
| # โโ Wang Score: standard ๅฑ pseudo-bulk โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| cur.execute( | |
| """SELECT layer, kv_head, ssr_QK FROM layer_head_metrics | |
| WHERE model_id = ? AND prefix = ? AND layer_type = 'standard' | |
| AND kv_shared = 0""", | |
| (model_id, prefix) | |
| ) | |
| std_ssr = _pseudobulk_col(cur.fetchall(), "ssr_QK") | |
| std_wang_score = float(1 - np.median(std_ssr)) if len(std_ssr) > 0 else None | |
| for layer_type in ["all", "standard", "global"]: | |
| if layer_type == "all": | |
| cur.execute( | |
| "SELECT * FROM layer_head_metrics WHERE model_id=? AND prefix=?", | |
| (model_id, prefix) | |
| ) | |
| else: | |
| cur.execute( | |
| """SELECT * FROM layer_head_metrics | |
| WHERE model_id=? AND prefix=? AND layer_type=?""", | |
| (model_id, prefix, layer_type) | |
| ) | |
| rows = cur.fetchall() | |
| summary = _calc_summary_row(rows, model_id, prefix, layer_type) | |
| if summary is None: | |
| continue | |
| summary["wang_score"] = std_wang_score # always from standard pseudo-bulk | |
| conn.execute( | |
| """INSERT OR REPLACE INTO model_summary( | |
| model_id, prefix, layer_type, | |
| median_pearson_QK, mean_pearson_QK, | |
| median_ssr_QK, mean_ssr_QK, | |
| median_ssr_QV, mean_ssr_QV, | |
| median_cond_Q, mean_cond_Q, | |
| median_cosU_QK, median_cosU_QV, | |
| median_cosV_QK, median_cosV_QV, | |
| wang_score, n_layers, n_records, updated_at | |
| ) VALUES ( | |
| :model_id, :prefix, :layer_type, | |
| :median_pearson_QK, :mean_pearson_QK, | |
| :median_ssr_QK, :mean_ssr_QK, | |
| :median_ssr_QV, :mean_ssr_QV, | |
| :median_cond_Q, :mean_cond_Q, | |
| :median_cosU_QK, :median_cosU_QV, | |
| :median_cosV_QK, :median_cosV_QV, | |
| :wang_score, :n_layers, :n_records, :updated_at | |
| )""", | |
| summary | |
| ) | |
| conn.commit() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆน้ๅทๆฐๆๆๆจกๅ็ model_summary | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def refresh_all_summaries(conn: sqlite3.Connection) -> int: | |
| """ | |
| Re-compute model_summary for every (model_id, prefix) in the DB. | |
| Called by Tab 3 Refresh button to migrate historical data to pseudo-bulk. | |
| Returns number of (model_id, prefix) pairs refreshed. | |
| """ | |
| cur = conn.cursor() | |
| cur.execute( | |
| "SELECT DISTINCT model_id, prefix FROM layer_head_metrics" | |
| ) | |
| pairs = cur.fetchall() | |
| for model_id, prefix in pairs: | |
| update_model_summary(conn, model_id, prefix) | |
| return len(pairs) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅ ้คๆจกๅ๏ผ็บง่ๆธ ้คๆๆ็ธๅ ณๆฐๆฎ๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def delete_model( | |
| conn: sqlite3.Connection, | |
| model_id: str, | |
| admin_token: str, | |
| ) -> tuple[bool, str]: | |
| """ | |
| ็บง่ๅ ้คไธไธชๆจกๅ็ๆๆๆฐๆฎใ | |
| ๅ ้ค้กบๅบ๏ผlayer_head_metrics โ model_summary โ components โ models | |
| ้่ฆ WRITE_TOKEN ้ช่ฏใ | |
| ่ฟๅ (success, message) | |
| """ | |
| if not check_write_permission(admin_token): | |
| return False, "โ Permission denied: invalid or missing Admin Write Token." | |
| model_id = model_id.strip() | |
| if not model_id: | |
| return False, "โ Model ID is empty." | |
| cur = conn.cursor() | |
| # ๅ ็กฎ่ฎคๆจกๅๆฏๅฆๅญๅจ | |
| cur.execute("SELECT model_id FROM models WHERE model_id = ?", (model_id,)) | |
| if cur.fetchone() is None: | |
| return False, f"โ Model not found in DB: '{model_id}'" | |
| # ็ป่ฎกๅ่กจๅฐ่ขซๅ ้ค็่กๆฐ๏ผ็จไบๆฅๅฟ๏ผ | |
| stats = {} | |
| for table in ["layer_head_metrics", "model_summary", "components"]: | |
| cur.execute(f"SELECT COUNT(*) FROM {table} WHERE model_id = ?", (model_id,)) | |
| stats[table] = cur.fetchone()[0] | |
| # ็บง่ๅ ้ค๏ผๅญ่กจๅ ๅ ๏ผๆๅๅ models๏ผ | |
| conn.execute("DELETE FROM layer_head_metrics WHERE model_id = ?", (model_id,)) | |
| conn.execute("DELETE FROM model_summary WHERE model_id = ?", (model_id,)) | |
| conn.execute("DELETE FROM components WHERE model_id = ?", (model_id,)) | |
| conn.execute("DELETE FROM models WHERE model_id = ?", (model_id,)) | |
| conn.commit() | |
| msg = ( | |
| f"โ Deleted: '{model_id}'\n" | |
| f" layer_head_metrics : {stats['layer_head_metrics']} rows\n" | |
| f" model_summary : {stats['model_summary']} rows\n" | |
| f" components : {stats['components']} rows\n" | |
| f" models : 1 row" | |
| ) | |
| return True, msg |