DT-Explorer / src /interpretability /path_patching.py
sadhumitha-s's picture
feat: implement NLA explainer and universality probe and refactor path patching engine
8577352
import torch
from typing import Dict, Optional, Tuple
class PathPatchingEngine:
"""
Engine for path-based causal interventions.
Helps isolate which internal paths are necessary for a decision.
"""
def __init__(self, model):
self.model = model
def perform_edge_ablation(
self,
inputs: dict,
layer: int,
head_index: int,
ablation_type: str = "zero"
) -> torch.Tensor:
"""Zeroes out a specific head's output to check its causal necessity."""
def ablation_hook(value, hook):
if ablation_type == "zero":
value[:, :, head_index, :] = 0.0
return value
hook_name = f"blocks.{layer}.attn.hook_result"
with self.model.transformer.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
preds = self.model(**inputs)
return preds
def patch_path(
self,
clean_inputs: dict,
corrupted_cache: dict,
layer: int,
head: int
) -> torch.Tensor:
"""Patches a specific head's output with activations from a corrupted run."""
def patch_hook(value, hook):
# value: [batch, pos, head, d_model]
corrupted_val = corrupted_cache[hook.name]
value[:, :, head, :] = corrupted_val[:, :, head, :]
return value
hook_name = f"blocks.{layer}.attn.hook_result"
with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
preds = self.model(**clean_inputs)
return preds