File size: 20,553 Bytes
e01b061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "# InterpGPT: Reproducing the Head-Swap Finding\n\nTwo 23.5M-parameter transformers trained on the same architecture with the same\nrecipe; the only difference is the distribution of their training data.\n**Standard** was trained on plain task decompositions; **ADHD** was trained on\ndecompositions with smaller steps and interleaved micro-regulation actions.\n\nThe Phase 1 headline finding: a *step-layout broadcast* head that persistently\nattends to preceding step-boundary tokens exists in **both** models, implementing\nthe same function \u2014 but it lives at **L3H0** in the standard model and **L3H5**\nin the ADHD model. Cross-model per-position attention profile cosine similarity\nat the matched pair is **0.997**; same-index baseline is **0.66**. This notebook\nreproduces that comparison end-to-end in under 15 minutes on Colab free tier.\n\n**Runtime**: CPU is fine. GPU optional.\n\n**Source**: https://github.com/cwklurks/interpgpt\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 1. Install dependencies"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "!pip install -q \\\n    torch \\\n    transformer_lens==2.4.1 \\\n    huggingface_hub \\\n    tokenizers \\\n    matplotlib\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 2. Configuration\n\nSet your HuggingFace org/user if you're loading a fork. The defaults point at\nthe canonical InterpGPT release.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "ORG = \"connaaa\"\nSTANDARD_REPO = f\"{ORG}/interpgpt-standard-23M\"\nADHD_REPO     = f\"{ORG}/interpgpt-adhd-23M\"\n\nimport torch\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nprint(f\"device: {DEVICE}\")\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 3. Load both models into TransformerLens\n\nThe HF repos ship a TransformerLens-compatible bundle (`hooked_transformer.pt`)\nalongside the HF `config.json` / `model.safetensors` pair. We use the\nTransformerLens bundle directly \u2014 it's the format the Phase 1 analyses were run\nagainst.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "from huggingface_hub import hf_hub_download\nfrom transformer_lens import HookedTransformer, HookedTransformerConfig\nimport torch\n\ndef load_tl(repo_id: str) -> HookedTransformer:\n    path = hf_hub_download(repo_id, \"hooked_transformer.pt\")\n    blob = torch.load(path, map_location=\"cpu\", weights_only=False)\n    cfg_dict = blob[\"config\"]\n    drop = {\"dtype\", \"device\", \"attention_dir\"}\n    cfg_keep = {\n        k: v for k, v in cfg_dict.items()\n        if k in HookedTransformerConfig.__dataclass_fields__\n           and not (k in drop and isinstance(v, str))\n           and not (isinstance(v, str) and v.startswith(\"torch.\"))\n    }\n    model = HookedTransformer(HookedTransformerConfig(**cfg_keep)).to(DEVICE)\n    model.load_state_dict(blob[\"model_state_dict\"])\n    model.eval()\n    return model\n\nstd_model  = load_tl(STANDARD_REPO)\nadhd_model = load_tl(ADHD_REPO)\nprint(std_model.cfg)\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 4. Tokenizer and paired prompts\n\nWe use the Phase 1 BPE tokenizer (shipped in the HF repo as `tokenizer.json`)\nand a small paired-task set. Each task appears in both the standard format and\nthe ADHD format; that's what 'paired' means here.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "from tokenizers import Tokenizer\n\ntok_path = hf_hub_download(STANDARD_REPO, \"tokenizer.json\")\ntok = Tokenizer.from_file(tok_path)\n\nSPECIAL_NAMES = [\"<|task|>\", \"<|steps|>\", \"<|sep|>\", \"<|end|>\", \"<|pad|>\"]\nspecial_ids = {n: tok.token_to_id(n) for n in SPECIAL_NAMES}\nprint(\"specials:\", special_ids)\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "# 12 paired tasks, hand-crafted to mirror the test-set distribution.\n# Each has a 'standard' variant (short steps, no regulation) and an\n# 'adhd' variant (smaller steps + interleaved regulation tokens).\nPAIRED = [\n    {\n        \"task\": \"Clean the kitchen\",\n        \"standard\": {\n            \"task\": \"Clean the kitchen\",\n            \"steps\": [\"Wash dishes\", \"Wipe counters\", \"Sweep floor\", \"Take out trash\"],\n        },\n        \"adhd\": {\n            \"task\": \"Clean the kitchen\",\n            \"steps\": [\n                \"Stack dirty dishes\", \"sip water\", \"Rinse plates\", \"deep breath\",\n                \"Scrub pans\", \"quick stretch\", \"Wipe counters\", \"pause\",\n                \"Sweep floor\", \"close eyes briefly\", \"Take out trash\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Write a grocery list\",\n        \"standard\": {\n            \"task\": \"Write a grocery list\",\n            \"steps\": [\"Check pantry\", \"Plan meals\", \"Write items\", \"Sort by aisle\"],\n        },\n        \"adhd\": {\n            \"task\": \"Write a grocery list\",\n            \"steps\": [\n                \"Grab paper\", \"sip water\", \"Open pantry\", \"deep breath\",\n                \"Jot missing items\", \"pause\", \"Plan three meals\", \"quick stretch\",\n                \"Sort items by aisle\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Do laundry\",\n        \"standard\": {\n            \"task\": \"Do laundry\",\n            \"steps\": [\"Sort clothes\", \"Load washer\", \"Start cycle\", \"Move to dryer\", \"Fold\"],\n        },\n        \"adhd\": {\n            \"task\": \"Do laundry\",\n            \"steps\": [\n                \"Gather hamper\", \"sip water\", \"Sort lights darks\", \"deep breath\",\n                \"Load washer\", \"Add detergent\", \"pause\", \"Start cycle\",\n                \"quick stretch\", \"Transfer to dryer\", \"close eyes briefly\", \"Fold clean clothes\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Prepare breakfast\",\n        \"standard\": {\n            \"task\": \"Prepare breakfast\",\n            \"steps\": [\"Pick recipe\", \"Gather ingredients\", \"Cook\", \"Plate food\"],\n        },\n        \"adhd\": {\n            \"task\": \"Prepare breakfast\",\n            \"steps\": [\n                \"Pick simple recipe\", \"sip water\", \"Open fridge\", \"deep breath\",\n                \"Take out eggs\", \"pause\", \"Crack eggs in bowl\", \"quick stretch\",\n                \"Cook on pan\", \"Plate breakfast\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Pay the bills\",\n        \"standard\": {\n            \"task\": \"Pay the bills\",\n            \"steps\": [\"Open statements\", \"Total amount\", \"Log in bank\", \"Pay each\"],\n        },\n        \"adhd\": {\n            \"task\": \"Pay the bills\",\n            \"steps\": [\n                \"Stack bills\", \"sip water\", \"Open first statement\", \"deep breath\",\n                \"Write due amount\", \"pause\", \"Log in to bank\", \"close eyes briefly\",\n                \"Pay each bill\", \"quick stretch\", \"File statements\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Water the plants\",\n        \"standard\": {\n            \"task\": \"Water the plants\",\n            \"steps\": [\"Fill watering can\", \"Water each plant\", \"Check soil\"],\n        },\n        \"adhd\": {\n            \"task\": \"Water the plants\",\n            \"steps\": [\n                \"Find watering can\", \"sip water\", \"Fill at sink\", \"deep breath\",\n                \"Check first pot soil\", \"pause\", \"Pour slowly\", \"quick stretch\",\n                \"Move to next plant\", \"Repeat for each\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Organize the closet\",\n        \"standard\": {\n            \"task\": \"Organize the closet\",\n            \"steps\": [\"Empty closet\", \"Sort items\", \"Donate\", \"Rearrange by type\"],\n        },\n        \"adhd\": {\n            \"task\": \"Organize the closet\",\n            \"steps\": [\n                \"Empty top shelf\", \"sip water\", \"Sort by keep donate\", \"deep breath\",\n                \"Bag donations\", \"pause\", \"Empty bottom shelf\", \"quick stretch\",\n                \"Sort again\", \"close eyes briefly\", \"Rearrange by type\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Send a work email\",\n        \"standard\": {\n            \"task\": \"Send a work email\",\n            \"steps\": [\"Draft message\", \"Check recipients\", \"Attach files\", \"Send\"],\n        },\n        \"adhd\": {\n            \"task\": \"Send a work email\",\n            \"steps\": [\n                \"Open email client\", \"sip water\", \"Draft subject\", \"deep breath\",\n                \"Write body paragraph\", \"pause\", \"Reread for tone\", \"quick stretch\",\n                \"Add recipients\", \"Attach files\", \"close eyes briefly\", \"Click send\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Study for an exam\",\n        \"standard\": {\n            \"task\": \"Study for an exam\",\n            \"steps\": [\"Review notes\", \"Do practice problems\", \"Flashcards\", \"Sleep early\"],\n        },\n        \"adhd\": {\n            \"task\": \"Study for an exam\",\n            \"steps\": [\n                \"Open notes\", \"sip water\", \"Read first section\", \"deep breath\",\n                \"Summarize aloud\", \"pause\", \"Do two practice problems\", \"quick stretch\",\n                \"Make flashcards\", \"close eyes briefly\", \"Review flashcards\", \"Sleep early\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Make the bed\",\n        \"standard\": {\n            \"task\": \"Make the bed\",\n            \"steps\": [\"Pull sheets\", \"Fluff pillows\", \"Smooth comforter\"],\n        },\n        \"adhd\": {\n            \"task\": \"Make the bed\",\n            \"steps\": [\n                \"Pull sheets up\", \"sip water\", \"Smooth bottom sheet\", \"deep breath\",\n                \"Fluff pillows\", \"pause\", \"Pull comforter\", \"quick stretch\",\n                \"Smooth top\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Take a walk\",\n        \"standard\": {\n            \"task\": \"Take a walk\",\n            \"steps\": [\"Put on shoes\", \"Grab keys\", \"Walk 20 minutes\"],\n        },\n        \"adhd\": {\n            \"task\": \"Take a walk\",\n            \"steps\": [\n                \"Put on shoes\", \"sip water\", \"Grab keys\", \"deep breath\",\n                \"Step outside\", \"Walk five minutes\", \"pause\",\n                \"Walk five more\", \"quick stretch\", \"Return home\",\n            ],\n        },\n    },\n    {\n        \"task\": \"Cook dinner\",\n        \"standard\": {\n            \"task\": \"Cook dinner\",\n            \"steps\": [\"Pick recipe\", \"Prep ingredients\", \"Cook\", \"Plate\"],\n        },\n        \"adhd\": {\n            \"task\": \"Cook dinner\",\n            \"steps\": [\n                \"Pick recipe\", \"sip water\", \"Wash hands\", \"deep breath\",\n                \"Chop vegetables\", \"pause\", \"Measure spices\", \"quick stretch\",\n                \"Start cooking\", \"close eyes briefly\", \"Taste\", \"Plate dinner\",\n            ],\n        },\n    },\n]\nprint(f\"paired tasks: {len(PAIRED)}\")\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "def encode_example(ex):\n    \"\"\"Encode one (task, steps) record into the training input format.\"\"\"\n    ids = [special_ids[\"<|task|>\"]]\n    ids += tok.encode(ex[\"task\"]).ids\n    ids += [special_ids[\"<|steps|>\"]]\n    for i, step in enumerate(ex[\"steps\"]):\n        if i > 0:\n            ids.append(special_ids[\"<|sep|>\"])\n        ids += tok.encode(step).ids\n    ids += [special_ids[\"<|end|>\"]]\n    return ids\n\ndef token_roles(ids):\n    steps_id = special_ids[\"<|steps|>\"]\n    roles = {\"task_range\": [], \"special_positions\": []}\n    hit_steps = False\n    for i, t in enumerate(ids):\n        if t == steps_id:\n            hit_steps = True\n            roles[\"special_positions\"].append(i)\n            continue\n        if t in set(special_ids.values()):\n            roles[\"special_positions\"].append(i)\n        elif not hit_steps:\n            roles[\"task_range\"].append(i)\n    return roles\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 5. Per-position attention profile\n\nFor every paired task we teacher-force the standard-variant sequence through\n**both** models and cache attention patterns. Then for a given (layer, head) we\ncompute, across query positions binned to a normalized `[0, 1]` axis, how much\nattention mass goes to *step-structure specials* (task/sep/steps/end markers).\nThat signal is the fingerprint of the step-layout-broadcast head.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "import numpy as np\n\ndef cache_attn(model, ids):\n    x = torch.tensor([ids], dtype=torch.long, device=DEVICE)\n    with torch.no_grad():\n        _, cache = model.run_with_cache(x, return_type=None)\n    L, H = model.cfg.n_layers, model.cfg.n_heads\n    T = x.shape[1]\n    out = torch.zeros(L, H, T, T)\n    for layer in range(L):\n        out[layer] = cache[f\"blocks.{layer}.attn.hook_pattern\"][0].to(\"cpu\")\n    return out\n\nrecords = []\nfor pair in PAIRED:\n    ids = encode_example(pair[\"standard\"])\n    if len(ids) < 5 or len(ids) > std_model.cfg.n_ctx:\n        continue\n    roles = token_roles(ids)\n    records.append({\n        \"ids\": ids,\n        \"roles\": roles,\n        \"std_attn\": cache_attn(std_model, ids),\n        \"adhd_attn\": cache_attn(adhd_model, ids),\n    })\nprint(f\"cached attention on {len(records)} paired prompts\")\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "NBINS = 20\n\ndef profile(records, layer, head, which):\n    bins_struct = np.zeros(NBINS)\n    bins_count  = np.zeros(NBINS)\n    for r in records:\n        attn = r[f\"{which}_attn\"][layer, head]\n        T = attn.shape[0]\n        spec = list(r[\"roles\"][\"special_positions\"])\n        if not spec:\n            continue\n        for q in range(1, T):\n            b = min(int(q / T * NBINS), NBINS - 1)\n            bins_struct[b] += attn[q, spec].sum().item()\n            bins_count[b]  += 1\n    return bins_struct / np.maximum(bins_count, 1)\n\ndef cosine(a, b):\n    na, nb = np.linalg.norm(a), np.linalg.norm(b)\n    if na == 0 or nb == 0:\n        return 0.0\n    return float(np.dot(a, b) / (na * nb))\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 6. Cross-model cosine: the 0.997 vs 0.66 comparison\n\n- **Matched pair** (function-to-function): standard L3H0 vs ADHD L3H5.\n- **Same-index baseline**: standard L3H0 vs ADHD L3H0 (and ADHD L3H5 vs standard L3H5).\n\nOn the full held-out test set the matched-pair cosine was **0.997** and the\nsame-index baselines were **0.663** and **0.643**. This notebook uses only\n12 hand-crafted paired prompts for speed, so the absolute numbers are softer\n(matched \u2248 **0.99**, baselines around **0.7\u20130.9**) \u2014 but the ordering is\nalways the same: the matched pair cosine is strictly higher than either\nsame-index baseline. That ordering is the claim.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "prof_std_L3H0  = profile(records, 3, 0, \"std\")\nprof_std_L3H5  = profile(records, 3, 5, \"std\")\nprof_adhd_L3H0 = profile(records, 3, 0, \"adhd\")\nprof_adhd_L3H5 = profile(records, 3, 5, \"adhd\")\n\nmatched   = cosine(prof_std_L3H0, prof_adhd_L3H5)\nbaseline1 = cosine(prof_std_L3H0, prof_adhd_L3H0)\nbaseline2 = cosine(prof_std_L3H5, prof_adhd_L3H5)\n\nprint(f\"matched (std L3H0 ~ adhd L3H5):   {matched:.3f}\")\nprint(f\"baseline (std L3H0 ~ adhd L3H0):  {baseline1:.3f}\")\nprint(f\"baseline (std L3H5 ~ adhd L3H5):  {baseline2:.3f}\")\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 7. Visualize the matched profiles\n\nThe matched pair should trace almost-identical curves; the same-index pair\nshould visibly diverge.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "import matplotlib.pyplot as plt\n\nx = np.linspace(0, 1, NBINS)\nfig, axes = plt.subplots(1, 2, figsize=(11, 4), sharey=True)\n\naxes[0].plot(x, prof_std_L3H0, label=\"standard L3H0\", lw=2)\naxes[0].plot(x, prof_adhd_L3H5, label=\"ADHD L3H5\", lw=2)\naxes[0].set_title(f\"matched pair (cos = {matched:.3f})\")\naxes[0].set_xlabel(\"normalized query position\")\naxes[0].set_ylabel(\"attention mass \u2192 step-structure specials\")\naxes[0].legend(); axes[0].grid(alpha=0.3)\n\naxes[1].plot(x, prof_std_L3H0, label=\"standard L3H0\", lw=2)\naxes[1].plot(x, prof_adhd_L3H0, label=\"ADHD L3H0\", lw=2)\naxes[1].set_title(f\"same-index baseline (cos = {baseline1:.3f})\")\naxes[1].set_xlabel(\"normalized query position\")\naxes[1].legend(); axes[1].grid(alpha=0.3)\n\nplt.tight_layout(); plt.show()\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "# Side-by-side (seq, seq) attention heatmaps on a single paired prompt.\nr = records[0]\nimport matplotlib.pyplot as plt\nfig, axes = plt.subplots(1, 2, figsize=(10, 4.5))\naxes[0].imshow(r[\"std_attn\"][3, 0], aspect=\"auto\", cmap=\"viridis\")\naxes[0].set_title(\"standard L3H0 \u2014 step-layout broadcast\")\naxes[1].imshow(r[\"adhd_attn\"][3, 5], aspect=\"auto\", cmap=\"viridis\")\naxes[1].set_title(\"ADHD L3H5 \u2014 same function, different index\")\nfor ax in axes:\n    ax.set_xlabel(\"key position\")\naxes[0].set_ylabel(\"query position\")\nplt.tight_layout(); plt.show()\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 8. Short ablation demo\n\nJust an illustration; the full 5-seed multi-prompt causal ablation (which drops\nSpearman(task_complexity \u00d7 step_count) 0.83 \u2192 0.78 in the ADHD model, median \u0394\n= -0.055) lives in `phase4_ablation_multiseed.py`. Here we zero out the L3H5\nattention pattern on one prompt and show that the output distribution at\nstep-onset shifts. This is a sniff-test, not the load-bearing causal claim.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "from transformer_lens.hook_points import HookPoint\n\ndef zero_head_hook(layer, head):\n    name = f\"blocks.{layer}.attn.hook_pattern\"\n    def hook(pattern, hook_point: HookPoint):\n        pattern = pattern.clone()\n        pattern[:, head] = 0.0\n        return pattern\n    return name, hook\n\nids = records[0][\"ids\"]\nx = torch.tensor([ids], dtype=torch.long, device=DEVICE)\n\nwith torch.no_grad():\n    base_logits = adhd_model(x, return_type=\"logits\")[0, -1]\n    name, hook  = zero_head_hook(3, 5)\n    ablated_logits = adhd_model.run_with_hooks(x, return_type=\"logits\",\n                                                fwd_hooks=[(name, hook)])[0, -1]\n\ndef top_tokens(logits, k=5):\n    probs = torch.softmax(logits.float(), dim=-1)\n    vals, idx = probs.topk(k)\n    return [(tok.id_to_token(int(i)), float(v)) for v, i in zip(vals, idx)]\n\nprint(\"baseline top-5:\", top_tokens(base_logits))\nprint(\"ablated top-5:\", top_tokens(ablated_logits))\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 9. Where to go next\n\n- Full 48-head swap inventory: see AGENT-2's `migration_map.md` + swap heatmap\n  figure in the Phase 2 results bundle.\n- SAE features + null-steering feature 2504: see the companion\n  [`connaaa/interpgpt-sae-phase5`](https://huggingface.co/connaaa/interpgpt-sae-phase5)\n  repo and the `phase5_*.py` scripts.\n- Phase 1 writeup: `interpgpt-writeup-draft.md` in the main repo.\n"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  },
  "colab": {
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}