OhhMoo commited on
Commit
0ef92d9
·
verified ·
1 Parent(s): b38c79d

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - sparse-autoencoder
5
+ - interpretability
6
+ - topk-sae
7
+ - qwen2.5
8
+ - ppo
9
+ - rlhf
10
+ base_model: Qwen/Qwen2.5-0.5B-Instruct
11
+ library_name: pytorch
12
+ ---
13
+
14
+ # SAE-RL-Qwen0.5B-bylayers
15
+
16
+ Sparse autoencoders (TopK SAEs) trained on the residual stream of
17
+ `Qwen/Qwen2.5-0.5B-Instruct` and a set of PPO-finetuned checkpoints derived
18
+ from it. Each SAE is trained per-layer per-stage, so the repository contains
19
+ one file per `(training_stage, layer)` pair.
20
+
21
+ This release covers **layers 6, 12, 18, and 23** across seven training
22
+ stages (`instruct_base` plus PPO steps 10, 30, 100, 140, 180, 200) — 28 SAEs
23
+ in total.
24
+
25
+ ---
26
+
27
+ ## Repository layout
28
+
29
+ ```
30
+ layer6/ sae_instruct_base_layer6.pt
31
+ sae_ppo_step{10,30,100,140,180,200}_layer6.pt
32
+ layer12/ sae_instruct_base_layer12.pt
33
+ sae_ppo_step{10,30,100,140,180,200}_layer12.pt
34
+ layer18/ sae_instruct_base_layer18.pt
35
+ sae_ppo_step{10,30,100,140,180,200}_layer18.pt
36
+ layer23/ sae_instruct_base_layer23.pt
37
+ sae_ppo_step{10,30,100,140,180,200}_layer23.pt
38
+ loader.py Minimal TopKSAE class + load() helper
39
+ ```
40
+
41
+ Each checkpoint is a dict:
42
+
43
+ ```python
44
+ {
45
+ "state_dict": {...}, # TopKSAE parameters
46
+ "config": {"d_model", "d_sae", "k", "source"},
47
+ }
48
+ ```
49
+
50
+ ## Quickstart
51
+
52
+ ```python
53
+ import torch
54
+ from loader import load_sae # provided in this repo
55
+
56
+ sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda")
57
+ x = ... # (N, d_model=896) residual-stream activations
58
+ x_hat, z = sae(x) # reconstruction, sparse code (k non-zeros/row)
59
+ ```
60
+
61
+ ---
62
+
63
+ ## Model provenance
64
+
65
+ - **Base model**: `Qwen/Qwen2.5-0.5B-Instruct` (24 decoder layers, `d_model = 896`)
66
+ - **PPO checkpoints**: PPO-without-SFT, trained on GSM8k with a
67
+ reward-model-based signal. Merged LoRA adapters into dense checkpoints
68
+ before activation collection. Steps released here: 10, 30, 100, 140, 180, 200.
69
+ - **Activation data**: 500k real (non-padding) tokens per `(stage, layer)`,
70
+ collected on GSM8k *train* prompts with `max_length=512`. Padding positions
71
+ were stripped before caching, so SAEs are trained only on real-token
72
+ activations.
73
+
74
+ ---
75
+
76
+ ## SAE architecture (TopK)
77
+
78
+ ```
79
+ b_pre ∈ R^{d_model} # pre-encoder centering; init to data mean
80
+ encoder: R^{d_model} → R^{d_sae} (Linear, bias=True)
81
+ decoder: R^{d_sae} → R^{d_model} (Linear, bias=True; cols unit-normed after every step)
82
+
83
+ encode(x):
84
+ z = encoder(x - b_pre)
85
+ keep top-k entries of z along the last dim, zero the rest → z_sparse
86
+ return z_sparse
87
+
88
+ forward(x):
89
+ return decoder(z_sparse), z_sparse
90
+ ```
91
+
92
+ Training loss = `MSE(x, x_hat) + aux_coeff · MSE(x, x_hat_dead)`, where
93
+ `x_hat_dead` is an auxiliary reconstruction that activates only dead encoder
94
+ rows on the current batch (revival term from
95
+ [Gao et al. 2024, Scaling and Evaluating Sparse Autoencoders](https://arxiv.org/abs/2406.04093)).
96
+
97
+ Dead features (fraction of active batches below `dead_threshold = 1e-4`) are
98
+ periodically **resampled** toward high-reconstruction-error tokens:
99
+ - pick a random token from the top-25% residual-error quartile;
100
+ - set that token's normalised activation as the dead row of the encoder
101
+ and the corresponding column of the decoder;
102
+ - reset the encoder bias to 0;
103
+ - **do not** reset the shared decoder bias (it encodes the learned residual-
104
+ stream mean and resetting it causes catastrophic loss spikes).
105
+
106
+ Decoder columns are re-projected to unit norm after every optimizer step.
107
+ The best-loss epoch is restored at the end of training to guard against
108
+ late-epoch spikes caused by the unit-norm projection fighting Adam's moments.
109
+
110
+ ---
111
+
112
+ ## Hyperparameters used per layer
113
+
114
+ | Layer | d_model | d_sae | expansion | k | batch_size | lr | epochs | resample every | aux_coeff |
115
+ |------:|--------:|------:|----------:|----:|-----------:|------:|-------:|---------------:|----------:|
116
+ | 6 | 896 | 7168 | 8× | 64 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
117
+ | 12 | 896 | 14336 | 16× | 96 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
118
+ | 18 | 896 | 14336 | 16× | 128 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
119
+ | 23 | 896 | 28672 | 32× | 128 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
120
+
121
+ Shared across all SAEs:
122
+ - Optimizer: Adam.
123
+ - LR schedule: cosine decay to `lr / 10` over `epochs × steps_per_epoch`.
124
+ - Gradient clipping: max-norm 1.0.
125
+ - Dead-feature threshold: mean firing frequency `< 1e-4` per epoch.
126
+
127
+ ---
128
+
129
+ ## Evaluation metrics
130
+
131
+ All metrics are computed on **held-out data**:
132
+ - *Reconstruction metrics* use the last 20% of cached activations for each
133
+ `(stage, layer)` pair.
134
+ - *CE-loss metrics* use the GSM8k **test** split (200 prompts, `max_length=256`).
135
+
136
+ ### 1. Reconstruction MSE
137
+
138
+ Mean squared error per element, averaged over the held-out activation slice:
139
+
140
+ ```
141
+ MSE = mean_{n, d} (x_{n,d} − x̂_{n,d})²
142
+ ```
143
+
144
+ ### 2. Fraction of Variance Explained (FVE)
145
+
146
+ ```
147
+ FVE = 1 − MSE / Var(x)
148
+ ```
149
+
150
+ `Var(x)` is computed over all elements of the held-out slice. Because
151
+ per-layer activation variance differs by an order of magnitude (notably low
152
+ at layer 23, where the residual stream is dominated by one direction near
153
+ the final layernorm), raw MSE is not comparable across layers — **use FVE
154
+ for cross-layer comparisons**.
155
+
156
+ ### 3. Mean L0
157
+
158
+ Average number of non-zero entries per reconstructed token. For a TopK SAE
159
+ this is exactly `k` in expectation; reported empirically as a sanity check
160
+ for the stored checkpoint config.
161
+
162
+ ### 4. Padding-safe model ΔCE with mean-ablation reference
163
+
164
+ The raw "splice the SAE and subtract" metric is unstable because
165
+ (i) padding positions are counted by `CausalLM` loss by default and
166
+ (ii) even on real tokens, a lossy reconstruction can peakify logits toward
167
+ high-prior tokens and artificially lower CE. Both bias the result. We fix
168
+ this with three changes:
169
+
170
+ 1. Mask padding tokens in the loss: `labels[attention_mask == 0] = -100`
171
+ so padding positions do not contribute to CE in any run.
172
+ 2. Splice only real-token positions in the forward hook:
173
+ `patched = where(attention_mask, sae(hidden), hidden)`. This matches the
174
+ training distribution, which excluded padding.
175
+ 3. Compare against a mean-ablation arm, not the raw baseline.
176
+
177
+ Let `layer_idx` be the decoder layer we intervene on.
178
+
179
+ **Three CE losses are measured per prompt batch**:
180
+ - `L_baseline`: no intervention.
181
+ - `L_sae`: at `layer_idx`, replace real-token hidden states with the SAE
182
+ reconstruction. Padding is left untouched.
183
+ - `L_mean`: at `layer_idx`, replace real-token hidden states with a fixed
184
+ dataset-mean vector (estimated on 32 warm-up prompts' real-token
185
+ hidden states at the same layer).
186
+
187
+ The headline metric is:
188
+
189
+ ```
190
+ frac_loss_recovered = (L_mean − L_sae) / (L_mean − L_baseline)
191
+ ```
192
+
193
+ **Interpretation**:
194
+ - `1.0` = the SAE reconstruction preserves downstream CE perfectly.
195
+ - `0.0` = the SAE reconstruction is no better than collapsing the layer to
196
+ its mean vector.
197
+ - Values *above* 1 or *below* 0 flag measurement artifacts (e.g.
198
+ unintended smoothing that makes the logits *more* peaked on common
199
+ next-tokens than the baseline distribution).
200
+
201
+ This is the metric to cite when comparing SAEs — it is bounded, interpretable,
202
+ and insensitive to the per-layer variance differences that inflate or
203
+ deflate raw MSE.
204
+
205
+ ### Why two CSVs?
206
+
207
+ The first evaluation pass (see `eval_report.json` / legacy
208
+ `sae_eval_metrics.csv` not shipped here) used a naive splice that counted
209
+ padding in the loss and replaced hidden states at all positions including
210
+ padding. That produced *negative* ΔCE (SAE "improves" the model), which is
211
+ a known artifact. The numbers in the table below come from the corrected
212
+ padding-safe evaluation only.
213
+
214
+ ---
215
+
216
+ ## Evaluation results (this release)
217
+
218
+ All values on GSM8k test, 200 prompts.
219
+
220
+ ### Layer 6 (d_sae=7168, k=64)
221
+
222
+ | Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
223
+ |---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
224
+ | instruct_base | 0.021983 | 0.9995 | 64.00 | 2.4445 | 2.6010| 11.6041 | 0.9829 |
225
+ | ppo_step10 | 0.021669 | 0.9995 | 64.00 | 2.4658 | 2.6620| 11.6256 | 0.9786 |
226
+ | ppo_step30 | 0.022306 | 0.9995 | 64.00 | 2.6018 | 2.7759| 11.6498 | 0.9808 |
227
+ | ppo_step100 | 0.022328 | 0.9995 | 64.00 | 3.0376 | 3.2190| 11.7095 | 0.9791 |
228
+ | ppo_step140 | 0.021771 | 0.9995 | 64.00 | 3.1451 | 3.4048| 11.6460 | 0.9694 |
229
+ | ppo_step180 | 0.023316 | 0.9995 | 64.00 | 3.2228 | 3.4790| 11.8361 | 0.9703 |
230
+ | ppo_step200 | 0.028901 | 0.9994 | 64.00 | 3.2032 | 3.4738| 11.8848 | 0.9688 |
231
+
232
+ ### Layer 12 (d_sae=14336, k=96)
233
+
234
+ | Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
235
+ |---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
236
+ | instruct_base | 0.031009 | 0.9993 | 96.00 | 2.4445 | 2.6575| 10.1495 | 0.9724 |
237
+ | ppo_step10 | 0.031459 | 0.9993 | 96.00 | 2.4658 | 2.6713| 10.1530 | 0.9733 |
238
+ | ppo_step30 | 0.030694 | 0.9994 | 96.00 | 2.6018 | 2.8250| 10.2270 | 0.9707 |
239
+ | ppo_step100 | 0.032453 | 0.9993 | 96.00 | 3.0376 | 3.4182| 10.5807 | 0.9495 |
240
+ | ppo_step140 | 0.037343 | 0.9992 | 96.00 | 3.1451 | 3.5767| 10.6014 | 0.9421 |
241
+ | ppo_step180 | 0.034286 | 0.9993 | 96.00 | 3.2228 | 3.6946| 10.7429 | 0.9373 |
242
+ | ppo_step200 | 0.037648 | 0.9992 | 96.00 | 3.2032 | 3.7207| 10.7912 | 0.9318 |
243
+
244
+ ### Layer 18 (d_sae=14336, k=128)
245
+
246
+ | Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
247
+ |---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
248
+ | instruct_base | 0.132713 | 0.9973 | 128.00 | 2.4445 | 2.7164| 10.7944 | 0.9674 |
249
+ | ppo_step10 | 0.125266 | 0.9974 | 128.00 | 2.4658 | 2.7106| 10.8126 | 0.9707 |
250
+ | ppo_step30 | 0.131541 | 0.9973 | 128.00 | 2.6018 | 2.8919| 10.9570 | 0.9653 |
251
+ | ppo_step100 | 0.127065 | 0.9974 | 128.00 | 3.0376 | 3.4449| 11.2926 | 0.9507 |
252
+ | ppo_step140 | 0.135698 | 0.9972 | 128.00 | 3.1451 | 3.6207| 11.4038 | 0.9424 |
253
+ | ppo_step180 | 0.134804 | 0.9972 | 128.00 | 3.2228 | 3.6742| 11.4629 | 0.9452 |
254
+ | ppo_step200 | 0.128425 | 0.9973 | 128.00 | 3.2032 | 3.6708| 11.4725 | 0.9435 |
255
+
256
+ ### Layer 23 (d_sae=28672, k=128)
257
+
258
+ | Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
259
+ |---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
260
+ | instruct_base | 0.440665 | 0.8560 | 128.00 | 2.4445 | 2.7846| 15.6202 | 0.9742 |
261
+ | ppo_step10 | 0.443813 | 0.8545 | 128.00 | 2.4658 | 2.8314| 15.9043 | 0.9728 |
262
+ | ppo_step30 | 0.447968 | 0.8501 | 128.00 | 2.6018 | 3.0068| 17.3308 | 0.9725 |
263
+ | ppo_step100 | 0.447669 | 0.8454 | 128.00 | 3.0376 | 3.5106| 20.0197 | 0.9721 |
264
+ | ppo_step140 | 0.441266 | 0.8461 | 128.00 | 3.1451 | 3.6356| 20.2853 | 0.9714 |
265
+ | ppo_step180 | 0.436758 | 0.8467 | 128.00 | 3.2228 | 3.7186| 20.6281 | 0.9715 |
266
+ | ppo_step200 | 0.430823 | 0.8482 | 128.00 | 3.2032 | 3.6975| 20.4202 | 0.9713 |
267
+
268
+ ---
269
+
270
+ ## Caveats for downstream users
271
+
272
+ - **Do not judge layer 23 by its FVE.** Layer 23's FVE of ~0.85 looks much
273
+ worse than layers 6/12/18 (all >0.997), but its `frac_loss_recovered` is
274
+ ~0.97 — comparable to layer 6 and *better* than layers 12 and 18 at late
275
+ PPO stages. The low FVE reflects layer 23's unusual activation geometry
276
+ (very low variance around a dominant direction near the final layernorm),
277
+ so most of the missing variance is noise that does not flow into the
278
+ logits. `frac_loss_recovered` is the metric to trust for downstream
279
+ usability.
280
+ - **`frac_loss_recovered` degrades across PPO stages for mid-network layers.**
281
+ Layer 12: 0.972 → 0.932. Layer 18: 0.967 → 0.944. Layer 6 and 23 are
282
+ roughly flat. If you are running feature analyses that compare early vs.
283
+ late PPO checkpoints at layers 12/18, expect higher reconstruction noise
284
+ at later stages. This is a likely interpretability signal (mid-network
285
+ features restructured by RL), not a training artifact.
286
+ - **`L_baseline` climbs with PPO steps** (2.44 → 3.20). The PPO model is
287
+ drifting from the GSM8k prompt-LM distribution as expected for PPO
288
+ without a KL anchor. Keep this in mind when comparing raw CE across
289
+ stages.
290
+ - SAEs were trained only on real tokens. Do **not** splice the SAE over
291
+ padding positions when using it at inference — replicate the
292
+ `where(attention_mask, sae(x), x)` pattern from the eval script.
293
+
294
+ ---
295
+
296
+ ## Reproducing
297
+
298
+ Scripts live in the training repo:
299
+ - `04_collect_activations.py`: cache per-layer residual-stream activations.
300
+ - `05_train_sae.py`: train one TopK SAE per activation file.
301
+ - `07_maskeval_sae_metrics.py`: run the padding-safe evaluation with
302
+ mean-ablation reference used to produce the numbers above.
303
+
304
+ ## License
305
+
306
+ Apache-2.0, matching the base model.
layer12/sae_instruct_base_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d7a114279329ef46f161cbd34b377856d960dda43f224a3d65d641e5f26dc75
3
+ size 102828191
layer12/sae_ppo_step100_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c2a9bc67e5e40aa33ae88ce9bdfce1e704f78b56a4a4bc5085d9a66d3df6eca
3
+ size 102828105
layer12/sae_ppo_step10_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30782b140234d172811f3a83198646f277cc9c12c3b3a35badd0ae4bd2167012
3
+ size 102828094
layer12/sae_ppo_step140_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc1e15b4027969eb20597c7224bbcd09275b6d2c6a68f1d66727709b91764c05
3
+ size 102828105
layer12/sae_ppo_step180_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba4986cd631cee9bfea3808e9dba5111664d6ef9e3928aae4eaece0b3d3e148b
3
+ size 102828105
layer12/sae_ppo_step200_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ec19ddd6914b1baafdb109bbe3b5c03d8ade30764a19cd795b22c3fcb294c41
3
+ size 102828105
layer12/sae_ppo_step30_layer12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9726cad1804fcfc0598e6022a124d295ab86bc5f2776615ec0b640dedb9dc6dc
3
+ size 102828094
layer18/sae_instruct_base_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e87177dbfe03aea89d8f794959ded884e9d30fad81c8c0d12fe9df56047ba02f
3
+ size 102828191
layer18/sae_ppo_step100_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ac615714831296d73983e7b3bf3934207a5624e7fadec11f84456c4d48fbef4
3
+ size 102828105
layer18/sae_ppo_step10_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:324f27f862651df1e16ffd0878fc676313290b9c2dc4c9d5467706e6cc469af4
3
+ size 102828094
layer18/sae_ppo_step140_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4c9fa0d943324e33801b269ffce90f22aa742c22500141c74d92e9632fd6a6e
3
+ size 102828105
layer18/sae_ppo_step180_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f5090c735eee8cb8e374db9e8c7d668e9574991594271e7417436f2c2aca724
3
+ size 102828105
layer18/sae_ppo_step200_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:919a5e991ef15813ec36c1b1c38b97a8679d167c5547058ba247ac44326b0657
3
+ size 102828105
layer18/sae_ppo_step30_layer18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45e49be205263c3eae4d5042c4516bcfff351c1ee43b6ff3009c305ea15e3100
3
+ size 102828094
layer23/sae_instruct_base_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77dca5a3336a1565741658fdd53cf2539c1dcb1919ffb1372f8909815a9fb275
3
+ size 205645983
layer23/sae_ppo_step100_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22505d691fca236bf7e03e72416be6f3eb4fc592dbff314019f4faf082277385
3
+ size 205645897
layer23/sae_ppo_step10_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893570dc800833388cd1c9c1c8fbf6acca242a09a4f1d3cf6329ad94db95e6c9
3
+ size 205645886
layer23/sae_ppo_step140_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa6f7a73ca08b2c744d65aac50beff0c41337c6a046fd264c35f3fe23cd89c9d
3
+ size 205645897
layer23/sae_ppo_step180_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5ddbdc64da6ab7c3c327380998841c30a05a820b9c86665bf6504e02abb9c93
3
+ size 205645897
layer23/sae_ppo_step200_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fe8cb7b3455ef1a9b6263ed93b56e22d448b7de1ed73c6e16feede4585fa92a
3
+ size 205645897
layer23/sae_ppo_step30_layer23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abb910db4e376c0a1e0dba401cbf82012970bcecaf60f0edff63519f81cde4a
3
+ size 205645886
layer6/sae_instruct_base_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d32c83a744c76b6c00bb903dd6500d9e8ca5ef37bd785a36b7e9e43316f927
3
+ size 51419220
layer6/sae_ppo_step100_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e71797aeeb1980351616a737e0b52cc68bafd9c9f166113867a4aadbf2275a2
3
+ size 51419198
layer6/sae_ppo_step10_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0458a2ee3d552a92238d243da34e52ca98238f2a6f2f3df48b78fd82e6afa8c5
3
+ size 51419123
layer6/sae_ppo_step140_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c018dfc57813768e4b221a3aca38dd3c62718703e8222dfd6c8eed78022991e
3
+ size 51419198
layer6/sae_ppo_step180_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c6ed0d416a186549eeba0ab2b96541c9f7b6cc05419e52fb7a1c0dff255ac35
3
+ size 51419198
layer6/sae_ppo_step200_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:929a7eafae57602446acbd98dbf3a77e700fac670a0e9450dc58b7a85dbd892d
3
+ size 51419198
layer6/sae_ppo_step30_layer6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fed47f4fca909b40675bf2f7d56cd64e94b507041192a7bdb2d62e31a50ec4cb
3
+ size 51419123
loader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal loader for the TopK SAEs in this repository.
2
+
3
+ Usage:
4
+ from loader import load_sae
5
+ sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda")
6
+ x_hat, z = sae(x) # x: (N, d_model=896)
7
+
8
+ The `sae(x)` forward returns:
9
+ x_hat: (N, d_model) reconstruction
10
+ z_sparse: (N, d_sae) sparse code, exactly `k` non-zeros per row
11
+
12
+ When splicing into the base model's residual stream, only replace
13
+ real-token positions (see README for the rationale):
14
+ patched = torch.where(mask.unsqueeze(-1).bool(), sae(h)[0], h)
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class TopKSAE(nn.Module):
22
+ def __init__(self, d_model: int, d_sae: int, k: int):
23
+ super().__init__()
24
+ self.k = k
25
+ self.d_model = d_model
26
+ self.d_sae = d_sae
27
+ self.b_pre = nn.Parameter(torch.zeros(d_model))
28
+ self.encoder = nn.Linear(d_model, d_sae, bias=True)
29
+ self.decoder = nn.Linear(d_sae, d_model, bias=True)
30
+
31
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
32
+ z = self.encoder(x - self.b_pre)
33
+ topk_values, topk_indices = torch.topk(z, self.k, dim=-1)
34
+ z_sparse = torch.zeros_like(z)
35
+ z_sparse.scatter_(-1, topk_indices, topk_values)
36
+ return z_sparse
37
+
38
+ def forward(self, x: torch.Tensor):
39
+ z_sparse = self.encode(x)
40
+ return self.decoder(z_sparse), z_sparse
41
+
42
+
43
+ def load_sae(path: str, device: str = "cpu"):
44
+ """Load a checkpoint saved by the training pipeline.
45
+
46
+ Checkpoint format:
47
+ {"state_dict": ..., "config": {"d_model", "d_sae", "k", "source"}}
48
+ """
49
+ ckpt = torch.load(path, map_location=device, weights_only=False)
50
+ cfg = ckpt["config"]
51
+ sae = TopKSAE(cfg["d_model"], cfg["d_sae"], cfg["k"])
52
+ sae.load_state_dict(ckpt["state_dict"], strict=False)
53
+ return sae.to(device).eval(), cfg