Alex W. commited on
Commit
7ecfc20
Β·
1 Parent(s): a6ecb93

Three UX issues reported on Tab 5 Interactive view:

Browse files

1. **Layout unusable** β€” 4Γ—3 grid compressed into a fixed Plotly canvas made
each subplot too small to read.
2. **Width overflow** β€” matplotlib's physical canvas size (18Γ—20 in) was
inherited by `mpl_to_plotly`, causing the figure to exceed the page width.
3. **Slow response** β€” `mpl_to_plotly` rendered matplotlib first (server-side),
then converted the entire figure to Plotly JSON. Double rendering + large
polygon count from IQR bands made interaction sluggish.

Root cause: all three issues trace back to a single architectural mistake β€”
using `mpl_to_plotly` as a bridge instead of two independent rendering paths.

---

**Throw away the bridge. Two completely independent engines.**

```
Interactive β†’ core/plotter_plotly.py (native Plotly)
Export β†’ core/plotter.py (matplotlib, unchanged)
```

No conversion between them. Each optimized for its own purpose.

---

Native Plotly figure generation. Zero matplotlib dependency.

**Layout: 12Γ—1 vertical stack**
Each of the 12 panels is a full-width subplot, 280 px tall:

```
0 pearson_QK Law 1 β€” Spectral Linear Alignment
1 ssr_QK Law 2 β€” Spectral Shape Fidelity
2 alpha_QK Law 1+2 β€” Scale Factor Ξ±
3 sigma_max_Q Law 3 β€” Max Singular Value (Q)
4 sigma_max_K Law 3 β€” Max Singular Value (K)
5 cond_Q & cond_K Law 3 β€” Condition Number ΞΊ (dual line, same panel)
6 cosU_QK Law 4 β€” Output Subspace Q–K
7 cosU_QV Law 4 β€” Output Subspace Q–V [super-orthogonal]
8 cosU_KV Law 4 β€” Output Subspace K–V [super-orthogonal]
9 cosV_QK Law 5 β€” Input Subspace Q–K
10 cosV_QV Law 5 β€” Input Subspace Q–V
11 cosV_KV Law 5 β€” Input Subspace K–V
```

**Key design decisions:**
- `width=None` + `autosize=True` β†’ fills browser container, no overflow
- `hovermode="x unified"` β†’ all 12 panels show values simultaneously on hover
- IQR band = `go.Scatter` with `fill="toself"` β€” single trace, fast
- cosU panels 6–8 share Y range; cosV panels 9–11 share Y range
(same logic as matplotlib version, consistent effect-size reading)
- Random baseline lines (1/√d_head, 1/√d_model) drawn as dashed `go.Scatter`
- Global (K=V shared) layers marked with `add_vline` dotted gray lines
- Legend horizontal at top, 10pt font

**Public API:**
```python
plotly_single(df, model_name, show_band=True) -> go.Figure
plotly_compare(df_a, df_b, name_a, name_b,
show_band=True, show_delta=True) -> go.Figure
```

**Comparison figure extras:**
- Model A: solid lines. Model B: dashed lines.
- Ξ” (B βˆ’ A): gray `fill="toself"` between the two median lines
- 4 lines in cond panel (Q/K Γ— A/B), color-coded by matrix type

**Internal helpers:**
```
_agg(df, col) β†’ (layers, median, q25, q75) per-layer aggregation
_band_traces(...) β†’ (fill_trace, line_trace)
_hline_trace(...) β†’ reference / baseline line
_hex_to_rgba(hex, alpha) β†’ "rgba(r,g,b,a)" string
_sync_yrange(...) β†’ force identical Y range across a set of rows
_sync_yrange_compare(...) β†’ same, for two-model figures
_global_layers(df) β†’ list of kv_shared layer indices
_infer_dims(df) β†’ (head_dim, d_model) from DataFrame
```

Deleted `fig_to_plotly()` β€” the `mpl_to_plotly` wrapper that caused all three
problems. Everything else unchanged. matplotlib remains the sole export engine.

```python
def fig_to_plotly(fig_mpl):
import plotly.tools as tls
return tls.mpl_to_plotly(fig_mpl) # ← the root cause
```

Two rendering paths now have separate buttons, separate outputs, separate
handler functions. No shared state between them.

**Handler separation:**
```
gen_single_plotly() β†’ calls plotly_single() β†’ gr.Plot()
gen_single_export() β†’ calls plot_single_model() β†’ gr.Image + gr.File Γ—3 + ZIP

gen_compare_plotly() β†’ calls plotly_compare() β†’ gr.Plot()
gen_compare_export() β†’ calls plot_compare_models() β†’ gr.Image + gr.File Γ—3 + ZIP
```

**UI structure:**
```
Shared controls: modality / start layer / end layer / IQR band toggle

Single Model accordion:
Tabs:
"πŸ“‰ Interactive (Plotly)" ← ⚑ fast button β†’ gr.Plot (full-width)
"πŸ–¨οΈ Export (PNG/PDF/SVG)" ← πŸ–¨ slow button β†’ preview + 4 download files

Two-Model Comparison accordion:
Model A dropdown + Model B dropdown + Ξ” fill toggle
Tabs: same structure as above
```

---

```
core/plotter.py matplotlib only
plot_single_model() β†’ plt.Figure (4Γ—3, 18Γ—20 in, 300 dpi)
plot_compare_models() β†’ plt.Figure
save_figure() β†’ [.png, .pdf, .svg]
fig_to_png_bytes() β†’ bytes

core/plotter_plotly.py Plotly only
plotly_single() β†’ go.Figure (12Γ—1, autosize, full-width)
plotly_compare() β†’ go.Figure

ui/tab_plot.py calls both, never converts between them
```

---

| Metric | Before (mpl_to_plotly) | After (native Plotly) |
|--------|------------------------|----------------------|
| Render path | matplotlib β†’ JSON conversion | Direct DataFrame β†’ JSON |
| IQR band | ~500 polygon vertices per panel | 1 `fill="toself"` trace |
| Width control | Inherited 18 in physical size | `autosize=True` |
| Layout | 4Γ—3 (cramped) | 12Γ—1 (readable) |
| Hover | Broken / misaligned | `x unified`, all panels sync |
| Interactive speed | Slow (double render) | Fast (single pass) |

---

- `core/plotter.py` matplotlib export logic (4Γ—3 paper layout) β€” untouched
- `db/` layer β€” untouched
- `app.py` β€” untouched
- `requirements.txt` β€” `plotly` already added in previous commit

---

Structural verification (network-disabled sandbox):
- `ast.parse()` on all three files βœ…
- 12 panels confirmed in `PANELS` list βœ…
- `mpl_to_plotly` string absent from all files βœ…
- `gen_single_plotly`, `gen_compare_plotly` wired correctly in tab βœ…
- `width=None`, `autosize=True` confirmed in both figure functions βœ…

Files changed (3) hide show
  1. core/plotter.py +1 -10
  2. core/plotter_plotly.py +485 -0
  3. ui/tab_plot.py +228 -285
core/plotter.py CHANGED
@@ -473,13 +473,4 @@ def fig_to_png_bytes(fig: plt.Figure) -> bytes:
473
  return buf.read()
474
 
475
 
476
- def fig_to_plotly(fig_mpl: plt.Figure):
477
- """
478
- Convert matplotlib Figure to a Plotly figure via mpl_to_plotly.
479
- Requires plotly installed. Falls back gracefully.
480
- """
481
- try:
482
- import plotly.tools as tls
483
- return tls.mpl_to_plotly(fig_mpl)
484
- except Exception:
485
- return None
 
473
  return buf.read()
474
 
475
 
476
+ # fig_to_plotly removed β€” use core/plotter_plotly.py for native Plotly figures.
 
 
 
 
 
 
 
 
 
core/plotter_plotly.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core/plotter_plotly.py
2
+ """
3
+ Native Plotly interactive figures for Wang's Five Laws.
4
+ 12 subplots stacked vertically (12Γ—1), full browser width.
5
+ Fast: data aggregated once, drawn directly β€” no matplotlib conversion.
6
+
7
+ Layout (top β†’ bottom):
8
+ 0 pearson_QK Law 1 Spectral Linear Alignment
9
+ 1 ssr_QK Law 2 Spectral Shape Fidelity
10
+ 2 alpha_QK Law 1+2 Scale Factor Ξ±
11
+ 3 sigma_max_Q Law 3 Max Singular Value (Q)
12
+ 4 sigma_max_K Law 3 Max Singular Value (K)
13
+ 5 cond_Q + cond_K Law 3 Condition Number ΞΊ (dual line)
14
+ 6 cosU_QK Law 4 Output Subspace Q–K
15
+ 7 cosU_QV Law 4 Output Subspace Q–V
16
+ 8 cosU_KV Law 4 Output Subspace K–V
17
+ 9 cosV_QK Law 5 Input Subspace Q–K
18
+ 10 cosV_QV Law 5 Input Subspace Q–V
19
+ 11 cosV_KV Law 5 Input Subspace K–V
20
+ """
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ import plotly.graph_objects as go
25
+ from plotly.subplots import make_subplots
26
+
27
+ # ── Color palette (identical to plotter.py) ───────────────────────────────────
28
+ C = {
29
+ "Q": "#2166AC",
30
+ "K": "#D6604D",
31
+ "V": "#4DAC26",
32
+ "QK": "#762A83",
33
+ "QV": "#01665E",
34
+ "KV": "#E08214",
35
+ "ref": "#888888",
36
+ }
37
+
38
+ BAND_ALPHA = 0.15 # opacity for IQR band fill
39
+
40
+ # ── Panel definitions ─────────────────────────────────────────────────────────
41
+ # (col, color_key, y_label, title, ideal_value or None)
42
+ PANELS = [
43
+ ("pearson_QK", "QK", "Pearson r", "Law 1 β€” Spectral Linear Alignment (Pearson r Q–K)", 1.0),
44
+ ("ssr_QK", "QK", "SSR", "Law 2 β€” Spectral Shape Fidelity (SSR Q–K)", 0.0),
45
+ ("alpha_QK", "QK", "Ξ±", "Law 1+2 β€” Scale Factor Ξ± (Q–K)", 1.0),
46
+ ("sigma_max_Q", "Q", "Οƒ_max", "Law 3 β€” Max Singular Value Οƒ_max (Q)", None),
47
+ ("sigma_max_K", "K", "Οƒ_max", "Law 3 β€” Max Singular Value Οƒ_max (K)", None),
48
+ ("cond_dual", None, "ΞΊ", "Law 3 β€” Condition Number ΞΊ (Q & K)", None),
49
+ ("cosU_QK", "QK", "cosU", "Law 4 β€” Output Subspace cosU (Q–K)", None),
50
+ ("cosU_QV", "QV", "cosU", "Law 4 β€” Output Subspace cosU (Q–V) [super-orth]", None),
51
+ ("cosU_KV", "KV", "cosU", "Law 4 β€” Output Subspace cosU (K–V) [super-orth]", None),
52
+ ("cosV_QK", "QK", "cosV", "Law 5 β€” Input Subspace cosV (Q–K)", None),
53
+ ("cosV_QV", "QV", "cosV", "Law 5 β€” Input Subspace cosV (Q–V)", None),
54
+ ("cosV_KV", "KV", "cosV", "Law 5 β€” Input Subspace cosV (K–V)", None),
55
+ ]
56
+
57
+ SUBPLOT_HEIGHT = 280 # px per subplot
58
+ TOTAL_HEIGHT = SUBPLOT_HEIGHT * len(PANELS) + 120 # +header
59
+
60
+
61
+ # ─────────────────────────────────────────────────────────────────────────────
62
+ # Data helpers
63
+ # ─────────────────────────────────────────────────────────────────────────────
64
+
65
+ def _agg(df: pd.DataFrame, col: str):
66
+ """Per-layer median + IQR. Excludes kv_shared rows for KV metrics."""
67
+ kv_cols = {"ssr_KV", "pearson_KV", "cosU_KV", "cosV_KV", "alpha_KV"}
68
+ d = df[df["kv_shared"] == 0] if col in kv_cols and "kv_shared" in df.columns else df
69
+ grp = d.groupby("layer")[col]
70
+ layers = np.array(sorted(d["layer"].unique()), dtype=int)
71
+ med = grp.median().reindex(layers).values.astype(float)
72
+ q25 = grp.quantile(0.25).reindex(layers).values.astype(float)
73
+ q75 = grp.quantile(0.75).reindex(layers).values.astype(float)
74
+ return layers, med, q25, q75
75
+
76
+
77
+ def _global_layers(df: pd.DataFrame) -> list[int]:
78
+ if "kv_shared" not in df.columns:
79
+ return []
80
+ return sorted(df[df["kv_shared"] == 1]["layer"].unique().tolist())
81
+
82
+
83
+ def _infer_dims(df: pd.DataFrame) -> tuple[int, int]:
84
+ head_dim = int(df["head_dim"].dropna().median()) if "head_dim" in df.columns and df["head_dim"].notna().any() else 128
85
+ d_model = int(df["d_model"].dropna().median()) if "d_model" in df.columns and df["d_model"].notna().any() else 5120
86
+ return head_dim, d_model
87
+
88
+
89
+ # ─────────────────────────────────────────────────────────────────────────────
90
+ # Trace builders
91
+ # ─────────────────────────────────────────────────────────────────────────────
92
+
93
+ def _band_traces(layers, med, q25, q75, color, name,
94
+ row, dash="solid", show_legend=True):
95
+ """Returns (band_trace, line_trace) for one series."""
96
+ rgba_fill = _hex_to_rgba(color, BAND_ALPHA)
97
+
98
+ band = go.Scatter(
99
+ x=np.concatenate([layers, layers[::-1]]).tolist(),
100
+ y=np.concatenate([q75, q25[::-1]]).tolist(),
101
+ fill="toself",
102
+ fillcolor=rgba_fill,
103
+ line=dict(color="rgba(0,0,0,0)"),
104
+ hoverinfo="skip",
105
+ showlegend=False,
106
+ legendgroup=name,
107
+ )
108
+ line = go.Scatter(
109
+ x=layers.tolist(),
110
+ y=med.tolist(),
111
+ mode="lines",
112
+ name=name,
113
+ line=dict(color=color, width=2, dash=dash),
114
+ hovertemplate=f"Layer %{{x}}<br>{name}: %{{y:.5f}}<extra></extra>",
115
+ showlegend=show_legend,
116
+ legendgroup=name,
117
+ )
118
+ return band, line
119
+
120
+
121
+ def _hline_trace(layers, y_val, label, color=None, row=None):
122
+ color = color or C["ref"]
123
+ return go.Scatter(
124
+ x=[layers[0], layers[-1]],
125
+ y=[y_val, y_val],
126
+ mode="lines",
127
+ name=label,
128
+ line=dict(color=color, width=1.2, dash="dash"),
129
+ hoverinfo="skip",
130
+ showlegend=True,
131
+ legendgroup=label,
132
+ )
133
+
134
+
135
+ def _hex_to_rgba(hex_color: str, alpha: float) -> str:
136
+ h = hex_color.lstrip("#")
137
+ r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
138
+ return f"rgba({r},{g},{b},{alpha})"
139
+
140
+
141
+ def _vlines(fig, global_layers, row, x_range):
142
+ for gl in global_layers:
143
+ fig.add_vline(
144
+ x=gl, row=row, col=1,
145
+ line=dict(color="#AAAAAA", width=1, dash="dot"),
146
+ annotation=dict(
147
+ text=f"G{gl}", font=dict(size=8, color="#999999"),
148
+ showarrow=False, yref="paper",
149
+ ) if row == 1 else None,
150
+ )
151
+
152
+
153
+ # ─────────────────────────────────────────────────────────────────────────────
154
+ # Single-model native Plotly figure
155
+ # ─────────────────────────────────────────────────────────────────────────────
156
+
157
+ def plotly_single(
158
+ df: pd.DataFrame,
159
+ model_name: str,
160
+ show_band: bool = True,
161
+ ) -> go.Figure:
162
+ """
163
+ 12Γ—1 stacked subplots, full browser width.
164
+ Each subplot: median line + IQR band + reference lines + global-layer markers.
165
+ """
166
+ n_panels = len(PANELS)
167
+ head_dim, d_model = _infer_dims(df)
168
+ baseline_U = 1.0 / np.sqrt(head_dim)
169
+ baseline_V = 1.0 / np.sqrt(d_model)
170
+ gl = _global_layers(df)
171
+
172
+ subtitles = [p[3] for p in PANELS]
173
+ fig = make_subplots(
174
+ rows=n_panels, cols=1,
175
+ subplot_titles=subtitles,
176
+ shared_xaxes=False,
177
+ vertical_spacing=0.03,
178
+ )
179
+
180
+ for row_idx, (col, color_key, ylabel, title, ideal) in enumerate(PANELS, start=1):
181
+ color = C[color_key] if color_key else C["Q"]
182
+
183
+ # ── special case: cond_dual ──────────────────────────────────────────
184
+ if col == "cond_dual":
185
+ for c_col, c_key, c_name in [
186
+ ("cond_Q", "Q", "ΞΊ(Q)"),
187
+ ("cond_K", "K", "ΞΊ(K)"),
188
+ ]:
189
+ layers, med, q25, q75 = _agg(df, c_col)
190
+ if len(layers) == 0:
191
+ continue
192
+ band, line = _band_traces(
193
+ layers, med, q25, q75, C[c_key], c_name,
194
+ row=row_idx, show_legend=True
195
+ )
196
+ if show_band:
197
+ fig.add_trace(band, row=row_idx, col=1)
198
+ fig.add_trace(line, row=row_idx, col=1)
199
+ layers_ref = _agg(df, "cond_Q")[0]
200
+
201
+ else:
202
+ layers, med, q25, q75 = _agg(df, col)
203
+ if len(layers) == 0:
204
+ continue
205
+ band, line = _band_traces(
206
+ layers, med, q25, q75, color,
207
+ model_name, row=row_idx, show_legend=(row_idx == 1)
208
+ )
209
+ if show_band:
210
+ fig.add_trace(band, row=row_idx, col=1)
211
+ fig.add_trace(line, row=row_idx, col=1)
212
+ layers_ref = layers
213
+
214
+ # ── ideal / baseline reference lines ─────────────────────────────
215
+ if ideal is not None and len(layers_ref):
216
+ fig.add_trace(
217
+ _hline_trace(layers_ref, ideal, f"Ideal={ideal}",
218
+ color=C["ref"]),
219
+ row=row_idx, col=1
220
+ )
221
+
222
+ # ── random baselines for cosU / cosV ────────────────────────────────
223
+ if col.startswith("cosU_") and len(layers_ref):
224
+ fig.add_trace(
225
+ _hline_trace(layers_ref, baseline_U,
226
+ f"Random 1/√d_h β‰ˆ {baseline_U:.4f}",
227
+ color="#E07B39"),
228
+ row=row_idx, col=1
229
+ )
230
+ if col.startswith("cosV_") and len(layers_ref):
231
+ fig.add_trace(
232
+ _hline_trace(layers_ref, baseline_V,
233
+ f"Random 1/√D β‰ˆ {baseline_V:.4f}",
234
+ color="#E07B39"),
235
+ row=row_idx, col=1
236
+ )
237
+
238
+ # ── global layer vertical markers ────────────────────────────────────
239
+ for gl_idx in gl:
240
+ fig.add_vline(
241
+ x=gl_idx, row=row_idx, col=1,
242
+ line=dict(color="#BBBBBB", width=1, dash="dot"),
243
+ )
244
+
245
+ # ── y-axis label ─────────────────────────────────────────────────────
246
+ fig.update_yaxes(title_text=ylabel, row=row_idx, col=1,
247
+ title_font=dict(size=11))
248
+ fig.update_xaxes(title_text="Layer index", row=row_idx, col=1,
249
+ title_font=dict(size=11))
250
+
251
+ # ── shared Y for cosU row (panels 6,7,8) ─────────────────────────────────
252
+ _sync_yrange(fig, df, ["cosU_QK", "cosU_QV", "cosU_KV"],
253
+ rows=[7, 8, 9], pad=0.08)
254
+ # ── shared Y for cosV row (panels 9,10,11) ───────────────────────────────
255
+ _sync_yrange(fig, df, ["cosV_QK", "cosV_QV", "cosV_KV"],
256
+ rows=[10, 11, 12], pad=0.08)
257
+
258
+ # ── layout ───────────────────────────────────────────────────────────────
259
+ fig.update_layout(
260
+ title=dict(
261
+ text=f"<b>Wang's Five Laws β€” {model_name}</b>",
262
+ font=dict(size=16),
263
+ x=0.5, xanchor="center",
264
+ ),
265
+ height=TOTAL_HEIGHT,
266
+ width=None, # full browser width
267
+ autosize=True,
268
+ showlegend=True,
269
+ legend=dict(
270
+ orientation="h",
271
+ yanchor="bottom", y=1.01,
272
+ xanchor="right", x=1,
273
+ font=dict(size=10),
274
+ ),
275
+ margin=dict(l=70, r=30, t=80, b=40),
276
+ paper_bgcolor="white",
277
+ plot_bgcolor="#FAFAFA",
278
+ font=dict(family="Arial, sans-serif", size=11),
279
+ hovermode="x unified",
280
+ )
281
+ fig.update_annotations(font_size=11)
282
+
283
+ return fig
284
+
285
+
286
+ # ─────────────────────────────────────────────────────────────────────────────
287
+ # Two-model comparison native Plotly figure
288
+ # ─────────────────────────────────────────────────────────────────────────────
289
+
290
+ def plotly_compare(
291
+ df_a: pd.DataFrame,
292
+ df_b: pd.DataFrame,
293
+ name_a: str,
294
+ name_b: str,
295
+ show_band: bool = True,
296
+ show_delta: bool = True,
297
+ ) -> go.Figure:
298
+ """
299
+ 12Γ—1 stacked subplots.
300
+ Model A: solid lines. Model B: dashed lines.
301
+ Ξ” = B βˆ’ A shown as light gray fill when show_delta=True.
302
+ """
303
+ n_panels = len(PANELS)
304
+ head_dim_a, d_model_a = _infer_dims(df_a)
305
+ head_dim_b, d_model_b = _infer_dims(df_b)
306
+ head_dim = (head_dim_a + head_dim_b) // 2
307
+ d_model = (d_model_a + d_model_b) // 2
308
+ baseline_U = 1.0 / np.sqrt(head_dim)
309
+ baseline_V = 1.0 / np.sqrt(d_model)
310
+ gl = sorted(set(_global_layers(df_a)) | set(_global_layers(df_b)))
311
+
312
+ subtitles = [p[3] for p in PANELS]
313
+ fig = make_subplots(
314
+ rows=n_panels, cols=1,
315
+ subplot_titles=subtitles,
316
+ shared_xaxes=False,
317
+ vertical_spacing=0.03,
318
+ )
319
+
320
+ for row_idx, (col, color_key, ylabel, title, ideal) in enumerate(PANELS, start=1):
321
+ color = C[color_key] if color_key else C["Q"]
322
+
323
+ if col == "cond_dual":
324
+ for c_col, c_key, c_name in [
325
+ ("cond_Q", "Q", "ΞΊ(Q)"),
326
+ ("cond_K", "K", "ΞΊ(K)"),
327
+ ]:
328
+ for df_, nm, dash in [(df_a, name_a, "solid"),
329
+ (df_b, name_b, "dash")]:
330
+ layers, med, q25, q75 = _agg(df_, c_col)
331
+ if len(layers) == 0:
332
+ continue
333
+ label = f"{c_name} {nm}"
334
+ band, line = _band_traces(
335
+ layers, med, q25, q75, C[c_key], label,
336
+ row=row_idx, dash=dash, show_legend=True
337
+ )
338
+ if show_band:
339
+ fig.add_trace(band, row=row_idx, col=1)
340
+ fig.add_trace(line, row=row_idx, col=1)
341
+ layers_ref = _agg(df_a, "cond_Q")[0]
342
+
343
+ else:
344
+ layers_a, med_a, q25_a, q75_a = _agg(df_a, col)
345
+ layers_b, med_b, q25_b, q75_b = _agg(df_b, col)
346
+
347
+ for layers, med, q25, q75, nm, dash in [
348
+ (layers_a, med_a, q25_a, q75_a, name_a, "solid"),
349
+ (layers_b, med_b, q25_b, q75_b, name_b, "dash"),
350
+ ]:
351
+ if len(layers) == 0:
352
+ continue
353
+ show_leg = (row_idx == 1)
354
+ band, line = _band_traces(
355
+ layers, med, q25, q75, color, nm,
356
+ row=row_idx, dash=dash, show_legend=show_leg
357
+ )
358
+ if show_band:
359
+ fig.add_trace(band, row=row_idx, col=1)
360
+ fig.add_trace(line, row=row_idx, col=1)
361
+
362
+ # Delta fill
363
+ if show_delta and len(layers_a) and len(layers_b):
364
+ common = np.intersect1d(layers_a, layers_b)
365
+ if len(common) > 1:
366
+ idx_a = np.isin(layers_a, common)
367
+ idx_b = np.isin(layers_b, common)
368
+ delta = med_b[idx_b] - med_a[idx_a]
369
+ zero = np.zeros_like(delta)
370
+ fig.add_trace(go.Scatter(
371
+ x=np.concatenate([common, common[::-1]]).tolist(),
372
+ y=np.concatenate([delta, zero[::-1]]).tolist(),
373
+ fill="toself",
374
+ fillcolor="rgba(160,160,160,0.20)",
375
+ line=dict(color="rgba(0,0,0,0)"),
376
+ hoverinfo="skip",
377
+ showlegend=(row_idx == 1),
378
+ name=f"Ξ” ({name_b}βˆ’{name_a})",
379
+ legendgroup="delta",
380
+ ), row=row_idx, col=1)
381
+
382
+ layers_ref = layers_a if len(layers_a) else layers_b
383
+
384
+ # Reference lines
385
+ if ideal is not None and len(layers_ref):
386
+ fig.add_trace(
387
+ _hline_trace(layers_ref, ideal, f"Ideal={ideal}", C["ref"]),
388
+ row=row_idx, col=1
389
+ )
390
+
391
+ if col.startswith("cosU_") and len(layers_ref):
392
+ fig.add_trace(
393
+ _hline_trace(layers_ref, baseline_U,
394
+ f"Random 1/√d_h β‰ˆ {baseline_U:.4f}", "#E07B39"),
395
+ row=row_idx, col=1
396
+ )
397
+ if col.startswith("cosV_") and len(layers_ref):
398
+ fig.add_trace(
399
+ _hline_trace(layers_ref, baseline_V,
400
+ f"Random 1/√D β‰ˆ {baseline_V:.4f}", "#E07B39"),
401
+ row=row_idx, col=1
402
+ )
403
+
404
+ for gl_idx in gl:
405
+ fig.add_vline(
406
+ x=gl_idx, row=row_idx, col=1,
407
+ line=dict(color="#BBBBBB", width=1, dash="dot"),
408
+ )
409
+
410
+ fig.update_yaxes(title_text=ylabel, row=row_idx, col=1,
411
+ title_font=dict(size=11))
412
+ fig.update_xaxes(title_text="Layer index", row=row_idx, col=1,
413
+ title_font=dict(size=11))
414
+
415
+ _sync_yrange_compare(fig, df_a, df_b,
416
+ ["cosU_QK", "cosU_QV", "cosU_KV"], [7, 8, 9])
417
+ _sync_yrange_compare(fig, df_a, df_b,
418
+ ["cosV_QK", "cosV_QV", "cosV_KV"], [10, 11, 12])
419
+
420
+ fig.update_layout(
421
+ title=dict(
422
+ text=f"<b>Wang's Five Laws β€” {name_a} vs {name_b}</b>",
423
+ font=dict(size=16),
424
+ x=0.5, xanchor="center",
425
+ ),
426
+ height=TOTAL_HEIGHT,
427
+ width=None,
428
+ autosize=True,
429
+ showlegend=True,
430
+ legend=dict(
431
+ orientation="h",
432
+ yanchor="bottom", y=1.01,
433
+ xanchor="right", x=1,
434
+ font=dict(size=10),
435
+ ),
436
+ margin=dict(l=70, r=30, t=80, b=40),
437
+ paper_bgcolor="white",
438
+ plot_bgcolor="#FAFAFA",
439
+ font=dict(family="Arial, sans-serif", size=11),
440
+ hovermode="x unified",
441
+ )
442
+ fig.update_annotations(font_size=11)
443
+
444
+ return fig
445
+
446
+
447
+ # ─────────────────────────────────────────────────────────────────────────────
448
+ # Shared Y-axis helpers
449
+ # ─────────────────────────────────────────────────────────────────────────────
450
+
451
+ def _sync_yrange(fig, df, cols, rows, pad=0.08):
452
+ """Force identical y-range for a set of rows (single model)."""
453
+ vals = []
454
+ for col in cols:
455
+ try:
456
+ _, med, q25, q75 = _agg(df, col)
457
+ vals.extend(q25[~np.isnan(q25)].tolist())
458
+ vals.extend(q75[~np.isnan(q75)].tolist())
459
+ except Exception:
460
+ pass
461
+ if not vals:
462
+ return
463
+ lo = max(0.0, min(vals) * (1 - pad))
464
+ hi = max(vals) * (1 + pad)
465
+ for r in rows:
466
+ fig.update_yaxes(range=[lo, hi], row=r, col=1)
467
+
468
+
469
+ def _sync_yrange_compare(fig, df_a, df_b, cols, rows, pad=0.08):
470
+ """Force identical y-range for a set of rows (two-model comparison)."""
471
+ vals = []
472
+ for col in cols:
473
+ for df_ in [df_a, df_b]:
474
+ try:
475
+ _, med, q25, q75 = _agg(df_, col)
476
+ vals.extend(q25[~np.isnan(q25)].tolist())
477
+ vals.extend(q75[~np.isnan(q75)].tolist())
478
+ except Exception:
479
+ pass
480
+ if not vals:
481
+ return
482
+ lo = max(0.0, min(vals) * (1 - pad))
483
+ hi = max(vals) * (1 + pad)
484
+ for r in rows:
485
+ fig.update_yaxes(range=[lo, hi], row=r, col=1)
ui/tab_plot.py CHANGED
@@ -1,30 +1,22 @@
1
  # ui/tab_plot.py
2
  """
3
  Tab5: Plot β€” Publication-quality figure generation
4
- Data pulled from SQLite DB.
5
- Supports: single model (4Γ—3) and two-model comparison (4Γ—3).
6
- Export: PNG (300 dpi) / PDF / SVG.
7
- Engine: matplotlib (publication) + optional Plotly (interactive).
8
  """
9
 
10
  import os
11
- import tempfile
12
  import zipfile
13
 
14
  import gradio as gr
15
  import pandas as pd
16
- import numpy as np
17
 
18
  from db.schema import init_db
19
  from db.reader import get_layer_metrics, get_analyzed_models
20
- from core.plotter import (
21
- plot_single_model,
22
- plot_compare_models,
23
- save_figure,
24
- fig_to_plotly,
25
- )
26
-
27
- # ── Output directory ──────────────────────────────────────────────────────────
28
  _OUT_DIR = "/tmp/wang_plots"
29
  os.makedirs(_OUT_DIR, exist_ok=True)
30
 
@@ -37,17 +29,14 @@ def _get_model_choices() -> list[str]:
37
  try:
38
  conn = init_db()
39
  df = get_analyzed_models(conn)
40
- if df.empty:
41
- return []
42
- return df["model_id"].tolist()
43
  except Exception:
44
  return []
45
 
46
 
47
- def _load_df(model_id: str, modality: str,
48
- start_layer: int, end_layer: int) -> pd.DataFrame:
49
  conn = init_db()
50
- df = get_layer_metrics(
51
  conn,
52
  model_id = model_id,
53
  modality = modality if modality != "all" else None,
@@ -55,170 +44,157 @@ def _load_df(model_id: str, modality: str,
55
  start_layer = int(start_layer),
56
  end_layer = int(end_layer),
57
  )
58
- return df
59
 
60
 
61
  def _infer_dims(df: pd.DataFrame) -> tuple[int, int]:
62
- """Try to read head_dim and d_model from the dataframe."""
63
  head_dim = 128
64
  d_model = 5120
65
  if not df.empty:
66
- if "head_dim" in df.columns:
67
- v = df["head_dim"].dropna()
68
- if len(v):
69
- head_dim = int(v.median())
70
- if "d_model" in df.columns:
71
- v = df["d_model"].dropna()
72
- if len(v):
73
- d_model = int(v.median())
74
  return head_dim, d_model
75
 
76
 
77
- def _short_name(model_id: str) -> str:
78
  return model_id.split("/")[-1] if "/" in model_id else model_id
79
 
80
 
81
- def _safe_base_path(name: str) -> str:
82
- safe = name.replace("/", "_").replace(" ", "_")
83
- return os.path.join(_OUT_DIR, safe)
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  # ─────────────────────────────────────────────────────────────────────────────
87
- # Main generation functions
88
  # ─────────────────────────────────────────────────────────────────────────────
89
 
90
- def generate_single(
91
- model_id: str,
92
- modality: str,
93
- start_layer: int,
94
- end_layer: int,
95
- show_band: bool,
96
- progress=gr.Progress()
97
- ) -> tuple:
98
- """
99
- Returns: (status_str, png_path, [png_path, pdf_path, svg_path], plotly_fig)
100
- """
101
- if not model_id or not model_id.strip():
102
- return "❌ Please select a model.", None, None, None
103
-
104
- progress(0.1, desc="Loading data from DB...")
105
- df = _load_df(model_id, modality, start_layer, end_layer)
106
-
107
  if df.empty:
108
- return (
109
- f"❌ No data found for {model_id} "
110
- f"(modality={modality}, layers {start_layer}~{end_layer}).\n"
111
- f"Please run analysis first in Tab 2.",
112
- None, None, None
113
- )
114
-
115
- progress(0.35, desc="Inferring dimensions...")
116
- head_dim, d_model = _infer_dims(df)
117
- n_layers = df["layer"].nunique()
118
- n_records = len(df)
119
-
120
- progress(0.50, desc="Generating matplotlib figure...")
121
- name = _short_name(model_id)
122
- fig = plot_single_model(
123
- df, model_name=name,
124
- show_band=show_band,
125
- head_dim=head_dim,
126
- d_model=d_model,
127
  )
 
 
128
 
129
- progress(0.75, desc="Saving PNG / PDF / SVG...")
130
- base = _safe_base_path(f"single_{name}_L{start_layer}-{end_layer}")
131
- paths = save_figure(fig, base)
132
-
133
- progress(0.90, desc="Generating Plotly preview...")
134
- plotly_fig = fig_to_plotly(fig)
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  import matplotlib.pyplot as plt
 
 
 
 
 
 
137
  plt.close(fig)
138
-
139
  status = (
140
- f"βœ… {model_id} | modality={modality} "
141
- f"| layers {start_layer}~{end_layer} "
142
- f"| {n_layers} layers {n_records} head-records\n"
143
- f" head_dim={head_dim} d_model={d_model}\n"
144
- f" Saved: {', '.join(os.path.basename(p) for p in paths)}"
145
  )
146
- png_path = paths[0]
147
- return status, png_path, paths, plotly_fig
148
-
149
-
150
- def generate_compare(
151
- model_a: str,
152
- model_b: str,
153
- modality: str,
154
- start_layer: int,
155
- end_layer: int,
156
- show_band: bool,
157
- show_delta: bool,
158
- progress=gr.Progress()
159
- ) -> tuple:
160
- if not model_a or not model_b:
161
- return "❌ Please select both models.", None, None, None
162
- if model_a == model_b:
163
- return "❌ Please select two different models.", None, None, None
164
 
165
- progress(0.10, desc="Loading Model A from DB...")
166
- df_a = _load_df(model_a, modality, start_layer, end_layer)
167
- progress(0.25, desc="Loading Model B from DB...")
168
- df_b = _load_df(model_b, modality, start_layer, end_layer)
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  if df_a.empty:
171
- return f"❌ No data for Model A ({model_a}).", None, None, None
172
  if df_b.empty:
173
- return f"❌ No data for Model B ({model_b}).", None, None, None
174
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  head_dim_a, d_model_a = _infer_dims(df_a)
176
  head_dim_b, d_model_b = _infer_dims(df_b)
177
- head_dim = int((head_dim_a + head_dim_b) / 2)
178
- d_model = int((d_model_a + d_model_b) / 2)
179
-
180
- progress(0.50, desc="Generating comparison figure...")
181
- name_a = _short_name(model_a)
182
- name_b = _short_name(model_b)
183
  fig = plot_compare_models(
184
- df_a, df_b,
185
- name_a=name_a, name_b=name_b,
186
- show_band=show_band,
187
- show_delta=show_delta,
188
- head_dim=head_dim,
189
- d_model=d_model,
 
190
  )
191
-
192
- progress(0.80, desc="Saving PNG / PDF / SVG...")
193
- base = _safe_base_path(f"compare_{name_a}_vs_{name_b}_L{start_layer}-{end_layer}")
194
  paths = save_figure(fig, base)
195
-
196
- progress(0.92, desc="Generating Plotly preview...")
197
- plotly_fig = fig_to_plotly(fig)
198
-
199
- import matplotlib.pyplot as plt
200
  plt.close(fig)
201
-
202
  status = (
203
- f"βœ… {name_a} vs {name_b}\n"
204
- f" modality={modality} layers {start_layer}~{end_layer}\n"
205
- f" Model A: {len(df_a)} records | Model B: {len(df_b)} records\n"
206
- f" head_dimβ‰ˆ{head_dim} d_modelβ‰ˆ{d_model}\n"
207
- f" Saved: {', '.join(os.path.basename(p) for p in paths)}"
208
  )
209
- return status, paths[0], paths, plotly_fig
210
-
211
-
212
- def make_zip(file_paths: list) -> str | None:
213
- """Bundle all exported files into a single ZIP for download."""
214
- if not file_paths:
215
- return None
216
- zip_path = os.path.join(_OUT_DIR, "wang_laws_figures.zip")
217
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
218
- for p in file_paths:
219
- if p and os.path.exists(p):
220
- zf.write(p, os.path.basename(p))
221
- return zip_path
222
 
223
 
224
  # ─────────────────────────────────────────────────────────────────────────────
@@ -228,185 +204,152 @@ def make_zip(file_paths: list) -> str | None:
228
  def build_tab_plot():
229
  with gr.Tab("πŸ“ˆ Plot"):
230
  gr.Markdown("""
231
- ## Wang's Five Laws β€” Publication-Quality Figures
232
- Data is loaded directly from the SQLite database (Tab 2 must be run first).
233
-
234
- **4Γ—3 grid layout** (12 subplots, one figure):
235
- | Row | Content | Laws |
236
- |-----|---------|------|
237
- | 1 | pearson_QK Β· SSR_QK Β· Ξ±_QK | Law 1 & 2 |
238
- | 2 | Οƒ_max(Q) Β· Οƒ_max(K) Β· ΞΊ(Q) & ΞΊ(K) | Law 3 |
239
- | 3 | cosU QK Β· QV Β· KV + random baseline | Law 4 |
240
- | 4 | cosV QK Β· QV Β· KV + random baseline | Law 5 |
241
-
242
- Export: **PNG 300 dpi** Β· **PDF (vector)** Β· **SVG (vector)**
243
  """)
244
 
245
  # ── Shared controls ───────────────────────────────────────────────────
246
  with gr.Row():
247
  modality_sel = gr.Dropdown(
248
- label="Modality",
249
- choices=["language", "vision", "audio", "all"],
250
- value="language",
251
- scale=1,
252
- )
253
- start_l = gr.Number(
254
- label="Start Layer", value=0, precision=0, scale=1
255
- )
256
- end_l = gr.Number(
257
- label="End Layer", value=47, precision=0, scale=1
258
  )
 
 
259
  show_band_chk = gr.Checkbox(
260
- label="Show 25%–75% band (head consistency)",
261
- value=True, scale=1
262
  )
263
 
264
  gr.Markdown("---")
265
 
266
- # ══ Mode 1: Single model ══════════════════════════════════════════════
267
  with gr.Accordion("πŸ“Š Single Model", open=True):
268
- with gr.Row():
269
- model_choices = _get_model_choices()
270
- single_model = gr.Dropdown(
271
- label="Model",
272
- choices=model_choices,
273
- value=model_choices[0] if model_choices else None,
274
- allow_custom_value=True,
275
- scale=3,
276
- info="Refresh the page after analyzing new models to update this list."
277
- )
278
- single_btn = gr.Button(
279
- "🎨 Generate Figure", variant="primary", scale=1
280
- )
281
-
282
- single_status = gr.Textbox(
283
- label="Status", lines=3, interactive=False
284
  )
285
 
286
  with gr.Tabs():
287
- with gr.Tab("πŸ–ΌοΈ Preview (PNG)"):
288
- single_img = gr.Image(
289
- label="Figure preview",
290
- type="filepath",
291
- height=600,
292
- )
293
  with gr.Tab("πŸ“‰ Interactive (Plotly)"):
294
- single_plotly = gr.Plot(label="Plotly interactive")
 
 
 
 
 
 
295
 
296
- with gr.Row():
297
- dl_single_png = gr.File(label="⬇ PNG (300 dpi)")
298
- dl_single_pdf = gr.File(label="⬇ PDF (vector)")
299
- dl_single_svg = gr.File(label="⬇ SVG (vector)")
300
- dl_single_zip = gr.File(label="⬇ ZIP (all formats)")
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  gr.Markdown("---")
303
 
304
- # ══ Mode 2: Two-model comparison ══════════════════════════════════════
305
  with gr.Accordion("πŸ“Š Two-Model Comparison", open=False):
306
  with gr.Row():
307
  model_a = gr.Dropdown(
308
- label="Model A (solid line)",
309
- choices=model_choices,
310
- value=model_choices[0] if len(model_choices) > 0 else None,
311
  allow_custom_value=True,
312
- scale=2,
313
  )
314
  model_b = gr.Dropdown(
315
- label="Model B (dashed line)",
316
- choices=model_choices,
317
- value=model_choices[1] if len(model_choices) > 1 else None,
318
  allow_custom_value=True,
319
- scale=2,
320
  )
321
  show_delta_chk = gr.Checkbox(
322
- label="Show Ξ” (B βˆ’ A) fill",
323
- value=True, scale=1
324
- )
325
- compare_btn = gr.Button(
326
- "🎨 Generate Comparison", variant="primary", scale=1
327
  )
328
 
329
- compare_status = gr.Textbox(
330
- label="Status", lines=3, interactive=False
331
- )
332
-
333
  with gr.Tabs():
334
- with gr.Tab("πŸ–ΌοΈ Preview (PNG)"):
335
- compare_img = gr.Image(
336
- label="Comparison figure preview",
337
- type="filepath",
338
- height=600,
339
- )
340
  with gr.Tab("πŸ“‰ Interactive (Plotly)"):
341
- compare_plotly = gr.Plot(label="Plotly interactive")
 
 
 
 
 
 
342
 
343
- with gr.Row():
344
- dl_cmp_png = gr.File(label="⬇ PNG (300 dpi)")
345
- dl_cmp_pdf = gr.File(label="⬇ PDF (vector)")
346
- dl_cmp_svg = gr.File(label="⬇ SVG (vector)")
347
- dl_cmp_zip = gr.File(label="⬇ ZIP (all formats)")
 
 
 
 
 
 
 
 
 
 
348
 
349
  gr.Markdown("""
350
  ---
351
- **Tips**
352
- - Band = 25%–75% quantile across attention heads per layer.
353
- Narrow band β†’ heads behave consistently β†’ model is "well-organized".
354
- - Vertical dotted lines mark **global layers** (K=V shared, e.g. Gemma-4).
355
- - Dashed horizontal lines = theoretical ideals or random baselines.
356
- - For Law 4 & 5 panels, Q–V and K–V cosU values **below** the random baseline
357
- indicate **super-orthogonality** β€” a key signature of pretraining convergence.
 
358
  """)
359
 
360
- # ── Wire up single model ──────────────────────────────────────────────
361
- _single_file_state = gr.State([])
362
-
363
- def _run_single(model_id, modality, start, end, band, progress=gr.Progress()):
364
- status, png, paths, plotly_fig = generate_single(
365
- model_id, modality, int(start), int(end), band, progress
366
- )
367
- if paths is None:
368
- return status, None, None, None, None, None, None, []
369
- zip_p = make_zip(paths)
370
- png_p = paths[0] if len(paths) > 0 else None
371
- pdf_p = paths[1] if len(paths) > 1 else None
372
- svg_p = paths[2] if len(paths) > 2 else None
373
- return (status, png, plotly_fig,
374
- png_p, pdf_p, svg_p, zip_p, paths)
375
-
376
- single_btn.click(
377
- fn=_run_single,
378
  inputs=[single_model, modality_sel, start_l, end_l, show_band_chk],
379
- outputs=[
380
- single_status, single_img, single_plotly,
381
- dl_single_png, dl_single_pdf, dl_single_svg, dl_single_zip,
382
- _single_file_state,
383
- ]
384
  )
385
-
386
- # ── Wire up comparison ────────────────────────────────────────────────
387
- _compare_file_state = gr.State([])
388
-
389
- def _run_compare(ma, mb, modality, start, end, band, delta,
390
- progress=gr.Progress()):
391
- status, png, paths, plotly_fig = generate_compare(
392
- ma, mb, modality, int(start), int(end), band, delta, progress
393
- )
394
- if paths is None:
395
- return status, None, None, None, None, None, None, []
396
- zip_p = make_zip(paths)
397
- png_p = paths[0] if len(paths) > 0 else None
398
- pdf_p = paths[1] if len(paths) > 1 else None
399
- svg_p = paths[2] if len(paths) > 2 else None
400
- return (status, png, plotly_fig,
401
- png_p, pdf_p, svg_p, zip_p, paths)
402
-
403
- compare_btn.click(
404
- fn=_run_compare,
405
  inputs=[model_a, model_b, modality_sel,
406
  start_l, end_l, show_band_chk, show_delta_chk],
407
- outputs=[
408
- compare_status, compare_img, compare_plotly,
409
- dl_cmp_png, dl_cmp_pdf, dl_cmp_svg, dl_cmp_zip,
410
- _compare_file_state,
411
- ]
412
  )
 
1
  # ui/tab_plot.py
2
  """
3
  Tab5: Plot β€” Publication-quality figure generation
4
+ - matplotlib β†’ PNG / PDF / SVG export (300 dpi, paper-ready)
5
+ - Plotly native β†’ interactive browser preview (12Γ—1, full-width, fast)
6
+ Two completely independent rendering paths, no conversion between them.
 
7
  """
8
 
9
  import os
 
10
  import zipfile
11
 
12
  import gradio as gr
13
  import pandas as pd
 
14
 
15
  from db.schema import init_db
16
  from db.reader import get_layer_metrics, get_analyzed_models
17
+ from core.plotter import plot_single_model, plot_compare_models, save_figure
18
+ from core.plotter_plotly import plotly_single, plotly_compare
19
+
 
 
 
 
 
20
  _OUT_DIR = "/tmp/wang_plots"
21
  os.makedirs(_OUT_DIR, exist_ok=True)
22
 
 
29
  try:
30
  conn = init_db()
31
  df = get_analyzed_models(conn)
32
+ return df["model_id"].tolist() if not df.empty else []
 
 
33
  except Exception:
34
  return []
35
 
36
 
37
+ def _load_df(model_id, modality, start_layer, end_layer) -> pd.DataFrame:
 
38
  conn = init_db()
39
+ return get_layer_metrics(
40
  conn,
41
  model_id = model_id,
42
  modality = modality if modality != "all" else None,
 
44
  start_layer = int(start_layer),
45
  end_layer = int(end_layer),
46
  )
 
47
 
48
 
49
  def _infer_dims(df: pd.DataFrame) -> tuple[int, int]:
 
50
  head_dim = 128
51
  d_model = 5120
52
  if not df.empty:
53
+ if "head_dim" in df.columns and df["head_dim"].notna().any():
54
+ head_dim = int(df["head_dim"].dropna().median())
55
+ if "d_model" in df.columns and df["d_model"].notna().any():
56
+ d_model = int(df["d_model"].dropna().median())
 
 
 
 
57
  return head_dim, d_model
58
 
59
 
60
+ def _short(model_id: str) -> str:
61
  return model_id.split("/")[-1] if "/" in model_id else model_id
62
 
63
 
64
+ def _safe_path(tag: str) -> str:
65
+ return os.path.join(_OUT_DIR, tag.replace("/", "_").replace(" ", "_"))
66
+
67
+
68
+ def _make_zip(paths: list) -> str | None:
69
+ valid = [p for p in paths if p and os.path.exists(p)]
70
+ if not valid:
71
+ return None
72
+ zp = os.path.join(_OUT_DIR, "wang_laws_figures.zip")
73
+ with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as zf:
74
+ for p in valid:
75
+ zf.write(p, os.path.basename(p))
76
+ return zp
77
 
78
 
79
  # ─────────────────────────────────────────────────────────────────────────────
80
+ # Single-model handlers
81
  # ─────────────────────────────────────────────────────────────────────────────
82
 
83
+ def gen_single_plotly(model_id, modality, start_l, end_l, show_band,
84
+ progress=gr.Progress()):
85
+ """Fast path: native Plotly, no matplotlib involved."""
86
+ if not model_id:
87
+ return None, "Please select a model."
88
+ progress(0.2, desc="Loading data from DB...")
89
+ df = _load_df(model_id, modality, start_l, end_l)
 
 
 
 
 
 
 
 
 
 
90
  if df.empty:
91
+ return None, f"No data for {model_id}. Run Tab 2 analysis first."
92
+ progress(0.7, desc="Building interactive figure...")
93
+ fig = plotly_single(df, _short(model_id), show_band=show_band)
94
+ status = (
95
+ f"βœ… {model_id} | {df['layer'].nunique()} layers "
96
+ f"{len(df)} head-records | modality={modality}"
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
+ progress(1.0)
99
+ return fig, status
100
 
 
 
 
 
 
 
101
 
102
+ def gen_single_export(model_id, modality, start_l, end_l, show_band,
103
+ progress=gr.Progress()):
104
+ """Export path: matplotlib β†’ PNG/PDF/SVG."""
105
+ if not model_id:
106
+ return "Please select a model.", None, None, None, None
107
+ progress(0.15, desc="Loading data from DB...")
108
+ df = _load_df(model_id, modality, start_l, end_l)
109
+ if df.empty:
110
+ return f"No data for {model_id}.", None, None, None, None
111
+ head_dim, d_model = _infer_dims(df)
112
+ progress(0.40, desc="Rendering matplotlib figure (18x20 in, 300 dpi)...")
113
  import matplotlib.pyplot as plt
114
+ fig = plot_single_model(df, _short(model_id),
115
+ show_band=show_band,
116
+ head_dim=head_dim, d_model=d_model)
117
+ progress(0.75, desc="Saving PNG / PDF / SVG...")
118
+ base = _safe_path(f"single_{_short(model_id)}_L{int(start_l)}-{int(end_l)}")
119
+ paths = save_figure(fig, base)
120
  plt.close(fig)
121
+ zip_p = _make_zip(paths)
122
  status = (
123
+ f"Exported: {', '.join(os.path.basename(p) for p in paths)}\n"
124
+ f"head_dim={head_dim} d_model={d_model}"
 
 
 
125
  )
126
+ progress(1.0)
127
+ png = paths[0] if len(paths) > 0 else None
128
+ pdf = paths[1] if len(paths) > 1 else None
129
+ svg = paths[2] if len(paths) > 2 else None
130
+ return status, png, pdf, svg, zip_p
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
 
132
 
133
+ # ─────────────────────────────────────────────────────────────────────────────
134
+ # Compare handlers
135
+ # ─────────────────────────────────────────────────────────────────────────────
136
+
137
+ def gen_compare_plotly(model_a, model_b, modality, start_l, end_l,
138
+ show_band, show_delta, progress=gr.Progress()):
139
+ if not model_a or not model_b:
140
+ return None, "Please select both models."
141
+ if model_a == model_b:
142
+ return None, "Please select two different models."
143
+ progress(0.15, desc="Loading Model A...")
144
+ df_a = _load_df(model_a, modality, start_l, end_l)
145
+ progress(0.35, desc="Loading Model B...")
146
+ df_b = _load_df(model_b, modality, start_l, end_l)
147
  if df_a.empty:
148
+ return None, f"No data for Model A ({model_a})."
149
  if df_b.empty:
150
+ return None, f"No data for Model B ({model_b})."
151
+ progress(0.65, desc="Building interactive comparison figure...")
152
+ fig = plotly_compare(df_a, df_b, _short(model_a), _short(model_b),
153
+ show_band=show_band, show_delta=show_delta)
154
+ status = (
155
+ f"βœ… {_short(model_a)} vs {_short(model_b)} | "
156
+ f"A: {len(df_a)} records B: {len(df_b)} records | modality={modality}"
157
+ )
158
+ progress(1.0)
159
+ return fig, status
160
+
161
+
162
+ def gen_compare_export(model_a, model_b, modality, start_l, end_l,
163
+ show_band, show_delta, progress=gr.Progress()):
164
+ if not model_a or not model_b or model_a == model_b:
165
+ return "Select two different models.", None, None, None, None
166
+ progress(0.10, desc="Loading data...")
167
+ df_a = _load_df(model_a, modality, start_l, end_l)
168
+ df_b = _load_df(model_b, modality, start_l, end_l)
169
+ if df_a.empty or df_b.empty:
170
+ return "Missing data for one or both models.", None, None, None, None
171
  head_dim_a, d_model_a = _infer_dims(df_a)
172
  head_dim_b, d_model_b = _infer_dims(df_b)
173
+ head_dim = (head_dim_a + head_dim_b) // 2
174
+ d_model = (d_model_a + d_model_b) // 2
175
+ progress(0.40, desc="Rendering matplotlib figure...")
176
+ import matplotlib.pyplot as plt
 
 
177
  fig = plot_compare_models(
178
+ df_a, df_b, _short(model_a), _short(model_b),
179
+ show_band=show_band, show_delta=show_delta,
180
+ head_dim=head_dim, d_model=d_model,
181
+ )
182
+ progress(0.78, desc="Saving PNG / PDF / SVG...")
183
+ base = _safe_path(
184
+ f"compare_{_short(model_a)}_vs_{_short(model_b)}_L{int(start_l)}-{int(end_l)}"
185
  )
 
 
 
186
  paths = save_figure(fig, base)
 
 
 
 
 
187
  plt.close(fig)
188
+ zip_p = _make_zip(paths)
189
  status = (
190
+ f"Exported: {', '.join(os.path.basename(p) for p in paths)}\n"
191
+ f"head_dimβ‰ˆ{head_dim} d_modelβ‰ˆ{d_model}"
 
 
 
192
  )
193
+ progress(1.0)
194
+ png = paths[0] if len(paths) > 0 else None
195
+ pdf = paths[1] if len(paths) > 1 else None
196
+ svg = paths[2] if len(paths) > 2 else None
197
+ return status, png, pdf, svg, zip_p
 
 
 
 
 
 
 
 
198
 
199
 
200
  # ─────────────────────────────────────────────────────────────────────────────
 
204
  def build_tab_plot():
205
  with gr.Tab("πŸ“ˆ Plot"):
206
  gr.Markdown("""
207
+ ## Wang's Five Laws β€” Figures
208
+
209
+ Two independent rendering paths:
210
+ | Path | Engine | Speed | Use for |
211
+ |------|--------|-------|---------|
212
+ | **Interactive** | Native Plotly, 12Γ—1 full-width | Fast | Exploration, hover, zoom |
213
+ | **Export** | Matplotlib 18Γ—20 in @ 300 dpi | Slower | Paper submission (PNG/PDF/SVG) |
214
+
215
+ > Run **Tab 2 (Analyze)** first to populate the database.
 
 
 
216
  """)
217
 
218
  # ── Shared controls ───────────────────────────────────────────────────
219
  with gr.Row():
220
  modality_sel = gr.Dropdown(
221
+ ["language", "vision", "audio", "all"],
222
+ value="language", label="Modality", scale=1,
 
 
 
 
 
 
 
 
223
  )
224
+ start_l = gr.Number(value=0, precision=0, label="Start Layer", scale=1)
225
+ end_l = gr.Number(value=47, precision=0, label="End Layer", scale=1)
226
  show_band_chk = gr.Checkbox(
227
+ value=True, label="Show IQR band (head consistency)", scale=1
 
228
  )
229
 
230
  gr.Markdown("---")
231
 
232
+ # ══ Single model ══════════════════════════════════════════════════════
233
  with gr.Accordion("πŸ“Š Single Model", open=True):
234
+ choices = _get_model_choices()
235
+ single_model = gr.Dropdown(
236
+ choices=choices,
237
+ value=choices[0] if choices else None,
238
+ allow_custom_value=True,
239
+ label="Model",
240
+ info="Refresh page after new analysis to update this list.",
 
 
 
 
 
 
 
 
 
241
  )
242
 
243
  with gr.Tabs():
244
+ # ── Interactive ───────────────────────────────────────────────
 
 
 
 
 
245
  with gr.Tab("πŸ“‰ Interactive (Plotly)"):
246
+ single_plotly_btn = gr.Button(
247
+ "⚑ Generate Interactive Figure", variant="primary"
248
+ )
249
+ single_plotly_status = gr.Textbox(
250
+ lines=1, interactive=False, label="Status"
251
+ )
252
+ single_plotly_fig = gr.Plot(label=None)
253
 
254
+ # ── Export ────────────────────────────────────────────────────
255
+ with gr.Tab("πŸ–¨οΈ Export (PNG / PDF / SVG)"):
256
+ single_export_btn = gr.Button(
257
+ "πŸ–¨οΈ Render & Export (paper-quality, ~30 s)",
258
+ variant="secondary"
259
+ )
260
+ single_export_status = gr.Textbox(
261
+ lines=2, interactive=False, label="Export status"
262
+ )
263
+ single_preview = gr.Image(
264
+ type="filepath", label="PNG preview", height=400
265
+ )
266
+ with gr.Row():
267
+ dl_s_png = gr.File(label="⬇ PNG (300 dpi)")
268
+ dl_s_pdf = gr.File(label="⬇ PDF (vector)")
269
+ dl_s_svg = gr.File(label="⬇ SVG (vector)")
270
+ dl_s_zip = gr.File(label="⬇ ZIP (all)")
271
 
272
  gr.Markdown("---")
273
 
274
+ # ══ Two-model comparison ══════════════════════════════════════════════
275
  with gr.Accordion("πŸ“Š Two-Model Comparison", open=False):
276
  with gr.Row():
277
  model_a = gr.Dropdown(
278
+ choices=choices,
279
+ value=choices[0] if len(choices) > 0 else None,
 
280
  allow_custom_value=True,
281
+ label="Model A (solid line)", scale=2,
282
  )
283
  model_b = gr.Dropdown(
284
+ choices=choices,
285
+ value=choices[1] if len(choices) > 1 else None,
 
286
  allow_custom_value=True,
287
+ label="Model B (dashed line)", scale=2,
288
  )
289
  show_delta_chk = gr.Checkbox(
290
+ value=True, label="Show Ξ” fill (B βˆ’ A)", scale=1
 
 
 
 
291
  )
292
 
 
 
 
 
293
  with gr.Tabs():
 
 
 
 
 
 
294
  with gr.Tab("πŸ“‰ Interactive (Plotly)"):
295
+ cmp_plotly_btn = gr.Button(
296
+ "⚑ Generate Interactive Comparison", variant="primary"
297
+ )
298
+ cmp_plotly_status = gr.Textbox(
299
+ lines=1, interactive=False, label="Status"
300
+ )
301
+ cmp_plotly_fig = gr.Plot(label=None)
302
 
303
+ with gr.Tab("πŸ–¨οΈ Export (PNG / PDF / SVG)"):
304
+ cmp_export_btn = gr.Button(
305
+ "πŸ–¨οΈ Render & Export", variant="secondary"
306
+ )
307
+ cmp_export_status = gr.Textbox(
308
+ lines=2, interactive=False, label="Export status"
309
+ )
310
+ cmp_preview = gr.Image(
311
+ type="filepath", label="PNG preview", height=400
312
+ )
313
+ with gr.Row():
314
+ dl_c_png = gr.File(label="⬇ PNG (300 dpi)")
315
+ dl_c_pdf = gr.File(label="⬇ PDF (vector)")
316
+ dl_c_svg = gr.File(label="⬇ SVG (vector)")
317
+ dl_c_zip = gr.File(label="⬇ ZIP (all)")
318
 
319
  gr.Markdown("""
320
  ---
321
+ **Reading the figures**
322
+ - **IQR band** = 25%–75% quantile across attention heads per layer.
323
+ Narrow band β†’ heads behave consistently β†’ model is well-organized.
324
+ - **Dotted vertical lines** = global (K=V shared) layers (Gemma-4 only).
325
+ - **Dashed horizontal lines** = theoretical ideals (r=1, SSR=0, Ξ±=1)
326
+ or random baselines (cosU: 1/√d_head · cosV: 1/√d_model).
327
+ - **Super-orthogonality** (Law 4): cosU(Q–V) and cosU(K–V) sit *below*
328
+ the random baseline β€” pretraining actively pushes V away from Q/K.
329
  """)
330
 
331
+ # ── Wiring ────────────────────────────────────────────────────────────
332
+ single_plotly_btn.click(
333
+ fn=gen_single_plotly,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  inputs=[single_model, modality_sel, start_l, end_l, show_band_chk],
335
+ outputs=[single_plotly_fig, single_plotly_status],
 
 
 
 
336
  )
337
+ single_export_btn.click(
338
+ fn=gen_single_export,
339
+ inputs=[single_model, modality_sel, start_l, end_l, show_band_chk],
340
+ outputs=[single_export_status, single_preview,
341
+ dl_s_png, dl_s_pdf, dl_s_svg, dl_s_zip],
342
+ )
343
+ cmp_plotly_btn.click(
344
+ fn=gen_compare_plotly,
345
+ inputs=[model_a, model_b, modality_sel,
346
+ start_l, end_l, show_band_chk, show_delta_chk],
347
+ outputs=[cmp_plotly_fig, cmp_plotly_status],
348
+ )
349
+ cmp_export_btn.click(
350
+ fn=gen_compare_export,
 
 
 
 
 
 
351
  inputs=[model_a, model_b, modality_sel,
352
  start_l, end_l, show_band_chk, show_delta_chk],
353
+ outputs=[cmp_export_status, cmp_preview,
354
+ dl_c_png, dl_c_pdf, dl_c_svg, dl_c_zip],
 
 
 
355
  )