Spaces:
Running
Running
File size: 10,294 Bytes
33a0021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | import torch
import torch.nn as nn
from typing import Dict, List, Set, Tuple, Union, Optional
class CircuitSurgeon:
"""
Manages interactive node and path ablations on the Decision Transformer.
Enables zeroing out specific components (Heads, MLPs) or cutting
specific communication paths between components.
"""
def __init__(self, model):
self.model = model
self.ablated_nodes: Set[str] = set() # format: "L<layer>H<head>" or "L<layer>MLP"
self.ablated_edges: Set[Tuple[str, str]] = set() # format: ("L<l1>H<h1>", "L<l2>H<h2>")
def add_node_ablation(self, node: str):
"""Adds a component node (e.g., 'L0H1' or 'L1MLP') to the ablated set."""
self.ablated_nodes.add(node)
def remove_node_ablation(self, node: str):
"""Removes a component node from the ablated set."""
self.ablated_nodes.discard(node)
def add_edge_ablation(self, from_node: str, to_node: str):
"""Adds a communication path (edge) to the ablated set."""
self.ablated_edges.add((from_node, to_node))
def remove_edge_ablation(self, from_node: str, to_node: str):
"""Removes a communication path from the ablated set."""
self.ablated_edges.discard((from_node, to_node))
def clear_ablations(self):
"""Clears all registered ablations."""
self.ablated_nodes.clear()
self.ablated_edges.clear()
def parse_node(self, node: str) -> Tuple[int, Optional[int]]:
"""Parses a node name into layer and head index (None for MLP)."""
if "MLP" in node:
layer = int(node.replace("MLP", "").replace("L", ""))
return layer, None
else:
parts = node.split("H")
layer = int(parts[0].replace("L", ""))
head = int(parts[1])
return layer, head
def get_ablation_hooks(self, baseline_cache: Dict[str, torch.Tensor]) -> List[Tuple[str, callable]]:
"""
Generates PyTorch forward hook functions for registered node and edge ablations.
"""
hooks = []
# 1. Node Ablations
# Group by layer for efficiency
attn_nodes_by_layer = {}
mlp_layers = set()
for node in self.ablated_nodes:
layer, head = self.parse_node(node)
if head is None:
mlp_layers.add(layer)
else:
if layer not in attn_nodes_by_layer:
attn_nodes_by_layer[layer] = []
attn_nodes_by_layer[layer].append(head)
# Attention Node Hooks
for layer, heads in attn_nodes_by_layer.items():
def make_attn_hook(l, hs):
def attn_hook(value, hook):
for h in hs:
value[:, :, h, :] = 0.0
return value
return attn_hook
hook_name = f"blocks.{layer}.attn.hook_result"
hooks.append((hook_name, make_attn_hook(layer, heads)))
# MLP Node Hooks
for layer in mlp_layers:
def make_mlp_hook(l):
def mlp_hook(value, hook):
value[:, :, :] = 0.0
return value
return mlp_hook
hook_name = f"blocks.{layer}.hook_mlp_out"
hooks.append((hook_name, make_mlp_hook(layer)))
# 2. Path/Edge Ablations
# Group edges by their destination node to avoid redundant hooks
edges_by_dest = {}
for from_node, to_node in self.ablated_edges:
# Skip if either endpoint is already node-ablated (subsumed by node ablation)
if from_node in self.ablated_nodes or to_node in self.ablated_nodes:
continue
if to_node not in edges_by_dest:
edges_by_dest[to_node] = []
edges_by_dest[to_node].append(from_node)
for to_node, from_nodes in edges_by_dest.items():
to_layer, to_head = self.parse_node(to_node)
if to_head is not None:
# Target is an attention head (L2H2)
# Hook Q, K, V activations of that layer
def make_path_attn_hooks(tl, th, fns):
def q_hook(value, hook):
for fn in fns:
fl, fh = self.parse_node(fn)
if fh is None:
# Source is MLP (L1MLP)
src_key = f"blocks.{fl}.hook_mlp_out"
else:
# Source is Head (L1H1)
src_key = f"blocks.{fl}.attn.hook_result"
if src_key in baseline_cache:
src_out = baseline_cache[src_key]
if fh is not None:
src_out = src_out[:, :, fh, :]
# Apply downstream layer's first layernorm
ln_out = self.model.transformer.blocks[tl].ln1(src_out)
# Project to Query
W_Q = self.model.transformer.blocks[tl].attn.W_Q[th]
q_contrib = ln_out @ W_Q
value[:, :, th, :] -= q_contrib
return value
def k_hook(value, hook):
for fn in fns:
fl, fh = self.parse_node(fn)
if fh is None:
src_key = f"blocks.{fl}.hook_mlp_out"
else:
src_key = f"blocks.{fl}.attn.hook_result"
if src_key in baseline_cache:
src_out = baseline_cache[src_key]
if fh is not None:
src_out = src_out[:, :, fh, :]
ln_out = self.model.transformer.blocks[tl].ln1(src_out)
W_K = self.model.transformer.blocks[tl].attn.W_K[th]
k_contrib = ln_out @ W_K
value[:, :, th, :] -= k_contrib
return value
def v_hook(value, hook):
for fn in fns:
fl, fh = self.parse_node(fn)
if fh is None:
src_key = f"blocks.{fl}.hook_mlp_out"
else:
src_key = f"blocks.{fl}.attn.hook_result"
if src_key in baseline_cache:
src_out = baseline_cache[src_key]
if fh is not None:
src_out = src_out[:, :, fh, :]
ln_out = self.model.transformer.blocks[tl].ln1(src_out)
W_V = self.model.transformer.blocks[tl].attn.W_V[th]
v_contrib = ln_out @ W_V
value[:, :, th, :] -= v_contrib
return value
return q_hook, k_hook, v_hook
qh, kh, vh = make_path_attn_hooks(to_layer, to_head, from_nodes)
hooks.append((f"blocks.{to_layer}.attn.hook_q", qh))
hooks.append((f"blocks.{to_layer}.attn.hook_k", kh))
hooks.append((f"blocks.{to_layer}.attn.hook_v", vh))
else:
# Target is an MLP layer (L2MLP)
# Hook the input to the MLP
def make_path_mlp_hook(tl, fns):
def mlp_in_hook(value, hook):
for fn in fns:
fl, fh = self.parse_node(fn)
if fh is None:
src_key = f"blocks.{fl}.hook_mlp_out"
else:
src_key = f"blocks.{fl}.attn.hook_result"
if src_key in baseline_cache:
src_out = baseline_cache[src_key]
if fh is not None:
src_out = src_out[:, :, fh, :]
# Apply downstream layer's second layernorm
ln_out = self.model.transformer.blocks[tl].ln2(src_out)
value -= ln_out
return value
return mlp_in_hook
hooks.append((f"blocks.{to_layer}.hook_mlp_in", make_path_mlp_hook(to_layer, from_nodes)))
return hooks
def compute_ablated_forward(
self,
states: torch.Tensor,
actions: torch.Tensor,
returns_to_go: torch.Tensor,
return_cache: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Executes a baseline pass to cache source activations, compiles the
necessary ablation hooks, and performs the ablated forward pass.
"""
# Step 1: Run baseline pass to populate cache
_, baseline_cache = self.model(states, actions, returns_to_go, return_cache=True)
# Step 2: Compile hooks
hooks = self.get_ablation_hooks(baseline_cache)
# Step 3: Run ablated forward pass
if len(hooks) == 0:
if return_cache:
return self.model(states, actions, returns_to_go, return_cache=True)
return self.model(states, actions, returns_to_go)
# Register hooks using the model's transformer context manager
if return_cache:
with self.model.transformer.hooks(fwd_hooks=hooks):
preds, cache = self.model(states, actions, returns_to_go, return_cache=True)
return preds, cache
else:
with self.model.transformer.hooks(fwd_hooks=hooks):
preds = self.model(states, actions, returns_to_go)
return preds
|