Spaces:
Running
Running
initial qwen-scope-live deploy
Browse files- Dockerfile +40 -0
- README.md +51 -4
- __pycache__/qwen_scope_steer.cpython-314.pyc +0 -0
- __pycache__/server.cpython-314.pyc +0 -0
- index.html +1564 -0
- qwen_scope_obs.py +129 -0
- qwen_scope_steer.py +279 -0
- requirements.txt +10 -0
- server.py +817 -0
Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Qwen-Scope Live SAE Feature Steering — HF Space Docker SDK image.
|
| 2 |
+
# Free tier (CPU, ~16GB RAM) — locked to Qwen3-1.7B-Base only.
|
| 3 |
+
FROM python:3.11-slim
|
| 4 |
+
|
| 5 |
+
# Avoid interactive prompts during installs
|
| 6 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 7 |
+
PYTHONUNBUFFERED=1 \
|
| 8 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 9 |
+
PIP_NO_CACHE_DIR=1 \
|
| 10 |
+
HF_HUB_DISABLE_TELEMETRY=1 \
|
| 11 |
+
TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 12 |
+
|
| 13 |
+
# HF Spaces convention: /home/user is writable, expects PORT=7860
|
| 14 |
+
ENV HOME=/home/user \
|
| 15 |
+
PATH=/home/user/.local/bin:$PATH \
|
| 16 |
+
HF_HOME=/home/user/.cache/huggingface \
|
| 17 |
+
PORT=7860
|
| 18 |
+
|
| 19 |
+
# Add a non-root user (HF Spaces best practice)
|
| 20 |
+
RUN useradd -m -u 1000 user && \
|
| 21 |
+
apt-get update && apt-get install -y --no-install-recommends \
|
| 22 |
+
build-essential ca-certificates && \
|
| 23 |
+
rm -rf /var/lib/apt/lists/*
|
| 24 |
+
|
| 25 |
+
USER user
|
| 26 |
+
WORKDIR $HOME/app
|
| 27 |
+
|
| 28 |
+
# Install Python deps (CPU-only torch from PyTorch index)
|
| 29 |
+
COPY --chown=user:user requirements.txt .
|
| 30 |
+
RUN pip install --user --no-cache-dir -r requirements.txt
|
| 31 |
+
|
| 32 |
+
# Copy application files
|
| 33 |
+
COPY --chown=user:user qwen_scope_steer.py qwen_scope_obs.py server.py index.html ./
|
| 34 |
+
|
| 35 |
+
# Pre-create cache dir (writable for the cached SAE positions JSON)
|
| 36 |
+
RUN mkdir -p $HOME/app/feature_positions
|
| 37 |
+
|
| 38 |
+
EXPOSE 7860
|
| 39 |
+
|
| 40 |
+
CMD ["python", "server.py"]
|
README.md
CHANGED
|
@@ -1,10 +1,57 @@
|
|
| 1 |
---
|
| 2 |
-
title: Qwen
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: "Qwen-Scope: Decoding Intelligence, Unleashing Potential"
|
| 3 |
+
emoji: 🔬
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
short_description: Live SAE feature steering for Qwen3-1.7B
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Qwen-Scope: Decoding Intelligence, Unleashing Potential
|
| 14 |
+
|
| 15 |
+
Live, interactive demo of the [Qwen-Scope](https://huggingface.co/Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50) sparse-autoencoder release. Type a prompt, see which SAE features fire, drag a slider to steer them, and watch the model's actual generated text change in real time — all running in your browser against `Qwen/Qwen3-1.7B-Base` + the layer-14 (and layer-selectable) Qwen-Scope SAE.
|
| 16 |
+
|
| 17 |
+
This Space implements **three of the four pillars** from the Qwen-Scope paper:
|
| 18 |
+
|
| 19 |
+
- **Steering** — interpretable feature-level control at inference time (per-token heatmap, position-selective steering, output-only steering, per-token probability panel with click-to-pin top-K candidates)
|
| 20 |
+
- **Evaluation** — capability-aware corpus analysis (paste a prompt set, encode each, see which features fire across the corpus; compare two prompt sets for distinguishing features)
|
| 21 |
+
- **Data-centric** — feature-guided curation (filter docs by feature signature) and steered synthesis (bulk-generate steered completions from seed prompts)
|
| 22 |
+
|
| 23 |
+
(The fourth pillar — Post-training pathology diagnosis — is not yet implemented.)
|
| 24 |
+
|
| 25 |
+
## How to use
|
| 26 |
+
|
| 27 |
+
1. **Wait ~3-5 minutes** on first load — the container has to download the model (~3.4GB) and SAE (~537MB) on cold start.
|
| 28 |
+
2. **Steering tab**: type a prompt, click *Encode*, click *steer* on any top feature, drag the slider, click *Generate*. The output panels show baseline vs. steered side-by-side as colored token chips — click any chip to see top-K next-token candidates. Per-token feature heatmap appears in the right column showing which features fire on which tokens.
|
| 29 |
+
3. **Evaluation tab**: paste many prompts (one per line), encode the corpus, see per-sample top features, signature heatmap, and a Compare panel for finding features that distinguish set A from set B.
|
| 30 |
+
4. **Data-centric tab**: paste a corpus, filter it by feature signature (include/exclude docs where feature X fires), and run bulk steered synthesis to produce N targeted training examples.
|
| 31 |
+
5. **Layer selector** in the header: change the SAE layer (0–27) to see how features differ across model depth. First swap to a layer takes ~20s; subsequent swaps are <0.1s thanks to an in-memory LRU cache.
|
| 32 |
+
|
| 33 |
+
## Hardware constraint
|
| 34 |
+
|
| 35 |
+
This Space is **locked to Qwen3-1.7B-Base** because that's what fits in a free-tier HF Space (CPU, 16GB RAM). The full local version supports the entire Qwen-Scope catalog including Qwen3-8B, Qwen3.5-9B/27B, Qwen3.6-27B/35B-A3B, and Qwen3-30B-A3B — each with its matching SAE — but those need GPU hardware and persistent storage. See the source repository for instructions.
|
| 36 |
+
|
| 37 |
+
## What's verified
|
| 38 |
+
|
| 39 |
+
This deployment is the same code that was [verified end-to-end on macOS / Apple Silicon MPS](https://huggingface.co/Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50) — TopK SAE encode, residual hook, additive steering, per-token probability extraction, position-selective hooks across prefill + decode steps. The math is identical; only the device (CPU instead of MPS) and dtype (fp32 instead of bf16) differ in the deployment build.
|
| 40 |
+
|
| 41 |
+
## License
|
| 42 |
+
|
| 43 |
+
Apache 2.0 (this code). The Qwen-Scope SAE weights and Qwen3 model weights are governed by their respective Qwen licenses.
|
| 44 |
+
|
| 45 |
+
## Citation
|
| 46 |
+
|
| 47 |
+
If you use this in your work, cite the Qwen-Scope paper:
|
| 48 |
+
|
| 49 |
+
```bibtex
|
| 50 |
+
@misc{qwen_scope,
|
| 51 |
+
title = {{Qwen-Scope}: Turning Sparse Features into Development Tools for Large Language Models},
|
| 52 |
+
url = {https://qianwen-res.oss-accelerate.aliyuncs.com/qwen-scope/Qwen_Scope.pdf},
|
| 53 |
+
author = {{Qwen Team}},
|
| 54 |
+
month = {April},
|
| 55 |
+
year = {2026}
|
| 56 |
+
}
|
| 57 |
+
```
|
__pycache__/qwen_scope_steer.cpython-314.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
__pycache__/server.cpython-314.pyc
ADDED
|
Binary file (46.5 kB). View file
|
|
|
index.html
ADDED
|
@@ -0,0 +1,1564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 6 |
+
<title>Qwen-Scope · Live SAE Feature Steering</title>
|
| 7 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
--bg: #08080d;
|
| 11 |
+
--bg-panel: rgba(18, 18, 26, 0.78);
|
| 12 |
+
--bg-panel-strong: rgba(24, 24, 34, 0.92);
|
| 13 |
+
--border: rgba(255,255,255,0.08);
|
| 14 |
+
--border-strong: rgba(255,255,255,0.16);
|
| 15 |
+
--fg: #ececf1;
|
| 16 |
+
--fg-dim: #9b9ba8;
|
| 17 |
+
--fg-faint: #5e5e6e;
|
| 18 |
+
--accent: #7df9ff;
|
| 19 |
+
--accent-2: #ff7df9;
|
| 20 |
+
--warn: #ffb84d;
|
| 21 |
+
--good: #74e2a3;
|
| 22 |
+
--bad: #ff7d7d;
|
| 23 |
+
--mono: ui-monospace, "SF Mono", "JetBrains Mono", Menlo, monospace;
|
| 24 |
+
--sans: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
| 25 |
+
}
|
| 26 |
+
* { box-sizing: border-box; }
|
| 27 |
+
html, body { margin:0; padding:0; height:100%; background:var(--bg); color:var(--fg); font-family:var(--sans); overflow:hidden; }
|
| 28 |
+
#scene { position:fixed; inset:0; z-index:0; }
|
| 29 |
+
#ui { position:fixed; inset:0; z-index:10; pointer-events:none; display:grid;
|
| 30 |
+
grid-template-columns: 380px 1fr 420px;
|
| 31 |
+
grid-template-rows: 1fr auto;
|
| 32 |
+
gap:14px; padding:14px; }
|
| 33 |
+
#ui > section { pointer-events:auto; }
|
| 34 |
+
|
| 35 |
+
.panel { background:var(--bg-panel); border:1px solid var(--border);
|
| 36 |
+
border-radius:14px; padding:14px; backdrop-filter: blur(14px) saturate(140%);
|
| 37 |
+
-webkit-backdrop-filter: blur(14px) saturate(140%); overflow:hidden;
|
| 38 |
+
transition: padding .15s ease, max-height .25s ease; }
|
| 39 |
+
.panel h2 { margin:0 0 10px; font-size:11px; font-weight:600;
|
| 40 |
+
text-transform:uppercase; letter-spacing:0.14em; color:var(--fg-dim);
|
| 41 |
+
display:flex; justify-content:space-between; align-items:center; gap:8px; }
|
| 42 |
+
.panel h3 { margin:0 0 6px; font-size:13px; font-weight:600; color:var(--fg); }
|
| 43 |
+
|
| 44 |
+
/* Minimize button on each panel header */
|
| 45 |
+
.min-btn { background:transparent; border:1px solid var(--border-strong);
|
| 46 |
+
color:var(--fg-dim); border-radius:5px; padding:1px 8px; font-size:14px;
|
| 47 |
+
font-family:var(--mono); cursor:pointer; line-height:1; min-width:0;
|
| 48 |
+
font-weight:400; letter-spacing:0; }
|
| 49 |
+
.min-btn:hover { background:var(--bg-panel-strong); color:var(--fg);
|
| 50 |
+
border-color:var(--accent); transform:none; box-shadow:none; }
|
| 51 |
+
.panel.collapsed { padding:8px 14px; }
|
| 52 |
+
.panel.collapsed > h2 { margin-bottom:0; }
|
| 53 |
+
.panel.collapsed > :not(h2) { display:none; }
|
| 54 |
+
|
| 55 |
+
/* Header strip */
|
| 56 |
+
#header { position:fixed; top:14px; left:50%; transform:translateX(-50%);
|
| 57 |
+
z-index:20; pointer-events:auto;
|
| 58 |
+
background:var(--bg-panel-strong); border:1px solid var(--border-strong);
|
| 59 |
+
border-radius:999px; padding:8px 18px; display:flex; gap:14px;
|
| 60 |
+
align-items:center; font-family:var(--mono); font-size:12px; color:var(--fg-dim);
|
| 61 |
+
backdrop-filter: blur(20px); }
|
| 62 |
+
#header .dot { width:8px; height:8px; border-radius:50%; background:var(--bad); transition:background .3s; }
|
| 63 |
+
#header.live .dot { background:var(--good); box-shadow:0 0 12px var(--good); }
|
| 64 |
+
#header.loading .dot { background:var(--warn); box-shadow:0 0 12px var(--warn); animation:pulse 1.5s ease-in-out infinite; }
|
| 65 |
+
@keyframes pulse { 0%,100% { opacity:1; } 50% { opacity:0.4; } }
|
| 66 |
+
#header b { color:var(--fg); font-weight:600; }
|
| 67 |
+
#model-select {
|
| 68 |
+
background:transparent; color:var(--fg); border:1px solid var(--border-strong);
|
| 69 |
+
border-radius:6px; padding:3px 8px; font-family:var(--mono); font-size:11px;
|
| 70 |
+
outline:none; cursor:pointer; max-width:280px;
|
| 71 |
+
}
|
| 72 |
+
#model-select:focus { border-color:var(--accent); }
|
| 73 |
+
#model-select option { background:#16161e; color:var(--fg); }
|
| 74 |
+
|
| 75 |
+
/* Tab strip */
|
| 76 |
+
#tabs { position:fixed; top:62px; left:50%; transform:translateX(-50%); z-index:18;
|
| 77 |
+
display:flex; gap:6px; pointer-events:auto; background:var(--bg-panel-strong);
|
| 78 |
+
border:1px solid var(--border-strong); border-radius:999px;
|
| 79 |
+
padding:4px; backdrop-filter: blur(20px); }
|
| 80 |
+
.tab { background:transparent; border:none; color:var(--fg-dim);
|
| 81 |
+
padding:6px 16px; font-size:11px; font-weight:600;
|
| 82 |
+
text-transform:uppercase; letter-spacing:0.08em; border-radius:999px;
|
| 83 |
+
cursor:pointer; transition: all 0.15s ease; }
|
| 84 |
+
.tab:hover { color:var(--fg); background:rgba(255,255,255,0.04); }
|
| 85 |
+
.tab.active { background: linear-gradient(135deg, rgba(125,249,255,0.18), rgba(255,125,249,0.12));
|
| 86 |
+
color:var(--fg); border:1px solid var(--border-strong); }
|
| 87 |
+
|
| 88 |
+
/* Tab visibility — hide panels not matching the current tab */
|
| 89 |
+
body[data-tab="steering"] [data-show-tab]:not([data-show-tab~="steering"]) { display:none !important; }
|
| 90 |
+
body[data-tab="evaluation"] [data-show-tab]:not([data-show-tab~="evaluation"]) { display:none !important; }
|
| 91 |
+
body[data-tab="datacentric"] [data-show-tab]:not([data-show-tab~="datacentric"]) { display:none !important; }
|
| 92 |
+
|
| 93 |
+
/* Loading overlay shown during model swap */
|
| 94 |
+
#loading-overlay {
|
| 95 |
+
position:fixed; inset:0; z-index:50; display:none;
|
| 96 |
+
background:rgba(8,8,13,0.78); backdrop-filter: blur(6px);
|
| 97 |
+
align-items:center; justify-content:center;
|
| 98 |
+
}
|
| 99 |
+
#loading-overlay.visible { display:flex; }
|
| 100 |
+
#loading-overlay .lo-card {
|
| 101 |
+
background:var(--bg-panel-strong); border:1px solid var(--border-strong);
|
| 102 |
+
border-radius:14px; padding:24px 32px; text-align:center; min-width:300px;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
/* Left: prompt + controls (top-left of screen, below the header) */
|
| 106 |
+
#left { display:flex; flex-direction:column; gap:14px; min-height:0;
|
| 107 |
+
grid-column: 1; grid-row: 1; align-self: start;
|
| 108 |
+
padding-top: 110px; }
|
| 109 |
+
textarea, input[type="text"], input[type="number"] {
|
| 110 |
+
width:100%; background:rgba(0,0,0,0.35); border:1px solid var(--border-strong);
|
| 111 |
+
color:var(--fg); border-radius:8px; padding:10px 12px; font-family:var(--mono);
|
| 112 |
+
font-size:13px; outline:none; resize:vertical;
|
| 113 |
+
}
|
| 114 |
+
textarea:focus, input:focus { border-color:var(--accent); box-shadow:0 0 0 3px rgba(125,249,255,0.12); }
|
| 115 |
+
textarea { min-height:72px; }
|
| 116 |
+
.row { display:flex; gap:8px; align-items:center; }
|
| 117 |
+
.row > * { flex:1; }
|
| 118 |
+
.row > .grow0 { flex:0; }
|
| 119 |
+
|
| 120 |
+
button { background: linear-gradient(135deg, rgba(125,249,255,0.18), rgba(255,125,249,0.12));
|
| 121 |
+
border:1px solid var(--border-strong); color:var(--fg); border-radius:8px;
|
| 122 |
+
padding:9px 14px; font-family:var(--sans); font-size:13px; font-weight:600;
|
| 123 |
+
cursor:pointer; transition:all .15s; letter-spacing:0.02em; }
|
| 124 |
+
button:hover:not(:disabled) { border-color:var(--accent); transform:translateY(-1px); box-shadow:0 4px 14px rgba(125,249,255,0.18); }
|
| 125 |
+
button:disabled { opacity:0.4; cursor:not-allowed; }
|
| 126 |
+
button.primary { background: linear-gradient(135deg, #7df9ff, #ff7df9); color:#0a0a14; border-color:transparent; }
|
| 127 |
+
button.primary:hover:not(:disabled) { box-shadow:0 4px 18px rgba(255,125,249,0.32); }
|
| 128 |
+
button.ghost { background:transparent; }
|
| 129 |
+
|
| 130 |
+
/* Middle: nothing (3D scene shows through) */
|
| 131 |
+
#middle { grid-column: 2; grid-row: 1; pointer-events:none; }
|
| 132 |
+
|
| 133 |
+
/* Bottom: output band spans all columns; sits below cloud area */
|
| 134 |
+
#bottom { grid-column: 1 / -1; grid-row: 2; pointer-events:auto;
|
| 135 |
+
display:grid; grid-template-columns: minmax(0, 1fr) minmax(0, 1fr);
|
| 136 |
+
gap:14px;
|
| 137 |
+
max-height: 38vh; min-height: 0;
|
| 138 |
+
margin: 0 0 0 0; }
|
| 139 |
+
#bottom .panel { display:flex; flex-direction:column;
|
| 140 |
+
min-width: 0; min-height: 0; overflow:hidden; }
|
| 141 |
+
#bottom .panel.collapsed { padding:8px 14px; }
|
| 142 |
+
#bottom .panel.collapsed > .panel-body { display:none; }
|
| 143 |
+
.panel-body { flex: 1 1 auto; overflow:auto; min-height:0; min-width:0; }
|
| 144 |
+
#right .panel { min-width:0; max-width:100%; }
|
| 145 |
+
#right .panel > div { max-width:100%; overflow:auto; }
|
| 146 |
+
|
| 147 |
+
/* Right: top features panel only — output moved to bottom band */
|
| 148 |
+
#right { display:flex; flex-direction:column; gap:14px; min-height:0;
|
| 149 |
+
grid-column: 3; grid-row: 1; max-height: calc(100vh - 28px); padding-top:110px; }
|
| 150 |
+
|
| 151 |
+
#features { flex: 1 1 auto; min-height: 120px; overflow-y:auto; }
|
| 152 |
+
#features::-webkit-scrollbar { width:6px; }
|
| 153 |
+
#features::-webkit-scrollbar-thumb { background:var(--border-strong); border-radius:3px; }
|
| 154 |
+
|
| 155 |
+
.feat { padding:10px; border:1px solid var(--border); border-radius:10px;
|
| 156 |
+
margin-bottom:8px; background:rgba(0,0,0,0.25); transition: border-color .2s, background .2s; }
|
| 157 |
+
.feat:hover { border-color:var(--accent); }
|
| 158 |
+
.feat.steered { border-color:var(--accent-2); background:rgba(255,125,249,0.06); }
|
| 159 |
+
.feat-head { display:flex; justify-content:space-between; align-items:center; margin-bottom:6px; }
|
| 160 |
+
.feat-id { font-family:var(--mono); font-size:12px; color:var(--fg); }
|
| 161 |
+
.feat-act { font-family:var(--mono); font-size:11px; color:var(--good); }
|
| 162 |
+
.feat-act.steered { color:var(--accent-2); }
|
| 163 |
+
.slider-row { display:flex; align-items:center; gap:8px; margin-top:4px; }
|
| 164 |
+
.slider-row input[type=range] { flex:1; -webkit-appearance:none; height:4px; background:var(--border-strong); border-radius:2px; outline:none; }
|
| 165 |
+
.slider-row input[type=range]::-webkit-slider-thumb {
|
| 166 |
+
-webkit-appearance:none; width:14px; height:14px; border-radius:50%;
|
| 167 |
+
background: linear-gradient(135deg, #7df9ff, #ff7df9); cursor:pointer;
|
| 168 |
+
box-shadow:0 0 8px rgba(125,249,255,0.5); border:none;
|
| 169 |
+
}
|
| 170 |
+
.slider-row input[type=range]::-moz-range-thumb {
|
| 171 |
+
width:14px; height:14px; border-radius:50%;
|
| 172 |
+
background: linear-gradient(135deg, #7df9ff, #ff7df9); cursor:pointer; border:none;
|
| 173 |
+
}
|
| 174 |
+
.slider-row .alpha-val { font-family:var(--mono); font-size:11px; color:var(--fg-dim); width:48px; text-align:right; }
|
| 175 |
+
.feat-tools { display:flex; gap:6px; }
|
| 176 |
+
.feat-tools button { padding:3px 8px; font-size:10px; font-weight:500; }
|
| 177 |
+
|
| 178 |
+
/* Output area */
|
| 179 |
+
#output-wrap { flex:1; overflow-y:auto; }
|
| 180 |
+
#output-wrap::-webkit-scrollbar { width:6px; }
|
| 181 |
+
#output-wrap::-webkit-scrollbar-thumb { background:var(--border-strong); border-radius:3px; }
|
| 182 |
+
.out-block { font-family:var(--mono); font-size:12px; line-height:1.55;
|
| 183 |
+
background:rgba(0,0,0,0.4); border:1px solid var(--border);
|
| 184 |
+
border-radius:8px; padding:10px; margin-bottom:8px; white-space:pre-wrap;
|
| 185 |
+
word-break: break-word; color:var(--fg); min-height:36px; }
|
| 186 |
+
.out-block.baseline { border-color:rgba(125,249,255,0.3); }
|
| 187 |
+
.out-block.steered { border-color:rgba(255,125,249,0.3); }
|
| 188 |
+
.out-label { display:flex; justify-content:space-between; align-items:center; margin-bottom:4px;
|
| 189 |
+
font-size:10px; text-transform:uppercase; letter-spacing:0.12em; color:var(--fg-dim); }
|
| 190 |
+
.verifier { font-family:var(--mono); font-size:10px; color:var(--fg-dim); margin-top:6px; padding:6px 8px;
|
| 191 |
+
border-left:2px solid var(--border-strong); }
|
| 192 |
+
.verifier .delta-up { color:var(--good); }
|
| 193 |
+
.verifier .delta-down { color:var(--accent-2); }
|
| 194 |
+
|
| 195 |
+
/* Hover tooltip for selected feature on 3D cloud */
|
| 196 |
+
#tooltip { position:fixed; z-index:25; pointer-events:none;
|
| 197 |
+
background:var(--bg-panel-strong); border:1px solid var(--border-strong);
|
| 198 |
+
border-radius:8px; padding:8px 12px; font-family:var(--mono); font-size:11px;
|
| 199 |
+
line-height:1.45;
|
| 200 |
+
color:var(--fg); display:none; transform: translate(-50%, calc(-100% - 12px));
|
| 201 |
+
min-width:180px; max-width:280px; backdrop-filter: blur(20px); }
|
| 202 |
+
|
| 203 |
+
/* Loader */
|
| 204 |
+
.loader { display:inline-block; width:12px; height:12px; border:2px solid var(--border-strong);
|
| 205 |
+
border-top-color:var(--accent); border-radius:50%; animation:spin 0.8s linear infinite; }
|
| 206 |
+
@keyframes spin { to { transform: rotate(360deg); } }
|
| 207 |
+
|
| 208 |
+
.empty { color:var(--fg-faint); font-size:12px; text-align:center; padding:24px 8px; }
|
| 209 |
+
|
| 210 |
+
.hint { font-size:11px; color:var(--fg-faint); margin-top:6px; line-height:1.4; }
|
| 211 |
+
|
| 212 |
+
.footer { position:fixed; bottom:8px; left:14px; z-index:20; font-family:var(--mono);
|
| 213 |
+
font-size:10px; color:var(--fg-faint); letter-spacing:0.04em; }
|
| 214 |
+
|
| 215 |
+
.legend { font-family:var(--mono); font-size:10px; color:var(--fg-faint); display:flex; gap:12px; margin-top:6px; }
|
| 216 |
+
.legend span::before { content:''; display:inline-block; width:8px; height:8px; border-radius:50%; margin-right:5px; vertical-align:middle; }
|
| 217 |
+
.legend .top::before { background:var(--accent); box-shadow:0 0 6px var(--accent); }
|
| 218 |
+
.legend .pick::before { background:var(--accent-2); box-shadow:0 0 6px var(--accent-2); }
|
| 219 |
+
.legend .dim::before { background:var(--fg-faint); }
|
| 220 |
+
|
| 221 |
+
/* small-screen graceful fallback */
|
| 222 |
+
@media (max-width: 1100px) {
|
| 223 |
+
#ui { grid-template-columns: 1fr; grid-template-rows: auto auto; }
|
| 224 |
+
#left { grid-column:1; grid-row:1; max-height:none; }
|
| 225 |
+
#right { grid-column:1; grid-row:2; padding-top:14px; max-height:none; }
|
| 226 |
+
#middle { display:none; }
|
| 227 |
+
}
|
| 228 |
+
</style>
|
| 229 |
+
</head>
|
| 230 |
+
<body>
|
| 231 |
+
<canvas id="scene"></canvas>
|
| 232 |
+
|
| 233 |
+
<div id="header">
|
| 234 |
+
<span class="dot"></span>
|
| 235 |
+
<span>QWEN-SCOPE LIVE</span>
|
| 236 |
+
<span style="opacity:.4">·</span>
|
| 237 |
+
<select id="model-select" title="Switch the loaded model + SAE pair">
|
| 238 |
+
<option>connecting…</option>
|
| 239 |
+
</select>
|
| 240 |
+
<span style="opacity:.4">·</span>
|
| 241 |
+
<span id="hdr-layer-wrap" title="Click to change SAE layer">
|
| 242 |
+
layer <input type="number" id="layer-input" value="?" min="0" max="0"
|
| 243 |
+
style="width:48px; padding:1px 4px; background:transparent; color:var(--accent);
|
| 244 |
+
border:1px solid var(--border-strong); border-radius:4px;
|
| 245 |
+
font-family:var(--mono); font-size:12px; text-align:center;" />
|
| 246 |
+
<span id="hdr-layer-meta">· mps · bfloat16</span>
|
| 247 |
+
</span>
|
| 248 |
+
<span style="opacity:.4">·</span>
|
| 249 |
+
<span id="hdr-features">— features</span>
|
| 250 |
+
</div>
|
| 251 |
+
|
| 252 |
+
<div id="tabs">
|
| 253 |
+
<button class="tab active" data-tab="steering">Steering</button>
|
| 254 |
+
<button class="tab" data-tab="evaluation">Evaluation</button>
|
| 255 |
+
<button class="tab" data-tab="datacentric">Data-centric</button>
|
| 256 |
+
</div>
|
| 257 |
+
|
| 258 |
+
<div id="loading-overlay">
|
| 259 |
+
<div class="lo-card">
|
| 260 |
+
<div class="loader" style="width:24px;height:24px;border-width:3px;"></div>
|
| 261 |
+
<div id="lo-msg" style="margin-top:14px; font-size:13px; color:var(--fg);">Loading…</div>
|
| 262 |
+
<div id="lo-detail" style="margin-top:6px; font-size:11px; color:var(--fg-faint); font-family:var(--mono);"></div>
|
| 263 |
+
</div>
|
| 264 |
+
</div>
|
| 265 |
+
|
| 266 |
+
<div id="ui">
|
| 267 |
+
<section id="left">
|
| 268 |
+
<div class="panel" data-pid="prompt" data-show-tab="steering">
|
| 269 |
+
<h2>1 · Prompt <button class="min-btn" data-min="prompt" title="Minimize panel">−</button></h2>
|
| 270 |
+
<textarea id="prompt" placeholder="The capital of France is">The capital of France is</textarea>
|
| 271 |
+
<div class="row" style="margin-top:8px;">
|
| 272 |
+
<label style="font-size:11px; color:var(--fg-dim); flex:0 0 auto;">top K</label>
|
| 273 |
+
<input type="number" id="top-k" value="20" min="1" max="500" style="flex:0 0 80px;" />
|
| 274 |
+
<button class="primary" id="btn-encode" style="flex:1;">Encode & show top features</button>
|
| 275 |
+
</div>
|
| 276 |
+
<div class="hint">
|
| 277 |
+
Encodes the prompt's last-token residual at the SAE's layer and ranks the firing features.
|
| 278 |
+
TopK SAE has at most K=50 nonzero features per token, so request up to ~50.
|
| 279 |
+
</div>
|
| 280 |
+
</div>
|
| 281 |
+
<div class="panel" data-pid="generate" data-show-tab="steering">
|
| 282 |
+
<h2>2 · Generate <button class="min-btn" data-min="generate" title="Minimize panel">−</button></h2>
|
| 283 |
+
<div class="row">
|
| 284 |
+
<label style="font-size:11px; color:var(--fg-dim); flex:0 0 auto;">tokens</label>
|
| 285 |
+
<input type="number" id="max-tokens" value="40" min="5" max="200" style="flex:0 0 80px;" />
|
| 286 |
+
<button class="primary" id="btn-generate" disabled>Generate baseline + steered</button>
|
| 287 |
+
</div>
|
| 288 |
+
<div class="hint">
|
| 289 |
+
Runs <code>model.generate</code> twice: once with no hooks, once with all active sliders applied
|
| 290 |
+
as residual-stream additions <code>h ← h + α · W_dec[:, feat]</code>.
|
| 291 |
+
</div>
|
| 292 |
+
</div>
|
| 293 |
+
<div class="panel" data-pid="viz" data-show-tab="steering evaluation datacentric" style="flex:0 0 auto;">
|
| 294 |
+
<h2>3 · Visualization <button class="min-btn" data-min="viz" title="Minimize panel">−</button></h2>
|
| 295 |
+
<div class="legend">
|
| 296 |
+
<span class="dim">all 32K features</span>
|
| 297 |
+
<span class="top">top firing</span>
|
| 298 |
+
<span class="pick">steered</span>
|
| 299 |
+
</div>
|
| 300 |
+
<div class="hint">
|
| 301 |
+
Each point is one SAE feature; positions are the top-3 PCA components of <code>W_enc</code>.
|
| 302 |
+
Click a point to add it as a steering target.
|
| 303 |
+
</div>
|
| 304 |
+
</div>
|
| 305 |
+
</section>
|
| 306 |
+
|
| 307 |
+
<section id="middle"></section>
|
| 308 |
+
|
| 309 |
+
<section id="right">
|
| 310 |
+
<div class="panel" id="features-panel" data-pid="features" data-show-tab="steering"
|
| 311 |
+
style="flex: 0 1 auto; max-height: 38vh; display:flex; flex-direction:column; min-height:0;">
|
| 312 |
+
<h2>Top features <span id="feat-count" style="font-family:var(--mono); color:var(--fg-faint);"></span>
|
| 313 |
+
<button class="min-btn" data-min="features" title="Minimize panel">−</button></h2>
|
| 314 |
+
<div id="features" style="overflow-y:auto; flex:1 1 auto; min-height:0;"><div class="empty">Encode a prompt to populate features.</div></div>
|
| 315 |
+
</div>
|
| 316 |
+
|
| 317 |
+
<div class="panel" id="heatmap-panel" data-show-tab="steering"
|
| 318 |
+
style="flex: 1 1 auto; display:flex; flex-direction:column; min-height:200px;">
|
| 319 |
+
<h2>Per-token feature heatmap
|
| 320 |
+
<span style="display:flex; gap:8px; align-items:center;">
|
| 321 |
+
<label style="font-size:10px; color:var(--fg-faint);">
|
| 322 |
+
<input type="checkbox" id="heatmap-skip-first" style="vertical-align:middle;" /> skip first
|
| 323 |
+
</label>
|
| 324 |
+
<button class="min-btn" data-min="heatmap-panel" title="Minimize panel">−</button>
|
| 325 |
+
</span>
|
| 326 |
+
</h2>
|
| 327 |
+
<div id="heatmap-grid" style="overflow:auto; flex:1 1 auto; min-height:0;">
|
| 328 |
+
<span class="empty">Encode a prompt — heatmap fills automatically.</span>
|
| 329 |
+
</div>
|
| 330 |
+
</div>
|
| 331 |
+
|
| 332 |
+
<!-- Evaluation: corpus encoding + signature heatmap -->
|
| 333 |
+
<div class="panel" data-show-tab="evaluation" id="eval-corpus">
|
| 334 |
+
<h2>Evaluation · corpus encode
|
| 335 |
+
<button class="min-btn" data-min="eval-corpus" title="Minimize panel">−</button>
|
| 336 |
+
</h2>
|
| 337 |
+
<textarea id="eval-prompts" rows="6" placeholder="One prompt per line. Example: The capital of France is Bonjour comment allez-vous def fibonacci(n): The mitochondria is the powerhouse"></textarea>
|
| 338 |
+
<div class="row" style="margin-top:8px;">
|
| 339 |
+
<button class="primary" id="btn-eval-encode" style="flex:1;">Encode corpus</button>
|
| 340 |
+
</div>
|
| 341 |
+
<div class="hint">
|
| 342 |
+
Encodes each prompt's last-token residual through the SAE. Returns
|
| 343 |
+
per-sample top features and corpus-level firing rates.
|
| 344 |
+
</div>
|
| 345 |
+
</div>
|
| 346 |
+
|
| 347 |
+
<div class="panel" data-show-tab="evaluation" id="eval-corpus-features">
|
| 348 |
+
<h2>Corpus features <span id="eval-stats" style="font-family:var(--mono); color:var(--fg-faint);"></span>
|
| 349 |
+
<button class="min-btn" data-min="eval-corpus-features" title="Minimize panel">−</button>
|
| 350 |
+
</h2>
|
| 351 |
+
<div id="eval-features-list" style="overflow-y:auto; max-height:30vh;">
|
| 352 |
+
<div class="empty">Encode a corpus to see feature firing rates.</div>
|
| 353 |
+
</div>
|
| 354 |
+
</div>
|
| 355 |
+
|
| 356 |
+
<div class="panel" data-show-tab="evaluation" id="eval-compare">
|
| 357 |
+
<h2>Compare two prompt sets
|
| 358 |
+
<button class="min-btn" data-min="eval-compare" title="Minimize panel">−</button>
|
| 359 |
+
</h2>
|
| 360 |
+
<textarea id="cmp-a" rows="3" placeholder="Set A — one prompt per line"></textarea>
|
| 361 |
+
<textarea id="cmp-b" rows="3" placeholder="Set B — one prompt per line" style="margin-top:6px;"></textarea>
|
| 362 |
+
<div class="row" style="margin-top:6px;">
|
| 363 |
+
<button class="primary" id="btn-cmp" style="flex:1;">Find distinguishing features</button>
|
| 364 |
+
</div>
|
| 365 |
+
<div class="hint">
|
| 366 |
+
Encodes both sets and ranks features by |fire_rate(A) − fire_rate(B)|.
|
| 367 |
+
Shows which features distinguish A from B.
|
| 368 |
+
</div>
|
| 369 |
+
</div>
|
| 370 |
+
|
| 371 |
+
<!-- Data-centric: filter + steered synthesis -->
|
| 372 |
+
<div class="panel" data-show-tab="datacentric" id="dc-corpus">
|
| 373 |
+
<h2>Data-centric · corpus
|
| 374 |
+
<button class="min-btn" data-min="dc-corpus" title="Minimize panel">−</button>
|
| 375 |
+
</h2>
|
| 376 |
+
<textarea id="dc-prompts" rows="5" placeholder="One prompt per line — same shape as Evaluation."></textarea>
|
| 377 |
+
<div class="row" style="margin-top:8px;">
|
| 378 |
+
<button class="primary" id="btn-dc-encode" style="flex:1;">Encode for filtering</button>
|
| 379 |
+
</div>
|
| 380 |
+
<div class="hint">
|
| 381 |
+
Encode docs, then filter by feature signature.
|
| 382 |
+
</div>
|
| 383 |
+
</div>
|
| 384 |
+
|
| 385 |
+
<div class="panel" data-show-tab="datacentric" id="dc-filter">
|
| 386 |
+
<h2>Filter
|
| 387 |
+
<button class="min-btn" data-min="dc-filter" title="Minimize panel">−</button>
|
| 388 |
+
</h2>
|
| 389 |
+
<div class="row">
|
| 390 |
+
<label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">feature</label>
|
| 391 |
+
<input type="number" id="dc-filter-id" placeholder="feature id" min="0" max="199999" style="flex:1;" />
|
| 392 |
+
<select id="dc-filter-mode" style="flex:0 0 auto; background:rgba(0,0,0,0.35); color:var(--fg); border:1px solid var(--border-strong); border-radius:6px; padding:6px 8px; font-family:var(--mono); font-size:11px;">
|
| 393 |
+
<option value="include">includes</option>
|
| 394 |
+
<option value="exclude">excludes</option>
|
| 395 |
+
</select>
|
| 396 |
+
</div>
|
| 397 |
+
<div class="row" style="margin-top:8px;">
|
| 398 |
+
<button id="btn-dc-filter" style="flex:1;">Apply filter</button>
|
| 399 |
+
<button id="btn-dc-clear" class="ghost" style="flex:0 0 auto;">clear</button>
|
| 400 |
+
</div>
|
| 401 |
+
<div class="hint">
|
| 402 |
+
"Includes" = keep only docs where this feature fired (any activation).
|
| 403 |
+
"Excludes" = drop docs where it fired.
|
| 404 |
+
</div>
|
| 405 |
+
</div>
|
| 406 |
+
|
| 407 |
+
<div class="panel" data-show-tab="datacentric" id="dc-synth">
|
| 408 |
+
<h2>Steered synthesis
|
| 409 |
+
<button class="min-btn" data-min="dc-synth" title="Minimize panel">−</button>
|
| 410 |
+
</h2>
|
| 411 |
+
<div class="row">
|
| 412 |
+
<label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">feature</label>
|
| 413 |
+
<input type="number" id="dc-synth-id" placeholder="feature id" min="0" max="199999" style="flex:1;" />
|
| 414 |
+
<label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">α</label>
|
| 415 |
+
<input type="number" id="dc-synth-alpha" value="50" min="-200" max="200" style="flex:0 0 60px;" />
|
| 416 |
+
</div>
|
| 417 |
+
<div class="row" style="margin-top:6px;">
|
| 418 |
+
<label style="flex:0 0 auto; font-size:11px; color:var(--fg-dim);">tokens</label>
|
| 419 |
+
<input type="number" id="dc-synth-tokens" value="30" min="5" max="200" style="flex:0 0 70px;" />
|
| 420 |
+
<button class="primary" id="btn-dc-synth" style="flex:1;">Synthesize from corpus</button>
|
| 421 |
+
</div>
|
| 422 |
+
<div class="hint">
|
| 423 |
+
For each prompt in the corpus above, generate a steered completion with
|
| 424 |
+
<code>h ← h + α · W_dec[:, feature]</code>. Useful for producing
|
| 425 |
+
targeted training examples that fire a specific feature.
|
| 426 |
+
</div>
|
| 427 |
+
</div>
|
| 428 |
+
</section>
|
| 429 |
+
|
| 430 |
+
<!-- Bottom band -->
|
| 431 |
+
<section id="bottom">
|
| 432 |
+
<!-- STEERING bottom: baseline / steered / heatmap -->
|
| 433 |
+
<div class="panel" data-show-tab="steering" data-pid="baseline">
|
| 434 |
+
<h2>Baseline output (α=0)
|
| 435 |
+
<span style="display:flex; gap:8px; align-items:center;">
|
| 436 |
+
<span id="base-time" style="font-family:var(--mono); color:var(--fg-faint);"></span>
|
| 437 |
+
<button class="min-btn" data-min="baseline" title="Minimize panel">−</button>
|
| 438 |
+
</span>
|
| 439 |
+
</h2>
|
| 440 |
+
<div class="panel-body">
|
| 441 |
+
<div class="out-block baseline" id="out-baseline"><span class="empty">(no run yet)</span></div>
|
| 442 |
+
</div>
|
| 443 |
+
</div>
|
| 444 |
+
<div class="panel" data-show-tab="steering" data-pid="steered">
|
| 445 |
+
<h2>Steered output
|
| 446 |
+
<span style="display:flex; gap:8px; align-items:center;">
|
| 447 |
+
<span id="steered-time" style="font-family:var(--mono); color:var(--fg-faint);"></span>
|
| 448 |
+
<button class="min-btn" data-min="steered" title="Minimize panel">−</button>
|
| 449 |
+
</span>
|
| 450 |
+
</h2>
|
| 451 |
+
<div class="panel-body">
|
| 452 |
+
<div class="out-block steered" id="out-steered"><span class="empty">(no run yet)</span></div>
|
| 453 |
+
<div class="verifier" id="verifier"></div>
|
| 454 |
+
</div>
|
| 455 |
+
</div>
|
| 456 |
+
|
| 457 |
+
<!-- EVALUATION bottom: per-sample table + signature heatmap -->
|
| 458 |
+
<div class="panel" data-show-tab="evaluation" id="eval-samples-panel" style="grid-column: 1 / 2;">
|
| 459 |
+
<h2>Per-sample top features
|
| 460 |
+
<button class="min-btn" data-min="eval-samples-panel" title="Minimize panel">−</button>
|
| 461 |
+
</h2>
|
| 462 |
+
<div class="panel-body">
|
| 463 |
+
<div id="eval-samples"><span class="empty">Encode a corpus to populate.</span></div>
|
| 464 |
+
</div>
|
| 465 |
+
</div>
|
| 466 |
+
<div class="panel" data-show-tab="evaluation" id="eval-heatmap-panel" style="grid-column: 2 / 3;">
|
| 467 |
+
<h2>Signature heatmap / Compare results
|
| 468 |
+
<span style="display:flex; gap:6px; align-items:center;">
|
| 469 |
+
<button id="btn-eval-view-heatmap" class="ghost" style="padding:2px 8px; font-size:10px;">heatmap</button>
|
| 470 |
+
<button id="btn-eval-view-compare" class="ghost" style="padding:2px 8px; font-size:10px;">compare</button>
|
| 471 |
+
<button class="min-btn" data-min="eval-heatmap-panel" title="Minimize panel">−</button>
|
| 472 |
+
</span>
|
| 473 |
+
</h2>
|
| 474 |
+
<div class="panel-body">
|
| 475 |
+
<div id="eval-heatmap"><span class="empty">Encode a corpus to populate.</span></div>
|
| 476 |
+
<div id="eval-compare-results" style="display:none;"><span class="empty">Run Compare to populate.</span></div>
|
| 477 |
+
</div>
|
| 478 |
+
</div>
|
| 479 |
+
|
| 480 |
+
<!-- DATA-CENTRIC bottom: filtered list + synth output -->
|
| 481 |
+
<div class="panel" data-show-tab="datacentric" id="dc-filtered-panel">
|
| 482 |
+
<h2>Filtered docs <span id="dc-filter-stats" style="font-family:var(--mono); color:var(--fg-faint);"></span>
|
| 483 |
+
<button class="min-btn" data-min="dc-filtered-panel" title="Minimize panel">−</button>
|
| 484 |
+
</h2>
|
| 485 |
+
<div class="panel-body">
|
| 486 |
+
<div id="dc-filtered-list"><span class="empty">Encode + apply filter to populate.</span></div>
|
| 487 |
+
</div>
|
| 488 |
+
</div>
|
| 489 |
+
<div class="panel" data-show-tab="datacentric" id="dc-synth-panel">
|
| 490 |
+
<h2>Synthesis output <span id="dc-synth-stats" style="font-family:var(--mono); color:var(--fg-faint);"></span>
|
| 491 |
+
<button class="min-btn" data-min="dc-synth-panel" title="Minimize panel">−</button>
|
| 492 |
+
</h2>
|
| 493 |
+
<div class="panel-body">
|
| 494 |
+
<div id="dc-synth-list"><span class="empty">Run "Synthesize from corpus" to populate.</span></div>
|
| 495 |
+
</div>
|
| 496 |
+
</div>
|
| 497 |
+
</section>
|
| 498 |
+
|
| 499 |
+
</div>
|
| 500 |
+
|
| 501 |
+
<div id="tooltip"></div>
|
| 502 |
+
|
| 503 |
+
<div class="footer">
|
| 504 |
+
<span id="status">idle</span>
|
| 505 |
+
</div>
|
| 506 |
+
|
| 507 |
+
<script>
|
| 508 |
+
const API = ""; // same-origin: HF Space serves API + HTML from one process
|
| 509 |
+
const N_FEATURES = 32768;
|
| 510 |
+
|
| 511 |
+
// --- State ---
|
| 512 |
+
const state = {
|
| 513 |
+
positions: null, // Float32Array length N*3
|
| 514 |
+
topFeatures: [], // [{id, act}, ...] from /encode
|
| 515 |
+
steering: new Map(), // feature_id -> alpha
|
| 516 |
+
picked: null, // currently hovered feature_id
|
| 517 |
+
};
|
| 518 |
+
|
| 519 |
+
// --- Three.js scene ---
|
| 520 |
+
const scene = new THREE.Scene();
|
| 521 |
+
scene.fog = new THREE.FogExp2(0x08080d, 0.08);
|
| 522 |
+
const camera = new THREE.PerspectiveCamera(55, window.innerWidth / window.innerHeight, 0.01, 60);
|
| 523 |
+
camera.position.set(0, 0, 3.4);
|
| 524 |
+
const renderer = new THREE.WebGLRenderer({
|
| 525 |
+
canvas: document.getElementById("scene"),
|
| 526 |
+
antialias: true,
|
| 527 |
+
alpha: true,
|
| 528 |
+
});
|
| 529 |
+
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
|
| 530 |
+
renderer.setSize(window.innerWidth, window.innerHeight);
|
| 531 |
+
renderer.setClearColor(0x000000, 0);
|
| 532 |
+
|
| 533 |
+
// Subtle starry haze background — many tiny dim points
|
| 534 |
+
function addStarHaze() {
|
| 535 |
+
const N = 600;
|
| 536 |
+
const g = new THREE.BufferGeometry();
|
| 537 |
+
const pos = new Float32Array(N*3);
|
| 538 |
+
for (let i=0; i<N; i++) {
|
| 539 |
+
const r = 18 + Math.random()*6;
|
| 540 |
+
const theta = Math.random()*Math.PI*2;
|
| 541 |
+
const phi = Math.acos(2*Math.random()-1);
|
| 542 |
+
pos[i*3] = r*Math.sin(phi)*Math.cos(theta);
|
| 543 |
+
pos[i*3+1] = r*Math.sin(phi)*Math.sin(theta);
|
| 544 |
+
pos[i*3+2] = r*Math.cos(phi);
|
| 545 |
+
}
|
| 546 |
+
g.setAttribute("position", new THREE.BufferAttribute(pos, 3));
|
| 547 |
+
const m = new THREE.PointsMaterial({ color: 0x2a2a3a, size: 0.06, sizeAttenuation: true, transparent:true, opacity:0.5 });
|
| 548 |
+
scene.add(new THREE.Points(g, m));
|
| 549 |
+
}
|
| 550 |
+
addStarHaze();
|
| 551 |
+
|
| 552 |
+
// Feature point cloud — built once positions arrive
|
| 553 |
+
let featurePoints = null, featureGeometry = null;
|
| 554 |
+
function buildFeatureCloud(positions) {
|
| 555 |
+
featureGeometry = new THREE.BufferGeometry();
|
| 556 |
+
const N = positions.length / 3;
|
| 557 |
+
featureGeometry.setAttribute("position", new THREE.BufferAttribute(positions, 3));
|
| 558 |
+
const colors = new Float32Array(N*3);
|
| 559 |
+
const sizes = new Float32Array(N);
|
| 560 |
+
for (let i=0; i<N; i++) {
|
| 561 |
+
colors[i*3] = 1.0; colors[i*3+1] = 1.0; colors[i*3+2] = 1.0;
|
| 562 |
+
sizes[i] = 0.5;
|
| 563 |
+
}
|
| 564 |
+
featureGeometry.setAttribute("color", new THREE.BufferAttribute(colors, 3));
|
| 565 |
+
featureGeometry.setAttribute("size", new THREE.BufferAttribute(sizes, 1));
|
| 566 |
+
|
| 567 |
+
const material = new THREE.ShaderMaterial({
|
| 568 |
+
uniforms: { uPixelRatio: { value: renderer.getPixelRatio() } },
|
| 569 |
+
vertexShader: `
|
| 570 |
+
attribute float size;
|
| 571 |
+
varying vec3 vColor;
|
| 572 |
+
varying float vIsActive;
|
| 573 |
+
varying float vSize;
|
| 574 |
+
uniform float uPixelRatio;
|
| 575 |
+
void main() {
|
| 576 |
+
vColor = color;
|
| 577 |
+
vIsActive = step(1.0, size); // active features (sizes attribute >= 1.5)
|
| 578 |
+
vec4 mv = modelViewMatrix * vec4(position, 1.0);
|
| 579 |
+
gl_Position = projectionMatrix * mv;
|
| 580 |
+
// Base features: small fixed-pixel hard dots (~3 device px, so 6 retina px).
|
| 581 |
+
// Active features: size attribute -> px directly, with mild perspective
|
| 582 |
+
// (closer = larger), capped so they never balloon out.
|
| 583 |
+
float basePx = 3.0 * uPixelRatio;
|
| 584 |
+
float scaledPx = size * 3.0 * (3.4 / -mv.z) * uPixelRatio;
|
| 585 |
+
scaledPx = clamp(scaledPx, basePx, 26.0 * uPixelRatio);
|
| 586 |
+
gl_PointSize = mix(basePx, scaledPx, vIsActive);
|
| 587 |
+
vSize = gl_PointSize;
|
| 588 |
+
}`,
|
| 589 |
+
fragmentShader: `
|
| 590 |
+
varying vec3 vColor;
|
| 591 |
+
varying float vIsActive;
|
| 592 |
+
varying float vSize;
|
| 593 |
+
void main() {
|
| 594 |
+
vec2 uv = gl_PointCoord - 0.5;
|
| 595 |
+
float d = length(uv);
|
| 596 |
+
if (d > 0.5) discard;
|
| 597 |
+
// Anti-alias the rim with a fixed 1-pixel band, regardless of point size,
|
| 598 |
+
// so even tiny points stay solid and crisp instead of going gaussian.
|
| 599 |
+
float aa = 1.0 / max(vSize, 2.0);
|
| 600 |
+
float core = 1.0 - smoothstep(0.5 - aa, 0.5, d);
|
| 601 |
+
// Halo only on active features
|
| 602 |
+
float glow = vIsActive * (1.0 - smoothstep(0.18, 0.5, d)) * 0.35;
|
| 603 |
+
vec3 col = mix(vColor, vColor * 1.5, vIsActive);
|
| 604 |
+
float alpha = clamp(core + glow, 0.0, 1.0);
|
| 605 |
+
gl_FragColor = vec4(col, alpha);
|
| 606 |
+
}`,
|
| 607 |
+
transparent: true,
|
| 608 |
+
depthWrite: false,
|
| 609 |
+
vertexColors: true,
|
| 610 |
+
blending: THREE.NormalBlending,
|
| 611 |
+
});
|
| 612 |
+
featurePoints = new THREE.Points(featureGeometry, material);
|
| 613 |
+
scene.add(featurePoints);
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
// Lazy orbit-like rotation (no heavy controls dep)
|
| 617 |
+
let userInteracting = false, autoRot = 0.06;
|
| 618 |
+
let dragging = false, lastX = 0, lastY = 0;
|
| 619 |
+
let yaw = 0.4, pitch = 0.05, dist = 3.4;
|
| 620 |
+
const cnv = document.getElementById("scene");
|
| 621 |
+
cnv.addEventListener("pointerdown", e => { dragging = true; userInteracting = true; lastX = e.clientX; lastY = e.clientY; });
|
| 622 |
+
window.addEventListener("pointerup", () => { dragging = false; });
|
| 623 |
+
window.addEventListener("pointermove", e => {
|
| 624 |
+
if (dragging) {
|
| 625 |
+
yaw += (e.clientX - lastX) * 0.005;
|
| 626 |
+
pitch += (e.clientY - lastY) * 0.005;
|
| 627 |
+
pitch = Math.max(-1.4, Math.min(1.4, pitch));
|
| 628 |
+
lastX = e.clientX; lastY = e.clientY;
|
| 629 |
+
}
|
| 630 |
+
});
|
| 631 |
+
cnv.addEventListener("wheel", e => {
|
| 632 |
+
dist *= (1 + e.deltaY * 0.0012);
|
| 633 |
+
dist = Math.max(1.3, Math.min(8, dist));
|
| 634 |
+
e.preventDefault();
|
| 635 |
+
}, { passive:false });
|
| 636 |
+
|
| 637 |
+
window.addEventListener("resize", () => {
|
| 638 |
+
camera.aspect = window.innerWidth / window.innerHeight;
|
| 639 |
+
camera.updateProjectionMatrix();
|
| 640 |
+
renderer.setSize(window.innerWidth, window.innerHeight);
|
| 641 |
+
});
|
| 642 |
+
|
| 643 |
+
function tick(t) {
|
| 644 |
+
if (!userInteracting) yaw += 0.0015;
|
| 645 |
+
camera.position.x = dist * Math.sin(yaw) * Math.cos(pitch);
|
| 646 |
+
camera.position.y = dist * Math.sin(pitch);
|
| 647 |
+
camera.position.z = dist * Math.cos(yaw) * Math.cos(pitch);
|
| 648 |
+
camera.lookAt(0, 0, 0);
|
| 649 |
+
renderer.render(scene, camera);
|
| 650 |
+
requestAnimationFrame(tick);
|
| 651 |
+
}
|
| 652 |
+
requestAnimationFrame(tick);
|
| 653 |
+
|
| 654 |
+
// --- Picking (raycasting) ---
|
| 655 |
+
const raycaster = new THREE.Raycaster();
|
| 656 |
+
raycaster.params.Points = { threshold: 0.04 };
|
| 657 |
+
const mouse = new THREE.Vector2();
|
| 658 |
+
const tooltip = document.getElementById("tooltip");
|
| 659 |
+
cnv.addEventListener("mousemove", e => {
|
| 660 |
+
mouse.x = (e.clientX / window.innerWidth) * 2 - 1;
|
| 661 |
+
mouse.y = -(e.clientY / window.innerHeight) * 2 + 1;
|
| 662 |
+
if (!featurePoints) return;
|
| 663 |
+
raycaster.setFromCamera(mouse, camera);
|
| 664 |
+
const hits = raycaster.intersectObject(featurePoints);
|
| 665 |
+
if (hits.length > 0) {
|
| 666 |
+
const id = hits[0].index;
|
| 667 |
+
state.picked = id;
|
| 668 |
+
const top = state.topFeatures.find(t => t.id === id);
|
| 669 |
+
const isSteered = state.steering.has(id);
|
| 670 |
+
const alpha = isSteered ? state.steering.get(id) : null;
|
| 671 |
+
let html = `<div style="color:var(--fg); font-weight:600;">feature ${id}</div>`;
|
| 672 |
+
if (top) {
|
| 673 |
+
const rank = state.topFeatures.indexOf(top);
|
| 674 |
+
html += `<div style="margin-top:3px; color:var(--accent);">rank #${rank} of top firing · activation ${top.act.toFixed(3)}</div>`;
|
| 675 |
+
} else {
|
| 676 |
+
html += `<div style="margin-top:3px; color:var(--fg-faint);">not in top firing for current prompt</div>`;
|
| 677 |
+
}
|
| 678 |
+
if (isSteered) {
|
| 679 |
+
const sign = alpha >= 0 ? "+" : "";
|
| 680 |
+
html += `<div style="margin-top:3px; color:var(--accent-2);">steered · α=${sign}${alpha.toFixed(0)}</div>`;
|
| 681 |
+
}
|
| 682 |
+
html += `<div style="margin-top:5px; color:var(--fg-faint); font-size:10px;">click to ${isSteered ? "edit slider" : "add steering slider"}</div>`;
|
| 683 |
+
tooltip.innerHTML = html;
|
| 684 |
+
tooltip.style.display = "block";
|
| 685 |
+
tooltip.style.left = e.clientX + "px";
|
| 686 |
+
tooltip.style.top = e.clientY + "px";
|
| 687 |
+
} else {
|
| 688 |
+
state.picked = null;
|
| 689 |
+
tooltip.style.display = "none";
|
| 690 |
+
}
|
| 691 |
+
});
|
| 692 |
+
cnv.addEventListener("click", () => {
|
| 693 |
+
if (state.picked != null) {
|
| 694 |
+
addSteerSlot(state.picked, 0);
|
| 695 |
+
}
|
| 696 |
+
});
|
| 697 |
+
|
| 698 |
+
// --- Update particle attributes after encode ---
|
| 699 |
+
function repaintCloud() {
|
| 700 |
+
if (!featureGeometry) return;
|
| 701 |
+
const colors = featureGeometry.attributes.color.array;
|
| 702 |
+
const sizes = featureGeometry.attributes.size.array;
|
| 703 |
+
const N = sizes.length;
|
| 704 |
+
for (let i=0; i<N; i++) {
|
| 705 |
+
colors[i*3] = 1.0; colors[i*3+1] = 1.0; colors[i*3+2] = 1.0;
|
| 706 |
+
sizes[i] = 0.5;
|
| 707 |
+
}
|
| 708 |
+
// Top firing — cyan, larger
|
| 709 |
+
const maxAct = state.topFeatures[0]?.act || 1;
|
| 710 |
+
for (const f of state.topFeatures) {
|
| 711 |
+
const i = f.id;
|
| 712 |
+
const t = Math.min(1, Math.abs(f.act) / maxAct);
|
| 713 |
+
colors[i*3] = 0.49 * t + 1.0 * (1-t);
|
| 714 |
+
colors[i*3+1] = 0.97 * t + 1.0 * (1-t);
|
| 715 |
+
colors[i*3+2] = 1.00;
|
| 716 |
+
sizes[i] = 1.5 + 4.5 * t;
|
| 717 |
+
}
|
| 718 |
+
// Steered — magenta override
|
| 719 |
+
for (const [id, slot] of state.steering) {
|
| 720 |
+
const alpha = (typeof slot === "object") ? slot.alpha : slot;
|
| 721 |
+
if (alpha === 0) continue;
|
| 722 |
+
const t = Math.min(1, Math.abs(alpha) / 100);
|
| 723 |
+
colors[id*3] = 1.00;
|
| 724 |
+
colors[id*3+1] = 0.49 * (1-t) + 0.4 * t;
|
| 725 |
+
colors[id*3+2] = 0.97;
|
| 726 |
+
sizes[id] = 4 + 5 * t;
|
| 727 |
+
}
|
| 728 |
+
featureGeometry.attributes.color.needsUpdate = true;
|
| 729 |
+
featureGeometry.attributes.size.needsUpdate = true;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
// --- API helpers ---
|
| 733 |
+
function setStatus(msg, busy=false) {
|
| 734 |
+
const el = document.getElementById("status");
|
| 735 |
+
el.innerHTML = busy ? `<span class="loader"></span> ${msg}` : msg;
|
| 736 |
+
}
|
| 737 |
+
async function api(path, body=null) {
|
| 738 |
+
const opts = body ? {method:"POST", headers:{"Content-Type":"application/json"}, body: JSON.stringify(body)}
|
| 739 |
+
: {method:"GET"};
|
| 740 |
+
const r = await fetch(API + path, opts);
|
| 741 |
+
if (!r.ok) throw new Error(`${path} -> ${r.status} ${r.statusText}`);
|
| 742 |
+
return await r.json();
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
// --- Boot ---
|
| 746 |
+
function applyHealthToHeader(hp) {
|
| 747 |
+
const layerInput = document.getElementById("layer-input");
|
| 748 |
+
layerInput.value = hp.layer;
|
| 749 |
+
layerInput.max = (hp.n_layers || 1) - 1;
|
| 750 |
+
layerInput.title = `Layer 0..${hp.n_layers - 1}`;
|
| 751 |
+
document.getElementById("hdr-layer-meta").textContent = ` · ${hp.device} · ${hp.dtype}`;
|
| 752 |
+
document.getElementById("hdr-features").textContent = `${hp.n_features.toLocaleString()} features`;
|
| 753 |
+
const hdr = document.getElementById("header");
|
| 754 |
+
hdr.classList.remove("loading");
|
| 755 |
+
hdr.classList.add("live");
|
| 756 |
+
if (hp.transferred && hp.note) {
|
| 757 |
+
setStatus("⚠ transferred SAE: " + hp.note);
|
| 758 |
+
}
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
// Layer hot-swap: change SAE layer without reloading model
|
| 762 |
+
let layerSwapPending = null;
|
| 763 |
+
document.getElementById("layer-input").addEventListener("change", async (ev) => {
|
| 764 |
+
const newLayer = parseInt(ev.target.value);
|
| 765 |
+
const hp = await api("/health");
|
| 766 |
+
if (newLayer === hp.layer) return;
|
| 767 |
+
if (isNaN(newLayer) || newLayer < 0 || newLayer >= hp.n_layers) {
|
| 768 |
+
ev.target.value = hp.layer; return;
|
| 769 |
+
}
|
| 770 |
+
// Clear stale UI state — features and outputs are layer-specific
|
| 771 |
+
state.topFeatures = []; state.steering.clear();
|
| 772 |
+
document.getElementById("features").innerHTML = `<div class="empty">Encode a prompt to populate features.</div>`;
|
| 773 |
+
document.getElementById("feat-count").textContent = "";
|
| 774 |
+
document.getElementById("out-baseline").innerHTML = `<span class="empty">(no run yet)</span>`;
|
| 775 |
+
document.getElementById("out-steered").innerHTML = `<span class="empty">(no run yet)</span>`;
|
| 776 |
+
document.getElementById("verifier").innerHTML = "";
|
| 777 |
+
document.getElementById("heatmap-grid").innerHTML = `<span class="empty">Encode a prompt — heatmap fills automatically.</span>`;
|
| 778 |
+
document.getElementById("btn-generate").disabled = true;
|
| 779 |
+
|
| 780 |
+
showLoading(`Switching to layer ${newLayer}…`,
|
| 781 |
+
"First time may take ~20s (download + SVD); subsequent switches are <0.1s.");
|
| 782 |
+
try {
|
| 783 |
+
const r = await api("/set_layer", {layer: newLayer});
|
| 784 |
+
const hp2 = await api("/health");
|
| 785 |
+
applyHealthToHeader(hp2);
|
| 786 |
+
applyPositionsToCloud(r.positions);
|
| 787 |
+
setStatus(`layer ${newLayer}` + (r.from_cache ? " (cached)" : ""));
|
| 788 |
+
} catch (e) {
|
| 789 |
+
setStatus(`layer swap error: ${e.message}`);
|
| 790 |
+
ev.target.value = hp.layer;
|
| 791 |
+
} finally {
|
| 792 |
+
hideLoading();
|
| 793 |
+
}
|
| 794 |
+
});
|
| 795 |
+
|
| 796 |
+
function applyPositionsToCloud(positions) {
|
| 797 |
+
const flat = new Float32Array(positions.length * 3);
|
| 798 |
+
for (let i=0; i<positions.length; i++) {
|
| 799 |
+
flat[i*3] = positions[i][0] * 2.2;
|
| 800 |
+
flat[i*3+1] = positions[i][1] * 2.2;
|
| 801 |
+
flat[i*3+2] = positions[i][2] * 2.2;
|
| 802 |
+
}
|
| 803 |
+
state.positions = flat;
|
| 804 |
+
if (featurePoints) {
|
| 805 |
+
scene.remove(featurePoints);
|
| 806 |
+
featureGeometry.dispose();
|
| 807 |
+
featurePoints.material.dispose();
|
| 808 |
+
featurePoints = null; featureGeometry = null;
|
| 809 |
+
}
|
| 810 |
+
buildFeatureCloud(flat);
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
function fillModelDropdown(models, currentModel) {
|
| 814 |
+
const sel = document.getElementById("model-select");
|
| 815 |
+
sel.innerHTML = "";
|
| 816 |
+
for (const m of models) {
|
| 817 |
+
const opt = document.createElement("option");
|
| 818 |
+
opt.value = m.model;
|
| 819 |
+
const xferTag = m.transferred ? " ⚠" : "";
|
| 820 |
+
opt.textContent = `${m.model.replace("Qwen/","")} (${m.approx_size_gb}GB · ${m.n_features.toLocaleString()}f)${xferTag}`;
|
| 821 |
+
if (m.model === currentModel) opt.selected = true;
|
| 822 |
+
sel.appendChild(opt);
|
| 823 |
+
}
|
| 824 |
+
// If only one model in the catalog (HF Space deployment), make it
|
| 825 |
+
// visually informational rather than interactive.
|
| 826 |
+
if (models.length <= 1) {
|
| 827 |
+
sel.disabled = true;
|
| 828 |
+
sel.title = "Locked to Qwen3-1.7B-Base on HF Space (free CPU). Run locally to swap models.";
|
| 829 |
+
sel.style.cursor = "default";
|
| 830 |
+
sel.style.opacity = "0.85";
|
| 831 |
+
}
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
function showLoading(msg, detail="") {
|
| 835 |
+
document.getElementById("lo-msg").textContent = msg;
|
| 836 |
+
document.getElementById("lo-detail").textContent = detail;
|
| 837 |
+
document.getElementById("loading-overlay").classList.add("visible");
|
| 838 |
+
document.getElementById("header").classList.remove("live");
|
| 839 |
+
document.getElementById("header").classList.add("loading");
|
| 840 |
+
document.getElementById("btn-encode").disabled = true;
|
| 841 |
+
document.getElementById("btn-generate").disabled = true;
|
| 842 |
+
}
|
| 843 |
+
function hideLoading() {
|
| 844 |
+
document.getElementById("loading-overlay").classList.remove("visible");
|
| 845 |
+
document.getElementById("btn-encode").disabled = false;
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
async function boot() {
|
| 849 |
+
setStatus("loading positions…", true);
|
| 850 |
+
try {
|
| 851 |
+
const [hp, ph, ml] = await Promise.all([
|
| 852 |
+
api("/health"),
|
| 853 |
+
api("/positions"),
|
| 854 |
+
api("/list_models"),
|
| 855 |
+
]);
|
| 856 |
+
fillModelDropdown(ml.models, hp.model);
|
| 857 |
+
applyHealthToHeader(hp);
|
| 858 |
+
applyPositionsToCloud(ph.positions);
|
| 859 |
+
setStatus("ready");
|
| 860 |
+
if (hp.transferred && hp.note) setStatus("⚠ " + hp.note);
|
| 861 |
+
} catch (e) {
|
| 862 |
+
setStatus(`server unreachable at ${API}: ${e.message}`);
|
| 863 |
+
document.getElementById("model-select").innerHTML = `<option style="color:var(--bad)">server offline</option>`;
|
| 864 |
+
}
|
| 865 |
+
}
|
| 866 |
+
boot();
|
| 867 |
+
|
| 868 |
+
// Tab switching
|
| 869 |
+
document.body.dataset.tab = "steering";
|
| 870 |
+
document.querySelectorAll(".tab").forEach(btn => {
|
| 871 |
+
btn.addEventListener("click", () => {
|
| 872 |
+
document.querySelectorAll(".tab").forEach(b => b.classList.remove("active"));
|
| 873 |
+
btn.classList.add("active");
|
| 874 |
+
document.body.dataset.tab = btn.dataset.tab;
|
| 875 |
+
});
|
| 876 |
+
});
|
| 877 |
+
|
| 878 |
+
// Minimize toggle on every panel header
|
| 879 |
+
document.querySelectorAll(".min-btn").forEach(btn => {
|
| 880 |
+
btn.addEventListener("click", (ev) => {
|
| 881 |
+
ev.stopPropagation();
|
| 882 |
+
const panel = btn.closest(".panel");
|
| 883 |
+
if (!panel) return;
|
| 884 |
+
panel.classList.toggle("collapsed");
|
| 885 |
+
btn.textContent = panel.classList.contains("collapsed") ? "+" : "−";
|
| 886 |
+
btn.title = panel.classList.contains("collapsed") ? "Expand panel" : "Minimize panel";
|
| 887 |
+
});
|
| 888 |
+
});
|
| 889 |
+
|
| 890 |
+
// Model swap on dropdown change
|
| 891 |
+
document.getElementById("model-select").addEventListener("change", async (ev) => {
|
| 892 |
+
const newModel = ev.target.value;
|
| 893 |
+
const ml = await api("/list_models");
|
| 894 |
+
const entry = ml.models.find(m => m.model === newModel);
|
| 895 |
+
if (!entry) return;
|
| 896 |
+
const ok = confirm(
|
| 897 |
+
`Load ${newModel}?\n\n` +
|
| 898 |
+
`~${entry.approx_size_gb}GB download (or cached if seen before).\n` +
|
| 899 |
+
`${entry.n_features.toLocaleString()} SAE features, ${entry.n_layers} layers.\n` +
|
| 900 |
+
(entry.transferred ? `⚠ TRANSFERRED SAE: ${entry.note}\n` : ``) +
|
| 901 |
+
`\nThis blocks the server until ready.`
|
| 902 |
+
);
|
| 903 |
+
if (!ok) {
|
| 904 |
+
// revert dropdown to current
|
| 905 |
+
const hp = await api("/health");
|
| 906 |
+
ev.target.value = hp.model;
|
| 907 |
+
return;
|
| 908 |
+
}
|
| 909 |
+
// Reset client-side state
|
| 910 |
+
state.topFeatures = []; state.steering.clear();
|
| 911 |
+
document.getElementById("features").innerHTML = `<div class="empty">Encode a prompt to populate features.</div>`;
|
| 912 |
+
document.getElementById("feat-count").textContent = "";
|
| 913 |
+
document.getElementById("out-baseline").innerHTML = `<span class="empty">(no run yet)</span>`;
|
| 914 |
+
document.getElementById("out-steered").innerHTML = `<span class="empty">(no run yet)</span>`;
|
| 915 |
+
document.getElementById("verifier").innerHTML = "";
|
| 916 |
+
document.getElementById("btn-generate").disabled = true;
|
| 917 |
+
|
| 918 |
+
showLoading(`Loading ${newModel}…`,
|
| 919 |
+
`~${entry.approx_size_gb}GB · this is real, watch the server log`);
|
| 920 |
+
try {
|
| 921 |
+
const r = await fetch(API + "/load_model", {
|
| 922 |
+
method:"POST", headers:{"Content-Type":"application/json"},
|
| 923 |
+
body: JSON.stringify({model:newModel})
|
| 924 |
+
});
|
| 925 |
+
if (!r.ok) {
|
| 926 |
+
const txt = await r.text();
|
| 927 |
+
throw new Error(`${r.status}: ${txt}`);
|
| 928 |
+
}
|
| 929 |
+
const data = await r.json();
|
| 930 |
+
const hp = await api("/health");
|
| 931 |
+
applyHealthToHeader(hp);
|
| 932 |
+
applyPositionsToCloud(data.positions);
|
| 933 |
+
setStatus(`loaded ${newModel}`);
|
| 934 |
+
if (hp.transferred && hp.note) setStatus("⚠ " + hp.note);
|
| 935 |
+
} catch (e) {
|
| 936 |
+
setStatus(`load failed: ${e.message}`);
|
| 937 |
+
alert(`Load failed:\n\n${e.message}\n\nCheck the server log.`);
|
| 938 |
+
// Try to revert dropdown
|
| 939 |
+
try {
|
| 940 |
+
const hp = await api("/health");
|
| 941 |
+
ev.target.value = hp.model;
|
| 942 |
+
} catch {}
|
| 943 |
+
} finally {
|
| 944 |
+
hideLoading();
|
| 945 |
+
}
|
| 946 |
+
});
|
| 947 |
+
|
| 948 |
+
// --- Encode action ---
|
| 949 |
+
document.getElementById("btn-encode").addEventListener("click", async () => {
|
| 950 |
+
const prompt = document.getElementById("prompt").value;
|
| 951 |
+
if (!prompt.trim()) return;
|
| 952 |
+
setStatus("encoding…", true);
|
| 953 |
+
try {
|
| 954 |
+
const t0 = performance.now();
|
| 955 |
+
const top_n = Math.max(1, parseInt(document.getElementById("top-k").value || "20"));
|
| 956 |
+
const r = await api("/encode", {prompt, top_n});
|
| 957 |
+
const dt = ((performance.now() - t0)/1000).toFixed(2);
|
| 958 |
+
state.topFeatures = r.top;
|
| 959 |
+
document.getElementById("feat-count").textContent = `(K=${r.top.length} of ${r.n_features.toLocaleString()})`;
|
| 960 |
+
renderFeatures();
|
| 961 |
+
repaintCloud();
|
| 962 |
+
document.getElementById("btn-generate").disabled = false;
|
| 963 |
+
setStatus(`encoded in ${dt}s`);
|
| 964 |
+
// Auto-render the per-token heatmap too
|
| 965 |
+
renderHeatmap(prompt).catch(e => console.error("heatmap:", e));
|
| 966 |
+
} catch (e) {
|
| 967 |
+
setStatus(`encode error: ${e.message}`);
|
| 968 |
+
}
|
| 969 |
+
});
|
| 970 |
+
|
| 971 |
+
document.getElementById("heatmap-skip-first").addEventListener("change", () => {
|
| 972 |
+
if (state._lastHeatmapPrompt) renderHeatmap(state._lastHeatmapPrompt);
|
| 973 |
+
});
|
| 974 |
+
|
| 975 |
+
// =====================================================================
|
| 976 |
+
// PILLAR 2 — EVALUATION
|
| 977 |
+
// =====================================================================
|
| 978 |
+
function parsePrompts(textareaId) {
|
| 979 |
+
return document.getElementById(textareaId).value
|
| 980 |
+
.split("\n").map(s => s.trim()).filter(s => s.length > 0);
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
document.getElementById("btn-eval-encode").addEventListener("click", async () => {
|
| 984 |
+
const prompts = parsePrompts("eval-prompts");
|
| 985 |
+
if (prompts.length === 0) { setStatus("paste prompts first"); return; }
|
| 986 |
+
setStatus(`encoding ${prompts.length} prompts…`, true);
|
| 987 |
+
document.getElementById("eval-features-list").innerHTML = `<span class="loader"></span>`;
|
| 988 |
+
document.getElementById("eval-samples").innerHTML = `<span class="loader"></span>`;
|
| 989 |
+
document.getElementById("eval-heatmap").innerHTML = `<span class="loader"></span>`;
|
| 990 |
+
try {
|
| 991 |
+
const t0 = performance.now();
|
| 992 |
+
const r = await api("/encode_batch", {prompts, top_n: 10});
|
| 993 |
+
const dt = ((performance.now()-t0)/1000).toFixed(2);
|
| 994 |
+
state.evalResult = r;
|
| 995 |
+
document.getElementById("eval-stats").textContent =
|
| 996 |
+
`(${r.n_samples} samples · ${r.corpus_features.length} features fired)`;
|
| 997 |
+
renderEvalCorpus(r);
|
| 998 |
+
renderEvalSamples(r);
|
| 999 |
+
renderEvalHeatmap(r);
|
| 1000 |
+
setStatus(`corpus encoded in ${dt}s`);
|
| 1001 |
+
} catch (e) {
|
| 1002 |
+
setStatus(`eval encode error: ${e.message}`);
|
| 1003 |
+
}
|
| 1004 |
+
});
|
| 1005 |
+
|
| 1006 |
+
function renderEvalCorpus(r) {
|
| 1007 |
+
const div = document.getElementById("eval-features-list");
|
| 1008 |
+
if (!r.corpus_features.length) {
|
| 1009 |
+
div.innerHTML = `<div class="empty">No features fired.</div>`;
|
| 1010 |
+
return;
|
| 1011 |
+
}
|
| 1012 |
+
// Top firing features ranked by fire_rate, with bar
|
| 1013 |
+
const max = r.corpus_features[0].fire_rate;
|
| 1014 |
+
div.innerHTML = r.corpus_features.slice(0, 60).map(f => {
|
| 1015 |
+
const w = Math.max(2, 100 * f.fire_rate / max);
|
| 1016 |
+
return `<div style="padding:5px 8px; border-bottom:1px solid var(--border); font-family:var(--mono); font-size:11px;">
|
| 1017 |
+
<div style="display:flex; justify-content:space-between; gap:10px; align-items:center;">
|
| 1018 |
+
<span style="color:var(--accent);">feat ${f.id}</span>
|
| 1019 |
+
<span style="color:var(--fg-faint);">${(f.fire_rate*100).toFixed(0)}% · μ=${f.mean_act.toFixed(2)}</span>
|
| 1020 |
+
</div>
|
| 1021 |
+
<div style="height:3px; background:rgba(125,249,255,0.55); width:${w}%; margin-top:4px; border-radius:2px;"></div>
|
| 1022 |
+
</div>`;
|
| 1023 |
+
}).join("");
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
function renderEvalSamples(r) {
|
| 1027 |
+
const div = document.getElementById("eval-samples");
|
| 1028 |
+
if (!r.per_sample.length) { div.innerHTML = `<span class="empty">no samples</span>`; return; }
|
| 1029 |
+
div.innerHTML = r.per_sample.map(s => {
|
| 1030 |
+
const top = s.top.slice(0,5).map(t =>
|
| 1031 |
+
`<span style="display:inline-block; padding:1px 6px; margin:1px; border:1px solid var(--border-strong); border-radius:4px; font-family:var(--mono); font-size:10px; color:var(--accent);">${t.id}<span style="color:var(--fg-faint);">·${t.act.toFixed(1)}</span></span>`
|
| 1032 |
+
).join("");
|
| 1033 |
+
return `<div style="padding:8px; border-bottom:1px solid var(--border);">
|
| 1034 |
+
<div style="font-family:var(--mono); font-size:11px; color:var(--fg); margin-bottom:4px;">[${s.i}] ${escapeHtml(s.preview)}</div>
|
| 1035 |
+
<div>${top}</div>
|
| 1036 |
+
</div>`;
|
| 1037 |
+
}).join("");
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
function renderEvalHeatmap(r) {
|
| 1041 |
+
const div = document.getElementById("eval-heatmap");
|
| 1042 |
+
// Pick top 20 features that fire most across samples
|
| 1043 |
+
const topFeats = r.corpus_features.slice(0, 20);
|
| 1044 |
+
if (!topFeats.length || !r.per_sample.length) { div.innerHTML = `<span class="empty">no data</span>`; return; }
|
| 1045 |
+
// Build sample-id × feature_id activation matrix
|
| 1046 |
+
const sampleActs = r.per_sample.map(s => {
|
| 1047 |
+
const m = {};
|
| 1048 |
+
for (const t of s.top) m[t.id] = t.act;
|
| 1049 |
+
return m;
|
| 1050 |
+
});
|
| 1051 |
+
const max = Math.max(1e-6, ...sampleActs.flatMap(m => Object.values(m)));
|
| 1052 |
+
// Render: rows = samples, cols = features
|
| 1053 |
+
const header = `<th style="padding:3px 6px; font-size:10px; color:var(--fg-dim); background:rgba(0,0,0,0.4); position:sticky; left:0; z-index:2;">sample</th>` +
|
| 1054 |
+
topFeats.map(f => `<th title="feat ${f.id} (${(f.fire_rate*100).toFixed(0)}% rate)" style="padding:3px 4px; font-size:9px; color:var(--accent); border-bottom:1px solid var(--border-strong);">${f.id}</th>`).join("");
|
| 1055 |
+
const rows = r.per_sample.map((s, si) => {
|
| 1056 |
+
const cells = topFeats.map(f => {
|
| 1057 |
+
const v = sampleActs[si][f.id] || 0;
|
| 1058 |
+
const t = Math.min(1, v/max);
|
| 1059 |
+
const r = Math.round(255 - 200*t), g = 255, b = Math.round(255 - 50*t);
|
| 1060 |
+
return `<td title="sample ${si} · feat ${f.id} · ${v.toFixed(2)}" style="background:rgb(${r},${g},${b}); width:22px; height:18px; border:1px solid rgba(0,0,0,0.3);"></td>`;
|
| 1061 |
+
}).join("");
|
| 1062 |
+
return `<tr><td style="padding:2px 6px; font-family:var(--mono); font-size:10px; color:var(--fg-dim); background:rgba(0,0,0,0.3); position:sticky; left:0; z-index:1; max-width:80px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap;">${si}: ${escapeHtml(s.preview.slice(0,16))}</td>${cells}</tr>`;
|
| 1063 |
+
}).join("");
|
| 1064 |
+
div.innerHTML = `<table style="border-collapse:collapse; font-family:var(--mono);"><thead><tr>${header}</tr></thead><tbody>${rows}</tbody></table>`;
|
| 1065 |
+
}
|
| 1066 |
+
|
| 1067 |
+
function escapeHtml(s) {
|
| 1068 |
+
return String(s).replace(/&/g,"&").replace(/</g,"<").replace(/>/g,">").replace(/"/g,""");
|
| 1069 |
+
}
|
| 1070 |
+
|
| 1071 |
+
// View toggle in Evaluation bottom-right panel: heatmap vs compare
|
| 1072 |
+
document.getElementById("btn-eval-view-heatmap").addEventListener("click", () => {
|
| 1073 |
+
document.getElementById("eval-heatmap").style.display = "block";
|
| 1074 |
+
document.getElementById("eval-compare-results").style.display = "none";
|
| 1075 |
+
});
|
| 1076 |
+
document.getElementById("btn-eval-view-compare").addEventListener("click", () => {
|
| 1077 |
+
document.getElementById("eval-heatmap").style.display = "none";
|
| 1078 |
+
document.getElementById("eval-compare-results").style.display = "block";
|
| 1079 |
+
});
|
| 1080 |
+
|
| 1081 |
+
// Differential feature mining
|
| 1082 |
+
document.getElementById("btn-cmp").addEventListener("click", async () => {
|
| 1083 |
+
const a = parsePrompts("cmp-a");
|
| 1084 |
+
const b = parsePrompts("cmp-b");
|
| 1085 |
+
if (a.length === 0 || b.length === 0) {
|
| 1086 |
+
setStatus("paste both sets first"); return;
|
| 1087 |
+
}
|
| 1088 |
+
setStatus(`comparing ${a.length} vs ${b.length} prompts…`, true);
|
| 1089 |
+
// Auto-switch the bottom-right panel to compare view
|
| 1090 |
+
document.getElementById("eval-heatmap").style.display = "none";
|
| 1091 |
+
document.getElementById("eval-compare-results").style.display = "block";
|
| 1092 |
+
document.getElementById("eval-compare-results").innerHTML = `<span class="loader"></span> encoding…`;
|
| 1093 |
+
try {
|
| 1094 |
+
const t0 = performance.now();
|
| 1095 |
+
const r = await api("/compare_batch", {prompts_a: a, prompts_b: b, top_n: 30});
|
| 1096 |
+
const dt = ((performance.now()-t0)/1000).toFixed(2);
|
| 1097 |
+
renderCompareResults(r, a, b);
|
| 1098 |
+
setStatus(`compared in ${dt}s`);
|
| 1099 |
+
} catch (e) {
|
| 1100 |
+
setStatus(`compare error: ${e.message}`);
|
| 1101 |
+
document.getElementById("eval-compare-results").innerHTML = `<span class="empty">error: ${e.message}</span>`;
|
| 1102 |
+
}
|
| 1103 |
+
});
|
| 1104 |
+
|
| 1105 |
+
function renderCompareResults(r, setA, setB) {
|
| 1106 |
+
const div = document.getElementById("eval-compare-results");
|
| 1107 |
+
if (!r.top_diff || !r.top_diff.length) {
|
| 1108 |
+
div.innerHTML = `<span class="empty">no distinguishing features found</span>`;
|
| 1109 |
+
return;
|
| 1110 |
+
}
|
| 1111 |
+
const max = r.top_diff[0].diff || 1;
|
| 1112 |
+
const header = `
|
| 1113 |
+
<div style="font-size:10px; color:var(--fg-faint); padding:4px 0; display:flex; gap:14px; flex-wrap:wrap; border-bottom:1px solid var(--border-strong); margin-bottom:6px;">
|
| 1114 |
+
<span><span style="color:#7df9ff; font-weight:700;">■ A</span> ${r.n_a} prompts</span>
|
| 1115 |
+
<span><span style="color:#ff7df9; font-weight:700;">■ B</span> ${r.n_b} prompts</span>
|
| 1116 |
+
<span>top ${r.top_diff.length} by |Δ rate|</span>
|
| 1117 |
+
</div>`;
|
| 1118 |
+
const rows = r.top_diff.map((f, i) => {
|
| 1119 |
+
const wA = Math.max(2, 60 * f.rate_a);
|
| 1120 |
+
const wB = Math.max(2, 60 * f.rate_b);
|
| 1121 |
+
const winner = f.winner === "a" ? "A" : "B";
|
| 1122 |
+
const winnerColor = f.winner === "a" ? "#7df9ff" : "#ff7df9";
|
| 1123 |
+
return `<tr style="border-bottom:1px solid var(--border);">
|
| 1124 |
+
<td style="padding:3px 6px; color:var(--fg-faint); font-size:10px;">${i+1}</td>
|
| 1125 |
+
<td style="padding:3px 6px; color:var(--accent); font-family:var(--mono);">${f.id}</td>
|
| 1126 |
+
<td style="padding:3px 6px; text-align:right; font-family:var(--mono);">${(f.rate_a*100).toFixed(0)}%</td>
|
| 1127 |
+
<td style="padding:3px 6px;">
|
| 1128 |
+
<div style="display:flex; gap:2px; align-items:center;">
|
| 1129 |
+
<div style="width:${wA}px; height:6px; background:#7df9ff;"></div>
|
| 1130 |
+
<div style="width:${wB}px; height:6px; background:#ff7df9;"></div>
|
| 1131 |
+
</div>
|
| 1132 |
+
</td>
|
| 1133 |
+
<td style="padding:3px 6px; text-align:right; font-family:var(--mono);">${(f.rate_b*100).toFixed(0)}%</td>
|
| 1134 |
+
<td style="padding:3px 6px; font-family:var(--mono); font-size:10px; color:${winnerColor};">${winner} ▲</td>
|
| 1135 |
+
<td style="padding:3px 6px; text-align:right; font-family:var(--mono); color:var(--fg);">${(f.diff*100).toFixed(0)}%</td>
|
| 1136 |
+
</tr>`;
|
| 1137 |
+
}).join("");
|
| 1138 |
+
div.innerHTML = `
|
| 1139 |
+
<div>
|
| 1140 |
+
${header}
|
| 1141 |
+
<div style="overflow:auto; max-height:32vh;">
|
| 1142 |
+
<table style="border-collapse:collapse; width:100%; font-size:11px;">
|
| 1143 |
+
<thead>
|
| 1144 |
+
<tr style="background:rgba(0,0,0,0.4); position:sticky; top:0; z-index:1;">
|
| 1145 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">#</th>
|
| 1146 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">feat</th>
|
| 1147 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:right;">rate A</th>
|
| 1148 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">A vs B</th>
|
| 1149 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:right;">rate B</th>
|
| 1150 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:left;">winner</th>
|
| 1151 |
+
<th style="padding:4px 6px; font-size:10px; color:var(--fg-dim); text-align:right;">|Δ|</th>
|
| 1152 |
+
</tr>
|
| 1153 |
+
</thead>
|
| 1154 |
+
<tbody>${rows}</tbody>
|
| 1155 |
+
</table>
|
| 1156 |
+
</div>
|
| 1157 |
+
</div>`;
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
// =====================================================================
|
| 1161 |
+
// PILLAR 3 — DATA-CENTRIC
|
| 1162 |
+
// =====================================================================
|
| 1163 |
+
document.getElementById("btn-dc-encode").addEventListener("click", async () => {
|
| 1164 |
+
const prompts = parsePrompts("dc-prompts");
|
| 1165 |
+
if (prompts.length === 0) { setStatus("paste prompts first"); return; }
|
| 1166 |
+
setStatus(`encoding ${prompts.length} prompts…`, true);
|
| 1167 |
+
try {
|
| 1168 |
+
const r = await api("/encode_batch", {prompts, top_n: 50});
|
| 1169 |
+
state.dcResult = r;
|
| 1170 |
+
state.dcAllPrompts = prompts;
|
| 1171 |
+
state.dcFiltered = r.per_sample;
|
| 1172 |
+
renderDcFiltered(state.dcFiltered, prompts, null);
|
| 1173 |
+
setStatus(`encoded ${r.n_samples} prompts`);
|
| 1174 |
+
} catch (e) {
|
| 1175 |
+
setStatus(`dc encode error: ${e.message}`);
|
| 1176 |
+
}
|
| 1177 |
+
});
|
| 1178 |
+
|
| 1179 |
+
document.getElementById("btn-dc-filter").addEventListener("click", () => {
|
| 1180 |
+
if (!state.dcResult) { setStatus("encode first"); return; }
|
| 1181 |
+
const fid = parseInt(document.getElementById("dc-filter-id").value);
|
| 1182 |
+
const mode = document.getElementById("dc-filter-mode").value;
|
| 1183 |
+
if (isNaN(fid)) { setStatus("enter a feature id"); return; }
|
| 1184 |
+
const fired = new Set();
|
| 1185 |
+
for (const s of state.dcResult.per_sample) {
|
| 1186 |
+
if (s.top.some(t => t.id === fid)) fired.add(s.i);
|
| 1187 |
+
}
|
| 1188 |
+
const filtered = state.dcResult.per_sample.filter(s =>
|
| 1189 |
+
mode === "include" ? fired.has(s.i) : !fired.has(s.i)
|
| 1190 |
+
);
|
| 1191 |
+
state.dcFiltered = filtered;
|
| 1192 |
+
renderDcFiltered(filtered, state.dcAllPrompts, {id: fid, mode});
|
| 1193 |
+
});
|
| 1194 |
+
|
| 1195 |
+
document.getElementById("btn-dc-clear").addEventListener("click", () => {
|
| 1196 |
+
if (!state.dcResult) return;
|
| 1197 |
+
state.dcFiltered = state.dcResult.per_sample;
|
| 1198 |
+
renderDcFiltered(state.dcFiltered, state.dcAllPrompts, null);
|
| 1199 |
+
});
|
| 1200 |
+
|
| 1201 |
+
function renderDcFiltered(samples, prompts, filt) {
|
| 1202 |
+
const div = document.getElementById("dc-filtered-list");
|
| 1203 |
+
const stats = document.getElementById("dc-filter-stats");
|
| 1204 |
+
if (filt) {
|
| 1205 |
+
stats.textContent = `(filter: ${filt.mode === "include" ? "+" : "−"}feat ${filt.id} · ${samples.length} of ${state.dcResult.n_samples})`;
|
| 1206 |
+
} else {
|
| 1207 |
+
stats.textContent = `(${samples.length} docs)`;
|
| 1208 |
+
}
|
| 1209 |
+
if (!samples.length) { div.innerHTML = `<span class="empty">no docs match.</span>`; return; }
|
| 1210 |
+
div.innerHTML = samples.map(s => {
|
| 1211 |
+
const fullPrompt = prompts[s.i] || s.preview;
|
| 1212 |
+
const top = s.top.slice(0,3).map(t =>
|
| 1213 |
+
`<span style="display:inline-block; padding:1px 5px; margin:1px; border:1px solid var(--border-strong); border-radius:3px; font-family:var(--mono); font-size:10px; color:var(--accent);">${t.id}</span>`
|
| 1214 |
+
).join("");
|
| 1215 |
+
return `<div style="padding:6px 8px; border-bottom:1px solid var(--border); font-family:var(--mono); font-size:11px;">
|
| 1216 |
+
<div style="color:var(--fg-faint); font-size:9px;">[${s.i}]</div>
|
| 1217 |
+
<div style="color:var(--fg);">${escapeHtml(fullPrompt)}</div>
|
| 1218 |
+
<div style="margin-top:3px;">${top}</div>
|
| 1219 |
+
</div>`;
|
| 1220 |
+
}).join("");
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
document.getElementById("btn-dc-synth").addEventListener("click", async () => {
|
| 1224 |
+
const prompts = state.dcAllPrompts || parsePrompts("dc-prompts");
|
| 1225 |
+
if (!prompts.length) { setStatus("paste seeds first"); return; }
|
| 1226 |
+
const fid = parseInt(document.getElementById("dc-synth-id").value);
|
| 1227 |
+
const alpha = parseFloat(document.getElementById("dc-synth-alpha").value);
|
| 1228 |
+
const max_new_tokens = parseInt(document.getElementById("dc-synth-tokens").value);
|
| 1229 |
+
if (isNaN(fid)) { setStatus("enter feature id"); return; }
|
| 1230 |
+
setStatus(`synthesizing ${prompts.length} steered completions…`, true);
|
| 1231 |
+
document.getElementById("dc-synth-list").innerHTML = `<span class="loader"></span> running…`;
|
| 1232 |
+
try {
|
| 1233 |
+
const r = await api("/synth_batch", {
|
| 1234 |
+
seed_prompts: prompts,
|
| 1235 |
+
steering: [{id: fid, alpha}],
|
| 1236 |
+
max_new_tokens,
|
| 1237 |
+
});
|
| 1238 |
+
document.getElementById("dc-synth-stats").textContent =
|
| 1239 |
+
`(${r.results.length} · feat ${fid} α=${alpha})`;
|
| 1240 |
+
document.getElementById("dc-synth-list").innerHTML = r.results.map((res, i) =>
|
| 1241 |
+
`<div style="padding:8px; border-bottom:1px solid var(--border); font-family:var(--mono); font-size:11px;">
|
| 1242 |
+
<div style="color:var(--fg-faint); font-size:9px;">[${i}] seed</div>
|
| 1243 |
+
<div style="color:var(--fg-dim);">${escapeHtml(res.seed)}</div>
|
| 1244 |
+
<div style="color:var(--fg-faint); font-size:9px; margin-top:4px;">→ steered</div>
|
| 1245 |
+
<div style="color:var(--fg);">${escapeHtml(res.text)}</div>
|
| 1246 |
+
</div>`
|
| 1247 |
+
).join("");
|
| 1248 |
+
setStatus("synthesis done");
|
| 1249 |
+
} catch (e) {
|
| 1250 |
+
setStatus(`synth error: ${e.message}`);
|
| 1251 |
+
document.getElementById("dc-synth-list").innerHTML = `<span class="empty">error: ${e.message}</span>`;
|
| 1252 |
+
}
|
| 1253 |
+
});
|
| 1254 |
+
|
| 1255 |
+
async function renderHeatmap(prompt) {
|
| 1256 |
+
const container = document.getElementById("heatmap-grid");
|
| 1257 |
+
state._lastHeatmapPrompt = prompt;
|
| 1258 |
+
container.innerHTML = `<span class="loader"></span> computing…`;
|
| 1259 |
+
try {
|
| 1260 |
+
const r = await api("/encode_full", {prompt, top_n: 16});
|
| 1261 |
+
let tokens = r.tokens, grid = r.grid, ids = r.feature_ids;
|
| 1262 |
+
const skipFirst = document.getElementById("heatmap-skip-first").checked;
|
| 1263 |
+
if (skipFirst && tokens.length > 1) {
|
| 1264 |
+
tokens = tokens.slice(1);
|
| 1265 |
+
grid = grid.map(row => row.slice(1));
|
| 1266 |
+
}
|
| 1267 |
+
if (tokens.length === 0) {
|
| 1268 |
+
container.innerHTML = `<span class="empty">No tokens (after skip).</span>`;
|
| 1269 |
+
return;
|
| 1270 |
+
}
|
| 1271 |
+
// Build HTML table
|
| 1272 |
+
const rowMax = grid.map(row => Math.max(...row.map(Math.abs), 1e-6));
|
| 1273 |
+
const tokHeader = tokens.map((t,i) => {
|
| 1274 |
+
const safe = (t.replace(/\n/g,"↵").replace(/</g,"<").replace(/>/g,">")).slice(0,8);
|
| 1275 |
+
return `<th title="pos ${i}: ${t.replace(/\n/g,'\\n').replace(/"/g,'"')}" style="padding:3px 4px; font-size:10px; color:var(--fg-dim); border-bottom:1px solid var(--border-strong); white-space:nowrap;">${safe || `[${i}]`}</th>`;
|
| 1276 |
+
}).join("");
|
| 1277 |
+
const rows = grid.map((row, fi) => {
|
| 1278 |
+
const m = rowMax[fi];
|
| 1279 |
+
const cells = row.map((v, pi) => {
|
| 1280 |
+
const t = Math.min(1, Math.abs(v) / m);
|
| 1281 |
+
const r = 255, g = Math.round(255 - 200*t), b = Math.round(255 - 220*t);
|
| 1282 |
+
return `<td title="feat ${ids[fi]} · pos ${pi} · act=${v.toFixed(3)}" style="background:rgb(${r},${g},${b}); width:30px; height:22px; border:1px solid rgba(0,0,0,0.3);"></td>`;
|
| 1283 |
+
}).join("");
|
| 1284 |
+
return `<tr><td style="font-family:var(--mono); font-size:10px; padding:3px 6px; color:var(--accent); white-space:nowrap; background:rgba(0,0,0,0.3); position:sticky; left:0; z-index:1;">#${ids[fi]}</td>${cells}</tr>`;
|
| 1285 |
+
}).join("");
|
| 1286 |
+
container.innerHTML = `<table style="border-collapse:collapse; font-family:var(--mono);"><thead><tr><th style="padding:3px 6px; font-size:10px; color:var(--fg-dim); position:sticky; left:0; z-index:2; background:rgba(0,0,0,0.4);">feat</th>${tokHeader}</tr></thead><tbody>${rows}</tbody></table>`;
|
| 1287 |
+
} catch (e) {
|
| 1288 |
+
container.innerHTML = `<span class="empty">heatmap error: ${e.message}</span>`;
|
| 1289 |
+
}
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
// --- Feature card render ---
|
| 1293 |
+
function renderFeatures() {
|
| 1294 |
+
const div = document.getElementById("features");
|
| 1295 |
+
if (state.topFeatures.length === 0) {
|
| 1296 |
+
div.innerHTML = `<div class="empty">Encode a prompt to populate features.</div>`;
|
| 1297 |
+
return;
|
| 1298 |
+
}
|
| 1299 |
+
div.innerHTML = "";
|
| 1300 |
+
for (const f of state.topFeatures) {
|
| 1301 |
+
div.appendChild(buildFeatureCard(f.id, f.act, /*topRanked=*/true));
|
| 1302 |
+
}
|
| 1303 |
+
// Also render any steered features that weren't in top-K
|
| 1304 |
+
for (const [id, alpha] of state.steering) {
|
| 1305 |
+
if (!state.topFeatures.find(t => t.id === id)) {
|
| 1306 |
+
div.appendChild(buildFeatureCard(id, null, /*topRanked=*/false));
|
| 1307 |
+
}
|
| 1308 |
+
}
|
| 1309 |
+
}
|
| 1310 |
+
function buildFeatureCard(id, act, topRanked) {
|
| 1311 |
+
const wrap = document.createElement("div");
|
| 1312 |
+
wrap.className = "feat" + (state.steering.has(id) ? " steered" : "");
|
| 1313 |
+
const alpha = state.steering.has(id) ? state.steering.get(id) : 0;
|
| 1314 |
+
|
| 1315 |
+
const head = document.createElement("div");
|
| 1316 |
+
head.className = "feat-head";
|
| 1317 |
+
const left = document.createElement("div");
|
| 1318 |
+
left.innerHTML = `<span class="feat-id">feat ${id}</span>` +
|
| 1319 |
+
(act != null ? ` <span class="feat-act">act ${act.toFixed(3)}</span>` : ` <span class="feat-act" style="color:var(--fg-faint)">(picked)</span>`);
|
| 1320 |
+
const tools = document.createElement("div");
|
| 1321 |
+
tools.className = "feat-tools";
|
| 1322 |
+
|
| 1323 |
+
if (state.steering.has(id)) {
|
| 1324 |
+
const reset = document.createElement("button");
|
| 1325 |
+
reset.textContent = "reset";
|
| 1326 |
+
reset.title = "Set α back to 0 (no intervention) but keep slider visible";
|
| 1327 |
+
reset.addEventListener("click", () => {
|
| 1328 |
+
state.steering.set(id, 0);
|
| 1329 |
+
renderFeatures();
|
| 1330 |
+
repaintCloud();
|
| 1331 |
+
});
|
| 1332 |
+
tools.appendChild(reset);
|
| 1333 |
+
const off = document.createElement("button");
|
| 1334 |
+
off.textContent = "remove";
|
| 1335 |
+
off.title = "Remove this slider entirely";
|
| 1336 |
+
off.addEventListener("click", () => {
|
| 1337 |
+
state.steering.delete(id);
|
| 1338 |
+
renderFeatures();
|
| 1339 |
+
repaintCloud();
|
| 1340 |
+
});
|
| 1341 |
+
tools.appendChild(off);
|
| 1342 |
+
} else {
|
| 1343 |
+
const add = document.createElement("button");
|
| 1344 |
+
add.textContent = "steer";
|
| 1345 |
+
add.addEventListener("click", () => addSteerSlot(id, 0));
|
| 1346 |
+
tools.appendChild(add);
|
| 1347 |
+
}
|
| 1348 |
+
head.appendChild(left);
|
| 1349 |
+
head.appendChild(tools);
|
| 1350 |
+
wrap.appendChild(head);
|
| 1351 |
+
|
| 1352 |
+
if (state.steering.has(id)) {
|
| 1353 |
+
const cur = state.steering.get(id);
|
| 1354 |
+
const curAlpha = (typeof cur === "object") ? cur.alpha : cur;
|
| 1355 |
+
const curPositions = (typeof cur === "object") ? (cur.positions || "") : "";
|
| 1356 |
+
const curOutOnly = (typeof cur === "object") ? !!cur.output_only : false;
|
| 1357 |
+
|
| 1358 |
+
const sl = document.createElement("div");
|
| 1359 |
+
sl.className = "slider-row";
|
| 1360 |
+
const range = document.createElement("input");
|
| 1361 |
+
range.type = "range"; range.min = -100; range.max = 100; range.step = 1;
|
| 1362 |
+
range.value = curAlpha;
|
| 1363 |
+
const val = document.createElement("span");
|
| 1364 |
+
val.className = "alpha-val";
|
| 1365 |
+
val.textContent = (curAlpha >= 0 ? "+" : "") + curAlpha.toFixed(0);
|
| 1366 |
+
range.addEventListener("input", () => {
|
| 1367 |
+
const a = parseFloat(range.value);
|
| 1368 |
+
const slot = state.steering.get(id);
|
| 1369 |
+
const next = (typeof slot === "object") ? {...slot, alpha:a} : {alpha:a, positions:"", output_only:false};
|
| 1370 |
+
state.steering.set(id, next);
|
| 1371 |
+
val.textContent = (a >= 0 ? "+" : "") + a.toFixed(0);
|
| 1372 |
+
repaintCloud();
|
| 1373 |
+
});
|
| 1374 |
+
sl.appendChild(range);
|
| 1375 |
+
sl.appendChild(val);
|
| 1376 |
+
wrap.appendChild(sl);
|
| 1377 |
+
|
| 1378 |
+
// Position-selective steering + output-only toggle
|
| 1379 |
+
const adv = document.createElement("div");
|
| 1380 |
+
adv.style.cssText = "margin-top:6px; display:flex; gap:6px; align-items:center; flex-wrap:wrap;";
|
| 1381 |
+
const posLbl = document.createElement("span");
|
| 1382 |
+
posLbl.textContent = "positions";
|
| 1383 |
+
posLbl.style.cssText = "font-size:10px; color:var(--fg-faint); flex:0 0 auto;";
|
| 1384 |
+
const posInput = document.createElement("input");
|
| 1385 |
+
posInput.type = "text";
|
| 1386 |
+
posInput.placeholder = "all";
|
| 1387 |
+
posInput.value = curPositions;
|
| 1388 |
+
posInput.style.cssText = "flex:1; min-width:60px; font-size:11px; padding:3px 6px; background:rgba(0,0,0,0.3); border:1px solid var(--border-strong); color:var(--fg); border-radius:4px; font-family:var(--mono);";
|
| 1389 |
+
posInput.title = "Token positions to steer at: 'all' or '3-7' or '0,2,5-8'. Empty = all.";
|
| 1390 |
+
posInput.addEventListener("change", () => {
|
| 1391 |
+
const slot = state.steering.get(id);
|
| 1392 |
+
const next = (typeof slot === "object") ? {...slot, positions: posInput.value} : {alpha: slot, positions: posInput.value, output_only: false};
|
| 1393 |
+
state.steering.set(id, next);
|
| 1394 |
+
});
|
| 1395 |
+
const outLbl = document.createElement("label");
|
| 1396 |
+
outLbl.style.cssText = "font-size:10px; color:var(--fg-faint); display:inline-flex; align-items:center; gap:3px; cursor:pointer; flex:0 0 auto;";
|
| 1397 |
+
const outChk = document.createElement("input");
|
| 1398 |
+
outChk.type = "checkbox";
|
| 1399 |
+
outChk.checked = curOutOnly;
|
| 1400 |
+
outChk.style.cssText = "margin:0;";
|
| 1401 |
+
outChk.addEventListener("change", () => {
|
| 1402 |
+
const slot = state.steering.get(id);
|
| 1403 |
+
const next = (typeof slot === "object") ? {...slot, output_only: outChk.checked} : {alpha: slot, positions: "", output_only: outChk.checked};
|
| 1404 |
+
state.steering.set(id, next);
|
| 1405 |
+
});
|
| 1406 |
+
outLbl.appendChild(outChk);
|
| 1407 |
+
outLbl.appendChild(document.createTextNode("output only"));
|
| 1408 |
+
adv.appendChild(posLbl);
|
| 1409 |
+
adv.appendChild(posInput);
|
| 1410 |
+
adv.appendChild(outLbl);
|
| 1411 |
+
wrap.appendChild(adv);
|
| 1412 |
+
|
| 1413 |
+
const lbl = document.createElement("div");
|
| 1414 |
+
lbl.style.cssText = "font-size:10px; color:var(--fg-faint); margin-top:4px;";
|
| 1415 |
+
lbl.textContent = "α: −100 ← suppress … +100 → amplify";
|
| 1416 |
+
wrap.appendChild(lbl);
|
| 1417 |
+
}
|
| 1418 |
+
return wrap;
|
| 1419 |
+
}
|
| 1420 |
+
|
| 1421 |
+
function addSteerSlot(id, alpha=0) {
|
| 1422 |
+
state.steering.set(id, alpha);
|
| 1423 |
+
renderFeatures();
|
| 1424 |
+
repaintCloud();
|
| 1425 |
+
}
|
| 1426 |
+
|
| 1427 |
+
// --- Generate action ---
|
| 1428 |
+
document.getElementById("btn-generate").addEventListener("click", async () => {
|
| 1429 |
+
const prompt = document.getElementById("prompt").value;
|
| 1430 |
+
const max_new_tokens = parseInt(document.getElementById("max-tokens").value || "40");
|
| 1431 |
+
const steering = [];
|
| 1432 |
+
for (const [id, slot] of state.steering) {
|
| 1433 |
+
const alpha = (typeof slot === "object") ? slot.alpha : slot;
|
| 1434 |
+
if (alpha === 0) continue;
|
| 1435 |
+
const positions = (typeof slot === "object") ? (slot.positions || null) : null;
|
| 1436 |
+
const output_only = (typeof slot === "object") ? !!slot.output_only : false;
|
| 1437 |
+
steering.push({id, alpha, positions, output_only});
|
| 1438 |
+
}
|
| 1439 |
+
setStatus("generating baseline…", true);
|
| 1440 |
+
document.getElementById("out-baseline").innerHTML = `<span class="loader"></span>`;
|
| 1441 |
+
document.getElementById("out-steered").innerHTML = `<span class="loader"></span>`;
|
| 1442 |
+
|
| 1443 |
+
try {
|
| 1444 |
+
const t0 = performance.now();
|
| 1445 |
+
const baseline = await api("/generate", {prompt, steering: [], max_new_tokens, return_probs: true, topk_display: 8});
|
| 1446 |
+
const t1 = performance.now();
|
| 1447 |
+
renderTokenChips("out-baseline", baseline, "blue");
|
| 1448 |
+
document.getElementById("base-time").textContent = ((t1-t0)/1000).toFixed(2) + "s";
|
| 1449 |
+
|
| 1450 |
+
if (steering.length === 0) {
|
| 1451 |
+
document.getElementById("out-steered").innerHTML = `<span class="empty">(no sliders engaged — steered = baseline)</span>`;
|
| 1452 |
+
document.getElementById("steered-time").textContent = "";
|
| 1453 |
+
document.getElementById("verifier").innerHTML = "";
|
| 1454 |
+
setStatus("done");
|
| 1455 |
+
return;
|
| 1456 |
+
}
|
| 1457 |
+
setStatus("generating steered…", true);
|
| 1458 |
+
const t2 = performance.now();
|
| 1459 |
+
const steered = await api("/generate", {prompt, steering, max_new_tokens, return_probs: true, topk_display: 8});
|
| 1460 |
+
const t3 = performance.now();
|
| 1461 |
+
renderTokenChips("out-steered", steered, "magenta");
|
| 1462 |
+
document.getElementById("steered-time").textContent = ((t3-t2)/1000).toFixed(2) + "s";
|
| 1463 |
+
|
| 1464 |
+
// Verifier
|
| 1465 |
+
const v = document.getElementById("verifier");
|
| 1466 |
+
v.innerHTML = "";
|
| 1467 |
+
for (const row of steered.verifier) {
|
| 1468 |
+
const d = row.steered - row.base;
|
| 1469 |
+
const cls = d > 0 ? "delta-up" : (d < 0 ? "delta-down" : "");
|
| 1470 |
+
const sign = d >= 0 ? "+" : "";
|
| 1471 |
+
const ext = [];
|
| 1472 |
+
if (row.positions) ext.push(`pos=${row.positions}`);
|
| 1473 |
+
if (row.output_only) ext.push("output-only");
|
| 1474 |
+
const extStr = ext.length ? ` · ${ext.join(" · ")}` : "";
|
| 1475 |
+
const div = document.createElement("div");
|
| 1476 |
+
div.innerHTML = `feat ${row.id} · α=${row.alpha.toFixed(0)}${extStr} · base=${row.base.toFixed(2)} → steered=${row.steered.toFixed(2)} <span class="${cls}">(Δ ${sign}${d.toFixed(2)})</span>`;
|
| 1477 |
+
v.appendChild(div);
|
| 1478 |
+
}
|
| 1479 |
+
setStatus("done");
|
| 1480 |
+
} catch (e) {
|
| 1481 |
+
setStatus(`generate error: ${e.message}`);
|
| 1482 |
+
document.getElementById("out-baseline").textContent = "(error)";
|
| 1483 |
+
document.getElementById("out-steered").textContent = "(error)";
|
| 1484 |
+
}
|
| 1485 |
+
});
|
| 1486 |
+
|
| 1487 |
+
// --- Per-token probability chip strip ---
|
| 1488 |
+
function renderTokenChips(containerId, gen, theme) {
|
| 1489 |
+
const container = document.getElementById(containerId);
|
| 1490 |
+
if (!gen.tokens || !gen.tokens.length) {
|
| 1491 |
+
container.textContent = gen.text || "(empty)";
|
| 1492 |
+
return;
|
| 1493 |
+
}
|
| 1494 |
+
// Theme: colors for chip background
|
| 1495 |
+
const themeFns = {
|
| 1496 |
+
blue: (p) => {
|
| 1497 |
+
const t = Math.max(0, Math.min(1, p));
|
| 1498 |
+
const r = Math.round(255 * (1 - t*0.85));
|
| 1499 |
+
const g = Math.round(255 * (1 - t*0.55));
|
| 1500 |
+
return [r, g, 255, t < 0.5 ? "#1e3a8a" : "#fff"];
|
| 1501 |
+
},
|
| 1502 |
+
magenta: (p) => {
|
| 1503 |
+
const t = Math.max(0, Math.min(1, p));
|
| 1504 |
+
const r = 255;
|
| 1505 |
+
const g = Math.round(255 * (1 - t*0.7));
|
| 1506 |
+
const b = Math.round(255 * (1 - t*0.5));
|
| 1507 |
+
return [r, g, b, t < 0.5 ? "#7f1d1d" : "#fff"];
|
| 1508 |
+
},
|
| 1509 |
+
};
|
| 1510 |
+
const colorize = themeFns[theme] || themeFns.blue;
|
| 1511 |
+
const chips = gen.tokens.map((row, i) => {
|
| 1512 |
+
const [r,g,b,fg] = colorize(row.prob);
|
| 1513 |
+
const tokDisp = row.tok.replace(/\n/g,"↵").replace(/\t/g,"→");
|
| 1514 |
+
const safe = escapeHtml(tokDisp);
|
| 1515 |
+
const panelHtml = topkPanelHtml(row.topk);
|
| 1516 |
+
return `<span class="tok-chip" data-panel='${escapeAttr(panelHtml)}' style="background:rgb(${r},${g},${b}); color:${fg}; padding:2px 6px; margin:1px; border-radius:4px; cursor:pointer; display:inline-block; font-family:var(--mono); font-size:11px; white-space:nowrap;">${safe}<sub style="opacity:.65; font-size:8px; margin-left:3px;">${(row.prob*100).toFixed(1)}%</sub></span>`;
|
| 1517 |
+
}).join("");
|
| 1518 |
+
container.innerHTML = `
|
| 1519 |
+
<div class="tok-strip" data-prob-root style="line-height:2.4;">
|
| 1520 |
+
<div style="font-size:10px; color:var(--fg-faint); margin-bottom:4px; font-style:italic;">click any token to see top-K candidates</div>
|
| 1521 |
+
<div>${chips}</div>
|
| 1522 |
+
<div data-topk-panel style="display:none; margin-top:6px; padding:6px; background:rgba(0,0,0,0.35); border:1px solid var(--border-strong); border-radius:6px; font-family:var(--mono); font-size:10px;"></div>
|
| 1523 |
+
</div>`;
|
| 1524 |
+
// Wire click-to-pin
|
| 1525 |
+
container.querySelectorAll(".tok-chip").forEach(chip => {
|
| 1526 |
+
chip.addEventListener("click", () => {
|
| 1527 |
+
const root = chip.closest("[data-prob-root]");
|
| 1528 |
+
const panel = root.querySelector("[data-topk-panel]");
|
| 1529 |
+
const wasSelected = chip.dataset.selected === "1";
|
| 1530 |
+
root.querySelectorAll(".tok-chip").forEach(c => {
|
| 1531 |
+
c.dataset.selected = "0"; c.style.outline = "";
|
| 1532 |
+
});
|
| 1533 |
+
if (wasSelected) {
|
| 1534 |
+
panel.innerHTML = ""; panel.style.display = "none";
|
| 1535 |
+
} else {
|
| 1536 |
+
chip.dataset.selected = "1";
|
| 1537 |
+
chip.style.outline = "2px solid #94a3b8";
|
| 1538 |
+
chip.style.outlineOffset = "-1px";
|
| 1539 |
+
panel.innerHTML = chip.dataset.panel;
|
| 1540 |
+
panel.style.display = "block";
|
| 1541 |
+
}
|
| 1542 |
+
});
|
| 1543 |
+
});
|
| 1544 |
+
}
|
| 1545 |
+
|
| 1546 |
+
function topkPanelHtml(topk) {
|
| 1547 |
+
const rows = topk.map((c, idx) => {
|
| 1548 |
+
const safe = escapeHtml(c.tok.replace(/\n/g,"↵").replace(/\t/g,"→"));
|
| 1549 |
+
const bg = c.is_chosen ? "background:rgba(125,249,255,0.12);" : "";
|
| 1550 |
+
const fw = c.is_chosen ? "font-weight:700;" : "";
|
| 1551 |
+
const mark = c.is_chosen ? " ✓" : "";
|
| 1552 |
+
return `<tr style="${bg}">
|
| 1553 |
+
<td style="padding:2px 6px; color:var(--fg-faint); text-align:right;">${idx+1}</td>
|
| 1554 |
+
<td style="padding:2px 6px; color:var(--fg); ${fw}">${safe}${mark}</td>
|
| 1555 |
+
<td style="padding:2px 6px; text-align:right; color:var(--accent);">${(c.prob*100).toFixed(2)}%</td>
|
| 1556 |
+
</tr>`;
|
| 1557 |
+
}).join("");
|
| 1558 |
+
return `<table style="border-collapse:collapse; width:100%;"><tbody>${rows}</tbody></table>`;
|
| 1559 |
+
}
|
| 1560 |
+
|
| 1561 |
+
function escapeAttr(s) { return s.replace(/'/g, "'").replace(/"/g, """); }
|
| 1562 |
+
</script>
|
| 1563 |
+
</body>
|
| 1564 |
+
</html>
|
qwen_scope_obs.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Observation-only tooling for Qwen-Scope SAE features.
|
| 2 |
+
|
| 3 |
+
Three utilities — none of them steer or modify generation. They exist to let a
|
| 4 |
+
reviewer interrogate what the SAE features encode, before any intervention:
|
| 5 |
+
|
| 6 |
+
* encode_prompts(model, tokenizer, sae, prompts, layer)
|
| 7 |
+
For each prompt, returns the last-token sparse feature code (shape:
|
| 8 |
+
len(prompts) x n_features).
|
| 9 |
+
|
| 10 |
+
* top_features_for_prompt(...)
|
| 11 |
+
The N strongest-firing features for a single prompt, with activation
|
| 12 |
+
values. Equivalent to the read-only path in qwen_scope_steer.
|
| 13 |
+
|
| 14 |
+
* differential_features(pos_codes, neg_codes, top_n)
|
| 15 |
+
Given two stacks of feature codes, returns features ranked by
|
| 16 |
+
(mean activation on positive set) - (mean activation on negative set).
|
| 17 |
+
Useful for "which features distinguish concept A from concept B?"
|
| 18 |
+
|
| 19 |
+
* scan_prompts_for_feature(codes, feature_id)
|
| 20 |
+
Given a stack of codes and a feature id, returns the per-prompt
|
| 21 |
+
activation values (zero where the feature didn't make TopK).
|
| 22 |
+
|
| 23 |
+
Nothing here generates steered text. Wire it up to the steering hooks in
|
| 24 |
+
qwen_scope_steer.py only when intervention is the explicit experimental goal,
|
| 25 |
+
and document the intervention separately.
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from typing import Sequence
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
from qwen_scope_steer import SAE, capture_residual
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class FeatureRanking:
|
| 39 |
+
feature_id: int
|
| 40 |
+
score: float # (pos_mean - neg_mean) for differential, or activation for top-features
|
| 41 |
+
pos_mean: float | None = None
|
| 42 |
+
neg_mean: float | None = None
|
| 43 |
+
pos_fire_rate: float | None = None
|
| 44 |
+
neg_fire_rate: float | None = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def encode_prompts(model, tokenizer, sae: SAE, prompts: Sequence[str],
|
| 48 |
+
layer_idx: int) -> torch.Tensor:
|
| 49 |
+
"""Encode the last-token residual of each prompt through the SAE.
|
| 50 |
+
|
| 51 |
+
Returns codes of shape (len(prompts), n_features) on CPU float32 for
|
| 52 |
+
stable downstream stats. No generation is performed.
|
| 53 |
+
"""
|
| 54 |
+
codes = []
|
| 55 |
+
for p in prompts:
|
| 56 |
+
inputs = tokenizer(p, return_tensors="pt").to(model.device)
|
| 57 |
+
with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
|
| 58 |
+
model(**inputs)
|
| 59 |
+
h_last = bucket["h"][0, -1].unsqueeze(0)
|
| 60 |
+
z = sae.encode(h_last)[0]
|
| 61 |
+
codes.append(z.detach().to("cpu", torch.float32))
|
| 62 |
+
return torch.stack(codes, dim=0)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def top_features_for_prompt(model, tokenizer, sae: SAE, prompt: str,
|
| 66 |
+
layer_idx: int, top_n: int = 10) -> list[FeatureRanking]:
|
| 67 |
+
codes = encode_prompts(model, tokenizer, sae, [prompt], layer_idx)[0]
|
| 68 |
+
nz = codes.nonzero(as_tuple=False).flatten()
|
| 69 |
+
vals = codes[nz]
|
| 70 |
+
order = vals.argsort(descending=True)[:top_n]
|
| 71 |
+
return [FeatureRanking(feature_id=int(nz[i]), score=float(vals[i]))
|
| 72 |
+
for i in order]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def differential_features(pos_codes: torch.Tensor, neg_codes: torch.Tensor,
|
| 76 |
+
top_n: int = 20) -> list[FeatureRanking]:
|
| 77 |
+
"""Rank features by their differential firing across two prompt sets.
|
| 78 |
+
|
| 79 |
+
pos_codes: (P, F) feature codes for the "positive" prompt set
|
| 80 |
+
neg_codes: (N, F) feature codes for the "negative" prompt set
|
| 81 |
+
|
| 82 |
+
Returns top_n features by (pos_mean - neg_mean), with both means and
|
| 83 |
+
per-set fire rates (fraction of prompts where the feature fired).
|
| 84 |
+
Read-only — no generation, no steering.
|
| 85 |
+
"""
|
| 86 |
+
if pos_codes.shape[1] != neg_codes.shape[1]:
|
| 87 |
+
raise ValueError(f"feature dim mismatch: {pos_codes.shape} vs {neg_codes.shape}")
|
| 88 |
+
pos_mean = pos_codes.mean(dim=0)
|
| 89 |
+
neg_mean = neg_codes.mean(dim=0)
|
| 90 |
+
diff = pos_mean - neg_mean
|
| 91 |
+
pos_fire = (pos_codes != 0).float().mean(dim=0)
|
| 92 |
+
neg_fire = (neg_codes != 0).float().mean(dim=0)
|
| 93 |
+
order = diff.argsort(descending=True)[:top_n]
|
| 94 |
+
return [
|
| 95 |
+
FeatureRanking(
|
| 96 |
+
feature_id=int(i),
|
| 97 |
+
score=float(diff[i]),
|
| 98 |
+
pos_mean=float(pos_mean[i]),
|
| 99 |
+
neg_mean=float(neg_mean[i]),
|
| 100 |
+
pos_fire_rate=float(pos_fire[i]),
|
| 101 |
+
neg_fire_rate=float(neg_fire[i]),
|
| 102 |
+
)
|
| 103 |
+
for i in order
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def scan_prompts_for_feature(codes: torch.Tensor, feature_id: int) -> torch.Tensor:
|
| 108 |
+
"""Per-prompt activation vector for a single feature (zero where it didn't make TopK)."""
|
| 109 |
+
return codes[:, feature_id]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def fire_rate(codes: torch.Tensor, feature_id: int) -> float:
|
| 113 |
+
"""Fraction of prompts on which the feature fired (was in TopK)."""
|
| 114 |
+
return float((codes[:, feature_id] != 0).float().mean())
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def pretty_ranking(rs: list[FeatureRanking]) -> str:
|
| 118 |
+
out = []
|
| 119 |
+
for r in rs:
|
| 120 |
+
if r.pos_mean is not None:
|
| 121 |
+
out.append(
|
| 122 |
+
f" feat {r.feature_id:>6d} "
|
| 123 |
+
f"diff={r.score:+8.4f} "
|
| 124 |
+
f"pos_mean={r.pos_mean:+8.4f} neg_mean={r.neg_mean:+8.4f} "
|
| 125 |
+
f"pos_fire={r.pos_fire_rate:.2f} neg_fire={r.neg_fire_rate:.2f}"
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
out.append(f" feat {r.feature_id:>6d} act={r.score:+8.4f}")
|
| 129 |
+
return "\n".join(out)
|
qwen_scope_steer.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qwen-Scope SAE feature reading + steering for transformers.
|
| 2 |
+
|
| 3 |
+
End-to-end demo:
|
| 4 |
+
1. Loads a base Qwen3 model and a matching Qwen-Scope TopK SAE checkpoint.
|
| 5 |
+
2. Captures the residual-stream output of a chosen decoder layer.
|
| 6 |
+
3. Encodes it through the SAE -> top-K firing features.
|
| 7 |
+
4. Generates a baseline completion.
|
| 8 |
+
5. Re-generates with feature steering: residual h <- h + alpha * W_dec[:, feat]
|
| 9 |
+
applied via register_forward_hook on every forward pass.
|
| 10 |
+
|
| 11 |
+
Verified against:
|
| 12 |
+
* Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50 (W_enc 32768x2048, W_dec 2048x32768,
|
| 13 |
+
b_enc 32768, b_dec 2048, all float32, K=50)
|
| 14 |
+
* Qwen/Qwen3-1.7B-Base (28 Qwen3DecoderLayer, hidden_size=2048, layer forward
|
| 15 |
+
returns bare torch.Tensor under transformers >= 5).
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import contextlib
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# SAE
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
@dataclass
|
| 34 |
+
class SAE:
|
| 35 |
+
W_enc: torch.Tensor # (n_features, d_model)
|
| 36 |
+
W_dec: torch.Tensor # (d_model, n_features)
|
| 37 |
+
b_enc: torch.Tensor # (n_features,)
|
| 38 |
+
b_dec: torch.Tensor # (d_model,)
|
| 39 |
+
k: int # TopK
|
| 40 |
+
layer: int # layer index this SAE belongs to
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def from_repo(cls, repo: str, layer: int, k: int, device: str = "cpu",
|
| 44 |
+
dtype: torch.dtype = torch.float32) -> "SAE":
|
| 45 |
+
path = hf_hub_download(repo, f"layer{layer}.sae.pt")
|
| 46 |
+
return cls.from_path(path, layer=layer, k=k, device=device, dtype=dtype)
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_path(cls, path: str | Path, layer: int, k: int,
|
| 50 |
+
device: str = "cpu", dtype: torch.dtype = torch.float32) -> "SAE":
|
| 51 |
+
sd = torch.load(str(path), map_location=device, weights_only=True)
|
| 52 |
+
for key in ("W_enc", "W_dec", "b_enc", "b_dec"):
|
| 53 |
+
if key not in sd:
|
| 54 |
+
raise KeyError(f"SAE checkpoint at {path} missing key {key!r}; "
|
| 55 |
+
f"got {list(sd.keys())}")
|
| 56 |
+
return cls(
|
| 57 |
+
W_enc=sd["W_enc"].to(device=device, dtype=dtype),
|
| 58 |
+
W_dec=sd["W_dec"].to(device=device, dtype=dtype),
|
| 59 |
+
b_enc=sd["b_enc"].to(device=device, dtype=dtype),
|
| 60 |
+
b_dec=sd["b_dec"].to(device=device, dtype=dtype),
|
| 61 |
+
k=k, layer=layer,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def n_features(self) -> int:
|
| 66 |
+
return self.W_enc.shape[0]
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def d_model(self) -> int:
|
| 70 |
+
return self.W_enc.shape[1]
|
| 71 |
+
|
| 72 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""Encode residual stream activations -> sparse feature codes (TopK)."""
|
| 74 |
+
x = x.to(device=self.W_enc.device, dtype=self.W_enc.dtype)
|
| 75 |
+
pre = F.linear(x, self.W_enc, self.b_enc) # (..., n_features)
|
| 76 |
+
topk_vals, topk_idx = pre.topk(self.k, dim=-1)
|
| 77 |
+
z = torch.zeros_like(pre)
|
| 78 |
+
z.scatter_(-1, topk_idx, topk_vals)
|
| 79 |
+
return z
|
| 80 |
+
|
| 81 |
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
z = z.to(device=self.W_dec.device, dtype=self.W_dec.dtype)
|
| 83 |
+
return F.linear(z, self.W_dec, self.b_dec)
|
| 84 |
+
|
| 85 |
+
def steering_vector(self, feature_id: int) -> torch.Tensor:
|
| 86 |
+
return self.W_dec[:, feature_id].clone()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
# Hook helpers
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
def _layer_output_to_tensor(out):
|
| 93 |
+
"""Qwen3DecoderLayer returns torch.Tensor in transformers >= 5,
|
| 94 |
+
a tuple (hidden_states, ...) in transformers < 5. Handle both."""
|
| 95 |
+
if isinstance(out, tuple):
|
| 96 |
+
return out[0], out
|
| 97 |
+
return out, None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _rebuild_layer_output(new_h: torch.Tensor, original_out):
|
| 101 |
+
if original_out is None:
|
| 102 |
+
return new_h
|
| 103 |
+
return (new_h, *original_out[1:])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@contextlib.contextmanager
|
| 107 |
+
def capture_residual(model, layer_idx: int):
|
| 108 |
+
"""Capture the residual-stream output of model.model.layers[layer_idx]."""
|
| 109 |
+
bucket: dict = {}
|
| 110 |
+
layer = model.model.layers[layer_idx]
|
| 111 |
+
|
| 112 |
+
def hook(_module, _inp, out):
|
| 113 |
+
h, _ = _layer_output_to_tensor(out)
|
| 114 |
+
bucket["h"] = h.detach()
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
handle = layer.register_forward_hook(hook)
|
| 118 |
+
try:
|
| 119 |
+
yield bucket
|
| 120 |
+
finally:
|
| 121 |
+
handle.remove()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@contextlib.contextmanager
|
| 125 |
+
def steer(model, layer_idx: int, direction: torch.Tensor, alpha: float):
|
| 126 |
+
"""Add `alpha * direction` to the residual stream output of layer_idx
|
| 127 |
+
on every forward pass while the context is active."""
|
| 128 |
+
layer = model.model.layers[layer_idx]
|
| 129 |
+
direction = direction.detach()
|
| 130 |
+
|
| 131 |
+
def hook(_module, _inp, out):
|
| 132 |
+
h, original = _layer_output_to_tensor(out)
|
| 133 |
+
d = direction.to(device=h.device, dtype=h.dtype)
|
| 134 |
+
new_h = h + alpha * d
|
| 135 |
+
return _rebuild_layer_output(new_h, original)
|
| 136 |
+
|
| 137 |
+
handle = layer.register_forward_hook(hook)
|
| 138 |
+
try:
|
| 139 |
+
yield
|
| 140 |
+
finally:
|
| 141 |
+
handle.remove()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
# Pipeline
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
def read_top_features(model, tokenizer, sae: SAE, prompt: str,
|
| 148 |
+
layer_idx: int, top_n: int = 10):
|
| 149 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 150 |
+
with torch.no_grad(), capture_residual(model, layer_idx) as bucket:
|
| 151 |
+
model(**inputs)
|
| 152 |
+
h = bucket["h"] # (1, T, d_model) on model.device
|
| 153 |
+
h_last = h[0, -1].unsqueeze(0) # (1, d_model) — encode() handles device/dtype
|
| 154 |
+
z = sae.encode(h_last)[0]
|
| 155 |
+
nonzero = z.nonzero(as_tuple=False).flatten()
|
| 156 |
+
vals = z[nonzero]
|
| 157 |
+
order = vals.argsort(descending=True)
|
| 158 |
+
top = nonzero[order][:top_n]
|
| 159 |
+
return [(int(f.item()), float(z[f].item())) for f in top]
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def generate(model, tokenizer, prompt: str, max_new_tokens: int = 40):
|
| 163 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
out = model.generate(
|
| 166 |
+
**inputs,
|
| 167 |
+
max_new_tokens=max_new_tokens,
|
| 168 |
+
do_sample=False, # deterministic for A/B comparison
|
| 169 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 170 |
+
)
|
| 171 |
+
return tokenizer.decode(out[0], skip_special_tokens=True)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# CLI
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
def parse_topk_from_repo(repo: str) -> int:
|
| 178 |
+
# e.g. "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50" -> 50
|
| 179 |
+
suffix = repo.rsplit("L0_", 1)
|
| 180 |
+
if len(suffix) == 2 and suffix[1].isdigit():
|
| 181 |
+
return int(suffix[1])
|
| 182 |
+
return 50
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main():
|
| 186 |
+
ap = argparse.ArgumentParser()
|
| 187 |
+
ap.add_argument("--model", default="Qwen/Qwen3-1.7B-Base")
|
| 188 |
+
ap.add_argument("--sae-repo", default="Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50")
|
| 189 |
+
ap.add_argument("--layer", type=int, default=14)
|
| 190 |
+
ap.add_argument("--prompt", default="The capital of France is")
|
| 191 |
+
ap.add_argument("--max-new-tokens", type=int, default=40)
|
| 192 |
+
ap.add_argument("--alpha", type=float, default=-10.0,
|
| 193 |
+
help="Steering magnitude. Negative suppresses, positive amplifies.")
|
| 194 |
+
ap.add_argument("--suppress-rank", type=int, default=0,
|
| 195 |
+
help="Which top-firing feature (0 = strongest) to steer.")
|
| 196 |
+
ap.add_argument("--feature-id", type=int, default=None,
|
| 197 |
+
help="Override: steer this exact feature instead of a top-rank pick.")
|
| 198 |
+
ap.add_argument("--topk", type=int, default=None,
|
| 199 |
+
help="Override SAE TopK (auto-detected from repo name).")
|
| 200 |
+
ap.add_argument("--device", default=None,
|
| 201 |
+
help="cuda | mps | cpu (auto if omitted)")
|
| 202 |
+
ap.add_argument("--dtype", default="bfloat16",
|
| 203 |
+
choices=["bfloat16", "float16", "float32"])
|
| 204 |
+
args = ap.parse_args()
|
| 205 |
+
|
| 206 |
+
if args.device is None:
|
| 207 |
+
if torch.cuda.is_available():
|
| 208 |
+
args.device = "cuda"
|
| 209 |
+
elif torch.backends.mps.is_available():
|
| 210 |
+
args.device = "mps"
|
| 211 |
+
else:
|
| 212 |
+
args.device = "cpu"
|
| 213 |
+
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
|
| 214 |
+
"float32": torch.float32}[args.dtype]
|
| 215 |
+
|
| 216 |
+
print(f"[load] model={args.model} device={args.device} dtype={args.dtype}")
|
| 217 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 218 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 219 |
+
args.model, dtype=dtype, device_map=args.device,
|
| 220 |
+
)
|
| 221 |
+
model.eval()
|
| 222 |
+
|
| 223 |
+
n_layers = len(model.model.layers)
|
| 224 |
+
if not (0 <= args.layer < n_layers):
|
| 225 |
+
raise ValueError(f"--layer {args.layer} out of range; model has {n_layers} layers")
|
| 226 |
+
hidden = model.config.hidden_size
|
| 227 |
+
print(f"[load] {type(model).__name__}: {n_layers} layers, hidden={hidden}")
|
| 228 |
+
|
| 229 |
+
k = args.topk or parse_topk_from_repo(args.sae_repo)
|
| 230 |
+
print(f"[load] SAE repo={args.sae_repo} layer={args.layer} K={k}")
|
| 231 |
+
sae = SAE.from_repo(args.sae_repo, layer=args.layer, k=k,
|
| 232 |
+
device=args.device, dtype=dtype)
|
| 233 |
+
if sae.d_model != hidden:
|
| 234 |
+
raise ValueError(f"SAE d_model={sae.d_model} != model hidden_size={hidden}; "
|
| 235 |
+
f"this SAE doesn't match this model.")
|
| 236 |
+
|
| 237 |
+
# 1. Top features for the prompt
|
| 238 |
+
print(f"\n[features] top firing at layer {args.layer} for prompt: {args.prompt!r}")
|
| 239 |
+
top = read_top_features(model, tokenizer, sae, args.prompt, args.layer, top_n=10)
|
| 240 |
+
for rank, (fid, act) in enumerate(top):
|
| 241 |
+
print(f" rank {rank:2d} feature {fid:>6d} act={act:+.4f}")
|
| 242 |
+
|
| 243 |
+
# Pick steering target
|
| 244 |
+
if args.feature_id is not None:
|
| 245 |
+
target_id = args.feature_id
|
| 246 |
+
else:
|
| 247 |
+
target_id = top[args.suppress_rank][0]
|
| 248 |
+
|
| 249 |
+
# 2. Baseline generation
|
| 250 |
+
print(f"\n[baseline] generating (no steering)...")
|
| 251 |
+
baseline = generate(model, tokenizer, args.prompt, args.max_new_tokens)
|
| 252 |
+
print(f" >>> {baseline!r}")
|
| 253 |
+
|
| 254 |
+
# 3. Steered generation
|
| 255 |
+
print(f"\n[steer] feature {target_id} at layer {args.layer} with alpha={args.alpha}")
|
| 256 |
+
direction = sae.steering_vector(target_id)
|
| 257 |
+
with steer(model, args.layer, direction, args.alpha):
|
| 258 |
+
steered = generate(model, tokenizer, args.prompt, args.max_new_tokens)
|
| 259 |
+
print(f" >>> {steered!r}")
|
| 260 |
+
|
| 261 |
+
# 4. Verify the steering actually moved the feature
|
| 262 |
+
inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device)
|
| 263 |
+
with torch.no_grad(), capture_residual(model, args.layer) as bucket:
|
| 264 |
+
model(**inputs)
|
| 265 |
+
base_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item()
|
| 266 |
+
with torch.no_grad(), steer(model, args.layer, direction, args.alpha), \
|
| 267 |
+
capture_residual(model, args.layer) as bucket:
|
| 268 |
+
model(**inputs)
|
| 269 |
+
steered_act = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0, target_id].item()
|
| 270 |
+
print(f"\n[verify] feature {target_id} activation: baseline={base_act:+.4f} "
|
| 271 |
+
f"steered={steered_act:+.4f} delta={steered_act - base_act:+.4f}")
|
| 272 |
+
if args.alpha > 0 and steered_act <= base_act:
|
| 273 |
+
print(" WARN: alpha>0 but activation didn't go up — unexpected.")
|
| 274 |
+
if args.alpha < 0 and steered_act >= base_act:
|
| 275 |
+
print(" WARN: alpha<0 but activation didn't go down — unexpected.")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 2 |
+
torch>=2.5.0
|
| 3 |
+
transformers>=4.51.0
|
| 4 |
+
accelerate>=1.0
|
| 5 |
+
huggingface_hub>=0.25
|
| 6 |
+
safetensors>=0.4
|
| 7 |
+
numpy>=1.26
|
| 8 |
+
fastapi>=0.110
|
| 9 |
+
uvicorn[standard]>=0.30
|
| 10 |
+
pydantic>=2.5
|
server.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend for the Qwen-Scope HF Space deployment.
|
| 2 |
+
|
| 3 |
+
Locked to Qwen3-1.7B-Base + the W32K-L0_50 SAE so it fits inside a
|
| 4 |
+
free-tier HF Space (CPU, ~16GB RAM). Layer is still selectable.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import gc
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import threading
|
| 12 |
+
from contextlib import asynccontextmanager
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from fastapi import FastAPI, HTTPException
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
from fastapi.responses import FileResponse
|
| 20 |
+
from pydantic import BaseModel
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from qwen_scope_steer import SAE, capture_residual, steer
|
| 24 |
+
|
| 25 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
DTYPE = torch.float32 # bf16 on CPU is slow + flaky on free-tier hardware
|
| 27 |
+
|
| 28 |
+
POSITIONS_DIR = Path(os.environ.get(
|
| 29 |
+
"POSITIONS_DIR",
|
| 30 |
+
str(Path(__file__).parent / "feature_positions"),
|
| 31 |
+
))
|
| 32 |
+
POSITIONS_DIR.mkdir(exist_ok=True, parents=True)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Catalog of supported model + SAE pairs.
|
| 37 |
+
# Verified against the Qwen org HF listing. For Qwen3.6 (no native SAE yet)
|
| 38 |
+
# we point at the Qwen3.5 SAE that matches dimensions; this is a best-effort
|
| 39 |
+
# fallback flagged as transferred=True in the response.
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
MODEL_CATALOG = [
|
| 42 |
+
{
|
| 43 |
+
"model": "Qwen/Qwen3-1.7B-Base",
|
| 44 |
+
"sae_repo": "Qwen/SAE-Res-Qwen3-1.7B-Base-W32K-L0_50",
|
| 45 |
+
"default_layer": 14, "n_layers": 28, "n_features": 32768,
|
| 46 |
+
"approx_size_gb": 3.4, "k": 50, "transferred": False,
|
| 47 |
+
},
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# State and locks
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
state: dict = {}
|
| 55 |
+
load_lock = threading.Lock()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _find_decoder_layers(model):
|
| 59 |
+
"""Return (layers_module_list, dotted_path) for any qwen3 / qwen3_5 model.
|
| 60 |
+
|
| 61 |
+
Handles:
|
| 62 |
+
* model.model.layers (standard Qwen3*ForCausalLM)
|
| 63 |
+
* model.language_model.model.layers (multimodal Qwen3_5ForConditionalGeneration)
|
| 64 |
+
"""
|
| 65 |
+
for path in (("model", "model", "layers"),
|
| 66 |
+
("model", "layers"),
|
| 67 |
+
("language_model", "model", "layers"),
|
| 68 |
+
("model", "language_model", "model", "layers")):
|
| 69 |
+
obj = model
|
| 70 |
+
ok = True
|
| 71 |
+
for p in path:
|
| 72 |
+
if not hasattr(obj, p):
|
| 73 |
+
ok = False; break
|
| 74 |
+
obj = getattr(obj, p)
|
| 75 |
+
if ok and hasattr(obj, "__len__") and len(obj) > 0:
|
| 76 |
+
return obj, ".".join(path)
|
| 77 |
+
raise RuntimeError(f"could not locate decoder layers on "
|
| 78 |
+
f"{type(model).__name__}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Position computation (cached per SAE).
|
| 83 |
+
# Uses TruncatedSVD via numpy power-iteration for the 80K feature SAE,
|
| 84 |
+
# economy SVD for smaller ones. Good enough for visualization layout.
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
def _safe_filename(s: str) -> str:
|
| 87 |
+
return s.replace("/", "__")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _positions_path(sae_repo: str, layer: int | None = None) -> Path:
|
| 91 |
+
if layer is None:
|
| 92 |
+
return POSITIONS_DIR / f"{_safe_filename(sae_repo)}.json"
|
| 93 |
+
return POSITIONS_DIR / f"{_safe_filename(sae_repo)}__L{layer}.json"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_positions(W_enc: torch.Tensor) -> list[list[float]]:
|
| 97 |
+
X = W_enc.detach().to("cpu", torch.float32).numpy() # (n_features, d_model)
|
| 98 |
+
X = X - X.mean(axis=0, keepdims=True)
|
| 99 |
+
n, d = X.shape
|
| 100 |
+
if n * d <= 32768 * 4096:
|
| 101 |
+
# Economy SVD is fine for the smaller SAEs.
|
| 102 |
+
_, _, Vt = np.linalg.svd(X, full_matrices=False)
|
| 103 |
+
pos = X @ Vt[:3].T
|
| 104 |
+
else:
|
| 105 |
+
# Randomized SVD for very large SAEs (e.g. 80K * 5120).
|
| 106 |
+
rng = np.random.default_rng(0)
|
| 107 |
+
Q = rng.standard_normal((d, 8)).astype(np.float32)
|
| 108 |
+
for _ in range(3): # power iterations
|
| 109 |
+
Q = X.T @ (X @ Q)
|
| 110 |
+
Q, _ = np.linalg.qr(Q)
|
| 111 |
+
Y = X @ Q # (n, 8)
|
| 112 |
+
_, _, Vt2 = np.linalg.svd(Y, full_matrices=False)
|
| 113 |
+
pos = Y @ Vt2[:3].T
|
| 114 |
+
pos = pos / max(abs(pos.min()), abs(pos.max()))
|
| 115 |
+
return pos.tolist()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_or_compute_positions(W_enc: torch.Tensor, sae_repo: str,
|
| 119 |
+
layer: int | None = None) -> list[list[float]]:
|
| 120 |
+
# Try layer-specific cache first; fall back to legacy SAE-repo-only cache
|
| 121 |
+
# so existing files don't go stale.
|
| 122 |
+
p_layer = _positions_path(sae_repo, layer)
|
| 123 |
+
p_legacy = _positions_path(sae_repo)
|
| 124 |
+
for p in (p_layer, p_legacy):
|
| 125 |
+
if p.exists():
|
| 126 |
+
try:
|
| 127 |
+
return json.loads(p.read_text())["positions"]
|
| 128 |
+
except Exception:
|
| 129 |
+
pass
|
| 130 |
+
pos = compute_positions(W_enc)
|
| 131 |
+
p_layer.write_text(json.dumps({"positions": pos}))
|
| 132 |
+
return pos
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
# Model + SAE loading (called both at startup and on /load_model)
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
def _free_current_state():
|
| 139 |
+
"""Release the currently loaded model + SAE so a new one can fit."""
|
| 140 |
+
for k in ("model", "tokenizer", "sae", "layers"):
|
| 141 |
+
if k in state:
|
| 142 |
+
del state[k]
|
| 143 |
+
gc.collect()
|
| 144 |
+
if hasattr(torch, "mps") and torch.backends.mps.is_available():
|
| 145 |
+
try:
|
| 146 |
+
torch.mps.empty_cache()
|
| 147 |
+
except Exception:
|
| 148 |
+
pass
|
| 149 |
+
if torch.cuda.is_available():
|
| 150 |
+
try:
|
| 151 |
+
torch.cuda.empty_cache()
|
| 152 |
+
except Exception:
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _catalog_entry(model_name: str, sae_repo: str | None) -> dict:
|
| 157 |
+
"""Find the catalog row that matches model_name (and optionally sae_repo)."""
|
| 158 |
+
for row in MODEL_CATALOG:
|
| 159 |
+
if row["model"] == model_name and (sae_repo is None or row["sae_repo"] == sae_repo):
|
| 160 |
+
return row
|
| 161 |
+
raise HTTPException(status_code=400,
|
| 162 |
+
detail=f"unknown model/sae combination: {model_name} / {sae_repo}")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_state(model_name: str, sae_repo: str | None = None,
|
| 166 |
+
layer: int | None = None, k: int = 50) -> dict:
|
| 167 |
+
"""Replace the loaded model+SAE+layer with the requested one."""
|
| 168 |
+
entry = _catalog_entry(model_name, sae_repo)
|
| 169 |
+
sae_repo = entry["sae_repo"]
|
| 170 |
+
layer = entry["default_layer"] if layer is None else int(layer)
|
| 171 |
+
k = entry.get("k", k)
|
| 172 |
+
|
| 173 |
+
print(f"[load] {model_name} ({entry['approx_size_gb']:.0f}GB) "
|
| 174 |
+
f"+ SAE {sae_repo} layer {layer} on {DEVICE}")
|
| 175 |
+
|
| 176 |
+
_free_current_state()
|
| 177 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 178 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 179 |
+
model_name, dtype=DTYPE, device_map=DEVICE,
|
| 180 |
+
)
|
| 181 |
+
model.eval()
|
| 182 |
+
|
| 183 |
+
layers, layers_path = _find_decoder_layers(model)
|
| 184 |
+
n_layers = len(layers)
|
| 185 |
+
if not (0 <= layer < n_layers):
|
| 186 |
+
layer = min(max(0, layer), n_layers - 1)
|
| 187 |
+
|
| 188 |
+
print(f"[load] model loaded: {type(model).__name__}, layers at "
|
| 189 |
+
f"'{layers_path}', n={n_layers}")
|
| 190 |
+
|
| 191 |
+
sae = SAE.from_repo(sae_repo, layer=layer, k=k, device=DEVICE, dtype=DTYPE)
|
| 192 |
+
print(f"[load] SAE loaded: n_features={sae.n_features}, d_model={sae.d_model}")
|
| 193 |
+
|
| 194 |
+
print("[load] computing/loading 3D feature positions")
|
| 195 |
+
positions = load_or_compute_positions(sae.W_enc, sae_repo, layer)
|
| 196 |
+
_sae_cache_put(sae_repo, layer, sae)
|
| 197 |
+
|
| 198 |
+
state.update(
|
| 199 |
+
model=model, tokenizer=tokenizer, sae=sae,
|
| 200 |
+
layers=layers, layers_path=layers_path,
|
| 201 |
+
positions=positions, n_layers=n_layers,
|
| 202 |
+
current_model=model_name, current_sae=sae_repo,
|
| 203 |
+
current_layer=layer, current_k=k,
|
| 204 |
+
catalog_entry=entry,
|
| 205 |
+
)
|
| 206 |
+
print("[load] ready")
|
| 207 |
+
return state
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
# Hook helpers — work against state["layers"] not model.model.layers
|
| 212 |
+
# ---------------------------------------------------------------------------
|
| 213 |
+
import contextlib
|
| 214 |
+
|
| 215 |
+
@contextlib.contextmanager
|
| 216 |
+
def _capture_at(layer_module):
|
| 217 |
+
bucket = {}
|
| 218 |
+
def hook(_m, _i, out):
|
| 219 |
+
h = out[0] if isinstance(out, tuple) else out
|
| 220 |
+
bucket["h"] = h.detach()
|
| 221 |
+
return out
|
| 222 |
+
handle = layer_module.register_forward_hook(hook)
|
| 223 |
+
try:
|
| 224 |
+
yield bucket
|
| 225 |
+
finally:
|
| 226 |
+
handle.remove()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@contextlib.contextmanager
|
| 230 |
+
def _steer_at(layer_module, direction, alpha, *,
|
| 231 |
+
positions=None, output_only=False, prompt_len=None):
|
| 232 |
+
"""Hook adds α·direction to layer residual, with position/decode controls.
|
| 233 |
+
|
| 234 |
+
positions : None or "all" → every token; list[int] → only those absolute
|
| 235 |
+
token indices (works across prefill + decode).
|
| 236 |
+
output_only : if True, only steer during decode (skip prefill entirely).
|
| 237 |
+
prompt_len : length of the prompt; needed to map decode-step counter
|
| 238 |
+
to absolute position when positions is a list.
|
| 239 |
+
"""
|
| 240 |
+
direction = direction.detach()
|
| 241 |
+
counter = [0]
|
| 242 |
+
pos_set = set(positions) if isinstance(positions, (list, set)) else None
|
| 243 |
+
|
| 244 |
+
def hook(_m, _i, out):
|
| 245 |
+
h = out[0] if isinstance(out, tuple) else out
|
| 246 |
+
d = direction.to(device=h.device, dtype=h.dtype)
|
| 247 |
+
cur = counter[0]
|
| 248 |
+
counter[0] += 1
|
| 249 |
+
new_h = h
|
| 250 |
+
is_prefill = (cur == 0)
|
| 251 |
+
|
| 252 |
+
if is_prefill:
|
| 253 |
+
seq = h.shape[1]
|
| 254 |
+
if output_only:
|
| 255 |
+
pass # leave prompt untouched
|
| 256 |
+
elif pos_set is None:
|
| 257 |
+
new_h = h + alpha * d
|
| 258 |
+
else:
|
| 259 |
+
new_h = h.clone()
|
| 260 |
+
for p in pos_set:
|
| 261 |
+
if 0 <= p < seq:
|
| 262 |
+
new_h[:, p, :] = new_h[:, p, :] + alpha * d
|
| 263 |
+
else:
|
| 264 |
+
# Decode step — h is [batch, 1, hidden] (one new token)
|
| 265 |
+
cur_pos = (prompt_len or 0) + cur - 1
|
| 266 |
+
if pos_set is None or output_only or (cur_pos in pos_set):
|
| 267 |
+
new_h = h + alpha * d
|
| 268 |
+
|
| 269 |
+
return (new_h, *out[1:]) if isinstance(out, tuple) else new_h
|
| 270 |
+
|
| 271 |
+
handle = layer_module.register_forward_hook(hook)
|
| 272 |
+
try:
|
| 273 |
+
yield
|
| 274 |
+
finally:
|
| 275 |
+
handle.remove()
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _parse_positions(s: str | None):
|
| 279 |
+
"""Parse '3', '3-7', '0,2,5-8', 'all', or None into a position spec.
|
| 280 |
+
Returns 'all' or a list[int] (or None if input is empty/None)."""
|
| 281 |
+
if s is None or not str(s).strip():
|
| 282 |
+
return None
|
| 283 |
+
s = str(s).strip().lower()
|
| 284 |
+
if s == "all":
|
| 285 |
+
return "all"
|
| 286 |
+
out: list[int] = []
|
| 287 |
+
for part in s.split(","):
|
| 288 |
+
part = part.strip()
|
| 289 |
+
if not part:
|
| 290 |
+
continue
|
| 291 |
+
if "-" in part:
|
| 292 |
+
try:
|
| 293 |
+
lo, hi = part.split("-", 1)
|
| 294 |
+
out.extend(range(int(lo), int(hi) + 1))
|
| 295 |
+
except ValueError:
|
| 296 |
+
continue
|
| 297 |
+
else:
|
| 298 |
+
try:
|
| 299 |
+
out.append(int(part))
|
| 300 |
+
except ValueError:
|
| 301 |
+
continue
|
| 302 |
+
return sorted(set(out)) if out else None
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _hook_stack(layer_module, sae, specs, prompt_len=None):
|
| 306 |
+
from contextlib import ExitStack
|
| 307 |
+
stack = ExitStack()
|
| 308 |
+
for s in specs:
|
| 309 |
+
d = sae.steering_vector(s.id)
|
| 310 |
+
positions = _parse_positions(getattr(s, "positions", None))
|
| 311 |
+
output_only = bool(getattr(s, "output_only", False))
|
| 312 |
+
# "all" or None both mean "every position" inside _steer_at — pass None.
|
| 313 |
+
eff_positions = None if (positions is None or positions == "all") else positions
|
| 314 |
+
stack.enter_context(_steer_at(
|
| 315 |
+
layer_module, d, s.alpha,
|
| 316 |
+
positions=eff_positions,
|
| 317 |
+
output_only=output_only,
|
| 318 |
+
prompt_len=prompt_len,
|
| 319 |
+
))
|
| 320 |
+
return stack
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ---------------------------------------------------------------------------
|
| 324 |
+
# Lifespan + app
|
| 325 |
+
# ---------------------------------------------------------------------------
|
| 326 |
+
@asynccontextmanager
|
| 327 |
+
async def lifespan(app: FastAPI):
|
| 328 |
+
# Default startup: small model so the demo is interactive immediately.
|
| 329 |
+
load_state("Qwen/Qwen3-1.7B-Base")
|
| 330 |
+
yield
|
| 331 |
+
state.clear()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
app = FastAPI(lifespan=lifespan)
|
| 335 |
+
app.add_middleware(CORSMiddleware, allow_origins=["*"],
|
| 336 |
+
allow_methods=["*"], allow_headers=["*"])
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ---------------------------------------------------------------------------
|
| 340 |
+
# Request models
|
| 341 |
+
# ---------------------------------------------------------------------------
|
| 342 |
+
class EncodeRequest(BaseModel):
|
| 343 |
+
prompt: str
|
| 344 |
+
top_n: int = 20
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class SteerSpec(BaseModel):
|
| 348 |
+
id: int
|
| 349 |
+
alpha: float
|
| 350 |
+
positions: str | None = None # "all" | "3-7" | "0,2,5" | None (= all)
|
| 351 |
+
output_only: bool = False # if True, steer only during decode, not prompt
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class GenerateRequest(BaseModel):
|
| 355 |
+
prompt: str
|
| 356 |
+
steering: list[SteerSpec] = []
|
| 357 |
+
max_new_tokens: int = 40
|
| 358 |
+
return_probs: bool = False # if True, return per-token softmax + top-K candidates
|
| 359 |
+
topk_display: int = 8 # number of candidate tokens to expose per step
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class LoadModelRequest(BaseModel):
|
| 363 |
+
model: str
|
| 364 |
+
sae_repo: str | None = None
|
| 365 |
+
layer: int | None = None
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class SetLayerRequest(BaseModel):
|
| 369 |
+
layer: int
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# ---------------------------------------------------------------------------
|
| 373 |
+
# Routes
|
| 374 |
+
# ---------------------------------------------------------------------------
|
| 375 |
+
@app.get("/")
|
| 376 |
+
def index():
|
| 377 |
+
return FileResponse(Path(__file__).parent / "index.html")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@app.get("/health")
|
| 381 |
+
def health():
|
| 382 |
+
sae = state.get("sae")
|
| 383 |
+
return {
|
| 384 |
+
"ok": True,
|
| 385 |
+
"model": state.get("current_model"),
|
| 386 |
+
"sae": state.get("current_sae"),
|
| 387 |
+
"layer": state.get("current_layer"),
|
| 388 |
+
"device": DEVICE,
|
| 389 |
+
"dtype": str(DTYPE).replace("torch.", ""),
|
| 390 |
+
"n_features": sae.n_features if sae else None,
|
| 391 |
+
"n_layers": state.get("n_layers"),
|
| 392 |
+
"transferred": state.get("catalog_entry", {}).get("transferred", False),
|
| 393 |
+
"note": state.get("catalog_entry", {}).get("note", ""),
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
@app.get("/list_models")
|
| 398 |
+
def list_models():
|
| 399 |
+
return {"models": MODEL_CATALOG}
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@app.post("/load_model")
|
| 403 |
+
def load_model(req: LoadModelRequest):
|
| 404 |
+
with load_lock:
|
| 405 |
+
try:
|
| 406 |
+
load_state(req.model, req.sae_repo, req.layer)
|
| 407 |
+
except HTTPException:
|
| 408 |
+
raise
|
| 409 |
+
except Exception as e:
|
| 410 |
+
raise HTTPException(status_code=500, detail=f"load failed: {e}")
|
| 411 |
+
sae = state["sae"]
|
| 412 |
+
return {
|
| 413 |
+
"ok": True,
|
| 414 |
+
"model": state["current_model"],
|
| 415 |
+
"sae": state["current_sae"],
|
| 416 |
+
"layer": state["current_layer"],
|
| 417 |
+
"n_features": sae.n_features,
|
| 418 |
+
"n_layers": state["n_layers"],
|
| 419 |
+
"transferred": state["catalog_entry"].get("transferred", False),
|
| 420 |
+
"note": state["catalog_entry"].get("note", ""),
|
| 421 |
+
"positions": state["positions"],
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# In-memory LRU cache of recently-used SAE checkpoints, keyed by
|
| 426 |
+
# (sae_repo, layer). Each SAE for the 1.7B model is ~537 MB on disk and
|
| 427 |
+
# similar in RAM at fp32; for the 27B SAE it's ~3.3 GB. Cap conservatively.
|
| 428 |
+
_sae_lru: "OrderedDict[tuple[str,int], SAE]" = None # initialized lazily
|
| 429 |
+
SAE_LRU_MAX = 6
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _sae_cache_get(sae_repo: str, layer: int):
|
| 433 |
+
global _sae_lru
|
| 434 |
+
if _sae_lru is None:
|
| 435 |
+
from collections import OrderedDict
|
| 436 |
+
_sae_lru = OrderedDict()
|
| 437 |
+
key = (sae_repo, layer)
|
| 438 |
+
if key in _sae_lru:
|
| 439 |
+
_sae_lru.move_to_end(key)
|
| 440 |
+
return _sae_lru[key]
|
| 441 |
+
return None
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _sae_cache_put(sae_repo: str, layer: int, sae: SAE):
|
| 445 |
+
global _sae_lru
|
| 446 |
+
if _sae_lru is None:
|
| 447 |
+
from collections import OrderedDict
|
| 448 |
+
_sae_lru = OrderedDict()
|
| 449 |
+
key = (sae_repo, layer)
|
| 450 |
+
_sae_lru[key] = sae
|
| 451 |
+
_sae_lru.move_to_end(key)
|
| 452 |
+
while len(_sae_lru) > SAE_LRU_MAX:
|
| 453 |
+
_sae_lru.popitem(last=False)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
@app.post("/set_layer")
|
| 457 |
+
def set_layer(req: SetLayerRequest):
|
| 458 |
+
"""Hot-swap the active SAE to a different layer of the same SAE repo.
|
| 459 |
+
|
| 460 |
+
Keeps the model loaded; just downloads (or fetches from cache) the new
|
| 461 |
+
layer's SAE checkpoint. Recomputes 3D positions for the new SAE
|
| 462 |
+
(cached on disk per SAE-repo+layer).
|
| 463 |
+
"""
|
| 464 |
+
if "model" not in state:
|
| 465 |
+
raise HTTPException(status_code=400, detail="no model loaded")
|
| 466 |
+
n_layers = state["n_layers"]
|
| 467 |
+
layer = int(req.layer)
|
| 468 |
+
if not (0 <= layer < n_layers):
|
| 469 |
+
raise HTTPException(status_code=400,
|
| 470 |
+
detail=f"layer must be in [0, {n_layers-1}]")
|
| 471 |
+
sae_repo = state["current_sae"]
|
| 472 |
+
if layer == state["current_layer"]:
|
| 473 |
+
return {"ok": True, "unchanged": True,
|
| 474 |
+
"layer": layer, "n_features": state["sae"].n_features,
|
| 475 |
+
"positions": state["positions"]}
|
| 476 |
+
|
| 477 |
+
with load_lock:
|
| 478 |
+
# 1. SAE itself — try LRU first
|
| 479 |
+
cached = _sae_cache_get(sae_repo, layer)
|
| 480 |
+
if cached is not None:
|
| 481 |
+
sae = cached
|
| 482 |
+
print(f"[layer-swap] SAE {sae_repo} layer {layer} from LRU cache")
|
| 483 |
+
else:
|
| 484 |
+
print(f"[layer-swap] loading SAE {sae_repo} layer {layer}")
|
| 485 |
+
k = state["catalog_entry"].get("k", 50)
|
| 486 |
+
sae = SAE.from_repo(sae_repo, layer=layer, k=k,
|
| 487 |
+
device=DEVICE, dtype=DTYPE)
|
| 488 |
+
_sae_cache_put(sae_repo, layer, sae)
|
| 489 |
+
|
| 490 |
+
# 2. Positions — per-layer cache file on disk
|
| 491 |
+
positions_key = f"{sae_repo}__L{layer}"
|
| 492 |
+
p = POSITIONS_DIR / f"{_safe_filename(positions_key)}.json"
|
| 493 |
+
if p.exists():
|
| 494 |
+
try:
|
| 495 |
+
positions = json.loads(p.read_text())["positions"]
|
| 496 |
+
except Exception:
|
| 497 |
+
positions = compute_positions(sae.W_enc)
|
| 498 |
+
p.write_text(json.dumps({"positions": positions}))
|
| 499 |
+
else:
|
| 500 |
+
print(f"[layer-swap] computing positions for layer {layer}")
|
| 501 |
+
positions = compute_positions(sae.W_enc)
|
| 502 |
+
p.write_text(json.dumps({"positions": positions}))
|
| 503 |
+
|
| 504 |
+
state["sae"] = sae
|
| 505 |
+
state["current_layer"] = layer
|
| 506 |
+
state["positions"] = positions
|
| 507 |
+
|
| 508 |
+
return {
|
| 509 |
+
"ok": True,
|
| 510 |
+
"layer": layer,
|
| 511 |
+
"n_features": sae.n_features,
|
| 512 |
+
"positions": positions,
|
| 513 |
+
"from_cache": cached is not None,
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
@app.get("/positions")
|
| 518 |
+
def positions():
|
| 519 |
+
return {"positions": state["positions"]}
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
@app.post("/encode")
|
| 523 |
+
def encode(req: EncodeRequest):
|
| 524 |
+
model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
|
| 525 |
+
layer_module = state["layers"][state["current_layer"]]
|
| 526 |
+
inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
|
| 527 |
+
with torch.no_grad(), _capture_at(layer_module) as bucket:
|
| 528 |
+
model(**inputs)
|
| 529 |
+
h_last = bucket["h"][0, -1].unsqueeze(0)
|
| 530 |
+
z = sae.encode(h_last)[0]
|
| 531 |
+
nz = z.nonzero(as_tuple=False).flatten()
|
| 532 |
+
vals = z[nz]
|
| 533 |
+
order = vals.argsort(descending=True)[:req.top_n]
|
| 534 |
+
top = [{"id": int(nz[i].item()), "act": float(vals[i].item())} for i in order]
|
| 535 |
+
return {"top": top, "n_features": sae.n_features}
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class EncodeFullRequest(BaseModel):
|
| 539 |
+
prompt: str
|
| 540 |
+
top_n: int = 16 # number of feature ROWS to return in the heatmap
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@app.post("/encode_full")
|
| 544 |
+
def encode_full(req: EncodeFullRequest):
|
| 545 |
+
"""Return a per-token feature activation grid for a single prompt.
|
| 546 |
+
|
| 547 |
+
Picks the top_n features ranked by *mean activation across all token
|
| 548 |
+
positions* (matches the official app.py heatmap definition), then returns
|
| 549 |
+
each feature's activation at every token position. Activations that
|
| 550 |
+
didn't make TopK at a given position are zero.
|
| 551 |
+
"""
|
| 552 |
+
model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
|
| 553 |
+
layer_module = state["layers"][state["current_layer"]]
|
| 554 |
+
inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
|
| 555 |
+
with torch.no_grad(), _capture_at(layer_module) as bucket:
|
| 556 |
+
model(**inputs)
|
| 557 |
+
h = bucket["h"][0] # (seq_len, d_model)
|
| 558 |
+
z = sae.encode(h) # (seq_len, n_features) sparse TopK
|
| 559 |
+
seq_len = z.shape[0]
|
| 560 |
+
# Token strings for column headers
|
| 561 |
+
ids = inputs["input_ids"][0].tolist()
|
| 562 |
+
tokens = [tokenizer.decode([t], skip_special_tokens=False) for t in ids]
|
| 563 |
+
|
| 564 |
+
# Rank features by mean activation across all positions
|
| 565 |
+
mean_per_feat = z.mean(dim=0)
|
| 566 |
+
top_vals, top_idx = mean_per_feat.topk(min(int(req.top_n), sae.n_features))
|
| 567 |
+
grid = z[:, top_idx] # (seq_len, top_n)
|
| 568 |
+
return {
|
| 569 |
+
"tokens": tokens,
|
| 570 |
+
"feature_ids": [int(i.item()) for i in top_idx],
|
| 571 |
+
"mean_acts": [float(v.item()) for v in top_vals],
|
| 572 |
+
# grid: outer list = features, inner list = positions (transposed for
|
| 573 |
+
# natural row-per-feature rendering in the UI)
|
| 574 |
+
"grid": [[float(grid[p, f].item()) for p in range(seq_len)]
|
| 575 |
+
for f in range(grid.shape[1])],
|
| 576 |
+
"seq_len": seq_len,
|
| 577 |
+
"n_features": sae.n_features,
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class EncodeBatchRequest(BaseModel):
|
| 582 |
+
prompts: list[str]
|
| 583 |
+
top_n: int = 20 # top features per prompt to return individually
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@app.post("/encode_batch")
|
| 587 |
+
def encode_batch(req: EncodeBatchRequest):
|
| 588 |
+
"""Encode N prompts and return per-sample top features + corpus-level stats.
|
| 589 |
+
|
| 590 |
+
For each prompt: encode the last-token residual through the SAE, return
|
| 591 |
+
its top_n firing features. Corpus-level: union of features that fired
|
| 592 |
+
at all, with per-feature firing rate (fraction of prompts where it
|
| 593 |
+
appeared) and mean activation.
|
| 594 |
+
"""
|
| 595 |
+
if not req.prompts:
|
| 596 |
+
return {"per_sample": [], "corpus_features": [], "n_features": state["sae"].n_features}
|
| 597 |
+
model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
|
| 598 |
+
layer_module = state["layers"][state["current_layer"]]
|
| 599 |
+
|
| 600 |
+
per_sample = []
|
| 601 |
+
union_act_sum: dict[int, float] = {}
|
| 602 |
+
union_count: dict[int, int] = {}
|
| 603 |
+
|
| 604 |
+
for idx, p in enumerate(req.prompts):
|
| 605 |
+
inputs = tokenizer(p, return_tensors="pt").to(model.device)
|
| 606 |
+
with torch.no_grad(), _capture_at(layer_module) as bucket:
|
| 607 |
+
model(**inputs)
|
| 608 |
+
h_last = bucket["h"][0, -1].unsqueeze(0)
|
| 609 |
+
z = sae.encode(h_last)[0]
|
| 610 |
+
nz = z.nonzero(as_tuple=False).flatten()
|
| 611 |
+
vals = z[nz]
|
| 612 |
+
order = vals.argsort(descending=True)
|
| 613 |
+
top_idx = nz[order][:req.top_n]
|
| 614 |
+
top = [{"id": int(top_idx[i].item()), "act": float(z[top_idx[i]].item())}
|
| 615 |
+
for i in range(len(top_idx))]
|
| 616 |
+
per_sample.append({
|
| 617 |
+
"i": idx,
|
| 618 |
+
"preview": p[:80] + ("…" if len(p) > 80 else ""),
|
| 619 |
+
"len": len(p),
|
| 620 |
+
"top": top,
|
| 621 |
+
"n_active": int(len(nz)),
|
| 622 |
+
})
|
| 623 |
+
# Union stats over ALL nonzero features, not just top
|
| 624 |
+
for fid, v in zip(nz.tolist(), vals.tolist()):
|
| 625 |
+
union_count[fid] = union_count.get(fid, 0) + 1
|
| 626 |
+
union_act_sum[fid] = union_act_sum.get(fid, 0.0) + float(v)
|
| 627 |
+
|
| 628 |
+
n = len(req.prompts)
|
| 629 |
+
corpus = []
|
| 630 |
+
for fid, cnt in union_count.items():
|
| 631 |
+
corpus.append({
|
| 632 |
+
"id": fid,
|
| 633 |
+
"fire_rate": cnt / n,
|
| 634 |
+
"mean_act": union_act_sum[fid] / cnt,
|
| 635 |
+
"n_samples": cnt,
|
| 636 |
+
})
|
| 637 |
+
# Sort by fire_rate desc then mean_act desc
|
| 638 |
+
corpus.sort(key=lambda r: (-r["fire_rate"], -r["mean_act"]))
|
| 639 |
+
return {
|
| 640 |
+
"per_sample": per_sample,
|
| 641 |
+
"corpus_features": corpus[:200], # cap to 200 most frequent
|
| 642 |
+
"n_features": sae.n_features,
|
| 643 |
+
"n_samples": n,
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
class CompareBatchRequest(BaseModel):
|
| 648 |
+
prompts_a: list[str]
|
| 649 |
+
prompts_b: list[str]
|
| 650 |
+
top_n: int = 30
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
@app.post("/compare_batch")
|
| 654 |
+
def compare_batch(req: CompareBatchRequest):
|
| 655 |
+
"""Differential feature mining between two prompt sets.
|
| 656 |
+
|
| 657 |
+
For each set: encode all prompts, compute per-feature firing rate
|
| 658 |
+
(fraction of prompts where the feature fires) and mean activation.
|
| 659 |
+
Rank features by |fire_rate_A − fire_rate_B|.
|
| 660 |
+
Returns top features that distinguish A from B.
|
| 661 |
+
"""
|
| 662 |
+
model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
|
| 663 |
+
layer_module = state["layers"][state["current_layer"]]
|
| 664 |
+
|
| 665 |
+
def _encode_set(prompts):
|
| 666 |
+
n_feats = sae.n_features
|
| 667 |
+
rate = torch.zeros(n_feats, dtype=torch.float32)
|
| 668 |
+
acts = torch.zeros(n_feats, dtype=torch.float32)
|
| 669 |
+
for p in prompts:
|
| 670 |
+
inputs = tokenizer(p, return_tensors="pt").to(model.device)
|
| 671 |
+
with torch.no_grad(), _capture_at(layer_module) as bucket:
|
| 672 |
+
model(**inputs)
|
| 673 |
+
h_last = bucket["h"][0, -1].unsqueeze(0)
|
| 674 |
+
z = sae.encode(h_last)[0].detach().to("cpu", torch.float32)
|
| 675 |
+
rate += (z != 0).float()
|
| 676 |
+
acts += z
|
| 677 |
+
if prompts:
|
| 678 |
+
rate /= len(prompts)
|
| 679 |
+
acts /= len(prompts)
|
| 680 |
+
return rate, acts
|
| 681 |
+
|
| 682 |
+
rate_a, acts_a = _encode_set(req.prompts_a)
|
| 683 |
+
rate_b, acts_b = _encode_set(req.prompts_b)
|
| 684 |
+
diff = (rate_a - rate_b).abs()
|
| 685 |
+
top_vals, top_idx = diff.topk(min(int(req.top_n), sae.n_features))
|
| 686 |
+
rows = []
|
| 687 |
+
for v, fid in zip(top_vals.tolist(), top_idx.tolist()):
|
| 688 |
+
rows.append({
|
| 689 |
+
"id": int(fid),
|
| 690 |
+
"diff": float(v),
|
| 691 |
+
"rate_a": float(rate_a[fid]),
|
| 692 |
+
"rate_b": float(rate_b[fid]),
|
| 693 |
+
"act_a": float(acts_a[fid]),
|
| 694 |
+
"act_b": float(acts_b[fid]),
|
| 695 |
+
"winner": "a" if rate_a[fid] >= rate_b[fid] else "b",
|
| 696 |
+
})
|
| 697 |
+
return {"top_diff": rows, "n_a": len(req.prompts_a), "n_b": len(req.prompts_b),
|
| 698 |
+
"n_features": sae.n_features}
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
class SynthRequest(BaseModel):
|
| 702 |
+
seed_prompts: list[str]
|
| 703 |
+
steering: list[SteerSpec] = []
|
| 704 |
+
max_new_tokens: int = 40
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
@app.post("/synth_batch")
|
| 708 |
+
def synth_batch(req: SynthRequest):
|
| 709 |
+
"""Bulk steered synthesis: run steered generate over N seed prompts.
|
| 710 |
+
|
| 711 |
+
Useful for the data-centric synthesis workflow: produce K examples
|
| 712 |
+
that fire feature F at strength α, for downstream training data.
|
| 713 |
+
"""
|
| 714 |
+
if not req.seed_prompts:
|
| 715 |
+
return {"results": []}
|
| 716 |
+
model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
|
| 717 |
+
layer_module = state["layers"][state["current_layer"]]
|
| 718 |
+
results = []
|
| 719 |
+
for seed in req.seed_prompts:
|
| 720 |
+
inputs = tokenizer(seed, return_tensors="pt").to(model.device)
|
| 721 |
+
with _hook_stack(layer_module, sae, req.steering):
|
| 722 |
+
with torch.no_grad():
|
| 723 |
+
out = model.generate(
|
| 724 |
+
**inputs,
|
| 725 |
+
max_new_tokens=req.max_new_tokens,
|
| 726 |
+
do_sample=False,
|
| 727 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 728 |
+
)
|
| 729 |
+
text = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 730 |
+
results.append({"seed": seed, "text": text})
|
| 731 |
+
return {"results": results}
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def _extract_per_token_probs(gen_out, prompt_len, tokenizer, top_k):
|
| 735 |
+
"""Build per-step probabilities + top-K candidate strings."""
|
| 736 |
+
new_ids = gen_out.sequences[0][prompt_len:].tolist()
|
| 737 |
+
if not new_ids:
|
| 738 |
+
return []
|
| 739 |
+
rows = []
|
| 740 |
+
for step, score_t in enumerate(gen_out.scores):
|
| 741 |
+
probs = torch.softmax(score_t[0].float(), dim=-1)
|
| 742 |
+
chosen_id = new_ids[step]
|
| 743 |
+
chosen_prob = float(probs[chosen_id].item())
|
| 744 |
+
top_vals, top_ids = probs.topk(min(top_k, probs.shape[0]))
|
| 745 |
+
top_ids_list = top_ids.tolist()
|
| 746 |
+
# Decode one batch (chosen + topK) to limit tokenizer overhead
|
| 747 |
+
decoded_chosen = tokenizer.decode([chosen_id], skip_special_tokens=False)
|
| 748 |
+
decoded_top = tokenizer.batch_decode([[t] for t in top_ids_list], skip_special_tokens=False)
|
| 749 |
+
topk = []
|
| 750 |
+
for tid, tv, ts in zip(top_ids_list, top_vals.tolist(), decoded_top):
|
| 751 |
+
topk.append({"tok": ts, "prob": float(tv), "is_chosen": tid == chosen_id})
|
| 752 |
+
# If the chosen token wasn't in top-K, append it explicitly
|
| 753 |
+
if chosen_id not in top_ids_list:
|
| 754 |
+
topk.append({"tok": decoded_chosen, "prob": chosen_prob, "is_chosen": True})
|
| 755 |
+
rows.append({"tok": decoded_chosen, "prob": chosen_prob, "topk": topk})
|
| 756 |
+
return rows
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
@app.post("/generate")
|
| 760 |
+
def generate(req: GenerateRequest):
|
| 761 |
+
model, tokenizer, sae = state["model"], state["tokenizer"], state["sae"]
|
| 762 |
+
layer_module = state["layers"][state["current_layer"]]
|
| 763 |
+
inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
|
| 764 |
+
prompt_len = int(inputs["input_ids"].shape[1])
|
| 765 |
+
|
| 766 |
+
base_acts = {}
|
| 767 |
+
if req.steering:
|
| 768 |
+
with torch.no_grad(), _capture_at(layer_module) as bucket:
|
| 769 |
+
model(**inputs)
|
| 770 |
+
z_base = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0]
|
| 771 |
+
for s in req.steering:
|
| 772 |
+
base_acts[s.id] = float(z_base[s.id].item())
|
| 773 |
+
|
| 774 |
+
gen_kwargs = dict(
|
| 775 |
+
**inputs,
|
| 776 |
+
max_new_tokens=req.max_new_tokens,
|
| 777 |
+
do_sample=False,
|
| 778 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 779 |
+
)
|
| 780 |
+
if req.return_probs:
|
| 781 |
+
gen_kwargs["return_dict_in_generate"] = True
|
| 782 |
+
gen_kwargs["output_scores"] = True
|
| 783 |
+
|
| 784 |
+
with _hook_stack(layer_module, sae, req.steering, prompt_len=prompt_len):
|
| 785 |
+
with torch.no_grad():
|
| 786 |
+
out = model.generate(**gen_kwargs)
|
| 787 |
+
seq = out.sequences[0] if req.return_probs else out[0]
|
| 788 |
+
text = tokenizer.decode(seq, skip_special_tokens=True)
|
| 789 |
+
per_token_probs = (_extract_per_token_probs(out, prompt_len, tokenizer, req.topk_display)
|
| 790 |
+
if req.return_probs else None)
|
| 791 |
+
|
| 792 |
+
steered_acts = {}
|
| 793 |
+
if req.steering:
|
| 794 |
+
with torch.no_grad(), _capture_at(layer_module) as bucket:
|
| 795 |
+
model(**inputs)
|
| 796 |
+
z_steered = sae.encode(bucket["h"][0, -1].unsqueeze(0))[0]
|
| 797 |
+
for s in req.steering:
|
| 798 |
+
steered_acts[s.id] = float(z_steered[s.id].item())
|
| 799 |
+
|
| 800 |
+
verifier = [
|
| 801 |
+
{"id": s.id, "alpha": s.alpha,
|
| 802 |
+
"positions": s.positions,
|
| 803 |
+
"output_only": s.output_only,
|
| 804 |
+
"base": base_acts.get(s.id, 0.0),
|
| 805 |
+
"steered": steered_acts.get(s.id, 0.0)}
|
| 806 |
+
for s in req.steering
|
| 807 |
+
]
|
| 808 |
+
resp = {"text": text, "verifier": verifier, "prompt_len": prompt_len}
|
| 809 |
+
if per_token_probs is not None:
|
| 810 |
+
resp["tokens"] = per_token_probs
|
| 811 |
+
return resp
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
if __name__ == "__main__":
|
| 815 |
+
import uvicorn
|
| 816 |
+
port = int(os.environ.get("PORT", 7860))
|
| 817 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|