File size: 1,467 Bytes
e2614dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa19e7
e2614dc
11dbbc6
e2614dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11dbbc6
e2614dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from typing import Callable, List, Optional
from transformer_lens import HookedTransformer

class ActivationPatcher:
    """
    Interface for causal interventions via activation patching.
    """
    def __init__(self, model):
        self.model = model

    def patch_head(
        self,
        clean_inputs: dict,
        corrupted_cache: dict,
        layer: int,
        head_index: int,
        target_token_index: int = -2
    ):
        """Patches head output with values from a corrupted run."""
        def patch_hook(value, hook):
            # value: [batch, pos, head, d_model]
            corrupted_value = corrupted_cache[hook.name]
            value[:, target_token_index, head_index, :] = corrupted_value[:, target_token_index, head_index, :]
            return value

        hook_name = f"blocks.{layer}.attn.hook_result"
        
        with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
            patched_outputs = self.model(**clean_inputs)
        
        return patched_outputs

    def calculate_probability_drop(
        self,
        clean_probs: torch.Tensor,
        patched_probs: torch.Tensor,
        correct_action_index: int
    ) -> float:
        """Calculates impact of patching on target action probability."""
        clean_val = clean_probs[0, -1, correct_action_index].item()
        patched_val = patched_probs[0, -1, correct_action_index].item()
        return clean_val - patched_val