DT-Explorer / tests /test_interactive_surgery.py
sadhumitha-s's picture
feat: implement interactive circuit surgery engine, dashboard integration, and Neuronpedia export functionality
33a0021
import pytest
import torch
import os
import json
from src.models.hooked_dt import HookedDT
from src.interpretability.circuit_surgeon import CircuitSurgeon
from src.interpretability.neuronpedia import NeuronpediaExporter
@pytest.fixture
def test_model():
"""Initializes a small, reproducible HookedDT model for testing."""
return HookedDT.from_config(
state_dim=3,
action_dim=7,
n_layers=2,
n_heads=2,
d_model=16,
max_length=5
)
def test_circuit_surgeon_registration(test_model):
"""Verifies registration and parsing of component node/edge surgical targets."""
surgeon = CircuitSurgeon(test_model)
# 1. Component node registration
surgeon.add_node_ablation("L0H0")
surgeon.add_node_ablation("L1MLP")
assert "L0H0" in surgeon.ablated_nodes
assert "L1MLP" in surgeon.ablated_nodes
surgeon.remove_node_ablation("L0H0")
assert "L0H0" not in surgeon.ablated_nodes
assert "L1MLP" in surgeon.ablated_nodes
# 2. Communication path registration
surgeon.add_edge_ablation("L0H1", "L1H0")
assert ("L0H1", "L1H0") in surgeon.ablated_edges
surgeon.remove_edge_ablation("L0H1", "L1H0")
assert ("L0H1", "L1H0") not in surgeon.ablated_edges
surgeon.clear_ablations()
assert len(surgeon.ablated_nodes) == 0
assert len(surgeon.ablated_edges) == 0
def test_node_parsing(test_model):
"""Verifies string components are mapped into exact layers and attention heads."""
surgeon = CircuitSurgeon(test_model)
layer, head = surgeon.parse_node("L0H1")
assert layer == 0
assert head == 1
layer_mlp, head_mlp = surgeon.parse_node("L2MLP")
assert layer_mlp == 2
assert head_mlp is None
def test_surgeon_hooks_compilation(test_model):
"""Tests the correct generation and layering of PyTorch forward hook configurations."""
surgeon = CircuitSurgeon(test_model)
# Empty surgeon should yield no hooks
baseline_cache = {}
hooks = surgeon.get_ablation_hooks(baseline_cache)
assert len(hooks) == 0
# Register attention head node and MLP node
surgeon.add_node_ablation("L0H1")
surgeon.add_node_ablation("L1MLP")
hooks = surgeon.get_ablation_hooks(baseline_cache)
# We expect 2 hooks (one result hook for L0H1, one mlp_out hook for L1MLP)
assert len(hooks) == 2
hook_names = [h[0] for h in hooks]
assert "blocks.0.attn.hook_result" in hook_names
assert "blocks.1.hook_mlp_out" in hook_names
def test_circuit_surgeon_forward_pass(test_model):
"""Performs end-to-end ablated forward passes ensuring clean executions without shape mismatches."""
surgeon = CircuitSurgeon(test_model)
states = torch.randn(1, 3, 3)
actions = torch.randn(1, 3, 7)
returns = torch.randn(1, 3, 1)
# 1. Run baseline
preds_clean = test_model(states, actions, returns)
assert preds_clean.shape == (1, 3, 7)
# 2. Add MLP layer ablation and run
surgeon.add_node_ablation("L0MLP")
preds_ablated = surgeon.compute_ablated_forward(states, actions, returns)
assert preds_ablated.shape == (1, 3, 7)
# The output logit values should differ since we severed an entire layer
assert not torch.allclose(preds_clean, preds_ablated, atol=1e-4)
# 3. Add path/edge ablation and run
surgeon.clear_ablations()
surgeon.add_edge_ablation("L0H0", "L1H1")
preds_edge_ablated = surgeon.compute_ablated_forward(states, actions, returns)
assert preds_edge_ablated.shape == (1, 3, 7)
def test_neuronpedia_exporter_local(tmp_path):
"""Verifies formatting of circuit blueprints and local serialization backups."""
exporter = NeuronpediaExporter()
local_file = os.path.join(tmp_path, "test_export.json")
manifest = {
"active_heads": ["L0H0", "L1H1"],
"pruned_count": 2,
"initial_perf": 0.95,
"final_perf": 0.88,
"ablated_paths": ["L0H1 -> L1H0"],
"ablated_nodes": ["L0H1", "L1MLP"],
"state_dim": 3,
"action_dim": 7,
"n_layers": 2,
"n_heads": 2
}
res = exporter.export_circuit("mini_dt", manifest, local_path=local_file)
assert res["status"] == "success_local"
assert os.path.exists(local_file)
# Verify exact JSON schema
with open(local_file, "r") as f:
data = json.load(f)
assert data["model_id"] == "mini_dt"
assert data["source"] == "DT-Circuits"
assert "circuit" in data
assert data["circuit"]["pruned_count"] == 2
assert "L0H0" in data["circuit"]["active_heads"]