import torch from typing import Callable, List, Optional from transformer_lens import HookedTransformer class ActivationPatcher: """ Interface for causal interventions via activation patching. """ def __init__(self, model): self.model = model def patch_head( self, clean_inputs: dict, corrupted_cache: dict, layer: int, head_index: int, target_token_index: int = -2 ): """Patches head output with values from a corrupted run.""" def patch_hook(value, hook): # value: [batch, pos, head, d_model] corrupted_value = corrupted_cache[hook.name] value[:, target_token_index, head_index, :] = corrupted_value[:, target_token_index, head_index, :] return value hook_name = f"blocks.{layer}.attn.hook_result" with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]): patched_outputs = self.model(**clean_inputs) return patched_outputs def calculate_probability_drop( self, clean_probs: torch.Tensor, patched_probs: torch.Tensor, correct_action_index: int ) -> float: """Calculates impact of patching on target action probability.""" clean_val = clean_probs[0, -1, correct_action_index].item() patched_val = patched_probs[0, -1, correct_action_index].item() return clean_val - patched_val