Spaces:
Running
Running
File size: 2,250 Bytes
11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 8577352 11dbbc6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import pytest
import torch
from src.models.hooked_dt import HookedDT
from src.interpretability.acdc import ACDCDiscovery
from src.interpretability.path_patching import PathPatchingEngine
from src.interpretability.evolution import EvolutionaryScanner
import os
import json
@pytest.fixture
def model():
return HookedDT.from_config(state_dim=10, action_dim=3, n_layers=2, n_heads=2, d_model=32)
@pytest.fixture
def sample_inputs():
return {
"states": torch.randn(1, 5, 10),
"actions": torch.zeros(1, 5, 3),
"returns_to_go": torch.ones(1, 5, 1)
}
def test_acdc_discovery(model, sample_inputs):
"""Verifies that ACDC can prune heads and save a manifest."""
model.eval()
acdc = ACDCDiscovery(model, threshold=0.5)
circuit = acdc.run(sample_inputs, target_action=1)
assert "active_heads" in circuit
assert "initial_perf" in circuit
manifest_path = "circuit_manifest.json"
acdc.save_manifest(manifest_path)
assert os.path.exists(manifest_path)
with open(manifest_path, 'r') as f:
data = json.load(f)
assert "active_heads" in data
os.remove(manifest_path)
def test_path_patching_ablation(model, sample_inputs):
"""Verifies that ablating a head changes the model output."""
engine = PathPatchingEngine(model)
orig_output = model(**sample_inputs)
ablated_output = engine.perform_edge_ablation(
sample_inputs, layer=0, head_index=0, ablation_type="zero"
)
diff = (orig_output - ablated_output).abs().max().item()
assert diff > 0
def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
"""Verifies scanning multiple checkpoints for circuit formation."""
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_dir.mkdir()
torch.save(model.state_dict(), checkpoint_dir / "step_100.pt")
torch.save(model.state_dict(), checkpoint_dir / "step_200.pt")
scanner = EvolutionaryScanner(HookedDT, state_dim=10, action_dim=3)
results = scanner.scan_checkpoints(
str(checkpoint_dir),
sample_inputs,
target_action=1,
d_model=32,
n_heads=2
)
assert len(results) == 2
assert "active_heads" in results[0]
|