File size: 2,250 Bytes
11dbbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8577352
 
 
11dbbc6
 
 
8577352
11dbbc6
8577352
 
11dbbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8577352
11dbbc6
 
8577352
 
11dbbc6
 
 
 
8577352
11dbbc6
 
8577352
11dbbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pytest
import torch
from src.models.hooked_dt import HookedDT
from src.interpretability.acdc import ACDCDiscovery
from src.interpretability.path_patching import PathPatchingEngine
from src.interpretability.evolution import EvolutionaryScanner
import os
import json

@pytest.fixture
def model():
    return HookedDT.from_config(state_dim=10, action_dim=3, n_layers=2, n_heads=2, d_model=32)

@pytest.fixture
def sample_inputs():
    return {
        "states": torch.randn(1, 5, 10),
        "actions": torch.zeros(1, 5, 3),
        "returns_to_go": torch.ones(1, 5, 1)
    }

def test_acdc_discovery(model, sample_inputs):
    """Verifies that ACDC can prune heads and save a manifest."""
    model.eval()
    acdc = ACDCDiscovery(model, threshold=0.5)
    circuit = acdc.run(sample_inputs, target_action=1)
    
    assert "active_heads" in circuit
    assert "initial_perf" in circuit
    
    manifest_path = "circuit_manifest.json"
    acdc.save_manifest(manifest_path)
    assert os.path.exists(manifest_path)
    
    with open(manifest_path, 'r') as f:
        data = json.load(f)
        assert "active_heads" in data
    
    os.remove(manifest_path)

def test_path_patching_ablation(model, sample_inputs):
    """Verifies that ablating a head changes the model output."""
    engine = PathPatchingEngine(model)
    
    orig_output = model(**sample_inputs)
    ablated_output = engine.perform_edge_ablation(
        sample_inputs, layer=0, head_index=0, ablation_type="zero"
    )
    
    diff = (orig_output - ablated_output).abs().max().item()
    assert diff > 0

def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
    """Verifies scanning multiple checkpoints for circuit formation."""
    checkpoint_dir = tmp_path / "checkpoints"
    checkpoint_dir.mkdir()
    
    torch.save(model.state_dict(), checkpoint_dir / "step_100.pt")
    torch.save(model.state_dict(), checkpoint_dir / "step_200.pt")
    
    scanner = EvolutionaryScanner(HookedDT, state_dim=10, action_dim=3)
    results = scanner.scan_checkpoints(
        str(checkpoint_dir), 
        sample_inputs, 
        target_action=1,
        d_model=32,
        n_heads=2
    )
    
    assert len(results) == 2
    assert "active_heads" in results[0]