File size: 18,074 Bytes
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
 
 
 
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
6f1ac4c
 
 
 
 
ba623bd
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
 
 
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
 
 
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
 
 
 
 
 
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba623bd
6f1ac4c
ba623bd
6f1ac4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
# core/table_gen.py
"""
Table generation for Wang's Five Laws β€” paper-ready output.
Pure computation layer: takes DataFrames from db/reader, returns DataFrames + formatted strings.
No UI, no DB, no side effects.

Tables:
  Table 1 β€” Cross-model summary (Law 1 & 2): Pearson r, SSR, Wang Score
  Table 2 β€” SSR layer-group trend (Law 2, RL effect): user-defined groups
  Table 3 β€” Output subspace cosU (Law 4): QK / QV / KV + random baseline
  Table 4 β€” Input subspace cosV (Law 5): QK / QV / KV + random baseline
  Table 5 β€” Condition number ΞΊ summary (Law 3): cond_Q, cond_K
  Table 6 β€” Wang Score leaderboard
"""

import numpy as np
import pandas as pd
from typing import Optional


# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _med(series) -> Optional[float]:
    v = series.dropna()
    return float(v.median()) if len(v) > 0 else None


def _mean(series) -> Optional[float]:
    v = series.dropna()
    return float(v.mean()) if len(v) > 0 else None


def _pseudobulk(df: pd.DataFrame, col: str) -> np.ndarray:
    """
    Pseudo-bulk two-step aggregation (Nature Comms 2021).
    Step 1: median across Q heads within each (layer, kv_head) group.
    Step 2: median across kv_head groups per layer.
    Returns 1-D array of per-layer medians.
    For MHA models this equals a plain per-layer median.
    """
    if df.empty or col not in df.columns:
        return np.array([])
    layers = sorted(df["layer"].unique())
    per_layer = []
    for layer in layers:
        ldf = df[df["layer"] == layer]
        if "kv_head" in ldf.columns:
            step1 = ldf.groupby("kv_head")[col].median().values
        else:
            step1 = ldf[col].dropna().values
        step1 = np.array(step1, dtype=float)
        step1 = step1[~np.isnan(step1)]
        if len(step1) > 0:
            per_layer.append(float(np.median(step1)))
    return np.array(per_layer, dtype=float)


def _pb_med(df: pd.DataFrame, col: str) -> Optional[float]:
    """Pseudo-bulk median across layers."""
    v = _pseudobulk(df, col)
    return float(np.median(v)) if len(v) > 0 else None


def _pb_mean(df: pd.DataFrame, col: str) -> Optional[float]:
    """Pseudo-bulk mean across layers."""
    v = _pseudobulk(df, col)
    return float(np.mean(v)) if len(v) > 0 else None


def _fmt(x, decimals=6) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return "β€”"
    return f"{x:.{decimals}f}"


def _short(model_id: str) -> str:
    return model_id.split("/")[-1] if "/" in model_id else model_id


def _standard_only(df: pd.DataFrame) -> pd.DataFrame:
    """Keep only standard layers (exclude global/KV-shared layers)."""
    if "kv_shared" in df.columns:
        return df[df["kv_shared"] == 0]
    if "layer_type" in df.columns:
        return df[df["layer_type"] == "standard"]
    return df


def _random_baseline_U(df: pd.DataFrame) -> float:
    if "head_dim" in df.columns and df["head_dim"].notna().any():
        return 1.0 / np.sqrt(float(df["head_dim"].dropna().median()))
    return float("nan")


def _random_baseline_V(df: pd.DataFrame) -> float:
    if "d_model" in df.columns and df["d_model"].notna().any():
        return 1.0 / np.sqrt(float(df["d_model"].dropna().median()))
    return float("nan")


def _n_global(df: pd.DataFrame) -> int:
    if "kv_shared" in df.columns:
        return int(df[df["kv_shared"] == 1]["layer"].nunique())
    return 0


# ─────────────────────────────────────────────────────────────────────────────
# LaTeX / Markdown helpers
# ─────────────────────────────────────────────────────────────────────────────

def df_to_latex(df: pd.DataFrame, caption: str, label: str) -> str:
    """Convert DataFrame to a complete LaTeX table."""
    cols = list(df.columns)
    n_cols = len(cols)
    col_fmt = "l" + "r" * (n_cols - 1)

    lines = [
        r"\begin{table}[htbp]",
        r"  \centering",
        f"  \\caption{{{caption}}}",
        f"  \\label{{{label}}}",
        f"  \\begin{{tabular}}{{{col_fmt}}}",
        r"    \toprule",
        "    " + " & ".join(str(c) for c in cols) + r" \\",
        r"    \midrule",
    ]
    for _, row in df.iterrows():
        lines.append("    " + " & ".join(str(v) for v in row.values) + r" \\")
    lines += [
        r"    \bottomrule",
        r"  \end{tabular}",
        r"\end{table}",
    ]
    return "\n".join(lines)


def df_to_markdown(df: pd.DataFrame, caption: str) -> str:
    """Convert DataFrame to GitHub-flavored Markdown table."""
    cols = list(df.columns)
    header = "| " + " | ".join(str(c) for c in cols) + " |"
    sep    = "| " + " | ".join("---" for _ in cols) + " |"
    rows   = []
    for _, row in df.iterrows():
        rows.append("| " + " | ".join(str(v) for v in row.values) + " |")
    lines = [f"**{caption}**", "", header, sep] + rows
    return "\n".join(lines)


# ─────────────────────────────────────────────────────────────────────────────
# Table 1 β€” Cross-model summary (Law 1 & 2)
# ─────────────────────────────────────────────────────────────────────────────

def make_table1(
    model_dfs: dict[str, pd.DataFrame],  # {model_id: full_df from DB}
) -> pd.DataFrame:
    """
    One row per model.
    Columns: Model | Layers | Global | Median Pearson | Mean Pearson | Median SSR | Mean SSR | Wang Score
    Uses standard layers only.
    """
    rows = []
    for model_id, df in model_dfs.items():
        std = _standard_only(df)
        if std.empty:
            continue
        n_layers = std["layer"].nunique()
        n_global = _n_global(df)
        rows.append({
            "Model":         _short(model_id),
            "Std Layers":    n_layers,
            "Global Layers": n_global if n_global > 0 else "β€”",
            "Median Pearson":_fmt(_pb_med(std, "pearson_QK"), 4),
            "Mean Pearson":  _fmt(_pb_mean(std, "pearson_QK"), 4),
            "Median SSR":    _fmt(_pb_med(std, "ssr_QK"), 6),
            "Mean SSR":      _fmt(_pb_mean(std, "ssr_QK"), 6),
        })
    return pd.DataFrame(rows)


# ─────────────────────────────────────────────────────────────────────────────
# Table 2 β€” SSR layer-group trend (Law 2, RL effect)
# ─────────────────────────────────────────────────────────────────────────────

def make_table2(
    df_a: pd.DataFrame,
    name_a: str,
    df_b: Optional[pd.DataFrame],
    name_b: Optional[str],
    group_bounds: list[tuple[int, int]],  # e.g. [(0,11),(12,23),(24,35),(36,47)]
) -> pd.DataFrame:
    """
    One row per layer group.
    Single model: Model SSR + Layers column.
    Two models: A SSR | B SSR | Improvement %.
    Uses standard layers only.
    """
    std_a = _standard_only(df_a)
    std_b = _standard_only(df_b) if df_b is not None else None

    rows = []
    for lo, hi in group_bounds:
        label = f"{lo}–{hi}"
        grp_a = std_a[(std_a["layer"] >= lo) & (std_a["layer"] <= hi)]
        ssr_a = _pb_med(grp_a, "ssr_QK")

        row = {"Layer Group": label, f"{_short(name_a)} SSR": _fmt(ssr_a, 6)}

        if std_b is not None and name_b:
            grp_b = std_b[(std_b["layer"] >= lo) & (std_b["layer"] <= hi)]
            ssr_b = _pb_med(grp_b, "ssr_QK")
            row[f"{_short(name_b)} SSR"] = _fmt(ssr_b, 6)
            if ssr_a and ssr_b and ssr_a > 0:
                improvement = (ssr_a - ssr_b) / ssr_a * 100
                row["Improvement (%)"] = f"+{improvement:.2f}%" if improvement >= 0 else f"{improvement:.2f}%"
            else:
                row["Improvement (%)"] = "β€”"

        rows.append(row)
    return pd.DataFrame(rows)


# ─────────────────────────────────────────────────────────────────────────────
# Table 3 β€” Output subspace cosU (Law 4)
# ─────────────────────────────────────────────────────────────────────────────

def make_table3(
    model_dfs: dict[str, pd.DataFrame],
) -> pd.DataFrame:
    """
    One row per model.
    Columns: Model | d_h | Random Baseline | cosU(QK) | cosU(QV) | cosU(KV)
    Uses standard layers only.
    """
    rows = []
    for model_id, df in model_dfs.items():
        std = _standard_only(df)
        if std.empty:
            continue
        baseline = _random_baseline_U(std)
        head_dim = int(std["head_dim"].dropna().median()) if "head_dim" in std.columns and std["head_dim"].notna().any() else "β€”"
        rows.append({
            "Model":           _short(model_id),
            "d_h":             head_dim,
            "Random 1/√d_h":  _fmt(baseline, 4),
            "cosU(Q,K)":      _fmt(_pb_med(std, "cosU_QK"), 4),
            "cosU(Q,V)":      _fmt(_pb_med(std, "cosU_QV"), 4),
            "cosU(K,V)":      _fmt(_pb_med(std, "cosU_KV"), 4),
        })
    return pd.DataFrame(rows)


# ─────────────────────────────────────────────────────────────────────────────
# Table 4 β€” Input subspace cosV (Law 5)
# ─────────────────────────────────────────────────────────────────────────────

def make_table4(
    model_dfs: dict[str, pd.DataFrame],
) -> pd.DataFrame:
    """
    One row per model.
    Columns: Model | d_model | Random Baseline | cosV(QK) | cosV(QV) | cosV(KV)
    Uses standard layers only.
    """
    rows = []
    for model_id, df in model_dfs.items():
        std = _standard_only(df)
        if std.empty:
            continue
        baseline = _random_baseline_V(std)
        d_model  = int(std["d_model"].dropna().median()) if "d_model" in std.columns and std["d_model"].notna().any() else "β€”"
        rows.append({
            "Model":           _short(model_id),
            "d_model":         d_model,
            "Random 1/√D":    _fmt(baseline, 4),
            "cosV(Q,K)":      _fmt(_pb_med(std, "cosV_QK"), 4),
            "cosV(Q,V)":      _fmt(_pb_med(std, "cosV_QV"), 4),
            "cosV(K,V)":      _fmt(_pb_med(std, "cosV_KV"), 4),
        })
    return pd.DataFrame(rows)


# ─────────────────────────────────────────────────────────────────────────────
# Table 5 β€” Condition number ΞΊ summary (Law 3)
# ─────────────────────────────────────────────────────────────────────────────

def make_table5(
    model_dfs: dict[str, pd.DataFrame],
) -> pd.DataFrame:
    """
    One row per model.
    Columns: Model | Median ΞΊ(Q) | Mean ΞΊ(Q) | Median ΞΊ(K) | Mean ΞΊ(K)
    Layer 0 typically has extreme ΞΊ β€” report separately.
    Uses standard layers only.
    """
    rows = []
    for model_id, df in model_dfs.items():
        std = _standard_only(df)
        if std.empty:
            continue
        # Layer 0 stats (typically extreme)
        l0 = std[std["layer"] == std["layer"].min()]
        deep = std[std["layer"] > std["layer"].min()]
        rows.append({
            "Model":            _short(model_id),
            "Median ΞΊ(Q) all":  _fmt(_pb_med(std, "cond_Q"), 1),
            "Median ΞΊ(K) all":  _fmt(_pb_med(std, "cond_K"), 1),
            "ΞΊ(Q) Layer 0":     _fmt(_pb_med(l0,  "cond_Q"), 1),
            "ΞΊ(K) Layer 0":     _fmt(_pb_med(l0,  "cond_K"), 1),
            "Median ΞΊ(Q) deep": _fmt(_pb_med(deep, "cond_Q"), 1),
            "Median ΞΊ(K) deep": _fmt(_pb_med(deep, "cond_K"), 1),
        })
    return pd.DataFrame(rows)


# ─────────────────────────────────────────────────────────────────────────────
# Table 6 β€” Wang Score leaderboard
# ─────────────────────────────────────────────────────────────────────────────

def make_table6(
    model_dfs: dict[str, pd.DataFrame],
) -> pd.DataFrame:
    """
    Ranked by Wang Score descending.
    Columns: Rank | Model | Std Layers | Median Pearson | Median SSR | Wang Score
    """
    rows = []
    for model_id, df in model_dfs.items():
        std = _standard_only(df)
        if std.empty:
            continue
        med_ssr     = _pb_med(std, "ssr_QK")
        wang_score  = 1 - med_ssr if med_ssr is not None else None
        med_pearson = _pb_med(std, "pearson_QK")
        rows.append({
            "Model":          _short(model_id),
            "Std Layers":     std["layer"].nunique(),
            "Median Pearson": _fmt(med_pearson, 4),
            "Median SSR":     _fmt(med_ssr, 6),
            "Wang Score":     wang_score if wang_score is not None else float("nan"),
        })

    df_out = pd.DataFrame(rows)
    if df_out.empty:
        return df_out

    df_out = df_out.sort_values("Wang Score", ascending=False).reset_index(drop=True)
    df_out.insert(0, "Rank", range(1, len(df_out) + 1))
    df_out["Wang Score"] = df_out["Wang Score"].apply(lambda x: _fmt(x, 6))
    return df_out


# ─────────────────────────────────────────────────────────────────────────────
# Master: generate all tables at once
# ─────────────────────────────────────────────────────────────────────────────

def generate_all_tables(
    model_dfs:    dict[str, pd.DataFrame],
    group_bounds: list[tuple[int, int]],
    name_a:       Optional[str] = None,
    name_b:       Optional[str] = None,
) -> dict[str, pd.DataFrame]:
    """
    Generate all 6 tables.
    model_dfs: {model_id: per-head DataFrame from DB}
    group_bounds: layer groups for Table 2, e.g. [(0,11),(12,23),(24,35),(36,47)]
    name_a / name_b: model IDs for Table 2 comparison (name_a must be in model_dfs)
    """
    df_a = model_dfs.get(name_a) if name_a else None
    df_b = model_dfs.get(name_b) if name_b else None

    tables = {}
    tables["t1"] = make_table1(model_dfs)
    if df_a is not None:
        tables["t2"] = make_table2(df_a, name_a, df_b, name_b, group_bounds)
    else:
        tables["t2"] = pd.DataFrame({"Note": ["Select at least Model A for Table 2"]})
    tables["t3"] = make_table3(model_dfs)
    tables["t4"] = make_table4(model_dfs)
    tables["t5"] = make_table5(model_dfs)
    tables["t6"] = make_table6(model_dfs)
    return tables


# ─────────────────────────────────────────────────────────────────────────────
# Format all outputs
# ─────────────────────────────────────────────────────────────────────────────

TABLE_META = {
    "t1": ("Table 1 β€” Cross-Model Summary (Law 1 & 2)",
           "tab:law12_summary"),
    "t2": ("Table 2 β€” SSR Layer-Group Trend (Law 2)",
           "tab:ssr_layergroup"),
    "t3": ("Table 3 β€” Output Subspace Alignment cosU (Law 4)",
           "tab:law4_cosU"),
    "t4": ("Table 4 β€” Input Subspace Alignment cosV (Law 5)",
           "tab:law5_cosV"),
    "t5": ("Table 5 β€” Condition Number ΞΊ Summary (Law 3)",
           "tab:law3_cond"),
    "t6": ("Table 6 β€” Wang Score Leaderboard",
           "tab:wang_score"),
}


def format_all_latex(tables: dict[str, pd.DataFrame]) -> str:
    parts = []
    for key, df in tables.items():
        caption, label = TABLE_META[key]
        parts.append(df_to_latex(df, caption, label))
    return "\n\n".join(parts)


def format_all_markdown(tables: dict[str, pd.DataFrame]) -> str:
    parts = []
    for key, df in tables.items():
        caption, _ = TABLE_META[key]
        parts.append(df_to_markdown(df, caption))
    return "\n\n---\n\n".join(parts)