writesae-ckpts / LOAD_EXAMPLE.py
JackYoung27's picture
Initial public release
f1850af
#!/usr/bin/env python3
"""Self-contained 30-line example: load WriteSAE atom and run substitution at the cache slot.
Reproduces the paper's headline 92.4% number on a single firing of feature 412.
"""
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
# 1. Download the primary cell SAE (5 MB).
ckpt_dir = snapshot_download(
"JackYoung27/writesae-ckpts",
allow_patterns=["writesae/qwen0p8b/L9_H4/*"],
)
ckpt = torch.load(f"{ckpt_dir}/writesae/qwen0p8b/L9_H4/best.pt",
weights_only=False, map_location="cpu")
# 2. Pick atom F412 (paper ERASE exemplar). Atom is rank-1 outer product v_i w_i^T.
v_412 = ckpt["sae"].decoder.v[412]
w_412 = ckpt["sae"].decoder.w[412]
atom = torch.outer(v_412, w_412)
print(f"Atom F412: shape {tuple(atom.shape)}, ||F = {atom.norm():.4f}")
# 3. To run the actual substitution test (atom replaces native cache write at one
# firing position, measure forward KL), see scripts/clean_amplified_kl.py in the
# code repo: https://anonymous.4open.science/r/WriteSAE-6158
print("\nNext: clone the code repo and run scripts/clean_amplified_kl.py --feature 412")