Spaces:
Running
Running
Commit ·
b7ddfc6
1
Parent(s): 14d2c06
refactor: implement centralized configuration, upgrade SAE training to multi-layer TopK, and optimize dashboard attribution UX
Browse files- .gitignore +2 -0
- Makefile +24 -0
- README.md +147 -76
- config.yaml +1 -0
- requirements.txt +1 -0
- scripts/train_dt.py +15 -11
- scripts/train_sae.py +102 -24
- src/config.py +73 -0
- src/dashboard/app.py +19 -19
- src/models/hooked_dt.py +3 -3
- tests/test_mechanistic.py +63 -0
.gitignore
CHANGED
|
@@ -56,3 +56,5 @@ static/
|
|
| 56 |
|
| 57 |
/PRD.md
|
| 58 |
artifacts/
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
/PRD.md
|
| 58 |
artifacts/
|
| 59 |
+
scratch/
|
| 60 |
+
*.log
|
Makefile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: setup train dashboard test clean
|
| 2 |
+
|
| 3 |
+
# Setup environment
|
| 4 |
+
setup:
|
| 5 |
+
pip install -r requirements.txt
|
| 6 |
+
|
| 7 |
+
# Run full pipeline: Harvesting -> DT Training -> SAE Training
|
| 8 |
+
train:
|
| 9 |
+
python3 scripts/train_dt.py
|
| 10 |
+
python3 scripts/train_sae.py
|
| 11 |
+
|
| 12 |
+
# Launch the explorer dashboard
|
| 13 |
+
dashboard:
|
| 14 |
+
streamlit run src/dashboard/app.py
|
| 15 |
+
|
| 16 |
+
# Run unit tests
|
| 17 |
+
test:
|
| 18 |
+
PYTHONPATH=. pytest tests/
|
| 19 |
+
|
| 20 |
+
# Remove artifacts and cached files
|
| 21 |
+
clean:
|
| 22 |
+
rm -rf data/*.pt models/*.pt artifacts/saes/*.pt
|
| 23 |
+
find . -type d -name "__pycache__" -exec rm -rf {} +
|
| 24 |
+
find . -type d -name ".pytest_cache" -exec rm -rf {} +
|
README.md
CHANGED
|
@@ -1,66 +1,86 @@
|
|
| 1 |
# DT-Circuits: Mechanistic Interpretability for Decision Transformers
|
| 2 |
|
| 3 |
-

|
| 4 |
-

|
|
|
|
|
|
|
| 5 |
|
| 6 |
DT-Circuits is a research framework for mechanistic interpretability of Decision Transformers, focused on causal analysis, sparse feature decomposition, and circuit-level understanding of sequential decision-making agents.
|
| 7 |
|
| 8 |
---
|
| 9 |
|
| 10 |
-
## Motivation
|
| 11 |
-
|
| 12 |
-
Mechanistic interpretability has primarily focused on language models, while reinforcement learning agents remain comparatively underexplored.
|
| 13 |
-
|
| 14 |
-
Decision Transformers provide a uniquely analyzable architecture because trajectories, rewards, and actions are represented in a unified autoregressive sequence.
|
| 15 |
-
|
| 16 |
-
DT-Circuits aims to make RL agents inspectable at the circuit level rather than only through behavioral evaluation.
|
| 17 |
-
|
| 18 |
-
---
|
| 19 |
-
|
| 20 |
## Table of Contents
|
| 21 |
-
- [
|
| 22 |
-
- [Technical
|
|
|
|
| 23 |
- [Project Structure](#project-structure)
|
| 24 |
-
- [
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
---
|
| 27 |
-
|
| 28 |
-
## Documentation
|
| 29 |
-
- [Circuit Discovery](./docs/circuit_discovery.md)
|
| 30 |
-
- [Activation Patching](./docs/activation_patching.md)
|
| 31 |
-
- [SAEs & Steering](./docs/sae_steering.md)
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
- **Hooked-DT**: Access any internal activation or weight during the agent's run.
|
| 39 |
-
- **Logit Attribution**: See which attention heads or MLP layers drive specific actions.
|
| 40 |
-
- **Induction Scan**: Find heads that recognize temporal patterns and past states.
|
| 41 |
|
| 42 |
-
##
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
- **TopK SAEs**: Decompose complex activations into a few active "concepts" for easier reading.
|
| 48 |
-
- **Auto-Labeling (NLA)**: Use an LLM to automatically describe what each discovered neuron feature does.
|
| 49 |
-
- **Cross-Model Probes**: Check if different agents (like DQNs) learn the same internal concepts as the DT.
|
| 50 |
|
| 51 |
-
##
|
| 52 |
-
- **ACDC**: Automatically strip the model down to the minimal circuit needed for a task.
|
| 53 |
-
- **Path Patching**: Trace how a signal flows from a specific input token to the final action.
|
| 54 |
-
- **Evolutionary Scan**: Watch how decision-making circuits form during training.
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
##
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
- **Dashboard**: Streamlit UI for real-time model analysis.
|
| 64 |
|
| 65 |
---
|
| 66 |
|
|
@@ -68,10 +88,7 @@ DT-Circuits aims to make RL agents inspectable at the circuit level rather than
|
|
| 68 |
|
| 69 |
```text
|
| 70 |
DT-Circuits/
|
| 71 |
-
├──
|
| 72 |
-
│ ├── train_dt.py # Decision Transformer training pipeline
|
| 73 |
-
│ └── train_sae.py # Sparse Autoencoder (SAE) training script
|
| 74 |
-
├── src/
|
| 75 |
│ ├── dashboard/
|
| 76 |
│ │ └── app.py # Streamlit-based visualization UI
|
| 77 |
│ ├── data/
|
|
@@ -89,44 +106,58 @@ DT-Circuits/
|
|
| 89 |
│ │ └── universality.py # Cross-architecture feature mapping
|
| 90 |
│ ├── models/
|
| 91 |
│ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
|
|
|
|
| 92 |
│ └── utils/
|
| 93 |
├── tests/ # Unit tests for all modules
|
| 94 |
-
├── config.yaml
|
| 95 |
-
|
|
|
|
| 96 |
```
|
| 97 |
|
| 98 |
-
---
|
| 99 |
|
| 100 |
-
##
|
| 101 |
|
| 102 |
-
|
| 103 |
-
- Python 3.9+
|
| 104 |
-
- PyTorch 2.x
|
| 105 |
-
- TransformerLens
|
| 106 |
-
- SAE-Lens
|
| 107 |
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
### Workflow
|
| 132 |
|
|
@@ -135,7 +166,47 @@ Follow these steps to initialize the environment and verify the installation.
|
|
| 135 |
python scripts/train_dt.py
|
| 136 |
```
|
| 137 |
|
| 138 |
-
2. **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
```bash
|
| 140 |
streamlit run src/dashboard/app.py
|
| 141 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# DT-Circuits: Mechanistic Interpretability for Decision Transformers
|
| 2 |
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 6 |
+
[](https://github.com/TransformerLensOrg/TransformerLens)
|
| 7 |
|
| 8 |
DT-Circuits is a research framework for mechanistic interpretability of Decision Transformers, focused on causal analysis, sparse feature decomposition, and circuit-level understanding of sequential decision-making agents.
|
| 9 |
|
| 10 |
---
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
## Table of Contents
|
| 13 |
+
- [Core Objectives](#core-objectives)
|
| 14 |
+
- [Technical Overview](#technical-overview)
|
| 15 |
+
- [Capabilities](#capabilities)
|
| 16 |
- [Project Structure](#project-structure)
|
| 17 |
+
- [Installation and Usage](#installation-and-usage)
|
| 18 |
+
- [Documentation](#documentation)
|
| 19 |
+
- [Citation](#citation)
|
| 20 |
+
- [License](#license)
|
| 21 |
|
| 22 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
## Core Objectives
|
| 25 |
|
| 26 |
+
1. **Map Information Flow**: Quantify how input tokens (State, Action, Reward-to-Go) contribute to the output action logits.
|
| 27 |
+
2. **Causal Verification**: Use intervention techniques to identify the minimal set of model components required for specific behaviors.
|
| 28 |
+
3. **Feature Decomposition**: Use Sparse Autoencoders (SAEs) to identify monosemantic features within the model's residual stream.
|
| 29 |
+
4. **Behavioral Control**: Modify agent decisions at inference time by manipulating internal activations.
|
| 30 |
|
| 31 |
+
---
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
## Technical Overview
|
| 34 |
+
|
| 35 |
+
The framework centers around `HookedDT`, a Decision Transformer implementation that allows for activation hooking and cache management.
|
| 36 |
+
|
| 37 |
+
### Information Flow Diagram
|
| 38 |
+
|
| 39 |
+
```mermaid
|
| 40 |
+
graph TD
|
| 41 |
+
subgraph Input_Sequence
|
| 42 |
+
S[State Tokens]
|
| 43 |
+
A[Action Tokens]
|
| 44 |
+
RTG[Reward-to-Go Tokens]
|
| 45 |
+
end
|
| 46 |
+
|
| 47 |
+
Input_Sequence --> Embed[Embedding Layers]
|
| 48 |
+
Embed --> Hooks[Activation Hooks]
|
| 49 |
+
|
| 50 |
+
subgraph Transformer_Block
|
| 51 |
+
Hooks --> Attn[Multi-Head Attention]
|
| 52 |
+
Attn --> MLP[MLP Layers]
|
| 53 |
+
MLP --> Res[Residual Stream]
|
| 54 |
+
end
|
| 55 |
+
|
| 56 |
+
Res --> DLA[Direct Logit Attribution]
|
| 57 |
+
Res --> SAE[Sparse Autoencoder]
|
| 58 |
+
Res --> Output[Action Logits]
|
| 59 |
+
|
| 60 |
+
subgraph Interpretability_Modules
|
| 61 |
+
DLA -.-> Analysis
|
| 62 |
+
SAE -.-> Features
|
| 63 |
+
Intervention[Activation Patching] -.-> Hooks
|
| 64 |
+
end
|
| 65 |
+
```
|
| 66 |
|
| 67 |
+
---
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
## Capabilities
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
### Causal Mediation and Attribution
|
| 72 |
+
* **Direct Logit Attribution (DLA)**: Measures the direct contribution of individual attention heads and MLP layers to the final logit output.
|
| 73 |
+
* **Activation Patching**: Substitutes internal activations from different runs to isolate the causal effect of specific inputs on model behavior.
|
| 74 |
+
* **Path Patching**: Traces how information flows through specific connections between model components.
|
| 75 |
|
| 76 |
+
### Feature Discovery and Analysis
|
| 77 |
+
* **Sparse Autoencoders (SAEs)**: Decomposes the residual stream into a set of sparse features, helping to resolve polysemanticity.
|
| 78 |
+
* **Induction Scanning**: Identifies attention heads that perform pattern-matching and temporal sequence recognition.
|
| 79 |
+
* **Automated Circuit Discovery (ACDC)**: Prunes the model to identify the smallest functional subgraph sufficient to perform a specific task.
|
| 80 |
|
| 81 |
+
### Behavioral Steering
|
| 82 |
+
* **Activation Steering**: Injects specific vectors into the residual stream to bias the agent's decision-making without retraining the weights.
|
| 83 |
+
* **Safety Auditing**: Monitors SAE reconstruction error and feature activation to detect anomalous or out-of-distribution internal states.
|
|
|
|
| 84 |
|
| 85 |
---
|
| 86 |
|
|
|
|
| 88 |
|
| 89 |
```text
|
| 90 |
DT-Circuits/
|
| 91 |
+
├── src/
|
|
|
|
|
|
|
|
|
|
| 92 |
│ ├── dashboard/
|
| 93 |
│ │ └── app.py # Streamlit-based visualization UI
|
| 94 |
│ ├── data/
|
|
|
|
| 106 |
│ │ └── universality.py # Cross-architecture feature mapping
|
| 107 |
│ ├── models/
|
| 108 |
│ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
|
| 109 |
+
│ ├── config.py # Centralized hyperparameter management
|
| 110 |
│ └── utils/
|
| 111 |
├── tests/ # Unit tests for all modules
|
| 112 |
+
├── config.yaml # External hyperparameter storage
|
| 113 |
+
├── requirements.txt
|
| 114 |
+
└── docs/
|
| 115 |
```
|
| 116 |
|
| 117 |
+
---
|
| 118 |
|
| 119 |
+
## Configuration
|
| 120 |
|
| 121 |
+
Hyperparameters are managed through a dual-system for both ease of use and research reproducibility:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
1. **`config.yaml`**: The primary interface for users. You can modify model dimensions, training epochs, and environment settings here without touching the code.
|
| 124 |
+
2. **`src/config.py`**: Defines the underlying structure using Python dataclasses. It automatically loads overrides from `config.yaml` at runtime.
|
| 125 |
|
| 126 |
+
### Key Configuration Sections
|
| 127 |
|
| 128 |
+
| Section | Description | Key Parameters |
|
| 129 |
+
| :--- | :--- | :--- |
|
| 130 |
+
| **`model`** | Architecture settings for the Decision Transformer | `n_layers`, `d_model`, `n_heads`, `max_length` |
|
| 131 |
+
| **`data`** | Settings for expert trajectory collection | `env_id`, `num_episodes` (for DT training) |
|
| 132 |
+
| **`train`** | DT training hyperparameters | `lr`, `epochs`, `seed` |
|
| 133 |
+
| **`sae`** | Sparse Autoencoder training hyperparameters | `expansion_factor`, `k`, `num_episodes` (SAE specific) |
|
| 134 |
|
| 135 |
+
**Example: Independent Data Control**
|
| 136 |
+
You can control the amount of data used for general training vs. interpretability separately:
|
| 137 |
+
```yaml
|
| 138 |
+
data:
|
| 139 |
+
num_episodes: 1000 # Episodes for training the DT teacher
|
| 140 |
|
| 141 |
+
sae:
|
| 142 |
+
num_episodes: 500 # Episodes for extracting SAE activations
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
---
|
| 146 |
+
|
| 147 |
+
## Installation and Usage
|
| 148 |
+
|
| 149 |
+
### Setup
|
| 150 |
+
```bash
|
| 151 |
+
python -m venv venv
|
| 152 |
+
source venv/bin/activate
|
| 153 |
+
pip install -r requirements.txt
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Dashboard Execution
|
| 157 |
+
Launch the `DT-Explorer` dashboard. The dashboard will initialize with a random model if no trained weights are detected.
|
| 158 |
+
```bash
|
| 159 |
+
streamlit run src/dashboard/app.py
|
| 160 |
+
```
|
| 161 |
|
| 162 |
### Workflow
|
| 163 |
|
|
|
|
| 166 |
python scripts/train_dt.py
|
| 167 |
```
|
| 168 |
|
| 169 |
+
2. **SAE Training**
|
| 170 |
+
```bash
|
| 171 |
+
python scripts/train_sae.py
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
3. **Interpretability Analysis**
|
| 175 |
```bash
|
| 176 |
streamlit run src/dashboard/app.py
|
| 177 |
```
|
| 178 |
+
|
| 179 |
+
### Alternative: Makefile
|
| 180 |
+
Common tasks can also be executed via `make`:
|
| 181 |
+
```bash
|
| 182 |
+
make setup # Install dependencies
|
| 183 |
+
make train # Run full training pipeline (DT + SAE)
|
| 184 |
+
make dashboard # Launch DT-Explorer
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## Documentation
|
| 190 |
+
|
| 191 |
+
Detailed technical documentation for specific modules:
|
| 192 |
+
* [Circuit Discovery](./docs/circuit_discovery.md)
|
| 193 |
+
* [Causal Intervention](./docs/activation_patching.md)
|
| 194 |
+
* [SAEs and Steering](./docs/sae_steering.md)
|
| 195 |
+
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
## Citation
|
| 199 |
+
|
| 200 |
+
```bibtex
|
| 201 |
+
@software{dt_circuits2026,
|
| 202 |
+
author = {Sadhumitha S.},
|
| 203 |
+
title = {DT-Circuits: Mechanistic Interpretability for Decision Transformers},
|
| 204 |
+
year = {2026},
|
| 205 |
+
url = {https://github.com/sadhumitha-s/DT-Circuits}
|
| 206 |
+
}
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## License
|
| 212 |
+
Apache 2.0
|
config.yaml
CHANGED
|
@@ -16,3 +16,4 @@ interpretability:
|
|
| 16 |
sae:
|
| 17 |
expansion_factor: 8
|
| 18 |
l1_coeff: 0.0005
|
|
|
|
|
|
| 16 |
sae:
|
| 17 |
expansion_factor: 8
|
| 18 |
l1_coeff: 0.0005
|
| 19 |
+
num_episodes: 100
|
requirements.txt
CHANGED
|
@@ -14,3 +14,4 @@ pytest
|
|
| 14 |
stable-baselines3
|
| 15 |
shimmy
|
| 16 |
seaborn
|
|
|
|
|
|
| 14 |
stable-baselines3
|
| 15 |
shimmy
|
| 16 |
seaborn
|
| 17 |
+
torchvision
|
scripts/train_dt.py
CHANGED
|
@@ -14,39 +14,43 @@ if root_path not in sys.path:
|
|
| 14 |
|
| 15 |
from src.models.hooked_dt import HookedDT
|
| 16 |
from src.data.harvester import PPOHarvester
|
|
|
|
| 17 |
|
| 18 |
def train():
|
| 19 |
"""Main training loop for Decision Transformer."""
|
| 20 |
# Step 1: Collect data from expert PPO teacher
|
| 21 |
-
harvester = PPOHarvester(model_path="models/ppo_teacher.zip")
|
| 22 |
-
trajectories = harvester.collect_trajectories(num_episodes=
|
| 23 |
|
| 24 |
# Save trajectories for the dashboard to use later
|
| 25 |
harvester.save_trajectories(trajectories, "data/trajectories.pt")
|
| 26 |
|
| 27 |
state_dim = trajectories[0]["observations"].shape[1]
|
| 28 |
-
action_dim =
|
| 29 |
|
| 30 |
model = HookedDT.from_config(
|
| 31 |
state_dim=state_dim,
|
| 32 |
action_dim=action_dim,
|
| 33 |
-
n_layers=
|
| 34 |
-
n_heads=
|
| 35 |
-
d_model=
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
-
optimizer = optim.AdamW(model.parameters(), lr=
|
| 39 |
criterion = nn.CrossEntropyLoss()
|
| 40 |
|
| 41 |
# Step 2: Train the DT
|
| 42 |
model.train()
|
| 43 |
-
for epoch in range(
|
| 44 |
total_loss = 0
|
| 45 |
for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
|
| 49 |
-
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 50 |
|
| 51 |
# Predict actions based on State tokens
|
| 52 |
action_preds = model(states, actions_one_hot, returns)
|
|
|
|
| 14 |
|
| 15 |
from src.models.hooked_dt import HookedDT
|
| 16 |
from src.data.harvester import PPOHarvester
|
| 17 |
+
from src.config import cfg
|
| 18 |
|
| 19 |
def train():
|
| 20 |
"""Main training loop for Decision Transformer."""
|
| 21 |
# Step 1: Collect data from expert PPO teacher
|
| 22 |
+
harvester = PPOHarvester(env_id=cfg.data.env_id, model_path="models/ppo_teacher.zip")
|
| 23 |
+
trajectories = harvester.collect_trajectories(num_episodes=cfg.data.num_episodes)
|
| 24 |
|
| 25 |
# Save trajectories for the dashboard to use later
|
| 26 |
harvester.save_trajectories(trajectories, "data/trajectories.pt")
|
| 27 |
|
| 28 |
state_dim = trajectories[0]["observations"].shape[1]
|
| 29 |
+
action_dim = cfg.model.action_dim
|
| 30 |
|
| 31 |
model = HookedDT.from_config(
|
| 32 |
state_dim=state_dim,
|
| 33 |
action_dim=action_dim,
|
| 34 |
+
n_layers=cfg.model.n_layers,
|
| 35 |
+
n_heads=cfg.model.n_heads,
|
| 36 |
+
d_model=cfg.model.d_model,
|
| 37 |
+
max_length=cfg.model.max_length
|
| 38 |
)
|
| 39 |
|
| 40 |
+
optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
|
| 41 |
criterion = nn.CrossEntropyLoss()
|
| 42 |
|
| 43 |
# Step 2: Train the DT
|
| 44 |
model.train()
|
| 45 |
+
for epoch in range(cfg.train.epochs):
|
| 46 |
total_loss = 0
|
| 47 |
for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
|
| 48 |
+
# Truncate to match model max_length
|
| 49 |
+
max_len = model.max_length
|
| 50 |
+
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)[:, -max_len:]
|
| 51 |
+
actions = torch.from_numpy(traj["actions"]).long()[-max_len:]
|
| 52 |
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
|
| 53 |
+
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)[:, -max_len:]
|
| 54 |
|
| 55 |
# Predict actions based on State tokens
|
| 56 |
action_preds = model(states, actions_one_hot, returns)
|
scripts/train_sae.py
CHANGED
|
@@ -1,34 +1,112 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
-
from sae_lens import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from src.models.hooked_dt import HookedDT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
def train_sae():
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
if __name__ == "__main__":
|
| 34 |
train_sae()
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
import torch
|
| 4 |
+
from sae_lens import TopKSAEConfig, TopKSAE
|
| 5 |
+
|
| 6 |
+
# Add project root to path for absolute imports
|
| 7 |
+
root_path = str(Path(__file__).resolve().parent.parent)
|
| 8 |
+
if root_path not in sys.path:
|
| 9 |
+
sys.path.append(root_path)
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
from src.models.hooked_dt import HookedDT
|
| 14 |
+
from src.interpretability.sae_manager import SAEManager
|
| 15 |
+
from src.config import cfg
|
| 16 |
+
|
| 17 |
+
def set_seed(seed: int = 42):
|
| 18 |
+
random.seed(seed)
|
| 19 |
+
np.random.seed(seed)
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
torch.cuda.manual_seed_all(seed)
|
| 23 |
+
torch.backends.cudnn.deterministic = True
|
| 24 |
+
torch.backends.cudnn.benchmark = False
|
| 25 |
|
| 26 |
def train_sae():
|
| 27 |
+
# 0. Set seed for reproducibility
|
| 28 |
+
set_seed(cfg.train.seed)
|
| 29 |
+
|
| 30 |
+
# 1. Load Trajectories to get dimensions
|
| 31 |
+
traj_path = "data/trajectories.pt"
|
| 32 |
+
if not Path(traj_path).exists():
|
| 33 |
+
print(f"Error: {traj_path} not found. Please run scripts/train_dt.py first.")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
trajectories = torch.load(traj_path, weights_only=False)
|
| 37 |
+
print(f"Loaded {len(trajectories)} trajectories.")
|
| 38 |
+
|
| 39 |
+
# 2. Initialize Model
|
| 40 |
+
state_dim = trajectories[0]["observations"].shape[1]
|
| 41 |
+
action_dim = cfg.model.action_dim
|
| 42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
+
|
| 44 |
+
model = HookedDT.from_config(
|
| 45 |
+
state_dim=state_dim,
|
| 46 |
+
action_dim=action_dim,
|
| 47 |
+
n_layers=cfg.model.n_layers,
|
| 48 |
+
n_heads=cfg.model.n_heads,
|
| 49 |
+
d_model=cfg.model.d_model
|
| 50 |
)
|
| 51 |
+
model.to(device)
|
| 52 |
+
|
| 53 |
+
# Check for trained DT checkpoint
|
| 54 |
+
checkpoint_path = "models/mini_dt.pt"
|
| 55 |
+
if Path(checkpoint_path).exists():
|
| 56 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 57 |
+
print(f"Loaded DT weights from {checkpoint_path}")
|
| 58 |
+
else:
|
| 59 |
+
print(f"Warning: {checkpoint_path} not found. Training SAE on random weights.")
|
| 60 |
|
| 61 |
+
# 3. & 4. Train SAEs for ALL layers
|
| 62 |
+
manager = SAEManager(model, sae_dir="artifacts/saes")
|
| 63 |
|
| 64 |
+
for layer in range(model.cfg.n_layers):
|
| 65 |
+
hook_point = f"blocks.{layer}.hook_resid_post"
|
| 66 |
+
all_activations = []
|
| 67 |
+
|
| 68 |
+
print(f"\n--- Processing Layer {layer} ({hook_point}) ---")
|
| 69 |
+
|
| 70 |
+
# Extract Activations
|
| 71 |
+
model.eval()
|
| 72 |
+
print(f"Extracting activations...")
|
| 73 |
+
|
| 74 |
+
# Number of trajectories from config
|
| 75 |
+
num_trajs_to_use = min(len(trajectories), cfg.sae.num_episodes)
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
for traj in trajectories[:num_trajs_to_use]:
|
| 79 |
+
states = torch.from_numpy(traj["observations"]).float().to(device).unsqueeze(0)
|
| 80 |
+
actions = torch.from_numpy(traj["actions"]).long().to(device)
|
| 81 |
+
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
|
| 82 |
+
returns = torch.from_numpy(traj["rewards"]).float().to(device).unsqueeze(0).unsqueeze(-1)
|
| 83 |
+
|
| 84 |
+
_, cache = model(states, actions_one_hot, returns, return_cache=True)
|
| 85 |
+
all_activations.append(cache[hook_point].squeeze(0).cpu())
|
| 86 |
+
|
| 87 |
+
activations = torch.cat(all_activations, dim=0)
|
| 88 |
+
print(f"Collected {activations.shape[0]} activation vectors.")
|
| 89 |
+
|
| 90 |
+
# Setup and Train
|
| 91 |
+
print(f"Starting TopK SAE training...")
|
| 92 |
+
manager.setup_sae(
|
| 93 |
+
hook_point=hook_point,
|
| 94 |
+
d_model=cfg.model.d_model,
|
| 95 |
+
architecture="topk",
|
| 96 |
+
k=cfg.sae.k
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
manager.train_on_trajectories(
|
| 100 |
+
hook_point=hook_point,
|
| 101 |
+
activations=activations,
|
| 102 |
+
epochs=cfg.sae.epochs,
|
| 103 |
+
batch_size=cfg.sae.batch_size
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Save all SAEs once training is complete for all layers
|
| 107 |
+
manager.save_all_saes()
|
| 108 |
+
|
| 109 |
+
print(f"\nSAE Training Complete for all {model.cfg.n_layers} layers. Results saved to artifacts/saes/")
|
| 110 |
|
| 111 |
if __name__ == "__main__":
|
| 112 |
train_sae()
|
src/config.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
import yaml
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class ModelConfig:
|
| 8 |
+
n_layers: int = 2
|
| 9 |
+
n_heads: int = 4
|
| 10 |
+
d_model: int = 128
|
| 11 |
+
max_length: int = 30
|
| 12 |
+
state_dim: Optional[int] = None
|
| 13 |
+
action_dim: int = 7
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DataConfig:
|
| 17 |
+
env_id: str = "MiniGrid-Empty-8x8-v0"
|
| 18 |
+
num_episodes: int = 1000
|
| 19 |
+
collection_method: str = "PPO-Teacher"
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class TrainConfig:
|
| 23 |
+
lr: float = 1e-4
|
| 24 |
+
epochs: int = 10
|
| 25 |
+
seed: int = 42
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class SAEConfig:
|
| 29 |
+
expansion_factor: int = 8
|
| 30 |
+
k: int = 32
|
| 31 |
+
l1_coeff: float = 0.0005
|
| 32 |
+
lr: float = 3e-4
|
| 33 |
+
epochs: int = 5
|
| 34 |
+
batch_size: int = 1024
|
| 35 |
+
num_episodes: int = 100
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class Config:
|
| 39 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 40 |
+
data: DataConfig = field(default_factory=DataConfig)
|
| 41 |
+
train: TrainConfig = field(default_factory=TrainConfig)
|
| 42 |
+
sae: SAEConfig = field(default_factory=SAEConfig)
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def load_from_yaml(cls, yaml_path: str = "config.yaml") -> "Config":
|
| 46 |
+
"""Loads configuration from a YAML file, overriding defaults."""
|
| 47 |
+
path = Path(yaml_path)
|
| 48 |
+
if not path.exists():
|
| 49 |
+
return cls()
|
| 50 |
+
|
| 51 |
+
with open(path, "r") as f:
|
| 52 |
+
data = yaml.safe_load(f)
|
| 53 |
+
|
| 54 |
+
# Helper to safely update dataclass from dict
|
| 55 |
+
def update_dataclass(dc_obj, dc_dict):
|
| 56 |
+
for key, value in dc_dict.items():
|
| 57 |
+
if hasattr(dc_obj, key):
|
| 58 |
+
setattr(dc_obj, key, value)
|
| 59 |
+
|
| 60 |
+
config = cls()
|
| 61 |
+
if "model" in data:
|
| 62 |
+
update_dataclass(config.model, data["model"])
|
| 63 |
+
if "data" in data:
|
| 64 |
+
update_dataclass(config.data, data["data"])
|
| 65 |
+
if "train" in data:
|
| 66 |
+
update_dataclass(config.train, data["train"])
|
| 67 |
+
if "sae" in data:
|
| 68 |
+
update_dataclass(config.sae, data["sae"])
|
| 69 |
+
|
| 70 |
+
return config
|
| 71 |
+
|
| 72 |
+
# Global config instance for easy access
|
| 73 |
+
cfg = Config.load_from_yaml()
|
src/dashboard/app.py
CHANGED
|
@@ -70,25 +70,25 @@ with tab1:
|
|
| 70 |
st.header("Direct Logit Attribution (DLA)")
|
| 71 |
st.write("Visualizing which heads contribute most to the predicted action.")
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
|
| 93 |
with tab2:
|
| 94 |
st.header("Activation Patching")
|
|
|
|
| 70 |
st.header("Direct Logit Attribution (DLA)")
|
| 71 |
st.write("Visualizing which heads contribute most to the predicted action.")
|
| 72 |
|
| 73 |
+
# Run automatically for better UX when changing trajectories
|
| 74 |
+
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
|
| 75 |
+
actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0)
|
| 76 |
+
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 77 |
+
|
| 78 |
+
preds, cache = model(states, actions, returns, return_cache=True)
|
| 79 |
+
target_action = preds[0, -1].argmax().item()
|
| 80 |
+
|
| 81 |
+
engine = LogitAttributionEngine(model)
|
| 82 |
+
# Use token index -2 to target the state token which predicts the action
|
| 83 |
+
dla_results = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2)
|
| 84 |
+
|
| 85 |
+
fig, ax = plt.subplots()
|
| 86 |
+
im = ax.imshow(dla_results.detach().cpu().numpy(), cmap="RdBu_r", aspect='auto')
|
| 87 |
+
plt.colorbar(im)
|
| 88 |
+
ax.set_xlabel("Head")
|
| 89 |
+
ax.set_ylabel("Layer")
|
| 90 |
+
st.pyplot(fig)
|
| 91 |
+
st.write(f"Analyzing Attribution for Action: {target_action} (at State token)")
|
| 92 |
|
| 93 |
with tab2:
|
| 94 |
st.header("Activation Patching")
|
src/models/hooked_dt.py
CHANGED
|
@@ -85,11 +85,11 @@ class HookedDT(nn.Module):
|
|
| 85 |
return action_preds
|
| 86 |
|
| 87 |
@classmethod
|
| 88 |
-
def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128):
|
| 89 |
cfg = HookedTransformerConfig(
|
| 90 |
n_layers=n_layers,
|
| 91 |
d_model=d_model,
|
| 92 |
-
n_ctx=
|
| 93 |
d_head=d_model // n_heads,
|
| 94 |
n_heads=n_heads,
|
| 95 |
d_vocab=10, # Dummy vocab size
|
|
@@ -99,5 +99,5 @@ class HookedDT(nn.Module):
|
|
| 99 |
use_attn_result=True,
|
| 100 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
)
|
| 102 |
-
return cls(cfg, state_dim, action_dim)
|
| 103 |
|
|
|
|
| 85 |
return action_preds
|
| 86 |
|
| 87 |
@classmethod
|
| 88 |
+
def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128, max_length=30):
|
| 89 |
cfg = HookedTransformerConfig(
|
| 90 |
n_layers=n_layers,
|
| 91 |
d_model=d_model,
|
| 92 |
+
n_ctx=3 * max_length,
|
| 93 |
d_head=d_model // n_heads,
|
| 94 |
n_heads=n_heads,
|
| 95 |
d_vocab=10, # Dummy vocab size
|
|
|
|
| 99 |
use_attn_result=True,
|
| 100 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
)
|
| 102 |
+
return cls(cfg, state_dim, action_dim, max_length=max_length)
|
| 103 |
|
tests/test_mechanistic.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pytest
|
| 4 |
+
from src.models.hooked_dt import HookedDT
|
| 5 |
+
from src.interpretability.attribution import LogitAttributionEngine
|
| 6 |
+
from src.interpretability.patching import ActivationPatcher
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def model():
|
| 10 |
+
return HookedDT.from_config(state_dim=10, action_dim=7, n_layers=1, n_heads=2, d_model=32)
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def mock_data():
|
| 14 |
+
batch_size = 1
|
| 15 |
+
seq_len = 5
|
| 16 |
+
state_dim = 10
|
| 17 |
+
action_dim = 7
|
| 18 |
+
|
| 19 |
+
states = torch.randn(batch_size, seq_len, state_dim)
|
| 20 |
+
actions = torch.nn.functional.one_hot(torch.randint(0, action_dim, (batch_size, seq_len)), num_classes=action_dim).float()
|
| 21 |
+
returns = torch.randn(batch_size, seq_len, 1)
|
| 22 |
+
|
| 23 |
+
return {"states": states, "actions": actions, "returns_to_go": returns}
|
| 24 |
+
|
| 25 |
+
def test_logit_attribution(model, mock_data):
|
| 26 |
+
engine = LogitAttributionEngine(model)
|
| 27 |
+
preds, cache = model(**mock_data, return_cache=True)
|
| 28 |
+
target_action = preds[0, -1].argmax().item()
|
| 29 |
+
|
| 30 |
+
dla = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2)
|
| 31 |
+
|
| 32 |
+
assert dla.shape == (model.cfg.n_layers, model.cfg.n_heads)
|
| 33 |
+
assert not torch.isnan(dla).any()
|
| 34 |
+
|
| 35 |
+
def test_activation_patching(model, mock_data):
|
| 36 |
+
patcher = ActivationPatcher(model)
|
| 37 |
+
|
| 38 |
+
# Clean run
|
| 39 |
+
clean_preds, clean_cache = model(**mock_data, return_cache=True)
|
| 40 |
+
clean_probs = torch.softmax(clean_preds, dim=-1)
|
| 41 |
+
target_action = clean_preds[0, -1].argmax().item()
|
| 42 |
+
|
| 43 |
+
# Create corrupted run (zeroed states)
|
| 44 |
+
corrupted_data = mock_data.copy()
|
| 45 |
+
corrupted_data["states"] = torch.zeros_like(mock_data["states"])
|
| 46 |
+
_, corrupted_cache = model(**corrupted_data, return_cache=True)
|
| 47 |
+
|
| 48 |
+
# Patch head 0 of layer 0
|
| 49 |
+
patched_logits = patcher.patch_head(
|
| 50 |
+
mock_data,
|
| 51 |
+
corrupted_cache,
|
| 52 |
+
layer=0,
|
| 53 |
+
head_index=0,
|
| 54 |
+
target_token_index=-2
|
| 55 |
+
)
|
| 56 |
+
patched_probs = torch.softmax(patched_logits, dim=-1)
|
| 57 |
+
|
| 58 |
+
drop = patcher.calculate_probability_drop(clean_probs, patched_probs, target_action)
|
| 59 |
+
|
| 60 |
+
assert isinstance(drop, float)
|
| 61 |
+
# Patching with corrupted (zeros) should generally decrease performance/probability
|
| 62 |
+
# but at minimum we check it returns a valid number.
|
| 63 |
+
assert not np.isnan(drop)
|