File size: 10,294 Bytes
33a0021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import torch
import torch.nn as nn
from typing import Dict, List, Set, Tuple, Union, Optional

class CircuitSurgeon:
    """
    Manages interactive node and path ablations on the Decision Transformer.
    Enables zeroing out specific components (Heads, MLPs) or cutting
    specific communication paths between components.
    """
    def __init__(self, model):
        self.model = model
        self.ablated_nodes: Set[str] = set()  # format: "L<layer>H<head>" or "L<layer>MLP"
        self.ablated_edges: Set[Tuple[str, str]] = set()  # format: ("L<l1>H<h1>", "L<l2>H<h2>")

    def add_node_ablation(self, node: str):
        """Adds a component node (e.g., 'L0H1' or 'L1MLP') to the ablated set."""
        self.ablated_nodes.add(node)

    def remove_node_ablation(self, node: str):
        """Removes a component node from the ablated set."""
        self.ablated_nodes.discard(node)

    def add_edge_ablation(self, from_node: str, to_node: str):
        """Adds a communication path (edge) to the ablated set."""
        self.ablated_edges.add((from_node, to_node))

    def remove_edge_ablation(self, from_node: str, to_node: str):
        """Removes a communication path from the ablated set."""
        self.ablated_edges.discard((from_node, to_node))

    def clear_ablations(self):
        """Clears all registered ablations."""
        self.ablated_nodes.clear()
        self.ablated_edges.clear()

    def parse_node(self, node: str) -> Tuple[int, Optional[int]]:
        """Parses a node name into layer and head index (None for MLP)."""
        if "MLP" in node:
            layer = int(node.replace("MLP", "").replace("L", ""))
            return layer, None
        else:
            parts = node.split("H")
            layer = int(parts[0].replace("L", ""))
            head = int(parts[1])
            return layer, head

    def get_ablation_hooks(self, baseline_cache: Dict[str, torch.Tensor]) -> List[Tuple[str, callable]]:
        """
        Generates PyTorch forward hook functions for registered node and edge ablations.
        """
        hooks = []

        # 1. Node Ablations
        # Group by layer for efficiency
        attn_nodes_by_layer = {}
        mlp_layers = set()

        for node in self.ablated_nodes:
            layer, head = self.parse_node(node)
            if head is None:
                mlp_layers.add(layer)
            else:
                if layer not in attn_nodes_by_layer:
                    attn_nodes_by_layer[layer] = []
                attn_nodes_by_layer[layer].append(head)

        # Attention Node Hooks
        for layer, heads in attn_nodes_by_layer.items():
            def make_attn_hook(l, hs):
                def attn_hook(value, hook):
                    for h in hs:
                        value[:, :, h, :] = 0.0
                    return value
                return attn_hook

            hook_name = f"blocks.{layer}.attn.hook_result"
            hooks.append((hook_name, make_attn_hook(layer, heads)))

        # MLP Node Hooks
        for layer in mlp_layers:
            def make_mlp_hook(l):
                def mlp_hook(value, hook):
                    value[:, :, :] = 0.0
                    return value
                return mlp_hook

            hook_name = f"blocks.{layer}.hook_mlp_out"
            hooks.append((hook_name, make_mlp_hook(layer)))

        # 2. Path/Edge Ablations
        # Group edges by their destination node to avoid redundant hooks
        edges_by_dest = {}
        for from_node, to_node in self.ablated_edges:
            # Skip if either endpoint is already node-ablated (subsumed by node ablation)
            if from_node in self.ablated_nodes or to_node in self.ablated_nodes:
                continue
            if to_node not in edges_by_dest:
                edges_by_dest[to_node] = []
            edges_by_dest[to_node].append(from_node)

        for to_node, from_nodes in edges_by_dest.items():
            to_layer, to_head = self.parse_node(to_node)

            if to_head is not None:
                # Target is an attention head (L2H2)
                # Hook Q, K, V activations of that layer
                def make_path_attn_hooks(tl, th, fns):
                    def q_hook(value, hook):
                        for fn in fns:
                            fl, fh = self.parse_node(fn)
                            if fh is None:
                                # Source is MLP (L1MLP)
                                src_key = f"blocks.{fl}.hook_mlp_out"
                            else:
                                # Source is Head (L1H1)
                                src_key = f"blocks.{fl}.attn.hook_result"

                            if src_key in baseline_cache:
                                src_out = baseline_cache[src_key]
                                if fh is not None:
                                    src_out = src_out[:, :, fh, :]
                                
                                # Apply downstream layer's first layernorm
                                ln_out = self.model.transformer.blocks[tl].ln1(src_out)
                                # Project to Query
                                W_Q = self.model.transformer.blocks[tl].attn.W_Q[th]
                                q_contrib = ln_out @ W_Q
                                value[:, :, th, :] -= q_contrib
                        return value

                    def k_hook(value, hook):
                        for fn in fns:
                            fl, fh = self.parse_node(fn)
                            if fh is None:
                                src_key = f"blocks.{fl}.hook_mlp_out"
                            else:
                                src_key = f"blocks.{fl}.attn.hook_result"

                            if src_key in baseline_cache:
                                src_out = baseline_cache[src_key]
                                if fh is not None:
                                    src_out = src_out[:, :, fh, :]
                                
                                ln_out = self.model.transformer.blocks[tl].ln1(src_out)
                                W_K = self.model.transformer.blocks[tl].attn.W_K[th]
                                k_contrib = ln_out @ W_K
                                value[:, :, th, :] -= k_contrib
                        return value

                    def v_hook(value, hook):
                        for fn in fns:
                            fl, fh = self.parse_node(fn)
                            if fh is None:
                                src_key = f"blocks.{fl}.hook_mlp_out"
                            else:
                                src_key = f"blocks.{fl}.attn.hook_result"

                            if src_key in baseline_cache:
                                src_out = baseline_cache[src_key]
                                if fh is not None:
                                    src_out = src_out[:, :, fh, :]
                                
                                ln_out = self.model.transformer.blocks[tl].ln1(src_out)
                                W_V = self.model.transformer.blocks[tl].attn.W_V[th]
                                v_contrib = ln_out @ W_V
                                value[:, :, th, :] -= v_contrib
                        return value

                    return q_hook, k_hook, v_hook

                    
                qh, kh, vh = make_path_attn_hooks(to_layer, to_head, from_nodes)
                hooks.append((f"blocks.{to_layer}.attn.hook_q", qh))
                hooks.append((f"blocks.{to_layer}.attn.hook_k", kh))
                hooks.append((f"blocks.{to_layer}.attn.hook_v", vh))

            else:
                # Target is an MLP layer (L2MLP)
                # Hook the input to the MLP
                def make_path_mlp_hook(tl, fns):
                    def mlp_in_hook(value, hook):
                        for fn in fns:
                            fl, fh = self.parse_node(fn)
                            if fh is None:
                                src_key = f"blocks.{fl}.hook_mlp_out"
                            else:
                                src_key = f"blocks.{fl}.attn.hook_result"

                            if src_key in baseline_cache:
                                src_out = baseline_cache[src_key]
                                if fh is not None:
                                    src_out = src_out[:, :, fh, :]
                                
                                # Apply downstream layer's second layernorm
                                ln_out = self.model.transformer.blocks[tl].ln2(src_out)
                                value -= ln_out
                        return value
                    return mlp_in_hook

                hooks.append((f"blocks.{to_layer}.hook_mlp_in", make_path_mlp_hook(to_layer, from_nodes)))

        return hooks

    def compute_ablated_forward(
        self, 
        states: torch.Tensor, 
        actions: torch.Tensor, 
        returns_to_go: torch.Tensor,
        return_cache: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        """
        Executes a baseline pass to cache source activations, compiles the
        necessary ablation hooks, and performs the ablated forward pass.
        """
        # Step 1: Run baseline pass to populate cache
        _, baseline_cache = self.model(states, actions, returns_to_go, return_cache=True)

        # Step 2: Compile hooks
        hooks = self.get_ablation_hooks(baseline_cache)

        # Step 3: Run ablated forward pass
        if len(hooks) == 0:
            if return_cache:
                return self.model(states, actions, returns_to_go, return_cache=True)
            return self.model(states, actions, returns_to_go)

        # Register hooks using the model's transformer context manager
        if return_cache:
            with self.model.transformer.hooks(fwd_hooks=hooks):
                preds, cache = self.model(states, actions, returns_to_go, return_cache=True)
            return preds, cache
        else:
            with self.model.transformer.hooks(fwd_hooks=hooks):
                preds = self.model(states, actions, returns_to_go)
            return preds