Ex0bit commited on
Commit
f2ae1f5
·
verified ·
1 Parent(s): b87ded2

initial qwen-scope-live deploy

Browse files
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen-Scope Live SAE Feature Steering — HF Space Docker SDK image.
2
+ # Free tier (CPU, ~16GB RAM) — locked to Qwen3-1.7B-Base only.
3
+ FROM python:3.11-slim
4
+
5
+ # Avoid interactive prompts during installs
6
+ ENV DEBIAN_FRONTEND=noninteractive \
7
+ PYTHONUNBUFFERED=1 \
8
+ PYTHONDONTWRITEBYTECODE=1 \
9
+ PIP_NO_CACHE_DIR=1 \
10
+ HF_HUB_DISABLE_TELEMETRY=1 \
11
+ TRANSFORMERS_NO_ADVISORY_WARNINGS=1
12
+
13
+ # HF Spaces convention: /home/user is writable, expects PORT=7860
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH \
16
+ HF_HOME=/home/user/.cache/huggingface \
17
+ PORT=7860
18
+
19
+ # Add a non-root user (HF Spaces best practice)
20
+ RUN useradd -m -u 1000 user && \
21
+ apt-get update && apt-get install -y --no-install-recommends \
22
+ build-essential ca-certificates && \
23
+ rm -rf /var/lib/apt/lists/*
24
+
25
+ USER user
26
+ WORKDIR $HOME/app
27
+
28
+ # Install Python deps (CPU-only torch from PyTorch index)
29
+ COPY --chown=user:user requirements.txt .
30
+ RUN pip install --user --no-cache-dir -r requirements.txt
31
+
32
+ # Copy application files
33
+ COPY --chown=user:user qwen_scope_steer.py qwen_scope_obs.py server.py index.html ./
34
+
35
+ # Pre-create cache dir (writable for the cached SAE positions JSON)
36
+ RUN mkdir -p $HOME/app/feature_positions
37
+
38
+ EXPOSE 7860
39
+
40
+ CMD ["python", "server.py"]
README.md CHANGED
@@ -1,10 +1,57 @@
1
  ---
2
- title: Qwen Scope Live
3
- emoji:
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: docker
 
7
  pinned: false
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: "Qwen-Scope: Decoding Intelligence, Unleashing Potential"
3
+ emoji: 🔬
4
  colorFrom: blue
5
+ colorTo: pink
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ license: apache-2.0
10
+ short_description: Live SAE feature steering for Qwen3-1.7B
11
  ---
12
 
13
+ # Qwen-Scope: Decoding Intelligence, Unleashing Potential
14
+
15
+ Live, interactive demo of the [Qwen-Scope](https://huggingface.co/Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50) sparse-autoencoder release. Type a prompt, see which SAE features fire, drag a slider to steer them, and watch the model's actual generated text change in real time — all running in your browser against `Qwen/Qwen3-1.7B-Base` + the layer-14 (and layer-selectable) Qwen-Scope SAE.
16
+
17
+ This Space implements **three of the four pillars** from the Qwen-Scope paper:
18
+
19
+ - **Steering** — interpretable feature-level control at inference time (per-token heatmap, position-selective steering, output-only steering, per-token probability panel with click-to-pin top-K candidates)
20
+ - **Evaluation** — capability-aware corpus analysis (paste a prompt set, encode each, see which features fire across the corpus; compare two prompt sets for distinguishing features)
21
+ - **Data-centric** — feature-guided curation (filter docs by feature signature) and steered synthesis (bulk-generate steered completions from seed prompts)
22
+
23
+ (The fourth pillar — Post-training pathology diagnosis — is not yet implemented.)
24
+
25
+ ## How to use
26
+
27
+ 1. **Wait ~3-5 minutes** on first load — the container has to download the model (~3.4GB) and SAE (~537MB) on cold start.
28
+ 2. **Steering tab**: type a prompt, click *Encode*, click *steer* on any top feature, drag the slider, click *Generate*. The output panels show baseline vs. steered side-by-side as colored token chips — click any chip to see top-K next-token candidates. Per-token feature heatmap appears in the right column showing which features fire on which tokens.
29
+ 3. **Evaluation tab**: paste many prompts (one per line), encode the corpus, see per-sample top features, signature heatmap, and a Compare panel for finding features that distinguish set A from set B.
30
+ 4. **Data-centric tab**: paste a corpus, filter it by feature signature (include/exclude docs where feature X fires), and run bulk steered synthesis to produce N targeted training examples.
31
+ 5. **Layer selector** in the header: change the SAE layer (0–27) to see how features differ across model depth. First swap to a layer takes ~20s; subsequent swaps are <0.1s thanks to an in-memory LRU cache.
32
+
33
+ ## Hardware constraint
34
+
35
+ This Space is **locked to Qwen3-1.7B-Base** because that's what fits in a free-tier HF Space (CPU, 16GB RAM). The full local version supports the entire Qwen-Scope catalog including Qwen3-8B, Qwen3.5-9B/27B, Qwen3.6-27B/35B-A3B, and Qwen3-30B-A3B — each with its matching SAE — but those need GPU hardware and persistent storage. See the source repository for instructions.
36
+
37
+ ## What's verified
38
+
39
+ This deployment is the same code that was [verified end-to-end on macOS / Apple Silicon MPS](https://huggingface.co/Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50) — TopK SAE encode, residual hook, additive steering, per-token probability extraction, position-selective hooks across prefill + decode steps. The math is identical; only the device (CPU instead of MPS) and dtype (fp32 instead of bf16) differ in the deployment build.
40
+
41
+ ## License
42
+
43
+ Apache 2.0 (this code). The Qwen-Scope SAE weights and Qwen3 model weights are governed by their respective Qwen licenses.
44
+
45
+ ## Citation
46
+
47
+ If you use this in your work, cite the Qwen-Scope paper:
48
+
49
+ ```bibtex
50
+ @misc{qwen_scope,
51
+ title = {{Qwen-Scope}: Turning Sparse Features into Development Tools for Large Language Models},
52
+ url = {https://qianwen-res.oss-accelerate.aliyuncs.com/qwen-scope/Qwen_Scope.pdf},
53
+ author = {{Qwen Team}},
54
+ month = {April},
55
+ year = {2026}
56
+ }
57
+ ```
__pycache__/qwen_scope_steer.cpython-314.pyc ADDED
Binary file (19.9 kB). View file
 
__pycache__/server.cpython-314.pyc ADDED
Binary file (46.5 kB). View file
 
index.html ADDED
@@ -0,0 +1,1564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>Qwen-Scope · Live SAE Feature Steering</title>
7
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
8
+ <style>
9
+ :root {
10
+ --bg: #08080d;
11
+ --bg-panel: rgba(18, 18, 26, 0.78);
12
+ --bg-panel-strong: rgba(24, 24, 34, 0.92);
13
+ --border: rgba(255,255,255,0.08);
14
+ --border-strong: rgba(255,255,255,0.16);
15
+ --fg: #ececf1;
16
+ --fg-dim: #9b9ba8;
17
+ --fg-faint: #5e5e6e;
18
+ --accent: #7df9ff;
19
+ --accent-2: #ff7df9;
20
+ --warn: #ffb84d;
21
+ --good: #74e2a3;
22
+ --bad: #ff7d7d;
23
+ --mono: ui-monospace, "SF Mono", "JetBrains Mono", Menlo, monospace;
24
+ --sans: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
25
+ }
26
+ * { box-sizing: border-box; }
27
+ html, body { margin:0; padding:0; height:100%; background:var(--bg); color:var(--fg); font-family:var(--sans); overflow:hidden; }
28
+ #scene { position:fixed; inset:0; z-index:0; }
29
+ #ui { position:fixed; inset:0; z-index:10; pointer-events:none; display:grid;
30
+ grid-template-columns: 380px 1fr 420px;
31
+ grid-template-rows: 1fr auto;
32
+ gap:14px; padding:14px; }
33
+ #ui > section { pointer-events:auto; }
34
+
35
+ .panel { background:var(--bg-panel); border:1px solid var(--border);
36
+ border-radius:14px; padding:14px; backdrop-filter: blur(14px) saturate(140%);
37
+ -webkit-backdrop-filter: blur(14px) saturate(140%); overflow:hidden;
38
+ transition: padding .15s ease, max-height .25s ease; }
39
+ .panel h2 { margin:0 0 10px; font-size:11px; font-weight:600;
40
+ text-transform:uppercase; letter-spacing:0.14em; color:var(--fg-dim);
41
+ display:flex; justify-content:space-between; align-items:center; gap:8px; }
42
+ .panel h3 { margin:0 0 6px; font-size:13px; font-weight:600; color:var(--fg); }
43
+
44
+ /* Minimize button on each panel header */
45
+ .min-btn { background:transparent; border:1px solid var(--border-strong);
46
+ color:var(--fg-dim); border-radius:5px; padding:1px 8px; font-size:14px;
47
+ font-family:var(--mono); cursor:pointer; line-height:1; min-width:0;
48
+ font-weight:400; letter-spacing:0; }
49
+ .min-btn:hover { background:var(--bg-panel-strong); color:var(--fg);
50
+ border-color:var(--accent); transform:none; box-shadow:none; }
51
+ .panel.collapsed { padding:8px 14px; }
52
+ .panel.collapsed > h2 { margin-bottom:0; }
53
+ .panel.collapsed > :not(h2) { display:none; }
54
+
55
+ /* Header strip */
56
+ #header { position:fixed; top:14px; left:50%; transform:translateX(-50%);
57
+ z-index:20; pointer-events:auto;
58
+ background:var(--bg-panel-strong); border:1px solid var(--border-strong);
59
+ border-radius:999px; padding:8px 18px; display:flex; gap:14px;
60
+ align-items:center; font-family:var(--mono); font-size:12px; color:var(--fg-dim);
61
+ backdrop-filter: blur(20px); }
62
+ #header .dot { width:8px; height:8px; border-radius:50%; background:var(--bad); transition:background .3s; }
63
+ #header.live .dot { background:var(--good); box-shadow:0 0 12px var(--good); }
64
+ #header.loading .dot { background:var(--warn); box-shadow:0 0 12px var(--warn); animation:pulse 1.5s ease-in-out infinite; }
65
+ @keyframes pulse { 0%,100% { opacity:1; } 50% { opacity:0.4; } }
66
+ #header b { color:var(--fg); font-weight:600; }
67
+ #model-select {
68
+ background:transparent; color:var(--fg); border:1px solid var(--border-strong);
69
+ border-radius:6px; padding:3px 8px; font-family:var(--mono); font-size:11px;
70
+ outline:none; cursor:pointer; max-width:280px;
71
+ }
72
+ #model-select:focus { border-color:var(--accent); }
73
+ #model-select option { background:#16161e; color:var(--fg); }
74
+
75
+ /* Tab strip */
76
+ #tabs { position:fixed; top:62px; left:50%; transform:translateX(-50%); z-index:18;
77
+ display:flex; gap:6px; pointer-events:auto; background:var(--bg-panel-strong);
78
+ border:1px solid var(--border-strong); border-radius:999px;
79
+ padding:4px; backdrop-filter: blur(20px); }
80
+ .tab { background:transparent; border:none; color:var(--fg-dim);
81
+ padding:6px 16px; font-size:11px; font-weight:600;
82
+ text-transform:uppercase; letter-spacing:0.08em; border-radius:999px;
83
+ cursor:pointer; transition: all 0.15s ease; }
84
+ .tab:hover { color:var(--fg); background:rgba(255,255,255,0.04); }
85
+ .tab.active { background: linear-gradient(135deg, rgba(125,249,255,0.18), rgba(255,125,249,0.12));
86
+ color:var(--fg); border:1px solid var(--border-strong); }
87
+
88
+ /* Tab visibility — hide panels not matching the current tab */
89
+ body[data-tab="steering"] [data-show-tab]:not([data-show-tab~="steering"]) { display:none !important; }
90
+ body[data-tab="evaluation"] [data-show-tab]:not([data-show-tab~="evaluation"]) { display:none !important; }
91
+ body[data-tab="datacentric"] [data-show-tab]:not([data-show-tab~="datacentric"]) { display:none !important; }
92
+
93
+ /* Loading overlay shown during model swap */
94
+ #loading-overlay {
95
+ position:fixed; inset:0; z-index:50; display:none;
96
+ background:rgba(8,8,13,0.78); backdrop-filter: blur(6px);
97
+ align-items:center; justify-content:center;
98
+ }
99
+ #loading-overlay.visible { display:flex; }
100
+ #loading-overlay .lo-card {
101
+ background:var(--bg-panel-strong); border:1px solid var(--border-strong);
102
+ border-radius:14px; padding:24px 32px; text-align:center; min-width:300px;
103
+ }
104
+
105
+ /* Left: prompt + controls (top-left of screen, below the header) */
106
+ #left { display:flex; flex-direction:column; gap:14px; min-height:0;
107
+ grid-column: 1; grid-row: 1; align-self: start;
108
+ padding-top: 110px; }
109
+ textarea, input[type="text"], input[type="number"] {
110
+ width:100%; background:rgba(0,0,0,0.35); border:1px solid var(--border-strong);
111
+ color:var(--fg); border-radius:8px; padding:10px 12px; font-family:var(--mono);
112
+ font-size:13px; outline:none; resize:vertical;
113
+ }
114
+ textarea:focus, input:focus { border-color:var(--accent); box-shadow:0 0 0 3px rgba(125,249,255,0.12); }
115
+ textarea { min-height:72px; }
116
+ .row { display:flex; gap:8px; align-items:center; }
117
+ .row > * { flex:1; }
118
+ .row > .grow0 { flex:0; }
119
+
120
+ button { background: linear-gradient(135deg, rgba(125,249,255,0.18), rgba(255,125,249,0.12));
121
+ border:1px solid var(--border-strong); color:var(--fg); border-radius:8px;
122
+ padding:9px 14px; font-family:var(--sans); font-size:13px; font-weight:600;
123
+ cursor:pointer; transition:all .15s; letter-spacing:0.02em; }
124
+ button:hover:not(:disabled) { border-color:var(--accent); transform:translateY(-1px); box-shadow:0 4px 14px rgba(125,249,255,0.18); }
125
+ button:disabled { opacity:0.4; cursor:not-allowed; }
126
+ button.primary { background: linear-gradient(135deg, #7df9ff, #ff7df9); color:#0a0a14; border-color:transparent; }
127
+ button.primary:hover:not(:disabled) { box-shadow:0 4px 18px rgba(255,125,249,0.32); }
128
+ button.ghost { background:transparent; }
129
+
130
+ /* Middle: nothing (3D scene shows through) */
131
+ #middle { grid-column: 2; grid-row: 1; pointer-events:none; }
132
+
133
+ /* Bottom: output band spans all columns; sits below cloud area */
134
+ #bottom { grid-column: 1 / -1; grid-row: 2; pointer-events:auto;
135
+ display:grid; grid-template-columns: minmax(0, 1fr) minmax(0, 1fr);
136
+ gap:14px;
137
+ max-height: 38vh; min-height: 0;
138
+ margin: 0 0 0 0; }
139
+ #bottom .panel { display:flex; flex-direction:column;
140
+ min-width: 0; min-height: 0; overflow:hidden; }
141
+ #bottom .panel.collapsed { padding:8px 14px; }
142
+ #bottom .panel.collapsed > .panel-body { display:none; }
143
+ .panel-body { flex: 1 1 auto; overflow:auto; min-height:0; min-width:0; }
144
+ #right .panel { min-width:0; max-width:100%; }
145
+ #right .panel > div { max-width:100%; overflow:auto; }
146
+
147
+ /* Right: top features panel only — output moved to bottom band */
148
+ #right { display:flex; flex-direction:column; gap:14px; min-height:0;
149
+ grid-column: 3; grid-row: 1; max-height: calc(100vh - 28px); padding-top:110px; }
150
+
151
+ #features { flex: 1 1 auto; min-height: 120px; overflow-y:auto; }
152
+ #features::-webkit-scrollbar { width:6px; }
153
+ #features::-webkit-scrollbar-thumb { background:var(--border-strong); border-radius:3px; }
154
+
155
+ .feat { padding:10px; border:1px solid var(--border); border-radius:10px;
156
+ margin-bottom:8px; background:rgba(0,0,0,0.25); transition: border-color .2s, background .2s; }
157
+ .feat:hover { border-color:var(--accent); }
158
+ .feat.steered { border-color:var(--accent-2); background:rgba(255,125,249,0.06); }
159
+ .feat-head { display:flex; justify-content:space-between; align-items:center; margin-bottom:6px; }
160
+ .feat-id { font-family:var(--mono); font-size:12px; color:var(--fg); }
161
+ .feat-act { font-family:var(--mono); font-size:11px; color:var(--good); }
162
+ .feat-act.steered { color:var(--accent-2); }
163
+ .slider-row { display:flex; align-items:center; gap:8px; margin-top:4px; }
164
+ .slider-row input[type=range] { flex:1; -webkit-appearance:none; height:4px; background:var(--border-strong); border-radius:2px; outline:none; }
165
+ .slider-row input[type=range]::-webkit-slider-thumb {
166
+ -webkit-appearance:none; width:14px; height:14px; border-radius:50%;
167
+ background: linear-gradient(135deg, #7df9ff, #ff7df9); cursor:pointer;
168
+ box-shadow:0 0 8px rgba(125,249,255,0.5); border:none;
169
+ }
170
+ .slider-row input[type=range]::-moz-range-thumb {
171
+ width:14px; height:14px; border-radius:50%;
172
+ background: linear-gradient(135deg, #7df9ff, #ff7df9); cursor:pointer; border:none;
173
+ }
174
+ .slider-row .alpha-val { font-family:var(--mono); font-size:11px; color:var(--fg-dim); width:48px; text-align:right; }
175
+ .feat-tools { display:flex; gap:6px; }
176
+ .feat-tools button { padding:3px 8px; font-size:10px; font-weight:500; }
177
+
178
+ /* Output area */
179
+ #output-wrap { flex:1; overflow-y:auto; }
180
+ #output-wrap::-webkit-scrollbar { width:6px; }
181
+ #output-wrap::-webkit-scrollbar-thumb { background:var(--border-strong); border-radius:3px; }
182
+ .out-block { font-family:var(--mono); font-size:12px; line-height:1.55;
183
+ background:rgba(0,0,0,0.4); border:1px solid var(--border);
184
+ border-radius:8px; padding:10px; margin-bottom:8px; white-space:pre-wrap;
185
+ word-break: break-word; color:var(--fg); min-height:36px; }
186
+ .out-block.baseline { border-color:rgba(125,249,255,0.3); }
187
+ .out-block.steered { border-color:rgba(255,125,249,0.3); }
188
+ .out-label { display:flex; justify-content:space-between; align-items:center; margin-bottom:4px;
189
+ font-size:10px; text-transform:uppercase; letter-spacing:0.12em; color:var(--fg-dim); }
190
+ .verifier { font-family:var(--mono); font-size:10px; color:var(--fg-dim); margin-top:6px; padding:6px 8px;
191
+ border-left:2px solid var(--border-strong); }
192
+ .verifier .delta-up { color:var(--good); }
193
+ .verifier .delta-down { color:var(--accent-2); }
194
+
195
+ /* Hover tooltip for selected feature on 3D cloud */
196
+ #tooltip { position:fixed; z-index:25; pointer-events:none;
197
+ background:var(--bg-panel-strong); border:1px solid var(--border-strong);
198
+ border-radius:8px; padding:8px 12px; font-family:var(--mono); font-size:11px;
199
+ line-height:1.45;
200
+ color:var(--fg); display:none; transform: translate(-50%, calc(-100% - 12px));
201
+ min-width:180px; max-width:280px; backdrop-filter: blur(20px); }
202
+
203
+ /* Loader */
204
+ .loader { display:inline-block; width:12px; height:12px; border:2px solid var(--border-strong);
205
+ border-top-color:var(--accent); border-radius:50%; animation:spin 0.8s linear infinite; }
206
+ @keyframes spin { to { transform: rotate(360deg); } }
207
+
208
+ .empty { color:var(--fg-faint); font-size:12px; text-align:center; padding:24px 8px; }
209
+
210
+ .hint { font-size:11px; color:var(--fg-faint); margin-top:6px; line-height:1.4; }
211
+
212
+ .footer { position:fixed; bottom:8px; left:14px; z-index:20; font-family:var(--mono);
213
+ font-size:10px; color:var(--fg-faint); letter-spacing:0.04em; }
214
+
215
+ .legend { font-family:var(--mono); font-size:10px; color:var(--fg-faint); display:flex; gap:12px; margin-top:6px; }
216
+ .legend span::before { content:''; display:inline-block; width:8px; height:8px; border-radius:50%; margin-right:5px; vertical-align:middle; }
217
+ .legend .top::before { background:var(--accent); box-shadow:0 0 6px var(--accent); }
218
+ .legend .pick::before { background:var(--accent-2); box-shadow:0 0 6px var(--accent-2); }
219
+ .legend .dim::before { background:var(--fg-faint); }
220
+
221
+ /* small-screen graceful fallback */
222
+ @media (max-width: 1100px) {
223
+ #ui { grid-template-columns: 1fr; grid-template-rows: auto auto; }
224
+ #left { grid-column:1; grid-row:1; max-height:none; }
225
+ #right { grid-column:1; grid-row:2; padding-top:14px; max-height:none; }
226
+ #middle { display:none; }
227
+ }
228
+ </style>
229
+ </head>
230
+ <body>
231
+ <canvas id="scene"></canvas>
232
+
233
+ <div id="header">
234
+ <span class="dot"></span>
235
+ <span>QWEN-SCOPE LIVE</span>
236
+ <span style="opacity:.4">·</span>
237
+ <select id="model-select" title="Switch the loaded model + SAE pair">
238
+ <option>connecting…</option>
239
+ </select>
240
+ <span style="opacity:.4">·</span>
241
+ <span id="hdr-layer-wrap" title="Click to change SAE layer">
242
+ layer <input type="number" id="layer-input" value="?" min="0" max="0"
243
+ style="width:48px; padding:1px 4px; background:transparent; color:var(--accent);
244
+ border:1px solid var(--border-strong); border-radius:4px;
245
+ font-family:var(--mono); font-size:12px; text-align:center;" />
246
+ <span id="hdr-layer-meta">· mps · bfloat16</span>
247
+ </span>
248
+ <span style="opacity:.4">·</span>
249
+ <span id="hdr-features">— features</span>
250
+ </div>
251
+
252
+ <div id="tabs">
253
+ <button class="tab active" data-tab="steering">Steering</button>
254
+ <button class="tab" data-tab="evaluation">Evaluation</button>
255
+ <button class="tab" data-tab="datacentric">Data-centric</button>
256
+ </div>
257
+
258
+ <div id="loading-overlay">
259
+ <div class="lo-card">
260
+ <div class="loader" style="width:24px;height:24px;border-width:3px;"></div>
261
+ <div id="lo-msg" style="margin-top:14px; font-size:13px; color:var(--fg);">Loading…</div>
262
+ <div id="lo-detail" style="margin-top:6px; font-size:11px; color:var(--fg-faint); font-family:var(--mono);"></div>
263
+ </div>
264
+ </div>
265
+
266
+ <div id="ui">
267
+ <section id="left">
268
+ <div class="panel" data-pid="prompt" data-show-tab="steering">
269
+ <h2>1 · Prompt <button class="min-btn" data-min="prompt" title="Minimize panel">−</button></h2>
270
+ <textarea id="prompt" placeholder="The capital of France is">The capital of France is</textarea>
271
+ <div class="row" style="margin-top:8px;">
272
+ <label style="font-size:11px; color:var(--fg-dim); flex:0 0 auto;">top K</label>
273
+ <input type="number" id="top-k" value="20" min="1" max="500" style="flex:0 0 80px;" />
274
+ <button class="primary" id="btn-encode" style="flex:1;">Encode &amp; show top features</button>
275
+ </div>
276
+ <div class="hint">
277
+ Encodes the prompt's last-token residual at the SAE's layer and ranks the firing features.
278
+ TopK SAE has at most K=50 nonzero features per token, so request up to ~50.
279
+ </div>
280
+ </div>
281
+ <div class="panel" data-pid="generate" data-show-tab="steering">
282
+ <h2>2 · Generate <button class="min-btn" data-min="generate" title="Minimize panel">−</button></h2>
283
+ <div class="row">
284
+ <label style="font-size:11px; color:var(--fg-dim); flex:0 0 auto;">tokens</label>
285
+ <input type="number" id="max-tokens" value="40" min="5" max="200" style="flex:0 0 80px;" />
286
+ <button class="primary" id="btn-generate" disabled>Generate baseline + steered</button>
287
+ </div>
288
+ <div class="hint">
289
+ Runs <code>model.generate</code> twice: once with no hooks, once with all active sliders applied
290
+ as residual-stream additions <code>h ← h + α · W_dec[:, feat]</code>.
291
+ </div>
292
+ </div>
293
+ <div class="panel" data-pid="viz" data-show-tab="steering evaluation datacentric" style="flex:0 0 auto;">
294
+ <h2>3 · Visualization <button class="min-btn" data-min="viz" title="Minimize panel">−</button></h2>
295
+ <div class="legend">
296
+ <span class="dim">all 32K features</span>
297
+ <span class="top">top firing</span>
298
+ <span class="pick">steered</span>
299
+ </div>
300
+ <div class="hint">
301
+ Each point is one SAE feature; positions are the top-3 PCA components of <code>W_enc</code>.
302
+ Click a point to add it as a steering target.
303
+ </div>
304
+ </div>
305
+ </section>
306
+
307
+ <section id="middle"></section>
308
+
309
+ <section id="right">
310
+ <div class="panel" id="features-panel" data-pid="features" data-show-tab="steering"
311
+ style="flex: 0 1 auto; max-height: 38vh; display:flex; flex-direction:column; min-height:0;">
312
+ <h2>Top features <span id="feat-count" style="font-family:var(--mono); color:var(--fg-faint);"></span>
313
+ <button class="min-btn" data-min="features" title="Minimize panel">−</button></h2>
314
+ <div id="features" style="overflow-y:auto; flex:1 1 auto; min-height:0;"><div class="empty">Encode a prompt to populate features.</div></div>
315
+ </div>
316
+
317
+ <div class="panel" id="heatmap-panel" data-show-tab="steering"
318
+ style="flex: 1 1 auto; display:flex; flex-direction:column; min-height:200px;">
319
+ <h2>Per-token feature heatmap
320
+ <span style="display:flex; gap:8px; align-items:center;">
321
+ <label style="font-size:10px; color:var(--fg-faint);">
322
+ <input type="checkbox" id="heatmap-skip-first" style="vertical-align:middle;" /> skip first
323
+ </label>
324
+ <button class="min-btn" data-min="heatmap-panel" title="Minimize panel">−</button>
325
+ </span>
326
+ </h2>
327
+ <div id="heatmap-grid" style="overflow:auto; flex:1 1 auto; min-height:0;">
328
+ <span class="empty">Encode a prompt — heatmap fills automatically.</span>
329
+ </div>
330
+ </div>
331
+
332
+ <!-- Evaluation: corpus encoding + signature heatmap -->
333
+ <div class="panel" data-show-tab="evaluation" id="eval-corpus">
334
+ <h2>Evaluation · corpus encode
335
+ <button class="min-btn" data-min="eval-corpus" title="Minimize panel">−</button>
336
+ </h2>
337
+ <textarea id="eval-prompts" rows="6" placeholder="One prompt per line.&#10;Example:&#10;The capital of France is&#10;Bonjour comment allez-vous&#10;def fibonacci(n):&#10;The mitochondria is the powerhouse"></textarea>
338
+ <div class="row" style="margin-top:8px;">
339
+ <button class="primary" id="btn-eval-encode" style="flex:1;">Encode corpus</button>
340
+ </div>
341
+ <div class="hint">
342
+ Encodes each prompt's last-token residual through the SAE. Returns
343
+ per-sample top features and corpus-level firing rates.
344
+ </div>
345
+ </div>
346
+
347
+ <div class="panel" data-show-tab="evaluation" id="eval-corpus-features">
348
+ <h2>Corpus features <span id="eval-stats" style="font-family:var(--mono); color:var(--fg-faint);"></span>
349
+ <button class="min-btn" data-min="eval-corpus-features" title="Minimize panel">−</button>
350
+ </h2>
351
+ <div id="eval-features-list" style="overflow-y:auto; max-height:30vh;">
352
+ <div class="empty">Encode a corpus to see feature firing rates.</div>
353
+ </div>
354
+ </div>
355
+
356
+ <div class="panel" data-show-tab="evaluation" id="eval-compare">
357
+ <h2>Compare two prompt sets
358
+ <button class="min-btn" data-min="eval-compare" title="Minimize panel">−</button>
359
+ </h2>
360
+ <textarea id="cmp-a" rows="3" placeholder="Set A — one prompt per line"></textarea>
361
+ <textarea id="cmp-b" rows="3" placeholder="Set B — one prompt per line" style="margin-top:6px;"></textarea>
362
+ <div class="row" style="margin-top:6px;">
363
+ <button class="primary" id="btn-cmp" style="flex:1;">Find distinguishing features</button>
364
+ </div>
365
+ <div class="hint">
366
+ Encodes both sets and ranks features by |fire_rate(A) − fire_rate(B)|.
367
+ Shows which features distinguish A from B.
368
+ </div>
369
+ </div>
370
+
371
+ <!-- Data-centric: filter + steered synthesis -->
372
+ <div class="panel" data-show-tab="datacentric" id="dc-corpus">
373
+ <h2>Data-centric · corpus
374
+ <button class="min-btn" data-min="dc-corpus" title="Minimize panel">−</button>
375
+ </h2>
376
+ <textarea id="dc-prompts" rows="5" placeholder="One prompt per line — same shape as Evaluation."></textarea>
377
+ <div class="row" style="margin-top:8px;">
378
+ <button class="primary" id="btn-dc-encode" style="flex:1;">Encode for filtering</button>
379
+ </div>
380
+ <div class="hint">
381
+ Encode docs, then filter by feature signature.
382
+ </div>
383
+ </div>
384
+
385
+ <div class="panel" data-show-tab="datacentric" id="dc-filter">
386
+ <h2>Filter
387
+ <button class="min-btn" data-min="dc-filter" title="Minimize panel">−</button>
388
+ </h2>
389
+ <div class="row">
390
+ <label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">feature</label>
391
+ <input type="number" id="dc-filter-id" placeholder="feature id" min="0" max="199999" style="flex:1;" />
392
+ <select id="dc-filter-mode" style="flex:0 0 auto; background:rgba(0,0,0,0.35); color:var(--fg); border:1px solid var(--border-strong); border-radius:6px; padding:6px 8px; font-family:var(--mono); font-size:11px;">
393
+ <option value="include">includes</option>
394
+ <option value="exclude">excludes</option>
395
+ </select>
396
+ </div>
397
+ <div class="row" style="margin-top:8px;">
398
+ <button id="btn-dc-filter" style="flex:1;">Apply filter</button>
399
+ <button id="btn-dc-clear" class="ghost" style="flex:0 0 auto;">clear</button>
400
+ </div>
401
+ <div class="hint">
402
+ "Includes" = keep only docs where this feature fired (any activation).
403
+ "Excludes" = drop docs where it fired.
404
+ </div>
405
+ </div>
406
+
407
+ <div class="panel" data-show-tab="datacentric" id="dc-synth">
408
+ <h2>Steered synthesis
409
+ <button class="min-btn" data-min="dc-synth" title="Minimize panel">−</button>
410
+ </h2>
411
+ <div class="row">
412
+ <label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">feature</label>
413
+ <input type="number" id="dc-synth-id" placeholder="feature id" min="0" max="199999" style="flex:1;" />
414
+ <label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">α</label>
415
+ <input type="number" id="dc-synth-alpha" value="50" min="-200" max="200" style="flex:0 0 60px;" />
416
+ </div>
417
+ <div class="row" style="margin-top:6px;">
418
+ <label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">tokens</label>
419
+ <input type="number" id="dc-synth-tokens" value="30" min="5" max="200" style="flex:0 0 70px;" />
420
+ <button class="primary" id="btn-dc-synth" style="flex:1;">Synthesize from corpus</button>
421
+ </div>
422
+ <div class="hint">
423
+ For each prompt in the corpus above, generate a steered completion with
424
+ <code>h ← h + α · W_dec[:, feature]</code>. Useful for producing
425
+ targeted training examples that fire a specific feature.
426
+ </div>
427
+ </div>
428
+ </section>
429
+
430
+ <!-- Bottom band -->
431
+ <section id="bottom">
432
+ <!-- STEERING bottom: baseline / steered / heatmap -->
433
+ <div class="panel" data-show-tab="steering" data-pid="baseline">
434
+ <h2>Baseline output (α=0)
435
+ <span style="display:flex; gap:8px; align-items:center;">
436
+ <span id="base-time" style="font-family:var(--mono); color:var(--fg-faint);"></span>
437
+ <button class="min-btn" data-min="baseline" title="Minimize panel">−</button>
438
+ </span>
439
+ </h2>
440
+ <div class="panel-body">
441
+ <div class="out-block baseline" id="out-baseline"><span class="empty">(no run yet)</span></div>
442
+ </div>
443
+ </div>
444
+ <div class="panel" data-show-tab="steering" data-pid="steered">
445
+ <h2>Steered output
446
+ <span style="display:flex; gap:8px; align-items:center;">
447
+ <span id="steered-time" style="font-family:var(--mono); color:var(--fg-faint);"></span>
448
+ <button class="min-btn" data-min="steered" title="Minimize panel">−</button>
449
+ </span>
450
+ </h2>
451
+ <div class="panel-body">
452
+ <div class="out-block steered" id="out-steered"><span class="empty">(no run yet)</span></div>
453
+ <div class="verifier" id="verifier"></div>
454
+ </div>
455
+ </div>
456
+
457
+ <!-- EVALUATION bottom: per-sample table + signature heatmap -->
458
+ <div class="panel" data-show-tab="evaluation" id="eval-samples-panel" style="grid-column: 1 / 2;">
459
+ <h2>Per-sample top features
460
+ <button class="min-btn" data-min="eval-samples-panel" title="Minimize panel">−</button>
461
+ </h2>
462
+ <div class="panel-body">
463
+ <div id="eval-samples"><span class="empty">Encode a corpus to populate.</span></div>
464
+ </div>
465
+ </div>
466
+ <div class="panel" data-show-tab="evaluation" id="eval-heatmap-panel" style="grid-column: 2 / 3;">
467
+ <h2>Signature heatmap / Compare results
468
+ <span style="display:flex; gap:6px; align-items:center;">
469
+ <button id="btn-eval-view-heatmap" class="ghost" style="padding:2px 8px; font-size:10px;">heatmap</button>
470
+ <button id="btn-eval-view-compare" class="ghost" style="padding:2px 8px; font-size:10px;">compare</button>
471
+ <button class="min-btn" data-min="eval-heatmap-panel" title="Minimize panel">−</button>
472
+ </span>
473
+ </h2>
474
+ <div class="panel-body">
475
+ <div id="eval-heatmap"><span class="empty">Encode a corpus to populate.</span></div>
476
+ <div id="eval-compare-results" style="display:none;"><span class="empty">Run Compare to populate.</span></div>
477
+ </div>
478
+ </div>
479
+
480
+ <!-- DATA-CENTRIC bottom: filtered list + synth output -->
481
+ <div class="panel" data-show-tab="datacentric" id="dc-filtered-panel">
482
+ <h2>Filtered docs <span id="dc-filter-stats" style="font-family:var(--mono); color:var(--fg-faint);"></span>
483
+ <button class="min-btn" data-min="dc-filtered-panel" title="Minimize panel">−</button>
484
+ </h2>
485
+ <div class="panel-body">
486
+ <div id="dc-filtered-list"><span class="empty">Encode + apply filter to populate.</span></div>
487
+ </div>
488
+ </div>
489
+ <div class="panel" data-show-tab="datacentric" id="dc-synth-panel">
490
+ <h2>Synthesis output <span id="dc-synth-stats" style="font-family:var(--mono); color:var(--fg-faint);"></span>
491
+ <button class="min-btn" data-min="dc-synth-panel" title="Minimize panel">−</button>
492
+ </h2>
493
+ <div class="panel-body">
494
+ <div id="dc-synth-list"><span class="empty">Run "Synthesize from corpus" to populate.</span></div>
495
+ </div>
496
+ </div>
497
+ </section>
498
+
499
+ </div>
500
+
501
+ <div id="tooltip"></div>
502
+
503
+ <div class="footer">
504
+ <span id="status">idle</span>
505
+ </div>
506
+
507
+ <script>
508
+ const API = ""; // same-origin: HF Space serves API + HTML from one process
509
+ const N_FEATURES = 32768;
510
+
511
+ // --- State ---
512
+ const state = {
513
+ positions: null, // Float32Array length N*3
514
+ topFeatures: [], // [{id, act}, ...] from /encode
515
+ steering: new Map(), // feature_id -> alpha
516
+ picked: null, // currently hovered feature_id
517
+ };
518
+
519
+ // --- Three.js scene ---
520
+ const scene = new THREE.Scene();
521
+ scene.fog = new THREE.FogExp2(0x08080d, 0.08);
522
+ const camera = new THREE.PerspectiveCamera(55, window.innerWidth / window.innerHeight, 0.01, 60);
523
+ camera.position.set(0, 0, 3.4);
524
+ const renderer = new THREE.WebGLRenderer({
525
+ canvas: document.getElementById("scene"),
526
+ antialias: true,
527
+ alpha: true,
528
+ });
529
+ renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
530
+ renderer.setSize(window.innerWidth, window.innerHeight);
531
+ renderer.setClearColor(0x000000, 0);
532
+
533
+ // Subtle starry haze background — many tiny dim points
534
+ function addStarHaze() {
535
+ const N = 600;
536
+ const g = new THREE.BufferGeometry();
537
+ const pos = new Float32Array(N*3);
538
+ for (let i=0; i<N; i++) {
539
+ const r = 18 + Math.random()*6;
540
+ const theta = Math.random()*Math.PI*2;
541
+ const phi = Math.acos(2*Math.random()-1);
542
+ pos[i*3] = r*Math.sin(phi)*Math.cos(theta);
543
+ pos[i*3+1] = r*Math.sin(phi)*Math.sin(theta);
544
+ pos[i*3+2] = r*Math.cos(phi);
545
+ }
546
+ g.setAttribute("position", new THREE.BufferAttribute(pos, 3));
547
+ const m = new THREE.PointsMaterial({ color: 0x2a2a3a, size: 0.06, sizeAttenuation: true, transparent:true, opacity:0.5 });
548
+ scene.add(new THREE.Points(g, m));
549
+ }
550
+ addStarHaze();
551
+
552
+ // Feature point cloud — built once positions arrive
553
+ let featurePoints = null, featureGeometry = null;
554
+ function buildFeatureCloud(positions) {
555
+ featureGeometry = new THREE.BufferGeometry();
556
+ const N = positions.length / 3;
557
+ featureGeometry.setAttribute("position", new THREE.BufferAttribute(positions, 3));
558
+ const colors = new Float32Array(N*3);
559
+ const sizes = new Float32Array(N);
560
+ for (let i=0; i<N; i++) {
561
+ colors[i*3] = 1.0; colors[i*3+1] = 1.0; colors[i*3+2] = 1.0;
562
+ sizes[i] = 0.5;
563
+ }
564
+ featureGeometry.setAttribute("color", new THREE.BufferAttribute(colors, 3));
565
+ featureGeometry.setAttribute("size", new THREE.BufferAttribute(sizes, 1));
566
+
567
+ const material = new THREE.ShaderMaterial({
568
+ uniforms: { uPixelRatio: { value: renderer.getPixelRatio() } },
569
+ vertexShader: `
570
+ attribute float size;
571
+ varying vec3 vColor;
572
+ varying float vIsActive;
573
+ varying float vSize;
574
+ uniform float uPixelRatio;
575
+ void main() {
576
+ vColor = color;
577
+ vIsActive = step(1.0, size); // active features (sizes attribute >= 1.5)
578
+ vec4 mv = modelViewMatrix * vec4(position, 1.0);
579
+ gl_Position = projectionMatrix * mv;
580
+ // Base features: small fixed-pixel hard dots (~3 device px, so 6 retina px).
581
+ // Active features: size attribute -> px directly, with mild perspective
582
+ // (closer = larger), capped so they never balloon out.
583
+ float basePx = 3.0 * uPixelRatio;
584
+ float scaledPx = size * 3.0 * (3.4 / -mv.z) * uPixelRatio;
585
+ scaledPx = clamp(scaledPx, basePx, 26.0 * uPixelRatio);
586
+ gl_PointSize = mix(basePx, scaledPx, vIsActive);
587
+ vSize = gl_PointSize;
588
+ }`,
589
+ fragmentShader: `
590
+ varying vec3 vColor;
591
+ varying float vIsActive;
592
+ varying float vSize;
593
+ void main() {
594
+ vec2 uv = gl_PointCoord - 0.5;
595
+ float d = length(uv);
596
+ if (d > 0.5) discard;
597
+ // Anti-alias the rim with a fixed 1-pixel band, regardless of point size,
598
+ // so even tiny points stay solid and crisp instead of going gaussian.
599
+ float aa = 1.0 / max(vSize, 2.0);
600
+ float core = 1.0 - smoothstep(0.5 - aa, 0.5, d);
601
+ // Halo only on active features
602
+ float glow = vIsActive * (1.0 - smoothstep(0.18, 0.5, d)) * 0.35;
603
+ vec3 col = mix(vColor, vColor * 1.5, vIsActive);
604
+ float alpha = clamp(core + glow, 0.0, 1.0);
605
+ gl_FragColor = vec4(col, alpha);
606
+ }`,
607
+ transparent: true,
608
+ depthWrite: false,
609
+ vertexColors: true,
610
+ blending: THREE.NormalBlending,
611
+ });
612
+ featurePoints = new THREE.Points(featureGeometry, material);
613
+ scene.add(featurePoints);
614
+ }
615
+
616
+ // Lazy orbit-like rotation (no heavy controls dep)
617
+ let userInteracting = false, autoRot = 0.06;
618
+ let dragging = false, lastX = 0, lastY = 0;
619
+ let yaw = 0.4, pitch = 0.05, dist = 3.4;
620
+ const cnv = document.getElementById("scene");
621
+ cnv.addEventListener("pointerdown", e => { dragging = true; userInteracting = true; lastX = e.clientX; lastY = e.clientY; });
622
+ window.addEventListener("pointerup", () => { dragging = false; });
623
+ window.addEventListener("pointermove", e => {
624
+ if (dragging) {
625
+ yaw += (e.clientX - lastX) * 0.005;
626
+ pitch += (e.clientY - lastY) * 0.005;
627
+ pitch = Math.max(-1.4, Math.min(1.4, pitch));
628
+ lastX = e.clientX; lastY = e.clientY;
629
+ }
630
+ });
631
+ cnv.addEventListener("wheel", e => {
632
+ dist *= (1 + e.deltaY * 0.0012);
633
+ dist = Math.max(1.3, Math.min(8, dist));
634
+ e.preventDefault();
635
+ }, { passive:false });
636
+
637
+ window.addEventListener("resize", () => {
638
+ camera.aspect = window.innerWidth / window.innerHeight;
639
+ camera.updateProjectionMatrix();
640
+ renderer.setSize(window.innerWidth, window.innerHeight);
641
+ });
642
+
643
+ function tick(t) {
644
+ if (!userInteracting) yaw += 0.0015;
645
+ camera.position.x = dist * Math.sin(yaw) * Math.cos(pitch);
646
+ camera.position.y = dist * Math.sin(pitch);
647
+ camera.position.z = dist * Math.cos(yaw) * Math.cos(pitch);
648
+ camera.lookAt(0, 0, 0);
649
+ renderer.render(scene, camera);
650
+ requestAnimationFrame(tick);
651
+ }
652
+ requestAnimationFrame(tick);
653
+
654
+ // --- Picking (raycasting) ---
655
+ const raycaster = new THREE.Raycaster();
656
+ raycaster.params.Points = { threshold: 0.04 };
657
+ const mouse = new THREE.Vector2();
658
+ const tooltip = document.getElementById("tooltip");
659
+ cnv.addEventListener("mousemove", e => {
660
+ mouse.x = (e.clientX / window.innerWidth) * 2 - 1;
661
+ mouse.y = -(e.clientY / window.innerHeight) * 2 + 1;
662
+ if (!featurePoints) return;
663
+ raycaster.setFromCamera(mouse, camera);
664
+ const hits = raycaster.intersectObject(featurePoints);
665
+ if (hits.length > 0) {
666
+ const id = hits[0].index;
667
+ state.picked = id;
668
+ const top = state.topFeatures.find(t => t.id === id);
669
+ const isSteered = state.steering.has(id);
670
+ const alpha = isSteered ? state.steering.get(id) : null;
671
+ let html = `<div style="color:var(--fg); font-weight:600;">feature ${id}</div>`;
672
+ if (top) {
673
+ const rank = state.topFeatures.indexOf(top);
674
+ html += `<div style="margin-top:3px; color:var(--accent);">rank #${rank} of top firing · activation ${top.act.toFixed(3)}</div>`;
675
+ } else {
676
+ html += `<div style="margin-top:3px; color:var(--fg-faint);">not in top firing for current prompt</div>`;
677
+ }
678
+ if (isSteered) {
679
+ const sign = alpha >= 0 ? "+" : "";
680
+ html += `<div style="margin-top:3px; color:var(--accent-2);">steered · α=${sign}${alpha.toFixed(0)}</div>`;
681
+ }
682
+ html += `<div style="margin-top:5px; color:var(--fg-faint); font-size:10px;">click to ${isSteered ? "edit slider" : "add steering slider"}</div>`;
683
+ tooltip.innerHTML = html;
684
+ tooltip.style.display = "block";
685
+ tooltip.style.left = e.clientX + "px";
686
+ tooltip.style.top = e.clientY + "px";
687
+ } else {
688
+ state.picked = null;
689
+ tooltip.style.display = "none";
690
+ }
691
+ });
692
+ cnv.addEventListener("click", () => {
693
+ if (state.picked != null) {
694
+ addSteerSlot(state.picked, 0);
695
+ }
696
+ });
697
+
698
+ // --- Update particle attributes after encode ---
699
+ function repaintCloud() {
700
+ if (!featureGeometry) return;
701
+ const colors = featureGeometry.attributes.color.array;
702
+ const sizes = featureGeometry.attributes.size.array;
703
+ const N = sizes.length;
704
+ for (let i=0; i<N; i++) {
705
+ colors[i*3] = 1.0; colors[i*3+1] = 1.0; colors[i*3+2] = 1.0;
706
+ sizes[i] = 0.5;
707
+ }
708
+ // Top firing — cyan, larger
709
+ const maxAct = state.topFeatures[0]?.act || 1;
710
+ for (const f of state.topFeatures) {
711
+ const i = f.id;
712
+ const t = Math.min(1, Math.abs(f.act) / maxAct);
713
+ colors[i*3] = 0.49 * t + 1.0 * (1-t);
714
+ colors[i*3+1] = 0.97 * t + 1.0 * (1-t);
715
+ colors[i*3+2] = 1.00;
716
+ sizes[i] = 1.5 + 4.5 * t;
717
+ }
718
+ // Steered — magenta override
719
+ for (const [id, slot] of state.steering) {
720
+ const alpha = (typeof slot === "object") ? slot.alpha : slot;
721
+ if (alpha === 0) continue;
722
+ const t = Math.min(1, Math.abs(alpha) / 100);
723
+ colors[id*3] = 1.00;
724
+ colors[id*3+1] = 0.49 * (1-t) + 0.4 * t;
725
+ colors[id*3+2] = 0.97;
726
+ sizes[id] = 4 + 5 * t;
727
+ }
728
+ featureGeometry.attributes.color.needsUpdate = true;
729
+ featureGeometry.attributes.size.needsUpdate = true;
730
+ }
731
+
732
+ // --- API helpers ---
733
+ function setStatus(msg, busy=false) {
734
+ const el = document.getElementById("status");
735
+ el.innerHTML = busy ? `<span class="loader"></span> &nbsp; ${msg}` : msg;
736
+ }
737
+ async function api(path, body=null) {
738
+ const opts = body ? {method:"POST", headers:{"Content-Type":"application/json"}, body: JSON.stringify(body)}
739
+ : {method:"GET"};
740
+ const r = await fetch(API + path, opts);
741
+ if (!r.ok) throw new Error(`${path} -> ${r.status} ${r.statusText}`);
742
+ return await r.json();
743
+ }
744
+
745
+ // --- Boot ---
746
+ function applyHealthToHeader(hp) {
747
+ const layerInput = document.getElementById("layer-input");
748
+ layerInput.value = hp.layer;
749
+ layerInput.max = (hp.n_layers || 1) - 1;
750
+ layerInput.title = `Layer 0..${hp.n_layers - 1}`;
751
+ document.getElementById("hdr-layer-meta").textContent = ` · ${hp.device} · ${hp.dtype}`;
752
+ document.getElementById("hdr-features").textContent = `${hp.n_features.toLocaleString()} features`;
753
+ const hdr = document.getElementById("header");
754
+ hdr.classList.remove("loading");
755
+ hdr.classList.add("live");
756
+ if (hp.transferred && hp.note) {
757
+ setStatus("⚠ transferred SAE: " + hp.note);
758
+ }
759
+ }
760
+
761
+ // Layer hot-swap: change SAE layer without reloading model
762
+ let layerSwapPending = null;
763
+ document.getElementById("layer-input").addEventListener("change", async (ev) => {
764
+ const newLayer = parseInt(ev.target.value);
765
+ const hp = await api("/health");
766
+ if (newLayer === hp.layer) return;
767
+ if (isNaN(newLayer) || newLayer < 0 || newLayer >= hp.n_layers) {
768
+ ev.target.value = hp.layer; return;
769
+ }
770
+ // Clear stale UI state — features and outputs are layer-specific
771
+ state.topFeatures = []; state.steering.clear();
772
+ document.getElementById("features").innerHTML = `<div class="empty">Encode a prompt to populate features.</div>`;
773
+ document.getElementById("feat-count").textContent = "";
774
+ document.getElementById("out-baseline").innerHTML = `<span class="empty">(no run yet)</span>`;
775
+ document.getElementById("out-steered").innerHTML = `<span class="empty">(no run yet)</span>`;
776
+ document.getElementById("verifier").innerHTML = "";
777
+ document.getElementById("heatmap-grid").innerHTML = `<span class="empty">Encode a prompt — heatmap fills automatically.</span>`;
778
+ document.getElementById("btn-generate").disabled = true;
779
+
780
+ showLoading(`Switching to layer ${newLayer}…`,
781
+ "First time may take ~20s (download + SVD); subsequent switches are <0.1s.");
782
+ try {
783
+ const r = await api("/set_layer", {layer: newLayer});
784
+ const hp2 = await api("/health");
785
+ applyHealthToHeader(hp2);
786
+ applyPositionsToCloud(r.positions);
787
+ setStatus(`layer ${newLayer}` + (r.from_cache ? " (cached)" : ""));
788
+ } catch (e) {
789
+ setStatus(`layer swap error: ${e.message}`);
790
+ ev.target.value = hp.layer;
791
+ } finally {
792
+ hideLoading();
793
+ }
794
+ });
795
+
796
+ function applyPositionsToCloud(positions) {
797
+ const flat = new Float32Array(positions.length * 3);
798
+ for (let i=0; i<positions.length; i++) {
799
+ flat[i*3] = positions[i][0] * 2.2;
800
+ flat[i*3+1] = positions[i][1] * 2.2;
801
+ flat[i*3+2] = positions[i][2] * 2.2;
802
+ }
803
+ state.positions = flat;
804
+ if (featurePoints) {
805
+ scene.remove(featurePoints);
806
+ featureGeometry.dispose();
807
+ featurePoints.material.dispose();
808
+ featurePoints = null; featureGeometry = null;
809
+ }
810
+ buildFeatureCloud(flat);
811
+ }
812
+
813
+ function fillModelDropdown(models, currentModel) {
814
+ const sel = document.getElementById("model-select");
815
+ sel.innerHTML = "";
816
+ for (const m of models) {
817
+ const opt = document.createElement("option");
818
+ opt.value = m.model;
819
+ const xferTag = m.transferred ? " ⚠" : "";
820
+ opt.textContent = `${m.model.replace("Qwen/","")} (${m.approx_size_gb}GB · ${m.n_features.toLocaleString()}f)${xferTag}`;
821
+ if (m.model === currentModel) opt.selected = true;
822
+ sel.appendChild(opt);
823
+ }
824
+ // If only one model in the catalog (HF Space deployment), make it
825
+ // visually informational rather than interactive.
826
+ if (models.length <= 1) {
827
+ sel.disabled = true;
828
+ sel.title = "Locked to Qwen3-1.7B-Base on HF Space (free CPU). Run locally to swap models.";
829
+ sel.style.cursor = "default";
830
+ sel.style.opacity = "0.85";
831
+ }
832
+ }
833
+
834
+ function showLoading(msg, detail="") {
835
+ document.getElementById("lo-msg").textContent = msg;
836
+ document.getElementById("lo-detail").textContent = detail;
837
+ document.getElementById("loading-overlay").classList.add("visible");
838
+ document.getElementById("header").classList.remove("live");
839
+ document.getElementById("header").classList.add("loading");
840
+ document.getElementById("btn-encode").disabled = true;
841
+ document.getElementById("btn-generate").disabled = true;
842
+ }
843
+ function hideLoading() {
844
+ document.getElementById("loading-overlay").classList.remove("visible");
845
+ document.getElementById("btn-encode").disabled = false;
846
+ }
847
+
848
+ async function boot() {
849
+ setStatus("loading positions…", true);
850
+ try {
851
+ const [hp, ph, ml] = await Promise.all([
852
+ api("/health"),
853
+ api("/positions"),
854
+ api("/list_models"),
855
+ ]);
856
+ fillModelDropdown(ml.models, hp.model);
857
+ applyHealthToHeader(hp);
858
+ applyPositionsToCloud(ph.positions);
859
+ setStatus("ready");
860
+ if (hp.transferred && hp.note) setStatus("⚠ " + hp.note);
861
+ } catch (e) {
862
+ setStatus(`server unreachable at ${API}: ${e.message}`);
863
+ document.getElementById("model-select").innerHTML = `<option style="color:var(--bad)">server offline</option>`;
864
+ }
865
+ }
866
+ boot();
867
+
868
+ // Tab switching
869
+ document.body.dataset.tab = "steering";
870
+ document.querySelectorAll(".tab").forEach(btn => {
871
+ btn.addEventListener("click", () => {
872
+ document.querySelectorAll(".tab").forEach(b => b.classList.remove("active"));
873
+ btn.classList.add("active");
874
+ document.body.dataset.tab = btn.dataset.tab;
875
+ });
876
+ });
877
+
878
+ // Minimize toggle on every panel header
879
+ document.querySelectorAll(".min-btn").forEach(btn => {
880
+ btn.addEventListener("click", (ev) => {
881
+ ev.stopPropagation();
882
+ const panel = btn.closest(".panel");
883
+ if (!panel) return;
884
+ panel.classList.toggle("collapsed");
885
+ btn.textContent = panel.classList.contains("collapsed") ? "+" : "−";
886
+ btn.title = panel.classList.contains("collapsed") ? "Expand panel" : "Minimize panel";
887
+ });
888
+ });
889
+
890
+ // Model swap on dropdown change
891
+ document.getElementById("model-select").addEventListener("change", async (ev) => {
892
+ const newModel = ev.target.value;
893
+ const ml = await api("/list_models");
894
+ const entry = ml.models.find(m => m.model === newModel);
895
+ if (!entry) return;
896
+ const ok = confirm(
897
+ `Load ${newModel}?\n\n` +
898
+ `~${entry.approx_size_gb}GB download (or cached if seen before).\n` +
899
+ `${entry.n_features.toLocaleString()} SAE features, ${entry.n_layers} layers.\n` +
900
+ (entry.transferred ? `⚠ TRANSFERRED SAE: ${entry.note}\n` : ``) +
901
+ `\nThis blocks the server until ready.`
902
+ );
903
+ if (!ok) {
904
+ // revert dropdown to current
905
+ const hp = await api("/health");
906
+ ev.target.value = hp.model;
907
+ return;
908
+ }
909
+ // Reset client-side state
910
+ state.topFeatures = []; state.steering.clear();
911
+ document.getElementById("features").innerHTML = `<div class="empty">Encode a prompt to populate features.</div>`;
912
+ document.getElementById("feat-count").textContent = "";
913
+ document.getElementById("out-baseline").innerHTML = `<span class="empty">(no run yet)</span>`;
914
+ document.getElementById("out-steered").innerHTML = `<span class="empty">(no run yet)</span>`;
915
+ document.getElementById("verifier").innerHTML = "";
916
+ document.getElementById("btn-generate").disabled = true;
917
+
918
+ showLoading(`Loading ${newModel}…`,
919
+ `~${entry.approx_size_gb}GB · this is real, watch the server log`);
920
+ try {
921
+ const r = await fetch(API + "/load_model", {
922
+ method:"POST", headers:{"Content-Type":"application/json"},
923
+ body: JSON.stringify({model:newModel})
924
+ });
925
+ if (!r.ok) {
926
+ const txt = await r.text();
927
+ throw new Error(`${r.status}: ${txt}`);
928
+ }
929
+ const data = await r.json();
930
+ const hp = await api("/health");
931
+ applyHealthToHeader(hp);
932
+ applyPositionsToCloud(data.positions);
933
+ setStatus(`loaded ${newModel}`);
934
+ if (hp.transferred && hp.note) setStatus("⚠ " + hp.note);
935
+ } catch (e) {
936
+ setStatus(`load failed: ${e.message}`);
937
+ alert(`Load failed:\n\n${e.message}\n\nCheck the server log.`);
938
+ // Try to revert dropdown
939
+ try {
940
+ const hp = await api("/health");
941
+ ev.target.value = hp.model;
942
+ } catch {}
943
+ } finally {
944
+ hideLoading();
945
+ }
946
+ });
947
+
948
+ // --- Encode action ---
949
+ document.getElementById("btn-encode").addEventListener("click", async () => {
950
+ const prompt = document.getElementById("prompt").value;
951
+ if (!prompt.trim()) return;
952
+ setStatus("encoding…", true);
953
+ try {
954
+ const t0 = performance.now();
955
+ const top_n = Math.max(1, parseInt(document.getElementById("top-k").value || "20"));
956
+ const r = await api("/encode", {prompt, top_n});
957
+ const dt = ((performance.now() - t0)/1000).toFixed(2);
958
+ state.topFeatures = r.top;
959
+ document.getElementById("feat-count").textContent = `(K=${r.top.length} of ${r.n_features.toLocaleString()})`;
960
+ renderFeatures();
961
+ repaintCloud();
962
+ document.getElementById("btn-generate").disabled = false;
963
+ setStatus(`encoded in ${dt}s`);
964
+ // Auto-render the per-token heatmap too
965
+ renderHeatmap(prompt).catch(e => console.error("heatmap:", e));
966
+ } catch (e) {
967
+ setStatus(`encode error: ${e.message}`);
968
+ }
969
+ });
970
+
971
+ document.getElementById("heatmap-skip-first").addEventListener("change", () => {
972
+ if (state._lastHeatmapPrompt) renderHeatmap(state._lastHeatmapPrompt);
973
+ });
974
+
975
+ // =====================================================================
976
+ // PILLAR 2 — EVALUATION
977
+ // =====================================================================
978
+ function parsePrompts(textareaId) {
979
+ return document.getElementById(textareaId).value
980
+ .split("\n").map(s => s.trim()).filter(s => s.length > 0);
981
+ }
982
+
983
+ document.getElementById("btn-eval-encode").addEventListener("click", async () => {
984
+ const prompts = parsePrompts("eval-prompts");
985
+ if (prompts.length === 0) { setStatus("paste prompts first"); return; }
986
+ setStatus(`encoding ${prompts.length} prompts…`, true);
987
+ document.getElementById("eval-features-list").innerHTML = `<span class="loader"></span>`;
988
+ document.getElementById("eval-samples").innerHTML = `<span class="loader"></span>`;
989
+ document.getElementById("eval-heatmap").innerHTML = `<span class="loader"></span>`;
990
+ try {
991
+ const t0 = performance.now();
992
+ const r = await api("/encode_batch", {prompts, top_n: 10});
993
+ const dt = ((performance.now()-t0)/1000).toFixed(2);
994
+ state.evalResult = r;
995
+ document.getElementById("eval-stats").textContent =
996
+ `(${r.n_samples} samples · ${r.corpus_features.length} features fired)`;
997
+ renderEvalCorpus(r);
998
+ renderEvalSamples(r);
999
+ renderEvalHeatmap(r);
1000
+ setStatus(`corpus encoded in ${dt}s`);
1001
+ } catch (e) {
1002
+ setStatus(`eval encode error: ${e.message}`);
1003
+ }
1004
+ });
1005
+
1006
+ function renderEvalCorpus(r) {
1007
+ const div = document.getElementById("eval-features-list");
1008
+ if (!r.corpus_features.length) {
1009
+ div.innerHTML = `<div class="empty">No features fired.</div>`;
1010
+ return;
1011
+ }
1012
+ // Top firing features ranked by fire_rate, with bar
1013
+ const max = r.corpus_features[0].fire_rate;
1014
+ div.innerHTML = r.corpus_features.slice(0, 60).map(f => {
1015
+ const w = Math.max(2, 100 * f.fire_rate / max);
1016
+ return `<div style="padding:5px 8px; border-bottom:1px solid var(--border); font-family:var(--mono); font-size:11px;">
1017
+ <div style="display:flex; justify-content:space-between; gap:10px; align-items:center;">
1018
+ <span style="color:var(--accent);">feat ${f.id}</span>
1019
+ <span style="color:var(--fg-faint);">${(f.fire_rate*100).toFixed(0)}% · μ=${f.mean_act.toFixed(2)}</span>
1020
+ </div>
1021
+ <div style="height:3px; background:rgba(125,249,255,0.55); width:${w}%; margin-top:4px; border-radius:2px;"></div>
1022
+ </div>`;
1023
+ }).join("");
1024
+ }
1025
+
1026
+ function renderEvalSamples(r) {
1027
+ const div = document.getElementById("eval-samples");
1028
+ if (!r.per_sample.length) { div.innerHTML = `<span class="empty">no samples</span>`; return; }
1029
+ div.innerHTML = r.per_sample.map(s => {
1030
+ const top = s.top.slice(0,5).map(t =>
1031
+ `<span style="display:inline-block; padding:1px 6px; margin:1px; border:1px solid var(--border-strong); border-radius:4px; font-family:var(--mono); font-size:10px; color:var(--accent);">${t.id}<span style="color:var(--fg-faint);">·${t.act.toFixed(1)}</span></span>`
1032
+ ).join("");
1033
+ return `<div style="padding:8px; border-bottom:1px solid var(--border);">
1034
+ <div style="font-family:var(--mono); font-size:11px; color:var(--fg); margin-bottom:4px;">[${s.i}] ${escapeHtml(s.preview)}</div>
1035
+ <div>${top}</div>
1036
+ </div>`;
1037
+ }).join("");
1038
+ }
1039
+
1040
+ function renderEvalHeatmap(r) {
1041
+ const div = document.getElementById("eval-heatmap");
1042
+ // Pick top 20 features that fire most across samples
1043
+ const topFeats = r.corpus_features.slice(0, 20);
1044
+ if (!topFeats.length || !r.per_sample.length) { div.innerHTML = `<span class="empty">no data</span>`; return; }
1045
+ // Build sample-id × feature_id activation matrix
1046
+ const sampleActs = r.per_sample.map(s => {
1047
+ const m = {};
1048
+ for (const t of s.top) m[t.id] = t.act;
1049
+ return m;
1050
+ });
1051
+ const max = Math.max(1e-6, ...sampleActs.flatMap(m => Object.values(m)));
1052
+ // Render: rows = samples, cols = features
1053
+ const header = `<th style="padding:3px 6px; font-size:10px; color:var(--fg-dim); background:rgba(0,0,0,0.4); position:sticky; left:0; z-index:2;">sample</th>` +
1054
+ topFeats.map(f => `<th title="feat ${f.id} (${(f.fire_rate*100).toFixed(0)}% rate)" style="padding:3px 4px; font-size:9px; color:var(--accent); border-bottom:1px solid var(--border-strong);">${f.id}</th>`).join("");
1055
+ const rows = r.per_sample.map((s, si) => {
1056
+ const cells = topFeats.map(f => {
1057
+ const v = sampleActs[si][f.id] || 0;
1058
+ const t = Math.min(1, v/max);
1059
+ const r = Math.round(255 - 200*t), g = 255, b = Math.round(255 - 50*t);
1060
+ return `<td title="sample ${si} · feat ${f.id} · ${v.toFixed(2)}" style="background:rgb(${r},${g},${b}); width:22px; height:18px; border:1px solid rgba(0,0,0,0.3);"></td>`;
1061
+ }).join("");
1062
+ return `<tr><td style="padding:2px 6px; font-family:var(--mono); font-size:10px; color:var(--fg-dim); background:rgba(0,0,0,0.3); position:sticky; left:0; z-index:1; max-width:80px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap;">${si}: ${escapeHtml(s.preview.slice(0,16))}</td>${cells}</tr>`;
1063
+ }).join("");
1064
+ div.innerHTML = `<table style="border-collapse:collapse; font-family:var(--mono);"><thead><tr>${header}</tr></thead><tbody>${rows}</tbody></table>`;
1065
+ }
1066
+
1067
+ function escapeHtml(s) {
1068
+ return String(s).replace(/&/g,"&amp;").replace(/</g,"&lt;").replace(/>/g,"&gt;").replace(/"/g,"&quot;");
1069
+ }
1070
+
1071
+ // View toggle in Evaluation bottom-right panel: heatmap vs compare
1072
+ document.getElementById("btn-eval-view-heatmap").addEventListener("click", () => {
1073
+ document.getElementById("eval-heatmap").style.display = "block";
1074
+ document.getElementById("eval-compare-results").style.display = "none";
1075
+ });
1076
+ document.getElementById("btn-eval-view-compare").addEventListener("click", () => {
1077
+ document.getElementById("eval-heatmap").style.display = "none";
1078
+ document.getElementById("eval-compare-results").style.display = "block";
1079
+ });
1080
+
1081
+ // Differential feature mining
1082
+ document.getElementById("btn-cmp").addEventListener("click", async () => {
1083
+ const a = parsePrompts("cmp-a");
1084
+ const b = parsePrompts("cmp-b");
1085
+ if (a.length === 0 || b.length === 0) {
1086
+ setStatus("paste both sets first"); return;
1087
+ }
1088
+ setStatus(`comparing ${a.length} vs ${b.length} prompts…`, true);
1089
+ // Auto-switch the bottom-right panel to compare view
1090
+ document.getElementById("eval-heatmap").style.display = "none";
1091
+ document.getElementById("eval-compare-results").style.display = "block";
1092
+ document.getElementById("eval-compare-results").innerHTML = `<span class="loader"></span> &nbsp; encoding…`;
1093
+ try {
1094
+ const t0 = performance.now();
1095
+ const r = await api("/compare_batch", {prompts_a: a, prompts_b: b, top_n: 30});
1096
+ const dt = ((performance.now()-t0)/1000).toFixed(2);
1097
+ renderCompareResults(r, a, b);
1098
+ setStatus(`compared in ${dt}s`);
1099
+ } catch (e) {
1100
+ setStatus(`compare error: ${e.message}`);
1101
+ document.getElementById("eval-compare-results").innerHTML = `<span class="empty">error: ${e.message}</span>`;
1102
+ }
1103
+ });
1104
+
1105
+ function renderCompareResults(r, setA, setB) {
1106
+ const div = document.getElementById("eval-compare-results");
1107
+ if (!r.top_diff || !r.top_diff.length) {
1108
+ div.innerHTML = `<span class="empty">no distinguishing features found</span>`;
1109
+ return;
1110
+ }
1111
+ const max = r.top_diff[0].diff || 1;
1112
+ const header = `
1113
+ <div style="font-size:10px; color:var(--fg-faint); padding:4px 0; display:flex; gap:14px; flex-wrap:wrap; border-bottom:1px solid var(--border-strong); margin-bottom:6px;">
1114
+ <span><span style="color:#7df9ff; font-weight:700;">■ A</span> ${r.n_a} prompts</span>
1115
+ <span><span style="color:#ff7df9; font-weight:700;">■ B</span> ${r.n_b} prompts</span>
1116
+ <span>top ${r.top_diff.length} by |Δ rate|</span>
1117
+ </div>`;
1118
+ const rows = r.top_diff.map((f, i) => {
1119
+ const wA = Math.max(2, 60 * f.rate_a);
1120
+ const wB = Math.max(2, 60 * f.rate_b);
1121
+ const winner = f.winner === "a" ? "A" : "B";
1122
+ const winnerColor = f.winner === "a" ? "#7df9ff" : "#ff7df9";
1123
+ return `<tr style="border-bottom:1px solid var(--border);">
1124
+ <td style="padding:3px 6px; color:var(--fg-faint); font-size:10px;">${i+1}</td>
1125
+ <td style="padding:3px 6px; color:var(--accent); font-family:var(--mono);">${f.id}</td>
1126
+ <td style="padding:3px 6px; text-align:right; font-family:var(--mono);">${(f.rate_a*100).toFixed(0)}%</td>
1127
+ <td style="padding:3px 6px;">
1128
+ <div style="display:flex; gap:2px; align-items:center;">
1129
+ <div style="width:${wA}px; height:6px; background:#7df9ff;"></div>
1130
+ <div style="width:${wB}px; height:6px; background:#ff7df9;"></div>
1131
+ </div>
1132
+ </td>
1133
+ <td style="padding:3px 6px; text-align:right; font-family:var(--mono);">${(f.rate_b*100).toFixed(0)}%</td>
1134
+ <td style="padding:3px 6px; font-family:var(--mono); font-size:10px; color:${winnerColor};">${winner} ▲</td>
1135
+ <td style="padding:3px 6px; text-align:right; font-family:var(--mono); color:var(--fg);">${(f.diff*100).toFixed(0)}%</td>
1136
+ </tr>`;
1137
+ }).join("");
1138
+ div.innerHTML = `
1139
+ <div>
1140
+ ${header}
1141
+ <div style="overflow:auto; max-height:32vh;">
1142
+ <table style="border-collapse:collapse; width:100%; font-size:11px;">
1143
+ <thead>
1144
+ <tr style="background:rgba(0,0,0,0.4); position:sticky; top:0; z-index:1;">
1145
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">#</th>
1146
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">feat</th>
1147
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:right;">rate A</th>
1148
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">A vs B</th>
1149
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:right;">rate B</th>
1150
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">winner</th>
1151
+ <th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:right;">|Δ|</th>
1152
+ </tr>
1153
+ </thead>
1154
+ <tbody>${rows}</tbody>
1155
+ </table>
1156
+ </div>
1157
+ </div>`;
1158
+ }
1159
+
1160
+ // =====================================================================
1161
+ // PILLAR 3 — DATA-CENTRIC
1162
+ // =====================================================================
1163
+ document.getElementById("btn-dc-encode").addEventListener("click", async () => {
1164
+ const prompts = parsePrompts("dc-prompts");
1165
+ if (prompts.length === 0) { setStatus("paste prompts first"); return; }
1166
+ setStatus(`encoding ${prompts.length} prompts…`, true);
1167
+ try {
1168
+ const r = await api("/encode_batch", {prompts, top_n: 50});
1169
+ state.dcResult = r;
1170
+ state.dcAllPrompts = prompts;
1171
+ state.dcFiltered = r.per_sample;
1172
+ renderDcFiltered(state.dcFiltered, prompts, null);
1173
+ setStatus(`encoded ${r.n_samples} prompts`);
1174
+ } catch (e) {
1175
+ setStatus(`dc encode error: ${e.message}`);
1176
+ }
1177
+ });
1178
+
1179
+ document.getElementById("btn-dc-filter").addEventListener("click", () => {
1180
+ if (!state.dcResult) { setStatus("encode first"); return; }
1181
+ const fid = parseInt(document.getElementById("dc-filter-id").value);
1182
+ const mode = document.getElementById("dc-filter-mode").value;
1183
+ if (isNaN(fid)) { setStatus("enter a feature id"); return; }
1184
+ const fired = new Set();
1185
+ for (const s of state.dcResult.per_sample) {
1186
+ if (s.top.some(t => t.id === fid)) fired.add(s.i);
1187
+ }
1188
+ const filtered = state.dcResult.per_sample.filter(s =>
1189
+ mode === "include" ? fired.has(s.i) : !fired.has(s.i)
1190
+ );
1191
+ state.dcFiltered = filtered;
1192
+ renderDcFiltered(filtered, state.dcAllPrompts, {id: fid, mode});
1193
+ });
1194
+
1195
+ document.getElementById("btn-dc-clear").addEventListener("click", () => {
1196
+ if (!state.dcResult) return;
1197
+ state.dcFiltered = state.dcResult.per_sample;
1198
+ renderDcFiltered(state.dcFiltered, state.dcAllPrompts, null);
1199
+ });
1200
+
1201
+ function renderDcFiltered(samples, prompts, filt) {
1202
+ const div = document.getElementById("dc-filtered-list");
1203
+ const stats = document.getElementById("dc-filter-stats");
1204
+ if (filt) {
1205
+ stats.textContent = `(filter: ${filt.mode === "include" ? "+" : "−"}feat ${filt.id} · ${samples.length} of ${state.dcResult.n_samples})`;
1206
+ } else {
1207
+ stats.textContent = `(${samples.length} docs)`;
1208
+ }
1209
+ if (!samples.length) { div.innerHTML = `<span class="empty">no docs match.</span>`; return; }
1210
+ div.innerHTML = samples.map(s => {
1211
+ const fullPrompt = prompts[s.i] || s.preview;
1212
+ const top = s.top.slice(0,3).map(t =>
1213
+ `<span style="display:inline-block; padding:1px 5px; margin:1px; border:1px solid var(--border-strong); border-radius:3px; font-family:var(--mono); font-size:10px; color:var(--accent);">${t.id}</span>`
1214
+ ).join("");
1215
+ return `<div style="padding:6px 8px; border-bottom:1px solid var(--border); font-family:var(--mono); font-size:11px;">
1216
+ <div style="color:var(--fg-faint); font-size:9px;">[${s.i}]</div>
1217
+ <div style="color:var(--fg);">${escapeHtml(fullPrompt)}</div>
1218
+ <div style="margin-top:3px;">${top}</div>
1219
+ </div>`;
1220
+ }).join("");
1221
+ }
1222
+
1223
+ document.getElementById("btn-dc-synth").addEventListener("click", async () => {
1224
+ const prompts = state.dcAllPrompts || parsePrompts("dc-prompts");
1225
+ if (!prompts.length) { setStatus("paste seeds first"); return; }
1226
+ const fid = parseInt(document.getElementById("dc-synth-id").value);
1227
+ const alpha = parseFloat(document.getElementById("dc-synth-alpha").value);
1228
+ const max_new_tokens = parseInt(document.getElementById("dc-synth-tokens").value);
1229
+ if (isNaN(fid)) { setStatus("enter feature id"); return; }
1230
+ setStatus(`synthesizing ${prompts.length} steered completions…`, true);
1231
+ document.getElementById("dc-synth-list").innerHTML = `<span class="loader"></span> &nbsp; running…`;
1232
+ try {
1233
+ const r = await api("/synth_batch", {
1234
+ seed_prompts: prompts,
1235
+ steering: [{id: fid, alpha}],
1236
+ max_new_tokens,
1237
+ });
1238
+ document.getElementById("dc-synth-stats").textContent =
1239
+ `(${r.results.length} · feat ${fid} α=${alpha})`;
1240
+ document.getElementById("dc-synth-list").innerHTML = r.results.map((res, i) =>
1241
+ `<div style="padding:8px; border-bottom:1px solid var(--border); font-family:var(--mono); font-size:11px;">
1242
+ <div style="color:var(--fg-faint); font-size:9px;">[${i}] seed</div>
1243
+ <div style="color:var(--fg-dim);">${escapeHtml(res.seed)}</div>
1244
+ <div style="color:var(--fg-faint); font-size:9px; margin-top:4px;">→ steered</div>
1245
+ <div style="color:var(--fg);">${escapeHtml(res.text)}</div>
1246
+ </div>`
1247
+ ).join("");
1248
+ setStatus("synthesis done");
1249
+ } catch (e) {
1250
+ setStatus(`synth error: ${e.message}`);
1251
+ document.getElementById("dc-synth-list").innerHTML = `<span class="empty">error: ${e.message}</span>`;
1252
+ }
1253
+ });
1254
+
1255
+ async function renderHeatmap(prompt) {
1256
+ const container = document.getElementById("heatmap-grid");
1257
+ state._lastHeatmapPrompt = prompt;
1258
+ container.innerHTML = `<span class="loader"></span> &nbsp; computing…`;
1259
+ try {
1260
+ const r = await api("/encode_full", {prompt, top_n: 16});
1261
+ let tokens = r.tokens, grid = r.grid, ids = r.feature_ids;
1262
+ const skipFirst = document.getElementById("heatmap-skip-first").checked;
1263
+ if (skipFirst && tokens.length > 1) {
1264
+ tokens = tokens.slice(1);
1265
+ grid = grid.map(row => row.slice(1));
1266
+ }
1267
+ if (tokens.length === 0) {
1268
+ container.innerHTML = `<span class="empty">No tokens (after skip).</span>`;
1269
+ return;
1270
+ }
1271
+ // Build HTML table
1272
+ const rowMax = grid.map(row => Math.max(...row.map(Math.abs), 1e-6));
1273
+ const tokHeader = tokens.map((t,i) => {
1274
+ const safe = (t.replace(/\n/g,"↵").replace(/</g,"&lt;").replace(/>/g,"&gt;")).slice(0,8);
1275
+ return `<th title="pos ${i}: ${t.replace(/\n/g,'\\n').replace(/"/g,'&quot;')}" style="padding:3px 4px; font-size:10px; color:var(--fg-dim); border-bottom:1px solid var(--border-strong); white-space:nowrap;">${safe || `[${i}]`}</th>`;
1276
+ }).join("");
1277
+ const rows = grid.map((row, fi) => {
1278
+ const m = rowMax[fi];
1279
+ const cells = row.map((v, pi) => {
1280
+ const t = Math.min(1, Math.abs(v) / m);
1281
+ const r = 255, g = Math.round(255 - 200*t), b = Math.round(255 - 220*t);
1282
+ return `<td title="feat ${ids[fi]} · pos ${pi} · act=${v.toFixed(3)}" style="background:rgb(${r},${g},${b}); width:30px; height:22px; border:1px solid rgba(0,0,0,0.3);"></td>`;
1283
+ }).join("");
1284
+ return `<tr><td style="font-family:var(--mono); font-size:10px; padding:3px 6px; color:var(--accent); white-space:nowrap; background:rgba(0,0,0,0.3); position:sticky; left:0; z-index:1;">#${ids[fi]}</td>${cells}</tr>`;
1285
+ }).join("");
1286
+ container.innerHTML = `<table style="border-collapse:collapse; font-family:var(--mono);"><thead><tr><th style="padding:3px 6px; font-size:10px; color:var(--fg-dim); position:sticky; left:0; z-index:2; background:rgba(0,0,0,0.4);">feat</th>${tokHeader}</tr></thead><tbody>${rows}</tbody></table>`;
1287
+ } catch (e) {
1288
+ container.innerHTML = `<span class="empty">heatmap error: ${e.message}</span>`;
1289
+ }
1290
+ }
1291
+
1292
+ // --- Feature card render ---
1293
+ function renderFeatures() {
1294
+ const div = document.getElementById("features");
1295
+ if (state.topFeatures.length === 0) {
1296
+ div.innerHTML = `<div class="empty">Encode a prompt to populate features.</div>`;
1297
+ return;
1298
+ }
1299
+ div.innerHTML = "";
1300
+ for (const f of state.topFeatures) {
1301
+ div.appendChild(buildFeatureCard(f.id, f.act, /*topRanked=*/true));
1302
+ }
1303
+ // Also render any steered features that weren't in top-K
1304
+ for (const [id, alpha] of state.steering) {
1305
+ if (!state.topFeatures.find(t => t.id === id)) {
1306
+ div.appendChild(buildFeatureCard(id, null, /*topRanked=*/false));
1307
+ }
1308
+ }
1309
+ }
1310
+ function buildFeatureCard(id, act, topRanked) {
1311
+ const wrap = document.createElement("div");
1312
+ wrap.className = "feat" + (state.steering.has(id) ? " steered" : "");
1313
+ const alpha = state.steering.has(id) ? state.steering.get(id) : 0;
1314
+
1315
+ const head = document.createElement("div");
1316
+ head.className = "feat-head";
1317
+ const left = document.createElement("div");
1318
+ left.innerHTML = `<span class="feat-id">feat ${id}</span>` +
1319
+ (act != null ? ` &nbsp;<span class="feat-act">act ${act.toFixed(3)}</span>` : ` &nbsp;<span class="feat-act" style="color:var(--fg-faint)">(picked)</span>`);
1320
+ const tools = document.createElement("div");
1321
+ tools.className = "feat-tools";
1322
+
1323
+ if (state.steering.has(id)) {
1324
+ const reset = document.createElement("button");
1325
+ reset.textContent = "reset";
1326
+ reset.title = "Set α back to 0 (no intervention) but keep slider visible";
1327
+ reset.addEventListener("click", () => {
1328
+ state.steering.set(id, 0);
1329
+ renderFeatures();
1330
+ repaintCloud();
1331
+ });
1332
+ tools.appendChild(reset);
1333
+ const off = document.createElement("button");
1334
+ off.textContent = "remove";
1335
+ off.title = "Remove this slider entirely";
1336
+ off.addEventListener("click", () => {
1337
+ state.steering.delete(id);
1338
+ renderFeatures();
1339
+ repaintCloud();
1340
+ });
1341
+ tools.appendChild(off);
1342
+ } else {
1343
+ const add = document.createElement("button");
1344
+ add.textContent = "steer";
1345
+ add.addEventListener("click", () => addSteerSlot(id, 0));
1346
+ tools.appendChild(add);
1347
+ }
1348
+ head.appendChild(left);
1349
+ head.appendChild(tools);
1350
+ wrap.appendChild(head);
1351
+
1352
+ if (state.steering.has(id)) {
1353
+ const cur = state.steering.get(id);
1354
+ const curAlpha = (typeof cur === "object") ? cur.alpha : cur;
1355
+ const curPositions = (typeof cur === "object") ? (cur.positions || "") : "";
1356
+ const curOutOnly = (typeof cur === "object") ? !!cur.output_only : false;
1357
+
1358
+ const sl = document.createElement("div");
1359
+ sl.className = "slider-row";
1360
+ const range = document.createElement("input");
1361
+ range.type = "range"; range.min = -100; range.max = 100; range.step = 1;
1362
+ range.value = curAlpha;
1363
+ const val = document.createElement("span");
1364
+ val.className = "alpha-val";
1365
+ val.textContent = (curAlpha >= 0 ? "+" : "") + curAlpha.toFixed(0);
1366
+ range.addEventListener("input", () => {
1367
+ const a = parseFloat(range.value);
1368
+ const slot = state.steering.get(id);
1369
+ const next = (typeof slot === "object") ? {...slot, alpha:a} : {alpha:a, positions:"", output_only:false};
1370
+ state.steering.set(id, next);
1371
+ val.textContent = (a >= 0 ? "+" : "") + a.toFixed(0);
1372
+ repaintCloud();
1373
+ });
1374
+ sl.appendChild(range);
1375
+ sl.appendChild(val);
1376
+ wrap.appendChild(sl);
1377
+
1378
+ // Position-selective steering + output-only toggle
1379
+ const adv = document.createElement("div");
1380
+ adv.style.cssText = "margin-top:6px; display:flex; gap:6px; align-items:center; flex-wrap:wrap;";
1381
+ const posLbl = document.createElement("span");
1382
+ posLbl.textContent = "positions";
1383
+ posLbl.style.cssText = "font-size:10px; color:var(--fg-faint); flex:0 0 auto;";
1384
+ const posInput = document.createElement("input");
1385
+ posInput.type = "text";
1386
+ posInput.placeholder = "all";
1387
+ posInput.value = curPositions;
1388
+ posInput.style.cssText = "flex:1; min-width:60px; font-size:11px; padding:3px 6px; background:rgba(0,0,0,0.3); border:1px solid var(--border-strong); color:var(--fg); border-radius:4px; font-family:var(--mono);";
1389
+ posInput.title = "Token positions to steer at: 'all' or '3-7' or '0,2,5-8'. Empty = all.";
1390
+ posInput.addEventListener("change", () => {
1391
+ const slot = state.steering.get(id);
1392
+ const next = (typeof slot === "object") ? {...slot, positions: posInput.value} : {alpha: slot, positions: posInput.value, output_only: false};
1393
+ state.steering.set(id, next);
1394
+ });
1395
+ const outLbl = document.createElement("label");
1396
+ outLbl.style.cssText = "font-size:10px; color:var(--fg-faint); display:inline-flex; align-items:center; gap:3px; cursor:pointer; flex:0 0 auto;";
1397
+ const outChk = document.createElement("input");
1398
+ outChk.type = "checkbox";
1399
+ outChk.checked = curOutOnly;
1400
+ outChk.style.cssText = "margin:0;";
1401
+ outChk.addEventListener("change", () => {
1402
+ const slot = state.steering.get(id);
1403
+ const next = (typeof slot === "object") ? {...slot, output_only: outChk.checked} : {alpha: slot, positions: "", output_only: outChk.checked};
1404
+ state.steering.set(id, next);
1405
+ });
1406
+ outLbl.appendChild(outChk);
1407
+ outLbl.appendChild(document.createTextNode("output only"));
1408
+ adv.appendChild(posLbl);
1409
+ adv.appendChild(posInput);
1410
+ adv.appendChild(outLbl);
1411
+ wrap.appendChild(adv);
1412
+
1413
+ const lbl = document.createElement("div");
1414
+ lbl.style.cssText = "font-size:10px; color:var(--fg-faint); margin-top:4px;";
1415
+ lbl.textContent = "α: −100 ← suppress … +100 → amplify";
1416
+ wrap.appendChild(lbl);
1417
+ }
1418
+ return wrap;
1419
+ }
1420
+
1421
+ function addSteerSlot(id, alpha=0) {
1422
+ state.steering.set(id, alpha);
1423
+ renderFeatures();
1424
+ repaintCloud();
1425
+ }
1426
+
1427
+ // --- Generate action ---
1428
+ document.getElementById("btn-generate").addEventListener("click", async () => {
1429
+ const prompt = document.getElementById("prompt").value;
1430
+ const max_new_tokens = parseInt(document.getElementById("max-tokens").value || "40");
1431
+ const steering = [];
1432
+ for (const [id, slot] of state.steering) {
1433
+ const alpha = (typeof slot === "object") ? slot.alpha : slot;
1434
+ if (alpha === 0) continue;
1435
+ const positions = (typeof slot === "object") ? (slot.positions || null) : null;
1436
+ const output_only = (typeof slot === "object") ? !!slot.output_only : false;
1437
+ steering.push({id, alpha, positions, output_only});
1438
+ }
1439
+ setStatus("generating baseline…", true);
1440
+ document.getElementById("out-baseline").innerHTML = `<span class="loader"></span>`;
1441
+ document.getElementById("out-steered").innerHTML = `<span class="loader"></span>`;
1442
+
1443
+ try {
1444
+ const t0 = performance.now();
1445
+ const baseline = await api("/generate", {prompt, steering: [], max_new_tokens, return_probs: true, topk_display: 8});
1446
+ const t1 = performance.now();
1447
+ renderTokenChips("out-baseline", baseline, "blue");
1448
+ document.getElementById("base-time").textContent = ((t1-t0)/1000).toFixed(2) + "s";
1449
+
1450
+ if (steering.length === 0) {
1451
+ document.getElementById("out-steered").innerHTML = `<span class="empty">(no sliders engaged — steered = baseline)</span>`;
1452
+ document.getElementById("steered-time").textContent = "";
1453
+ document.getElementById("verifier").innerHTML = "";
1454
+ setStatus("done");
1455
+ return;
1456
+ }
1457
+ setStatus("generating steered…", true);
1458
+ const t2 = performance.now();
1459
+ const steered = await api("/generate", {prompt, steering, max_new_tokens, return_probs: true, topk_display: 8});
1460
+ const t3 = performance.now();
1461
+ renderTokenChips("out-steered", steered, "magenta");
1462
+ document.getElementById("steered-time").textContent = ((t3-t2)/1000).toFixed(2) + "s";
1463
+
1464
+ // Verifier
1465
+ const v = document.getElementById("verifier");
1466
+ v.innerHTML = "";
1467
+ for (const row of steered.verifier) {
1468
+ const d = row.steered - row.base;
1469
+ const cls = d > 0 ? "delta-up" : (d < 0 ? "delta-down" : "");
1470
+ const sign = d >= 0 ? "+" : "";
1471
+ const ext = [];
1472
+ if (row.positions) ext.push(`pos=${row.positions}`);
1473
+ if (row.output_only) ext.push("output-only");
1474
+ const extStr = ext.length ? ` · ${ext.join(" · ")}` : "";
1475
+ const div = document.createElement("div");
1476
+ div.innerHTML = `feat ${row.id} · α=${row.alpha.toFixed(0)}${extStr} · base=${row.base.toFixed(2)} → steered=${row.steered.toFixed(2)} <span class="${cls}">(Δ ${sign}${d.toFixed(2)})</span>`;
1477
+ v.appendChild(div);
1478
+ }
1479
+ setStatus("done");
1480
+ } catch (e) {
1481
+ setStatus(`generate error: ${e.message}`);
1482
+ document.getElementById("out-baseline").textContent = "(error)";
1483
+ document.getElementById("out-steered").textContent = "(error)";
1484
+ }
1485
+ });
1486
+
1487
+ // --- Per-token probability chip strip ---
1488
+ function renderTokenChips(containerId, gen, theme) {
1489
+ const container = document.getElementById(containerId);
1490
+ if (!gen.tokens || !gen.tokens.length) {
1491
+ container.textContent = gen.text || "(empty)";
1492
+ return;
1493
+ }
1494
+ // Theme: colors for chip background
1495
+ const themeFns = {
1496
+ blue: (p) => {
1497
+ const t = Math.max(0, Math.min(1, p));
1498
+ const r = Math.round(255 * (1 - t*0.85));
1499
+ const g = Math.round(255 * (1 - t*0.55));
1500
+ return [r, g, 255, t < 0.5 ? "#1e3a8a" : "#fff"];
1501
+ },
1502
+ magenta: (p) => {
1503
+ const t = Math.max(0, Math.min(1, p));
1504
+ const r = 255;
1505
+ const g = Math.round(255 * (1 - t*0.7));
1506
+ const b = Math.round(255 * (1 - t*0.5));
1507
+ return [r, g, b, t < 0.5 ? "#7f1d1d" : "#fff"];
1508
+ },
1509
+ };
1510
+ const colorize = themeFns[theme] || themeFns.blue;
1511
+ const chips = gen.tokens.map((row, i) => {
1512
+ const [r,g,b,fg] = colorize(row.prob);
1513
+ const tokDisp = row.tok.replace(/\n/g,"↵").replace(/\t/g,"→");
1514
+ const safe = escapeHtml(tokDisp);
1515
+ const panelHtml = topkPanelHtml(row.topk);
1516
+ return `<span class="tok-chip" data-panel='${escapeAttr(panelHtml)}' style="background:rgb(${r},${g},${b}); color:${fg}; padding:2px 6px; margin:1px; border-radius:4px; cursor:pointer; display:inline-block; font-family:var(--mono); font-size:11px; white-space:nowrap;">${safe}<sub style="opacity:.65; font-size:8px; margin-left:3px;">${(row.prob*100).toFixed(1)}%</sub></span>`;
1517
+ }).join("");
1518
+ container.innerHTML = `
1519
+ <div class="tok-strip" data-prob-root style="line-height:2.4;">
1520
+ <div style="font-size:10px; color:var(--fg-faint); margin-bottom:4px; font-style:italic;">click any token to see top-K candidates</div>
1521
+ <div>${chips}</div>
1522
+ <div data-topk-panel style="display:none; margin-top:6px; padding:6px; background:rgba(0,0,0,0.35); border:1px solid var(--border-strong); border-radius:6px; font-family:var(--mono); font-size:10px;"></div>
1523
+ </div>`;
1524
+ // Wire click-to-pin
1525
+ container.querySelectorAll(".tok-chip").forEach(chip => {
1526
+ chip.addEventListener("click", () => {
1527
+ const root = chip.closest("[data-prob-root]");
1528
+ const panel = root.querySelector("[data-topk-panel]");
1529
+ const wasSelected = chip.dataset.selected === "1";
1530
+ root.querySelectorAll(".tok-chip").forEach(c => {
1531
+ c.dataset.selected = "0"; c.style.outline = "";
1532
+ });
1533
+ if (wasSelected) {
1534
+ panel.innerHTML = ""; panel.style.display = "none";
1535
+ } else {
1536
+ chip.dataset.selected = "1";
1537
+ chip.style.outline = "2px solid #94a3b8";
1538
+ chip.style.outlineOffset = "-1px";
1539
+ panel.innerHTML = chip.dataset.panel;
1540
+ panel.style.display = "block";
1541
+ }
1542
+ });
1543
+ });
1544
+ }
1545
+
1546
+ function topkPanelHtml(topk) {
1547
+ const rows = topk.map((c, idx) => {
1548
+ const safe = escapeHtml(c.tok.replace(/\n/g,"↵").replace(/\t/g,"→"));
1549
+ const bg = c.is_chosen ? "background:rgba(125,249,255,0.12);" : "";
1550
+ const fw = c.is_chosen ? "font-weight:700;" : "";
1551
+ const mark = c.is_chosen ? " ✓" : "";
1552
+ return `<tr style="${bg}">
1553
+ <td style="padding:2px 6px; color:var(--fg-faint); text-align:right;">${idx+1}</td>
1554
+ <td style="padding:2px 6px; color:var(--fg); ${fw}">${safe}${mark}</td>
1555
+ <td style="padding:2px 6px; text-align:right; color:var(--accent);">${(c.prob*100).toFixed(2)}%</td>
1556
+ </tr>`;
1557
+ }).join("");
1558
+ return `<table style="border-collapse:collapse; width:100%;"><tbody>${rows}</tbody></table>`;
1559
+ }
1560
+
1561
+ function escapeAttr(s) { return s.replace(/'/g, "&apos;").replace(/"/g, "&quot;"); }
1562
+ </script>
1563
+ </body>
1564
+ </html>
qwen_scope_obs.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Observation-only tooling for Qwen-Scope SAE features.
2
+
3
+ Three utilities — none of them steer or modify generation. They exist to let a
4
+ reviewer interrogate what the SAE features encode, before any intervention:
5
+
6
+ * encode_prompts(model, tokenizer, sae, prompts, layer)
7
+ For each prompt, returns the last-token sparse feature code (shape:
8
+ len(prompts) x n_features).
9
+
10
+ * top_features_for_prompt(...)
11
+ The N strongest-firing features for a single prompt, with activation
12
+ values. Equivalent to the read-only path in qwen_scope_steer.
13
+
14
+ * differential_features(pos_codes, neg_codes, top_n)
15
+ Given two stacks of feature codes, returns features ranked by
16
+ (mean activation on positive set) - (mean activation on negative set).
17
+ Useful for "which features distinguish concept A from concept B?"
18
+
19
+ * scan_prompts_for_feature(codes, feature_id)
20
+ Given a stack of codes and a feature id, returns the per-prompt
21
+ activation values (zero where the feature didn't make TopK).
22
+
23
+ Nothing here generates steered text. Wire it up to the steering hooks in
24
+ qwen_scope_steer.py only when intervention is the explicit experimental goal,
25
+ and document the intervention separately.
26
+ """
27
+ from __future__ import annotations
28
+
29
+ from dataclasses import dataclass
30
+ from typing import Sequence
31
+
32
+ import torch
33
+
34
+ from qwen_scope_steer import SAE, capture_residual
35
+
36
+
37
+ @dataclass
38
+ class FeatureRanking:
39
+ feature_id: int
40
+ score: float # (pos_mean - neg_mean) for differential, or activation for top-features
41
+ pos_mean: float | None = None
42
+ neg_mean: float | None = None
43
+ pos_fire_rate: float | None = None
44
+ neg_fire_rate: float | None = None
45
+
46
+
47
+ def encode_prompts(model, tokenizer, sae: SAE, prompts: Sequence[str],
48
+ layer_idx: int) -> torch.Tensor:
49
+ """Encode the last-token residual of each prompt through the SAE.
50
+
51
+ Returns codes of shape (len(prompts), n_features) on CPU float32 for
52
+ stable downstream stats. No generation is performed.
53
+ """
54
+ codes = []
55
+ for p in prompts:
56
+ inputs = tokenizer(p, return_tensors="pt").to(model.device)
57
+ with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
58
+ model(**inputs)
59
+ h_last = bucket["h"][0, -1].unsqueeze(0)
60
+ z = sae.encode(h_last)[0]
61
+ codes.append(z.detach().to("cpu", torch.float32))
62
+ return torch.stack(codes, dim=0)
63
+
64
+
65
+ def top_features_for_prompt(model, tokenizer, sae: SAE, prompt: str,
66
+ layer_idx: int, top_n: int = 10) -> list[FeatureRanking]:
67
+ codes = encode_prompts(model, tokenizer, sae, [prompt], layer_idx)[0]
68
+ nz = codes.nonzero(as_tuple=False).flatten()
69
+ vals = codes[nz]
70
+ order = vals.argsort(descending=True)[:top_n]
71
+ return [FeatureRanking(feature_id=int(nz[i]), score=float(vals[i]))
72
+ for i in order]
73
+
74
+
75
+ def differential_features(pos_codes: torch.Tensor, neg_codes: torch.Tensor,
76
+ top_n: int = 20) -> list[FeatureRanking]:
77
+ """Rank features by their differential firing across two prompt sets.
78
+
79
+ pos_codes: (P, F) feature codes for the "positive" prompt set
80
+ neg_codes: (N, F) feature codes for the "negative" prompt set
81
+
82
+ Returns top_n features by (pos_mean - neg_mean), with both means and
83
+ per-set fire rates (fraction of prompts where the feature fired).
84
+ Read-only — no generation, no steering.
85
+ """
86
+ if pos_codes.shape[1] != neg_codes.shape[1]:
87
+ raise ValueError(f"feature dim mismatch: {pos_codes.shape} vs {neg_codes.shape}")
88
+ pos_mean = pos_codes.mean(dim=0)
89
+ neg_mean = neg_codes.mean(dim=0)
90
+ diff = pos_mean - neg_mean
91
+ pos_fire = (pos_codes != 0).float().mean(dim=0)
92
+ neg_fire = (neg_codes != 0).float().mean(dim=0)
93
+ order = diff.argsort(descending=True)[:top_n]
94
+ return [
95
+ FeatureRanking(
96
+ feature_id=int(i),
97
+ score=float(diff[i]),
98
+ pos_mean=float(pos_mean[i]),
99
+ neg_mean=float(neg_mean[i]),
100
+ pos_fire_rate=float(pos_fire[i]),
101
+ neg_fire_rate=float(neg_fire[i]),
102
+ )
103
+ for i in order
104
+ ]
105
+
106
+
107
+ def scan_prompts_for_feature(codes: torch.Tensor, feature_id: int) -> torch.Tensor:
108
+ """Per-prompt activation vector for a single feature (zero where it didn't make TopK)."""
109
+ return codes[:, feature_id]
110
+
111
+
112
+ def fire_rate(codes: torch.Tensor, feature_id: int) -> float:
113
+ """Fraction of prompts on which the feature fired (was in TopK)."""
114
+ return float((codes[:, feature_id] != 0).float().mean())
115
+
116
+
117
+ def pretty_ranking(rs: list[FeatureRanking]) -> str:
118
+ out = []
119
+ for r in rs:
120
+ if r.pos_mean is not None:
121
+ out.append(
122
+ f" feat {r.feature_id:>6d} "
123
+ f"diff={r.score:+8.4f} "
124
+ f"pos_mean={r.pos_mean:+8.4f} neg_mean={r.neg_mean:+8.4f} "
125
+ f"pos_fire={r.pos_fire_rate:.2f} neg_fire={r.neg_fire_rate:.2f}"
126
+ )
127
+ else:
128
+ out.append(f" feat {r.feature_id:>6d} act={r.score:+8.4f}")
129
+ return "\n".join(out)
qwen_scope_steer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen-Scope SAE feature reading + steering for transformers.
2
+
3
+ End-to-end demo:
4
+ 1. Loads a base Qwen3 model and a matching Qwen-Scope TopK SAE checkpoint.
5
+ 2. Captures the residual-stream output of a chosen decoder layer.
6
+ 3. Encodes it through the SAE -> top-K firing features.
7
+ 4. Generates a baseline completion.
8
+ 5. Re-generates with feature steering: residual h <- h + alpha * W_dec[:, feat]
9
+ applied via register_forward_hook on every forward pass.
10
+
11
+ Verified against:
12
+ * Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50 (W_enc 32768x2048, W_dec 2048x32768,
13
+ b_enc 32768, b_dec 2048, all float32, K=50)
14
+ * Qwen/Qwen3-1.7B-Base (28 Qwen3DecoderLayer, hidden_size=2048, layer forward
15
+ returns bare torch.Tensor under transformers >= 5).
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import contextlib
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from huggingface_hub import hf_hub_download
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # SAE
32
+ # ---------------------------------------------------------------------------
33
+ @dataclass
34
+ class SAE:
35
+ W_enc: torch.Tensor # (n_features, d_model)
36
+ W_dec: torch.Tensor # (d_model, n_features)
37
+ b_enc: torch.Tensor # (n_features,)
38
+ b_dec: torch.Tensor # (d_model,)
39
+ k: int # TopK
40
+ layer: int # layer index this SAE belongs to
41
+
42
+ @classmethod
43
+ def from_repo(cls, repo: str, layer: int, k: int, device: str = "cpu",
44
+ dtype: torch.dtype = torch.float32) -> "SAE":
45
+ path = hf_hub_download(repo, f"layer{layer}.sae.pt")
46
+ return cls.from_path(path, layer=layer, k=k, device=device, dtype=dtype)
47
+
48
+ @classmethod
49
+ def from_path(cls, path: str | Path, layer: int, k: int,
50
+ device: str = "cpu", dtype: torch.dtype = torch.float32) -> "SAE":
51
+ sd = torch.load(str(path), map_location=device, weights_only=True)
52
+ for key in ("W_enc", "W_dec", "b_enc", "b_dec"):
53
+ if key not in sd:
54
+ raise KeyError(f"SAE checkpoint at {path} missing key {key!r}; "
55
+ f"got {list(sd.keys())}")
56
+ return cls(
57
+ W_enc=sd["W_enc"].to(device=device, dtype=dtype),
58
+ W_dec=sd["W_dec"].to(device=device, dtype=dtype),
59
+ b_enc=sd["b_enc"].to(device=device, dtype=dtype),
60
+ b_dec=sd["b_dec"].to(device=device, dtype=dtype),
61
+ k=k, layer=layer,
62
+ )
63
+
64
+ @property
65
+ def n_features(self) -> int:
66
+ return self.W_enc.shape[0]
67
+
68
+ @property
69
+ def d_model(self) -> int:
70
+ return self.W_enc.shape[1]
71
+
72
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
73
+ """Encode residual stream activations -> sparse feature codes (TopK)."""
74
+ x = x.to(device=self.W_enc.device, dtype=self.W_enc.dtype)
75
+ pre = F.linear(x, self.W_enc, self.b_enc) # (..., n_features)
76
+ topk_vals, topk_idx = pre.topk(self.k, dim=-1)
77
+ z = torch.zeros_like(pre)
78
+ z.scatter_(-1, topk_idx, topk_vals)
79
+ return z
80
+
81
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
82
+ z = z.to(device=self.W_dec.device, dtype=self.W_dec.dtype)
83
+ return F.linear(z, self.W_dec, self.b_dec)
84
+
85
+ def steering_vector(self, feature_id: int) -> torch.Tensor:
86
+ return self.W_dec[:, feature_id].clone()
87
+
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # Hook helpers
91
+ # ---------------------------------------------------------------------------
92
+ def _layer_output_to_tensor(out):
93
+ """Qwen3DecoderLayer returns torch.Tensor in transformers >= 5,
94
+ a tuple (hidden_states, ...) in transformers < 5. Handle both."""
95
+ if isinstance(out, tuple):
96
+ return out[0], out
97
+ return out, None
98
+
99
+
100
+ def _rebuild_layer_output(new_h: torch.Tensor, original_out):
101
+ if original_out is None:
102
+ return new_h
103
+ return (new_h, *original_out[1:])
104
+
105
+
106
+ @contextlib.contextmanager
107
+ def capture_residual(model, layer_idx: int):
108
+ """Capture the residual-stream output of model.model.layers[layer_idx]."""
109
+ bucket: dict = {}
110
+ layer = model.model.layers[layer_idx]
111
+
112
+ def hook(_module, _inp, out):
113
+ h, _ = _layer_output_to_tensor(out)
114
+ bucket["h"] = h.detach()
115
+ return out
116
+
117
+ handle = layer.register_forward_hook(hook)
118
+ try:
119
+ yield bucket
120
+ finally:
121
+ handle.remove()
122
+
123
+
124
+ @contextlib.contextmanager
125
+ def steer(model, layer_idx: int, direction: torch.Tensor, alpha: float):
126
+ """Add `alpha * direction` to the residual stream output of layer_idx
127
+ on every forward pass while the context is active."""
128
+ layer = model.model.layers[layer_idx]
129
+ direction = direction.detach()
130
+
131
+ def hook(_module, _inp, out):
132
+ h, original = _layer_output_to_tensor(out)
133
+ d = direction.to(device=h.device, dtype=h.dtype)
134
+ new_h = h + alpha * d
135
+ return _rebuild_layer_output(new_h, original)
136
+
137
+ handle = layer.register_forward_hook(hook)
138
+ try:
139
+ yield
140
+ finally:
141
+ handle.remove()
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Pipeline
146
+ # ---------------------------------------------------------------------------
147
+ def read_top_features(model, tokenizer, sae: SAE, prompt: str,
148
+ layer_idx: int, top_n: int = 10):
149
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
150
+ with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
151
+ model(**inputs)
152
+ h = bucket["h"] # (1, T, d_model) on model.device
153
+ h_last = h[0, -1].unsqueeze(0) # (1, d_model) — encode() handles device/dtype
154
+ z = sae.encode(h_last)[0]
155
+ nonzero = z.nonzero(as_tuple=False).flatten()
156
+ vals = z[nonzero]
157
+ order = vals.argsort(descending=True)
158
+ top = nonzero[order][:top_n]
159
+ return [(int(f.item()), float(z[f].item())) for f in top]
160
+
161
+
162
+ def generate(model, tokenizer, prompt: str, max_new_tokens: int = 40):
163
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
164
+ with torch.no_grad():
165
+ out = model.generate(
166
+ **inputs,
167
+ max_new_tokens=max_new_tokens,
168
+ do_sample=False, # deterministic for A/B comparison
169
+ pad_token_id=tokenizer.eos_token_id,
170
+ )
171
+ return tokenizer.decode(out[0], skip_special_tokens=True)
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # CLI
176
+ # ---------------------------------------------------------------------------
177
+ def parse_topk_from_repo(repo: str) -> int:
178
+ # e.g. "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50" -> 50
179
+ suffix = repo.rsplit("L0_", 1)
180
+ if len(suffix) == 2 and suffix[1].isdigit():
181
+ return int(suffix[1])
182
+ return 50
183
+
184
+
185
+ def main():
186
+ ap = argparse.ArgumentParser()
187
+ ap.add_argument("--model", default="Qwen/Qwen3-1.7B-Base")
188
+ ap.add_argument("--sae-repo", default="Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50")
189
+ ap.add_argument("--layer", type=int, default=14)
190
+ ap.add_argument("--prompt", default="The capital of France is")
191
+ ap.add_argument("--max-new-tokens", type=int, default=40)
192
+ ap.add_argument("--alpha", type=float, default=-10.0,
193
+ help="Steering magnitude. Negative suppresses, positive amplifies.")
194
+ ap.add_argument("--suppress-rank", type=int, default=0,
195
+ help="Which top-firing feature (0 = strongest) to steer.")
196
+ ap.add_argument("--feature-id", type=int, default=None,
197
+ help="Override: steer this exact feature instead of a top-rank pick.")
198
+ ap.add_argument("--topk", type=int, default=None,
199
+ help="Override SAE TopK (auto-detected from repo name).")
200
+ ap.add_argument("--device", default=None,
201
+ help="cuda | mps | cpu (auto if omitted)")
202
+ ap.add_argument("--dtype", default="bfloat16",
203
+ choices=["bfloat16", "float16", "float32"])
204
+ args = ap.parse_args()
205
+
206
+ if args.device is None:
207
+ if torch.cuda.is_available():
208
+ args.device = "cuda"
209
+ elif torch.backends.mps.is_available():
210
+ args.device = "mps"
211
+ else:
212
+ args.device = "cpu"
213
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
214
+ "float32": torch.float32}[args.dtype]
215
+
216
+ print(f"[load] model={args.model} device={args.device} dtype={args.dtype}")
217
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
218
+ model = AutoModelForCausalLM.from_pretrained(
219
+ args.model, dtype=dtype, device_map=args.device,
220
+ )
221
+ model.eval()
222
+
223
+ n_layers = len(model.model.layers)
224
+ if not (0 <= args.layer < n_layers):
225
+ raise ValueError(f"--layer {args.layer} out of range; model has {n_layers} layers")
226
+ hidden = model.config.hidden_size
227
+ print(f"[load] {type(model).__name__}: {n_layers} layers, hidden={hidden}")
228
+
229
+ k = args.topk or parse_topk_from_repo(args.sae_repo)
230
+ print(f"[load] SAE repo={args.sae_repo} layer={args.layer} K={k}")
231
+ sae = SAE.from_repo(args.sae_repo, layer=args.layer, k=k,
232
+ device=args.device, dtype=dtype)
233
+ if sae.d_model != hidden:
234
+ raise ValueError(f"SAE d_model={sae.d_model} != model hidden_size={hidden}; "
235
+ f"this SAE doesn't match this model.")
236
+
237
+ # 1. Top features for the prompt
238
+ print(f"\n[features] top firing at layer {args.layer} for prompt: {args.prompt!r}")
239
+ top = read_top_features(model, tokenizer, sae, args.prompt, args.layer, top_n=10)
240
+ for rank, (fid, act) in enumerate(top):
241
+ print(f" rank {rank:2d} feature {fid:>6d} act={act:+.4f}")
242
+
243
+ # Pick steering target
244
+ if args.feature_id is not None:
245
+ target_id = args.feature_id
246
+ else:
247
+ target_id = top[args.suppress_rank][0]
248
+
249
+ # 2. Baseline generation
250
+ print(f"\n[baseline] generating (no steering)...")
251
+ baseline = generate(model, tokenizer, args.prompt, args.max_new_tokens)
252
+ print(f" >>> {baseline!r}")
253
+
254
+ # 3. Steered generation
255
+ print(f"\n[steer] feature {target_id} at layer {args.layer} with alpha={args.alpha}")
256
+ direction = sae.steering_vector(target_id)
257
+ with steer(model, args.layer, direction, args.alpha):
258
+ steered = generate(model, tokenizer, args.prompt, args.max_new_tokens)
259
+ print(f" >>> {steered!r}")
260
+
261
+ # 4. Verify the steering actually moved the feature
262
+ inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device)
263
+ with torch.no_grad(), capture_residual(model, args.layer) as bucket:
264
+ model(**inputs)
265
+ base_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item()
266
+ with torch.no_grad(), steer(model, args.layer, direction, args.alpha), \
267
+ capture_residual(model, args.layer) as bucket:
268
+ model(**inputs)
269
+ steered_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item()
270
+ print(f"\n[verify] feature {target_id} activation: baseline={base_act:+.4f} "
271
+ f"steered={steered_act:+.4f} delta={steered_act - base_act:+.4f}")
272
+ if args.alpha > 0 and steered_act <= base_act:
273
+ print(" WARN: alpha>0 but activation didn't go up — unexpected.")
274
+ if args.alpha < 0 and steered_act >= base_act:
275
+ print(" WARN: alpha<0 but activation didn't go down — unexpected.")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch>=2.5.0
3
+ transformers>=4.51.0
4
+ accelerate>=1.0
5
+ huggingface_hub>=0.25
6
+ safetensors>=0.4
7
+ numpy>=1.26
8
+ fastapi>=0.110
9
+ uvicorn[standard]>=0.30
10
+ pydantic>=2.5
server.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI backend for the Qwen-Scope HF Space deployment.
2
+
3
+ Locked to Qwen3-1.7B-Base + the W32K-L0_50 SAE so it fits inside a
4
+ free-tier HF Space (CPU, ~16GB RAM). Layer is still selectable.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import gc
9
+ import json
10
+ import os
11
+ import threading
12
+ from contextlib import asynccontextmanager
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import torch
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from fastapi.responses import FileResponse
20
+ from pydantic import BaseModel
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ from qwen_scope_steer import SAE, capture_residual, steer
24
+
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ DTYPE = torch.float32 # bf16 on CPU is slow + flaky on free-tier hardware
27
+
28
+ POSITIONS_DIR = Path(os.environ.get(
29
+ "POSITIONS_DIR",
30
+ str(Path(__file__).parent / "feature_positions"),
31
+ ))
32
+ POSITIONS_DIR.mkdir(exist_ok=True, parents=True)
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Catalog of supported model + SAE pairs.
37
+ # Verified against the Qwen org HF listing. For Qwen3.6 (no native SAE yet)
38
+ # we point at the Qwen3.5 SAE that matches dimensions; this is a best-effort
39
+ # fallback flagged as transferred=True in the response.
40
+ # ---------------------------------------------------------------------------
41
+ MODEL_CATALOG = [
42
+ {
43
+ "model": "Qwen/Qwen3-1.7B-Base",
44
+ "sae_repo": "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50",
45
+ "default_layer": 14, "n_layers": 28, "n_features": 32768,
46
+ "approx_size_gb": 3.4, "k": 50, "transferred": False,
47
+ },
48
+ ]
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # State and locks
53
+ # ---------------------------------------------------------------------------
54
+ state: dict = {}
55
+ load_lock = threading.Lock()
56
+
57
+
58
+ def _find_decoder_layers(model):
59
+ """Return (layers_module_list, dotted_path) for any qwen3 / qwen3_5 model.
60
+
61
+ Handles:
62
+ * model.model.layers (standard Qwen3*ForCausalLM)
63
+ * model.language_model.model.layers (multimodal Qwen3_5ForConditionalGeneration)
64
+ """
65
+ for path in (("model", "model", "layers"),
66
+ ("model", "layers"),
67
+ ("language_model", "model", "layers"),
68
+ ("model", "language_model", "model", "layers")):
69
+ obj = model
70
+ ok = True
71
+ for p in path:
72
+ if not hasattr(obj, p):
73
+ ok = False; break
74
+ obj = getattr(obj, p)
75
+ if ok and hasattr(obj, "__len__") and len(obj) > 0:
76
+ return obj, ".".join(path)
77
+ raise RuntimeError(f"could not locate decoder layers on "
78
+ f"{type(model).__name__}")
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Position computation (cached per SAE).
83
+ # Uses TruncatedSVD via numpy power-iteration for the 80K feature SAE,
84
+ # economy SVD for smaller ones. Good enough for visualization layout.
85
+ # ---------------------------------------------------------------------------
86
+ def _safe_filename(s: str) -> str:
87
+ return s.replace("/", "__")
88
+
89
+
90
+ def _positions_path(sae_repo: str, layer: int | None = None) -> Path:
91
+ if layer is None:
92
+ return POSITIONS_DIR / f"{_safe_filename(sae_repo)}.json"
93
+ return POSITIONS_DIR / f"{_safe_filename(sae_repo)}__L{layer}.json"
94
+
95
+
96
+ def compute_positions(W_enc: torch.Tensor) -> list[list[float]]:
97
+ X = W_enc.detach().to("cpu", torch.float32).numpy() # (n_features, d_model)
98
+ X = X - X.mean(axis=0, keepdims=True)
99
+ n, d = X.shape
100
+ if n * d <= 32768 * 4096:
101
+ # Economy SVD is fine for the smaller SAEs.
102
+ _, _, Vt = np.linalg.svd(X, full_matrices=False)
103
+ pos = X @ Vt[:3].T
104
+ else:
105
+ # Randomized SVD for very large SAEs (e.g. 80K * 5120).
106
+ rng = np.random.default_rng(0)
107
+ Q = rng.standard_normal((d, 8)).astype(np.float32)
108
+ for _ in range(3): # power iterations
109
+ Q = X.T @ (X @ Q)
110
+ Q, _ = np.linalg.qr(Q)
111
+ Y = X @ Q # (n, 8)
112
+ _, _, Vt2 = np.linalg.svd(Y, full_matrices=False)
113
+ pos = Y @ Vt2[:3].T
114
+ pos = pos / max(abs(pos.min()), abs(pos.max()))
115
+ return pos.tolist()
116
+
117
+
118
+ def load_or_compute_positions(W_enc: torch.Tensor, sae_repo: str,
119
+ layer: int | None = None) -> list[list[float]]:
120
+ # Try layer-specific cache first; fall back to legacy SAE-repo-only cache
121
+ # so existing files don't go stale.
122
+ p_layer = _positions_path(sae_repo, layer)
123
+ p_legacy = _positions_path(sae_repo)
124
+ for p in (p_layer, p_legacy):
125
+ if p.exists():
126
+ try:
127
+ return json.loads(p.read_text())["positions"]
128
+ except Exception:
129
+ pass
130
+ pos = compute_positions(W_enc)
131
+ p_layer.write_text(json.dumps({"positions": pos}))
132
+ return pos
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # Model + SAE loading (called both at startup and on /load_model)
137
+ # ---------------------------------------------------------------------------
138
+ def _free_current_state():
139
+ """Release the currently loaded model + SAE so a new one can fit."""
140
+ for k in ("model", "tokenizer", "sae", "layers"):
141
+ if k in state:
142
+ del state[k]
143
+ gc.collect()
144
+ if hasattr(torch, "mps") and torch.backends.mps.is_available():
145
+ try:
146
+ torch.mps.empty_cache()
147
+ except Exception:
148
+ pass
149
+ if torch.cuda.is_available():
150
+ try:
151
+ torch.cuda.empty_cache()
152
+ except Exception:
153
+ pass
154
+
155
+
156
+ def _catalog_entry(model_name: str, sae_repo: str | None) -> dict:
157
+ """Find the catalog row that matches model_name (and optionally sae_repo)."""
158
+ for row in MODEL_CATALOG:
159
+ if row["model"] == model_name and (sae_repo is None or row["sae_repo"] == sae_repo):
160
+ return row
161
+ raise HTTPException(status_code=400,
162
+ detail=f"unknown model/sae combination: {model_name} / {sae_repo}")
163
+
164
+
165
+ def load_state(model_name: str, sae_repo: str | None = None,
166
+ layer: int | None = None, k: int = 50) -> dict:
167
+ """Replace the loaded model+SAE+layer with the requested one."""
168
+ entry = _catalog_entry(model_name, sae_repo)
169
+ sae_repo = entry["sae_repo"]
170
+ layer = entry["default_layer"] if layer is None else int(layer)
171
+ k = entry.get("k", k)
172
+
173
+ print(f"[load] {model_name} ({entry['approx_size_gb']:.0f}GB) "
174
+ f"+ SAE {sae_repo} layer {layer} on {DEVICE}")
175
+
176
+ _free_current_state()
177
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
178
+ model = AutoModelForCausalLM.from_pretrained(
179
+ model_name, dtype=DTYPE, device_map=DEVICE,
180
+ )
181
+ model.eval()
182
+
183
+ layers, layers_path = _find_decoder_layers(model)
184
+ n_layers = len(layers)
185
+ if not (0 <= layer < n_layers):
186
+ layer = min(max(0, layer), n_layers - 1)
187
+
188
+ print(f"[load] model loaded: {type(model).__name__}, layers at "
189
+ f"'{layers_path}', n={n_layers}")
190
+
191
+ sae = SAE.from_repo(sae_repo, layer=layer, k=k, device=DEVICE, dtype=DTYPE)
192
+ print(f"[load] SAE loaded: n_features={sae.n_features}, d_model={sae.d_model}")
193
+
194
+ print("[load] computing/loading 3D feature positions")
195
+ positions = load_or_compute_positions(sae.W_enc, sae_repo, layer)
196
+ _sae_cache_put(sae_repo, layer, sae)
197
+
198
+ state.update(
199
+ model=model, tokenizer=tokenizer, sae=sae,
200
+ layers=layers, layers_path=layers_path,
201
+ positions=positions, n_layers=n_layers,
202
+ current_model=model_name, current_sae=sae_repo,
203
+ current_layer=layer, current_k=k,
204
+ catalog_entry=entry,
205
+ )
206
+ print("[load] ready")
207
+ return state
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # Hook helpers — work against state["layers"] not model.model.layers
212
+ # ---------------------------------------------------------------------------
213
+ import contextlib
214
+
215
+ @contextlib.contextmanager
216
+ def _capture_at(layer_module):
217
+ bucket = {}
218
+ def hook(_m, _i, out):
219
+ h = out[0] if isinstance(out, tuple) else out
220
+ bucket["h"] = h.detach()
221
+ return out
222
+ handle = layer_module.register_forward_hook(hook)
223
+ try:
224
+ yield bucket
225
+ finally:
226
+ handle.remove()
227
+
228
+
229
+ @contextlib.contextmanager
230
+ def _steer_at(layer_module, direction, alpha, *,
231
+ positions=None, output_only=False, prompt_len=None):
232
+ """Hook adds α·direction to layer residual, with position/decode controls.
233
+
234
+ positions : None or "all" → every token; list[int] → only those absolute
235
+ token indices (works across prefill + decode).
236
+ output_only : if True, only steer during decode (skip prefill entirely).
237
+ prompt_len : length of the prompt; needed to map decode-step counter
238
+ to absolute position when positions is a list.
239
+ """
240
+ direction = direction.detach()
241
+ counter = [0]
242
+ pos_set = set(positions) if isinstance(positions, (list, set)) else None
243
+
244
+ def hook(_m, _i, out):
245
+ h = out[0] if isinstance(out, tuple) else out
246
+ d = direction.to(device=h.device, dtype=h.dtype)
247
+ cur = counter[0]
248
+ counter[0] += 1
249
+ new_h = h
250
+ is_prefill = (cur == 0)
251
+
252
+ if is_prefill:
253
+ seq = h.shape[1]
254
+ if output_only:
255
+ pass # leave prompt untouched
256
+ elif pos_set is None:
257
+ new_h = h + alpha * d
258
+ else:
259
+ new_h = h.clone()
260
+ for p in pos_set:
261
+ if 0 <= p < seq:
262
+ new_h[:, p, :] = new_h[:, p, :] + alpha * d
263
+ else:
264
+ # Decode step — h is [batch, 1, hidden] (one new token)
265
+ cur_pos = (prompt_len or 0) + cur - 1
266
+ if pos_set is None or output_only or (cur_pos in pos_set):
267
+ new_h = h + alpha * d
268
+
269
+ return (new_h, *out[1:]) if isinstance(out, tuple) else new_h
270
+
271
+ handle = layer_module.register_forward_hook(hook)
272
+ try:
273
+ yield
274
+ finally:
275
+ handle.remove()
276
+
277
+
278
+ def _parse_positions(s: str | None):
279
+ """Parse '3', '3-7', '0,2,5-8', 'all', or None into a position spec.
280
+ Returns 'all' or a list[int] (or None if input is empty/None)."""
281
+ if s is None or not str(s).strip():
282
+ return None
283
+ s = str(s).strip().lower()
284
+ if s == "all":
285
+ return "all"
286
+ out: list[int] = []
287
+ for part in s.split(","):
288
+ part = part.strip()
289
+ if not part:
290
+ continue
291
+ if "-" in part:
292
+ try:
293
+ lo, hi = part.split("-", 1)
294
+ out.extend(range(int(lo), int(hi) + 1))
295
+ except ValueError:
296
+ continue
297
+ else:
298
+ try:
299
+ out.append(int(part))
300
+ except ValueError:
301
+ continue
302
+ return sorted(set(out)) if out else None
303
+
304
+
305
+ def _hook_stack(layer_module, sae, specs, prompt_len=None):
306
+ from contextlib import ExitStack
307
+ stack = ExitStack()
308
+ for s in specs:
309
+ d = sae.steering_vector(s.id)
310
+ positions = _parse_positions(getattr(s, "positions", None))
311
+ output_only = bool(getattr(s, "output_only", False))
312
+ # "all" or None both mean "every position" inside _steer_at — pass None.
313
+ eff_positions = None if (positions is None or positions == "all") else positions
314
+ stack.enter_context(_steer_at(
315
+ layer_module, d, s.alpha,
316
+ positions=eff_positions,
317
+ output_only=output_only,
318
+ prompt_len=prompt_len,
319
+ ))
320
+ return stack
321
+
322
+
323
+ # ---------------------------------------------------------------------------
324
+ # Lifespan + app
325
+ # ---------------------------------------------------------------------------
326
+ @asynccontextmanager
327
+ async def lifespan(app: FastAPI):
328
+ # Default startup: small model so the demo is interactive immediately.
329
+ load_state("Qwen/Qwen3-1.7B-Base")
330
+ yield
331
+ state.clear()
332
+
333
+
334
+ app = FastAPI(lifespan=lifespan)
335
+ app.add_middleware(CORSMiddleware, allow_origins=["*"],
336
+ allow_methods=["*"], allow_headers=["*"])
337
+
338
+
339
+ # ---------------------------------------------------------------------------
340
+ # Request models
341
+ # ---------------------------------------------------------------------------
342
+ class EncodeRequest(BaseModel):
343
+ prompt: str
344
+ top_n: int = 20
345
+
346
+
347
+ class SteerSpec(BaseModel):
348
+ id: int
349
+ alpha: float
350
+ positions: str | None = None # "all" | "3-7" | "0,2,5" | None (= all)
351
+ output_only: bool = False # if True, steer only during decode, not prompt
352
+
353
+
354
+ class GenerateRequest(BaseModel):
355
+ prompt: str
356
+ steering: list[SteerSpec] = []
357
+ max_new_tokens: int = 40
358
+ return_probs: bool = False # if True, return per-token softmax + top-K candidates
359
+ topk_display: int = 8 # number of candidate tokens to expose per step
360
+
361
+
362
+ class LoadModelRequest(BaseModel):
363
+ model: str
364
+ sae_repo: str | None = None
365
+ layer: int | None = None
366
+
367
+
368
+ class SetLayerRequest(BaseModel):
369
+ layer: int
370
+
371
+
372
+ # ---------------------------------------------------------------------------
373
+ # Routes
374
+ # ---------------------------------------------------------------------------
375
+ @app.get("/")
376
+ def index():
377
+ return FileResponse(Path(__file__).parent / "index.html")
378
+
379
+
380
+ @app.get("/health")
381
+ def health():
382
+ sae = state.get("sae")
383
+ return {
384
+ "ok": True,
385
+ "model": state.get("current_model"),
386
+ "sae": state.get("current_sae"),
387
+ "layer": state.get("current_layer"),
388
+ "device": DEVICE,
389
+ "dtype": str(DTYPE).replace("torch.", ""),
390
+ "n_features": sae.n_features if sae else None,
391
+ "n_layers": state.get("n_layers"),
392
+ "transferred": state.get("catalog_entry", {}).get("transferred", False),
393
+ "note": state.get("catalog_entry", {}).get("note", ""),
394
+ }
395
+
396
+
397
+ @app.get("/list_models")
398
+ def list_models():
399
+ return {"models": MODEL_CATALOG}
400
+
401
+
402
+ @app.post("/load_model")
403
+ def load_model(req: LoadModelRequest):
404
+ with load_lock:
405
+ try:
406
+ load_state(req.model, req.sae_repo, req.layer)
407
+ except HTTPException:
408
+ raise
409
+ except Exception as e:
410
+ raise HTTPException(status_code=500, detail=f"load failed: {e}")
411
+ sae = state["sae"]
412
+ return {
413
+ "ok": True,
414
+ "model": state["current_model"],
415
+ "sae": state["current_sae"],
416
+ "layer": state["current_layer"],
417
+ "n_features": sae.n_features,
418
+ "n_layers": state["n_layers"],
419
+ "transferred": state["catalog_entry"].get("transferred", False),
420
+ "note": state["catalog_entry"].get("note", ""),
421
+ "positions": state["positions"],
422
+ }
423
+
424
+
425
+ # In-memory LRU cache of recently-used SAE checkpoints, keyed by
426
+ # (sae_repo, layer). Each SAE for the 1.7B model is ~537 MB on disk and
427
+ # similar in RAM at fp32; for the 27B SAE it's ~3.3 GB. Cap conservatively.
428
+ _sae_lru: "OrderedDict[tuple[str,int], SAE]" = None # initialized lazily
429
+ SAE_LRU_MAX = 6
430
+
431
+
432
+ def _sae_cache_get(sae_repo: str, layer: int):
433
+ global _sae_lru
434
+ if _sae_lru is None:
435
+ from collections import OrderedDict
436
+ _sae_lru = OrderedDict()
437
+ key = (sae_repo, layer)
438
+ if key in _sae_lru:
439
+ _sae_lru.move_to_end(key)
440
+ return _sae_lru[key]
441
+ return None
442
+
443
+
444
+ def _sae_cache_put(sae_repo: str, layer: int, sae: SAE):
445
+ global _sae_lru
446
+ if _sae_lru is None:
447
+ from collections import OrderedDict
448
+ _sae_lru = OrderedDict()
449
+ key = (sae_repo, layer)
450
+ _sae_lru[key] = sae
451
+ _sae_lru.move_to_end(key)
452
+ while len(_sae_lru) > SAE_LRU_MAX:
453
+ _sae_lru.popitem(last=False)
454
+
455
+
456
+ @app.post("/set_layer")
457
+ def set_layer(req: SetLayerRequest):
458
+ """Hot-swap the active SAE to a different layer of the same SAE repo.
459
+
460
+ Keeps the model loaded; just downloads (or fetches from cache) the new
461
+ layer's SAE checkpoint. Recomputes 3D positions for the new SAE
462
+ (cached on disk per SAE-repo+layer).
463
+ """
464
+ if "model" not in state:
465
+ raise HTTPException(status_code=400, detail="no model loaded")
466
+ n_layers = state["n_layers"]
467
+ layer = int(req.layer)
468
+ if not (0 <= layer < n_layers):
469
+ raise HTTPException(status_code=400,
470
+ detail=f"layer must be in [0, {n_layers-1}]")
471
+ sae_repo = state["current_sae"]
472
+ if layer == state["current_layer"]:
473
+ return {"ok": True, "unchanged": True,
474
+ "layer": layer, "n_features": state["sae"].n_features,
475
+ "positions": state["positions"]}
476
+
477
+ with load_lock:
478
+ # 1. SAE itself — try LRU first
479
+ cached = _sae_cache_get(sae_repo, layer)
480
+ if cached is not None:
481
+ sae = cached
482
+ print(f"[layer-swap] SAE {sae_repo} layer {layer} from LRU cache")
483
+ else:
484
+ print(f"[layer-swap] loading SAE {sae_repo} layer {layer}")
485
+ k = state["catalog_entry"].get("k", 50)
486
+ sae = SAE.from_repo(sae_repo, layer=layer, k=k,
487
+ device=DEVICE, dtype=DTYPE)
488
+ _sae_cache_put(sae_repo, layer, sae)
489
+
490
+ # 2. Positions — per-layer cache file on disk
491
+ positions_key = f"{sae_repo}__L{layer}"
492
+ p = POSITIONS_DIR / f"{_safe_filename(positions_key)}.json"
493
+ if p.exists():
494
+ try:
495
+ positions = json.loads(p.read_text())["positions"]
496
+ except Exception:
497
+ positions = compute_positions(sae.W_enc)
498
+ p.write_text(json.dumps({"positions": positions}))
499
+ else:
500
+ print(f"[layer-swap] computing positions for layer {layer}")
501
+ positions = compute_positions(sae.W_enc)
502
+ p.write_text(json.dumps({"positions": positions}))
503
+
504
+ state["sae"] = sae
505
+ state["current_layer"] = layer
506
+ state["positions"] = positions
507
+
508
+ return {
509
+ "ok": True,
510
+ "layer": layer,
511
+ "n_features": sae.n_features,
512
+ "positions": positions,
513
+ "from_cache": cached is not None,
514
+ }
515
+
516
+
517
+ @app.get("/positions")
518
+ def positions():
519
+ return {"positions": state["positions"]}
520
+
521
+
522
+ @app.post("/encode")
523
+ def encode(req: EncodeRequest):
524
+ model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
525
+ layer_module = state["layers"][state["current_layer"]]
526
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
527
+ with torch.no_grad(), _capture_at(layer_module) as bucket:
528
+ model(**inputs)
529
+ h_last = bucket["h"][0, -1].unsqueeze(0)
530
+ z = sae.encode(h_last)[0]
531
+ nz = z.nonzero(as_tuple=False).flatten()
532
+ vals = z[nz]
533
+ order = vals.argsort(descending=True)[:req.top_n]
534
+ top = [{"id": int(nz[i].item()), "act": float(vals[i].item())} for i in order]
535
+ return {"top": top, "n_features": sae.n_features}
536
+
537
+
538
+ class EncodeFullRequest(BaseModel):
539
+ prompt: str
540
+ top_n: int = 16 # number of feature ROWS to return in the heatmap
541
+
542
+
543
+ @app.post("/encode_full")
544
+ def encode_full(req: EncodeFullRequest):
545
+ """Return a per-token feature activation grid for a single prompt.
546
+
547
+ Picks the top_n features ranked by *mean activation across all token
548
+ positions* (matches the official app.py heatmap definition), then returns
549
+ each feature's activation at every token position. Activations that
550
+ didn't make TopK at a given position are zero.
551
+ """
552
+ model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
553
+ layer_module = state["layers"][state["current_layer"]]
554
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
555
+ with torch.no_grad(), _capture_at(layer_module) as bucket:
556
+ model(**inputs)
557
+ h = bucket["h"][0] # (seq_len, d_model)
558
+ z = sae.encode(h) # (seq_len, n_features) sparse TopK
559
+ seq_len = z.shape[0]
560
+ # Token strings for column headers
561
+ ids = inputs["input_ids"][0].tolist()
562
+ tokens = [tokenizer.decode([t], skip_special_tokens=False) for t in ids]
563
+
564
+ # Rank features by mean activation across all positions
565
+ mean_per_feat = z.mean(dim=0)
566
+ top_vals, top_idx = mean_per_feat.topk(min(int(req.top_n), sae.n_features))
567
+ grid = z[:, top_idx] # (seq_len, top_n)
568
+ return {
569
+ "tokens": tokens,
570
+ "feature_ids": [int(i.item()) for i in top_idx],
571
+ "mean_acts": [float(v.item()) for v in top_vals],
572
+ # grid: outer list = features, inner list = positions (transposed for
573
+ # natural row-per-feature rendering in the UI)
574
+ "grid": [[float(grid[p, f].item()) for p in range(seq_len)]
575
+ for f in range(grid.shape[1])],
576
+ "seq_len": seq_len,
577
+ "n_features": sae.n_features,
578
+ }
579
+
580
+
581
+ class EncodeBatchRequest(BaseModel):
582
+ prompts: list[str]
583
+ top_n: int = 20 # top features per prompt to return individually
584
+
585
+
586
+ @app.post("/encode_batch")
587
+ def encode_batch(req: EncodeBatchRequest):
588
+ """Encode N prompts and return per-sample top features + corpus-level stats.
589
+
590
+ For each prompt: encode the last-token residual through the SAE, return
591
+ its top_n firing features. Corpus-level: union of features that fired
592
+ at all, with per-feature firing rate (fraction of prompts where it
593
+ appeared) and mean activation.
594
+ """
595
+ if not req.prompts:
596
+ return {"per_sample": [], "corpus_features": [], "n_features": state["sae"].n_features}
597
+ model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
598
+ layer_module = state["layers"][state["current_layer"]]
599
+
600
+ per_sample = []
601
+ union_act_sum: dict[int, float] = {}
602
+ union_count: dict[int, int] = {}
603
+
604
+ for idx, p in enumerate(req.prompts):
605
+ inputs = tokenizer(p, return_tensors="pt").to(model.device)
606
+ with torch.no_grad(), _capture_at(layer_module) as bucket:
607
+ model(**inputs)
608
+ h_last = bucket["h"][0, -1].unsqueeze(0)
609
+ z = sae.encode(h_last)[0]
610
+ nz = z.nonzero(as_tuple=False).flatten()
611
+ vals = z[nz]
612
+ order = vals.argsort(descending=True)
613
+ top_idx = nz[order][:req.top_n]
614
+ top = [{"id": int(top_idx[i].item()), "act": float(z[top_idx[i]].item())}
615
+ for i in range(len(top_idx))]
616
+ per_sample.append({
617
+ "i": idx,
618
+ "preview": p[:80] + ("…" if len(p) > 80 else ""),
619
+ "len": len(p),
620
+ "top": top,
621
+ "n_active": int(len(nz)),
622
+ })
623
+ # Union stats over ALL nonzero features, not just top
624
+ for fid, v in zip(nz.tolist(), vals.tolist()):
625
+ union_count[fid] = union_count.get(fid, 0) + 1
626
+ union_act_sum[fid] = union_act_sum.get(fid, 0.0) + float(v)
627
+
628
+ n = len(req.prompts)
629
+ corpus = []
630
+ for fid, cnt in union_count.items():
631
+ corpus.append({
632
+ "id": fid,
633
+ "fire_rate": cnt / n,
634
+ "mean_act": union_act_sum[fid] / cnt,
635
+ "n_samples": cnt,
636
+ })
637
+ # Sort by fire_rate desc then mean_act desc
638
+ corpus.sort(key=lambda r: (-r["fire_rate"], -r["mean_act"]))
639
+ return {
640
+ "per_sample": per_sample,
641
+ "corpus_features": corpus[:200], # cap to 200 most frequent
642
+ "n_features": sae.n_features,
643
+ "n_samples": n,
644
+ }
645
+
646
+
647
+ class CompareBatchRequest(BaseModel):
648
+ prompts_a: list[str]
649
+ prompts_b: list[str]
650
+ top_n: int = 30
651
+
652
+
653
+ @app.post("/compare_batch")
654
+ def compare_batch(req: CompareBatchRequest):
655
+ """Differential feature mining between two prompt sets.
656
+
657
+ For each set: encode all prompts, compute per-feature firing rate
658
+ (fraction of prompts where the feature fires) and mean activation.
659
+ Rank features by |fire_rate_A − fire_rate_B|.
660
+ Returns top features that distinguish A from B.
661
+ """
662
+ model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
663
+ layer_module = state["layers"][state["current_layer"]]
664
+
665
+ def _encode_set(prompts):
666
+ n_feats = sae.n_features
667
+ rate = torch.zeros(n_feats, dtype=torch.float32)
668
+ acts = torch.zeros(n_feats, dtype=torch.float32)
669
+ for p in prompts:
670
+ inputs = tokenizer(p, return_tensors="pt").to(model.device)
671
+ with torch.no_grad(), _capture_at(layer_module) as bucket:
672
+ model(**inputs)
673
+ h_last = bucket["h"][0, -1].unsqueeze(0)
674
+ z = sae.encode(h_last)[0].detach().to("cpu", torch.float32)
675
+ rate += (z != 0).float()
676
+ acts += z
677
+ if prompts:
678
+ rate /= len(prompts)
679
+ acts /= len(prompts)
680
+ return rate, acts
681
+
682
+ rate_a, acts_a = _encode_set(req.prompts_a)
683
+ rate_b, acts_b = _encode_set(req.prompts_b)
684
+ diff = (rate_a - rate_b).abs()
685
+ top_vals, top_idx = diff.topk(min(int(req.top_n), sae.n_features))
686
+ rows = []
687
+ for v, fid in zip(top_vals.tolist(), top_idx.tolist()):
688
+ rows.append({
689
+ "id": int(fid),
690
+ "diff": float(v),
691
+ "rate_a": float(rate_a[fid]),
692
+ "rate_b": float(rate_b[fid]),
693
+ "act_a": float(acts_a[fid]),
694
+ "act_b": float(acts_b[fid]),
695
+ "winner": "a" if rate_a[fid] >= rate_b[fid] else "b",
696
+ })
697
+ return {"top_diff": rows, "n_a": len(req.prompts_a), "n_b": len(req.prompts_b),
698
+ "n_features": sae.n_features}
699
+
700
+
701
+ class SynthRequest(BaseModel):
702
+ seed_prompts: list[str]
703
+ steering: list[SteerSpec] = []
704
+ max_new_tokens: int = 40
705
+
706
+
707
+ @app.post("/synth_batch")
708
+ def synth_batch(req: SynthRequest):
709
+ """Bulk steered synthesis: run steered generate over N seed prompts.
710
+
711
+ Useful for the data-centric synthesis workflow: produce K examples
712
+ that fire feature F at strength α, for downstream training data.
713
+ """
714
+ if not req.seed_prompts:
715
+ return {"results": []}
716
+ model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
717
+ layer_module = state["layers"][state["current_layer"]]
718
+ results = []
719
+ for seed in req.seed_prompts:
720
+ inputs = tokenizer(seed, return_tensors="pt").to(model.device)
721
+ with _hook_stack(layer_module, sae, req.steering):
722
+ with torch.no_grad():
723
+ out = model.generate(
724
+ **inputs,
725
+ max_new_tokens=req.max_new_tokens,
726
+ do_sample=False,
727
+ pad_token_id=tokenizer.eos_token_id,
728
+ )
729
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
730
+ results.append({"seed": seed, "text": text})
731
+ return {"results": results}
732
+
733
+
734
+ def _extract_per_token_probs(gen_out, prompt_len, tokenizer, top_k):
735
+ """Build per-step probabilities + top-K candidate strings."""
736
+ new_ids = gen_out.sequences[0][prompt_len:].tolist()
737
+ if not new_ids:
738
+ return []
739
+ rows = []
740
+ for step, score_t in enumerate(gen_out.scores):
741
+ probs = torch.softmax(score_t[0].float(), dim=-1)
742
+ chosen_id = new_ids[step]
743
+ chosen_prob = float(probs[chosen_id].item())
744
+ top_vals, top_ids = probs.topk(min(top_k, probs.shape[0]))
745
+ top_ids_list = top_ids.tolist()
746
+ # Decode one batch (chosen + topK) to limit tokenizer overhead
747
+ decoded_chosen = tokenizer.decode([chosen_id], skip_special_tokens=False)
748
+ decoded_top = tokenizer.batch_decode([[t] for t in top_ids_list], skip_special_tokens=False)
749
+ topk = []
750
+ for tid, tv, ts in zip(top_ids_list, top_vals.tolist(), decoded_top):
751
+ topk.append({"tok": ts, "prob": float(tv), "is_chosen": tid == chosen_id})
752
+ # If the chosen token wasn't in top-K, append it explicitly
753
+ if chosen_id not in top_ids_list:
754
+ topk.append({"tok": decoded_chosen, "prob": chosen_prob, "is_chosen": True})
755
+ rows.append({"tok": decoded_chosen, "prob": chosen_prob, "topk": topk})
756
+ return rows
757
+
758
+
759
+ @app.post("/generate")
760
+ def generate(req: GenerateRequest):
761
+ model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
762
+ layer_module = state["layers"][state["current_layer"]]
763
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
764
+ prompt_len = int(inputs["input_ids"].shape[1])
765
+
766
+ base_acts = {}
767
+ if req.steering:
768
+ with torch.no_grad(), _capture_at(layer_module) as bucket:
769
+ model(**inputs)
770
+ z_base = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0]
771
+ for s in req.steering:
772
+ base_acts[s.id] = float(z_base[s.id].item())
773
+
774
+ gen_kwargs = dict(
775
+ **inputs,
776
+ max_new_tokens=req.max_new_tokens,
777
+ do_sample=False,
778
+ pad_token_id=tokenizer.eos_token_id,
779
+ )
780
+ if req.return_probs:
781
+ gen_kwargs["return_dict_in_generate"] = True
782
+ gen_kwargs["output_scores"] = True
783
+
784
+ with _hook_stack(layer_module, sae, req.steering, prompt_len=prompt_len):
785
+ with torch.no_grad():
786
+ out = model.generate(**gen_kwargs)
787
+ seq = out.sequences[0] if req.return_probs else out[0]
788
+ text = tokenizer.decode(seq, skip_special_tokens=True)
789
+ per_token_probs = (_extract_per_token_probs(out, prompt_len, tokenizer, req.topk_display)
790
+ if req.return_probs else None)
791
+
792
+ steered_acts = {}
793
+ if req.steering:
794
+ with torch.no_grad(), _capture_at(layer_module) as bucket:
795
+ model(**inputs)
796
+ z_steered = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0]
797
+ for s in req.steering:
798
+ steered_acts[s.id] = float(z_steered[s.id].item())
799
+
800
+ verifier = [
801
+ {"id": s.id, "alpha": s.alpha,
802
+ "positions": s.positions,
803
+ "output_only": s.output_only,
804
+ "base": base_acts.get(s.id, 0.0),
805
+ "steered": steered_acts.get(s.id, 0.0)}
806
+ for s in req.steering
807
+ ]
808
+ resp = {"text": text, "verifier": verifier, "prompt_len": prompt_len}
809
+ if per_token_probs is not None:
810
+ resp["tokens"] = per_token_probs
811
+ return resp
812
+
813
+
814
+ if __name__ == "__main__":
815
+ import uvicorn
816
+ port = int(os.environ.get("PORT", 7860))
817
+ uvicorn.run(app, host="0.0.0.0", port=port)