Spaces:
Running
Running
| import pytest | |
| import torch | |
| from src.models.hooked_dt import HookedDT | |
| from src.interpretability.acdc import ACDCDiscovery | |
| from src.interpretability.path_patching import PathPatchingEngine | |
| from src.interpretability.evolution import EvolutionaryScanner | |
| import os | |
| import json | |
| def model(): | |
| return HookedDT.from_config(state_dim=10, action_dim=3, n_layers=2, n_heads=2, d_model=32) | |
| def sample_inputs(): | |
| return { | |
| "states": torch.randn(1, 5, 10), | |
| "actions": torch.zeros(1, 5, 3), | |
| "returns_to_go": torch.ones(1, 5, 1) | |
| } | |
| def test_acdc_discovery(model, sample_inputs): | |
| """Verifies that ACDC can prune heads and save a manifest.""" | |
| model.eval() | |
| acdc = ACDCDiscovery(model, threshold=0.5) | |
| circuit = acdc.run(sample_inputs, target_action=1) | |
| assert "active_heads" in circuit | |
| assert "initial_perf" in circuit | |
| manifest_path = "circuit_manifest.json" | |
| acdc.save_manifest(manifest_path) | |
| assert os.path.exists(manifest_path) | |
| with open(manifest_path, 'r') as f: | |
| data = json.load(f) | |
| assert "active_heads" in data | |
| os.remove(manifest_path) | |
| def test_path_patching_ablation(model, sample_inputs): | |
| """Verifies that ablating a head changes the model output.""" | |
| engine = PathPatchingEngine(model) | |
| orig_output = model(**sample_inputs) | |
| ablated_output = engine.perform_edge_ablation( | |
| sample_inputs, layer=0, head_index=0, ablation_type="zero" | |
| ) | |
| diff = (orig_output - ablated_output).abs().max().item() | |
| assert diff > 0 | |
| def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path): | |
| """Verifies scanning multiple checkpoints for circuit formation.""" | |
| checkpoint_dir = tmp_path / "checkpoints" | |
| checkpoint_dir.mkdir() | |
| torch.save(model.state_dict(), checkpoint_dir / "step_100.pt") | |
| torch.save(model.state_dict(), checkpoint_dir / "step_200.pt") | |
| scanner = EvolutionaryScanner(HookedDT, state_dim=10, action_dim=3) | |
| results = scanner.scan_checkpoints( | |
| str(checkpoint_dir), | |
| sample_inputs, | |
| target_action=1, | |
| d_model=32, | |
| n_heads=2 | |
| ) | |
| assert len(results) == 2 | |
| assert "active_heads" in results[0] | |