File size: 3,048 Bytes
11dbbc6
 
 
 
 
 
 
 
8577352
11dbbc6
8577352
11dbbc6
 
8577352
11dbbc6
8577352
 
11dbbc6
 
8577352
 
11dbbc6
 
 
8577352
 
 
11dbbc6
8577352
11dbbc6
 
8577352
11dbbc6
8577352
 
 
11dbbc6
8577352
11dbbc6
 
 
 
8577352
 
 
11dbbc6
 
 
 
8577352
11dbbc6
8577352
 
11dbbc6
 
 
 
 
 
 
8577352
11dbbc6
8577352
 
11dbbc6
8577352
11dbbc6
 
8577352
11dbbc6
8577352
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import json
from typing import Dict, List, Callable, Optional, Tuple
from tqdm import tqdm

class ACDCDiscovery:
    """
    Automated Circuit Discovery and Click-through (ACDC).
    Finds the minimal set of heads needed to maintain model performance.
    """
    def __init__(self, model, threshold: float = 0.1):
        self.model = model
        self.threshold = threshold
        self.current_circuit = {}

    def get_metric(self, action_preds: torch.Tensor, target_action: int) -> float:
        """Calculates logit of the target action at the last timestep."""
        return action_preds[0, -1, target_action].item()

    def run(self, inputs: dict, target_action: int) -> dict:
        """Greedily prunes heads while keeping performance above threshold."""
        n_layers = self.model.cfg.n_layers
        n_heads = self.model.cfg.n_heads
        
        # Get baseline performance
        initial_preds = self.model(**inputs)
        initial_perf = self.get_metric(initial_preds, target_action)
        
        all_heads = [(l, h) for l in range(n_layers) for h in range(n_heads)]
        pruned_heads = []
        
        pbar = tqdm(all_heads, desc="ACDC Pruning")
        for layer, head in pbar:
            # Try pruning this head + already pruned heads
            trial_pruned = pruned_heads + [(layer, head)]
            perf = self._eval_with_pruning(inputs, trial_pruned, target_action)
            
            # If performance is still good, keep it pruned
            if abs(perf - initial_perf) < self.threshold:
                pruned_heads.append((layer, head))
                pbar.set_postfix({"pruned": len(pruned_heads)})
        
        active_heads = [h for h in all_heads if h not in pruned_heads]
        self.current_circuit = {
            "active_heads": active_heads,
            "pruned_count": len(pruned_heads),
            "initial_perf": initial_perf,
            "final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action)
        }
        return self.current_circuit

    def _eval_with_pruning(self, inputs: dict, pruned_heads: list, target_action: int) -> float:
        """Evaluates model with specified heads zeroed out."""
        def pruning_hook(value, hook):
            layer_idx = int(hook.name.split(".")[1])
            for p_layer, p_head in pruned_heads:
                if p_layer == layer_idx:
                    value[:, :, p_head, :] = 0.0
            return value

        hooks = [(f"blocks.{l}.attn.hook_result", pruning_hook) for l in range(self.model.cfg.n_layers)]
        
        with self.model.transformer.hooks(fwd_hooks=hooks):
            preds = self.model(**inputs)
            
        return self.get_metric(preds, target_action)

    def save_manifest(self, path: str):
        """Saves discovered circuit to a JSON file."""
        with open(path, 'w') as f:
            data = self.current_circuit.copy()
            data["active_heads"] = [f"L{l}H{h}" for l, h in data["active_heads"]]
            json.dump(data, f, indent=4)