Upload folder using huggingface_hub
Browse files- README.md +306 -0
- layer12/sae_instruct_base_layer12.pt +3 -0
- layer12/sae_ppo_step100_layer12.pt +3 -0
- layer12/sae_ppo_step10_layer12.pt +3 -0
- layer12/sae_ppo_step140_layer12.pt +3 -0
- layer12/sae_ppo_step180_layer12.pt +3 -0
- layer12/sae_ppo_step200_layer12.pt +3 -0
- layer12/sae_ppo_step30_layer12.pt +3 -0
- layer18/sae_instruct_base_layer18.pt +3 -0
- layer18/sae_ppo_step100_layer18.pt +3 -0
- layer18/sae_ppo_step10_layer18.pt +3 -0
- layer18/sae_ppo_step140_layer18.pt +3 -0
- layer18/sae_ppo_step180_layer18.pt +3 -0
- layer18/sae_ppo_step200_layer18.pt +3 -0
- layer18/sae_ppo_step30_layer18.pt +3 -0
- layer23/sae_instruct_base_layer23.pt +3 -0
- layer23/sae_ppo_step100_layer23.pt +3 -0
- layer23/sae_ppo_step10_layer23.pt +3 -0
- layer23/sae_ppo_step140_layer23.pt +3 -0
- layer23/sae_ppo_step180_layer23.pt +3 -0
- layer23/sae_ppo_step200_layer23.pt +3 -0
- layer23/sae_ppo_step30_layer23.pt +3 -0
- layer6/sae_instruct_base_layer6.pt +3 -0
- layer6/sae_ppo_step100_layer6.pt +3 -0
- layer6/sae_ppo_step10_layer6.pt +3 -0
- layer6/sae_ppo_step140_layer6.pt +3 -0
- layer6/sae_ppo_step180_layer6.pt +3 -0
- layer6/sae_ppo_step200_layer6.pt +3 -0
- layer6/sae_ppo_step30_layer6.pt +3 -0
- loader.py +53 -0
README.md
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- sparse-autoencoder
|
| 5 |
+
- interpretability
|
| 6 |
+
- topk-sae
|
| 7 |
+
- qwen2.5
|
| 8 |
+
- ppo
|
| 9 |
+
- rlhf
|
| 10 |
+
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
| 11 |
+
library_name: pytorch
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# SAE-RL-Qwen0.5B-bylayers
|
| 15 |
+
|
| 16 |
+
Sparse autoencoders (TopK SAEs) trained on the residual stream of
|
| 17 |
+
`Qwen/Qwen2.5-0.5B-Instruct` and a set of PPO-finetuned checkpoints derived
|
| 18 |
+
from it. Each SAE is trained per-layer per-stage, so the repository contains
|
| 19 |
+
one file per `(training_stage, layer)` pair.
|
| 20 |
+
|
| 21 |
+
This release covers **layers 6, 12, 18, and 23** across seven training
|
| 22 |
+
stages (`instruct_base` plus PPO steps 10, 30, 100, 140, 180, 200) — 28 SAEs
|
| 23 |
+
in total.
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## Repository layout
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
layer6/ sae_instruct_base_layer6.pt
|
| 31 |
+
sae_ppo_step{10,30,100,140,180,200}_layer6.pt
|
| 32 |
+
layer12/ sae_instruct_base_layer12.pt
|
| 33 |
+
sae_ppo_step{10,30,100,140,180,200}_layer12.pt
|
| 34 |
+
layer18/ sae_instruct_base_layer18.pt
|
| 35 |
+
sae_ppo_step{10,30,100,140,180,200}_layer18.pt
|
| 36 |
+
layer23/ sae_instruct_base_layer23.pt
|
| 37 |
+
sae_ppo_step{10,30,100,140,180,200}_layer23.pt
|
| 38 |
+
loader.py Minimal TopKSAE class + load() helper
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Each checkpoint is a dict:
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
{
|
| 45 |
+
"state_dict": {...}, # TopKSAE parameters
|
| 46 |
+
"config": {"d_model", "d_sae", "k", "source"},
|
| 47 |
+
}
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Quickstart
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import torch
|
| 54 |
+
from loader import load_sae # provided in this repo
|
| 55 |
+
|
| 56 |
+
sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda")
|
| 57 |
+
x = ... # (N, d_model=896) residual-stream activations
|
| 58 |
+
x_hat, z = sae(x) # reconstruction, sparse code (k non-zeros/row)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Model provenance
|
| 64 |
+
|
| 65 |
+
- **Base model**: `Qwen/Qwen2.5-0.5B-Instruct` (24 decoder layers, `d_model = 896`)
|
| 66 |
+
- **PPO checkpoints**: PPO-without-SFT, trained on GSM8k with a
|
| 67 |
+
reward-model-based signal. Merged LoRA adapters into dense checkpoints
|
| 68 |
+
before activation collection. Steps released here: 10, 30, 100, 140, 180, 200.
|
| 69 |
+
- **Activation data**: 500k real (non-padding) tokens per `(stage, layer)`,
|
| 70 |
+
collected on GSM8k *train* prompts with `max_length=512`. Padding positions
|
| 71 |
+
were stripped before caching, so SAEs are trained only on real-token
|
| 72 |
+
activations.
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## SAE architecture (TopK)
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
b_pre ∈ R^{d_model} # pre-encoder centering; init to data mean
|
| 80 |
+
encoder: R^{d_model} → R^{d_sae} (Linear, bias=True)
|
| 81 |
+
decoder: R^{d_sae} → R^{d_model} (Linear, bias=True; cols unit-normed after every step)
|
| 82 |
+
|
| 83 |
+
encode(x):
|
| 84 |
+
z = encoder(x - b_pre)
|
| 85 |
+
keep top-k entries of z along the last dim, zero the rest → z_sparse
|
| 86 |
+
return z_sparse
|
| 87 |
+
|
| 88 |
+
forward(x):
|
| 89 |
+
return decoder(z_sparse), z_sparse
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Training loss = `MSE(x, x_hat) + aux_coeff · MSE(x, x_hat_dead)`, where
|
| 93 |
+
`x_hat_dead` is an auxiliary reconstruction that activates only dead encoder
|
| 94 |
+
rows on the current batch (revival term from
|
| 95 |
+
[Gao et al. 2024, Scaling and Evaluating Sparse Autoencoders](https://arxiv.org/abs/2406.04093)).
|
| 96 |
+
|
| 97 |
+
Dead features (fraction of active batches below `dead_threshold = 1e-4`) are
|
| 98 |
+
periodically **resampled** toward high-reconstruction-error tokens:
|
| 99 |
+
- pick a random token from the top-25% residual-error quartile;
|
| 100 |
+
- set that token's normalised activation as the dead row of the encoder
|
| 101 |
+
and the corresponding column of the decoder;
|
| 102 |
+
- reset the encoder bias to 0;
|
| 103 |
+
- **do not** reset the shared decoder bias (it encodes the learned residual-
|
| 104 |
+
stream mean and resetting it causes catastrophic loss spikes).
|
| 105 |
+
|
| 106 |
+
Decoder columns are re-projected to unit norm after every optimizer step.
|
| 107 |
+
The best-loss epoch is restored at the end of training to guard against
|
| 108 |
+
late-epoch spikes caused by the unit-norm projection fighting Adam's moments.
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## Hyperparameters used per layer
|
| 113 |
+
|
| 114 |
+
| Layer | d_model | d_sae | expansion | k | batch_size | lr | epochs | resample every | aux_coeff |
|
| 115 |
+
|------:|--------:|------:|----------:|----:|-----------:|------:|-------:|---------------:|----------:|
|
| 116 |
+
| 6 | 896 | 7168 | 8× | 64 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
|
| 117 |
+
| 12 | 896 | 14336 | 16× | 96 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
|
| 118 |
+
| 18 | 896 | 14336 | 16× | 128 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
|
| 119 |
+
| 23 | 896 | 28672 | 32× | 128 | 256 | 3e-4 | 40 | 5 epochs | 1/32 |
|
| 120 |
+
|
| 121 |
+
Shared across all SAEs:
|
| 122 |
+
- Optimizer: Adam.
|
| 123 |
+
- LR schedule: cosine decay to `lr / 10` over `epochs × steps_per_epoch`.
|
| 124 |
+
- Gradient clipping: max-norm 1.0.
|
| 125 |
+
- Dead-feature threshold: mean firing frequency `< 1e-4` per epoch.
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## Evaluation metrics
|
| 130 |
+
|
| 131 |
+
All metrics are computed on **held-out data**:
|
| 132 |
+
- *Reconstruction metrics* use the last 20% of cached activations for each
|
| 133 |
+
`(stage, layer)` pair.
|
| 134 |
+
- *CE-loss metrics* use the GSM8k **test** split (200 prompts, `max_length=256`).
|
| 135 |
+
|
| 136 |
+
### 1. Reconstruction MSE
|
| 137 |
+
|
| 138 |
+
Mean squared error per element, averaged over the held-out activation slice:
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
MSE = mean_{n, d} (x_{n,d} − x̂_{n,d})²
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
### 2. Fraction of Variance Explained (FVE)
|
| 145 |
+
|
| 146 |
+
```
|
| 147 |
+
FVE = 1 − MSE / Var(x)
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
`Var(x)` is computed over all elements of the held-out slice. Because
|
| 151 |
+
per-layer activation variance differs by an order of magnitude (notably low
|
| 152 |
+
at layer 23, where the residual stream is dominated by one direction near
|
| 153 |
+
the final layernorm), raw MSE is not comparable across layers — **use FVE
|
| 154 |
+
for cross-layer comparisons**.
|
| 155 |
+
|
| 156 |
+
### 3. Mean L0
|
| 157 |
+
|
| 158 |
+
Average number of non-zero entries per reconstructed token. For a TopK SAE
|
| 159 |
+
this is exactly `k` in expectation; reported empirically as a sanity check
|
| 160 |
+
for the stored checkpoint config.
|
| 161 |
+
|
| 162 |
+
### 4. Padding-safe model ΔCE with mean-ablation reference
|
| 163 |
+
|
| 164 |
+
The raw "splice the SAE and subtract" metric is unstable because
|
| 165 |
+
(i) padding positions are counted by `CausalLM` loss by default and
|
| 166 |
+
(ii) even on real tokens, a lossy reconstruction can peakify logits toward
|
| 167 |
+
high-prior tokens and artificially lower CE. Both bias the result. We fix
|
| 168 |
+
this with three changes:
|
| 169 |
+
|
| 170 |
+
1. Mask padding tokens in the loss: `labels[attention_mask == 0] = -100`
|
| 171 |
+
so padding positions do not contribute to CE in any run.
|
| 172 |
+
2. Splice only real-token positions in the forward hook:
|
| 173 |
+
`patched = where(attention_mask, sae(hidden), hidden)`. This matches the
|
| 174 |
+
training distribution, which excluded padding.
|
| 175 |
+
3. Compare against a mean-ablation arm, not the raw baseline.
|
| 176 |
+
|
| 177 |
+
Let `layer_idx` be the decoder layer we intervene on.
|
| 178 |
+
|
| 179 |
+
**Three CE losses are measured per prompt batch**:
|
| 180 |
+
- `L_baseline`: no intervention.
|
| 181 |
+
- `L_sae`: at `layer_idx`, replace real-token hidden states with the SAE
|
| 182 |
+
reconstruction. Padding is left untouched.
|
| 183 |
+
- `L_mean`: at `layer_idx`, replace real-token hidden states with a fixed
|
| 184 |
+
dataset-mean vector (estimated on 32 warm-up prompts' real-token
|
| 185 |
+
hidden states at the same layer).
|
| 186 |
+
|
| 187 |
+
The headline metric is:
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
frac_loss_recovered = (L_mean − L_sae) / (L_mean − L_baseline)
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
**Interpretation**:
|
| 194 |
+
- `1.0` = the SAE reconstruction preserves downstream CE perfectly.
|
| 195 |
+
- `0.0` = the SAE reconstruction is no better than collapsing the layer to
|
| 196 |
+
its mean vector.
|
| 197 |
+
- Values *above* 1 or *below* 0 flag measurement artifacts (e.g.
|
| 198 |
+
unintended smoothing that makes the logits *more* peaked on common
|
| 199 |
+
next-tokens than the baseline distribution).
|
| 200 |
+
|
| 201 |
+
This is the metric to cite when comparing SAEs — it is bounded, interpretable,
|
| 202 |
+
and insensitive to the per-layer variance differences that inflate or
|
| 203 |
+
deflate raw MSE.
|
| 204 |
+
|
| 205 |
+
### Why two CSVs?
|
| 206 |
+
|
| 207 |
+
The first evaluation pass (see `eval_report.json` / legacy
|
| 208 |
+
`sae_eval_metrics.csv` not shipped here) used a naive splice that counted
|
| 209 |
+
padding in the loss and replaced hidden states at all positions including
|
| 210 |
+
padding. That produced *negative* ΔCE (SAE "improves" the model), which is
|
| 211 |
+
a known artifact. The numbers in the table below come from the corrected
|
| 212 |
+
padding-safe evaluation only.
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
## Evaluation results (this release)
|
| 217 |
+
|
| 218 |
+
All values on GSM8k test, 200 prompts.
|
| 219 |
+
|
| 220 |
+
### Layer 6 (d_sae=7168, k=64)
|
| 221 |
+
|
| 222 |
+
| Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
|
| 223 |
+
|---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
|
| 224 |
+
| instruct_base | 0.021983 | 0.9995 | 64.00 | 2.4445 | 2.6010| 11.6041 | 0.9829 |
|
| 225 |
+
| ppo_step10 | 0.021669 | 0.9995 | 64.00 | 2.4658 | 2.6620| 11.6256 | 0.9786 |
|
| 226 |
+
| ppo_step30 | 0.022306 | 0.9995 | 64.00 | 2.6018 | 2.7759| 11.6498 | 0.9808 |
|
| 227 |
+
| ppo_step100 | 0.022328 | 0.9995 | 64.00 | 3.0376 | 3.2190| 11.7095 | 0.9791 |
|
| 228 |
+
| ppo_step140 | 0.021771 | 0.9995 | 64.00 | 3.1451 | 3.4048| 11.6460 | 0.9694 |
|
| 229 |
+
| ppo_step180 | 0.023316 | 0.9995 | 64.00 | 3.2228 | 3.4790| 11.8361 | 0.9703 |
|
| 230 |
+
| ppo_step200 | 0.028901 | 0.9994 | 64.00 | 3.2032 | 3.4738| 11.8848 | 0.9688 |
|
| 231 |
+
|
| 232 |
+
### Layer 12 (d_sae=14336, k=96)
|
| 233 |
+
|
| 234 |
+
| Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
|
| 235 |
+
|---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
|
| 236 |
+
| instruct_base | 0.031009 | 0.9993 | 96.00 | 2.4445 | 2.6575| 10.1495 | 0.9724 |
|
| 237 |
+
| ppo_step10 | 0.031459 | 0.9993 | 96.00 | 2.4658 | 2.6713| 10.1530 | 0.9733 |
|
| 238 |
+
| ppo_step30 | 0.030694 | 0.9994 | 96.00 | 2.6018 | 2.8250| 10.2270 | 0.9707 |
|
| 239 |
+
| ppo_step100 | 0.032453 | 0.9993 | 96.00 | 3.0376 | 3.4182| 10.5807 | 0.9495 |
|
| 240 |
+
| ppo_step140 | 0.037343 | 0.9992 | 96.00 | 3.1451 | 3.5767| 10.6014 | 0.9421 |
|
| 241 |
+
| ppo_step180 | 0.034286 | 0.9993 | 96.00 | 3.2228 | 3.6946| 10.7429 | 0.9373 |
|
| 242 |
+
| ppo_step200 | 0.037648 | 0.9992 | 96.00 | 3.2032 | 3.7207| 10.7912 | 0.9318 |
|
| 243 |
+
|
| 244 |
+
### Layer 18 (d_sae=14336, k=128)
|
| 245 |
+
|
| 246 |
+
| Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
|
| 247 |
+
|---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
|
| 248 |
+
| instruct_base | 0.132713 | 0.9973 | 128.00 | 2.4445 | 2.7164| 10.7944 | 0.9674 |
|
| 249 |
+
| ppo_step10 | 0.125266 | 0.9974 | 128.00 | 2.4658 | 2.7106| 10.8126 | 0.9707 |
|
| 250 |
+
| ppo_step30 | 0.131541 | 0.9973 | 128.00 | 2.6018 | 2.8919| 10.9570 | 0.9653 |
|
| 251 |
+
| ppo_step100 | 0.127065 | 0.9974 | 128.00 | 3.0376 | 3.4449| 11.2926 | 0.9507 |
|
| 252 |
+
| ppo_step140 | 0.135698 | 0.9972 | 128.00 | 3.1451 | 3.6207| 11.4038 | 0.9424 |
|
| 253 |
+
| ppo_step180 | 0.134804 | 0.9972 | 128.00 | 3.2228 | 3.6742| 11.4629 | 0.9452 |
|
| 254 |
+
| ppo_step200 | 0.128425 | 0.9973 | 128.00 | 3.2032 | 3.6708| 11.4725 | 0.9435 |
|
| 255 |
+
|
| 256 |
+
### Layer 23 (d_sae=28672, k=128)
|
| 257 |
+
|
| 258 |
+
| Stage | MSE | FVE | mean L0 | L_base | L_sae | L_mean | frac_loss_recovered |
|
| 259 |
+
|---------------:|---------:|-------:|--------:|-------:|------:|--------:|--------------------:|
|
| 260 |
+
| instruct_base | 0.440665 | 0.8560 | 128.00 | 2.4445 | 2.7846| 15.6202 | 0.9742 |
|
| 261 |
+
| ppo_step10 | 0.443813 | 0.8545 | 128.00 | 2.4658 | 2.8314| 15.9043 | 0.9728 |
|
| 262 |
+
| ppo_step30 | 0.447968 | 0.8501 | 128.00 | 2.6018 | 3.0068| 17.3308 | 0.9725 |
|
| 263 |
+
| ppo_step100 | 0.447669 | 0.8454 | 128.00 | 3.0376 | 3.5106| 20.0197 | 0.9721 |
|
| 264 |
+
| ppo_step140 | 0.441266 | 0.8461 | 128.00 | 3.1451 | 3.6356| 20.2853 | 0.9714 |
|
| 265 |
+
| ppo_step180 | 0.436758 | 0.8467 | 128.00 | 3.2228 | 3.7186| 20.6281 | 0.9715 |
|
| 266 |
+
| ppo_step200 | 0.430823 | 0.8482 | 128.00 | 3.2032 | 3.6975| 20.4202 | 0.9713 |
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## Caveats for downstream users
|
| 271 |
+
|
| 272 |
+
- **Do not judge layer 23 by its FVE.** Layer 23's FVE of ~0.85 looks much
|
| 273 |
+
worse than layers 6/12/18 (all >0.997), but its `frac_loss_recovered` is
|
| 274 |
+
~0.97 — comparable to layer 6 and *better* than layers 12 and 18 at late
|
| 275 |
+
PPO stages. The low FVE reflects layer 23's unusual activation geometry
|
| 276 |
+
(very low variance around a dominant direction near the final layernorm),
|
| 277 |
+
so most of the missing variance is noise that does not flow into the
|
| 278 |
+
logits. `frac_loss_recovered` is the metric to trust for downstream
|
| 279 |
+
usability.
|
| 280 |
+
- **`frac_loss_recovered` degrades across PPO stages for mid-network layers.**
|
| 281 |
+
Layer 12: 0.972 → 0.932. Layer 18: 0.967 → 0.944. Layer 6 and 23 are
|
| 282 |
+
roughly flat. If you are running feature analyses that compare early vs.
|
| 283 |
+
late PPO checkpoints at layers 12/18, expect higher reconstruction noise
|
| 284 |
+
at later stages. This is a likely interpretability signal (mid-network
|
| 285 |
+
features restructured by RL), not a training artifact.
|
| 286 |
+
- **`L_baseline` climbs with PPO steps** (2.44 → 3.20). The PPO model is
|
| 287 |
+
drifting from the GSM8k prompt-LM distribution as expected for PPO
|
| 288 |
+
without a KL anchor. Keep this in mind when comparing raw CE across
|
| 289 |
+
stages.
|
| 290 |
+
- SAEs were trained only on real tokens. Do **not** splice the SAE over
|
| 291 |
+
padding positions when using it at inference — replicate the
|
| 292 |
+
`where(attention_mask, sae(x), x)` pattern from the eval script.
|
| 293 |
+
|
| 294 |
+
---
|
| 295 |
+
|
| 296 |
+
## Reproducing
|
| 297 |
+
|
| 298 |
+
Scripts live in the training repo:
|
| 299 |
+
- `04_collect_activations.py`: cache per-layer residual-stream activations.
|
| 300 |
+
- `05_train_sae.py`: train one TopK SAE per activation file.
|
| 301 |
+
- `07_maskeval_sae_metrics.py`: run the padding-safe evaluation with
|
| 302 |
+
mean-ablation reference used to produce the numbers above.
|
| 303 |
+
|
| 304 |
+
## License
|
| 305 |
+
|
| 306 |
+
Apache-2.0, matching the base model.
|
layer12/sae_instruct_base_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d7a114279329ef46f161cbd34b377856d960dda43f224a3d65d641e5f26dc75
|
| 3 |
+
size 102828191
|
layer12/sae_ppo_step100_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c2a9bc67e5e40aa33ae88ce9bdfce1e704f78b56a4a4bc5085d9a66d3df6eca
|
| 3 |
+
size 102828105
|
layer12/sae_ppo_step10_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30782b140234d172811f3a83198646f277cc9c12c3b3a35badd0ae4bd2167012
|
| 3 |
+
size 102828094
|
layer12/sae_ppo_step140_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dc1e15b4027969eb20597c7224bbcd09275b6d2c6a68f1d66727709b91764c05
|
| 3 |
+
size 102828105
|
layer12/sae_ppo_step180_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba4986cd631cee9bfea3808e9dba5111664d6ef9e3928aae4eaece0b3d3e148b
|
| 3 |
+
size 102828105
|
layer12/sae_ppo_step200_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ec19ddd6914b1baafdb109bbe3b5c03d8ade30764a19cd795b22c3fcb294c41
|
| 3 |
+
size 102828105
|
layer12/sae_ppo_step30_layer12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9726cad1804fcfc0598e6022a124d295ab86bc5f2776615ec0b640dedb9dc6dc
|
| 3 |
+
size 102828094
|
layer18/sae_instruct_base_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e87177dbfe03aea89d8f794959ded884e9d30fad81c8c0d12fe9df56047ba02f
|
| 3 |
+
size 102828191
|
layer18/sae_ppo_step100_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ac615714831296d73983e7b3bf3934207a5624e7fadec11f84456c4d48fbef4
|
| 3 |
+
size 102828105
|
layer18/sae_ppo_step10_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:324f27f862651df1e16ffd0878fc676313290b9c2dc4c9d5467706e6cc469af4
|
| 3 |
+
size 102828094
|
layer18/sae_ppo_step140_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4c9fa0d943324e33801b269ffce90f22aa742c22500141c74d92e9632fd6a6e
|
| 3 |
+
size 102828105
|
layer18/sae_ppo_step180_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f5090c735eee8cb8e374db9e8c7d668e9574991594271e7417436f2c2aca724
|
| 3 |
+
size 102828105
|
layer18/sae_ppo_step200_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:919a5e991ef15813ec36c1b1c38b97a8679d167c5547058ba247ac44326b0657
|
| 3 |
+
size 102828105
|
layer18/sae_ppo_step30_layer18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45e49be205263c3eae4d5042c4516bcfff351c1ee43b6ff3009c305ea15e3100
|
| 3 |
+
size 102828094
|
layer23/sae_instruct_base_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77dca5a3336a1565741658fdd53cf2539c1dcb1919ffb1372f8909815a9fb275
|
| 3 |
+
size 205645983
|
layer23/sae_ppo_step100_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22505d691fca236bf7e03e72416be6f3eb4fc592dbff314019f4faf082277385
|
| 3 |
+
size 205645897
|
layer23/sae_ppo_step10_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:893570dc800833388cd1c9c1c8fbf6acca242a09a4f1d3cf6329ad94db95e6c9
|
| 3 |
+
size 205645886
|
layer23/sae_ppo_step140_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa6f7a73ca08b2c744d65aac50beff0c41337c6a046fd264c35f3fe23cd89c9d
|
| 3 |
+
size 205645897
|
layer23/sae_ppo_step180_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5ddbdc64da6ab7c3c327380998841c30a05a820b9c86665bf6504e02abb9c93
|
| 3 |
+
size 205645897
|
layer23/sae_ppo_step200_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2fe8cb7b3455ef1a9b6263ed93b56e22d448b7de1ed73c6e16feede4585fa92a
|
| 3 |
+
size 205645897
|
layer23/sae_ppo_step30_layer23.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1abb910db4e376c0a1e0dba401cbf82012970bcecaf60f0edff63519f81cde4a
|
| 3 |
+
size 205645886
|
layer6/sae_instruct_base_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60d32c83a744c76b6c00bb903dd6500d9e8ca5ef37bd785a36b7e9e43316f927
|
| 3 |
+
size 51419220
|
layer6/sae_ppo_step100_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e71797aeeb1980351616a737e0b52cc68bafd9c9f166113867a4aadbf2275a2
|
| 3 |
+
size 51419198
|
layer6/sae_ppo_step10_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0458a2ee3d552a92238d243da34e52ca98238f2a6f2f3df48b78fd82e6afa8c5
|
| 3 |
+
size 51419123
|
layer6/sae_ppo_step140_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c018dfc57813768e4b221a3aca38dd3c62718703e8222dfd6c8eed78022991e
|
| 3 |
+
size 51419198
|
layer6/sae_ppo_step180_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c6ed0d416a186549eeba0ab2b96541c9f7b6cc05419e52fb7a1c0dff255ac35
|
| 3 |
+
size 51419198
|
layer6/sae_ppo_step200_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:929a7eafae57602446acbd98dbf3a77e700fac670a0e9450dc58b7a85dbd892d
|
| 3 |
+
size 51419198
|
layer6/sae_ppo_step30_layer6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fed47f4fca909b40675bf2f7d56cd64e94b507041192a7bdb2d62e31a50ec4cb
|
| 3 |
+
size 51419123
|
loader.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal loader for the TopK SAEs in this repository.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
from loader import load_sae
|
| 5 |
+
sae, cfg = load_sae("layer6/sae_instruct_base_layer6.pt", device="cuda")
|
| 6 |
+
x_hat, z = sae(x) # x: (N, d_model=896)
|
| 7 |
+
|
| 8 |
+
The `sae(x)` forward returns:
|
| 9 |
+
x_hat: (N, d_model) reconstruction
|
| 10 |
+
z_sparse: (N, d_sae) sparse code, exactly `k` non-zeros per row
|
| 11 |
+
|
| 12 |
+
When splicing into the base model's residual stream, only replace
|
| 13 |
+
real-token positions (see README for the rationale):
|
| 14 |
+
patched = torch.where(mask.unsqueeze(-1).bool(), sae(h)[0], h)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TopKSAE(nn.Module):
|
| 22 |
+
def __init__(self, d_model: int, d_sae: int, k: int):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.k = k
|
| 25 |
+
self.d_model = d_model
|
| 26 |
+
self.d_sae = d_sae
|
| 27 |
+
self.b_pre = nn.Parameter(torch.zeros(d_model))
|
| 28 |
+
self.encoder = nn.Linear(d_model, d_sae, bias=True)
|
| 29 |
+
self.decoder = nn.Linear(d_sae, d_model, bias=True)
|
| 30 |
+
|
| 31 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
z = self.encoder(x - self.b_pre)
|
| 33 |
+
topk_values, topk_indices = torch.topk(z, self.k, dim=-1)
|
| 34 |
+
z_sparse = torch.zeros_like(z)
|
| 35 |
+
z_sparse.scatter_(-1, topk_indices, topk_values)
|
| 36 |
+
return z_sparse
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor):
|
| 39 |
+
z_sparse = self.encode(x)
|
| 40 |
+
return self.decoder(z_sparse), z_sparse
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_sae(path: str, device: str = "cpu"):
|
| 44 |
+
"""Load a checkpoint saved by the training pipeline.
|
| 45 |
+
|
| 46 |
+
Checkpoint format:
|
| 47 |
+
{"state_dict": ..., "config": {"d_model", "d_sae", "k", "source"}}
|
| 48 |
+
"""
|
| 49 |
+
ckpt = torch.load(path, map_location=device, weights_only=False)
|
| 50 |
+
cfg = ckpt["config"]
|
| 51 |
+
sae = TopKSAE(cfg["d_model"], cfg["d_sae"], cfg["k"])
|
| 52 |
+
sae.load_state_dict(ckpt["state_dict"], strict=False)
|
| 53 |
+
return sae.to(device).eval(), cfg
|