Spaces:
Running
Running
File size: 2,006 Bytes
11dbbc6 fa350cc 11dbbc6 fa350cc 11dbbc6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import torch
import os
from typing import List, Dict
from src.interpretability.acdc import ACDCDiscovery
class EvolutionaryScanner:
"""
Analyzes how circuits evolve across different training checkpoints.
"""
def __init__(self, model_class, state_dim: int, action_dim: int):
self.model_class = model_class
self.state_dim = state_dim
self.action_dim = action_dim
def scan_checkpoints(
self,
checkpoint_dir: str,
inputs: Dict[str, torch.Tensor],
target_action: int,
threshold: float = 0.1,
**model_kwargs
) -> List[Dict]:
"""
Runs ACDC on all checkpoints in a directory and returns the results.
"""
results = []
ckpt_files = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt") or f.endswith(".pth")])
for ckpt in ckpt_files:
ckpt_path = os.path.join(checkpoint_dir, ckpt)
print(f"Analyzing checkpoint: {ckpt}")
# Load model
model = self.model_class.from_config(self.state_dim, self.action_dim, **model_kwargs)
model.load_state_dict(torch.load(ckpt_path, map_location=model.transformer.cfg.device))
model.eval()
# Run ACDC
acdc = ACDCDiscovery(model, threshold=threshold)
circuit = acdc.run(inputs, target_action)
circuit["checkpoint"] = ckpt
results.append(circuit)
return results
def detect_phase_transition(self, scan_results: List[Dict]) -> int:
"""
Identifies the step where a major jump in circuit stability or performance occurred.
"""
# Simple heuristic: look for the first checkpoint where performance > 0.5 and circuit size stabilizes.
for i, res in enumerate(scan_results):
if res["final_perf"] > 0.5 and len(res["active_heads"]) > 0:
return i
return -1
|