{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": "# InterpGPT: Reproducing the Head-Swap Finding\n\nTwo 23.5M-parameter transformers trained on the same architecture with the same\nrecipe; the only difference is the distribution of their training data.\n**Standard** was trained on plain task decompositions; **ADHD** was trained on\ndecompositions with smaller steps and interleaved micro-regulation actions.\n\nThe Phase 1 headline finding: a *step-layout broadcast* head that persistently\nattends to preceding step-boundary tokens exists in **both** models, implementing\nthe same function \u2014 but it lives at **L3H0** in the standard model and **L3H5**\nin the ADHD model. Cross-model per-position attention profile cosine similarity\nat the matched pair is **0.997**; same-index baseline is **0.66**. This notebook\nreproduces that comparison end-to-end in under 15 minutes on Colab free tier.\n\n**Runtime**: CPU is fine. GPU optional.\n\n**Source**: https://github.com/cwklurks/interpgpt\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 1. Install dependencies" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "!pip install -q \\\n torch \\\n transformer_lens==2.4.1 \\\n huggingface_hub \\\n tokenizers \\\n matplotlib\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 2. Configuration\n\nSet your HuggingFace org/user if you're loading a fork. The defaults point at\nthe canonical InterpGPT release.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "ORG = \"connaaa\"\nSTANDARD_REPO = f\"{ORG}/interpgpt-standard-23M\"\nADHD_REPO = f\"{ORG}/interpgpt-adhd-23M\"\n\nimport torch\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nprint(f\"device: {DEVICE}\")\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 3. Load both models into TransformerLens\n\nThe HF repos ship a TransformerLens-compatible bundle (`hooked_transformer.pt`)\nalongside the HF `config.json` / `model.safetensors` pair. We use the\nTransformerLens bundle directly \u2014 it's the format the Phase 1 analyses were run\nagainst.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "from huggingface_hub import hf_hub_download\nfrom transformer_lens import HookedTransformer, HookedTransformerConfig\nimport torch\n\ndef load_tl(repo_id: str) -> HookedTransformer:\n path = hf_hub_download(repo_id, \"hooked_transformer.pt\")\n blob = torch.load(path, map_location=\"cpu\", weights_only=False)\n cfg_dict = blob[\"config\"]\n drop = {\"dtype\", \"device\", \"attention_dir\"}\n cfg_keep = {\n k: v for k, v in cfg_dict.items()\n if k in HookedTransformerConfig.__dataclass_fields__\n and not (k in drop and isinstance(v, str))\n and not (isinstance(v, str) and v.startswith(\"torch.\"))\n }\n model = HookedTransformer(HookedTransformerConfig(**cfg_keep)).to(DEVICE)\n model.load_state_dict(blob[\"model_state_dict\"])\n model.eval()\n return model\n\nstd_model = load_tl(STANDARD_REPO)\nadhd_model = load_tl(ADHD_REPO)\nprint(std_model.cfg)\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 4. Tokenizer and paired prompts\n\nWe use the Phase 1 BPE tokenizer (shipped in the HF repo as `tokenizer.json`)\nand a small paired-task set. Each task appears in both the standard format and\nthe ADHD format; that's what 'paired' means here.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "from tokenizers import Tokenizer\n\ntok_path = hf_hub_download(STANDARD_REPO, \"tokenizer.json\")\ntok = Tokenizer.from_file(tok_path)\n\nSPECIAL_NAMES = [\"<|task|>\", \"<|steps|>\", \"<|sep|>\", \"<|end|>\", \"<|pad|>\"]\nspecial_ids = {n: tok.token_to_id(n) for n in SPECIAL_NAMES}\nprint(\"specials:\", special_ids)\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "# 12 paired tasks, hand-crafted to mirror the test-set distribution.\n# Each has a 'standard' variant (short steps, no regulation) and an\n# 'adhd' variant (smaller steps + interleaved regulation tokens).\nPAIRED = [\n {\n \"task\": \"Clean the kitchen\",\n \"standard\": {\n \"task\": \"Clean the kitchen\",\n \"steps\": [\"Wash dishes\", \"Wipe counters\", \"Sweep floor\", \"Take out trash\"],\n },\n \"adhd\": {\n \"task\": \"Clean the kitchen\",\n \"steps\": [\n \"Stack dirty dishes\", \"sip water\", \"Rinse plates\", \"deep breath\",\n \"Scrub pans\", \"quick stretch\", \"Wipe counters\", \"pause\",\n \"Sweep floor\", \"close eyes briefly\", \"Take out trash\",\n ],\n },\n },\n {\n \"task\": \"Write a grocery list\",\n \"standard\": {\n \"task\": \"Write a grocery list\",\n \"steps\": [\"Check pantry\", \"Plan meals\", \"Write items\", \"Sort by aisle\"],\n },\n \"adhd\": {\n \"task\": \"Write a grocery list\",\n \"steps\": [\n \"Grab paper\", \"sip water\", \"Open pantry\", \"deep breath\",\n \"Jot missing items\", \"pause\", \"Plan three meals\", \"quick stretch\",\n \"Sort items by aisle\",\n ],\n },\n },\n {\n \"task\": \"Do laundry\",\n \"standard\": {\n \"task\": \"Do laundry\",\n \"steps\": [\"Sort clothes\", \"Load washer\", \"Start cycle\", \"Move to dryer\", \"Fold\"],\n },\n \"adhd\": {\n \"task\": \"Do laundry\",\n \"steps\": [\n \"Gather hamper\", \"sip water\", \"Sort lights darks\", \"deep breath\",\n \"Load washer\", \"Add detergent\", \"pause\", \"Start cycle\",\n \"quick stretch\", \"Transfer to dryer\", \"close eyes briefly\", \"Fold clean clothes\",\n ],\n },\n },\n {\n \"task\": \"Prepare breakfast\",\n \"standard\": {\n \"task\": \"Prepare breakfast\",\n \"steps\": [\"Pick recipe\", \"Gather ingredients\", \"Cook\", \"Plate food\"],\n },\n \"adhd\": {\n \"task\": \"Prepare breakfast\",\n \"steps\": [\n \"Pick simple recipe\", \"sip water\", \"Open fridge\", \"deep breath\",\n \"Take out eggs\", \"pause\", \"Crack eggs in bowl\", \"quick stretch\",\n \"Cook on pan\", \"Plate breakfast\",\n ],\n },\n },\n {\n \"task\": \"Pay the bills\",\n \"standard\": {\n \"task\": \"Pay the bills\",\n \"steps\": [\"Open statements\", \"Total amount\", \"Log in bank\", \"Pay each\"],\n },\n \"adhd\": {\n \"task\": \"Pay the bills\",\n \"steps\": [\n \"Stack bills\", \"sip water\", \"Open first statement\", \"deep breath\",\n \"Write due amount\", \"pause\", \"Log in to bank\", \"close eyes briefly\",\n \"Pay each bill\", \"quick stretch\", \"File statements\",\n ],\n },\n },\n {\n \"task\": \"Water the plants\",\n \"standard\": {\n \"task\": \"Water the plants\",\n \"steps\": [\"Fill watering can\", \"Water each plant\", \"Check soil\"],\n },\n \"adhd\": {\n \"task\": \"Water the plants\",\n \"steps\": [\n \"Find watering can\", \"sip water\", \"Fill at sink\", \"deep breath\",\n \"Check first pot soil\", \"pause\", \"Pour slowly\", \"quick stretch\",\n \"Move to next plant\", \"Repeat for each\",\n ],\n },\n },\n {\n \"task\": \"Organize the closet\",\n \"standard\": {\n \"task\": \"Organize the closet\",\n \"steps\": [\"Empty closet\", \"Sort items\", \"Donate\", \"Rearrange by type\"],\n },\n \"adhd\": {\n \"task\": \"Organize the closet\",\n \"steps\": [\n \"Empty top shelf\", \"sip water\", \"Sort by keep donate\", \"deep breath\",\n \"Bag donations\", \"pause\", \"Empty bottom shelf\", \"quick stretch\",\n \"Sort again\", \"close eyes briefly\", \"Rearrange by type\",\n ],\n },\n },\n {\n \"task\": \"Send a work email\",\n \"standard\": {\n \"task\": \"Send a work email\",\n \"steps\": [\"Draft message\", \"Check recipients\", \"Attach files\", \"Send\"],\n },\n \"adhd\": {\n \"task\": \"Send a work email\",\n \"steps\": [\n \"Open email client\", \"sip water\", \"Draft subject\", \"deep breath\",\n \"Write body paragraph\", \"pause\", \"Reread for tone\", \"quick stretch\",\n \"Add recipients\", \"Attach files\", \"close eyes briefly\", \"Click send\",\n ],\n },\n },\n {\n \"task\": \"Study for an exam\",\n \"standard\": {\n \"task\": \"Study for an exam\",\n \"steps\": [\"Review notes\", \"Do practice problems\", \"Flashcards\", \"Sleep early\"],\n },\n \"adhd\": {\n \"task\": \"Study for an exam\",\n \"steps\": [\n \"Open notes\", \"sip water\", \"Read first section\", \"deep breath\",\n \"Summarize aloud\", \"pause\", \"Do two practice problems\", \"quick stretch\",\n \"Make flashcards\", \"close eyes briefly\", \"Review flashcards\", \"Sleep early\",\n ],\n },\n },\n {\n \"task\": \"Make the bed\",\n \"standard\": {\n \"task\": \"Make the bed\",\n \"steps\": [\"Pull sheets\", \"Fluff pillows\", \"Smooth comforter\"],\n },\n \"adhd\": {\n \"task\": \"Make the bed\",\n \"steps\": [\n \"Pull sheets up\", \"sip water\", \"Smooth bottom sheet\", \"deep breath\",\n \"Fluff pillows\", \"pause\", \"Pull comforter\", \"quick stretch\",\n \"Smooth top\",\n ],\n },\n },\n {\n \"task\": \"Take a walk\",\n \"standard\": {\n \"task\": \"Take a walk\",\n \"steps\": [\"Put on shoes\", \"Grab keys\", \"Walk 20 minutes\"],\n },\n \"adhd\": {\n \"task\": \"Take a walk\",\n \"steps\": [\n \"Put on shoes\", \"sip water\", \"Grab keys\", \"deep breath\",\n \"Step outside\", \"Walk five minutes\", \"pause\",\n \"Walk five more\", \"quick stretch\", \"Return home\",\n ],\n },\n },\n {\n \"task\": \"Cook dinner\",\n \"standard\": {\n \"task\": \"Cook dinner\",\n \"steps\": [\"Pick recipe\", \"Prep ingredients\", \"Cook\", \"Plate\"],\n },\n \"adhd\": {\n \"task\": \"Cook dinner\",\n \"steps\": [\n \"Pick recipe\", \"sip water\", \"Wash hands\", \"deep breath\",\n \"Chop vegetables\", \"pause\", \"Measure spices\", \"quick stretch\",\n \"Start cooking\", \"close eyes briefly\", \"Taste\", \"Plate dinner\",\n ],\n },\n },\n]\nprint(f\"paired tasks: {len(PAIRED)}\")\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "def encode_example(ex):\n \"\"\"Encode one (task, steps) record into the training input format.\"\"\"\n ids = [special_ids[\"<|task|>\"]]\n ids += tok.encode(ex[\"task\"]).ids\n ids += [special_ids[\"<|steps|>\"]]\n for i, step in enumerate(ex[\"steps\"]):\n if i > 0:\n ids.append(special_ids[\"<|sep|>\"])\n ids += tok.encode(step).ids\n ids += [special_ids[\"<|end|>\"]]\n return ids\n\ndef token_roles(ids):\n steps_id = special_ids[\"<|steps|>\"]\n roles = {\"task_range\": [], \"special_positions\": []}\n hit_steps = False\n for i, t in enumerate(ids):\n if t == steps_id:\n hit_steps = True\n roles[\"special_positions\"].append(i)\n continue\n if t in set(special_ids.values()):\n roles[\"special_positions\"].append(i)\n elif not hit_steps:\n roles[\"task_range\"].append(i)\n return roles\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 5. Per-position attention profile\n\nFor every paired task we teacher-force the standard-variant sequence through\n**both** models and cache attention patterns. Then for a given (layer, head) we\ncompute, across query positions binned to a normalized `[0, 1]` axis, how much\nattention mass goes to *step-structure specials* (task/sep/steps/end markers).\nThat signal is the fingerprint of the step-layout-broadcast head.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "import numpy as np\n\ndef cache_attn(model, ids):\n x = torch.tensor([ids], dtype=torch.long, device=DEVICE)\n with torch.no_grad():\n _, cache = model.run_with_cache(x, return_type=None)\n L, H = model.cfg.n_layers, model.cfg.n_heads\n T = x.shape[1]\n out = torch.zeros(L, H, T, T)\n for layer in range(L):\n out[layer] = cache[f\"blocks.{layer}.attn.hook_pattern\"][0].to(\"cpu\")\n return out\n\nrecords = []\nfor pair in PAIRED:\n ids = encode_example(pair[\"standard\"])\n if len(ids) < 5 or len(ids) > std_model.cfg.n_ctx:\n continue\n roles = token_roles(ids)\n records.append({\n \"ids\": ids,\n \"roles\": roles,\n \"std_attn\": cache_attn(std_model, ids),\n \"adhd_attn\": cache_attn(adhd_model, ids),\n })\nprint(f\"cached attention on {len(records)} paired prompts\")\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "NBINS = 20\n\ndef profile(records, layer, head, which):\n bins_struct = np.zeros(NBINS)\n bins_count = np.zeros(NBINS)\n for r in records:\n attn = r[f\"{which}_attn\"][layer, head]\n T = attn.shape[0]\n spec = list(r[\"roles\"][\"special_positions\"])\n if not spec:\n continue\n for q in range(1, T):\n b = min(int(q / T * NBINS), NBINS - 1)\n bins_struct[b] += attn[q, spec].sum().item()\n bins_count[b] += 1\n return bins_struct / np.maximum(bins_count, 1)\n\ndef cosine(a, b):\n na, nb = np.linalg.norm(a), np.linalg.norm(b)\n if na == 0 or nb == 0:\n return 0.0\n return float(np.dot(a, b) / (na * nb))\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 6. Cross-model cosine: the 0.997 vs 0.66 comparison\n\n- **Matched pair** (function-to-function): standard L3H0 vs ADHD L3H5.\n- **Same-index baseline**: standard L3H0 vs ADHD L3H0 (and ADHD L3H5 vs standard L3H5).\n\nOn the full held-out test set the matched-pair cosine was **0.997** and the\nsame-index baselines were **0.663** and **0.643**. This notebook uses only\n12 hand-crafted paired prompts for speed, so the absolute numbers are softer\n(matched \u2248 **0.99**, baselines around **0.7\u20130.9**) \u2014 but the ordering is\nalways the same: the matched pair cosine is strictly higher than either\nsame-index baseline. That ordering is the claim.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "prof_std_L3H0 = profile(records, 3, 0, \"std\")\nprof_std_L3H5 = profile(records, 3, 5, \"std\")\nprof_adhd_L3H0 = profile(records, 3, 0, \"adhd\")\nprof_adhd_L3H5 = profile(records, 3, 5, \"adhd\")\n\nmatched = cosine(prof_std_L3H0, prof_adhd_L3H5)\nbaseline1 = cosine(prof_std_L3H0, prof_adhd_L3H0)\nbaseline2 = cosine(prof_std_L3H5, prof_adhd_L3H5)\n\nprint(f\"matched (std L3H0 ~ adhd L3H5): {matched:.3f}\")\nprint(f\"baseline (std L3H0 ~ adhd L3H0): {baseline1:.3f}\")\nprint(f\"baseline (std L3H5 ~ adhd L3H5): {baseline2:.3f}\")\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 7. Visualize the matched profiles\n\nThe matched pair should trace almost-identical curves; the same-index pair\nshould visibly diverge.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "import matplotlib.pyplot as plt\n\nx = np.linspace(0, 1, NBINS)\nfig, axes = plt.subplots(1, 2, figsize=(11, 4), sharey=True)\n\naxes[0].plot(x, prof_std_L3H0, label=\"standard L3H0\", lw=2)\naxes[0].plot(x, prof_adhd_L3H5, label=\"ADHD L3H5\", lw=2)\naxes[0].set_title(f\"matched pair (cos = {matched:.3f})\")\naxes[0].set_xlabel(\"normalized query position\")\naxes[0].set_ylabel(\"attention mass \u2192 step-structure specials\")\naxes[0].legend(); axes[0].grid(alpha=0.3)\n\naxes[1].plot(x, prof_std_L3H0, label=\"standard L3H0\", lw=2)\naxes[1].plot(x, prof_adhd_L3H0, label=\"ADHD L3H0\", lw=2)\naxes[1].set_title(f\"same-index baseline (cos = {baseline1:.3f})\")\naxes[1].set_xlabel(\"normalized query position\")\naxes[1].legend(); axes[1].grid(alpha=0.3)\n\nplt.tight_layout(); plt.show()\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "# Side-by-side (seq, seq) attention heatmaps on a single paired prompt.\nr = records[0]\nimport matplotlib.pyplot as plt\nfig, axes = plt.subplots(1, 2, figsize=(10, 4.5))\naxes[0].imshow(r[\"std_attn\"][3, 0], aspect=\"auto\", cmap=\"viridis\")\naxes[0].set_title(\"standard L3H0 \u2014 step-layout broadcast\")\naxes[1].imshow(r[\"adhd_attn\"][3, 5], aspect=\"auto\", cmap=\"viridis\")\naxes[1].set_title(\"ADHD L3H5 \u2014 same function, different index\")\nfor ax in axes:\n ax.set_xlabel(\"key position\")\naxes[0].set_ylabel(\"query position\")\nplt.tight_layout(); plt.show()\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 8. Short ablation demo\n\nJust an illustration; the full 5-seed multi-prompt causal ablation (which drops\nSpearman(task_complexity \u00d7 step_count) 0.83 \u2192 0.78 in the ADHD model, median \u0394\n= -0.055) lives in `phase4_ablation_multiseed.py`. Here we zero out the L3H5\nattention pattern on one prompt and show that the output distribution at\nstep-onset shifts. This is a sniff-test, not the load-bearing causal claim.\n" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "from transformer_lens.hook_points import HookPoint\n\ndef zero_head_hook(layer, head):\n name = f\"blocks.{layer}.attn.hook_pattern\"\n def hook(pattern, hook_point: HookPoint):\n pattern = pattern.clone()\n pattern[:, head] = 0.0\n return pattern\n return name, hook\n\nids = records[0][\"ids\"]\nx = torch.tensor([ids], dtype=torch.long, device=DEVICE)\n\nwith torch.no_grad():\n base_logits = adhd_model(x, return_type=\"logits\")[0, -1]\n name, hook = zero_head_hook(3, 5)\n ablated_logits = adhd_model.run_with_hooks(x, return_type=\"logits\",\n fwd_hooks=[(name, hook)])[0, -1]\n\ndef top_tokens(logits, k=5):\n probs = torch.softmax(logits.float(), dim=-1)\n vals, idx = probs.topk(k)\n return [(tok.id_to_token(int(i)), float(v)) for v, i in zip(vals, idx)]\n\nprint(\"baseline top-5:\", top_tokens(base_logits))\nprint(\"ablated top-5:\", top_tokens(ablated_logits))\n" }, { "cell_type": "markdown", "metadata": {}, "source": "## 9. Where to go next\n\n- Full 48-head swap inventory: see AGENT-2's `migration_map.md` + swap heatmap\n figure in the Phase 2 results bundle.\n- SAE features + null-steering feature 2504: see the companion\n [`connaaa/interpgpt-sae-phase5`](https://huggingface.co/connaaa/interpgpt-sae-phase5)\n repo and the `phase5_*.py` scripts.\n- Phase 1 writeup: `interpgpt-writeup-draft.md` in the main repo.\n" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }