JackYoung27 commited on
Commit
3fcb6f7
·
verified ·
1 Parent(s): f1850af

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +53 -47
README.md CHANGED
@@ -19,72 +19,80 @@ language:
19
  pipeline_tag: feature-extraction
20
  ---
21
 
22
- # WriteSAE
23
 
24
- **WriteSAE: Sparse Autoencoders for Recurrent State**
25
 
26
- Jack Young
 
 
27
 
28
- [Paper](https://arxiv.org/abs/2605.12770) | [Website](https://www.jackyoung.io/research/writesae) | [Code](https://github.com/JackYoung27/writesae)
29
 
30
- WriteSAE factors each decoder atom as the rank-1 outer product **vᵢwᵢᵀ**, matching the native **kₜvₜᵀ** write that Gated DeltaNet, Mamba-2, and RWKV-7 install into a **dₖ × dᵥ** matrix cache. Residual SAEs cannot reach that write site; WriteSAE can. Atom substitution beats matched-Frobenius-norm ablation on **92.4%** of *n*=4,851 firings at Qwen3.5-0.8B L9 H4, the closed form predicts measured logit shifts at **R² = 0.98**, and sustained three-position installs lift midrank target-in-continuation from 33.3% to **100%** under greedy decoding. Cross-architecture: GDN rank-1 atoms transfer to Mamba-2-370M at 88.1% over 2,500 firings, with sharpness ordering GDN > RWKV-7 > Mamba-2.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  ## Quick start
33
 
34
  ```python
35
  from huggingface_hub import snapshot_download
36
- import torch
37
-
38
- ckpt_dir = snapshot_download(
39
- "JackYoung27/writesae-ckpts",
40
- allow_patterns=["writesae/qwen0p8b/L9_H4/*"],
41
- )
42
-
43
- ckpt = torch.load(
44
- f"{ckpt_dir}/writesae/qwen0p8b/L9_H4/best.pt",
45
- weights_only=False,
46
- map_location="cpu",
47
- )
48
-
49
- # Decoder atom 412 — the paper's ERASE example.
50
- v_412 = ckpt["sae"].decoder.v[412] # (d_k,)
51
- w_412 = ckpt["sae"].decoder.w[412] # (d_v,)
52
- atom = torch.outer(v_412, w_412) # (d_k, d_v)
53
  ```
54
 
55
- Standalone runnable in [`LOAD_EXAMPLE.py`](LOAD_EXAMPLE.py).
56
 
57
- ## Variants
 
 
 
 
 
 
58
 
59
- | variant | encoder | decoder | role |
60
- |---|---|---|---|
61
- | **WriteSAE** | bilinear vᵢᵀ S wᵢ | rank-1 vᵢwᵢᵀ | All headline numbers |
62
- | FlatSAE | linear on vec(S) | flat | Architectural-prior comparison |
63
- | MatrixSAE | linear on vec(S) | full-rank | Ablation |
64
- | BilinearSAE | bilinear | bilinear | Ablation |
65
 
66
- ## Base models covered
 
 
 
 
67
 
68
- Qwen3.5-0.8B (primary), Qwen3.5-4B, Qwen3.5-27B, Mamba-2-370M, RWKV-7-1.5B, DeltaNet-1.3B, GLA-1.3B. See [`MODEL_CARD.md`](MODEL_CARD.md) for full layer / head coverage and training details.
69
 
70
- ## Repository layout
71
 
72
- ```text
73
- writesae-ckpts/
74
- README.md
75
- MODEL_CARD.md
76
- manifest.json
77
- LOAD_EXAMPLE.py
78
- LICENSE
79
 
80
- writesae/<base-model>/<layer>_<head>/best.pt # primary cells
81
- flat_baseline/<base-model>_<layer>_<head>/best.pt # FlatSAE controls
82
- results/<test-name>/ # JSON outputs per paper claim
83
- ```
84
 
85
  ## Limitations
86
 
87
- The closed-form factorization predicts well only on Gated DeltaNet (R² = 0.98 at L9 H4); applied to Mamba-2 or Qwen3.5-4B, it returns negative R². The substitution test itself transfers to Mamba-2 (88.1%); the analytical coefficient does not. Per-atom identity varies across SAE seeds; the class-level register / bundle partition reproduces at CV 4–12%.
 
 
 
 
 
 
88
 
89
  ## Citation
90
 
@@ -97,5 +105,3 @@ The closed-form factorization predicts well only on Gated DeltaNet (R² = 0.98 a
97
  url = {https://github.com/JackYoung27/writesae}
98
  }
99
  ```
100
-
101
- MIT license. Base models retain their upstream licenses; no base-model weights are redistributed.
 
19
  pipeline_tag: feature-extraction
20
  ---
21
 
22
+ # WriteSAE: Sparse Autoencoders for Recurrent State
23
 
24
+ A sparse autoencoder for the matrix updates that Gated DeltaNet, Mamba-2, and RWKV-7 write into their recurrent cache each token. WriteSAE atoms are rank-1 matrices with the same shape as the model's own write, so a single atom can replace one native write at one position. Companion checkpoints for the paper *WriteSAE: Sparse Autoencoders for Recurrent State* ([arXiv:2605.12770](https://arxiv.org/abs/2605.12770)).
25
 
26
+ - **Code:** [github.com/JackYoung27/writesae](https://github.com/JackYoung27/writesae)
27
+ - **Project page:** [jackyoung.io/research/writesae](https://www.jackyoung.io/research/writesae)
28
+ - **Author:** [Jack Young](https://www.jackyoung.io), Indiana University ([youngjh@iu.edu](mailto:youngjh@iu.edu), ORCID [0009-0004-6785-303X](https://orcid.org/0009-0004-6785-303X)).
29
 
30
+ ## Headline result
31
 
32
+ At a single Gated DeltaNet layer-head on Qwen3.5-0.8B, the WriteSAE atom yields a closer final token distribution than deleting the write on **92.4%** of evaluated positions; averaged per atom, the rate is **89.8%**. A closed-form expression in the forget gate, read query, and output embedding predicts the per-firing logit change at **R²=0.98**. The same replacement test transfers to Mamba-2-370M at **88.1%**. In generation, writing the formula's chosen direction into three consecutive cache positions at the norm of the model's write makes tokens initially ranked 100–1000 by the unmodified model appear in **100%** of continuations, up from 33.3%. To our knowledge this is the first cache-level steering intervention in a state-space or hybrid recurrent layer.
33
+
34
+ ## Variants
35
+
36
+ | variant | encoder | decoder |
37
+ | --- | --- | --- |
38
+ | **WriteSAE** | $v_i^\top S w_i$ | $v_i w_i^\top$ (rank-1) |
39
+ | FlatSAE | linear on vec($S$) | flat |
40
+ | MatrixSAE | linear on vec($S$) | full-rank |
41
+ | BilinearSAE | $v_i^\top S w_i$ | bilinear |
42
+
43
+ WriteSAE is the primary artifact and supports all main-text results.
44
+
45
+ ## Base models covered
46
+
47
+ - Qwen3.5-0.8B (primary)
48
+ - Qwen3.5-4B (scale replication)
49
+ - Qwen3.5-27B (scale replication)
50
+ - Cross-architecture: DeltaNet 1.3B, GLA 1.3B, Mamba-2 2.8B, RWKV-7
51
 
52
  ## Quick start
53
 
54
  ```python
55
  from huggingface_hub import snapshot_download
56
+
57
+ ckpt_dir = snapshot_download("JackYoung27/writesae-ckpts", local_dir="ckpts")
58
+ # ckpts/manifest.json maps tags to SHA256 and metadata.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ```
60
 
61
+ Load and run with the companion code:
62
 
63
+ ```bash
64
+ git clone https://github.com/JackYoung27/writesae && cd matrix-sae
65
+ pip install -e .
66
+ python -m experiments.analysis.analyze \
67
+ --sae_checkpoint ckpts/writesae/qwen3p5-0p8b/L9_H4/best.pt \
68
+ --data_dir states --layer 9 --head 4 --output_dir out
69
+ ```
70
 
71
+ ## Training details
 
 
 
 
 
72
 
73
+ - Architecture: rank-1 decoder atoms $v_i w_i^\top$, bilinear encoder.
74
+ - Dictionary size: 16384 features (configurable).
75
+ - Sparsity: TopK activation; BatchTopK supported.
76
+ - Training data: OpenWebText (`Skylion007/openwebtext`, streaming), tokenized with the Qwen3.5 tokenizer.
77
+ - Training compute: ~180 H100-hours single-GPU total across variants (paper App. B.3).
78
 
79
+ ## Intended use
80
 
81
+ Interpretability research on matrix-recurrent and linear-attention model internals: decomposing register/bundle structure, validating cross-architecture transfer, and testing causal substitution experiments at the cache write site.
82
 
83
+ ## Out of scope
 
 
 
 
 
 
84
 
85
+ Production model editing, safety interventions without independent validation, or claims about individual atom identity. Atoms reproduce class-level structure; the basis is SAE-run specific (paper section 6).
 
 
 
86
 
87
  ## Limitations
88
 
89
+ - Single primary architecture (GatedDeltaNet); Mamba-2 and GLA are confirmed negative class.
90
+ - Small-model primary (0.8B); 4B and 27B replications supplement but do not replace the main evidence base.
91
+ - Mechanism claims are class-granular, not per-atom.
92
+
93
+ ## License
94
+
95
+ MIT.
96
 
97
  ## Citation
98
 
 
105
  url = {https://github.com/JackYoung27/writesae}
106
  }
107
  ```