import torch import json from typing import Dict, List, Callable, Optional, Tuple from tqdm import tqdm class ACDCDiscovery: """ Automated Circuit Discovery and Click-through (ACDC). Finds the minimal set of heads needed to maintain model performance. """ def __init__(self, model, threshold: float = 0.1): self.model = model self.threshold = threshold self.current_circuit = {} def get_metric(self, action_preds: torch.Tensor, target_action: int) -> float: """Calculates logit of the target action at the last timestep.""" return action_preds[0, -1, target_action].item() def run(self, inputs: dict, target_action: int) -> dict: """Greedily prunes heads while keeping performance above threshold.""" n_layers = self.model.cfg.n_layers n_heads = self.model.cfg.n_heads # Get baseline performance initial_preds = self.model(**inputs) initial_perf = self.get_metric(initial_preds, target_action) all_heads = [(l, h) for l in range(n_layers) for h in range(n_heads)] pruned_heads = [] pbar = tqdm(all_heads, desc="ACDC Pruning") for layer, head in pbar: # Try pruning this head + already pruned heads trial_pruned = pruned_heads + [(layer, head)] perf = self._eval_with_pruning(inputs, trial_pruned, target_action) # If performance is still good, keep it pruned if abs(perf - initial_perf) < self.threshold: pruned_heads.append((layer, head)) pbar.set_postfix({"pruned": len(pruned_heads)}) active_heads = [h for h in all_heads if h not in pruned_heads] self.current_circuit = { "active_heads": active_heads, "pruned_count": len(pruned_heads), "initial_perf": initial_perf, "final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action) } return self.current_circuit def _eval_with_pruning(self, inputs: dict, pruned_heads: list, target_action: int) -> float: """Evaluates model with specified heads zeroed out.""" def pruning_hook(value, hook): layer_idx = int(hook.name.split(".")[1]) for p_layer, p_head in pruned_heads: if p_layer == layer_idx: value[:, :, p_head, :] = 0.0 return value hooks = [(f"blocks.{l}.attn.hook_result", pruning_hook) for l in range(self.model.cfg.n_layers)] with self.model.transformer.hooks(fwd_hooks=hooks): preds = self.model(**inputs) return self.get_metric(preds, target_action) def save_manifest(self, path: str): """Saves discovered circuit to a JSON file.""" with open(path, 'w') as f: data = self.current_circuit.copy() data["active_heads"] = [f"L{l}H{h}" for l, h in data["active_heads"]] json.dump(data, f, indent=4)