sadhumitha-s commited on
Commit
b7ddfc6
·
1 Parent(s): 14d2c06

refactor: implement centralized configuration, upgrade SAE training to multi-layer TopK, and optimize dashboard attribution UX

Browse files
.gitignore CHANGED
@@ -56,3 +56,5 @@ static/
56
 
57
  /PRD.md
58
  artifacts/
 
 
 
56
 
57
  /PRD.md
58
  artifacts/
59
+ scratch/
60
+ *.log
Makefile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: setup train dashboard test clean
2
+
3
+ # Setup environment
4
+ setup:
5
+ pip install -r requirements.txt
6
+
7
+ # Run full pipeline: Harvesting -> DT Training -> SAE Training
8
+ train:
9
+ python3 scripts/train_dt.py
10
+ python3 scripts/train_sae.py
11
+
12
+ # Launch the explorer dashboard
13
+ dashboard:
14
+ streamlit run src/dashboard/app.py
15
+
16
+ # Run unit tests
17
+ test:
18
+ PYTHONPATH=. pytest tests/
19
+
20
+ # Remove artifacts and cached files
21
+ clean:
22
+ rm -rf data/*.pt models/*.pt artifacts/saes/*.pt
23
+ find . -type d -name "__pycache__" -exec rm -rf {} +
24
+ find . -type d -name ".pytest_cache" -exec rm -rf {} +
README.md CHANGED
@@ -1,66 +1,86 @@
1
  # DT-Circuits: Mechanistic Interpretability for Decision Transformers
2
 
3
- ![Python](https://img.shields.io/badge/python-3.9+-blue)
4
- ![PyTorch](https://img.shields.io/badge/PyTorch-2.x-red)
 
 
5
 
6
  DT-Circuits is a research framework for mechanistic interpretability of Decision Transformers, focused on causal analysis, sparse feature decomposition, and circuit-level understanding of sequential decision-making agents.
7
 
8
  ---
9
 
10
- ## Motivation
11
-
12
- Mechanistic interpretability has primarily focused on language models, while reinforcement learning agents remain comparatively underexplored.
13
-
14
- Decision Transformers provide a uniquely analyzable architecture because trajectories, rewards, and actions are represented in a unified autoregressive sequence.
15
-
16
- DT-Circuits aims to make RL agents inspectable at the circuit level rather than only through behavioral evaluation.
17
-
18
- ---
19
-
20
  ## Table of Contents
21
- - [Features](#features)
22
- - [Technical Architecture](#technical-architecture)
 
23
  - [Project Structure](#project-structure)
24
- - [Getting Started](#getting-started)
 
 
 
25
 
26
- ---
27
-
28
- ## Documentation
29
- - [Circuit Discovery](./docs/circuit_discovery.md)
30
- - [Activation Patching](./docs/activation_patching.md)
31
- - [SAEs & Steering](./docs/sae_steering.md)
32
 
33
- ---
34
 
35
- ## Features
 
 
 
36
 
37
- ### 1. Neural Mapping
38
- - **Hooked-DT**: Access any internal activation or weight during the agent's run.
39
- - **Logit Attribution**: See which attention heads or MLP layers drive specific actions.
40
- - **Induction Scan**: Find heads that recognize temporal patterns and past states.
41
 
42
- ### 2. Testing Causality
43
- - **Activation Patching**: Swap internal states to see what actually changes the agent's move.
44
- - **Behavior Steering**: Add vectors to activations to push the agent toward specific goals without retraining.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- ### 3. Concept Discovery
47
- - **TopK SAEs**: Decompose complex activations into a few active "concepts" for easier reading.
48
- - **Auto-Labeling (NLA)**: Use an LLM to automatically describe what each discovered neuron feature does.
49
- - **Cross-Model Probes**: Check if different agents (like DQNs) learn the same internal concepts as the DT.
50
 
51
- ### 4. Circuit Analysis
52
- - **ACDC**: Automatically strip the model down to the minimal circuit needed for a task.
53
- - **Path Patching**: Trace how a signal flows from a specific input token to the final action.
54
- - **Evolutionary Scan**: Watch how decision-making circuits form during training.
55
 
56
- ---
 
 
 
57
 
58
- ## Technical Architecture
 
 
 
59
 
60
- - **Data**: Collects expert paths using a PPO harvester.
61
- - **Model**: Custom Decision Transformer compatible with TransformerLens.
62
- - **Tools**: Dedicated modules for attribution, patching, SAEs, and steering.
63
- - **Dashboard**: Streamlit UI for real-time model analysis.
64
 
65
  ---
66
 
@@ -68,10 +88,7 @@ DT-Circuits aims to make RL agents inspectable at the circuit level rather than
68
 
69
  ```text
70
  DT-Circuits/
71
- ├── scripts/
72
- │ ├── train_dt.py # Decision Transformer training pipeline
73
- │ └── train_sae.py # Sparse Autoencoder (SAE) training script
74
- ├── src/
75
  │ ├── dashboard/
76
  │ │ └── app.py # Streamlit-based visualization UI
77
  │ ├── data/
@@ -89,44 +106,58 @@ DT-Circuits/
89
  │ │ └── universality.py # Cross-architecture feature mapping
90
  │ ├── models/
91
  │ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
 
92
  │ └── utils/
93
  ├── tests/ # Unit tests for all modules
94
- ├── config.yaml
95
- ── requirements.txt
 
96
  ```
97
 
98
- ---
99
 
100
- ## Getting Started
101
 
102
- ### Prerequisites
103
- - Python 3.9+
104
- - PyTorch 2.x
105
- - TransformerLens
106
- - SAE-Lens
107
 
108
- ### Quick Start
 
109
 
110
- Follow these steps to initialize the environment and verify the installation.
111
 
112
- 1. **Environment Setup**
113
- ```bash
114
- python -m venv venv
115
- source venv/bin/activate
116
- pip install -r requirements.txt
117
- ```
118
 
119
- 2. **Verification**
120
- Run the component tests to ensure all dependencies and hooks are correctly configured.
121
- ```bash
122
- PYTHONPATH=. pytest tests/test_components.py
123
- ```
124
 
125
- 3. **Dashboard Execution**
126
- Launch the `DT-Explorer` dashboard. The dashboard will initialize with a random model if no trained weights are detected.
127
- ```bash
128
- streamlit run src/dashboard/app.py
129
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  ### Workflow
132
 
@@ -135,7 +166,47 @@ Follow these steps to initialize the environment and verify the installation.
135
  python scripts/train_dt.py
136
  ```
137
 
138
- 2. **Interpretability Analysis**
 
 
 
 
 
139
  ```bash
140
  streamlit run src/dashboard/app.py
141
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # DT-Circuits: Mechanistic Interpretability for Decision Transformers
2
 
3
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
4
+ [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-red.svg)](https://pytorch.org/)
5
+ [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-green.svg)](https://opensource.org/licenses/Apache-2.0)
6
+ [![Framework: TransformerLens](https://img.shields.io/badge/Framework-TransformerLens-orange.svg)](https://github.com/TransformerLensOrg/TransformerLens)
7
 
8
  DT-Circuits is a research framework for mechanistic interpretability of Decision Transformers, focused on causal analysis, sparse feature decomposition, and circuit-level understanding of sequential decision-making agents.
9
 
10
  ---
11
 
 
 
 
 
 
 
 
 
 
 
12
  ## Table of Contents
13
+ - [Core Objectives](#core-objectives)
14
+ - [Technical Overview](#technical-overview)
15
+ - [Capabilities](#capabilities)
16
  - [Project Structure](#project-structure)
17
+ - [Installation and Usage](#installation-and-usage)
18
+ - [Documentation](#documentation)
19
+ - [Citation](#citation)
20
+ - [License](#license)
21
 
22
+ ---
 
 
 
 
 
23
 
24
+ ## Core Objectives
25
 
26
+ 1. **Map Information Flow**: Quantify how input tokens (State, Action, Reward-to-Go) contribute to the output action logits.
27
+ 2. **Causal Verification**: Use intervention techniques to identify the minimal set of model components required for specific behaviors.
28
+ 3. **Feature Decomposition**: Use Sparse Autoencoders (SAEs) to identify monosemantic features within the model's residual stream.
29
+ 4. **Behavioral Control**: Modify agent decisions at inference time by manipulating internal activations.
30
 
31
+ ---
 
 
 
32
 
33
+ ## Technical Overview
34
+
35
+ The framework centers around `HookedDT`, a Decision Transformer implementation that allows for activation hooking and cache management.
36
+
37
+ ### Information Flow Diagram
38
+
39
+ ```mermaid
40
+ graph TD
41
+ subgraph Input_Sequence
42
+ S[State Tokens]
43
+ A[Action Tokens]
44
+ RTG[Reward-to-Go Tokens]
45
+ end
46
+
47
+ Input_Sequence --> Embed[Embedding Layers]
48
+ Embed --> Hooks[Activation Hooks]
49
+
50
+ subgraph Transformer_Block
51
+ Hooks --> Attn[Multi-Head Attention]
52
+ Attn --> MLP[MLP Layers]
53
+ MLP --> Res[Residual Stream]
54
+ end
55
+
56
+ Res --> DLA[Direct Logit Attribution]
57
+ Res --> SAE[Sparse Autoencoder]
58
+ Res --> Output[Action Logits]
59
+
60
+ subgraph Interpretability_Modules
61
+ DLA -.-> Analysis
62
+ SAE -.-> Features
63
+ Intervention[Activation Patching] -.-> Hooks
64
+ end
65
+ ```
66
 
67
+ ---
 
 
 
68
 
69
+ ## Capabilities
 
 
 
70
 
71
+ ### Causal Mediation and Attribution
72
+ * **Direct Logit Attribution (DLA)**: Measures the direct contribution of individual attention heads and MLP layers to the final logit output.
73
+ * **Activation Patching**: Substitutes internal activations from different runs to isolate the causal effect of specific inputs on model behavior.
74
+ * **Path Patching**: Traces how information flows through specific connections between model components.
75
 
76
+ ### Feature Discovery and Analysis
77
+ * **Sparse Autoencoders (SAEs)**: Decomposes the residual stream into a set of sparse features, helping to resolve polysemanticity.
78
+ * **Induction Scanning**: Identifies attention heads that perform pattern-matching and temporal sequence recognition.
79
+ * **Automated Circuit Discovery (ACDC)**: Prunes the model to identify the smallest functional subgraph sufficient to perform a specific task.
80
 
81
+ ### Behavioral Steering
82
+ * **Activation Steering**: Injects specific vectors into the residual stream to bias the agent's decision-making without retraining the weights.
83
+ * **Safety Auditing**: Monitors SAE reconstruction error and feature activation to detect anomalous or out-of-distribution internal states.
 
84
 
85
  ---
86
 
 
88
 
89
  ```text
90
  DT-Circuits/
91
+ ├── src/
 
 
 
92
  │ ├── dashboard/
93
  │ │ └── app.py # Streamlit-based visualization UI
94
  │ ├── data/
 
106
  │ │ └── universality.py # Cross-architecture feature mapping
107
  │ ├── models/
108
  │ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
109
+ │ ├── config.py # Centralized hyperparameter management
110
  │ └── utils/
111
  ├── tests/ # Unit tests for all modules
112
+ ├── config.yaml # External hyperparameter storage
113
+ ── requirements.txt
114
+ └── docs/
115
  ```
116
 
117
+ ---
118
 
119
+ ## Configuration
120
 
121
+ Hyperparameters are managed through a dual-system for both ease of use and research reproducibility:
 
 
 
 
122
 
123
+ 1. **`config.yaml`**: The primary interface for users. You can modify model dimensions, training epochs, and environment settings here without touching the code.
124
+ 2. **`src/config.py`**: Defines the underlying structure using Python dataclasses. It automatically loads overrides from `config.yaml` at runtime.
125
 
126
+ ### Key Configuration Sections
127
 
128
+ | Section | Description | Key Parameters |
129
+ | :--- | :--- | :--- |
130
+ | **`model`** | Architecture settings for the Decision Transformer | `n_layers`, `d_model`, `n_heads`, `max_length` |
131
+ | **`data`** | Settings for expert trajectory collection | `env_id`, `num_episodes` (for DT training) |
132
+ | **`train`** | DT training hyperparameters | `lr`, `epochs`, `seed` |
133
+ | **`sae`** | Sparse Autoencoder training hyperparameters | `expansion_factor`, `k`, `num_episodes` (SAE specific) |
134
 
135
+ **Example: Independent Data Control**
136
+ You can control the amount of data used for general training vs. interpretability separately:
137
+ ```yaml
138
+ data:
139
+ num_episodes: 1000 # Episodes for training the DT teacher
140
 
141
+ sae:
142
+ num_episodes: 500 # Episodes for extracting SAE activations
143
+ ```
144
+
145
+ ---
146
+
147
+ ## Installation and Usage
148
+
149
+ ### Setup
150
+ ```bash
151
+ python -m venv venv
152
+ source venv/bin/activate
153
+ pip install -r requirements.txt
154
+ ```
155
+
156
+ ### Dashboard Execution
157
+ Launch the `DT-Explorer` dashboard. The dashboard will initialize with a random model if no trained weights are detected.
158
+ ```bash
159
+ streamlit run src/dashboard/app.py
160
+ ```
161
 
162
  ### Workflow
163
 
 
166
  python scripts/train_dt.py
167
  ```
168
 
169
+ 2. **SAE Training**
170
+ ```bash
171
+ python scripts/train_sae.py
172
+ ```
173
+
174
+ 3. **Interpretability Analysis**
175
  ```bash
176
  streamlit run src/dashboard/app.py
177
  ```
178
+
179
+ ### Alternative: Makefile
180
+ Common tasks can also be executed via `make`:
181
+ ```bash
182
+ make setup # Install dependencies
183
+ make train # Run full training pipeline (DT + SAE)
184
+ make dashboard # Launch DT-Explorer
185
+ ```
186
+
187
+ ---
188
+
189
+ ## Documentation
190
+
191
+ Detailed technical documentation for specific modules:
192
+ * [Circuit Discovery](./docs/circuit_discovery.md)
193
+ * [Causal Intervention](./docs/activation_patching.md)
194
+ * [SAEs and Steering](./docs/sae_steering.md)
195
+
196
+ ---
197
+
198
+ ## Citation
199
+
200
+ ```bibtex
201
+ @software{dt_circuits2026,
202
+ author = {Sadhumitha S.},
203
+ title = {DT-Circuits: Mechanistic Interpretability for Decision Transformers},
204
+ year = {2026},
205
+ url = {https://github.com/sadhumitha-s/DT-Circuits}
206
+ }
207
+ ```
208
+
209
+ ---
210
+
211
+ ## License
212
+ Apache 2.0
config.yaml CHANGED
@@ -16,3 +16,4 @@ interpretability:
16
  sae:
17
  expansion_factor: 8
18
  l1_coeff: 0.0005
 
 
16
  sae:
17
  expansion_factor: 8
18
  l1_coeff: 0.0005
19
+ num_episodes: 100
requirements.txt CHANGED
@@ -14,3 +14,4 @@ pytest
14
  stable-baselines3
15
  shimmy
16
  seaborn
 
 
14
  stable-baselines3
15
  shimmy
16
  seaborn
17
+ torchvision
scripts/train_dt.py CHANGED
@@ -14,39 +14,43 @@ if root_path not in sys.path:
14
 
15
  from src.models.hooked_dt import HookedDT
16
  from src.data.harvester import PPOHarvester
 
17
 
18
  def train():
19
  """Main training loop for Decision Transformer."""
20
  # Step 1: Collect data from expert PPO teacher
21
- harvester = PPOHarvester(model_path="models/ppo_teacher.zip")
22
- trajectories = harvester.collect_trajectories(num_episodes=100)
23
 
24
  # Save trajectories for the dashboard to use later
25
  harvester.save_trajectories(trajectories, "data/trajectories.pt")
26
 
27
  state_dim = trajectories[0]["observations"].shape[1]
28
- action_dim = 7 # MiniGrid standard actions
29
 
30
  model = HookedDT.from_config(
31
  state_dim=state_dim,
32
  action_dim=action_dim,
33
- n_layers=2,
34
- n_heads=4,
35
- d_model=128
 
36
  )
37
 
38
- optimizer = optim.AdamW(model.parameters(), lr=1e-4)
39
  criterion = nn.CrossEntropyLoss()
40
 
41
  # Step 2: Train the DT
42
  model.train()
43
- for epoch in range(10):
44
  total_loss = 0
45
  for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
46
- states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
47
- actions = torch.from_numpy(traj["actions"]).long()
 
 
48
  actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
49
- returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
50
 
51
  # Predict actions based on State tokens
52
  action_preds = model(states, actions_one_hot, returns)
 
14
 
15
  from src.models.hooked_dt import HookedDT
16
  from src.data.harvester import PPOHarvester
17
+ from src.config import cfg
18
 
19
  def train():
20
  """Main training loop for Decision Transformer."""
21
  # Step 1: Collect data from expert PPO teacher
22
+ harvester = PPOHarvester(env_id=cfg.data.env_id, model_path="models/ppo_teacher.zip")
23
+ trajectories = harvester.collect_trajectories(num_episodes=cfg.data.num_episodes)
24
 
25
  # Save trajectories for the dashboard to use later
26
  harvester.save_trajectories(trajectories, "data/trajectories.pt")
27
 
28
  state_dim = trajectories[0]["observations"].shape[1]
29
+ action_dim = cfg.model.action_dim
30
 
31
  model = HookedDT.from_config(
32
  state_dim=state_dim,
33
  action_dim=action_dim,
34
+ n_layers=cfg.model.n_layers,
35
+ n_heads=cfg.model.n_heads,
36
+ d_model=cfg.model.d_model,
37
+ max_length=cfg.model.max_length
38
  )
39
 
40
+ optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
41
  criterion = nn.CrossEntropyLoss()
42
 
43
  # Step 2: Train the DT
44
  model.train()
45
+ for epoch in range(cfg.train.epochs):
46
  total_loss = 0
47
  for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
48
+ # Truncate to match model max_length
49
+ max_len = model.max_length
50
+ states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)[:, -max_len:]
51
+ actions = torch.from_numpy(traj["actions"]).long()[-max_len:]
52
  actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
53
+ returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)[:, -max_len:]
54
 
55
  # Predict actions based on State tokens
56
  action_preds = model(states, actions_one_hot, returns)
scripts/train_sae.py CHANGED
@@ -1,34 +1,112 @@
 
 
1
  import torch
2
- from sae_lens import SAEConfig, SAE
 
 
 
 
 
 
 
 
3
  from src.models.hooked_dt import HookedDT
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def train_sae():
6
- # Load DT
7
- state_dim = 2739
8
- action_dim = 7
9
- model = HookedDT.from_config(state_dim, action_dim)
10
- # model.load_state_dict(torch.load("models/mini_dt.pt"))
11
-
12
- # Configure SAE
13
- cfg = SAEConfig(
14
- d_in=128, # d_model
15
- d_sae=128 * 8, # Expansion factor
16
- hook_point="blocks.0.hook_resid_post",
17
- hook_point_layer=0,
18
- architecture="standard",
19
- activation_fn="relu",
20
- expansion_factor=8,
21
- l1_coefficient=5e-4,
22
- lr=3e-4,
23
- train_batch_size=4096,
24
- context_size=30, # Sequence length
 
 
 
 
25
  )
 
 
 
 
 
 
 
 
 
26
 
27
- sae = SAE(cfg)
 
28
 
29
- # Training logic would go here, using activations from the DT
30
- print("SAE Configured for DT-Explorer.")
31
- print(f"Hooking into: {cfg.hook_point}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  if __name__ == "__main__":
34
  train_sae()
 
1
+ import sys
2
+ from pathlib import Path
3
  import torch
4
+ from sae_lens import TopKSAEConfig, TopKSAE
5
+
6
+ # Add project root to path for absolute imports
7
+ root_path = str(Path(__file__).resolve().parent.parent)
8
+ if root_path not in sys.path:
9
+ sys.path.append(root_path)
10
+
11
+ import random
12
+ import numpy as np
13
  from src.models.hooked_dt import HookedDT
14
+ from src.interpretability.sae_manager import SAEManager
15
+ from src.config import cfg
16
+
17
+ def set_seed(seed: int = 42):
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed_all(seed)
23
+ torch.backends.cudnn.deterministic = True
24
+ torch.backends.cudnn.benchmark = False
25
 
26
  def train_sae():
27
+ # 0. Set seed for reproducibility
28
+ set_seed(cfg.train.seed)
29
+
30
+ # 1. Load Trajectories to get dimensions
31
+ traj_path = "data/trajectories.pt"
32
+ if not Path(traj_path).exists():
33
+ print(f"Error: {traj_path} not found. Please run scripts/train_dt.py first.")
34
+ return
35
+
36
+ trajectories = torch.load(traj_path, weights_only=False)
37
+ print(f"Loaded {len(trajectories)} trajectories.")
38
+
39
+ # 2. Initialize Model
40
+ state_dim = trajectories[0]["observations"].shape[1]
41
+ action_dim = cfg.model.action_dim
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ model = HookedDT.from_config(
45
+ state_dim=state_dim,
46
+ action_dim=action_dim,
47
+ n_layers=cfg.model.n_layers,
48
+ n_heads=cfg.model.n_heads,
49
+ d_model=cfg.model.d_model
50
  )
51
+ model.to(device)
52
+
53
+ # Check for trained DT checkpoint
54
+ checkpoint_path = "models/mini_dt.pt"
55
+ if Path(checkpoint_path).exists():
56
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
57
+ print(f"Loaded DT weights from {checkpoint_path}")
58
+ else:
59
+ print(f"Warning: {checkpoint_path} not found. Training SAE on random weights.")
60
 
61
+ # 3. & 4. Train SAEs for ALL layers
62
+ manager = SAEManager(model, sae_dir="artifacts/saes")
63
 
64
+ for layer in range(model.cfg.n_layers):
65
+ hook_point = f"blocks.{layer}.hook_resid_post"
66
+ all_activations = []
67
+
68
+ print(f"\n--- Processing Layer {layer} ({hook_point}) ---")
69
+
70
+ # Extract Activations
71
+ model.eval()
72
+ print(f"Extracting activations...")
73
+
74
+ # Number of trajectories from config
75
+ num_trajs_to_use = min(len(trajectories), cfg.sae.num_episodes)
76
+
77
+ with torch.no_grad():
78
+ for traj in trajectories[:num_trajs_to_use]:
79
+ states = torch.from_numpy(traj["observations"]).float().to(device).unsqueeze(0)
80
+ actions = torch.from_numpy(traj["actions"]).long().to(device)
81
+ actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
82
+ returns = torch.from_numpy(traj["rewards"]).float().to(device).unsqueeze(0).unsqueeze(-1)
83
+
84
+ _, cache = model(states, actions_one_hot, returns, return_cache=True)
85
+ all_activations.append(cache[hook_point].squeeze(0).cpu())
86
+
87
+ activations = torch.cat(all_activations, dim=0)
88
+ print(f"Collected {activations.shape[0]} activation vectors.")
89
+
90
+ # Setup and Train
91
+ print(f"Starting TopK SAE training...")
92
+ manager.setup_sae(
93
+ hook_point=hook_point,
94
+ d_model=cfg.model.d_model,
95
+ architecture="topk",
96
+ k=cfg.sae.k
97
+ )
98
+
99
+ manager.train_on_trajectories(
100
+ hook_point=hook_point,
101
+ activations=activations,
102
+ epochs=cfg.sae.epochs,
103
+ batch_size=cfg.sae.batch_size
104
+ )
105
+
106
+ # Save all SAEs once training is complete for all layers
107
+ manager.save_all_saes()
108
+
109
+ print(f"\nSAE Training Complete for all {model.cfg.n_layers} layers. Results saved to artifacts/saes/")
110
 
111
  if __name__ == "__main__":
112
  train_sae()
src/config.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Dict, Optional
3
+ import yaml
4
+ from pathlib import Path
5
+
6
+ @dataclass
7
+ class ModelConfig:
8
+ n_layers: int = 2
9
+ n_heads: int = 4
10
+ d_model: int = 128
11
+ max_length: int = 30
12
+ state_dim: Optional[int] = None
13
+ action_dim: int = 7
14
+
15
+ @dataclass
16
+ class DataConfig:
17
+ env_id: str = "MiniGrid-Empty-8x8-v0"
18
+ num_episodes: int = 1000
19
+ collection_method: str = "PPO-Teacher"
20
+
21
+ @dataclass
22
+ class TrainConfig:
23
+ lr: float = 1e-4
24
+ epochs: int = 10
25
+ seed: int = 42
26
+
27
+ @dataclass
28
+ class SAEConfig:
29
+ expansion_factor: int = 8
30
+ k: int = 32
31
+ l1_coeff: float = 0.0005
32
+ lr: float = 3e-4
33
+ epochs: int = 5
34
+ batch_size: int = 1024
35
+ num_episodes: int = 100
36
+
37
+ @dataclass
38
+ class Config:
39
+ model: ModelConfig = field(default_factory=ModelConfig)
40
+ data: DataConfig = field(default_factory=DataConfig)
41
+ train: TrainConfig = field(default_factory=TrainConfig)
42
+ sae: SAEConfig = field(default_factory=SAEConfig)
43
+
44
+ @classmethod
45
+ def load_from_yaml(cls, yaml_path: str = "config.yaml") -> "Config":
46
+ """Loads configuration from a YAML file, overriding defaults."""
47
+ path = Path(yaml_path)
48
+ if not path.exists():
49
+ return cls()
50
+
51
+ with open(path, "r") as f:
52
+ data = yaml.safe_load(f)
53
+
54
+ # Helper to safely update dataclass from dict
55
+ def update_dataclass(dc_obj, dc_dict):
56
+ for key, value in dc_dict.items():
57
+ if hasattr(dc_obj, key):
58
+ setattr(dc_obj, key, value)
59
+
60
+ config = cls()
61
+ if "model" in data:
62
+ update_dataclass(config.model, data["model"])
63
+ if "data" in data:
64
+ update_dataclass(config.data, data["data"])
65
+ if "train" in data:
66
+ update_dataclass(config.train, data["train"])
67
+ if "sae" in data:
68
+ update_dataclass(config.sae, data["sae"])
69
+
70
+ return config
71
+
72
+ # Global config instance for easy access
73
+ cfg = Config.load_from_yaml()
src/dashboard/app.py CHANGED
@@ -70,25 +70,25 @@ with tab1:
70
  st.header("Direct Logit Attribution (DLA)")
71
  st.write("Visualizing which heads contribute most to the predicted action.")
72
 
73
- if st.button("Run Attribution"):
74
- states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
75
- actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0)
76
- returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
77
-
78
- preds, cache = model(states, actions, returns, return_cache=True)
79
- target_action = preds[0, -1].argmax().item()
80
-
81
- engine = LogitAttributionEngine(model)
82
- # Use token index -2 to target the state token which predicts the action
83
- dla_results = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2)
84
-
85
- fig, ax = plt.subplots()
86
- im = ax.imshow(dla_results.detach().cpu().numpy(), cmap="RdBu_r", aspect='auto')
87
- plt.colorbar(im)
88
- ax.set_xlabel("Head")
89
- ax.set_ylabel("Layer")
90
- st.pyplot(fig)
91
- st.write(f"Analyzing Attribution for Action: {target_action} (at State token)")
92
 
93
  with tab2:
94
  st.header("Activation Patching")
 
70
  st.header("Direct Logit Attribution (DLA)")
71
  st.write("Visualizing which heads contribute most to the predicted action.")
72
 
73
+ # Run automatically for better UX when changing trajectories
74
+ states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
75
+ actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0)
76
+ returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
77
+
78
+ preds, cache = model(states, actions, returns, return_cache=True)
79
+ target_action = preds[0, -1].argmax().item()
80
+
81
+ engine = LogitAttributionEngine(model)
82
+ # Use token index -2 to target the state token which predicts the action
83
+ dla_results = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2)
84
+
85
+ fig, ax = plt.subplots()
86
+ im = ax.imshow(dla_results.detach().cpu().numpy(), cmap="RdBu_r", aspect='auto')
87
+ plt.colorbar(im)
88
+ ax.set_xlabel("Head")
89
+ ax.set_ylabel("Layer")
90
+ st.pyplot(fig)
91
+ st.write(f"Analyzing Attribution for Action: {target_action} (at State token)")
92
 
93
  with tab2:
94
  st.header("Activation Patching")
src/models/hooked_dt.py CHANGED
@@ -85,11 +85,11 @@ class HookedDT(nn.Module):
85
  return action_preds
86
 
87
  @classmethod
88
- def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128):
89
  cfg = HookedTransformerConfig(
90
  n_layers=n_layers,
91
  d_model=d_model,
92
- n_ctx=300,
93
  d_head=d_model // n_heads,
94
  n_heads=n_heads,
95
  d_vocab=10, # Dummy vocab size
@@ -99,5 +99,5 @@ class HookedDT(nn.Module):
99
  use_attn_result=True,
100
  device="cuda" if torch.cuda.is_available() else "cpu"
101
  )
102
- return cls(cfg, state_dim, action_dim)
103
 
 
85
  return action_preds
86
 
87
  @classmethod
88
+ def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128, max_length=30):
89
  cfg = HookedTransformerConfig(
90
  n_layers=n_layers,
91
  d_model=d_model,
92
+ n_ctx=3 * max_length,
93
  d_head=d_model // n_heads,
94
  n_heads=n_heads,
95
  d_vocab=10, # Dummy vocab size
 
99
  use_attn_result=True,
100
  device="cuda" if torch.cuda.is_available() else "cpu"
101
  )
102
+ return cls(cfg, state_dim, action_dim, max_length=max_length)
103
 
tests/test_mechanistic.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pytest
4
+ from src.models.hooked_dt import HookedDT
5
+ from src.interpretability.attribution import LogitAttributionEngine
6
+ from src.interpretability.patching import ActivationPatcher
7
+
8
+ @pytest.fixture
9
+ def model():
10
+ return HookedDT.from_config(state_dim=10, action_dim=7, n_layers=1, n_heads=2, d_model=32)
11
+
12
+ @pytest.fixture
13
+ def mock_data():
14
+ batch_size = 1
15
+ seq_len = 5
16
+ state_dim = 10
17
+ action_dim = 7
18
+
19
+ states = torch.randn(batch_size, seq_len, state_dim)
20
+ actions = torch.nn.functional.one_hot(torch.randint(0, action_dim, (batch_size, seq_len)), num_classes=action_dim).float()
21
+ returns = torch.randn(batch_size, seq_len, 1)
22
+
23
+ return {"states": states, "actions": actions, "returns_to_go": returns}
24
+
25
+ def test_logit_attribution(model, mock_data):
26
+ engine = LogitAttributionEngine(model)
27
+ preds, cache = model(**mock_data, return_cache=True)
28
+ target_action = preds[0, -1].argmax().item()
29
+
30
+ dla = engine.calculate_dla(cache, target_logit_index=target_action, token_index=-2)
31
+
32
+ assert dla.shape == (model.cfg.n_layers, model.cfg.n_heads)
33
+ assert not torch.isnan(dla).any()
34
+
35
+ def test_activation_patching(model, mock_data):
36
+ patcher = ActivationPatcher(model)
37
+
38
+ # Clean run
39
+ clean_preds, clean_cache = model(**mock_data, return_cache=True)
40
+ clean_probs = torch.softmax(clean_preds, dim=-1)
41
+ target_action = clean_preds[0, -1].argmax().item()
42
+
43
+ # Create corrupted run (zeroed states)
44
+ corrupted_data = mock_data.copy()
45
+ corrupted_data["states"] = torch.zeros_like(mock_data["states"])
46
+ _, corrupted_cache = model(**corrupted_data, return_cache=True)
47
+
48
+ # Patch head 0 of layer 0
49
+ patched_logits = patcher.patch_head(
50
+ mock_data,
51
+ corrupted_cache,
52
+ layer=0,
53
+ head_index=0,
54
+ target_token_index=-2
55
+ )
56
+ patched_probs = torch.softmax(patched_logits, dim=-1)
57
+
58
+ drop = patcher.calculate_probability_drop(clean_probs, patched_probs, target_action)
59
+
60
+ assert isinstance(drop, float)
61
+ # Patching with corrupted (zeros) should generally decrease performance/probability
62
+ # but at minimum we check it returns a valid number.
63
+ assert not np.isnan(drop)