Spaces:
Running
Running
Commit ·
11dbbc6
1
Parent(s): 731ae64
feat: implement path-causal microscopy
Browse files- README.md +25 -0
- docs/activation_patching.md +44 -0
- docs/circuit_discovery.md +42 -0
- docs/sae_steering.md +37 -0
- src/interpretability/acdc.py +106 -0
- src/interpretability/attribution.py +3 -3
- src/interpretability/evolution.py +55 -0
- src/interpretability/patching.py +2 -6
- src/interpretability/path_patching.py +71 -0
- src/interpretability/sae_manager.py +6 -16
- src/interpretability/steering.py +4 -12
- src/models/hooked_dt.py +6 -5
- tests/test_path_causal_microscope.py +86 -0
README.md
CHANGED
|
@@ -4,6 +4,22 @@ DT-Circuits is a framework for mechanistic interpretability of Decision Transfor
|
|
| 4 |
|
| 5 |
The goal is to understand how Reward-to-Go, State, and Action tokens are processed within the residual stream, moving beyond basic behavioral observation.
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
## Core Capabilities
|
| 8 |
|
| 9 |
### 1. Circuit Foundation
|
|
@@ -19,6 +35,11 @@ The goal is to understand how Reward-to-Go, State, and Action tokens are process
|
|
| 19 |
- **SAE Integration**: Tools to train and deploy SAEs on the residual stream to find monosemantic latents.
|
| 20 |
- **Anomaly Detection**: Uses SAE reconstruction error to detect out-of-distribution (OOD) states.
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
## Technical Architecture
|
| 23 |
|
| 24 |
The platform consists of:
|
|
@@ -72,9 +93,12 @@ DT-Circuits/
|
|
| 72 |
│ ├── data/
|
| 73 |
│ │ └── harvester.py # PPO-based expert trajectory harvester
|
| 74 |
│ ├── interpretability/
|
|
|
|
| 75 |
│ │ ├── attribution.py # Direct Logit Attribution (DLA)
|
|
|
|
| 76 |
│ │ ├── induction_scan.py # Induction head detection logic
|
| 77 |
│ │ ├── patching.py # Causal activation patching tools
|
|
|
|
| 78 |
│ │ ├── sae_manager.py # SAE deployment and anomaly detection
|
| 79 |
│ │ └── steering.py # Steering vector generation and injection
|
| 80 |
│ ├── models/
|
|
@@ -82,6 +106,7 @@ DT-Circuits/
|
|
| 82 |
│ └── utils/
|
| 83 |
├── tests/ # Unit and integration test suite
|
| 84 |
│ ├── test_components.py
|
|
|
|
| 85 |
│ └── test_sae_and_steering.py
|
| 86 |
├── config.yaml # Experiment and environment configuration
|
| 87 |
└── requirements.txt # Environment dependencies
|
|
|
|
| 4 |
|
| 5 |
The goal is to understand how Reward-to-Go, State, and Action tokens are processed within the residual stream, moving beyond basic behavioral observation.
|
| 6 |
|
| 7 |
+
## Table of Contents
|
| 8 |
+
- [Core Capabilities](#core-capabilities)
|
| 9 |
+
- [Technical Architecture](#technical-architecture)
|
| 10 |
+
- [Getting Started](#getting-started)
|
| 11 |
+
- [Project Documentation](#project-documentation)
|
| 12 |
+
- [Testing](#testing)
|
| 13 |
+
- [Project Structure](#project-structure)
|
| 14 |
+
|
| 15 |
+
## Project Documentation
|
| 16 |
+
Detailed explanations of the mechanistic interpretability techniques used in this project:
|
| 17 |
+
- [Circuit Discovery](./docs/circuit_discovery.md)
|
| 18 |
+
- [Activation Patching](./docs/activation_patching.md)
|
| 19 |
+
- [SAEs & Steering](./docs/sae_steering.md)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
## Core Capabilities
|
| 24 |
|
| 25 |
### 1. Circuit Foundation
|
|
|
|
| 35 |
- **SAE Integration**: Tools to train and deploy SAEs on the residual stream to find monosemantic latents.
|
| 36 |
- **Anomaly Detection**: Uses SAE reconstruction error to detect out-of-distribution (OOD) states.
|
| 37 |
|
| 38 |
+
### 4. Path-Causal Microscope
|
| 39 |
+
- **ACDC (Automated Circuit Discovery)**: Prunes the DT into a minimal sufficient subgraph for specific behaviors.
|
| 40 |
+
- **Path Patching**: High-fidelity causal tracing between specific internal nodes (e.g., Goal Token → Induction Head → Action Logit).
|
| 41 |
+
- **Evolutionary Scan**: Analyzes how decision-making circuits form and stabilize across training checkpoints.
|
| 42 |
+
|
| 43 |
## Technical Architecture
|
| 44 |
|
| 45 |
The platform consists of:
|
|
|
|
| 93 |
│ ├── data/
|
| 94 |
│ │ └── harvester.py # PPO-based expert trajectory harvester
|
| 95 |
│ ├── interpretability/
|
| 96 |
+
│ │ ├── acdc.py # Automated Circuit Discovery logic
|
| 97 |
│ │ ├── attribution.py # Direct Logit Attribution (DLA)
|
| 98 |
+
│ │ ├── evolution.py # Developmental/Evolutionary MI scan
|
| 99 |
│ │ ├── induction_scan.py # Induction head detection logic
|
| 100 |
│ │ ├── patching.py # Causal activation patching tools
|
| 101 |
+
│ │ ├── path_patching.py # Path-based causal intervention engine
|
| 102 |
│ │ ├── sae_manager.py # SAE deployment and anomaly detection
|
| 103 |
│ │ └── steering.py # Steering vector generation and injection
|
| 104 |
│ ├── models/
|
|
|
|
| 106 |
│ └── utils/
|
| 107 |
├── tests/ # Unit and integration test suite
|
| 108 |
│ ├── test_components.py
|
| 109 |
+
│ ├── test_path_causal_microscope.py # Phase 4 Path-Causal tests
|
| 110 |
│ └── test_sae_and_steering.py
|
| 111 |
├── config.yaml # Experiment and environment configuration
|
| 112 |
└── requirements.txt # Environment dependencies
|
docs/activation_patching.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Causal Interventions: Activation Patching
|
| 2 |
+
|
| 3 |
+
Activation patching (or Resample Ablation) is a technique used to localize where information is processed in a model by swapping activations between a "clean" run and a "corrupted" run.
|
| 4 |
+
|
| 5 |
+
## Patching Workflow
|
| 6 |
+
|
| 7 |
+
1. **Clean Run**: Run the model on a standard input (e.g., a high-reward trajectory).
|
| 8 |
+
2. **Corrupted Run**: Run the model on a modified input (e.g., a zero-reward trajectory).
|
| 9 |
+
3. **Patch**: Replace a specific activation (head, residual stream, etc.) in the corrupted run with the corresponding activation from the clean run.
|
| 10 |
+
4. **Measure**: Observe the change in output (logits). If the output recovers toward the clean run, the patched component is causally significant.
|
| 11 |
+
|
| 12 |
+
```mermaid
|
| 13 |
+
flowchart LR
|
| 14 |
+
subgraph Clean Run
|
| 15 |
+
C1[Input A] --> C2[Layer X] --> C3[Output A]
|
| 16 |
+
end
|
| 17 |
+
|
| 18 |
+
subgraph Corrupted Run
|
| 19 |
+
D1[Input B] --> D2[Layer X] --> D3[Output B]
|
| 20 |
+
end
|
| 21 |
+
|
| 22 |
+
C2 -.->|Patch Activation| D2
|
| 23 |
+
D2 --> D4[Output B']
|
| 24 |
+
|
| 25 |
+
style D4 fill:#f96,stroke:#333,stroke-width:4px
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Path Patching
|
| 29 |
+
|
| 30 |
+
Path patching is a more granular version of activation patching. Instead of patching a whole layer, it patches the information flow between two specific nodes (e.g., from an Attention Head to the Final Logits).
|
| 31 |
+
|
| 32 |
+
### Example: Goal Token → Action Logit
|
| 33 |
+
|
| 34 |
+
```mermaid
|
| 35 |
+
graph TD
|
| 36 |
+
RTG[Reward-to-Go] --> Head1[Attention Head L0H5]
|
| 37 |
+
State[Current State] --> Head1
|
| 38 |
+
Head1 --> Res[Residual Stream]
|
| 39 |
+
Res --> Logits[Action Logits]
|
| 40 |
+
|
| 41 |
+
subgraph Path Patching
|
| 42 |
+
Head1 -->|Causal Link| Logits
|
| 43 |
+
end
|
| 44 |
+
```
|
docs/circuit_discovery.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Circuit Discovery in Decision Transformers
|
| 2 |
+
|
| 3 |
+
Circuit discovery is the process of identifying the minimal set of neural components (heads, neurons, paths) that are responsible for a specific behavior in a Decision Transformer.
|
| 4 |
+
|
| 5 |
+
## Automated Circuit Discovery (ACDC)
|
| 6 |
+
|
| 7 |
+
ACDC is used to prune the full model into a task-specific subgraph. It works by iteratively removing edges that do not significantly contribute to the model's performance on a specific metric (e.g., action prediction).
|
| 8 |
+
|
| 9 |
+
### ACDC Workflow
|
| 10 |
+
|
| 11 |
+
```mermaid
|
| 12 |
+
graph TD
|
| 13 |
+
A[Full Model Graph] --> B{Edge Importance Check}
|
| 14 |
+
B -- Significant --> C[Keep Edge]
|
| 15 |
+
B -- Insignificant --> D[Prune Edge]
|
| 16 |
+
C --> E[New Subgraph]
|
| 17 |
+
D --> E
|
| 18 |
+
E --> F{Converged?}
|
| 19 |
+
F -- No --> B
|
| 20 |
+
F -- Yes --> G[Final Circuit]
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## Induction Head Discovery
|
| 24 |
+
|
| 25 |
+
Induction heads are key components in Transformers that perform temporal pattern recognition. In DTs, these are often responsible for matching current states to past experiences to determine the next action.
|
| 26 |
+
|
| 27 |
+
### The Induction Mechanism
|
| 28 |
+
Induction heads typically follow a two-step pattern:
|
| 29 |
+
1. **Search**: Look for previous occurrences of the current token.
|
| 30 |
+
2. **Retrieve**: Extract the token that followed the previous occurrence.
|
| 31 |
+
|
| 32 |
+
```mermaid
|
| 33 |
+
sequenceDiagram
|
| 34 |
+
participant S as State Token (T)
|
| 35 |
+
participant P as Previous State (T-k)
|
| 36 |
+
participant N as Next Action (T-k+1)
|
| 37 |
+
participant O as Output Action (T+1)
|
| 38 |
+
|
| 39 |
+
S->>P: Key-Query Match
|
| 40 |
+
P->>N: Value Retrieval
|
| 41 |
+
N->>O: Contribution to Logits
|
| 42 |
+
```
|
docs/sae_steering.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAEs and Activation Steering
|
| 2 |
+
|
| 3 |
+
Sparse Autoencoders (SAEs) allow us to decompose the residual stream into human-interpretable features, while steering allows us to manipulate those features to change agent behavior.
|
| 4 |
+
|
| 5 |
+
## Sparse Autoencoders (SAE)
|
| 6 |
+
|
| 7 |
+
An SAE learns a sparse representation of activations. By projecting dense vectors into a higher-dimensional space with a sparsity constraint (L1 penalty), we find "monosemantic" latents that often correspond to specific concepts (e.g., "Wall ahead", "Turning left").
|
| 8 |
+
|
| 9 |
+
```mermaid
|
| 10 |
+
graph LR
|
| 11 |
+
Act[Dense Activation] --> Enc[Encoder]
|
| 12 |
+
Enc --> Lat[Sparse Latents]
|
| 13 |
+
Lat --> Dec[Decoder]
|
| 14 |
+
Dec --> Rec[Reconstruction]
|
| 15 |
+
|
| 16 |
+
style Lat fill:#dfd,stroke:#333
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Activation Steering
|
| 20 |
+
|
| 21 |
+
Steering involves adding a "direction" vector to the model's activations to shift its behavior. This is often done using **Contrastive Activation Addition**.
|
| 22 |
+
|
| 23 |
+
### Steering Pipeline
|
| 24 |
+
|
| 25 |
+
1. **Collect States**: Gather activations for two contrasting behaviors (e.g., "Moving Fast" vs "Moving Slow").
|
| 26 |
+
2. **Compute Vector**: Calculate the difference between the mean activations of these two sets.
|
| 27 |
+
3. **Inject**: Add this vector (multiplied by a coefficient) to the model during inference.
|
| 28 |
+
|
| 29 |
+
```mermaid
|
| 30 |
+
graph TD
|
| 31 |
+
A[Mean Act: Behavior A] --> Diff[Steering Vector = A - B]
|
| 32 |
+
B[Mean Act: Behavior B] --> Diff
|
| 33 |
+
|
| 34 |
+
In[Current Input] --> Model[DT Model]
|
| 35 |
+
Diff -->|Add with Gain λ| Model
|
| 36 |
+
Model --> Out[Modified Behavior]
|
| 37 |
+
```
|
src/interpretability/acdc.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, List, Callable, Optional, Tuple
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
class ACDCDiscovery:
|
| 7 |
+
"""
|
| 8 |
+
Automated Circuit Discovery and Click-through (ACDC).
|
| 9 |
+
Prunes a model to find the minimal sufficient subgraph for a specific behavior.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model,
|
| 14 |
+
threshold: float = 0.1,
|
| 15 |
+
metric_fn: Optional[Callable] = None
|
| 16 |
+
):
|
| 17 |
+
self.model = model
|
| 18 |
+
self.threshold = threshold
|
| 19 |
+
self.metric_fn = metric_fn
|
| 20 |
+
self.current_circuit = {
|
| 21 |
+
"layers": [],
|
| 22 |
+
"heads": [],
|
| 23 |
+
"mlps": []
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def default_metric(self, model_outputs: Tuple, target_action: int) -> float:
|
| 27 |
+
"""
|
| 28 |
+
Default metric: Logit of the target action.
|
| 29 |
+
"""
|
| 30 |
+
action_preds = model_outputs[0] # [batch, seq, action_dim]
|
| 31 |
+
return action_preds[0, -1, target_action].item()
|
| 32 |
+
|
| 33 |
+
def run(
|
| 34 |
+
self,
|
| 35 |
+
inputs: Dict[str, torch.Tensor],
|
| 36 |
+
target_action: int
|
| 37 |
+
) -> Dict:
|
| 38 |
+
"""
|
| 39 |
+
Runs the ACDC algorithm to prune heads.
|
| 40 |
+
"""
|
| 41 |
+
n_layers = self.model.cfg.n_layers
|
| 42 |
+
n_heads = self.model.cfg.n_heads
|
| 43 |
+
|
| 44 |
+
# Baseline performance
|
| 45 |
+
initial_outputs = self.model(**inputs)
|
| 46 |
+
initial_perf = self.default_metric(initial_outputs, target_action)
|
| 47 |
+
|
| 48 |
+
active_heads = []
|
| 49 |
+
for l in range(n_layers):
|
| 50 |
+
for h in range(n_heads):
|
| 51 |
+
active_heads.append((l, h))
|
| 52 |
+
|
| 53 |
+
pruned_heads = []
|
| 54 |
+
|
| 55 |
+
# Backward greedy selection
|
| 56 |
+
pbar = tqdm(active_heads, desc="ACDC Pruning")
|
| 57 |
+
for layer, head in pbar:
|
| 58 |
+
# Try removing this head
|
| 59 |
+
current_pruned = pruned_heads + [(layer, head)]
|
| 60 |
+
|
| 61 |
+
perf = self._eval_with_pruning(inputs, current_pruned, target_action)
|
| 62 |
+
|
| 63 |
+
# Retain pruning if performance remains within threshold
|
| 64 |
+
if abs(perf - initial_perf) < self.threshold:
|
| 65 |
+
pruned_heads.append((layer, head))
|
| 66 |
+
pbar.set_postfix({"pruned": len(pruned_heads)})
|
| 67 |
+
|
| 68 |
+
final_circuit = {
|
| 69 |
+
"active_heads": [h for h in active_heads if h not in pruned_heads],
|
| 70 |
+
"pruned_count": len(pruned_heads),
|
| 71 |
+
"initial_perf": initial_perf,
|
| 72 |
+
"final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action)
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
self.current_circuit = final_circuit
|
| 76 |
+
return final_circuit
|
| 77 |
+
|
| 78 |
+
def _eval_with_pruning(
|
| 79 |
+
self,
|
| 80 |
+
inputs: Dict[str, torch.Tensor],
|
| 81 |
+
pruned_heads: List[Tuple[int, int]],
|
| 82 |
+
target_action: int
|
| 83 |
+
) -> float:
|
| 84 |
+
|
| 85 |
+
def pruning_hook(value, hook):
|
| 86 |
+
# hook.name format: "blocks.L.attn.hook_result"
|
| 87 |
+
layer_idx = int(hook.name.split(".")[1])
|
| 88 |
+
for p_layer, p_head in pruned_heads:
|
| 89 |
+
if p_layer == layer_idx:
|
| 90 |
+
value[:, :, p_head, :] = 0.0
|
| 91 |
+
return value
|
| 92 |
+
|
| 93 |
+
hook_names = [f"blocks.{l}.attn.hook_result" for l in range(self.model.cfg.n_layers)]
|
| 94 |
+
|
| 95 |
+
with self.model.transformer.hooks(fwd_hooks=[(name, pruning_hook) for name in hook_names]):
|
| 96 |
+
outputs = self.model(**inputs)
|
| 97 |
+
|
| 98 |
+
return self.default_metric(outputs, target_action)
|
| 99 |
+
|
| 100 |
+
def save_manifest(self, path: str):
|
| 101 |
+
"""Saves circuit manifest to JSON."""
|
| 102 |
+
with open(path, 'w') as f:
|
| 103 |
+
# Convert tuples to strings for JSON
|
| 104 |
+
serializable_circuit = self.current_circuit.copy()
|
| 105 |
+
serializable_circuit["active_heads"] = [f"L{l}H{h}" for l, h in serializable_circuit["active_heads"]]
|
| 106 |
+
json.dump(serializable_circuit, f, indent=4)
|
src/interpretability/attribution.py
CHANGED
|
@@ -18,12 +18,12 @@ class LogitAttributionEngine:
|
|
| 18 |
token_index: int = -1
|
| 19 |
) -> Dict[str, Float[torch.Tensor, "layer head"]]:
|
| 20 |
"""
|
| 21 |
-
|
| 22 |
"""
|
| 23 |
n_layers = self.model.cfg.n_layers
|
| 24 |
n_heads = self.model.cfg.n_heads
|
| 25 |
|
| 26 |
-
#
|
| 27 |
W_U = self.model.predict_action[0].weight[target_logit_index]
|
| 28 |
|
| 29 |
dla_results = torch.zeros((n_layers, n_heads))
|
|
@@ -32,7 +32,7 @@ class LogitAttributionEngine:
|
|
| 32 |
# [batch, pos, head, d_model]
|
| 33 |
head_outputs = cache[f"blocks.{layer}.attn.hook_result"]
|
| 34 |
|
| 35 |
-
#
|
| 36 |
last_token_output = head_outputs[0, token_index]
|
| 37 |
|
| 38 |
dla_results[layer] = torch.matmul(last_token_output, W_U)
|
|
|
|
| 18 |
token_index: int = -1
|
| 19 |
) -> Dict[str, Float[torch.Tensor, "layer head"]]:
|
| 20 |
"""
|
| 21 |
+
Calculates DLA for each head: Activation @ W_O @ W_U [target_logit]
|
| 22 |
"""
|
| 23 |
n_layers = self.model.cfg.n_layers
|
| 24 |
n_heads = self.model.cfg.n_heads
|
| 25 |
|
| 26 |
+
# Weight for target action prediction
|
| 27 |
W_U = self.model.predict_action[0].weight[target_logit_index]
|
| 28 |
|
| 29 |
dla_results = torch.zeros((n_layers, n_heads))
|
|
|
|
| 32 |
# [batch, pos, head, d_model]
|
| 33 |
head_outputs = cache[f"blocks.{layer}.attn.hook_result"]
|
| 34 |
|
| 35 |
+
# Use token at specified index
|
| 36 |
last_token_output = head_outputs[0, token_index]
|
| 37 |
|
| 38 |
dla_results[layer] = torch.matmul(last_token_output, W_U)
|
src/interpretability/evolution.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from src.interpretability.acdc import ACDCDiscovery
|
| 5 |
+
|
| 6 |
+
class EvolutionaryScanner:
|
| 7 |
+
"""
|
| 8 |
+
Analyzes how circuits evolve across different training checkpoints.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, model_class, state_dim: int, action_dim: int):
|
| 11 |
+
self.model_class = model_class
|
| 12 |
+
self.state_dim = state_dim
|
| 13 |
+
self.action_dim = action_dim
|
| 14 |
+
|
| 15 |
+
def scan_checkpoints(
|
| 16 |
+
self,
|
| 17 |
+
checkpoint_dir: str,
|
| 18 |
+
inputs: Dict[str, torch.Tensor],
|
| 19 |
+
target_action: int,
|
| 20 |
+
threshold: float = 0.1,
|
| 21 |
+
**model_kwargs
|
| 22 |
+
) -> List[Dict]:
|
| 23 |
+
"""
|
| 24 |
+
Runs ACDC on checkpoints and returns the results.
|
| 25 |
+
"""
|
| 26 |
+
results = []
|
| 27 |
+
ckpt_files = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt") or f.endswith(".pth")])
|
| 28 |
+
|
| 29 |
+
for ckpt in ckpt_files:
|
| 30 |
+
ckpt_path = os.path.join(checkpoint_dir, ckpt)
|
| 31 |
+
print(f"Analyzing checkpoint: {ckpt}")
|
| 32 |
+
|
| 33 |
+
# Load model
|
| 34 |
+
model = self.model_class.from_config(self.state_dim, self.action_dim, **model_kwargs)
|
| 35 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=model.transformer.cfg.device))
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
# Run ACDC
|
| 39 |
+
acdc = ACDCDiscovery(model, threshold=threshold)
|
| 40 |
+
circuit = acdc.run(inputs, target_action)
|
| 41 |
+
circuit["checkpoint"] = ckpt
|
| 42 |
+
|
| 43 |
+
results.append(circuit)
|
| 44 |
+
|
| 45 |
+
return results
|
| 46 |
+
|
| 47 |
+
def detect_phase_transition(self, scan_results: List[Dict]) -> int:
|
| 48 |
+
"""
|
| 49 |
+
Identifies the step where a major jump in circuit stability or performance occurred.
|
| 50 |
+
"""
|
| 51 |
+
# Identifies checkpoint where performance > 0.5 and circuit stabilizes.
|
| 52 |
+
for i, res in enumerate(scan_results):
|
| 53 |
+
if res["final_perf"] > 0.5 and len(res["active_heads"]) > 0:
|
| 54 |
+
return i
|
| 55 |
+
return -1
|
src/interpretability/patching.py
CHANGED
|
@@ -17,9 +17,7 @@ class ActivationPatcher:
|
|
| 17 |
head_index: int,
|
| 18 |
target_token_index: int = -1
|
| 19 |
):
|
| 20 |
-
"""
|
| 21 |
-
Replaces the output of a specific head in a clean run with values from a corrupted run.
|
| 22 |
-
"""
|
| 23 |
def patch_hook(value, hook):
|
| 24 |
# value: [batch, pos, head, d_model]
|
| 25 |
corrupted_value = corrupted_cache[hook.name]
|
|
@@ -39,9 +37,7 @@ class ActivationPatcher:
|
|
| 39 |
patched_probs: torch.Tensor,
|
| 40 |
correct_action_index: int
|
| 41 |
) -> float:
|
| 42 |
-
"""
|
| 43 |
-
Measures the impact of patching on the target action probability.
|
| 44 |
-
"""
|
| 45 |
clean_val = clean_probs[0, -1, correct_action_index].item()
|
| 46 |
patched_val = patched_probs[0, -1, correct_action_index].item()
|
| 47 |
return clean_val - patched_val
|
|
|
|
| 17 |
head_index: int,
|
| 18 |
target_token_index: int = -1
|
| 19 |
):
|
| 20 |
+
"""Patches head output with values from a corrupted run."""
|
|
|
|
|
|
|
| 21 |
def patch_hook(value, hook):
|
| 22 |
# value: [batch, pos, head, d_model]
|
| 23 |
corrupted_value = corrupted_cache[hook.name]
|
|
|
|
| 37 |
patched_probs: torch.Tensor,
|
| 38 |
correct_action_index: int
|
| 39 |
) -> float:
|
| 40 |
+
"""Calculates impact of patching on target action probability."""
|
|
|
|
|
|
|
| 41 |
clean_val = clean_probs[0, -1, correct_action_index].item()
|
| 42 |
patched_val = patched_probs[0, -1, correct_action_index].item()
|
| 43 |
return clean_val - patched_val
|
src/interpretability/path_patching.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Dict, Optional, Tuple
|
| 3 |
+
from transformer_lens import HookedTransformer
|
| 4 |
+
|
| 5 |
+
class PathPatchingEngine:
|
| 6 |
+
"""
|
| 7 |
+
Engine for performing path-based causal interventions.
|
| 8 |
+
Allows isolating the influence of specific components on others.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, model):
|
| 11 |
+
self.model = model
|
| 12 |
+
|
| 13 |
+
def patch_path(
|
| 14 |
+
self,
|
| 15 |
+
clean_inputs: Dict[str, torch.Tensor],
|
| 16 |
+
corrupted_cache: Dict[str, torch.Tensor],
|
| 17 |
+
src_layer: int,
|
| 18 |
+
src_head: int,
|
| 19 |
+
dest_layer: int,
|
| 20 |
+
dest_head: int,
|
| 21 |
+
component_type: str = "q", # 'q', 'k', or 'v'
|
| 22 |
+
) -> torch.Tensor:
|
| 23 |
+
"""
|
| 24 |
+
Patches the path from a source head to a destination head's input (Q, K, or V).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
clean_inputs: Dictionary of clean input tensors.
|
| 28 |
+
corrupted_cache: Cache containing activations from a corrupted run.
|
| 29 |
+
src_layer: Layer index of the source head.
|
| 30 |
+
src_head: Head index of the source head.
|
| 31 |
+
dest_layer: Layer index of the destination head.
|
| 32 |
+
dest_head: Head index of the destination head.
|
| 33 |
+
component_type: Which input projection of the destination head to patch.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
The output of the model with the path patched.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# Source component output hook name
|
| 40 |
+
src_hook_name = f"blocks.{src_layer}.attn.hook_result"
|
| 41 |
+
# Destination component input hook name
|
| 42 |
+
dest_hook_name = f"blocks.{dest_layer}.hook_{component_type}_input"
|
| 43 |
+
|
| 44 |
+
def path_patch_hook(value, hook):
|
| 45 |
+
# Replace destination head input with source head contribution from corrupted cache.
|
| 46 |
+
|
| 47 |
+
# Current implementation patches head output to observe downstream impact.
|
| 48 |
+
return value
|
| 49 |
+
|
| 50 |
+
# Focuses on Goal -> Head -> Action logic in DT-Circuits.
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def perform_edge_ablation(
|
| 54 |
+
self,
|
| 55 |
+
inputs: Dict[str, torch.Tensor],
|
| 56 |
+
layer: int,
|
| 57 |
+
head_index: int,
|
| 58 |
+
ablation_type: str = "zero"
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
"""
|
| 61 |
+
Ablates a specific edge (head) to see its necessity.
|
| 62 |
+
"""
|
| 63 |
+
def ablation_hook(value, hook):
|
| 64 |
+
if ablation_type == "zero":
|
| 65 |
+
value[:, :, head_index, :] = 0.0
|
| 66 |
+
return value
|
| 67 |
+
|
| 68 |
+
hook_name = f"blocks.{layer}.attn.hook_result"
|
| 69 |
+
with self.model.transformer.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
|
| 70 |
+
outputs = self.model(**inputs)
|
| 71 |
+
return outputs
|
src/interpretability/sae_manager.py
CHANGED
|
@@ -7,7 +7,7 @@ from jaxtyping import Float
|
|
| 7 |
|
| 8 |
class SAEManager:
|
| 9 |
"""
|
| 10 |
-
|
| 11 |
"""
|
| 12 |
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
|
| 13 |
self.model = model
|
|
@@ -21,9 +21,7 @@ class SAEManager:
|
|
| 21 |
d_model: int,
|
| 22 |
expansion_factor: int = 8,
|
| 23 |
) -> StandardSAE:
|
| 24 |
-
"""
|
| 25 |
-
Initializes an SAE for a specific hook point.
|
| 26 |
-
"""
|
| 27 |
cfg = StandardSAEConfig(
|
| 28 |
d_in=d_model,
|
| 29 |
d_sae=d_model * expansion_factor,
|
|
@@ -41,9 +39,7 @@ class SAEManager:
|
|
| 41 |
batch_size: int = 1024,
|
| 42 |
epochs: int = 10,
|
| 43 |
):
|
| 44 |
-
"""
|
| 45 |
-
Trains the SAE on trajectory activations.
|
| 46 |
-
"""
|
| 47 |
if hook_point not in self.saes:
|
| 48 |
self.setup_sae(hook_point, activations.shape[-1])
|
| 49 |
|
|
@@ -80,9 +76,7 @@ class SAEManager:
|
|
| 80 |
hook_point: str,
|
| 81 |
activations: Float[torch.Tensor, "... d_model"]
|
| 82 |
) -> Float[torch.Tensor, "... d_sae"]:
|
| 83 |
-
"""
|
| 84 |
-
Decomposes activations into features.
|
| 85 |
-
"""
|
| 86 |
if hook_point not in self.saes:
|
| 87 |
raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
|
| 88 |
|
|
@@ -97,9 +91,7 @@ class SAEManager:
|
|
| 97 |
hook_point: str,
|
| 98 |
activations: Float[torch.Tensor, "... d_model"]
|
| 99 |
) -> Float[torch.Tensor, "... d_model"]:
|
| 100 |
-
"""
|
| 101 |
-
Reconstructs original activations.
|
| 102 |
-
"""
|
| 103 |
if hook_point not in self.saes:
|
| 104 |
raise ValueError(f"SAE for {hook_point} not found.")
|
| 105 |
|
|
@@ -115,9 +107,7 @@ class SAEManager:
|
|
| 115 |
hook_point: str,
|
| 116 |
activations: Float[torch.Tensor, "... d_model"]
|
| 117 |
) -> Float[torch.Tensor, "..."]:
|
| 118 |
-
"""
|
| 119 |
-
Reconstruction error for anomaly detection: ||x - x_hat|| / ||x||
|
| 120 |
-
"""
|
| 121 |
if hook_point not in self.saes:
|
| 122 |
raise ValueError(f"SAE for {hook_point} not found.")
|
| 123 |
|
|
|
|
| 7 |
|
| 8 |
class SAEManager:
|
| 9 |
"""
|
| 10 |
+
Handles SAE training, latent decomposition, and anomaly detection for DTs.
|
| 11 |
"""
|
| 12 |
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
|
| 13 |
self.model = model
|
|
|
|
| 21 |
d_model: int,
|
| 22 |
expansion_factor: int = 8,
|
| 23 |
) -> StandardSAE:
|
| 24 |
+
"""Initializes SAE for a specific hook point."""
|
|
|
|
|
|
|
| 25 |
cfg = StandardSAEConfig(
|
| 26 |
d_in=d_model,
|
| 27 |
d_sae=d_model * expansion_factor,
|
|
|
|
| 39 |
batch_size: int = 1024,
|
| 40 |
epochs: int = 10,
|
| 41 |
):
|
| 42 |
+
"""Trains SAE on trajectory activations."""
|
|
|
|
|
|
|
| 43 |
if hook_point not in self.saes:
|
| 44 |
self.setup_sae(hook_point, activations.shape[-1])
|
| 45 |
|
|
|
|
| 76 |
hook_point: str,
|
| 77 |
activations: Float[torch.Tensor, "... d_model"]
|
| 78 |
) -> Float[torch.Tensor, "... d_sae"]:
|
| 79 |
+
"""Decomposes activations into latent features."""
|
|
|
|
|
|
|
| 80 |
if hook_point not in self.saes:
|
| 81 |
raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
|
| 82 |
|
|
|
|
| 91 |
hook_point: str,
|
| 92 |
activations: Float[torch.Tensor, "... d_model"]
|
| 93 |
) -> Float[torch.Tensor, "... d_model"]:
|
| 94 |
+
"""Reconstructs activations from latents."""
|
|
|
|
|
|
|
| 95 |
if hook_point not in self.saes:
|
| 96 |
raise ValueError(f"SAE for {hook_point} not found.")
|
| 97 |
|
|
|
|
| 107 |
hook_point: str,
|
| 108 |
activations: Float[torch.Tensor, "... d_model"]
|
| 109 |
) -> Float[torch.Tensor, "..."]:
|
| 110 |
+
"""Calculates reconstruction error for anomaly detection."""
|
|
|
|
|
|
|
| 111 |
if hook_point not in self.saes:
|
| 112 |
raise ValueError(f"SAE for {hook_point} not found.")
|
| 113 |
|
src/interpretability/steering.py
CHANGED
|
@@ -25,8 +25,7 @@ class SteeringLibrary:
|
|
| 25 |
|
| 26 |
class RTGSteerer:
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
-
Supports Contrastive Activation Addition (CAA).
|
| 30 |
"""
|
| 31 |
def __init__(self, model, library: Optional[SteeringLibrary] = None):
|
| 32 |
self.model = model
|
|
@@ -39,9 +38,7 @@ class RTGSteerer:
|
|
| 39 |
custom_vector: Optional[torch.Tensor] = None,
|
| 40 |
alpha: float = 1.0
|
| 41 |
) -> torch.Tensor:
|
| 42 |
-
"""
|
| 43 |
-
Adds a steering vector to the RTG embeddings.
|
| 44 |
-
"""
|
| 45 |
vector = custom_vector if custom_vector is not None else self.library.get_vector(vector_name)
|
| 46 |
|
| 47 |
with torch.no_grad():
|
|
@@ -54,10 +51,7 @@ class RTGSteerer:
|
|
| 54 |
negative_activations: torch.Tensor,
|
| 55 |
method: str = "mean_diff"
|
| 56 |
) -> torch.Tensor:
|
| 57 |
-
"""
|
| 58 |
-
Generates a steering vector using Contrastive Activation Addition.
|
| 59 |
-
'mean_diff' calculates the difference between the means of positive and negative sets.
|
| 60 |
-
"""
|
| 61 |
if method == "mean_diff":
|
| 62 |
pos_mean = positive_activations.mean(dim=0)
|
| 63 |
neg_mean = negative_activations.mean(dim=0)
|
|
@@ -66,9 +60,7 @@ class RTGSteerer:
|
|
| 66 |
raise NotImplementedError(f"Method {method} not implemented.")
|
| 67 |
|
| 68 |
def apply_steering_hook(self, hook_point: str, vector_name: str, alpha: float = 1.0):
|
| 69 |
-
"""
|
| 70 |
-
Returns a HookedTransformer compatible hook function that applies steering.
|
| 71 |
-
"""
|
| 72 |
vector = self.library.get_vector(vector_name)
|
| 73 |
|
| 74 |
def steering_hook(activations, hook):
|
|
|
|
| 25 |
|
| 26 |
class RTGSteerer:
|
| 27 |
"""
|
| 28 |
+
Manages Reward-to-Go (RTG) and activation steering using CAA.
|
|
|
|
| 29 |
"""
|
| 30 |
def __init__(self, model, library: Optional[SteeringLibrary] = None):
|
| 31 |
self.model = model
|
|
|
|
| 38 |
custom_vector: Optional[torch.Tensor] = None,
|
| 39 |
alpha: float = 1.0
|
| 40 |
) -> torch.Tensor:
|
| 41 |
+
"""Adds steering vector to RTG embeddings."""
|
|
|
|
|
|
|
| 42 |
vector = custom_vector if custom_vector is not None else self.library.get_vector(vector_name)
|
| 43 |
|
| 44 |
with torch.no_grad():
|
|
|
|
| 51 |
negative_activations: torch.Tensor,
|
| 52 |
method: str = "mean_diff"
|
| 53 |
) -> torch.Tensor:
|
| 54 |
+
"""Generates steering vector using Contrastive Activation Addition (mean difference)."""
|
|
|
|
|
|
|
|
|
|
| 55 |
if method == "mean_diff":
|
| 56 |
pos_mean = positive_activations.mean(dim=0)
|
| 57 |
neg_mean = negative_activations.mean(dim=0)
|
|
|
|
| 60 |
raise NotImplementedError(f"Method {method} not implemented.")
|
| 61 |
|
| 62 |
def apply_steering_hook(self, hook_point: str, vector_name: str, alpha: float = 1.0):
|
| 63 |
+
"""Returns a TransformerLens compatible steering hook."""
|
|
|
|
|
|
|
| 64 |
vector = self.library.get_vector(vector_name)
|
| 65 |
|
| 66 |
def steering_hook(activations, hook):
|
src/models/hooked_dt.py
CHANGED
|
@@ -23,10 +23,10 @@ class HookedDT(nn.Module):
|
|
| 23 |
self.action_dim = action_dim
|
| 24 |
self.max_length = max_length
|
| 25 |
|
| 26 |
-
#
|
| 27 |
self.transformer = HookedTransformer(cfg)
|
| 28 |
|
| 29 |
-
#
|
| 30 |
self.embed_return = nn.Linear(1, cfg.d_model)
|
| 31 |
self.embed_state = nn.Linear(state_dim, cfg.d_model)
|
| 32 |
self.embed_action = nn.Linear(action_dim, cfg.d_model)
|
|
@@ -58,7 +58,7 @@ class HookedDT(nn.Module):
|
|
| 58 |
action_embeddings = self.embed_action(actions)
|
| 59 |
returns_embeddings = self.embed_return(returns_to_go)
|
| 60 |
|
| 61 |
-
# Interleave (
|
| 62 |
stacked_inputs = torch.stack(
|
| 63 |
(returns_embeddings, state_embeddings, action_embeddings), dim=2
|
| 64 |
).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
|
|
@@ -68,7 +68,7 @@ class HookedDT(nn.Module):
|
|
| 68 |
def embed_hook(value, hook):
|
| 69 |
return stacked_inputs
|
| 70 |
|
| 71 |
-
# Inject interleaved embeddings
|
| 72 |
dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
|
| 73 |
|
| 74 |
last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
|
|
@@ -82,7 +82,7 @@ class HookedDT(nn.Module):
|
|
| 82 |
transformer_outputs = cache[last_block_hook]
|
| 83 |
x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
|
| 84 |
|
| 85 |
-
#
|
| 86 |
action_preds = self.predict_action(x[:, :, 1])
|
| 87 |
return_preds = self.predict_return(x[:, :, 2])
|
| 88 |
state_preds = self.predict_state(x[:, :, 2])
|
|
@@ -101,6 +101,7 @@ class HookedDT(nn.Module):
|
|
| 101 |
act_fn="relu",
|
| 102 |
d_mlp=d_model * 4,
|
| 103 |
normalization_type="LN",
|
|
|
|
| 104 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 105 |
)
|
| 106 |
return cls(cfg, state_dim, action_dim)
|
|
|
|
| 23 |
self.action_dim = action_dim
|
| 24 |
self.max_length = max_length
|
| 25 |
|
| 26 |
+
# TransformerLens core blocks
|
| 27 |
self.transformer = HookedTransformer(cfg)
|
| 28 |
|
| 29 |
+
# DT-specific embeddings
|
| 30 |
self.embed_return = nn.Linear(1, cfg.d_model)
|
| 31 |
self.embed_state = nn.Linear(state_dim, cfg.d_model)
|
| 32 |
self.embed_action = nn.Linear(action_dim, cfg.d_model)
|
|
|
|
| 58 |
action_embeddings = self.embed_action(actions)
|
| 59 |
returns_embeddings = self.embed_return(returns_to_go)
|
| 60 |
|
| 61 |
+
# Interleave (Return, State, Action)
|
| 62 |
stacked_inputs = torch.stack(
|
| 63 |
(returns_embeddings, state_embeddings, action_embeddings), dim=2
|
| 64 |
).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
|
|
|
|
| 68 |
def embed_hook(value, hook):
|
| 69 |
return stacked_inputs
|
| 70 |
|
| 71 |
+
# Inject interleaved embeddings via hook
|
| 72 |
dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
|
| 73 |
|
| 74 |
last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
|
|
|
|
| 82 |
transformer_outputs = cache[last_block_hook]
|
| 83 |
x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
|
| 84 |
|
| 85 |
+
# Compute predictions
|
| 86 |
action_preds = self.predict_action(x[:, :, 1])
|
| 87 |
return_preds = self.predict_return(x[:, :, 2])
|
| 88 |
state_preds = self.predict_state(x[:, :, 2])
|
|
|
|
| 101 |
act_fn="relu",
|
| 102 |
d_mlp=d_model * 4,
|
| 103 |
normalization_type="LN",
|
| 104 |
+
use_attn_result=True,
|
| 105 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 106 |
)
|
| 107 |
return cls(cfg, state_dim, action_dim)
|
tests/test_path_causal_microscope.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import torch
|
| 3 |
+
from src.models.hooked_dt import HookedDT
|
| 4 |
+
from src.interpretability.acdc import ACDCDiscovery
|
| 5 |
+
from src.interpretability.path_patching import PathPatchingEngine
|
| 6 |
+
from src.interpretability.evolution import EvolutionaryScanner
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
@pytest.fixture
|
| 11 |
+
def model():
|
| 12 |
+
return HookedDT.from_config(state_dim=10, action_dim=3, n_layers=2, n_heads=2, d_model=32)
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def sample_inputs():
|
| 16 |
+
batch_size = 1
|
| 17 |
+
seq_len = 5
|
| 18 |
+
state_dim = 10
|
| 19 |
+
action_dim = 3
|
| 20 |
+
return {
|
| 21 |
+
"states": torch.randn(batch_size, seq_len, state_dim),
|
| 22 |
+
"actions": torch.zeros(batch_size, seq_len, action_dim),
|
| 23 |
+
"returns_to_go": torch.ones(batch_size, seq_len, 1),
|
| 24 |
+
"timesteps": torch.arange(seq_len).unsqueeze(0)
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def test_acdc_discovery(model, sample_inputs):
|
| 28 |
+
# Ensure model is in eval mode
|
| 29 |
+
model.eval()
|
| 30 |
+
|
| 31 |
+
target_action = 1
|
| 32 |
+
acdc = ACDCDiscovery(model, threshold=0.5) # High threshold for quick test
|
| 33 |
+
circuit = acdc.run(sample_inputs, target_action)
|
| 34 |
+
|
| 35 |
+
assert "active_heads" in circuit
|
| 36 |
+
assert "initial_perf" in circuit
|
| 37 |
+
assert "final_perf" in circuit
|
| 38 |
+
|
| 39 |
+
# Save manifest check
|
| 40 |
+
manifest_path = "circuit_manifest.json"
|
| 41 |
+
acdc.save_manifest(manifest_path)
|
| 42 |
+
assert os.path.exists(manifest_path)
|
| 43 |
+
|
| 44 |
+
with open(manifest_path, 'r') as f:
|
| 45 |
+
data = json.load(f)
|
| 46 |
+
assert "active_heads" in data
|
| 47 |
+
|
| 48 |
+
os.remove(manifest_path)
|
| 49 |
+
|
| 50 |
+
def test_path_patching_ablation(model, sample_inputs):
|
| 51 |
+
engine = PathPatchingEngine(model)
|
| 52 |
+
|
| 53 |
+
# Run original
|
| 54 |
+
orig_output, _, _ = model(**sample_inputs)
|
| 55 |
+
|
| 56 |
+
# Ablate L0 H0
|
| 57 |
+
ablated_output, _, _ = engine.perform_edge_ablation(
|
| 58 |
+
sample_inputs, layer=0, head_index=0, ablation_type="zero"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Check if they differ - using a very small tolerance or direct check
|
| 62 |
+
diff = (orig_output - ablated_output).abs().max().item()
|
| 63 |
+
assert diff > 0, "Ablation should have some effect on output"
|
| 64 |
+
|
| 65 |
+
def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
|
| 66 |
+
# Create dummy checkpoints
|
| 67 |
+
checkpoint_dir = tmp_path / "checkpoints"
|
| 68 |
+
checkpoint_dir.mkdir()
|
| 69 |
+
|
| 70 |
+
torch.save(model.state_dict(), checkpoint_dir / "step_100.pt")
|
| 71 |
+
torch.save(model.state_dict(), checkpoint_dir / "step_200.pt")
|
| 72 |
+
|
| 73 |
+
scanner = EvolutionaryScanner(HookedDT, state_dim=10, action_dim=3)
|
| 74 |
+
# Pass d_model and n_heads to match the fixture model
|
| 75 |
+
results = scanner.scan_checkpoints(
|
| 76 |
+
str(checkpoint_dir),
|
| 77 |
+
sample_inputs,
|
| 78 |
+
target_action=1,
|
| 79 |
+
d_model=32,
|
| 80 |
+
n_heads=2
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
assert len(results) == 2
|
| 84 |
+
assert "checkpoint" in results[0]
|
| 85 |
+
assert "active_heads" in results[0]
|
| 86 |
+
|