sadhumitha-s's picture
feat: implement NLA explainer and universality probe and refactor path patching engine
8577352
import torch
import json
from typing import Dict, List, Callable, Optional, Tuple
from tqdm import tqdm
class ACDCDiscovery:
"""
Automated Circuit Discovery and Click-through (ACDC).
Finds the minimal set of heads needed to maintain model performance.
"""
def __init__(self, model, threshold: float = 0.1):
self.model = model
self.threshold = threshold
self.current_circuit = {}
def get_metric(self, action_preds: torch.Tensor, target_action: int) -> float:
"""Calculates logit of the target action at the last timestep."""
return action_preds[0, -1, target_action].item()
def run(self, inputs: dict, target_action: int) -> dict:
"""Greedily prunes heads while keeping performance above threshold."""
n_layers = self.model.cfg.n_layers
n_heads = self.model.cfg.n_heads
# Get baseline performance
initial_preds = self.model(**inputs)
initial_perf = self.get_metric(initial_preds, target_action)
all_heads = [(l, h) for l in range(n_layers) for h in range(n_heads)]
pruned_heads = []
pbar = tqdm(all_heads, desc="ACDC Pruning")
for layer, head in pbar:
# Try pruning this head + already pruned heads
trial_pruned = pruned_heads + [(layer, head)]
perf = self._eval_with_pruning(inputs, trial_pruned, target_action)
# If performance is still good, keep it pruned
if abs(perf - initial_perf) < self.threshold:
pruned_heads.append((layer, head))
pbar.set_postfix({"pruned": len(pruned_heads)})
active_heads = [h for h in all_heads if h not in pruned_heads]
self.current_circuit = {
"active_heads": active_heads,
"pruned_count": len(pruned_heads),
"initial_perf": initial_perf,
"final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action)
}
return self.current_circuit
def _eval_with_pruning(self, inputs: dict, pruned_heads: list, target_action: int) -> float:
"""Evaluates model with specified heads zeroed out."""
def pruning_hook(value, hook):
layer_idx = int(hook.name.split(".")[1])
for p_layer, p_head in pruned_heads:
if p_layer == layer_idx:
value[:, :, p_head, :] = 0.0
return value
hooks = [(f"blocks.{l}.attn.hook_result", pruning_hook) for l in range(self.model.cfg.n_layers)]
with self.model.transformer.hooks(fwd_hooks=hooks):
preds = self.model(**inputs)
return self.get_metric(preds, target_action)
def save_manifest(self, path: str):
"""Saves discovered circuit to a JSON file."""
with open(path, 'w') as f:
data = self.current_circuit.copy()
data["active_heads"] = [f"L{l}H{h}" for l, h in data["active_heads"]]
json.dump(data, f, indent=4)