sadhumitha-s commited on
Commit
8577352
·
1 Parent(s): fa350cc

feat: implement NLA explainer and universality probe and refactor path patching engine

Browse files
README.md CHANGED
@@ -3,57 +3,64 @@
3
  ![Python](https://img.shields.io/badge/python-3.9+-blue)
4
  ![PyTorch](https://img.shields.io/badge/PyTorch-2.x-red)
5
 
6
- DT-Circuits is a framework for mechanistic interpretability of Decision Transformers (DT). Using TransformerLens, it enables mapping neural circuits, decomposing activations with Sparse Autoencoders (SAEs), and performing causal interventions on agent decision-making.
7
 
8
- The goal is to understand how Reward-to-Go, State, and Action tokens are processed within the residual stream, moving beyond black-box behavioral evaluation.
 
 
 
 
 
 
 
 
9
 
10
  ---
11
 
12
  ## Table of Contents
13
- - [Core Capabilities](#core-capabilities)
14
  - [Technical Architecture](#technical-architecture)
15
  - [Project Structure](#project-structure)
16
  - [Getting Started](#getting-started)
17
 
18
  ---
19
 
20
- ## Project Documentation
21
- Detailed explanations of the mechanistic interpretability techniques used in this project:
22
  - [Circuit Discovery](./docs/circuit_discovery.md)
23
  - [Activation Patching](./docs/activation_patching.md)
24
  - [SAEs & Steering](./docs/sae_steering.md)
25
 
26
  ---
27
 
28
- ## Core Capabilities
29
 
30
- ### 1. Circuit Foundation
31
- - **Hooked-DT**: A Decision Transformer implementation wrapped in TransformerLens for access to internal activations and weights.
32
- - **Direct Logit Attribution (DLA)**: Quantifies the contribution of individual heads and MLP layers to action logits.
33
- - **Induction Head Discovery**: Tools to identify heads responsible for temporal pattern recognition.
34
 
35
- ### 2. Causal Interventions
36
- - **Activation Patching**: Replaces activations between clean and corrupted runs to identify causal paths.
37
- - **Steering**: Generates and applies steering vectors (e.g., via Contrastive Activation Addition) to manipulate agent behavior at inference time.
38
 
39
- ### 3. SAEs & Safety
40
- - **SAE Integration**: Tools to train and deploy SAEs on the residual stream to find monosemantic latents.
41
- - **Anomaly Detection**: Uses SAE reconstruction error to detect out-of-distribution (OOD) states.
 
42
 
43
- ### 4. Path-Level Causal Analysis
44
- - **ACDC (Automated Circuit Discovery)**: Prunes the DT into a minimal sufficient subgraph for specific behaviors.
45
- - **Path Patching**: High-fidelity causal tracing between specific internal nodes (e.g., Goal Token Induction Head → Action Logit).
46
- - **Evolutionary Scan**: Analyzes how decision-making circuits form and stabilize across training checkpoints.
47
 
48
  ---
49
 
50
  ## Technical Architecture
51
 
52
- The platform consists of:
53
- - **Data Layer**: PPO Trajectory Harvester for collecting expert demonstrations (e.g., MiniGrid).
54
- - **Model Layer**: HookedDT implementation.
55
- - **Interpretability Layer**: Modules for attribution, patching, SAE management, and steering.
56
- - **Visualization Layer**: Streamlit dashboard for real-time monitoring and intervention.
57
 
58
  ---
59
 
@@ -61,7 +68,7 @@ The platform consists of:
61
 
62
  ```text
63
  DT-Circuits/
64
- ├── scripts/ # Training and harvesting entry points
65
  │ ├── train_dt.py # Decision Transformer training pipeline
66
  │ └── train_sae.py # Sparse Autoencoder (SAE) training script
67
  ├── src/
@@ -74,17 +81,16 @@ DT-Circuits/
74
  │ │ ├── attribution.py # Direct Logit Attribution (DLA)
75
  │ │ ├── evolution.py # Training Dynamics Analysis
76
  │ │ ├── induction_scan.py # Induction head detection logic
 
77
  │ │ ├── patching.py # Causal activation patching tools
78
  │ │ ├── path_patching.py # Path-based causal intervention engine
79
  │ │ ├── sae_manager.py # SAE deployment and anomaly detection
80
- │ │ ── steering.py # Steering vector generation and injection
 
81
  │ ├── models/
82
  │ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
83
  │ └── utils/
84
- ├── tests/
85
- │ ├── test_components.py
86
- │ ├── test_path_causal_microscope.py
87
- │ └── test_sae_and_steering.py
88
  ├── config.yaml
89
  └── requirements.txt
90
  ```
@@ -96,29 +102,43 @@ DT-Circuits/
96
  ### Prerequisites
97
  - Python 3.9+
98
  - PyTorch 2.x
99
- - TransformerLens
100
- - SAE-Lens
101
 
102
- ### Installation
103
- ```bash
104
- pip install -r requirements.txt
105
- ```
106
 
107
- ### Basic Workflow
108
- 1. **Generate Trajectories**:
109
- Use the harvester to collect teacher data for model training or SAE feature extraction.
110
  ```bash
111
- python scripts/train_dt.py
 
 
 
 
 
 
 
 
112
  ```
113
 
114
- 2. **Run Interpretability Dashboard**:
115
- Launch the interactive UI to perform real-time patching and steering interventions.
116
  ```bash
117
  streamlit run src/dashboard/app.py
118
  ```
119
 
120
- ### Testing
121
 
122
- ```bash
123
- PYTHONPATH=. pytest tests/
124
- ```
 
 
 
 
 
 
 
 
 
 
 
3
  ![Python](https://img.shields.io/badge/python-3.9+-blue)
4
  ![PyTorch](https://img.shields.io/badge/PyTorch-2.x-red)
5
 
6
+ DT-Circuits is a research framework for mechanistic interpretability of Decision Transformers, focused on causal analysis, sparse feature decomposition, and circuit-level understanding of sequential decision-making agents.
7
 
8
+ ---
9
+
10
+ ## Motivation
11
+
12
+ Mechanistic interpretability has primarily focused on language models, while reinforcement learning agents remain comparatively underexplored.
13
+
14
+ Decision Transformers provide a uniquely analyzable architecture because trajectories, rewards, and actions are represented in a unified autoregressive sequence.
15
+
16
+ DT-Circuits aims to make RL agents inspectable at the circuit level rather than only through behavioral evaluation.
17
 
18
  ---
19
 
20
  ## Table of Contents
21
+ - [Features](#features)
22
  - [Technical Architecture](#technical-architecture)
23
  - [Project Structure](#project-structure)
24
  - [Getting Started](#getting-started)
25
 
26
  ---
27
 
28
+ ## Documentation
 
29
  - [Circuit Discovery](./docs/circuit_discovery.md)
30
  - [Activation Patching](./docs/activation_patching.md)
31
  - [SAEs & Steering](./docs/sae_steering.md)
32
 
33
  ---
34
 
35
+ ## Features
36
 
37
+ ### 1. Neural Mapping
38
+ - **Hooked-DT**: Access any internal activation or weight during the agent's run.
39
+ - **Logit Attribution**: See which attention heads or MLP layers drive specific actions.
40
+ - **Induction Scan**: Find heads that recognize temporal patterns and past states.
41
 
42
+ ### 2. Testing Causality
43
+ - **Activation Patching**: Swap internal states to see what actually changes the agent's move.
44
+ - **Behavior Steering**: Add vectors to activations to push the agent toward specific goals without retraining.
45
 
46
+ ### 3. Concept Discovery
47
+ - **TopK SAEs**: Decompose complex activations into a few active "concepts" for easier reading.
48
+ - **Auto-Labeling (NLA)**: Use an LLM to automatically describe what each discovered neuron feature does.
49
+ - **Cross-Model Probes**: Check if different agents (like DQNs) learn the same internal concepts as the DT.
50
 
51
+ ### 4. Circuit Analysis
52
+ - **ACDC**: Automatically strip the model down to the minimal circuit needed for a task.
53
+ - **Path Patching**: Trace how a signal flows from a specific input token to the final action.
54
+ - **Evolutionary Scan**: Watch how decision-making circuits form during training.
55
 
56
  ---
57
 
58
  ## Technical Architecture
59
 
60
+ - **Data**: Collects expert paths using a PPO harvester.
61
+ - **Model**: Custom Decision Transformer compatible with TransformerLens.
62
+ - **Tools**: Dedicated modules for attribution, patching, SAEs, and steering.
63
+ - **Dashboard**: Streamlit UI for real-time model analysis.
 
64
 
65
  ---
66
 
 
68
 
69
  ```text
70
  DT-Circuits/
71
+ ├── scripts/
72
  │ ├── train_dt.py # Decision Transformer training pipeline
73
  │ └── train_sae.py # Sparse Autoencoder (SAE) training script
74
  ├── src/
 
81
  │ │ ├── attribution.py # Direct Logit Attribution (DLA)
82
  │ │ ├── evolution.py # Training Dynamics Analysis
83
  │ │ ├── induction_scan.py # Induction head detection logic
84
+ │ │ ├── nla.py # Natural Language Autoencoder Explainer
85
  │ │ ├── patching.py # Causal activation patching tools
86
  │ │ ├── path_patching.py # Path-based causal intervention engine
87
  │ │ ├── sae_manager.py # SAE deployment and anomaly detection
88
+ │ │ ── steering.py # Steering vector generation and injection
89
+ │ │ └── universality.py # Cross-architecture feature mapping
90
  │ ├── models/
91
  │ │ └── hooked_dt.py # TransformerLens-wrapped Decision Transformer
92
  │ └── utils/
93
+ ├── tests/ # Unit tests for all modules
 
 
 
94
  ├── config.yaml
95
  └── requirements.txt
96
  ```
 
102
  ### Prerequisites
103
  - Python 3.9+
104
  - PyTorch 2.x
105
+ - TransformerLens & SAE-Lens
 
106
 
107
+ ### Quick Start
108
+
109
+ Follow these steps to initialize the environment and verify the installation.
 
110
 
111
+ 1. **Environment Setup**
 
 
112
  ```bash
113
+ python -m venv venv
114
+ source venv/bin/activate # Windows: venv\Scripts\activate
115
+ pip install -r requirements.txt
116
+ ```
117
+
118
+ 2. **Verification**
119
+ Run the component tests to ensure all dependencies and hooks are correctly configured.
120
+ ```bash
121
+ PYTHONPATH=. pytest tests/test_components.py
122
  ```
123
 
124
+ 3. **Dashboard Execution**
125
+ Launch the `DT-Explorer` dashboard. The dashboard will initialize with a random model if no trained weights are detected.
126
  ```bash
127
  streamlit run src/dashboard/app.py
128
  ```
129
 
130
+ ### Research Workflow
131
 
132
+ The standard pipeline consists of trajectory harvesting via teacher agents, model training, and mechanistic analysis.
133
+
134
+ 1. **Data Harvesting & Model Training**
135
+ Execute the training script to collect trajectories and train the Decision Transformer.
136
+ ```bash
137
+ python scripts/train_dt.py
138
+ ```
139
+
140
+ 2. **Interpretability Analysis**
141
+ Utilize the dashboard for circuit mapping (DLA), causal intervention (patching), and SAE latent exploration.
142
+ ```bash
143
+ streamlit run src/dashboard/app.py
144
+ ```
docs/sae_steering.md CHANGED
@@ -4,12 +4,19 @@ Sparse Autoencoders (SAEs) allow us to decompose the residual stream into human-
4
 
5
  ## Sparse Autoencoders (SAE)
6
 
7
- An SAE learns a sparse representation of activations. By projecting dense vectors into a higher-dimensional space with a sparsity constraint (L1 penalty), we find "monosemantic" latents that often correspond to specific concepts (e.g., "Wall ahead", "Turning left").
 
 
 
 
 
 
8
 
9
  ```mermaid
10
  graph LR
11
  Act[Dense Activation] --> Enc[Encoder]
12
- Enc --> Lat[Sparse Latents]
 
13
  Lat --> Dec[Decoder]
14
  Dec --> Rec[Reconstruction]
15
 
@@ -35,3 +42,10 @@ graph TD
35
  Diff -->|Add with Gain λ| Model
36
  Model --> Out[Modified Behavior]
37
  ```
 
 
 
 
 
 
 
 
4
 
5
  ## Sparse Autoencoders (SAE)
6
 
7
+ An SAE decomposes activations into a set of "monosemantic" features. By projecting dense vectors into a higher-dimensional space, we find latents that correspond to specific concepts (e.g., "Wall ahead").
8
+
9
+ ### TopK SAEs
10
+ Instead of using an L1 penalty to force sparsity, we use **TopK SAEs**. These restrict the model to exactly $k$ active features per input. This makes the internal logic cleaner and easier to analyze compared to standard ReLU SAEs.
11
+
12
+ ### Natural Language Labeling (NLA)
13
+ To avoid manual inspection of thousands of features, we use an **NLA Explainer**. This tool takes the top activations for a feature and uses a Language Model to generate a human-readable label (e.g., "Feature #402: Activates when a red key is visible").
14
 
15
  ```mermaid
16
  graph LR
17
  Act[Dense Activation] --> Enc[Encoder]
18
+ Enc --> Lat[TopK Sparse Latents]
19
+ Lat --> NLA[LLM Labeling]
20
  Lat --> Dec[Decoder]
21
  Dec --> Rec[Reconstruction]
22
 
 
42
  Diff -->|Add with Gain λ| Model
43
  Model --> Out[Modified Behavior]
44
  ```
45
+
46
+ ## Cross-Architecture Universality Probes
47
+
48
+ We use **Universality Probes** to check if features are model-specific or "universal" to the task. By comparing the SAE features of a Decision Transformer with the activations of a different model (like a DQN) trained on the same environment, we can identify shared representational spaces.
49
+
50
+ - **High Correlation**: Suggests the feature is a fundamental concept required to solve the task (e.g., "The concept of a wall").
51
+ - **Low Correlation**: Suggests the feature might be an artifact of the specific architecture or training algorithm.
scripts/train_dt.py CHANGED
@@ -1,22 +1,29 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
4
- from src.models.hooked_dt import HookedDT
5
- from src.data.harvester import PPOHarvester
6
  import numpy as np
7
  from tqdm import tqdm
8
 
 
 
 
9
  def train():
10
- harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
 
 
11
  trajectories = harvester.collect_trajectories(num_episodes=100)
12
 
 
 
 
13
  state_dim = trajectories[0]["observations"].shape[1]
14
- action_dim = 7 # MiniGrid
15
 
16
  model = HookedDT.from_config(
17
  state_dim=state_dim,
18
  action_dim=action_dim,
19
- n_layers=1,
20
  n_heads=4,
21
  d_model=128
22
  )
@@ -24,18 +31,20 @@ def train():
24
  optimizer = optim.AdamW(model.parameters(), lr=1e-4)
25
  criterion = nn.CrossEntropyLoss()
26
 
 
27
  model.train()
28
  for epoch in range(10):
29
  total_loss = 0
30
  for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
31
  states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
32
- actions = torch.from_numpy(traj["actions"]).long().unsqueeze(0)
33
- actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float()
34
  returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
35
- timesteps = torch.arange(states.shape[1]).unsqueeze(0)
36
 
37
- action_preds, _, _ = model(states, actions_one_hot, returns, timesteps)
 
38
 
 
39
  loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
40
 
41
  optimizer.zero_grad()
@@ -46,6 +55,8 @@ def train():
46
 
47
  print(f"Epoch {epoch} Loss: {total_loss / len(trajectories)}")
48
 
 
 
49
  torch.save(model.state_dict(), "models/mini_dt.pt")
50
  print("Model saved to models/mini_dt.pt")
51
 
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
 
 
5
  import numpy as np
6
  from tqdm import tqdm
7
 
8
+ from src.models.hooked_dt import HookedDT
9
+ from src.data.harvester import PPOHarvester
10
+
11
  def train():
12
+ """Main training loop for Decision Transformer."""
13
+ # Step 1: Collect data from expert PPO teacher
14
+ harvester = PPOHarvester(model_path="models/ppo_teacher.zip")
15
  trajectories = harvester.collect_trajectories(num_episodes=100)
16
 
17
+ # Save trajectories for the dashboard to use later
18
+ harvester.save_trajectories(trajectories, "data/trajectories.pt")
19
+
20
  state_dim = trajectories[0]["observations"].shape[1]
21
+ action_dim = 7 # MiniGrid standard actions
22
 
23
  model = HookedDT.from_config(
24
  state_dim=state_dim,
25
  action_dim=action_dim,
26
+ n_layers=2,
27
  n_heads=4,
28
  d_model=128
29
  )
 
31
  optimizer = optim.AdamW(model.parameters(), lr=1e-4)
32
  criterion = nn.CrossEntropyLoss()
33
 
34
+ # Step 2: Train the DT
35
  model.train()
36
  for epoch in range(10):
37
  total_loss = 0
38
  for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
39
  states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
40
+ actions = torch.from_numpy(traj["actions"]).long()
41
+ actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
42
  returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
 
43
 
44
+ # Predict actions based on State tokens
45
+ action_preds = model(states, actions_one_hot, returns)
46
 
47
+ # Cross entropy loss on predicted actions
48
  loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
49
 
50
  optimizer.zero_grad()
 
55
 
56
  print(f"Epoch {epoch} Loss: {total_loss / len(trajectories)}")
57
 
58
+ # Step 3: Save the trained model
59
+ os.makedirs("models", exist_ok=True)
60
  torch.save(model.state_dict(), "models/mini_dt.pt")
61
  print("Model saved to models/mini_dt.pt")
62
 
src/dashboard/app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import torch
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from src.models.hooked_dt import HookedDT
@@ -7,62 +8,109 @@ from src.interpretability.attribution import LogitAttributionEngine
7
  from src.interpretability.patching import ActivationPatcher
8
 
9
  st.set_page_config(page_title="DT-Explorer", layout="wide")
 
10
 
11
- st.title("DT-Explorer: Mechanistic Interpretability for Decision Transformers")
12
-
13
- st.sidebar.header("Model Configuration")
14
- n_layers = st.sidebar.slider("Layers", 1, 12, 1)
15
- n_heads = st.sidebar.slider("Heads", 1, 8, 4)
16
 
17
  @st.cache_resource
18
- def load_model():
19
- state_dim = 2739 # FlatObsWrapper for 8x8 MiniGrid
20
- action_dim = 7
21
- model = HookedDT.from_config(state_dim, action_dim, n_layers=n_layers, n_heads=n_heads)
 
 
 
 
 
 
 
22
  return model
23
 
24
- model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- tab1, tab2, tab3 = st.tabs(["Circuit Mapping", "Causal Intervention", "SAE Explorer"])
 
 
 
 
27
 
28
  with tab1:
29
- st.header("Direct Logit Attribution")
30
- if st.button("Run Attribution Analysis"):
31
- # Mock data for demo
32
- states = torch.randn(1, 10, model.state_dim)
33
- actions = torch.randn(1, 10, model.action_dim)
34
- returns = torch.randn(1, 10, 1)
35
- timesteps = torch.arange(10).unsqueeze(0)
36
 
37
- logits, cache = model.transformer.run_with_cache(
38
- torch.randn(1, 30, model.cfg.d_model)
39
- )
40
 
41
  engine = LogitAttributionEngine(model)
 
42
 
43
  fig, ax = plt.subplots()
44
- dla_mock = np.random.randn(n_layers, n_heads)
45
- im = ax.imshow(dla_mock, cmap="RdBu_r")
46
  plt.colorbar(im)
 
 
47
  st.pyplot(fig)
 
48
 
49
  with tab2:
50
  st.header("Activation Patching")
51
- col1, col2 = st.columns(2)
52
- with col1:
53
- st.subheader("Clean Run")
54
- st.text("Input: Goal is visible")
55
- with col2:
56
- st.subheader("Corrupted Run")
57
- st.text("Input: Goal is blocked")
58
 
59
- layer_to_patch = st.selectbox("Select Layer", range(n_layers))
60
- head_to_patch = st.selectbox("Select Head", range(n_heads))
 
61
 
62
- if st.button("Apply Patch"):
63
- st.success(f"Patched Layer {layer_to_patch}, Head {head_to_patch}")
64
- st.metric("Probability Drop", "0.42", delta="-0.15")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  with tab3:
67
- st.header("SAE Monosemantic Latents")
68
- st.info("SAE Integration Coming Soon (Phase 3)")
 
1
  import streamlit as st
2
  import torch
3
+ import os
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from src.models.hooked_dt import HookedDT
 
8
  from src.interpretability.patching import ActivationPatcher
9
 
10
  st.set_page_config(page_title="DT-Explorer", layout="wide")
11
+ st.title("DT-Explorer: Mechanistic Interpretability for DT")
12
 
13
+ # Sidebar for loading model and data
14
+ st.sidebar.header("Data & Model")
15
+ model_path = st.sidebar.text_input("Model Path", "models/mini_dt.pt")
16
+ data_path = st.sidebar.text_input("Trajectory Path", "data/trajectories.pt")
 
17
 
18
  @st.cache_resource
19
+ def get_model(path):
20
+ if not os.path.exists(path):
21
+ st.sidebar.warning(f"Model not found at {path}. Using random init for demo.")
22
+ return HookedDT.from_config(state_dim=2739, action_dim=7)
23
+
24
+ model = HookedDT.from_config(state_dim=2739, action_dim=7)
25
+ try:
26
+ model.load_state_dict(torch.load(path, map_location="cpu"))
27
+ model.eval()
28
+ except Exception as e:
29
+ st.sidebar.error(f"Error loading model: {e}")
30
  return model
31
 
32
+ @st.cache_data
33
+ def get_data(path):
34
+ if not os.path.exists(path):
35
+ st.sidebar.warning(f"Data not found at {path}. Please run training script.")
36
+ return None
37
+ return torch.load(path)
38
+
39
+ model = get_model(model_path)
40
+ trajectories = get_data(data_path)
41
+
42
+ if trajectories is None:
43
+ st.error("No real data available. Please run `python scripts/train_dt.py` first.")
44
+ st.stop()
45
 
46
+ # Select a trajectory and token for analysis
47
+ traj_idx = st.sidebar.number_input("Select Trajectory", 0, len(trajectories)-1, 0)
48
+ traj = trajectories[traj_idx]
49
+
50
+ tab1, tab2, tab3 = st.tabs(["Circuit Mapping (DLA)", "Causal Intervention (Patching)", "SAE Latents"])
51
 
52
  with tab1:
53
+ st.header("Direct Logit Attribution (DLA)")
54
+ st.write("Visualizing which heads contribute most to the predicted action.")
55
+
56
+ if st.button("Run Attribution"):
57
+ states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
58
+ actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0)
59
+ returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
60
 
61
+ preds, cache = model(states, actions, returns, return_cache=True)
62
+ target_action = preds[0, -1].argmax().item()
 
63
 
64
  engine = LogitAttributionEngine(model)
65
+ dla_results = engine.calculate_dla(cache, target_logit_index=target_action)
66
 
67
  fig, ax = plt.subplots()
68
+ im = ax.imshow(dla_results.detach().cpu().numpy(), cmap="RdBu_r", aspect='auto')
 
69
  plt.colorbar(im)
70
+ ax.set_xlabel("Head")
71
+ ax.set_ylabel("Layer")
72
  st.pyplot(fig)
73
+ st.write(f"Analyzing Attribution for Action: {target_action}")
74
 
75
  with tab2:
76
  st.header("Activation Patching")
77
+ st.write("Quantifying causal importance by patching corrupted activations.")
 
 
 
 
 
 
78
 
79
+ # Simple corruption: zero out the last observation
80
+ corrupted_states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
81
+ corrupted_states[0, -1, :] = 0.0
82
 
83
+ states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
84
+ actions = torch.nn.functional.one_hot(torch.from_numpy(traj["actions"]).long(), num_classes=7).float().unsqueeze(0)
85
+ returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
86
+
87
+ layer = st.selectbox("Layer to Patch", range(model.cfg.n_layers))
88
+ head = st.selectbox("Head to Patch", range(model.cfg.n_heads))
89
+
90
+ if st.button("Calculate Probability Drop"):
91
+ patcher = ActivationPatcher(model)
92
+
93
+ clean_logits = model(states, actions, returns)
94
+ _, corrupted_cache = model(corrupted_states, actions, returns, return_cache=True)
95
+
96
+ patched_logits = patcher.patch_head(
97
+ {"states": states, "actions": actions, "returns_to_go": returns},
98
+ corrupted_cache, layer, head
99
+ )
100
+
101
+ target_idx = clean_logits[0, -1].argmax().item()
102
+ drop = patcher.calculate_probability_drop(
103
+ torch.softmax(clean_logits, dim=-1),
104
+ torch.softmax(patched_logits, dim=-1),
105
+ target_idx
106
+ )
107
+
108
+ st.metric("Logit Prob Drop", f"{drop:.4f}")
109
+ if drop > 0.05:
110
+ st.success(f"Head {layer}.{head} has causal impact on this decision.")
111
+ else:
112
+ st.info("Low causal impact observed for this head.")
113
 
114
  with tab3:
115
+ st.header("SAE Feature Exploration")
116
+ st.info("SAE Integration ready for Phase 3. Latents will be mapped to trajectories here.")
src/interpretability/acdc.py CHANGED
@@ -6,101 +6,68 @@ from tqdm import tqdm
6
  class ACDCDiscovery:
7
  """
8
  Automated Circuit Discovery and Click-through (ACDC).
9
- Prunes a model to find the minimal sufficient subgraph for a specific behavior.
10
  """
11
- def __init__(
12
- self,
13
- model,
14
- threshold: float = 0.1,
15
- metric_fn: Optional[Callable] = None
16
- ):
17
  self.model = model
18
  self.threshold = threshold
19
- self.metric_fn = metric_fn
20
- self.current_circuit = {
21
- "layers": [],
22
- "heads": [],
23
- "mlps": []
24
- }
25
 
26
- def default_metric(self, model_outputs: Tuple, target_action: int) -> float:
27
- """
28
- Default metric: Logit of the target action.
29
- """
30
- action_preds = model_outputs[0] # [batch, seq, action_dim]
31
  return action_preds[0, -1, target_action].item()
32
 
33
- def run(
34
- self,
35
- inputs: Dict[str, torch.Tensor],
36
- target_action: int
37
- ) -> Dict:
38
- """
39
- Runs the ACDC algorithm to prune heads.
40
- """
41
  n_layers = self.model.cfg.n_layers
42
  n_heads = self.model.cfg.n_heads
43
 
44
- # Baseline performance
45
- initial_outputs = self.model(**inputs)
46
- initial_perf = self.default_metric(initial_outputs, target_action)
47
-
48
- active_heads = []
49
- for l in range(n_layers):
50
- for h in range(n_heads):
51
- active_heads.append((l, h))
52
 
 
53
  pruned_heads = []
54
 
55
- # Greedy pruning (backward selection)
56
- pbar = tqdm(active_heads, desc="ACDC Pruning")
57
  for layer, head in pbar:
58
- # Try removing this head
59
- current_pruned = pruned_heads + [(layer, head)]
60
-
61
- perf = self._eval_with_pruning(inputs, current_pruned, target_action)
62
 
63
- # Retain pruning if performance remains within threshold
64
  if abs(perf - initial_perf) < self.threshold:
65
  pruned_heads.append((layer, head))
66
  pbar.set_postfix({"pruned": len(pruned_heads)})
67
 
68
- final_circuit = {
69
- "active_heads": [h for h in active_heads if h not in pruned_heads],
 
70
  "pruned_count": len(pruned_heads),
71
  "initial_perf": initial_perf,
72
  "final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action)
73
  }
74
-
75
- self.current_circuit = final_circuit
76
- return final_circuit
77
 
78
- def _eval_with_pruning(
79
- self,
80
- inputs: Dict[str, torch.Tensor],
81
- pruned_heads: List[Tuple[int, int]],
82
- target_action: int
83
- ) -> float:
84
-
85
  def pruning_hook(value, hook):
86
- # hook.name format: "blocks.L.attn.hook_result"
87
  layer_idx = int(hook.name.split(".")[1])
88
  for p_layer, p_head in pruned_heads:
89
  if p_layer == layer_idx:
90
  value[:, :, p_head, :] = 0.0
91
  return value
92
 
93
- hook_names = [f"blocks.{l}.attn.hook_result" for l in range(self.model.cfg.n_layers)]
94
 
95
- with self.model.transformer.hooks(fwd_hooks=[(name, pruning_hook) for name in hook_names]):
96
- outputs = self.model(**inputs)
97
 
98
- return self.default_metric(outputs, target_action)
99
 
100
  def save_manifest(self, path: str):
101
- """Saves the circuit manifest to a JSON file."""
102
  with open(path, 'w') as f:
103
- # Convert tuples to strings for JSON
104
- serializable_circuit = self.current_circuit.copy()
105
- serializable_circuit["active_heads"] = [f"L{l}H{h}" for l, h in serializable_circuit["active_heads"]]
106
- json.dump(serializable_circuit, f, indent=4)
 
6
  class ACDCDiscovery:
7
  """
8
  Automated Circuit Discovery and Click-through (ACDC).
9
+ Finds the minimal set of heads needed to maintain model performance.
10
  """
11
+ def __init__(self, model, threshold: float = 0.1):
 
 
 
 
 
12
  self.model = model
13
  self.threshold = threshold
14
+ self.current_circuit = {}
 
 
 
 
 
15
 
16
+ def get_metric(self, action_preds: torch.Tensor, target_action: int) -> float:
17
+ """Calculates logit of the target action at the last timestep."""
 
 
 
18
  return action_preds[0, -1, target_action].item()
19
 
20
+ def run(self, inputs: dict, target_action: int) -> dict:
21
+ """Greedily prunes heads while keeping performance above threshold."""
 
 
 
 
 
 
22
  n_layers = self.model.cfg.n_layers
23
  n_heads = self.model.cfg.n_heads
24
 
25
+ # Get baseline performance
26
+ initial_preds = self.model(**inputs)
27
+ initial_perf = self.get_metric(initial_preds, target_action)
 
 
 
 
 
28
 
29
+ all_heads = [(l, h) for l in range(n_layers) for h in range(n_heads)]
30
  pruned_heads = []
31
 
32
+ pbar = tqdm(all_heads, desc="ACDC Pruning")
 
33
  for layer, head in pbar:
34
+ # Try pruning this head + already pruned heads
35
+ trial_pruned = pruned_heads + [(layer, head)]
36
+ perf = self._eval_with_pruning(inputs, trial_pruned, target_action)
 
37
 
38
+ # If performance is still good, keep it pruned
39
  if abs(perf - initial_perf) < self.threshold:
40
  pruned_heads.append((layer, head))
41
  pbar.set_postfix({"pruned": len(pruned_heads)})
42
 
43
+ active_heads = [h for h in all_heads if h not in pruned_heads]
44
+ self.current_circuit = {
45
+ "active_heads": active_heads,
46
  "pruned_count": len(pruned_heads),
47
  "initial_perf": initial_perf,
48
  "final_perf": self._eval_with_pruning(inputs, pruned_heads, target_action)
49
  }
50
+ return self.current_circuit
 
 
51
 
52
+ def _eval_with_pruning(self, inputs: dict, pruned_heads: list, target_action: int) -> float:
53
+ """Evaluates model with specified heads zeroed out."""
 
 
 
 
 
54
  def pruning_hook(value, hook):
 
55
  layer_idx = int(hook.name.split(".")[1])
56
  for p_layer, p_head in pruned_heads:
57
  if p_layer == layer_idx:
58
  value[:, :, p_head, :] = 0.0
59
  return value
60
 
61
+ hooks = [(f"blocks.{l}.attn.hook_result", pruning_hook) for l in range(self.model.cfg.n_layers)]
62
 
63
+ with self.model.transformer.hooks(fwd_hooks=hooks):
64
+ preds = self.model(**inputs)
65
 
66
+ return self.get_metric(preds, target_action)
67
 
68
  def save_manifest(self, path: str):
69
+ """Saves discovered circuit to a JSON file."""
70
  with open(path, 'w') as f:
71
+ data = self.current_circuit.copy()
72
+ data["active_heads"] = [f"L{l}H{h}" for l, h in data["active_heads"]]
73
+ json.dump(data, f, indent=4)
 
src/interpretability/nla.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Dict, Optional
3
+ import requests
4
+
5
+ class NLAExplainer:
6
+ """
7
+ Natural Language Autoencoder (NLA) Explainer.
8
+ Uses an LLM to auto-label SAE features based on activation patterns.
9
+ """
10
+ def __init__(self, api_key: Optional[str] = None, model_name: str = "gpt-4-turbo"):
11
+ self.api_key = api_key
12
+ self.model_name = model_name
13
+ self.feature_labels: Dict[int, str] = {}
14
+
15
+ def generate_label(
16
+ self,
17
+ feature_id: int,
18
+ top_activations: List[Dict],
19
+ context_description: str = "MiniGrid environment agent state"
20
+ ) -> str:
21
+ """
22
+ Generates a natural language label for a specific SAE feature.
23
+ In a real scenario, this would call an LLM API.
24
+ """
25
+ if not self.api_key:
26
+ # Mock labeling for demonstration if no API key is provided
27
+ label = f"Mock Feature {feature_id}: Activates on {context_description} pattern"
28
+ self.feature_labels[feature_id] = label
29
+ return label
30
+
31
+ prompt = self._build_prompt(feature_id, top_activations, context_description)
32
+
33
+ # This is a placeholder for a real API call (e.g., OpenAI, Anthropic, or custom)
34
+ # label = self._call_llm_api(prompt)
35
+ label = f"Auto-labeled Feature {feature_id}"
36
+
37
+ self.feature_labels[feature_id] = label
38
+ return label
39
+
40
+ def _build_prompt(self, feature_id: int, top_activations: List[Dict], context: str) -> str:
41
+ """Constructs the prompt for the LLM explainer."""
42
+ examples = "\n".join([f"- State: {a['state']}, Activation: {a['value']:.4f}" for a in top_activations])
43
+ return (
44
+ f"I have a Sparse Autoencoder feature (ID: {feature_id}) trained on a Decision Transformer. "
45
+ f"The context is: {context}.\n"
46
+ f"Here are the top activations for this feature:\n{examples}\n"
47
+ "What is the most likely semantic meaning of this feature? Provide a concise label."
48
+ )
49
+
50
+ def get_label(self, feature_id: int) -> str:
51
+ return self.feature_labels.get(feature_id, f"Unlabeled Feature {feature_id}")
52
+
53
+ def bulk_label(self, feature_ids: List[int], activation_data: Dict[int, List[Dict]]):
54
+ """Labels multiple features in sequence."""
55
+ for fid in feature_ids:
56
+ if fid in activation_data:
57
+ self.generate_label(fid, activation_data[fid])
src/interpretability/path_patching.py CHANGED
@@ -1,64 +1,22 @@
1
  import torch
2
  from typing import Dict, Optional, Tuple
3
- from transformer_lens import HookedTransformer
4
 
5
  class PathPatchingEngine:
6
  """
7
- Engine for performing path-based causal interventions.
8
- Allows isolating the influence of specific components on others.
9
  """
10
  def __init__(self, model):
11
  self.model = model
12
 
13
- def patch_path(
14
- self,
15
- clean_inputs: Dict[str, torch.Tensor],
16
- corrupted_cache: Dict[str, torch.Tensor],
17
- src_layer: int,
18
- src_head: int,
19
- dest_layer: int,
20
- dest_head: int,
21
- component_type: str = "q", # 'q', 'k', or 'v'
22
- ) -> torch.Tensor:
23
- """
24
- Patches the path from a specific source head to a destination head's input (Q, K, or V).
25
-
26
- Args:
27
- clean_inputs: Dictionary of clean input tensors.
28
- corrupted_cache: Cache containing activations from a corrupted run.
29
- src_layer: Layer index of the source head.
30
- src_head: Head index of the source head.
31
- dest_layer: Layer index of the destination head.
32
- dest_head: Head index of the destination head.
33
- component_type: Which input projection of the destination head to patch.
34
-
35
- Returns:
36
- The output of the model with the path patched.
37
- """
38
-
39
- # Source component output hook name
40
- src_hook_name = f"blocks.{src_layer}.attn.hook_result"
41
- # Destination component input hook name
42
- dest_hook_name = f"blocks.{dest_layer}.hook_{component_type}_input"
43
-
44
- def path_patch_hook(value, hook):
45
- # Replace destination head input with source head contribution from corrupted cache.
46
- # Current implementation patches head output to observe downstream impact.
47
- return value
48
-
49
- # Focuses on Goal -> Head -> Action logic in DT-Circuits.
50
- pass
51
-
52
  def perform_edge_ablation(
53
  self,
54
- inputs: Dict[str, torch.Tensor],
55
  layer: int,
56
  head_index: int,
57
  ablation_type: str = "zero"
58
  ) -> torch.Tensor:
59
- """
60
- Ablates a specific edge (head) to see its necessity.
61
- """
62
  def ablation_hook(value, hook):
63
  if ablation_type == "zero":
64
  value[:, :, head_index, :] = 0.0
@@ -66,5 +24,24 @@ class PathPatchingEngine:
66
 
67
  hook_name = f"blocks.{layer}.attn.hook_result"
68
  with self.model.transformer.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
69
- outputs = self.model(**inputs)
70
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, Optional, Tuple
 
3
 
4
  class PathPatchingEngine:
5
  """
6
+ Engine for path-based causal interventions.
7
+ Helps isolate which internal paths are necessary for a decision.
8
  """
9
  def __init__(self, model):
10
  self.model = model
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def perform_edge_ablation(
13
  self,
14
+ inputs: dict,
15
  layer: int,
16
  head_index: int,
17
  ablation_type: str = "zero"
18
  ) -> torch.Tensor:
19
+ """Zeroes out a specific head's output to check its causal necessity."""
 
 
20
  def ablation_hook(value, hook):
21
  if ablation_type == "zero":
22
  value[:, :, head_index, :] = 0.0
 
24
 
25
  hook_name = f"blocks.{layer}.attn.hook_result"
26
  with self.model.transformer.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
27
+ preds = self.model(**inputs)
28
+ return preds
29
+
30
+ def patch_path(
31
+ self,
32
+ clean_inputs: dict,
33
+ corrupted_cache: dict,
34
+ layer: int,
35
+ head: int
36
+ ) -> torch.Tensor:
37
+ """Patches a specific head's output with activations from a corrupted run."""
38
+ def patch_hook(value, hook):
39
+ # value: [batch, pos, head, d_model]
40
+ corrupted_val = corrupted_cache[hook.name]
41
+ value[:, :, head, :] = corrupted_val[:, :, head, :]
42
+ return value
43
+
44
+ hook_name = f"blocks.{layer}.attn.hook_result"
45
+ with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
46
+ preds = self.model(**clean_inputs)
47
+ return preds
src/interpretability/sae_manager.py CHANGED
@@ -2,17 +2,22 @@ import torch
2
  import torch.nn as nn
3
  import os
4
  from typing import Dict, List, Optional, Tuple, Union
5
- from sae_lens import StandardSAE, StandardSAEConfig
 
 
 
 
6
  from jaxtyping import Float
7
 
8
  class SAEManager:
9
  """
10
  Handles SAE training, latent decomposition, and anomaly detection for DTs.
 
11
  """
12
  def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
13
  self.model = model
14
  self.sae_dir = sae_dir
15
- self.saes: Dict[str, StandardSAE] = {}
16
  os.makedirs(sae_dir, exist_ok=True)
17
 
18
  def setup_sae(
@@ -20,14 +25,31 @@ class SAEManager:
20
  hook_point: str,
21
  d_model: int,
22
  expansion_factor: int = 8,
23
- ) -> StandardSAE:
24
- """Initializes SAE for a specific hook point."""
25
- cfg = StandardSAEConfig(
26
- d_in=d_model,
27
- d_sae=d_model * expansion_factor,
28
- device=str(next(self.model.parameters()).device)
29
- )
30
- sae = StandardSAE(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.saes[hook_point] = sae
32
  return sae
33
 
@@ -39,7 +61,7 @@ class SAEManager:
39
  batch_size: int = 1024,
40
  epochs: int = 10,
41
  ):
42
- """Trains SAE on trajectory activations."""
43
  if hook_point not in self.saes:
44
  self.setup_sae(hook_point, activations.shape[-1])
45
 
@@ -48,6 +70,7 @@ class SAEManager:
48
 
49
  sae.train()
50
  n_samples = activations.shape[0]
 
51
 
52
  for epoch in range(epochs):
53
  permutation = torch.randperm(n_samples)
@@ -62,8 +85,13 @@ class SAEManager:
62
  sae_out = sae.decode(feature_acts)
63
 
64
  mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts)
65
- l1_loss = l1_coefficient * feature_acts.abs().sum()
66
- loss = mse_loss + l1_loss
 
 
 
 
 
67
 
68
  loss.backward()
69
  optimizer.step()
@@ -78,7 +106,7 @@ class SAEManager:
78
  ) -> Float[torch.Tensor, "... d_sae"]:
79
  """Decomposes activations into latent features."""
80
  if hook_point not in self.saes:
81
- raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
82
 
83
  sae = self.saes[hook_point]
84
  sae.eval()
@@ -108,7 +136,7 @@ class SAEManager:
108
  activations: Float[torch.Tensor, "... d_model"]
109
  ) -> Float[torch.Tensor, "..."]:
110
  """
111
- Reconstruction error for anomaly detection: ||x - x_hat|| / ||x||
112
  """
113
  if hook_point not in self.saes:
114
  raise ValueError(f"SAE for {hook_point} not found.")
@@ -128,7 +156,8 @@ class SAEManager:
128
  path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt")
129
  torch.save({
130
  'state_dict': sae.state_dict(),
131
- 'cfg': sae.cfg
 
132
  }, path)
133
  print(f"Saved SAE for {hook} to {path}")
134
 
@@ -137,8 +166,12 @@ class SAEManager:
137
  if not os.path.exists(path):
138
  raise FileNotFoundError(f"No saved SAE found at {path}")
139
 
140
- checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device))
141
- sae = StandardSAE(checkpoint['cfg'])
 
 
 
 
142
  sae.load_state_dict(checkpoint['state_dict'])
143
  self.saes[hook_point] = sae
144
  return sae
 
2
  import torch.nn as nn
3
  import os
4
  from typing import Dict, List, Optional, Tuple, Union
5
+ from sae_lens import (
6
+ StandardSAE, StandardSAEConfig,
7
+ TopKSAE, TopKSAEConfig,
8
+ SAE, SAEConfig
9
+ )
10
  from jaxtyping import Float
11
 
12
  class SAEManager:
13
  """
14
  Handles SAE training, latent decomposition, and anomaly detection for DTs.
15
+ Supports Standard (ReLU) and TopK architectures.
16
  """
17
  def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
18
  self.model = model
19
  self.sae_dir = sae_dir
20
+ self.saes: Dict[str, Union[StandardSAE, TopKSAE]] = {}
21
  os.makedirs(sae_dir, exist_ok=True)
22
 
23
  def setup_sae(
 
25
  hook_point: str,
26
  d_model: int,
27
  expansion_factor: int = 8,
28
+ architecture: str = "standard",
29
+ k: Optional[int] = None,
30
+ ) -> Union[StandardSAE, TopKSAE]:
31
+ """Initializes an SAE (Standard or TopK) for a specific hook point."""
32
+ d_sae = d_model * expansion_factor
33
+ device = str(next(self.model.parameters()).device)
34
+
35
+ if architecture == "topk":
36
+ if k is None:
37
+ k = d_sae // 32 # Default sparsity
38
+ cfg = TopKSAEConfig(
39
+ d_in=d_model,
40
+ d_sae=d_sae,
41
+ k=k,
42
+ device=device
43
+ )
44
+ sae = TopKSAE(cfg)
45
+ else:
46
+ cfg = StandardSAEConfig(
47
+ d_in=d_model,
48
+ d_sae=d_sae,
49
+ device=device
50
+ )
51
+ sae = StandardSAE(cfg)
52
+
53
  self.saes[hook_point] = sae
54
  return sae
55
 
 
61
  batch_size: int = 1024,
62
  epochs: int = 10,
63
  ):
64
+ """Trains the SAE on collected activations."""
65
  if hook_point not in self.saes:
66
  self.setup_sae(hook_point, activations.shape[-1])
67
 
 
70
 
71
  sae.train()
72
  n_samples = activations.shape[0]
73
+ is_topk = isinstance(sae, TopKSAE)
74
 
75
  for epoch in range(epochs):
76
  permutation = torch.randperm(n_samples)
 
85
  sae_out = sae.decode(feature_acts)
86
 
87
  mse_loss = torch.nn.functional.mse_loss(sae_out, batch_acts)
88
+
89
+ if is_topk:
90
+ # TopK doesn't use L1; sparsity is enforced by architecture
91
+ loss = mse_loss
92
+ else:
93
+ l1_loss = l1_coefficient * feature_acts.abs().sum()
94
+ loss = mse_loss + l1_loss
95
 
96
  loss.backward()
97
  optimizer.step()
 
106
  ) -> Float[torch.Tensor, "... d_sae"]:
107
  """Decomposes activations into latent features."""
108
  if hook_point not in self.saes:
109
+ raise ValueError(f"SAE for {hook_point} not found.")
110
 
111
  sae = self.saes[hook_point]
112
  sae.eval()
 
136
  activations: Float[torch.Tensor, "... d_model"]
137
  ) -> Float[torch.Tensor, "..."]:
138
  """
139
+ Reconstruction error for anomaly detection.
140
  """
141
  if hook_point not in self.saes:
142
  raise ValueError(f"SAE for {hook_point} not found.")
 
156
  path = os.path.join(self.sae_dir, f"{hook.replace('.', '_')}_sae.pt")
157
  torch.save({
158
  'state_dict': sae.state_dict(),
159
+ 'cfg': sae.cfg,
160
+ 'type': 'topk' if isinstance(sae, TopKSAE) else 'standard'
161
  }, path)
162
  print(f"Saved SAE for {hook} to {path}")
163
 
 
166
  if not os.path.exists(path):
167
  raise FileNotFoundError(f"No saved SAE found at {path}")
168
 
169
+ checkpoint = torch.load(path, map_location=str(next(self.model.parameters()).device), weights_only=False)
170
+ if checkpoint.get('type') == 'topk':
171
+ sae = TopKSAE(checkpoint['cfg'])
172
+ else:
173
+ sae = StandardSAE(checkpoint['cfg'])
174
+
175
  sae.load_state_dict(checkpoint['state_dict'])
176
  self.saes[hook_point] = sae
177
  return sae
src/interpretability/universality.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict, List, Any
4
+ import numpy as np
5
+
6
+ class UniversalityProbe:
7
+ """
8
+ Probes for universal feature representations across different architectures (e.g., DT vs DQN).
9
+ """
10
+ def __init__(self, dt_model: nn.Module, dqn_model: nn.Module):
11
+ self.dt_model = dt_model
12
+ self.dqn_model = dqn_model
13
+
14
+ def collect_paired_activations(
15
+ self,
16
+ env_states: torch.Tensor,
17
+ dt_hook_point: str,
18
+ dqn_layer_idx: int
19
+ ) -> Dict[str, torch.Tensor]:
20
+ """
21
+ Collects activations from both models on the same set of environmental states.
22
+ """
23
+ # DT activations (assuming cache is handled or provided)
24
+ # This is a simplified placeholder
25
+ dt_acts = torch.randn(env_states.shape[0], 128) # Mock
26
+
27
+ # DQN activations
28
+ # dqn_acts = self.dqn_model.get_layer_activations(env_states, dqn_layer_idx)
29
+ dqn_acts = torch.randn(env_states.shape[0], 64) # Mock
30
+
31
+ return {
32
+ "dt": dt_acts,
33
+ "dqn": dqn_acts
34
+ }
35
+
36
+ def compute_cross_correlation(
37
+ self,
38
+ dt_sae_features: torch.Tensor,
39
+ dqn_activations: torch.Tensor
40
+ ) -> torch.Tensor:
41
+ """
42
+ Computes the correlation matrix between DT SAE features and DQN activations.
43
+ High correlation suggests a 'Universal Concept'.
44
+ """
45
+ # Normalize
46
+ dt_feat_norm = (dt_sae_features - dt_sae_features.mean(dim=0)) / (dt_sae_features.std(dim=0) + 1e-8)
47
+ dqn_act_norm = (dqn_activations - dqn_activations.mean(dim=0)) / (dqn_activations.std(dim=0) + 1e-8)
48
+
49
+ # Correlation matrix
50
+ correlation = torch.matmul(dt_feat_norm.t(), dqn_act_norm) / dt_feat_norm.shape[0]
51
+ return correlation
52
+
53
+ def identify_universal_features(
54
+ self,
55
+ correlation_matrix: torch.Tensor,
56
+ threshold: float = 0.7
57
+ ) -> List[Dict[str, Any]]:
58
+ """
59
+ Identifies pairs of (DT Feature, DQN Neuron) that represent the same concept.
60
+ """
61
+ universal_pairs = []
62
+ matches = (correlation_matrix.abs() > threshold).nonzero()
63
+
64
+ for i, j in matches:
65
+ universal_pairs.append({
66
+ "dt_feature_idx": i.item(),
67
+ "dqn_neuron_idx": j.item(),
68
+ "correlation": correlation_matrix[i, j].item()
69
+ })
70
+
71
+ return universal_pairs
src/models/hooked_dt.py CHANGED
@@ -6,7 +6,7 @@ from typing import Optional, Union, List
6
 
7
  class HookedDT(nn.Module):
8
  """
9
- A Decision Transformer implementation wrapped in TransformerLens logic.
10
  Supports State, Action, and Reward-to-Go (RTG) tokens.
11
  """
12
  def __init__(
@@ -15,7 +15,6 @@ class HookedDT(nn.Module):
15
  state_dim: int,
16
  action_dim: int,
17
  max_length: int = 30,
18
- max_ep_len: int = 1000,
19
  ):
20
  super().__init__()
21
  self.cfg = cfg
@@ -23,71 +22,62 @@ class HookedDT(nn.Module):
23
  self.action_dim = action_dim
24
  self.max_length = max_length
25
 
26
- # HookedTransformer for the core transformer blocks
27
  self.transformer = HookedTransformer(cfg)
28
 
29
- # Custom embeddings for DT
30
  self.embed_return = nn.Linear(1, cfg.d_model)
31
  self.embed_state = nn.Linear(state_dim, cfg.d_model)
32
  self.embed_action = nn.Linear(action_dim, cfg.d_model)
33
-
34
  self.embed_ln = nn.LayerNorm(cfg.d_model)
35
 
36
  # Prediction heads
37
- self.predict_action = nn.Sequential(
38
- nn.Linear(cfg.d_model, action_dim)
39
- )
40
- self.predict_return = nn.Sequential(
41
- nn.Linear(cfg.d_model, 1)
42
- )
43
- self.predict_state = nn.Sequential(
44
- nn.Linear(cfg.d_model, state_dim)
45
- )
46
 
47
- def forward(
48
- self,
49
- states: Float[torch.Tensor, "batch seq state_dim"],
50
- actions: Float[torch.Tensor, "batch seq action_dim"],
51
- returns_to_go: Float[torch.Tensor, "batch seq 1"],
52
- timesteps: Int[torch.Tensor, "batch seq"],
53
- attention_mask: Optional[Float[torch.Tensor, "batch seq"]] = None,
54
- ):
55
  batch_size, seq_len, _ = states.shape
56
-
57
- state_embeddings = self.embed_state(states)
58
- action_embeddings = self.embed_action(actions)
59
- returns_embeddings = self.embed_return(returns_to_go)
60
 
61
- # Interleave (Return, State, Action)
62
- stacked_inputs = torch.stack(
63
- (returns_embeddings, state_embeddings, action_embeddings), dim=2
64
- ).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
65
 
66
- stacked_inputs = self.embed_ln(stacked_inputs)
 
 
 
67
 
68
- def embed_hook(value, hook):
69
- return stacked_inputs
70
-
71
- # Inject interleaved embeddings into TransformerLens
72
- dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
73
-
74
- last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
75
-
76
- with self.transformer.hooks(fwd_hooks=[("hook_embed", embed_hook)]):
77
- _, cache = self.transformer.run_with_cache(
78
- dummy_input,
79
- names_filter=lambda name: name == last_block_hook
80
- )
81
 
82
- transformer_outputs = cache[last_block_hook]
83
- x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
84
 
85
- # Compute predictions
86
- action_preds = self.predict_action(x[:, :, 1])
87
- return_preds = self.predict_return(x[:, :, 2])
88
- state_preds = self.predict_state(x[:, :, 2])
89
-
90
- return action_preds, state_preds, return_preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  @classmethod
93
  def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128):
@@ -97,7 +87,7 @@ class HookedDT(nn.Module):
97
  n_ctx=300,
98
  d_head=d_model // n_heads,
99
  n_heads=n_heads,
100
- d_vocab=10,
101
  act_fn="relu",
102
  d_mlp=d_model * 4,
103
  normalization_type="LN",
@@ -105,3 +95,4 @@ class HookedDT(nn.Module):
105
  device="cuda" if torch.cuda.is_available() else "cpu"
106
  )
107
  return cls(cfg, state_dim, action_dim)
 
 
6
 
7
  class HookedDT(nn.Module):
8
  """
9
+ Decision Transformer wrapped in TransformerLens logic.
10
  Supports State, Action, and Reward-to-Go (RTG) tokens.
11
  """
12
  def __init__(
 
15
  state_dim: int,
16
  action_dim: int,
17
  max_length: int = 30,
 
18
  ):
19
  super().__init__()
20
  self.cfg = cfg
 
22
  self.action_dim = action_dim
23
  self.max_length = max_length
24
 
25
+ # Core transformer blocks from TransformerLens
26
  self.transformer = HookedTransformer(cfg)
27
 
28
+ # DT-specific embeddings
29
  self.embed_return = nn.Linear(1, cfg.d_model)
30
  self.embed_state = nn.Linear(state_dim, cfg.d_model)
31
  self.embed_action = nn.Linear(action_dim, cfg.d_model)
 
32
  self.embed_ln = nn.LayerNorm(cfg.d_model)
33
 
34
  # Prediction heads
35
+ self.predict_action = nn.Sequential(nn.Linear(cfg.d_model, action_dim))
36
+ self.predict_return = nn.Sequential(nn.Linear(cfg.d_model, 1))
37
+ self.predict_state = nn.Sequential(nn.Linear(cfg.d_model, state_dim))
 
 
 
 
 
 
38
 
39
+ def get_embeddings(self, states, actions, returns_to_go):
40
+ """Interleaves RTG, State, and Action embeddings."""
 
 
 
 
 
 
41
  batch_size, seq_len, _ = states.shape
 
 
 
 
42
 
43
+ ret_emb = self.embed_return(returns_to_go)
44
+ state_emb = self.embed_state(states)
45
+ act_emb = self.embed_action(actions)
 
46
 
47
+ # Interleave: [R1, S1, A1, R2, S2, A2, ...]
48
+ stacked = torch.stack((ret_emb, state_emb, act_emb), dim=2)
49
+ stacked = stacked.reshape(batch_size, 3 * seq_len, self.cfg.d_model)
50
+ return self.embed_ln(stacked)
51
 
52
+ def forward(self, states, actions, returns_to_go, timesteps=None, return_cache=False):
53
+ """Forward pass through DT."""
54
+ embeddings = self.get_embeddings(states, actions, returns_to_go)
55
+ dummy_tokens = torch.zeros((embeddings.shape[0], embeddings.shape[1]),
56
+ dtype=torch.long, device=embeddings.device)
 
 
 
 
 
 
 
 
57
 
58
+ def inject_embeddings(value, hook):
59
+ return embeddings
60
 
61
+ # We need the residual stream post-processing from the last block
62
+ last_resid_hook = f"blocks.{self.cfg.n_layers-1}.hook_resid_post"
63
+
64
+ if return_cache:
65
+ with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
66
+ _, cache = self.transformer.run_with_cache(dummy_tokens)
67
+
68
+ last_resid = cache[last_resid_hook]
69
+ x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
70
+ action_preds = self.predict_action(x[:, :, 1]) # State token predicts action
71
+ return action_preds, cache
72
+ else:
73
+ with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
74
+ # run_with_cache is safer to ensure we can grab the specific hook output
75
+ _, cache = self.transformer.run_with_cache(dummy_tokens, names_filter=lambda n: n == last_resid_hook)
76
+
77
+ last_resid = cache[last_resid_hook]
78
+ x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
79
+ action_preds = self.predict_action(x[:, :, 1])
80
+ return action_preds
81
 
82
  @classmethod
83
  def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128):
 
87
  n_ctx=300,
88
  d_head=d_model // n_heads,
89
  n_heads=n_heads,
90
+ d_vocab=10, # Dummy vocab size
91
  act_fn="relu",
92
  d_mlp=d_model * 4,
93
  normalization_type="LN",
 
95
  device="cuda" if torch.cuda.is_available() else "cpu"
96
  )
97
  return cls(cfg, state_dim, action_dim)
98
+
tests/test_components.py CHANGED
@@ -5,35 +5,37 @@ from src.interpretability.attribution import LogitAttributionEngine
5
  from transformer_lens import HookedTransformerConfig
6
 
7
  def test_hooked_dt_forward():
8
- state_dim = 10
9
- action_dim = 5
10
- seq_len = 5
11
- batch_size = 2
12
-
13
  model = HookedDT.from_config(state_dim, action_dim, n_layers=1, n_heads=2, d_model=32)
14
 
15
  states = torch.randn(batch_size, seq_len, state_dim)
16
  actions = torch.randn(batch_size, seq_len, action_dim)
17
  returns = torch.randn(batch_size, seq_len, 1)
18
- timesteps = torch.arange(seq_len).repeat(batch_size, 1)
19
-
20
- action_preds, state_preds, return_preds = model(states, actions, returns, timesteps)
21
 
 
22
  assert action_preds.shape == (batch_size, seq_len, action_dim)
23
- assert state_preds.shape == (batch_size, seq_len, state_dim)
24
- assert return_preds.shape == (batch_size, seq_len, 1)
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def test_logit_attribution_shape():
27
- state_dim = 10
28
- action_dim = 5
29
  model = HookedDT.from_config(state_dim, action_dim, n_layers=2, n_heads=4, d_model=32)
30
  engine = LogitAttributionEngine(model)
31
 
32
- # Mock cache
33
- cache = {}
34
- for l in range(2):
35
- cache[f"blocks.{l}.attn.hook_result"] = torch.randn(1, 15, 4, 32)
36
-
37
  dla = engine.calculate_dla(cache, target_logit_index=0, token_index=-1)
38
  assert dla.shape == (2, 4)
39
 
 
5
  from transformer_lens import HookedTransformerConfig
6
 
7
  def test_hooked_dt_forward():
8
+ """Verifies basic forward pass of HookedDT."""
9
+ state_dim, action_dim, seq_len, batch_size = 10, 5, 5, 2
 
 
 
10
  model = HookedDT.from_config(state_dim, action_dim, n_layers=1, n_heads=2, d_model=32)
11
 
12
  states = torch.randn(batch_size, seq_len, state_dim)
13
  actions = torch.randn(batch_size, seq_len, action_dim)
14
  returns = torch.randn(batch_size, seq_len, 1)
 
 
 
15
 
16
+ action_preds = model(states, actions, returns)
17
  assert action_preds.shape == (batch_size, seq_len, action_dim)
18
+
19
+ def test_hooked_dt_with_cache():
20
+ """Verifies that cache is returned correctly."""
21
+ state_dim, action_dim, seq_len, batch_size = 10, 5, 5, 1
22
+ model = HookedDT.from_config(state_dim, action_dim, n_layers=1, n_heads=2, d_model=32)
23
+
24
+ states = torch.randn(batch_size, seq_len, state_dim)
25
+ actions = torch.randn(batch_size, seq_len, action_dim)
26
+ returns = torch.randn(batch_size, seq_len, 1)
27
+
28
+ preds, cache = model(states, actions, returns, return_cache=True)
29
+ assert "blocks.0.attn.hook_result" in cache
30
+ assert preds.shape == (batch_size, seq_len, action_dim)
31
 
32
  def test_logit_attribution_shape():
33
+ """Checks that DLA engine produces the correct result matrix."""
34
+ state_dim, action_dim = 10, 5
35
  model = HookedDT.from_config(state_dim, action_dim, n_layers=2, n_heads=4, d_model=32)
36
  engine = LogitAttributionEngine(model)
37
 
38
+ cache = {f"blocks.{l}.attn.hook_result": torch.randn(1, 15, 4, 32) for l in range(2)}
 
 
 
 
39
  dla = engine.calculate_dla(cache, target_logit_index=0, token_index=-1)
40
  assert dla.shape == (2, 4)
41
 
tests/test_high_fidelity_latents.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytest
4
+ import os
5
+ from src.interpretability.sae_manager import SAEManager
6
+ from src.interpretability.nla import NLAExplainer
7
+ from src.interpretability.universality import UniversalityProbe
8
+
9
+ class MockModel(nn.Module):
10
+ def __init__(self, d_model=128):
11
+ super().__init__()
12
+ self.param = nn.Parameter(torch.randn(1))
13
+ self.d_model = d_model
14
+
15
+ def test_topk_sae_setup_and_training():
16
+ model = MockModel()
17
+ manager = SAEManager(model, sae_dir="tests/artifacts/saes")
18
+
19
+ hook_point = "blocks.0.hook_resid_post"
20
+ d_model = 128
21
+
22
+ # Setup TopK SAE
23
+ sae = manager.setup_sae(hook_point, d_model, architecture="topk", k=10)
24
+
25
+ assert sae.cfg.k == 10
26
+ assert sae.cfg.d_in == d_model
27
+
28
+ # Mock activations
29
+ activations = torch.randn(100, d_model)
30
+
31
+ # Test training (short run)
32
+ manager.train_on_trajectories(hook_point, activations, epochs=1, batch_size=10)
33
+
34
+ # Test feature extraction
35
+ features = manager.get_feature_activations(hook_point, activations)
36
+ assert features.shape[0] == 100
37
+ # TopK should have exactly k active features per sample (or less if some are zero, but usually k)
38
+ l0 = (features > 0).float().sum(dim=-1)
39
+ assert torch.all(l0 <= 10)
40
+
41
+ # Test save/load
42
+ manager.save_all_saes()
43
+
44
+ new_manager = SAEManager(model, sae_dir="tests/artifacts/saes")
45
+ loaded_sae = new_manager.load_sae(hook_point)
46
+ assert loaded_sae.cfg.k == 10
47
+
48
+ def test_nla_explainer():
49
+ explainer = NLAExplainer()
50
+
51
+ feature_id = 42
52
+ top_acts = [
53
+ {"state": "near_wall", "value": 0.9},
54
+ {"state": "facing_wall", "value": 0.85}
55
+ ]
56
+
57
+ label = explainer.generate_label(feature_id, top_acts, context_description="Wall avoidance")
58
+ assert "Mock Feature 42" in label
59
+ assert explainer.get_label(feature_id) == label
60
+
61
+ def test_universality_probe():
62
+ dt_model = MockModel(d_model=128)
63
+ dqn_model = MockModel(d_model=64)
64
+ probe = UniversalityProbe(dt_model, dqn_model)
65
+
66
+ # Mock data
67
+ dt_features = torch.randn(100, 32)
68
+ dqn_activations = torch.randn(100, 16)
69
+
70
+ # Force a high correlation for testing
71
+ dt_features[:, 0] = dqn_activations[:, 0] * 2 + 0.1
72
+
73
+ corr_matrix = probe.compute_cross_correlation(dt_features, dqn_activations)
74
+ assert corr_matrix.shape == (32, 16)
75
+
76
+ universal = probe.identify_universal_features(corr_matrix, threshold=0.9)
77
+ assert len(universal) > 0
78
+ assert universal[0]["dt_feature_idx"] == 0
79
+ assert universal[0]["dqn_neuron_idx"] == 0
80
+
81
+ if __name__ == "__main__":
82
+ pytest.main([__file__])
tests/test_path_causal_microscope.py CHANGED
@@ -13,30 +13,21 @@ def model():
13
 
14
  @pytest.fixture
15
  def sample_inputs():
16
- batch_size = 1
17
- seq_len = 5
18
- state_dim = 10
19
- action_dim = 3
20
  return {
21
- "states": torch.randn(batch_size, seq_len, state_dim),
22
- "actions": torch.zeros(batch_size, seq_len, action_dim),
23
- "returns_to_go": torch.ones(batch_size, seq_len, 1),
24
- "timesteps": torch.arange(seq_len).unsqueeze(0)
25
  }
26
 
27
  def test_acdc_discovery(model, sample_inputs):
28
- # Ensure model is in eval mode
29
  model.eval()
30
-
31
- target_action = 1
32
- acdc = ACDCDiscovery(model, threshold=0.5) # High threshold for quick test
33
- circuit = acdc.run(sample_inputs, target_action)
34
 
35
  assert "active_heads" in circuit
36
  assert "initial_perf" in circuit
37
- assert "final_perf" in circuit
38
 
39
- # Save manifest check
40
  manifest_path = "circuit_manifest.json"
41
  acdc.save_manifest(manifest_path)
42
  assert os.path.exists(manifest_path)
@@ -48,22 +39,19 @@ def test_acdc_discovery(model, sample_inputs):
48
  os.remove(manifest_path)
49
 
50
  def test_path_patching_ablation(model, sample_inputs):
 
51
  engine = PathPatchingEngine(model)
52
 
53
- # Run original
54
- orig_output, _, _ = model(**sample_inputs)
55
-
56
- # Ablate L0 H0
57
- ablated_output, _, _ = engine.perform_edge_ablation(
58
  sample_inputs, layer=0, head_index=0, ablation_type="zero"
59
  )
60
 
61
- # Check if they differ - using a very small tolerance or direct check
62
  diff = (orig_output - ablated_output).abs().max().item()
63
- assert diff > 0, "Ablation should have some effect on output"
64
 
65
  def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
66
- # Create dummy checkpoints
67
  checkpoint_dir = tmp_path / "checkpoints"
68
  checkpoint_dir.mkdir()
69
 
@@ -71,7 +59,6 @@ def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
71
  torch.save(model.state_dict(), checkpoint_dir / "step_200.pt")
72
 
73
  scanner = EvolutionaryScanner(HookedDT, state_dim=10, action_dim=3)
74
- # Pass d_model and n_heads to match the fixture model
75
  results = scanner.scan_checkpoints(
76
  str(checkpoint_dir),
77
  sample_inputs,
@@ -81,6 +68,4 @@ def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
81
  )
82
 
83
  assert len(results) == 2
84
- assert "checkpoint" in results[0]
85
  assert "active_heads" in results[0]
86
-
 
13
 
14
  @pytest.fixture
15
  def sample_inputs():
 
 
 
 
16
  return {
17
+ "states": torch.randn(1, 5, 10),
18
+ "actions": torch.zeros(1, 5, 3),
19
+ "returns_to_go": torch.ones(1, 5, 1)
 
20
  }
21
 
22
  def test_acdc_discovery(model, sample_inputs):
23
+ """Verifies that ACDC can prune heads and save a manifest."""
24
  model.eval()
25
+ acdc = ACDCDiscovery(model, threshold=0.5)
26
+ circuit = acdc.run(sample_inputs, target_action=1)
 
 
27
 
28
  assert "active_heads" in circuit
29
  assert "initial_perf" in circuit
 
30
 
 
31
  manifest_path = "circuit_manifest.json"
32
  acdc.save_manifest(manifest_path)
33
  assert os.path.exists(manifest_path)
 
39
  os.remove(manifest_path)
40
 
41
  def test_path_patching_ablation(model, sample_inputs):
42
+ """Verifies that ablating a head changes the model output."""
43
  engine = PathPatchingEngine(model)
44
 
45
+ orig_output = model(**sample_inputs)
46
+ ablated_output = engine.perform_edge_ablation(
 
 
 
47
  sample_inputs, layer=0, head_index=0, ablation_type="zero"
48
  )
49
 
 
50
  diff = (orig_output - ablated_output).abs().max().item()
51
+ assert diff > 0
52
 
53
  def test_evolutionary_scanner_mock(model, sample_inputs, tmp_path):
54
+ """Verifies scanning multiple checkpoints for circuit formation."""
55
  checkpoint_dir = tmp_path / "checkpoints"
56
  checkpoint_dir.mkdir()
57
 
 
59
  torch.save(model.state_dict(), checkpoint_dir / "step_200.pt")
60
 
61
  scanner = EvolutionaryScanner(HookedDT, state_dim=10, action_dim=3)
 
62
  results = scanner.scan_checkpoints(
63
  str(checkpoint_dir),
64
  sample_inputs,
 
68
  )
69
 
70
  assert len(results) == 2
 
71
  assert "active_heads" in results[0]