Spaces:
Running
Running
Commit ·
e2614dc
0
Parent(s):
First commit
Browse files- .gitignore +57 -0
- LICENSE +21 -0
- README.md +44 -0
- config.yaml +18 -0
- requirements.txt +15 -0
- scripts/train_dt.py +61 -0
- scripts/train_sae.py +34 -0
- src/dashboard/app.py +79 -0
- src/interpretability/attribution.py +57 -0
- src/interpretability/induction_scan.py +48 -0
- src/interpretability/patching.py +48 -0
- src/interpretability/steering.py +43 -0
- src/models/hooked_dt.py +137 -0
- tests/test_components.py +41 -0
.gitignore
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
build/
|
| 9 |
+
develop-eggs/
|
| 10 |
+
dist/
|
| 11 |
+
downloads/
|
| 12 |
+
eggs/
|
| 13 |
+
.eggs/
|
| 14 |
+
lib/
|
| 15 |
+
lib64/
|
| 16 |
+
parts/
|
| 17 |
+
sdist/
|
| 18 |
+
var/
|
| 19 |
+
wheels/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.installed.cfg
|
| 22 |
+
*.egg
|
| 23 |
+
|
| 24 |
+
# Virtual Environment
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
|
| 28 |
+
# Data and Models
|
| 29 |
+
data/
|
| 30 |
+
models/*.pt
|
| 31 |
+
models/*.pth
|
| 32 |
+
*.zip
|
| 33 |
+
*.h5
|
| 34 |
+
*.pt
|
| 35 |
+
|
| 36 |
+
# Experiment Tracking
|
| 37 |
+
wandb/
|
| 38 |
+
|
| 39 |
+
# IDEs
|
| 40 |
+
.vscode/
|
| 41 |
+
.idea/
|
| 42 |
+
.DS_Store
|
| 43 |
+
|
| 44 |
+
# Testing
|
| 45 |
+
.pytest_cache/
|
| 46 |
+
.coverage
|
| 47 |
+
htmlcov/
|
| 48 |
+
|
| 49 |
+
# Streamlit
|
| 50 |
+
.streamlit/
|
| 51 |
+
static/
|
| 52 |
+
|
| 53 |
+
# Environment Variables
|
| 54 |
+
.env
|
| 55 |
+
.venv
|
| 56 |
+
|
| 57 |
+
/PRD.md
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Sadhumitha S.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DT-Explorer
|
| 2 |
+
|
| 3 |
+
A research-grade platform for the mechanistic interpretability of Decision Transformers.
|
| 4 |
+
|
| 5 |
+
## Architecture
|
| 6 |
+
- **Data**: PPO Trajectory Harvester for high-quality teacher data.
|
| 7 |
+
- **Model**: `HookedDT` - A custom Decision Transformer wrapped in `TransformerLens` for full activation visibility.
|
| 8 |
+
- **Interpretability**: Tools for Direct Logit Attribution (DLA), Activation Patching, and Induction Head detection.
|
| 9 |
+
- **Dashboard**: Streamlit-based UI for real-time causal interventions.
|
| 10 |
+
|
| 11 |
+
## Quick Start
|
| 12 |
+
|
| 13 |
+
### 1. Install Dependencies
|
| 14 |
+
```bash
|
| 15 |
+
pip install -r requirements.txt
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
### 2. Collect Data & Train Mini-DT
|
| 19 |
+
```bash
|
| 20 |
+
python scripts/train_dt.py
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### 3. Run Interpretation Dashboard
|
| 24 |
+
```bash
|
| 25 |
+
streamlit run src/dashboard/app.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Testing
|
| 29 |
+
Run the test suite to ensure system integrity:
|
| 30 |
+
```bash
|
| 31 |
+
pytest tests/
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Components
|
| 35 |
+
- `src/data/harvester.py`: Collects trajectories from MiniGrid.
|
| 36 |
+
- `src/models/hooked_dt.py`: Hookable transformer implementation.
|
| 37 |
+
- `src/interpretability/`:
|
| 38 |
+
- `attribution.py`: Direct Logit Attribution logic.
|
| 39 |
+
- `patching.py`: Activation patching interface.
|
| 40 |
+
- `induction_scan.py`: Automated circuit discovery.
|
| 41 |
+
|
| 42 |
+
## License
|
| 43 |
+
|
| 44 |
+
MIT
|
config.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
n_layers: 2
|
| 3 |
+
n_heads: 4
|
| 4 |
+
d_model: 128
|
| 5 |
+
max_length: 30
|
| 6 |
+
|
| 7 |
+
data:
|
| 8 |
+
env_id: "MiniGrid-Empty-8x8-v0"
|
| 9 |
+
num_episodes: 1000
|
| 10 |
+
collection_method: "PPO-Teacher"
|
| 11 |
+
|
| 12 |
+
interpretability:
|
| 13 |
+
dla_threshold: 0.1
|
| 14 |
+
patching_metric: "logit_diff"
|
| 15 |
+
|
| 16 |
+
sae:
|
| 17 |
+
expansion_factor: 8
|
| 18 |
+
l1_coeff: 0.0005
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformer_lens
|
| 3 |
+
gymnasium
|
| 4 |
+
minigrid
|
| 5 |
+
sae-lens
|
| 6 |
+
wandb
|
| 7 |
+
streamlit
|
| 8 |
+
numpy
|
| 9 |
+
matplotlib
|
| 10 |
+
tqdm
|
| 11 |
+
einops
|
| 12 |
+
jaxtyping
|
| 13 |
+
pytest
|
| 14 |
+
stable-baselines3
|
| 15 |
+
shimmy
|
scripts/train_dt.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from src.models.hooked_dt import HookedDT
|
| 5 |
+
from src.data.harvester import PPOHarvester
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
def train():
|
| 10 |
+
# 1. Collect Data
|
| 11 |
+
harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
|
| 12 |
+
trajectories = harvester.collect_trajectories(num_episodes=100)
|
| 13 |
+
|
| 14 |
+
# 2. Setup Model
|
| 15 |
+
state_dim = trajectories[0]["observations"].shape[1]
|
| 16 |
+
action_dim = 7 # MiniGrid has 7 actions
|
| 17 |
+
|
| 18 |
+
model = HookedDT.from_config(
|
| 19 |
+
state_dim=state_dim,
|
| 20 |
+
action_dim=action_dim,
|
| 21 |
+
n_layers=1,
|
| 22 |
+
n_heads=4,
|
| 23 |
+
d_model=128
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
| 27 |
+
criterion = nn.CrossEntropyLoss()
|
| 28 |
+
|
| 29 |
+
# 3. Training Loop (Simplified)
|
| 30 |
+
model.train()
|
| 31 |
+
for epoch in range(10):
|
| 32 |
+
total_loss = 0
|
| 33 |
+
for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
|
| 34 |
+
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
|
| 35 |
+
actions = torch.from_numpy(traj["actions"]).long().unsqueeze(0)
|
| 36 |
+
# One-hot actions for input
|
| 37 |
+
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float()
|
| 38 |
+
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 39 |
+
timesteps = torch.arange(states.shape[1]).unsqueeze(0)
|
| 40 |
+
|
| 41 |
+
# Mask (dummy for now)
|
| 42 |
+
|
| 43 |
+
action_preds, _, _ = model(states, actions_one_hot, returns, timesteps)
|
| 44 |
+
|
| 45 |
+
# Target actions (shifted by 1 for next action prediction)
|
| 46 |
+
# Standard DT predicts a_t from s_t
|
| 47 |
+
loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
|
| 48 |
+
|
| 49 |
+
optimizer.zero_grad()
|
| 50 |
+
loss.backward()
|
| 51 |
+
optimizer.step()
|
| 52 |
+
|
| 53 |
+
total_loss += loss.item()
|
| 54 |
+
|
| 55 |
+
print(f"Epoch {epoch} Loss: {total_loss / len(trajectories)}")
|
| 56 |
+
|
| 57 |
+
torch.save(model.state_dict(), "models/mini_dt.pt")
|
| 58 |
+
print("Model saved to models/mini_dt.pt")
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
train()
|
scripts/train_sae.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from sae_lens import SAEConfig, SAE
|
| 3 |
+
from src.models.hooked_dt import HookedDT
|
| 4 |
+
|
| 5 |
+
def train_sae():
|
| 6 |
+
# Load DT
|
| 7 |
+
state_dim = 2739
|
| 8 |
+
action_dim = 7
|
| 9 |
+
model = HookedDT.from_config(state_dim, action_dim)
|
| 10 |
+
# model.load_state_dict(torch.load("models/mini_dt.pt"))
|
| 11 |
+
|
| 12 |
+
# Configure SAE
|
| 13 |
+
cfg = SAEConfig(
|
| 14 |
+
d_in=128, # d_model
|
| 15 |
+
d_sae=128 * 8, # Expansion factor
|
| 16 |
+
hook_point="blocks.0.hook_resid_post",
|
| 17 |
+
hook_point_layer=0,
|
| 18 |
+
architecture="standard",
|
| 19 |
+
activation_fn="relu",
|
| 20 |
+
expansion_factor=8,
|
| 21 |
+
l1_coefficient=5e-4,
|
| 22 |
+
lr=3e-4,
|
| 23 |
+
train_batch_size=4096,
|
| 24 |
+
context_size=30, # Sequence length
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
sae = SAE(cfg)
|
| 28 |
+
|
| 29 |
+
# Training logic would go here, using activations from the DT
|
| 30 |
+
print("SAE Configured for DT-Explorer.")
|
| 31 |
+
print(f"Hooking into: {cfg.hook_point}")
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
train_sae()
|
src/dashboard/app.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from src.models.hooked_dt import HookedDT
|
| 6 |
+
from src.interpretability.attribution import LogitAttributionEngine
|
| 7 |
+
from src.interpretability.patching import ActivationPatcher
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="DT-Explorer", layout="wide")
|
| 10 |
+
|
| 11 |
+
st.title("DT-Explorer: Mechanistic Interpretability for Decision Transformers")
|
| 12 |
+
|
| 13 |
+
# Sidebar for controls
|
| 14 |
+
st.sidebar.header("Model Configuration")
|
| 15 |
+
n_layers = st.sidebar.slider("Layers", 1, 12, 1)
|
| 16 |
+
n_heads = st.sidebar.slider("Heads", 1, 8, 4)
|
| 17 |
+
|
| 18 |
+
# Load Model
|
| 19 |
+
@st.cache_resource
|
| 20 |
+
def load_model():
|
| 21 |
+
# Placeholder dimensions for MiniGrid
|
| 22 |
+
state_dim = 2739 # FlatObsWrapper for 8x8 MiniGrid
|
| 23 |
+
action_dim = 7
|
| 24 |
+
model = HookedDT.from_config(state_dim, action_dim, n_layers=n_layers, n_heads=n_heads)
|
| 25 |
+
# model.load_state_dict(torch.load("models/mini_dt.pt"))
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
model = load_model()
|
| 29 |
+
|
| 30 |
+
# Dashboard Tabs
|
| 31 |
+
tab1, tab2, tab3 = st.tabs(["Circuit Mapping", "Causal Intervention", "SAE Explorer"])
|
| 32 |
+
|
| 33 |
+
with tab1:
|
| 34 |
+
st.header("Direct Logit Attribution")
|
| 35 |
+
# Simulate a forward pass
|
| 36 |
+
if st.button("Run Attribution Analysis"):
|
| 37 |
+
# Dummy data for demo
|
| 38 |
+
states = torch.randn(1, 10, model.state_dim)
|
| 39 |
+
actions = torch.randn(1, 10, model.action_dim)
|
| 40 |
+
returns = torch.randn(1, 10, 1)
|
| 41 |
+
timesteps = torch.arange(10).unsqueeze(0)
|
| 42 |
+
|
| 43 |
+
# Capture cache
|
| 44 |
+
logits, cache = model.transformer.run_with_cache(
|
| 45 |
+
# Need to handle DT's interleaved forward pass here
|
| 46 |
+
# For demo, we'll just show the UI structure
|
| 47 |
+
torch.randn(1, 30, model.cfg.d_model)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
engine = LogitAttributionEngine(model)
|
| 51 |
+
# dla = engine.calculate_dla(cache, target_logit_index=0)
|
| 52 |
+
|
| 53 |
+
# Placeholder plot
|
| 54 |
+
fig, ax = plt.subplots()
|
| 55 |
+
dla_mock = np.random.randn(n_layers, n_heads)
|
| 56 |
+
im = ax.imshow(dla_mock, cmap="RdBu_r")
|
| 57 |
+
plt.colorbar(im)
|
| 58 |
+
st.pyplot(fig)
|
| 59 |
+
|
| 60 |
+
with tab2:
|
| 61 |
+
st.header("Activation Patching")
|
| 62 |
+
col1, col2 = st.columns(2)
|
| 63 |
+
with col1:
|
| 64 |
+
st.subheader("Clean Run")
|
| 65 |
+
st.text("Input: Goal is visible")
|
| 66 |
+
with col2:
|
| 67 |
+
st.subheader("Corrupted Run")
|
| 68 |
+
st.text("Input: Goal is blocked")
|
| 69 |
+
|
| 70 |
+
layer_to_patch = st.selectbox("Select Layer", range(n_layers))
|
| 71 |
+
head_to_patch = st.selectbox("Select Head", range(n_heads))
|
| 72 |
+
|
| 73 |
+
if st.button("Apply Patch"):
|
| 74 |
+
st.success(f"Patched Layer {layer_to_patch}, Head {head_to_patch}")
|
| 75 |
+
st.metric("Probability Drop", "0.42", delta="-0.15")
|
| 76 |
+
|
| 77 |
+
with tab3:
|
| 78 |
+
st.header("SAE Monosemantic Latents")
|
| 79 |
+
st.info("SAE Integration Coming Soon (Phase 3)")
|
src/interpretability/attribution.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from jaxtyping import Float
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
|
| 7 |
+
class LogitAttributionEngine:
|
| 8 |
+
"""
|
| 9 |
+
Calculates the Direct Logit Attribution (DLA) of transformer components.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, model):
|
| 12 |
+
self.model = model
|
| 13 |
+
|
| 14 |
+
def calculate_dla(
|
| 15 |
+
self,
|
| 16 |
+
cache,
|
| 17 |
+
target_logit_index: int,
|
| 18 |
+
token_index: int = -1
|
| 19 |
+
) -> Dict[str, Float[torch.Tensor, "layer head"]]:
|
| 20 |
+
"""
|
| 21 |
+
Computes DLA for each head in the model.
|
| 22 |
+
Formula: DLA = Activation @ W_O @ W_U [target_logit]
|
| 23 |
+
"""
|
| 24 |
+
n_layers = self.model.cfg.n_layers
|
| 25 |
+
n_heads = self.model.cfg.n_heads
|
| 26 |
+
d_model = self.model.cfg.d_model
|
| 27 |
+
|
| 28 |
+
# Get the unembedding matrix for the action prediction head
|
| 29 |
+
# In our HookedDT, the prediction head is a Linear layer: self.predict_action[0].weight
|
| 30 |
+
W_U = self.model.predict_action[0].weight[target_logit_index] # [d_model]
|
| 31 |
+
|
| 32 |
+
dla_results = torch.zeros((n_layers, n_heads))
|
| 33 |
+
|
| 34 |
+
for layer in range(n_layers):
|
| 35 |
+
# Head outputs from cache: [batch, pos, head, d_model]
|
| 36 |
+
# For HookedTransformer, it's usually 'blocks.{layer}.attn.hook_result'
|
| 37 |
+
head_outputs = cache[f"blocks.{layer}.attn.hook_result"] # [batch, pos, head, d_model]
|
| 38 |
+
|
| 39 |
+
# We take the token_index (usually the last state token)
|
| 40 |
+
# In interleaved (R, S, A), S_t is at 3t + 1
|
| 41 |
+
# If we want the last predicted action, we look at the last state token's output
|
| 42 |
+
|
| 43 |
+
last_token_output = head_outputs[0, token_index] # [head, d_model]
|
| 44 |
+
|
| 45 |
+
# Attribution: projection onto W_U
|
| 46 |
+
attribution = torch.matmul(last_token_output, W_U) # [head]
|
| 47 |
+
dla_results[layer] = attribution
|
| 48 |
+
|
| 49 |
+
return dla_results
|
| 50 |
+
|
| 51 |
+
def plot_dla(self, dla_results: torch.Tensor, title="Direct Logit Attribution"):
|
| 52 |
+
plt.figure(figsize=(10, 6))
|
| 53 |
+
sns.heatmap(dla_results.detach().cpu().numpy(), annot=True, fmt=".2f", cmap="RdBu_r", center=0)
|
| 54 |
+
plt.xlabel("Head")
|
| 55 |
+
plt.ylabel("Layer")
|
| 56 |
+
plt.title(title)
|
| 57 |
+
plt.show()
|
src/interpretability/induction_scan.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
class InductionScanner:
|
| 5 |
+
"""
|
| 6 |
+
Automated scan for Induction Heads.
|
| 7 |
+
Induction heads attend to the token that followed the current token's previous occurrence.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, model):
|
| 10 |
+
self.model = model
|
| 11 |
+
|
| 12 |
+
def scan(self, cache, sequence: torch.Tensor) -> List[Tuple[int, int]]:
|
| 13 |
+
"""
|
| 14 |
+
Scans all heads for 'Induction' behavior on a given sequence.
|
| 15 |
+
Logic: For token S, find previous occurrence of S at index i.
|
| 16 |
+
Check if current token attends to token at i+1.
|
| 17 |
+
"""
|
| 18 |
+
n_layers = self.model.cfg.n_layers
|
| 19 |
+
n_heads = self.model.cfg.n_heads
|
| 20 |
+
seq_len = sequence.shape[1]
|
| 21 |
+
|
| 22 |
+
induction_heads = []
|
| 23 |
+
|
| 24 |
+
# Find repeated tokens
|
| 25 |
+
# For simplicity, we assume 'sequence' is the flattened list of tokens (or states)
|
| 26 |
+
# In DT, this is more complex due to interleaving.
|
| 27 |
+
# Let's look at state tokens specifically.
|
| 28 |
+
|
| 29 |
+
for layer in range(n_layers):
|
| 30 |
+
attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"] # [batch, head, query_pos, key_pos]
|
| 31 |
+
|
| 32 |
+
for head in range(n_heads):
|
| 33 |
+
score = self._calculate_induction_score(attn_pattern[0, head])
|
| 34 |
+
if score > 0.5: # Threshold for induction
|
| 35 |
+
induction_heads.append((layer, head))
|
| 36 |
+
|
| 37 |
+
return induction_heads
|
| 38 |
+
|
| 39 |
+
def _calculate_induction_score(self, pattern: torch.Tensor) -> float:
|
| 40 |
+
"""
|
| 41 |
+
Simplified induction score.
|
| 42 |
+
Checks if the attention is shifted by 1 relative to a diagonal.
|
| 43 |
+
This is a heuristic; more robust methods exist in TransformerLens.
|
| 44 |
+
"""
|
| 45 |
+
# In a real scenario, we'd use a sequence like [A, B, C, ..., A]
|
| 46 |
+
# and check if the second A attends to B.
|
| 47 |
+
# Here we just return a placeholder logic for the scan structure.
|
| 48 |
+
return torch.diagonal(pattern, offset=-1).mean().item()
|
src/interpretability/patching.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Callable, List, Optional
|
| 3 |
+
from transformer_lens import HookedTransformer
|
| 4 |
+
|
| 5 |
+
class ActivationPatcher:
|
| 6 |
+
"""
|
| 7 |
+
Interface for causal interventions via activation patching.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, model):
|
| 10 |
+
self.model = model
|
| 11 |
+
|
| 12 |
+
def patch_head(
|
| 13 |
+
self,
|
| 14 |
+
clean_inputs: dict,
|
| 15 |
+
corrupted_cache: dict,
|
| 16 |
+
layer: int,
|
| 17 |
+
head_index: int,
|
| 18 |
+
target_token_index: int = -1
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Replaces the output of a specific head in a clean run with values from a corrupted run.
|
| 22 |
+
"""
|
| 23 |
+
def patch_hook(value, hook):
|
| 24 |
+
# value: [batch, pos, head, d_model]
|
| 25 |
+
corrupted_value = corrupted_cache[hook.name]
|
| 26 |
+
value[:, target_token_index, head_index, :] = corrupted_value[:, target_token_index, head_index, :]
|
| 27 |
+
return value
|
| 28 |
+
|
| 29 |
+
hook_name = f"blocks.{layer}.attn.hook_result"
|
| 30 |
+
|
| 31 |
+
# Run the model with the hook
|
| 32 |
+
with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
|
| 33 |
+
patched_outputs = self.model(**clean_inputs)
|
| 34 |
+
|
| 35 |
+
return patched_outputs
|
| 36 |
+
|
| 37 |
+
def calculate_probability_drop(
|
| 38 |
+
self,
|
| 39 |
+
clean_probs: torch.Tensor,
|
| 40 |
+
patched_probs: torch.Tensor,
|
| 41 |
+
correct_action_index: int
|
| 42 |
+
) -> float:
|
| 43 |
+
"""
|
| 44 |
+
Measures the impact of patching on the target action probability.
|
| 45 |
+
"""
|
| 46 |
+
clean_val = clean_probs[0, -1, correct_action_index].item()
|
| 47 |
+
patched_val = patched_probs[0, -1, correct_action_index].item()
|
| 48 |
+
return clean_val - patched_val
|
src/interpretability/steering.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
class RTGSteerer:
|
| 6 |
+
"""
|
| 7 |
+
Enables 'Behavioral Steering' by manipulating Reward-to-Go (RTG) tokens.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, model):
|
| 10 |
+
self.model = model
|
| 11 |
+
|
| 12 |
+
def steer(
|
| 13 |
+
self,
|
| 14 |
+
states: torch.Tensor,
|
| 15 |
+
actions: torch.Tensor,
|
| 16 |
+
base_rtg: torch.Tensor,
|
| 17 |
+
steering_vector: torch.Tensor,
|
| 18 |
+
alpha: float = 1.0
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Adds a steering vector to the RTG embeddings.
|
| 22 |
+
RTG_new = RTG_base + alpha * steering_vector
|
| 23 |
+
"""
|
| 24 |
+
# Embed base RTG
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
rtg_emb = self.model.embed_return(base_rtg)
|
| 27 |
+
|
| 28 |
+
# Apply steering
|
| 29 |
+
steered_rtg_emb = rtg_emb + alpha * steering_vector
|
| 30 |
+
|
| 31 |
+
# Hook the model to use the steered RTG
|
| 32 |
+
# This requires a slightly more complex hook in HookedDT
|
| 33 |
+
# For now, we returns the steered embedding to be used in a custom forward pass
|
| 34 |
+
return steered_rtg_emb
|
| 35 |
+
|
| 36 |
+
def find_success_vector(self, high_reward_cache, low_reward_cache):
|
| 37 |
+
"""
|
| 38 |
+
Identifies the 'Success Vector' by comparing high vs low reward activations.
|
| 39 |
+
Vector = Mean(High Reward Residual) - Mean(Low Reward Residual)
|
| 40 |
+
"""
|
| 41 |
+
high_res = high_reward_cache["blocks.0.hook_resid_post"].mean(dim=(0, 1))
|
| 42 |
+
low_res = low_reward_cache["blocks.0.hook_resid_post"].mean(dim=(0, 1))
|
| 43 |
+
return high_res - low_res
|
src/models/hooked_dt.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformer_lens import HookedTransformer, HookedTransformerConfig
|
| 4 |
+
from jaxtyping import Float, Int
|
| 5 |
+
from typing import Optional, Union, List
|
| 6 |
+
|
| 7 |
+
class HookedDT(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A Decision Transformer implementation wrapped in TransformerLens logic.
|
| 10 |
+
Supports State, Action, and Reward-to-Go (RTG) tokens.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
cfg: HookedTransformerConfig,
|
| 15 |
+
state_dim: int,
|
| 16 |
+
action_dim: int,
|
| 17 |
+
max_length: int = 30,
|
| 18 |
+
max_ep_len: int = 1000,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.cfg = cfg
|
| 22 |
+
self.state_dim = state_dim
|
| 23 |
+
self.action_dim = action_dim
|
| 24 |
+
self.max_length = max_length
|
| 25 |
+
|
| 26 |
+
# HookedTransformer for the core transformer blocks
|
| 27 |
+
self.transformer = HookedTransformer(cfg)
|
| 28 |
+
|
| 29 |
+
# Custom embeddings for DT
|
| 30 |
+
self.embed_return = nn.Linear(1, cfg.d_model)
|
| 31 |
+
self.embed_state = nn.Linear(state_dim, cfg.d_model)
|
| 32 |
+
self.embed_action = nn.Linear(action_dim, cfg.d_model)
|
| 33 |
+
|
| 34 |
+
self.embed_ln = nn.LayerNorm(cfg.d_model)
|
| 35 |
+
|
| 36 |
+
# Prediction heads
|
| 37 |
+
self.predict_action = nn.Sequential(
|
| 38 |
+
nn.Linear(cfg.d_model, action_dim)
|
| 39 |
+
)
|
| 40 |
+
self.predict_return = nn.Sequential(
|
| 41 |
+
nn.Linear(cfg.d_model, 1)
|
| 42 |
+
)
|
| 43 |
+
self.predict_state = nn.Sequential(
|
| 44 |
+
nn.Linear(cfg.d_model, state_dim)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(
|
| 48 |
+
self,
|
| 49 |
+
states: Float[torch.Tensor, "batch seq state_dim"],
|
| 50 |
+
actions: Float[torch.Tensor, "batch seq action_dim"],
|
| 51 |
+
returns_to_go: Float[torch.Tensor, "batch seq 1"],
|
| 52 |
+
timesteps: Int[torch.Tensor, "batch seq"],
|
| 53 |
+
attention_mask: Optional[Float[torch.Tensor, "batch seq"]] = None,
|
| 54 |
+
):
|
| 55 |
+
batch_size, seq_len, _ = states.shape
|
| 56 |
+
|
| 57 |
+
# Embed tokens
|
| 58 |
+
state_embeddings = self.embed_state(states)
|
| 59 |
+
action_embeddings = self.embed_action(actions)
|
| 60 |
+
returns_embeddings = self.embed_return(returns_to_go)
|
| 61 |
+
|
| 62 |
+
# In DT, we interleave (R, S, A)
|
| 63 |
+
# Sequence: (R1, S1, A1, R2, S2, A2, ...)
|
| 64 |
+
stacked_inputs = torch.stack(
|
| 65 |
+
(returns_embeddings, state_embeddings, action_embeddings), dim=2
|
| 66 |
+
).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
|
| 67 |
+
|
| 68 |
+
stacked_inputs = self.embed_ln(stacked_inputs)
|
| 69 |
+
|
| 70 |
+
# Add positional embeddings manually or via HookedTransformer
|
| 71 |
+
# DT usually uses learned positional embeddings for timesteps
|
| 72 |
+
# HookedTransformer usually handles this via its own embed_pos
|
| 73 |
+
# We'll use the timestep info to get positional embeddings
|
| 74 |
+
|
| 75 |
+
# For simplicity, let's assume we can use HookedTransformer's forward
|
| 76 |
+
# but we need to handle the interleaved nature.
|
| 77 |
+
|
| 78 |
+
# We pass the stacked_inputs directly to the transformer blocks
|
| 79 |
+
# We use run_with_cache or standard forward based on whether we need the cache
|
| 80 |
+
# For TransformerLens, we need to specify that we are passing embeddings
|
| 81 |
+
|
| 82 |
+
# Note: HookedTransformer expects [batch, pos, d_model] if input is embeddings
|
| 83 |
+
# We need to set use_local_embeddings=True or similar if we want to bypass default embeds
|
| 84 |
+
|
| 85 |
+
# A better way is to use model.blocks directly or use the hook_embed to inject
|
| 86 |
+
|
| 87 |
+
def embed_hook(value, hook):
|
| 88 |
+
return stacked_inputs
|
| 89 |
+
|
| 90 |
+
# We inject our interleaved embeddings into the 'hook_embed'
|
| 91 |
+
# and pass a dummy tensor of the right shape to the transformer
|
| 92 |
+
dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
|
| 93 |
+
|
| 94 |
+
# We want the residual stream after the last block
|
| 95 |
+
# HookedTransformer.run_with_cache returns (output, cache)
|
| 96 |
+
# We can also use return_type="residual" or similar in some versions,
|
| 97 |
+
# but let's just use the cache or the direct output if we set it up correctly.
|
| 98 |
+
|
| 99 |
+
# In TransformerLens, the output of the forward pass is usually the logits.
|
| 100 |
+
# We want the 'hook_resid_post' of the last block.
|
| 101 |
+
|
| 102 |
+
last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
|
| 103 |
+
|
| 104 |
+
with self.transformer.hooks(fwd_hooks=[("hook_embed", embed_hook)]):
|
| 105 |
+
_, cache = self.transformer.run_with_cache(
|
| 106 |
+
dummy_input,
|
| 107 |
+
names_filter=lambda name: name == last_block_hook
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
transformer_outputs = cache[last_block_hook]
|
| 111 |
+
|
| 112 |
+
# Reshape back to (batch, seq, 3, d_model)
|
| 113 |
+
x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
|
| 114 |
+
|
| 115 |
+
# Predict (A from S, S from A, R from S?)
|
| 116 |
+
# Standard DT: Action is predicted from State token
|
| 117 |
+
action_preds = self.predict_action(x[:, :, 1]) # predict next action from state
|
| 118 |
+
return_preds = self.predict_return(x[:, :, 2]) # predict next return from action
|
| 119 |
+
state_preds = self.predict_state(x[:, :, 2]) # predict next state from action
|
| 120 |
+
|
| 121 |
+
return action_preds, state_preds, return_preds
|
| 122 |
+
|
| 123 |
+
@classmethod
|
| 124 |
+
def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128):
|
| 125 |
+
cfg = HookedTransformerConfig(
|
| 126 |
+
n_layers=n_layers,
|
| 127 |
+
d_model=d_model,
|
| 128 |
+
n_ctx=300, # Max sequence length * 3
|
| 129 |
+
d_head=d_model // n_heads,
|
| 130 |
+
n_heads=n_heads,
|
| 131 |
+
d_vocab=10, # Dummy value, we use custom embeddings
|
| 132 |
+
act_fn="relu", # DT original uses ReLU or GeLU
|
| 133 |
+
d_mlp=d_model * 4,
|
| 134 |
+
normalization_type="LN",
|
| 135 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 136 |
+
)
|
| 137 |
+
return cls(cfg, state_dim, action_dim)
|
tests/test_components.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import torch
|
| 3 |
+
from src.models.hooked_dt import HookedDT
|
| 4 |
+
from src.interpretability.attribution import LogitAttributionEngine
|
| 5 |
+
from transformer_lens import HookedTransformerConfig
|
| 6 |
+
|
| 7 |
+
def test_hooked_dt_forward():
|
| 8 |
+
state_dim = 10
|
| 9 |
+
action_dim = 5
|
| 10 |
+
seq_len = 5
|
| 11 |
+
batch_size = 2
|
| 12 |
+
|
| 13 |
+
model = HookedDT.from_config(state_dim, action_dim, n_layers=1, n_heads=2, d_model=32)
|
| 14 |
+
|
| 15 |
+
states = torch.randn(batch_size, seq_len, state_dim)
|
| 16 |
+
actions = torch.randn(batch_size, seq_len, action_dim)
|
| 17 |
+
returns = torch.randn(batch_size, seq_len, 1)
|
| 18 |
+
timesteps = torch.arange(seq_len).repeat(batch_size, 1)
|
| 19 |
+
|
| 20 |
+
action_preds, state_preds, return_preds = model(states, actions, returns, timesteps)
|
| 21 |
+
|
| 22 |
+
assert action_preds.shape == (batch_size, seq_len, action_dim)
|
| 23 |
+
assert state_preds.shape == (batch_size, seq_len, state_dim)
|
| 24 |
+
assert return_preds.shape == (batch_size, seq_len, 1)
|
| 25 |
+
|
| 26 |
+
def test_logit_attribution_shape():
|
| 27 |
+
state_dim = 10
|
| 28 |
+
action_dim = 5
|
| 29 |
+
model = HookedDT.from_config(state_dim, action_dim, n_layers=2, n_heads=4, d_model=32)
|
| 30 |
+
engine = LogitAttributionEngine(model)
|
| 31 |
+
|
| 32 |
+
# Mock cache
|
| 33 |
+
cache = {}
|
| 34 |
+
for l in range(2):
|
| 35 |
+
cache[f"blocks.{l}.attn.hook_result"] = torch.randn(1, 15, 4, 32)
|
| 36 |
+
|
| 37 |
+
dla = engine.calculate_dla(cache, target_logit_index=0, token_index=-1)
|
| 38 |
+
assert dla.shape == (2, 4)
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
pytest.main([__file__])
|