Spaces:
Running
Running
File size: 3,048 Bytes
11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import torch
import json
from typing import Dict, List, Callable, Optional, Tuple
from tqdm import tqdm
class ACDCDiscovery:
"""
Automated Circuit Discovery and Click-through (ACDC).
Finds the minimal set of heads needed to maintain model performance.
"""
def __init__(self, model, threshold: float = 0.1):
self.model = model
self.threshold = threshold
self.current_circuit = {}
def get_metric(self, action_preds: torch.Tensor, target_action: int) -> float:
"""Calculates logit of the target action at the last timestep."""
return action_preds[0, -1, target_action].item()
def run(self, inputs: dict, target_action: int) -> dict:
"""Greedily prunes heads while keeping performance above threshold."""
n_layers = self.model.cfg.n_layers
n_heads = self.model.cfg.n_heads
# Get baseline performance
initial_preds = self.model(**inputs)
initial_perf = self.get_metric(initial_preds, target_action)
all_heads = [(l, h) for l in range(n_layers) for h in range(n_heads)]
pruned_heads = []
pbar = tqdm(all_heads, desc="ACDC Pruning")
for layer, head in pbar:
# Try pruning this head + already pruned heads
trial_pruned = pruned_heads + [(layer, head)]
perf = self._eval_with_pruning(inputs, trial_pruned, target_action)
# If performance is still good, keep it pruned
if abs(perf - initial_perf) < self.threshold:
pruned_heads.append((layer, head))
pbar.set_postfix({"pruned": len(pruned_heads)})
active_heads = [h for h in all_heads if h not in pruned_heads]
self.current_circuit = {
"active_heads": active_heads,
"pruned_count": len(pruned_heads),
"initial_perf": initial_perf,
"final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action)
}
return self.current_circuit
def _eval_with_pruning(self, inputs: dict, pruned_heads: list, target_action: int) -> float:
"""Evaluates model with specified heads zeroed out."""
def pruning_hook(value, hook):
layer_idx = int(hook.name.split(".")[1])
for p_layer, p_head in pruned_heads:
if p_layer == layer_idx:
value[:, :, p_head, :] = 0.0
return value
hooks = [(f"blocks.{l}.attn.hook_result", pruning_hook) for l in range(self.model.cfg.n_layers)]
with self.model.transformer.hooks(fwd_hooks=hooks):
preds = self.model(**inputs)
return self.get_metric(preds, target_action)
def save_manifest(self, path: str):
"""Saves discovered circuit to a JSON file."""
with open(path, 'w') as f:
data = self.current_circuit.copy()
data["active_heads"] = [f"L{l}H{h}" for l, h in data["active_heads"]]
json.dump(data, f, indent=4)
|