Ward 2025 Stage B β€” trained dictionaries (TXC, SAE, TSAE)

13 curated dictionary checkpoints from the Stage B paper-budget reproduction of Ward et al. 2025. Each checkpoint is a sparse dictionary trained on Llama-3.1-8B residual / attention / pre-LN activations at layer 10, used to steer DeepSeek-R1-Distill-Llama-8B into emitting backtracking tokens.

Companion to:

Checkpoints

checkpoints/
  txc__resid_L10__k16__s42.pt            # B1 Sonnet primary winner (s42)
  txc__resid_L10__k16__s7.pt             # multi-seed verification
  txc__resid_L10__k16__s11.pt
  txc__resid_L10__k16__s23.pt
  txc_h13__resid_L10__k16__s42.pt        # Han matryoshka Γ— MD contrastive
  txc_h13__resid_L10__k16__s7.pt
  txc_h13__resid_L10__k16__s11.pt
  txc_h13__resid_L10__k16__s23.pt
  txc_h8__resid_L10__k16__s42.pt         # Han multi-distance contrastive
  topk_sae__ln1_L10__k64__s42.pt         # best non-TXC SAE under Sonnet
  stacked_sae__resid_L10__k16__s42.pt
  tsae__resid_L10__k32__s42.pt           # Han's TSAE (TopK variant)
  tsae_paper__resid_L10__k32__s42.pt     # Bhalla 2025 paper-faithful

architectures.py                          # build_arch / arch_forward dispatch
cell_id.py                                # cell-id parser/serializer
config.yaml                               # arch_kwargs + training config

Loading

import torch, yaml
from huggingface_hub import hf_hub_download
from architectures import build_arch  # also bundled in this repo
from cell_id import Cell

cell_id = "txc__resid_L10__k16__s42"
ckpt_path = hf_hub_download(
    repo_id="aniketdesh/ward-stage-b-dictionaries",
    filename=f"checkpoints/{cell_id}.pt",
)
config_path = hf_hub_download(
    repo_id="aniketdesh/ward-stage-b-dictionaries",
    filename="config.yaml",
)
cfg = yaml.safe_load(open(config_path))

cell = Cell.from_id(cell_id)
arch_kw = cfg["txc"].get("arch_kwargs", {}).get(cell.arch, {})
model = build_arch(
    arch=cell.arch,
    d_in=cfg["txc"]["d_model"],
    d_sae=cfg["txc"]["d_sae"],
    T=cfg["txc"]["T"],
    k=cell.k_per_position,
    **arch_kw,
)
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(state["state_dict"])
model.eval()

Cell ID convention

<arch>__<hookpoint>__k<k>__s<seed>

  • arch: txc (TemporalCrosscoder), txc_h13 / txc_h8 (Han's contrastive variants), topk_sae (per-position TopK), stacked_sae (matryoshka H/L recon), tsae (Han's TemporalSAE w/ TopK), tsae_paper (Bhalla 2025 ReLU+L1).
  • hookpoint: resid_L10 (Ward's layer-10 residual), attn_L10, ln1_L10 (pre-LN, captured via forward-pre-hook).
  • k_per_position: TopK target per offset slot (window-L0 = k Γ— T = k Γ— 6).

Steering vector extraction

Mined feature decoder rows are in the companion dataset under features/<cell_id>.npz:

import numpy as np
z = np.load("features/txc__resid_L10__k16__s42.npz", allow_pickle=True)
top_features = z["top_features"]      # (k_for_steering,) ranked by D+/D-
decoder_pos0 = z["decoder_at_pos0"]   # (k, d_model) β€” single-T-slot direction
decoder_union = z["decoder_union"]    # (k, d_model) β€” averaged across T slots

For the headline cell (txc__resid_L10__k16__s42) the winning steering direction is decoder_at_pos0[idx] where top_features[idx] == 14621.

Caveat

The steering directions induce backtracking text behavior (Sonnet 4.6 behavioral judge confirms ~93% of keyword tokens reflect genuine text-level course-corrections), but DO NOT improve MATH-500 answer correctness β€” see B3 results in results_b_behavioral.md. Treat these checkpoints as research artifacts for studying the linguistic surface form of induced backtracking, not as a tool for boosting reasoning performance.

Downloads last month
4
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support