SAE Γ— RL: Qwen2.5-0.5B on GSM8k (multi-condition)

Warm-start TopK SAEs trained on residual-stream activations of Qwen2.5-0.5B-Instruct across three PPO training conditions. The base model, training data, and SAE protocol are held fixed across all three; only the reward function and the KL coefficient vary. Purpose: measure what RL changes inside the model and which of those changes depend on the reward signal vs the optimization budget vs the layer architecture.

Three conditions

Chain Reward KL coef PPO stages saved Final PPO step
flexible last-number-correct (custom) 0.005 instruct_base, ppo_step{10,30,60,100,140,180,200} 200
strict verl built-in GSM8k (#### N format) 0.005 instruct_base, ppo_step{10,30,50,80,116} 116
kl0p025 last-number-correct (custom) 0.025 instruct_base, ppo_step{10,30} 58

Two A/B comparisons run on top of these:

  • flexible vs strict β€” same KL coef, different reward shape.
  • flexible vs kl0p025 β€” same reward, 5Γ— the KL coefficient.

What each chain tests

flexible is the reference run. The reward is permissive β€” credit is given if the last number anywhere in the response matches the answer, regardless of formatting. Baseline accuracy β‰ˆ 35%; the policy reaches β‰ˆ 70% by step 200. This chain is the longest (8 stages out to step 200) and is the canonical drift trajectory; everything else is compared to it. L23 was retrained at K=256 (the other layers stay at K=64) because L23's NMSE was unacceptably high at K=64; this is the only place the SAE protocol differs from the other chains.

strict is the first A/B partner. The reward function is replaced with verl's built-in GSM8k scorer, which only credits responses ending in the canonical #### N format. Baseline β‰ˆ 14% (vs flexible's 35%) β€” the same model with the same prompts has a ~3Γ— harder reward to satisfy, so the policy must change more aggressively in early training to hit nontrivial reward. KL coefficient is held at 0.005 to match flexible. This chain isolates what changes when the reward is harder while the KL budget is identical. Trained to step 116 (further runs were deprioritized once the directional findings reproduced).

kl0p025 is the second A/B partner. The reward function is identical to flexible, but the KL coefficient is 5Γ— larger (0.025 vs 0.005). The KL term constrains how far the policy can move from the reference per step, so a larger coefficient is a stronger leash. This chain isolates what changes when the optimization budget tightens while the reward is identical. Trained to step 58; SAEs at instruct_base, ppo_step10, ppo_step30 are sufficient to compute a per-step drift rate for comparison against flexible.

The three-chain design lets each measured effect be classified: a phenomenon that holds across all three is structural (a property of PPO + this base model + this dataset, independent of the specific reward and KL settings); one that flips between flexible and strict is reward-driven; one that flips between flexible and kl0p025 is KL-budget-driven.

Repository layout

sae_flexible/                       canonical flexible chain
    sae_{stage}_layer{6,12,18,23}.pt
sae_strict/
    k64/                            full strict chain at k=64 across all layers
        sae_{stage}_layer{6,12,18,23}.pt
    l23_k256/                       L23 retrained at k=256 (matches flexible L23 capacity)
        sae_{stage}_layer23.pt
sae_kl0p025/                        5x KL-coef sweep
    sae_{stage}_layer{6,12,18,23}.pt
results/                            eval CSVs and analysis outputs (per-chain)
loader.py                           convenience SAE loader
README.md

Each stage's SAE is warm-start initialised from the previous stage's weights, so feature indices align across checkpoints within a chain. This is what makes per-feature decoder-cosine drift a meaningful quantity.

Hyperparameters

Common across all chains (PPO)

Base model Qwen/Qwen2.5-0.5B-Instruct
Advantage estimator GAE
Train / mini / micro batch 256 / 64 / 8
Actor LR / Critic LR 1e-6 / 1e-5
Entropy coeff 0.001
use_kl_loss False (KL is in reward, not loss)
Rollout vLLM, n=8, temperature=1.0
Max prompt / response length 512 / 512
save_freq / test_freq 10 / 5
GPUs 2

Per-chain (PPO)

Chain Reward function KL coef Total epochs
flexible reward_gsm8k_flexible.py (compute_score) 0.005 4
strict verl built-in openai/gsm8k scorer 0.005 unset (ran to step 116)
kl0p025 reward_gsm8k_flexible.py (compute_score) 0.025 2

SAE (TopK with pre-encoder centering, common across all chains)

Architecture TopK SAE + b_pre (pre-encoder bias, init = data mean)
d_model / d_sae 896 / 7168
Expansion factor 8Γ—
Epochs / LR / Batch 20 / 1e-4 / 512
Optimizer Adam
LR schedule Cosine annealing β†’ lr/10
Gradient clip max_norm=1.0
Dead resample every 10 epochs, threshold 1e-4
Aux-k loss coeff 1/32
Decoder constraint unit-norm projection after every step
Warm-start init previous stage's SAE weights
Train / val split 80 / 20 random shuffle, seed=0
Saved checkpoint best-epoch val MSE

K (top-k sparsity) per layer per chain

Chain L6 L12 L18 L23
sae_flexible/ 64 64 64 256
sae_strict/k64/ 64 64 64 64
sae_strict/l23_k256/ β€” β€” β€” 256
sae_kl0p025/ 64 64 64 64

L23 is at K=256 in sae_flexible/ and sae_strict/l23_k256/ β€” these are the two chains intended for apples-to-apples L23 comparison. The sae_strict/k64/ L23 files are kept for completeness but their L23 NMSE is much higher (0.28-0.29 vs 0.15-0.18 at K=256) and they should not be used for cross-chain L23 analysis.

Activation collection (common)

Layers 6, 12, 18, 23
Max sequence length 512
Batch size 16
Tokens per (stage, layer) 2,000,000 (β‰ˆ 445,744 token positions)
Dataset GSM8k train split
Train / val split 356,596 / 89,148 rows, random shuffle, seed=0

Quality (per-chain, NMSE and CE-loss-recovery)

NMSE = MSE / Var(x) pooled over val activations (89,148 rows per (stage, layer)). frac_rec = (L_mean βˆ’ L_sae) / (L_mean βˆ’ L_base) on 100 GSM8k test prompts: 1 = SAE recovers all the loss a mean ablation would have lost; 0 = no better than mean. Real-token positions only; padding pass-through.

flexible chain (kl=0.005, 8 stages)

Layer K NMSE range frac_rec range
6 64 0.0004 – 0.0005 0.981 – 0.989
12 64 0.0007 – 0.0009 0.960 – 0.966
18 64 0.0034 – 0.0038 0.937 – 0.965
23 256 0.149 – 0.177 0.979 – 0.986

strict chain (kl=0.005, 6 stages)

sae_strict/k64/:

Layer K NMSE range frac_rec range
6 64 0.0003 – 0.0005 0.985 – 0.989
12 64 0.0007 – 0.0008 0.961 – 0.972
18 64 0.0032 – 0.0036 0.950 – 0.967
23 64 0.281 – 0.290 0.946 – 0.957

sae_strict/l23_k256/ (L23 only):

Layer K NMSE range frac_rec range
23 256 0.152 – 0.175 0.979 – 0.986

kl0p025 chain (kl=0.025, 3 stages)

Drift analysis available (scripts/analyze_kl_sweep.py); full eval_sae.py quality numbers are pending. Updates to follow.

Cross-chain trends

Decoder-cosine drift (per-step), three observations that hold across both reward shapes at matched KL=0.005:

  • L23 effective feature count collapses during PPO (5868 β†’ 4405 flex, 5855 β†’ 4665 strict). Feature distribution concentrates onto fewer directions while every feature stays alive.
  • L6/L12/L18 effective feature count expands during PPO (+30 to +58%). Feature distribution diffuses across more directions.
  • L18 encoder/decoder cosines decouple at the cold-start transition (corr β‰ˆ βˆ’0.14 flex, βˆ’0.10 strict). Decoder direction is preserved while encoder selectivity is rewired β€” features "die" when their old inputs no longer fire them and are "born" when encoder rewiring routes new inputs to existing decoder directions.

Cross-chain at different KL coef (flexible vs kl0p025), the prediction "per-step drift rate scales linearly with KL coef" is falsified: observed kl_high/flex ratios at the 0β†’10 and 10β†’30 transitions cluster in 0.6 – 2.3, never near the predicted 5.0. Drift response is layer-specific (L6 suppressed by stronger KL, L12 amplified, L18 invariant), not budget-set.

Loading SAEs

from loader import load_sae, load_chain

# Explicit path
sae, cfg = load_sae("sae_flexible/sae_ppo_step100_layer12.pt", device="cuda")

# Chain / stage / layer (convenience)
sae, cfg = load_chain("flexible", "ppo_step100", 12, root=".", device="cuda")
sae, cfg = load_chain("strict",   "ppo_step116", 23, k=256, root=".")  # β†’ sae_strict/l23_k256/
sae, cfg = load_chain("strict",   "ppo_step116", 23, k=64,  root=".")  # β†’ sae_strict/k64/
sae, cfg = load_chain("kl0p025",  "ppo_step30",  18, root=".")

# Forward
x_hat, z = sae(x)   # x: (N, 896)

When splicing SAE reconstruction into the residual stream of the base model for evaluation, replace only real-token positions:

patched = torch.where(mask.unsqueeze(-1).bool(), sae(h)[0], h)

Files

sae_{stage}_layer{N}.pt β€” SAE state_dict + config (d_model, d_sae, k, source, source_kind, seed, init_from_stage, best_epoch, best_val_loss, selection_metric).

Base model

Qwen/Qwen2.5-0.5B-Instruct β€” d_model=896, 24 transformer layers.

References

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for OhhMoo/sae-rl-qwen05b-layers

Finetuned
(752)
this model

Dataset used to train OhhMoo/sae-rl-qwen05b-layers