File size: 2,006 Bytes
11dbbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa350cc
11dbbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa350cc
11dbbc6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import os
from typing import List, Dict
from src.interpretability.acdc import ACDCDiscovery

class EvolutionaryScanner:
    """
    Analyzes how circuits evolve across different training checkpoints.
    """
    def __init__(self, model_class, state_dim: int, action_dim: int):
        self.model_class = model_class
        self.state_dim = state_dim
        self.action_dim = action_dim

    def scan_checkpoints(
        self, 
        checkpoint_dir: str, 
        inputs: Dict[str, torch.Tensor],
        target_action: int,
        threshold: float = 0.1,
        **model_kwargs
    ) -> List[Dict]:
        """
        Runs ACDC on all checkpoints in a directory and returns the results.
        """
        results = []
        ckpt_files = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt") or f.endswith(".pth")])
        
        for ckpt in ckpt_files:
            ckpt_path = os.path.join(checkpoint_dir, ckpt)
            print(f"Analyzing checkpoint: {ckpt}")
            
            # Load model
            model = self.model_class.from_config(self.state_dim, self.action_dim, **model_kwargs)
            model.load_state_dict(torch.load(ckpt_path, map_location=model.transformer.cfg.device))
            model.eval()
            
            # Run ACDC
            acdc = ACDCDiscovery(model, threshold=threshold)
            circuit = acdc.run(inputs, target_action)
            circuit["checkpoint"] = ckpt
            
            results.append(circuit)
            
        return results

    def detect_phase_transition(self, scan_results: List[Dict]) -> int:
        """
        Identifies the step where a major jump in circuit stability or performance occurred.
        """
        # Simple heuristic: look for the first checkpoint where performance > 0.5 and circuit size stabilizes.
        for i, res in enumerate(scan_results):
            if res["final_perf"] > 0.5 and len(res["active_heads"]) > 0:
                return i
        return -1