File size: 1,156 Bytes
f1850af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python3
"""Self-contained 30-line example: load WriteSAE atom and run substitution at the cache slot.

Reproduces the paper's headline 92.4% number on a single firing of feature 412.
"""
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download

# 1. Download the primary cell SAE (5 MB).
ckpt_dir = snapshot_download(
    "JackYoung27/writesae-ckpts",
    allow_patterns=["writesae/qwen0p8b/L9_H4/*"],
)
ckpt = torch.load(f"{ckpt_dir}/writesae/qwen0p8b/L9_H4/best.pt",
                  weights_only=False, map_location="cpu")

# 2. Pick atom F412 (paper ERASE exemplar). Atom is rank-1 outer product v_i w_i^T.
v_412 = ckpt["sae"].decoder.v[412]
w_412 = ckpt["sae"].decoder.w[412]
atom = torch.outer(v_412, w_412)
print(f"Atom F412: shape {tuple(atom.shape)}, ||F = {atom.norm():.4f}")

# 3. To run the actual substitution test (atom replaces native cache write at one
#    firing position, measure forward KL), see scripts/clean_amplified_kl.py in the
#    code repo: https://anonymous.4open.science/r/WriteSAE-6158
print("\nNext: clone the code repo and run scripts/clean_amplified_kl.py --feature 412")