Spaces:
Running
Running
Commit ·
731ae64
1
Parent(s): 0346604
refactor: added logic as comments
Browse files- README.md +15 -16
- scripts/train_dt.py +1 -9
- src/dashboard/app.py +1 -12
- src/interpretability/attribution.py +8 -17
- src/interpretability/induction_scan.py +7 -19
- src/interpretability/patching.py +0 -1
- src/interpretability/sae_manager.py +6 -9
- src/models/hooked_dt.py +9 -40
README.md
CHANGED
|
@@ -1,32 +1,31 @@
|
|
| 1 |
# DT-Circuits: Mechanistic Interpretability for Decision Transformers
|
| 2 |
|
| 3 |
-
DT-Circuits is a
|
| 4 |
|
| 5 |
-
The
|
| 6 |
|
| 7 |
## Core Capabilities
|
| 8 |
|
| 9 |
### 1. Circuit Foundation
|
| 10 |
-
- **Hooked-DT
|
| 11 |
-
- **Direct Logit Attribution (DLA)**:
|
| 12 |
-
- **Induction Head Discovery**:
|
| 13 |
|
| 14 |
### 2. Causal Interventions
|
| 15 |
-
- **Activation Patching**:
|
| 16 |
-
- **
|
| 17 |
-
- **Steering Library**: A persistent library of pre-calculated vectors (e.g., success_vector, exploration_vector) that can be injected at inference time to manipulate agent behavior without retraining.
|
| 18 |
|
| 19 |
-
### 3.
|
| 20 |
-
- **
|
| 21 |
-
- **
|
| 22 |
|
| 23 |
## Technical Architecture
|
| 24 |
|
| 25 |
-
The platform
|
| 26 |
-
- **Data Layer**: PPO Trajectory Harvester for
|
| 27 |
-
- **Model Layer**:
|
| 28 |
-
- **Interpretability Layer**:
|
| 29 |
-
- **Visualization Layer**:
|
| 30 |
|
| 31 |
## Getting Started
|
| 32 |
|
|
|
|
| 1 |
# DT-Circuits: Mechanistic Interpretability for Decision Transformers
|
| 2 |
|
| 3 |
+
DT-Circuits is a framework for mechanistic interpretability of Decision Transformers (DT). Using TransformerLens, it enables mapping neural circuits, decomposing activations with Sparse Autoencoders (SAEs), and performing causal interventions on agent decision-making.
|
| 4 |
|
| 5 |
+
The goal is to understand how Reward-to-Go, State, and Action tokens are processed within the residual stream, moving beyond basic behavioral observation.
|
| 6 |
|
| 7 |
## Core Capabilities
|
| 8 |
|
| 9 |
### 1. Circuit Foundation
|
| 10 |
+
- **Hooked-DT**: A Decision Transformer implementation wrapped in TransformerLens for access to internal activations and weights.
|
| 11 |
+
- **Direct Logit Attribution (DLA)**: Quantifies the contribution of individual heads and MLP layers to action logits.
|
| 12 |
+
- **Induction Head Discovery**: Tools to identify heads responsible for temporal pattern recognition.
|
| 13 |
|
| 14 |
### 2. Causal Interventions
|
| 15 |
+
- **Activation Patching**: Replaces activations between clean and corrupted runs to identify causal paths.
|
| 16 |
+
- **Steering**: Generates and applies steering vectors (e.g., via Contrastive Activation Addition) to manipulate agent behavior at inference time.
|
|
|
|
| 17 |
|
| 18 |
+
### 3. SAEs & Safety
|
| 19 |
+
- **SAE Integration**: Tools to train and deploy SAEs on the residual stream to find monosemantic latents.
|
| 20 |
+
- **Anomaly Detection**: Uses SAE reconstruction error to detect out-of-distribution (OOD) states.
|
| 21 |
|
| 22 |
## Technical Architecture
|
| 23 |
|
| 24 |
+
The platform consists of:
|
| 25 |
+
- **Data Layer**: PPO Trajectory Harvester for collecting expert demonstrations (e.g., MiniGrid).
|
| 26 |
+
- **Model Layer**: HookedDT implementation.
|
| 27 |
+
- **Interpretability Layer**: Modules for attribution, patching, SAE management, and steering.
|
| 28 |
+
- **Visualization Layer**: Streamlit dashboard for real-time monitoring and intervention.
|
| 29 |
|
| 30 |
## Getting Started
|
| 31 |
|
scripts/train_dt.py
CHANGED
|
@@ -7,13 +7,11 @@ import numpy as np
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
def train():
|
| 10 |
-
# 1. Collect Data
|
| 11 |
harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
|
| 12 |
trajectories = harvester.collect_trajectories(num_episodes=100)
|
| 13 |
|
| 14 |
-
# 2. Setup Model
|
| 15 |
state_dim = trajectories[0]["observations"].shape[1]
|
| 16 |
-
action_dim = 7 # MiniGrid
|
| 17 |
|
| 18 |
model = HookedDT.from_config(
|
| 19 |
state_dim=state_dim,
|
|
@@ -26,24 +24,18 @@ def train():
|
|
| 26 |
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
| 27 |
criterion = nn.CrossEntropyLoss()
|
| 28 |
|
| 29 |
-
# 3. Training Loop (Simplified)
|
| 30 |
model.train()
|
| 31 |
for epoch in range(10):
|
| 32 |
total_loss = 0
|
| 33 |
for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
|
| 34 |
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
|
| 35 |
actions = torch.from_numpy(traj["actions"]).long().unsqueeze(0)
|
| 36 |
-
# One-hot actions for input
|
| 37 |
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float()
|
| 38 |
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 39 |
timesteps = torch.arange(states.shape[1]).unsqueeze(0)
|
| 40 |
|
| 41 |
-
# Mask (dummy for now)
|
| 42 |
-
|
| 43 |
action_preds, _, _ = model(states, actions_one_hot, returns, timesteps)
|
| 44 |
|
| 45 |
-
# Target actions (shifted by 1 for next action prediction)
|
| 46 |
-
# Standard DT predicts a_t from s_t
|
| 47 |
loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
|
| 48 |
|
| 49 |
optimizer.zero_grad()
|
|
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
def train():
|
|
|
|
| 10 |
harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
|
| 11 |
trajectories = harvester.collect_trajectories(num_episodes=100)
|
| 12 |
|
|
|
|
| 13 |
state_dim = trajectories[0]["observations"].shape[1]
|
| 14 |
+
action_dim = 7 # MiniGrid
|
| 15 |
|
| 16 |
model = HookedDT.from_config(
|
| 17 |
state_dim=state_dim,
|
|
|
|
| 24 |
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
| 25 |
criterion = nn.CrossEntropyLoss()
|
| 26 |
|
|
|
|
| 27 |
model.train()
|
| 28 |
for epoch in range(10):
|
| 29 |
total_loss = 0
|
| 30 |
for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
|
| 31 |
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
|
| 32 |
actions = torch.from_numpy(traj["actions"]).long().unsqueeze(0)
|
|
|
|
| 33 |
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float()
|
| 34 |
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
|
| 35 |
timesteps = torch.arange(states.shape[1]).unsqueeze(0)
|
| 36 |
|
|
|
|
|
|
|
| 37 |
action_preds, _, _ = model(states, actions_one_hot, returns, timesteps)
|
| 38 |
|
|
|
|
|
|
|
| 39 |
loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
|
| 40 |
|
| 41 |
optimizer.zero_grad()
|
src/dashboard/app.py
CHANGED
|
@@ -10,47 +10,36 @@ st.set_page_config(page_title="DT-Explorer", layout="wide")
|
|
| 10 |
|
| 11 |
st.title("DT-Explorer: Mechanistic Interpretability for Decision Transformers")
|
| 12 |
|
| 13 |
-
# Sidebar for controls
|
| 14 |
st.sidebar.header("Model Configuration")
|
| 15 |
n_layers = st.sidebar.slider("Layers", 1, 12, 1)
|
| 16 |
n_heads = st.sidebar.slider("Heads", 1, 8, 4)
|
| 17 |
|
| 18 |
-
# Load Model
|
| 19 |
@st.cache_resource
|
| 20 |
def load_model():
|
| 21 |
-
# Placeholder dimensions for MiniGrid
|
| 22 |
state_dim = 2739 # FlatObsWrapper for 8x8 MiniGrid
|
| 23 |
action_dim = 7
|
| 24 |
model = HookedDT.from_config(state_dim, action_dim, n_layers=n_layers, n_heads=n_heads)
|
| 25 |
-
# model.load_state_dict(torch.load("models/mini_dt.pt"))
|
| 26 |
return model
|
| 27 |
|
| 28 |
model = load_model()
|
| 29 |
|
| 30 |
-
# Dashboard Tabs
|
| 31 |
tab1, tab2, tab3 = st.tabs(["Circuit Mapping", "Causal Intervention", "SAE Explorer"])
|
| 32 |
|
| 33 |
with tab1:
|
| 34 |
st.header("Direct Logit Attribution")
|
| 35 |
-
# Simulate a forward pass
|
| 36 |
if st.button("Run Attribution Analysis"):
|
| 37 |
-
#
|
| 38 |
states = torch.randn(1, 10, model.state_dim)
|
| 39 |
actions = torch.randn(1, 10, model.action_dim)
|
| 40 |
returns = torch.randn(1, 10, 1)
|
| 41 |
timesteps = torch.arange(10).unsqueeze(0)
|
| 42 |
|
| 43 |
-
# Capture cache
|
| 44 |
logits, cache = model.transformer.run_with_cache(
|
| 45 |
-
# Need to handle DT's interleaved forward pass here
|
| 46 |
-
# For demo, we'll just show the UI structure
|
| 47 |
torch.randn(1, 30, model.cfg.d_model)
|
| 48 |
)
|
| 49 |
|
| 50 |
engine = LogitAttributionEngine(model)
|
| 51 |
-
# dla = engine.calculate_dla(cache, target_logit_index=0)
|
| 52 |
|
| 53 |
-
# Placeholder plot
|
| 54 |
fig, ax = plt.subplots()
|
| 55 |
dla_mock = np.random.randn(n_layers, n_heads)
|
| 56 |
im = ax.imshow(dla_mock, cmap="RdBu_r")
|
|
|
|
| 10 |
|
| 11 |
st.title("DT-Explorer: Mechanistic Interpretability for Decision Transformers")
|
| 12 |
|
|
|
|
| 13 |
st.sidebar.header("Model Configuration")
|
| 14 |
n_layers = st.sidebar.slider("Layers", 1, 12, 1)
|
| 15 |
n_heads = st.sidebar.slider("Heads", 1, 8, 4)
|
| 16 |
|
|
|
|
| 17 |
@st.cache_resource
|
| 18 |
def load_model():
|
|
|
|
| 19 |
state_dim = 2739 # FlatObsWrapper for 8x8 MiniGrid
|
| 20 |
action_dim = 7
|
| 21 |
model = HookedDT.from_config(state_dim, action_dim, n_layers=n_layers, n_heads=n_heads)
|
|
|
|
| 22 |
return model
|
| 23 |
|
| 24 |
model = load_model()
|
| 25 |
|
|
|
|
| 26 |
tab1, tab2, tab3 = st.tabs(["Circuit Mapping", "Causal Intervention", "SAE Explorer"])
|
| 27 |
|
| 28 |
with tab1:
|
| 29 |
st.header("Direct Logit Attribution")
|
|
|
|
| 30 |
if st.button("Run Attribution Analysis"):
|
| 31 |
+
# Mock data for demo
|
| 32 |
states = torch.randn(1, 10, model.state_dim)
|
| 33 |
actions = torch.randn(1, 10, model.action_dim)
|
| 34 |
returns = torch.randn(1, 10, 1)
|
| 35 |
timesteps = torch.arange(10).unsqueeze(0)
|
| 36 |
|
|
|
|
| 37 |
logits, cache = model.transformer.run_with_cache(
|
|
|
|
|
|
|
| 38 |
torch.randn(1, 30, model.cfg.d_model)
|
| 39 |
)
|
| 40 |
|
| 41 |
engine = LogitAttributionEngine(model)
|
|
|
|
| 42 |
|
|
|
|
| 43 |
fig, ax = plt.subplots()
|
| 44 |
dla_mock = np.random.randn(n_layers, n_heads)
|
| 45 |
im = ax.imshow(dla_mock, cmap="RdBu_r")
|
src/interpretability/attribution.py
CHANGED
|
@@ -18,33 +18,24 @@ class LogitAttributionEngine:
|
|
| 18 |
token_index: int = -1
|
| 19 |
) -> Dict[str, Float[torch.Tensor, "layer head"]]:
|
| 20 |
"""
|
| 21 |
-
Computes DLA for each head
|
| 22 |
-
Formula: DLA = Activation @ W_O @ W_U [target_logit]
|
| 23 |
"""
|
| 24 |
n_layers = self.model.cfg.n_layers
|
| 25 |
n_heads = self.model.cfg.n_heads
|
| 26 |
-
d_model = self.model.cfg.d_model
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
W_U = self.model.predict_action[0].weight[target_logit_index] # [d_model]
|
| 31 |
|
| 32 |
dla_results = torch.zeros((n_layers, n_heads))
|
| 33 |
|
| 34 |
for layer in range(n_layers):
|
| 35 |
-
#
|
| 36 |
-
|
| 37 |
-
head_outputs = cache[f"blocks.{layer}.attn.hook_result"] # [batch, pos, head, d_model]
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
# If we want the last predicted action, we look at the last state token's output
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# Attribution: projection onto W_U
|
| 46 |
-
attribution = torch.matmul(last_token_output, W_U) # [head]
|
| 47 |
-
dla_results[layer] = attribution
|
| 48 |
|
| 49 |
return dla_results
|
| 50 |
|
|
|
|
| 18 |
token_index: int = -1
|
| 19 |
) -> Dict[str, Float[torch.Tensor, "layer head"]]:
|
| 20 |
"""
|
| 21 |
+
Computes DLA for each head: Activation @ W_O @ W_U [target_logit]
|
|
|
|
| 22 |
"""
|
| 23 |
n_layers = self.model.cfg.n_layers
|
| 24 |
n_heads = self.model.cfg.n_heads
|
|
|
|
| 25 |
|
| 26 |
+
# Action prediction unembedding
|
| 27 |
+
W_U = self.model.predict_action[0].weight[target_logit_index]
|
|
|
|
| 28 |
|
| 29 |
dla_results = torch.zeros((n_layers, n_heads))
|
| 30 |
|
| 31 |
for layer in range(n_layers):
|
| 32 |
+
# [batch, pos, head, d_model]
|
| 33 |
+
head_outputs = cache[f"blocks.{layer}.attn.hook_result"]
|
|
|
|
| 34 |
|
| 35 |
+
# S_t is at 3t + 1 in interleaved (R, S, A)
|
| 36 |
+
last_token_output = head_outputs[0, token_index]
|
|
|
|
| 37 |
|
| 38 |
+
dla_results[layer] = torch.matmul(last_token_output, W_U)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
return dla_results
|
| 41 |
|
src/interpretability/induction_scan.py
CHANGED
|
@@ -3,46 +3,34 @@ from typing import List, Tuple
|
|
| 3 |
|
| 4 |
class InductionScanner:
|
| 5 |
"""
|
| 6 |
-
|
| 7 |
-
Induction heads attend to the token that followed the current token's previous occurrence.
|
| 8 |
"""
|
| 9 |
def __init__(self, model):
|
| 10 |
self.model = model
|
| 11 |
|
| 12 |
def scan(self, cache, sequence: torch.Tensor) -> List[Tuple[int, int]]:
|
| 13 |
"""
|
| 14 |
-
Scans
|
| 15 |
-
Logic: For token S, find previous occurrence of S at index i.
|
| 16 |
-
Check if current token attends to token at i+1.
|
| 17 |
"""
|
| 18 |
n_layers = self.model.cfg.n_layers
|
| 19 |
n_heads = self.model.cfg.n_heads
|
| 20 |
-
seq_len = sequence.shape[1]
|
| 21 |
|
| 22 |
induction_heads = []
|
| 23 |
|
| 24 |
-
# Find repeated tokens
|
| 25 |
-
# For simplicity, we assume 'sequence' is the flattened list of tokens (or states)
|
| 26 |
-
# In DT, this is more complex due to interleaving.
|
| 27 |
-
# Let's look at state tokens specifically.
|
| 28 |
-
|
| 29 |
for layer in range(n_layers):
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
for head in range(n_heads):
|
| 33 |
score = self._calculate_induction_score(attn_pattern[0, head])
|
| 34 |
-
if score > 0.5:
|
| 35 |
induction_heads.append((layer, head))
|
| 36 |
|
| 37 |
return induction_heads
|
| 38 |
|
| 39 |
def _calculate_induction_score(self, pattern: torch.Tensor) -> float:
|
| 40 |
"""
|
| 41 |
-
|
| 42 |
-
Checks if the attention is shifted by 1 relative to a diagonal.
|
| 43 |
-
This is a heuristic; more robust methods exist in TransformerLens.
|
| 44 |
"""
|
| 45 |
-
#
|
| 46 |
-
# and check if the second A attends to B.
|
| 47 |
-
# Here we just return a placeholder logic for the scan structure.
|
| 48 |
return torch.diagonal(pattern, offset=-1).mean().item()
|
|
|
|
| 3 |
|
| 4 |
class InductionScanner:
|
| 5 |
"""
|
| 6 |
+
Identifies induction heads that attend to tokens following a previous occurrence.
|
|
|
|
| 7 |
"""
|
| 8 |
def __init__(self, model):
|
| 9 |
self.model = model
|
| 10 |
|
| 11 |
def scan(self, cache, sequence: torch.Tensor) -> List[Tuple[int, int]]:
|
| 12 |
"""
|
| 13 |
+
Scans heads for induction behavior.
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
n_layers = self.model.cfg.n_layers
|
| 16 |
n_heads = self.model.cfg.n_heads
|
|
|
|
| 17 |
|
| 18 |
induction_heads = []
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
for layer in range(n_layers):
|
| 21 |
+
# [batch, head, query_pos, key_pos]
|
| 22 |
+
attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"]
|
| 23 |
|
| 24 |
for head in range(n_heads):
|
| 25 |
score = self._calculate_induction_score(attn_pattern[0, head])
|
| 26 |
+
if score > 0.5:
|
| 27 |
induction_heads.append((layer, head))
|
| 28 |
|
| 29 |
return induction_heads
|
| 30 |
|
| 31 |
def _calculate_induction_score(self, pattern: torch.Tensor) -> float:
|
| 32 |
"""
|
| 33 |
+
Heuristic check for shifted diagonal attention.
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
+
# Checks if attention is shifted by 1 relative to diagonal.
|
|
|
|
|
|
|
| 36 |
return torch.diagonal(pattern, offset=-1).mean().item()
|
src/interpretability/patching.py
CHANGED
|
@@ -28,7 +28,6 @@ class ActivationPatcher:
|
|
| 28 |
|
| 29 |
hook_name = f"blocks.{layer}.attn.hook_result"
|
| 30 |
|
| 31 |
-
# Run the model with the hook
|
| 32 |
with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
|
| 33 |
patched_outputs = self.model(**clean_inputs)
|
| 34 |
|
|
|
|
| 28 |
|
| 29 |
hook_name = f"blocks.{layer}.attn.hook_result"
|
| 30 |
|
|
|
|
| 31 |
with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
|
| 32 |
patched_outputs = self.model(**clean_inputs)
|
| 33 |
|
src/interpretability/sae_manager.py
CHANGED
|
@@ -7,8 +7,7 @@ from jaxtyping import Float
|
|
| 7 |
|
| 8 |
class SAEManager:
|
| 9 |
"""
|
| 10 |
-
|
| 11 |
-
Handles training, decomposition into monosemantic latents, and mechanistic anomaly detection.
|
| 12 |
"""
|
| 13 |
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
|
| 14 |
self.model = model
|
|
@@ -23,7 +22,7 @@ class SAEManager:
|
|
| 23 |
expansion_factor: int = 8,
|
| 24 |
) -> StandardSAE:
|
| 25 |
"""
|
| 26 |
-
Initializes an SAE for a specific hook point
|
| 27 |
"""
|
| 28 |
cfg = StandardSAEConfig(
|
| 29 |
d_in=d_model,
|
|
@@ -43,7 +42,7 @@ class SAEManager:
|
|
| 43 |
epochs: int = 10,
|
| 44 |
):
|
| 45 |
"""
|
| 46 |
-
Trains the SAE on
|
| 47 |
"""
|
| 48 |
if hook_point not in self.saes:
|
| 49 |
self.setup_sae(hook_point, activations.shape[-1])
|
|
@@ -63,7 +62,6 @@ class SAEManager:
|
|
| 63 |
|
| 64 |
optimizer.zero_grad()
|
| 65 |
|
| 66 |
-
# Manual forward pass for training
|
| 67 |
feature_acts = sae.encode(batch_acts)
|
| 68 |
sae_out = sae.decode(feature_acts)
|
| 69 |
|
|
@@ -83,7 +81,7 @@ class SAEManager:
|
|
| 83 |
activations: Float[torch.Tensor, "... d_model"]
|
| 84 |
) -> Float[torch.Tensor, "... d_sae"]:
|
| 85 |
"""
|
| 86 |
-
Decomposes activations into
|
| 87 |
"""
|
| 88 |
if hook_point not in self.saes:
|
| 89 |
raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
|
|
@@ -100,7 +98,7 @@ class SAEManager:
|
|
| 100 |
activations: Float[torch.Tensor, "... d_model"]
|
| 101 |
) -> Float[torch.Tensor, "... d_model"]:
|
| 102 |
"""
|
| 103 |
-
Reconstructs
|
| 104 |
"""
|
| 105 |
if hook_point not in self.saes:
|
| 106 |
raise ValueError(f"SAE for {hook_point} not found.")
|
|
@@ -118,8 +116,7 @@ class SAEManager:
|
|
| 118 |
activations: Float[torch.Tensor, "... d_model"]
|
| 119 |
) -> Float[torch.Tensor, "..."]:
|
| 120 |
"""
|
| 121 |
-
|
| 122 |
-
Formula: ||x - x_hat|| / ||x||
|
| 123 |
"""
|
| 124 |
if hook_point not in self.saes:
|
| 125 |
raise ValueError(f"SAE for {hook_point} not found.")
|
|
|
|
| 7 |
|
| 8 |
class SAEManager:
|
| 9 |
"""
|
| 10 |
+
Manages SAEs for Decision Transformers: training, latent decomposition, and anomaly detection.
|
|
|
|
| 11 |
"""
|
| 12 |
def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
|
| 13 |
self.model = model
|
|
|
|
| 22 |
expansion_factor: int = 8,
|
| 23 |
) -> StandardSAE:
|
| 24 |
"""
|
| 25 |
+
Initializes an SAE for a specific hook point.
|
| 26 |
"""
|
| 27 |
cfg = StandardSAEConfig(
|
| 28 |
d_in=d_model,
|
|
|
|
| 42 |
epochs: int = 10,
|
| 43 |
):
|
| 44 |
"""
|
| 45 |
+
Trains the SAE on trajectory activations.
|
| 46 |
"""
|
| 47 |
if hook_point not in self.saes:
|
| 48 |
self.setup_sae(hook_point, activations.shape[-1])
|
|
|
|
| 62 |
|
| 63 |
optimizer.zero_grad()
|
| 64 |
|
|
|
|
| 65 |
feature_acts = sae.encode(batch_acts)
|
| 66 |
sae_out = sae.decode(feature_acts)
|
| 67 |
|
|
|
|
| 81 |
activations: Float[torch.Tensor, "... d_model"]
|
| 82 |
) -> Float[torch.Tensor, "... d_sae"]:
|
| 83 |
"""
|
| 84 |
+
Decomposes activations into features.
|
| 85 |
"""
|
| 86 |
if hook_point not in self.saes:
|
| 87 |
raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
|
|
|
|
| 98 |
activations: Float[torch.Tensor, "... d_model"]
|
| 99 |
) -> Float[torch.Tensor, "... d_model"]:
|
| 100 |
"""
|
| 101 |
+
Reconstructs original activations.
|
| 102 |
"""
|
| 103 |
if hook_point not in self.saes:
|
| 104 |
raise ValueError(f"SAE for {hook_point} not found.")
|
|
|
|
| 116 |
activations: Float[torch.Tensor, "... d_model"]
|
| 117 |
) -> Float[torch.Tensor, "..."]:
|
| 118 |
"""
|
| 119 |
+
Reconstruction error for anomaly detection: ||x - x_hat|| / ||x||
|
|
|
|
| 120 |
"""
|
| 121 |
if hook_point not in self.saes:
|
| 122 |
raise ValueError(f"SAE for {hook_point} not found.")
|
src/models/hooked_dt.py
CHANGED
|
@@ -54,51 +54,23 @@ class HookedDT(nn.Module):
|
|
| 54 |
):
|
| 55 |
batch_size, seq_len, _ = states.shape
|
| 56 |
|
| 57 |
-
# Embed tokens
|
| 58 |
state_embeddings = self.embed_state(states)
|
| 59 |
action_embeddings = self.embed_action(actions)
|
| 60 |
returns_embeddings = self.embed_return(returns_to_go)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
# Sequence: (R1, S1, A1, R2, S2, A2, ...)
|
| 64 |
stacked_inputs = torch.stack(
|
| 65 |
(returns_embeddings, state_embeddings, action_embeddings), dim=2
|
| 66 |
).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
|
| 67 |
|
| 68 |
stacked_inputs = self.embed_ln(stacked_inputs)
|
| 69 |
|
| 70 |
-
# Add positional embeddings manually or via HookedTransformer
|
| 71 |
-
# DT usually uses learned positional embeddings for timesteps
|
| 72 |
-
# HookedTransformer usually handles this via its own embed_pos
|
| 73 |
-
# We'll use the timestep info to get positional embeddings
|
| 74 |
-
|
| 75 |
-
# For simplicity, let's assume we can use HookedTransformer's forward
|
| 76 |
-
# but we need to handle the interleaved nature.
|
| 77 |
-
|
| 78 |
-
# We pass the stacked_inputs directly to the transformer blocks
|
| 79 |
-
# We use run_with_cache or standard forward based on whether we need the cache
|
| 80 |
-
# For TransformerLens, we need to specify that we are passing embeddings
|
| 81 |
-
|
| 82 |
-
# Note: HookedTransformer expects [batch, pos, d_model] if input is embeddings
|
| 83 |
-
# We need to set use_local_embeddings=True or similar if we want to bypass default embeds
|
| 84 |
-
|
| 85 |
-
# A better way is to use model.blocks directly or use the hook_embed to inject
|
| 86 |
-
|
| 87 |
def embed_hook(value, hook):
|
| 88 |
return stacked_inputs
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
# and pass a dummy tensor of the right shape to the transformer
|
| 92 |
dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
|
| 93 |
|
| 94 |
-
# We want the residual stream after the last block
|
| 95 |
-
# HookedTransformer.run_with_cache returns (output, cache)
|
| 96 |
-
# We can also use return_type="residual" or similar in some versions,
|
| 97 |
-
# but let's just use the cache or the direct output if we set it up correctly.
|
| 98 |
-
|
| 99 |
-
# In TransformerLens, the output of the forward pass is usually the logits.
|
| 100 |
-
# We want the 'hook_resid_post' of the last block.
|
| 101 |
-
|
| 102 |
last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
|
| 103 |
|
| 104 |
with self.transformer.hooks(fwd_hooks=[("hook_embed", embed_hook)]):
|
|
@@ -108,15 +80,12 @@ class HookedDT(nn.Module):
|
|
| 108 |
)
|
| 109 |
|
| 110 |
transformer_outputs = cache[last_block_hook]
|
| 111 |
-
|
| 112 |
-
# Reshape back to (batch, seq, 3, d_model)
|
| 113 |
x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
state_preds = self.predict_state(x[:, :, 2]) # predict next state from action
|
| 120 |
|
| 121 |
return action_preds, state_preds, return_preds
|
| 122 |
|
|
@@ -125,11 +94,11 @@ class HookedDT(nn.Module):
|
|
| 125 |
cfg = HookedTransformerConfig(
|
| 126 |
n_layers=n_layers,
|
| 127 |
d_model=d_model,
|
| 128 |
-
n_ctx=300,
|
| 129 |
d_head=d_model // n_heads,
|
| 130 |
n_heads=n_heads,
|
| 131 |
-
d_vocab=10,
|
| 132 |
-
act_fn="relu",
|
| 133 |
d_mlp=d_model * 4,
|
| 134 |
normalization_type="LN",
|
| 135 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 54 |
):
|
| 55 |
batch_size, seq_len, _ = states.shape
|
| 56 |
|
|
|
|
| 57 |
state_embeddings = self.embed_state(states)
|
| 58 |
action_embeddings = self.embed_action(actions)
|
| 59 |
returns_embeddings = self.embed_return(returns_to_go)
|
| 60 |
|
| 61 |
+
# Interleave (R, S, A) sequence
|
|
|
|
| 62 |
stacked_inputs = torch.stack(
|
| 63 |
(returns_embeddings, state_embeddings, action_embeddings), dim=2
|
| 64 |
).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
|
| 65 |
|
| 66 |
stacked_inputs = self.embed_ln(stacked_inputs)
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def embed_hook(value, hook):
|
| 69 |
return stacked_inputs
|
| 70 |
|
| 71 |
+
# Inject interleaved embeddings into TransformerLens
|
|
|
|
| 72 |
dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
|
| 75 |
|
| 76 |
with self.transformer.hooks(fwd_hooks=[("hook_embed", embed_hook)]):
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
transformer_outputs = cache[last_block_hook]
|
|
|
|
|
|
|
| 83 |
x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
|
| 84 |
|
| 85 |
+
# Action from state, return/state from action
|
| 86 |
+
action_preds = self.predict_action(x[:, :, 1])
|
| 87 |
+
return_preds = self.predict_return(x[:, :, 2])
|
| 88 |
+
state_preds = self.predict_state(x[:, :, 2])
|
|
|
|
| 89 |
|
| 90 |
return action_preds, state_preds, return_preds
|
| 91 |
|
|
|
|
| 94 |
cfg = HookedTransformerConfig(
|
| 95 |
n_layers=n_layers,
|
| 96 |
d_model=d_model,
|
| 97 |
+
n_ctx=300,
|
| 98 |
d_head=d_model // n_heads,
|
| 99 |
n_heads=n_heads,
|
| 100 |
+
d_vocab=10,
|
| 101 |
+
act_fn="relu",
|
| 102 |
d_mlp=d_model * 4,
|
| 103 |
normalization_type="LN",
|
| 104 |
device="cuda" if torch.cuda.is_available() else "cpu"
|