Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -19,72 +19,80 @@ language:
|
|
| 19 |
pipeline_tag: feature-extraction
|
| 20 |
---
|
| 21 |
|
| 22 |
-
# WriteSAE
|
| 23 |
|
| 24 |
-
*
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
## Quick start
|
| 33 |
|
| 34 |
```python
|
| 35 |
from huggingface_hub import snapshot_download
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
"JackYoung27/writesae-ckpts",
|
| 40 |
-
allow_patterns=["writesae/qwen0p8b/L9_H4/*"],
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
ckpt = torch.load(
|
| 44 |
-
f"{ckpt_dir}/writesae/qwen0p8b/L9_H4/best.pt",
|
| 45 |
-
weights_only=False,
|
| 46 |
-
map_location="cpu",
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
# Decoder atom 412 — the paper's ERASE example.
|
| 50 |
-
v_412 = ckpt["sae"].decoder.v[412] # (d_k,)
|
| 51 |
-
w_412 = ckpt["sae"].decoder.w[412] # (d_v,)
|
| 52 |
-
atom = torch.outer(v_412, w_412) # (d_k, d_v)
|
| 53 |
```
|
| 54 |
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|---|---|---|---|
|
| 61 |
-
| **WriteSAE** | bilinear vᵢᵀ S wᵢ | rank-1 vᵢwᵢᵀ | All headline numbers |
|
| 62 |
-
| FlatSAE | linear on vec(S) | flat | Architectural-prior comparison |
|
| 63 |
-
| MatrixSAE | linear on vec(S) | full-rank | Ablation |
|
| 64 |
-
| BilinearSAE | bilinear | bilinear | Ablation |
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
|
| 70 |
-
|
| 71 |
|
| 72 |
-
|
| 73 |
-
writesae-ckpts/
|
| 74 |
-
README.md
|
| 75 |
-
MODEL_CARD.md
|
| 76 |
-
manifest.json
|
| 77 |
-
LOAD_EXAMPLE.py
|
| 78 |
-
LICENSE
|
| 79 |
|
| 80 |
-
|
| 81 |
-
flat_baseline/<base-model>_<layer>_<head>/best.pt # FlatSAE controls
|
| 82 |
-
results/<test-name>/ # JSON outputs per paper claim
|
| 83 |
-
```
|
| 84 |
|
| 85 |
## Limitations
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
## Citation
|
| 90 |
|
|
@@ -97,5 +105,3 @@ The closed-form factorization predicts well only on Gated DeltaNet (R² = 0.98 a
|
|
| 97 |
url = {https://github.com/JackYoung27/writesae}
|
| 98 |
}
|
| 99 |
```
|
| 100 |
-
|
| 101 |
-
MIT license. Base models retain their upstream licenses; no base-model weights are redistributed.
|
|
|
|
| 19 |
pipeline_tag: feature-extraction
|
| 20 |
---
|
| 21 |
|
| 22 |
+
# WriteSAE: Sparse Autoencoders for Recurrent State
|
| 23 |
|
| 24 |
+
A sparse autoencoder for the matrix updates that Gated DeltaNet, Mamba-2, and RWKV-7 write into their recurrent cache each token. WriteSAE atoms are rank-1 matrices with the same shape as the model's own write, so a single atom can replace one native write at one position. Companion checkpoints for the paper *WriteSAE: Sparse Autoencoders for Recurrent State* ([arXiv:2605.12770](https://arxiv.org/abs/2605.12770)).
|
| 25 |
|
| 26 |
+
- **Code:** [github.com/JackYoung27/writesae](https://github.com/JackYoung27/writesae)
|
| 27 |
+
- **Project page:** [jackyoung.io/research/writesae](https://www.jackyoung.io/research/writesae)
|
| 28 |
+
- **Author:** [Jack Young](https://www.jackyoung.io), Indiana University ([youngjh@iu.edu](mailto:youngjh@iu.edu), ORCID [0009-0004-6785-303X](https://orcid.org/0009-0004-6785-303X)).
|
| 29 |
|
| 30 |
+
## Headline result
|
| 31 |
|
| 32 |
+
At a single Gated DeltaNet layer-head on Qwen3.5-0.8B, the WriteSAE atom yields a closer final token distribution than deleting the write on **92.4%** of evaluated positions; averaged per atom, the rate is **89.8%**. A closed-form expression in the forget gate, read query, and output embedding predicts the per-firing logit change at **R²=0.98**. The same replacement test transfers to Mamba-2-370M at **88.1%**. In generation, writing the formula's chosen direction into three consecutive cache positions at 3× the norm of the model's write makes tokens initially ranked 100–1000 by the unmodified model appear in **100%** of continuations, up from 33.3%. To our knowledge this is the first cache-level steering intervention in a state-space or hybrid recurrent layer.
|
| 33 |
+
|
| 34 |
+
## Variants
|
| 35 |
+
|
| 36 |
+
| variant | encoder | decoder |
|
| 37 |
+
| --- | --- | --- |
|
| 38 |
+
| **WriteSAE** | $v_i^\top S w_i$ | $v_i w_i^\top$ (rank-1) |
|
| 39 |
+
| FlatSAE | linear on vec($S$) | flat |
|
| 40 |
+
| MatrixSAE | linear on vec($S$) | full-rank |
|
| 41 |
+
| BilinearSAE | $v_i^\top S w_i$ | bilinear |
|
| 42 |
+
|
| 43 |
+
WriteSAE is the primary artifact and supports all main-text results.
|
| 44 |
+
|
| 45 |
+
## Base models covered
|
| 46 |
+
|
| 47 |
+
- Qwen3.5-0.8B (primary)
|
| 48 |
+
- Qwen3.5-4B (scale replication)
|
| 49 |
+
- Qwen3.5-27B (scale replication)
|
| 50 |
+
- Cross-architecture: DeltaNet 1.3B, GLA 1.3B, Mamba-2 2.8B, RWKV-7
|
| 51 |
|
| 52 |
## Quick start
|
| 53 |
|
| 54 |
```python
|
| 55 |
from huggingface_hub import snapshot_download
|
| 56 |
+
|
| 57 |
+
ckpt_dir = snapshot_download("JackYoung27/writesae-ckpts", local_dir="ckpts")
|
| 58 |
+
# ckpts/manifest.json maps tags to SHA256 and metadata.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
```
|
| 60 |
|
| 61 |
+
Load and run with the companion code:
|
| 62 |
|
| 63 |
+
```bash
|
| 64 |
+
git clone https://github.com/JackYoung27/writesae && cd matrix-sae
|
| 65 |
+
pip install -e .
|
| 66 |
+
python -m experiments.analysis.analyze \
|
| 67 |
+
--sae_checkpoint ckpts/writesae/qwen3p5-0p8b/L9_H4/best.pt \
|
| 68 |
+
--data_dir states --layer 9 --head 4 --output_dir out
|
| 69 |
+
```
|
| 70 |
|
| 71 |
+
## Training details
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
- Architecture: rank-1 decoder atoms $v_i w_i^\top$, bilinear encoder.
|
| 74 |
+
- Dictionary size: 16384 features (configurable).
|
| 75 |
+
- Sparsity: TopK activation; BatchTopK supported.
|
| 76 |
+
- Training data: OpenWebText (`Skylion007/openwebtext`, streaming), tokenized with the Qwen3.5 tokenizer.
|
| 77 |
+
- Training compute: ~180 H100-hours single-GPU total across variants (paper App. B.3).
|
| 78 |
|
| 79 |
+
## Intended use
|
| 80 |
|
| 81 |
+
Interpretability research on matrix-recurrent and linear-attention model internals: decomposing register/bundle structure, validating cross-architecture transfer, and testing causal substitution experiments at the cache write site.
|
| 82 |
|
| 83 |
+
## Out of scope
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
Production model editing, safety interventions without independent validation, or claims about individual atom identity. Atoms reproduce class-level structure; the basis is SAE-run specific (paper section 6).
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
## Limitations
|
| 88 |
|
| 89 |
+
- Single primary architecture (GatedDeltaNet); Mamba-2 and GLA are confirmed negative class.
|
| 90 |
+
- Small-model primary (0.8B); 4B and 27B replications supplement but do not replace the main evidence base.
|
| 91 |
+
- Mechanism claims are class-granular, not per-atom.
|
| 92 |
+
|
| 93 |
+
## License
|
| 94 |
+
|
| 95 |
+
MIT.
|
| 96 |
|
| 97 |
## Citation
|
| 98 |
|
|
|
|
| 105 |
url = {https://github.com/JackYoung27/writesae}
|
| 106 |
}
|
| 107 |
```
|
|
|
|
|
|