File size: 4,626 Bytes
33a0021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pytest
import torch
import os
import json
from src.models.hooked_dt import HookedDT
from src.interpretability.circuit_surgeon import CircuitSurgeon
from src.interpretability.neuronpedia import NeuronpediaExporter

@pytest.fixture
def test_model():
    """Initializes a small, reproducible HookedDT model for testing."""
    return HookedDT.from_config(
        state_dim=3,
        action_dim=7,
        n_layers=2,
        n_heads=2,
        d_model=16,
        max_length=5
    )

def test_circuit_surgeon_registration(test_model):
    """Verifies registration and parsing of component node/edge surgical targets."""
    surgeon = CircuitSurgeon(test_model)
    
    # 1. Component node registration
    surgeon.add_node_ablation("L0H0")
    surgeon.add_node_ablation("L1MLP")
    assert "L0H0" in surgeon.ablated_nodes
    assert "L1MLP" in surgeon.ablated_nodes
    
    surgeon.remove_node_ablation("L0H0")
    assert "L0H0" not in surgeon.ablated_nodes
    assert "L1MLP" in surgeon.ablated_nodes
    
    # 2. Communication path registration
    surgeon.add_edge_ablation("L0H1", "L1H0")
    assert ("L0H1", "L1H0") in surgeon.ablated_edges
    
    surgeon.remove_edge_ablation("L0H1", "L1H0")
    assert ("L0H1", "L1H0") not in surgeon.ablated_edges
    
    surgeon.clear_ablations()
    assert len(surgeon.ablated_nodes) == 0
    assert len(surgeon.ablated_edges) == 0

def test_node_parsing(test_model):
    """Verifies string components are mapped into exact layers and attention heads."""
    surgeon = CircuitSurgeon(test_model)
    
    layer, head = surgeon.parse_node("L0H1")
    assert layer == 0
    assert head == 1
    
    layer_mlp, head_mlp = surgeon.parse_node("L2MLP")
    assert layer_mlp == 2
    assert head_mlp is None

def test_surgeon_hooks_compilation(test_model):
    """Tests the correct generation and layering of PyTorch forward hook configurations."""
    surgeon = CircuitSurgeon(test_model)
    
    # Empty surgeon should yield no hooks
    baseline_cache = {}
    hooks = surgeon.get_ablation_hooks(baseline_cache)
    assert len(hooks) == 0
    
    # Register attention head node and MLP node
    surgeon.add_node_ablation("L0H1")
    surgeon.add_node_ablation("L1MLP")
    hooks = surgeon.get_ablation_hooks(baseline_cache)
    
    # We expect 2 hooks (one result hook for L0H1, one mlp_out hook for L1MLP)
    assert len(hooks) == 2
    hook_names = [h[0] for h in hooks]
    assert "blocks.0.attn.hook_result" in hook_names
    assert "blocks.1.hook_mlp_out" in hook_names

def test_circuit_surgeon_forward_pass(test_model):
    """Performs end-to-end ablated forward passes ensuring clean executions without shape mismatches."""
    surgeon = CircuitSurgeon(test_model)
    
    states = torch.randn(1, 3, 3)
    actions = torch.randn(1, 3, 7)
    returns = torch.randn(1, 3, 1)
    
    # 1. Run baseline
    preds_clean = test_model(states, actions, returns)
    assert preds_clean.shape == (1, 3, 7)
    
    # 2. Add MLP layer ablation and run
    surgeon.add_node_ablation("L0MLP")
    preds_ablated = surgeon.compute_ablated_forward(states, actions, returns)
    assert preds_ablated.shape == (1, 3, 7)
    # The output logit values should differ since we severed an entire layer
    assert not torch.allclose(preds_clean, preds_ablated, atol=1e-4)

    # 3. Add path/edge ablation and run
    surgeon.clear_ablations()
    surgeon.add_edge_ablation("L0H0", "L1H1")
    preds_edge_ablated = surgeon.compute_ablated_forward(states, actions, returns)
    assert preds_edge_ablated.shape == (1, 3, 7)

def test_neuronpedia_exporter_local(tmp_path):
    """Verifies formatting of circuit blueprints and local serialization backups."""
    exporter = NeuronpediaExporter()
    local_file = os.path.join(tmp_path, "test_export.json")
    
    manifest = {
        "active_heads": ["L0H0", "L1H1"],
        "pruned_count": 2,
        "initial_perf": 0.95,
        "final_perf": 0.88,
        "ablated_paths": ["L0H1 -> L1H0"],
        "ablated_nodes": ["L0H1", "L1MLP"],
        "state_dim": 3,
        "action_dim": 7,
        "n_layers": 2,
        "n_heads": 2
    }
    
    res = exporter.export_circuit("mini_dt", manifest, local_path=local_file)
    assert res["status"] == "success_local"
    assert os.path.exists(local_file)
    
    # Verify exact JSON schema
    with open(local_file, "r") as f:
        data = json.load(f)
        assert data["model_id"] == "mini_dt"
        assert data["source"] == "DT-Circuits"
        assert "circuit" in data
        assert data["circuit"]["pruned_count"] == 2
        assert "L0H0" in data["circuit"]["active_heads"]