import torch import os from typing import List, Dict from src.interpretability.acdc import ACDCDiscovery class EvolutionaryScanner: """ Analyzes how circuits evolve across different training checkpoints. """ def __init__(self, model_class, state_dim: int, action_dim: int): self.model_class = model_class self.state_dim = state_dim self.action_dim = action_dim def scan_checkpoints( self, checkpoint_dir: str, inputs: Dict[str, torch.Tensor], target_action: int, threshold: float = 0.1, **model_kwargs ) -> List[Dict]: """ Runs ACDC on all checkpoints in a directory and returns the results. """ results = [] ckpt_files = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt") or f.endswith(".pth")]) for ckpt in ckpt_files: ckpt_path = os.path.join(checkpoint_dir, ckpt) print(f"Analyzing checkpoint: {ckpt}") # Load model model = self.model_class.from_config(self.state_dim, self.action_dim, **model_kwargs) model.load_state_dict(torch.load(ckpt_path, map_location=model.transformer.cfg.device)) model.eval() # Run ACDC acdc = ACDCDiscovery(model, threshold=threshold) circuit = acdc.run(inputs, target_action) circuit["checkpoint"] = ckpt results.append(circuit) return results def detect_phase_transition(self, scan_results: List[Dict]) -> int: """ Identifies the step where a major jump in circuit stability or performance occurred. """ # Simple heuristic: look for the first checkpoint where performance > 0.5 and circuit size stabilizes. for i, res in enumerate(scan_results): if res["final_perf"] > 0.5 and len(res["active_heads"]) > 0: return i return -1