Spaces:
Running
Running
File size: 1,467 Bytes
e2614dc 4aa19e7 e2614dc 11dbbc6 e2614dc 11dbbc6 e2614dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import torch
from typing import Callable, List, Optional
from transformer_lens import HookedTransformer
class ActivationPatcher:
"""
Interface for causal interventions via activation patching.
"""
def __init__(self, model):
self.model = model
def patch_head(
self,
clean_inputs: dict,
corrupted_cache: dict,
layer: int,
head_index: int,
target_token_index: int = -2
):
"""Patches head output with values from a corrupted run."""
def patch_hook(value, hook):
# value: [batch, pos, head, d_model]
corrupted_value = corrupted_cache[hook.name]
value[:, target_token_index, head_index, :] = corrupted_value[:, target_token_index, head_index, :]
return value
hook_name = f"blocks.{layer}.attn.hook_result"
with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
patched_outputs = self.model(**clean_inputs)
return patched_outputs
def calculate_probability_drop(
self,
clean_probs: torch.Tensor,
patched_probs: torch.Tensor,
correct_action_index: int
) -> float:
"""Calculates impact of patching on target action probability."""
clean_val = clean_probs[0, -1, correct_action_index].item()
patched_val = patched_probs[0, -1, correct_action_index].item()
return clean_val - patched_val
|