aniketdesh commited on
Commit
246fdbe
·
verified ·
1 Parent(s): e6d3b08

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +120 -0
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - mechanistic-interpretability
7
+ - sparse-autoencoder
8
+ - temporal-crosscoder
9
+ - reasoning
10
+ - backtracking
11
+ - llama
12
+ ---
13
+
14
+ # Ward 2025 Stage B — trained dictionaries (TXC, SAE, TSAE)
15
+
16
+ 13 curated dictionary checkpoints from the Stage B paper-budget
17
+ reproduction of Ward et al. 2025. Each checkpoint is a sparse
18
+ dictionary trained on Llama-3.1-8B residual / attention / pre-LN
19
+ activations at layer 10, used to steer DeepSeek-R1-Distill-Llama-8B
20
+ into emitting backtracking tokens.
21
+
22
+ Companion to:
23
+ - **Code:** [chainik1125/temp_xc, branch `aniket-ward-stage-b`](https://github.com/chainik1125/temp_xc/tree/aniket-ward-stage-b)
24
+ - **Activation cache + B1 results:** [aniketdesh/ward-stage-b-cache](https://huggingface.co/datasets/aniketdesh/ward-stage-b-cache)
25
+ - **Writeup:** [`results_b.md`](https://github.com/chainik1125/temp_xc/blob/aniket-ward-stage-b/docs/aniket/experiments/ward_backtracking/results_b.md), [`results_b_behavioral.md`](https://github.com/chainik1125/temp_xc/blob/aniket-ward-stage-b/docs/aniket/experiments/ward_backtracking/results_b_behavioral.md)
26
+
27
+ ## Checkpoints
28
+
29
+ ```text
30
+ checkpoints/
31
+ txc__resid_L10__k16__s42.pt # B1 Sonnet primary winner (s42)
32
+ txc__resid_L10__k16__s7.pt # multi-seed verification
33
+ txc__resid_L10__k16__s11.pt
34
+ txc__resid_L10__k16__s23.pt
35
+ txc_h13__resid_L10__k16__s42.pt # Han matryoshka × MD contrastive
36
+ txc_h13__resid_L10__k16__s7.pt
37
+ txc_h13__resid_L10__k16__s11.pt
38
+ txc_h13__resid_L10__k16__s23.pt
39
+ txc_h8__resid_L10__k16__s42.pt # Han multi-distance contrastive
40
+ topk_sae__ln1_L10__k64__s42.pt # best non-TXC SAE under Sonnet
41
+ stacked_sae__resid_L10__k16__s42.pt
42
+ tsae__resid_L10__k32__s42.pt # Han's TSAE (TopK variant)
43
+ tsae_paper__resid_L10__k32__s42.pt # Bhalla 2025 paper-faithful
44
+
45
+ architectures.py # build_arch / arch_forward dispatch
46
+ cell_id.py # cell-id parser/serializer
47
+ config.yaml # arch_kwargs + training config
48
+ ```
49
+
50
+ ## Loading
51
+
52
+ ```python
53
+ import torch, yaml
54
+ from huggingface_hub import hf_hub_download
55
+ from architectures import build_arch # also bundled in this repo
56
+ from cell_id import Cell
57
+
58
+ cell_id = "txc__resid_L10__k16__s42"
59
+ ckpt_path = hf_hub_download(
60
+ repo_id="aniketdesh/ward-stage-b-dictionaries",
61
+ filename=f"checkpoints/{cell_id}.pt",
62
+ )
63
+ config_path = hf_hub_download(
64
+ repo_id="aniketdesh/ward-stage-b-dictionaries",
65
+ filename="config.yaml",
66
+ )
67
+ cfg = yaml.safe_load(open(config_path))
68
+
69
+ cell = Cell.from_id(cell_id)
70
+ arch_kw = cfg["txc"].get("arch_kwargs", {}).get(cell.arch, {})
71
+ model = build_arch(
72
+ arch=cell.arch,
73
+ d_in=cfg["txc"]["d_model"],
74
+ d_sae=cfg["txc"]["d_sae"],
75
+ T=cfg["txc"]["T"],
76
+ k=cell.k_per_position,
77
+ **arch_kw,
78
+ )
79
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
80
+ model.load_state_dict(state["state_dict"])
81
+ model.eval()
82
+ ```
83
+
84
+ ## Cell ID convention
85
+
86
+ `<arch>__<hookpoint>__k<k>__s<seed>`
87
+
88
+ - **arch**: `txc` (TemporalCrosscoder), `txc_h13` / `txc_h8` (Han's
89
+ contrastive variants), `topk_sae` (per-position TopK), `stacked_sae`
90
+ (matryoshka H/L recon), `tsae` (Han's TemporalSAE w/ TopK),
91
+ `tsae_paper` (Bhalla 2025 ReLU+L1).
92
+ - **hookpoint**: `resid_L10` (Ward's layer-10 residual), `attn_L10`,
93
+ `ln1_L10` (pre-LN, captured via forward-pre-hook).
94
+ - **k_per_position**: TopK target per offset slot (window-L0 = k × T = k × 6).
95
+
96
+ ## Steering vector extraction
97
+
98
+ Mined feature decoder rows are in the [companion dataset](https://huggingface.co/datasets/aniketdesh/ward-stage-b-cache)
99
+ under `features/<cell_id>.npz`:
100
+
101
+ ```python
102
+ import numpy as np
103
+ z = np.load("features/txc__resid_L10__k16__s42.npz", allow_pickle=True)
104
+ top_features = z["top_features"] # (k_for_steering,) ranked by D+/D-
105
+ decoder_pos0 = z["decoder_at_pos0"] # (k, d_model) — single-T-slot direction
106
+ decoder_union = z["decoder_union"] # (k, d_model) — averaged across T slots
107
+ ```
108
+
109
+ For the headline cell (`txc__resid_L10__k16__s42`) the winning steering
110
+ direction is `decoder_at_pos0[idx]` where `top_features[idx] == 14621`.
111
+
112
+ ## Caveat
113
+
114
+ The steering directions induce backtracking *text behavior* (Sonnet 4.6
115
+ behavioral judge confirms ~93% of keyword tokens reflect genuine
116
+ text-level course-corrections), but DO NOT improve MATH-500 answer
117
+ correctness — see B3 results in `results_b_behavioral.md`. Treat these
118
+ checkpoints as research artifacts for studying the linguistic surface
119
+ form of induced backtracking, not as a tool for boosting reasoning
120
+ performance.