Spaces:
Running
Running
File size: 1,555 Bytes
11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import torch
from typing import Dict, Optional, Tuple
class PathPatchingEngine:
"""
Engine for path-based causal interventions.
Helps isolate which internal paths are necessary for a decision.
"""
def __init__(self, model):
self.model = model
def perform_edge_ablation(
self,
inputs: dict,
layer: int,
head_index: int,
ablation_type: str = "zero"
) -> torch.Tensor:
"""Zeroes out a specific head's output to check its causal necessity."""
def ablation_hook(value, hook):
if ablation_type == "zero":
value[:, :, head_index, :] = 0.0
return value
hook_name = f"blocks.{layer}.attn.hook_result"
with self.model.transformer.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
preds = self.model(**inputs)
return preds
def patch_path(
self,
clean_inputs: dict,
corrupted_cache: dict,
layer: int,
head: int
) -> torch.Tensor:
"""Patches a specific head's output with activations from a corrupted run."""
def patch_hook(value, hook):
# value: [batch, pos, head, d_model]
corrupted_val = corrupted_cache[hook.name]
value[:, :, head, :] = corrupted_val[:, :, head, :]
return value
hook_name = f"blocks.{layer}.attn.hook_result"
with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
preds = self.model(**clean_inputs)
return preds
|