File size: 1,555 Bytes
11dbbc6
 
 
 
 
8577352
 
11dbbc6
 
 
 
 
 
8577352
11dbbc6
 
 
 
8577352
11dbbc6
 
 
 
 
 
 
8577352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from typing import Dict, Optional, Tuple

class PathPatchingEngine:
    """
    Engine for path-based causal interventions.
    Helps isolate which internal paths are necessary for a decision.
    """
    def __init__(self, model):
        self.model = model

    def perform_edge_ablation(
        self,
        inputs: dict,
        layer: int,
        head_index: int,
        ablation_type: str = "zero"
    ) -> torch.Tensor:
        """Zeroes out a specific head's output to check its causal necessity."""
        def ablation_hook(value, hook):
            if ablation_type == "zero":
                value[:, :, head_index, :] = 0.0
            return value

        hook_name = f"blocks.{layer}.attn.hook_result"
        with self.model.transformer.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
            preds = self.model(**inputs)
        return preds

    def patch_path(
        self,
        clean_inputs: dict,
        corrupted_cache: dict,
        layer: int,
        head: int
    ) -> torch.Tensor:
        """Patches a specific head's output with activations from a corrupted run."""
        def patch_hook(value, hook):
            # value: [batch, pos, head, d_model]
            corrupted_val = corrupted_cache[hook.name]
            value[:, :, head, :] = corrupted_val[:, :, head, :]
            return value

        hook_name = f"blocks.{layer}.attn.hook_result"
        with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
            preds = self.model(**clean_inputs)
        return preds