Ward 2025 Stage B β trained dictionaries (TXC, SAE, TSAE)
13 curated dictionary checkpoints from the Stage B paper-budget reproduction of Ward et al. 2025. Each checkpoint is a sparse dictionary trained on Llama-3.1-8B residual / attention / pre-LN activations at layer 10, used to steer DeepSeek-R1-Distill-Llama-8B into emitting backtracking tokens.
Companion to:
- Code: chainik1125/temp_xc, branch
aniket-ward-stage-b - Activation cache + B1 results: aniketdesh/ward-stage-b-cache
- Writeup:
results_b.md,results_b_behavioral.md
Checkpoints
checkpoints/
txc__resid_L10__k16__s42.pt # B1 Sonnet primary winner (s42)
txc__resid_L10__k16__s7.pt # multi-seed verification
txc__resid_L10__k16__s11.pt
txc__resid_L10__k16__s23.pt
txc_h13__resid_L10__k16__s42.pt # Han matryoshka Γ MD contrastive
txc_h13__resid_L10__k16__s7.pt
txc_h13__resid_L10__k16__s11.pt
txc_h13__resid_L10__k16__s23.pt
txc_h8__resid_L10__k16__s42.pt # Han multi-distance contrastive
topk_sae__ln1_L10__k64__s42.pt # best non-TXC SAE under Sonnet
stacked_sae__resid_L10__k16__s42.pt
tsae__resid_L10__k32__s42.pt # Han's TSAE (TopK variant)
tsae_paper__resid_L10__k32__s42.pt # Bhalla 2025 paper-faithful
architectures.py # build_arch / arch_forward dispatch
cell_id.py # cell-id parser/serializer
config.yaml # arch_kwargs + training config
Loading
import torch, yaml
from huggingface_hub import hf_hub_download
from architectures import build_arch # also bundled in this repo
from cell_id import Cell
cell_id = "txc__resid_L10__k16__s42"
ckpt_path = hf_hub_download(
repo_id="aniketdesh/ward-stage-b-dictionaries",
filename=f"checkpoints/{cell_id}.pt",
)
config_path = hf_hub_download(
repo_id="aniketdesh/ward-stage-b-dictionaries",
filename="config.yaml",
)
cfg = yaml.safe_load(open(config_path))
cell = Cell.from_id(cell_id)
arch_kw = cfg["txc"].get("arch_kwargs", {}).get(cell.arch, {})
model = build_arch(
arch=cell.arch,
d_in=cfg["txc"]["d_model"],
d_sae=cfg["txc"]["d_sae"],
T=cfg["txc"]["T"],
k=cell.k_per_position,
**arch_kw,
)
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(state["state_dict"])
model.eval()
Cell ID convention
<arch>__<hookpoint>__k<k>__s<seed>
- arch:
txc(TemporalCrosscoder),txc_h13/txc_h8(Han's contrastive variants),topk_sae(per-position TopK),stacked_sae(matryoshka H/L recon),tsae(Han's TemporalSAE w/ TopK),tsae_paper(Bhalla 2025 ReLU+L1). - hookpoint:
resid_L10(Ward's layer-10 residual),attn_L10,ln1_L10(pre-LN, captured via forward-pre-hook). - k_per_position: TopK target per offset slot (window-L0 = k Γ T = k Γ 6).
Steering vector extraction
Mined feature decoder rows are in the companion dataset
under features/<cell_id>.npz:
import numpy as np
z = np.load("features/txc__resid_L10__k16__s42.npz", allow_pickle=True)
top_features = z["top_features"] # (k_for_steering,) ranked by D+/D-
decoder_pos0 = z["decoder_at_pos0"] # (k, d_model) β single-T-slot direction
decoder_union = z["decoder_union"] # (k, d_model) β averaged across T slots
For the headline cell (txc__resid_L10__k16__s42) the winning steering
direction is decoder_at_pos0[idx] where top_features[idx] == 14621.
Caveat
The steering directions induce backtracking text behavior (Sonnet 4.6
behavioral judge confirms ~93% of keyword tokens reflect genuine
text-level course-corrections), but DO NOT improve MATH-500 answer
correctness β see B3 results in results_b_behavioral.md. Treat these
checkpoints as research artifacts for studying the linguistic surface
form of induced backtracking, not as a tool for boosting reasoning
performance.
- Downloads last month
- 4