sadhumitha-s commited on
Commit
731ae64
·
1 Parent(s): 0346604

refactor: added logic as comments

Browse files
README.md CHANGED
@@ -1,32 +1,31 @@
1
  # DT-Circuits: Mechanistic Interpretability for Decision Transformers
2
 
3
- DT-Circuits is a research-grade framework designed for the rigorous mechanistic interpretability of Decision Transformers (DT). By leveraging the TransformerLens paradigm, this platform enables researchers to map internal neural circuits, decompose activations using Sparse Autoencoders, and perform causal interventions on agent decision-making.
4
 
5
- The primary objective is to move beyond behavioral observation and saliency maps toward a quantitative understanding of how Reward-to-Go, State, and Action tokens are processed within the residual stream.
6
 
7
  ## Core Capabilities
8
 
9
  ### 1. Circuit Foundation
10
- - **Hooked-DT Architecture**: A custom Decision Transformer implementation wrapped in TransformerLens, providing full access to internal activations, weights, and the residual stream.
11
- - **Direct Logit Attribution (DLA)**: Quantitative mapping of individual attention heads and MLP layers to the final action logits.
12
- - **Induction Head Discovery**: Automated scanning tools to identify heads responsible for temporal pattern recognition and "memory" in RL tasks.
13
 
14
  ### 2. Causal Interventions
15
- - **Activation Patching**: Surgical replacement of activations between "clean" and "corrupted" runs to identify bottleneck features and causal paths.
16
- - **Contrastive Activation Addition (CAA)**: Generation of steering vectors by calculating the mean difference between positive and negative activation sets.
17
- - **Steering Library**: A persistent library of pre-calculated vectors (e.g., success_vector, exploration_vector) that can be injected at inference time to manipulate agent behavior without retraining.
18
 
19
- ### 3. Deep Discovery & Safety
20
- - **Sparse Autoencoder (SAE) Integration**: Tools to train and deploy SAEs on the residual stream, decomposing polysemantic neurons into monosemantic latents.
21
- - **Mechanistic Anomaly Detection**: Utilizing SAE reconstruction error as a high-fidelity proxy for detecting out-of-distribution (OOD) states.
22
 
23
  ## Technical Architecture
24
 
25
- The platform is divided into four primary layers:
26
- - **Data Layer**: PPO Trajectory Harvester for generating high-quality expert demonstrations in Gymnasium environments (e.g., MiniGrid).
27
- - **Model Layer**: The HookedDT implementation which maintains compatibility with standard DT architectures while adding hook-based visibility.
28
- - **Interpretability Layer**: A suite of modules for attribution, patching, SAE management, and steering.
29
- - **Visualization Layer**: A Streamlit-based dashboard for real-time activation monitoring and interactive steering.
30
 
31
  ## Getting Started
32
 
 
1
  # DT-Circuits: Mechanistic Interpretability for Decision Transformers
2
 
3
+ DT-Circuits is a framework for mechanistic interpretability of Decision Transformers (DT). Using TransformerLens, it enables mapping neural circuits, decomposing activations with Sparse Autoencoders (SAEs), and performing causal interventions on agent decision-making.
4
 
5
+ The goal is to understand how Reward-to-Go, State, and Action tokens are processed within the residual stream, moving beyond basic behavioral observation.
6
 
7
  ## Core Capabilities
8
 
9
  ### 1. Circuit Foundation
10
+ - **Hooked-DT**: A Decision Transformer implementation wrapped in TransformerLens for access to internal activations and weights.
11
+ - **Direct Logit Attribution (DLA)**: Quantifies the contribution of individual heads and MLP layers to action logits.
12
+ - **Induction Head Discovery**: Tools to identify heads responsible for temporal pattern recognition.
13
 
14
  ### 2. Causal Interventions
15
+ - **Activation Patching**: Replaces activations between clean and corrupted runs to identify causal paths.
16
+ - **Steering**: Generates and applies steering vectors (e.g., via Contrastive Activation Addition) to manipulate agent behavior at inference time.
 
17
 
18
+ ### 3. SAEs & Safety
19
+ - **SAE Integration**: Tools to train and deploy SAEs on the residual stream to find monosemantic latents.
20
+ - **Anomaly Detection**: Uses SAE reconstruction error to detect out-of-distribution (OOD) states.
21
 
22
  ## Technical Architecture
23
 
24
+ The platform consists of:
25
+ - **Data Layer**: PPO Trajectory Harvester for collecting expert demonstrations (e.g., MiniGrid).
26
+ - **Model Layer**: HookedDT implementation.
27
+ - **Interpretability Layer**: Modules for attribution, patching, SAE management, and steering.
28
+ - **Visualization Layer**: Streamlit dashboard for real-time monitoring and intervention.
29
 
30
  ## Getting Started
31
 
scripts/train_dt.py CHANGED
@@ -7,13 +7,11 @@ import numpy as np
7
  from tqdm import tqdm
8
 
9
  def train():
10
- # 1. Collect Data
11
  harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
12
  trajectories = harvester.collect_trajectories(num_episodes=100)
13
 
14
- # 2. Setup Model
15
  state_dim = trajectories[0]["observations"].shape[1]
16
- action_dim = 7 # MiniGrid has 7 actions
17
 
18
  model = HookedDT.from_config(
19
  state_dim=state_dim,
@@ -26,24 +24,18 @@ def train():
26
  optimizer = optim.AdamW(model.parameters(), lr=1e-4)
27
  criterion = nn.CrossEntropyLoss()
28
 
29
- # 3. Training Loop (Simplified)
30
  model.train()
31
  for epoch in range(10):
32
  total_loss = 0
33
  for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
34
  states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
35
  actions = torch.from_numpy(traj["actions"]).long().unsqueeze(0)
36
- # One-hot actions for input
37
  actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float()
38
  returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
39
  timesteps = torch.arange(states.shape[1]).unsqueeze(0)
40
 
41
- # Mask (dummy for now)
42
-
43
  action_preds, _, _ = model(states, actions_one_hot, returns, timesteps)
44
 
45
- # Target actions (shifted by 1 for next action prediction)
46
- # Standard DT predicts a_t from s_t
47
  loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
48
 
49
  optimizer.zero_grad()
 
7
  from tqdm import tqdm
8
 
9
  def train():
 
10
  harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
11
  trajectories = harvester.collect_trajectories(num_episodes=100)
12
 
 
13
  state_dim = trajectories[0]["observations"].shape[1]
14
+ action_dim = 7 # MiniGrid
15
 
16
  model = HookedDT.from_config(
17
  state_dim=state_dim,
 
24
  optimizer = optim.AdamW(model.parameters(), lr=1e-4)
25
  criterion = nn.CrossEntropyLoss()
26
 
 
27
  model.train()
28
  for epoch in range(10):
29
  total_loss = 0
30
  for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
31
  states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)
32
  actions = torch.from_numpy(traj["actions"]).long().unsqueeze(0)
 
33
  actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float()
34
  returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)
35
  timesteps = torch.arange(states.shape[1]).unsqueeze(0)
36
 
 
 
37
  action_preds, _, _ = model(states, actions_one_hot, returns, timesteps)
38
 
 
 
39
  loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
40
 
41
  optimizer.zero_grad()
src/dashboard/app.py CHANGED
@@ -10,47 +10,36 @@ st.set_page_config(page_title="DT-Explorer", layout="wide")
10
 
11
  st.title("DT-Explorer: Mechanistic Interpretability for Decision Transformers")
12
 
13
- # Sidebar for controls
14
  st.sidebar.header("Model Configuration")
15
  n_layers = st.sidebar.slider("Layers", 1, 12, 1)
16
  n_heads = st.sidebar.slider("Heads", 1, 8, 4)
17
 
18
- # Load Model
19
  @st.cache_resource
20
  def load_model():
21
- # Placeholder dimensions for MiniGrid
22
  state_dim = 2739 # FlatObsWrapper for 8x8 MiniGrid
23
  action_dim = 7
24
  model = HookedDT.from_config(state_dim, action_dim, n_layers=n_layers, n_heads=n_heads)
25
- # model.load_state_dict(torch.load("models/mini_dt.pt"))
26
  return model
27
 
28
  model = load_model()
29
 
30
- # Dashboard Tabs
31
  tab1, tab2, tab3 = st.tabs(["Circuit Mapping", "Causal Intervention", "SAE Explorer"])
32
 
33
  with tab1:
34
  st.header("Direct Logit Attribution")
35
- # Simulate a forward pass
36
  if st.button("Run Attribution Analysis"):
37
- # Dummy data for demo
38
  states = torch.randn(1, 10, model.state_dim)
39
  actions = torch.randn(1, 10, model.action_dim)
40
  returns = torch.randn(1, 10, 1)
41
  timesteps = torch.arange(10).unsqueeze(0)
42
 
43
- # Capture cache
44
  logits, cache = model.transformer.run_with_cache(
45
- # Need to handle DT's interleaved forward pass here
46
- # For demo, we'll just show the UI structure
47
  torch.randn(1, 30, model.cfg.d_model)
48
  )
49
 
50
  engine = LogitAttributionEngine(model)
51
- # dla = engine.calculate_dla(cache, target_logit_index=0)
52
 
53
- # Placeholder plot
54
  fig, ax = plt.subplots()
55
  dla_mock = np.random.randn(n_layers, n_heads)
56
  im = ax.imshow(dla_mock, cmap="RdBu_r")
 
10
 
11
  st.title("DT-Explorer: Mechanistic Interpretability for Decision Transformers")
12
 
 
13
  st.sidebar.header("Model Configuration")
14
  n_layers = st.sidebar.slider("Layers", 1, 12, 1)
15
  n_heads = st.sidebar.slider("Heads", 1, 8, 4)
16
 
 
17
  @st.cache_resource
18
  def load_model():
 
19
  state_dim = 2739 # FlatObsWrapper for 8x8 MiniGrid
20
  action_dim = 7
21
  model = HookedDT.from_config(state_dim, action_dim, n_layers=n_layers, n_heads=n_heads)
 
22
  return model
23
 
24
  model = load_model()
25
 
 
26
  tab1, tab2, tab3 = st.tabs(["Circuit Mapping", "Causal Intervention", "SAE Explorer"])
27
 
28
  with tab1:
29
  st.header("Direct Logit Attribution")
 
30
  if st.button("Run Attribution Analysis"):
31
+ # Mock data for demo
32
  states = torch.randn(1, 10, model.state_dim)
33
  actions = torch.randn(1, 10, model.action_dim)
34
  returns = torch.randn(1, 10, 1)
35
  timesteps = torch.arange(10).unsqueeze(0)
36
 
 
37
  logits, cache = model.transformer.run_with_cache(
 
 
38
  torch.randn(1, 30, model.cfg.d_model)
39
  )
40
 
41
  engine = LogitAttributionEngine(model)
 
42
 
 
43
  fig, ax = plt.subplots()
44
  dla_mock = np.random.randn(n_layers, n_heads)
45
  im = ax.imshow(dla_mock, cmap="RdBu_r")
src/interpretability/attribution.py CHANGED
@@ -18,33 +18,24 @@ class LogitAttributionEngine:
18
  token_index: int = -1
19
  ) -> Dict[str, Float[torch.Tensor, "layer head"]]:
20
  """
21
- Computes DLA for each head in the model.
22
- Formula: DLA = Activation @ W_O @ W_U [target_logit]
23
  """
24
  n_layers = self.model.cfg.n_layers
25
  n_heads = self.model.cfg.n_heads
26
- d_model = self.model.cfg.d_model
27
 
28
- # Get the unembedding matrix for the action prediction head
29
- # In our HookedDT, the prediction head is a Linear layer: self.predict_action[0].weight
30
- W_U = self.model.predict_action[0].weight[target_logit_index] # [d_model]
31
 
32
  dla_results = torch.zeros((n_layers, n_heads))
33
 
34
  for layer in range(n_layers):
35
- # Head outputs from cache: [batch, pos, head, d_model]
36
- # For HookedTransformer, it's usually 'blocks.{layer}.attn.hook_result'
37
- head_outputs = cache[f"blocks.{layer}.attn.hook_result"] # [batch, pos, head, d_model]
38
 
39
- # We take the token_index (usually the last state token)
40
- # In interleaved (R, S, A), S_t is at 3t + 1
41
- # If we want the last predicted action, we look at the last state token's output
42
 
43
- last_token_output = head_outputs[0, token_index] # [head, d_model]
44
-
45
- # Attribution: projection onto W_U
46
- attribution = torch.matmul(last_token_output, W_U) # [head]
47
- dla_results[layer] = attribution
48
 
49
  return dla_results
50
 
 
18
  token_index: int = -1
19
  ) -> Dict[str, Float[torch.Tensor, "layer head"]]:
20
  """
21
+ Computes DLA for each head: Activation @ W_O @ W_U [target_logit]
 
22
  """
23
  n_layers = self.model.cfg.n_layers
24
  n_heads = self.model.cfg.n_heads
 
25
 
26
+ # Action prediction unembedding
27
+ W_U = self.model.predict_action[0].weight[target_logit_index]
 
28
 
29
  dla_results = torch.zeros((n_layers, n_heads))
30
 
31
  for layer in range(n_layers):
32
+ # [batch, pos, head, d_model]
33
+ head_outputs = cache[f"blocks.{layer}.attn.hook_result"]
 
34
 
35
+ # S_t is at 3t + 1 in interleaved (R, S, A)
36
+ last_token_output = head_outputs[0, token_index]
 
37
 
38
+ dla_results[layer] = torch.matmul(last_token_output, W_U)
 
 
 
 
39
 
40
  return dla_results
41
 
src/interpretability/induction_scan.py CHANGED
@@ -3,46 +3,34 @@ from typing import List, Tuple
3
 
4
  class InductionScanner:
5
  """
6
- Automated scan for Induction Heads.
7
- Induction heads attend to the token that followed the current token's previous occurrence.
8
  """
9
  def __init__(self, model):
10
  self.model = model
11
 
12
  def scan(self, cache, sequence: torch.Tensor) -> List[Tuple[int, int]]:
13
  """
14
- Scans all heads for 'Induction' behavior on a given sequence.
15
- Logic: For token S, find previous occurrence of S at index i.
16
- Check if current token attends to token at i+1.
17
  """
18
  n_layers = self.model.cfg.n_layers
19
  n_heads = self.model.cfg.n_heads
20
- seq_len = sequence.shape[1]
21
 
22
  induction_heads = []
23
 
24
- # Find repeated tokens
25
- # For simplicity, we assume 'sequence' is the flattened list of tokens (or states)
26
- # In DT, this is more complex due to interleaving.
27
- # Let's look at state tokens specifically.
28
-
29
  for layer in range(n_layers):
30
- attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"] # [batch, head, query_pos, key_pos]
 
31
 
32
  for head in range(n_heads):
33
  score = self._calculate_induction_score(attn_pattern[0, head])
34
- if score > 0.5: # Threshold for induction
35
  induction_heads.append((layer, head))
36
 
37
  return induction_heads
38
 
39
  def _calculate_induction_score(self, pattern: torch.Tensor) -> float:
40
  """
41
- Simplified induction score.
42
- Checks if the attention is shifted by 1 relative to a diagonal.
43
- This is a heuristic; more robust methods exist in TransformerLens.
44
  """
45
- # In a real scenario, we'd use a sequence like [A, B, C, ..., A]
46
- # and check if the second A attends to B.
47
- # Here we just return a placeholder logic for the scan structure.
48
  return torch.diagonal(pattern, offset=-1).mean().item()
 
3
 
4
  class InductionScanner:
5
  """
6
+ Identifies induction heads that attend to tokens following a previous occurrence.
 
7
  """
8
  def __init__(self, model):
9
  self.model = model
10
 
11
  def scan(self, cache, sequence: torch.Tensor) -> List[Tuple[int, int]]:
12
  """
13
+ Scans heads for induction behavior.
 
 
14
  """
15
  n_layers = self.model.cfg.n_layers
16
  n_heads = self.model.cfg.n_heads
 
17
 
18
  induction_heads = []
19
 
 
 
 
 
 
20
  for layer in range(n_layers):
21
+ # [batch, head, query_pos, key_pos]
22
+ attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"]
23
 
24
  for head in range(n_heads):
25
  score = self._calculate_induction_score(attn_pattern[0, head])
26
+ if score > 0.5:
27
  induction_heads.append((layer, head))
28
 
29
  return induction_heads
30
 
31
  def _calculate_induction_score(self, pattern: torch.Tensor) -> float:
32
  """
33
+ Heuristic check for shifted diagonal attention.
 
 
34
  """
35
+ # Checks if attention is shifted by 1 relative to diagonal.
 
 
36
  return torch.diagonal(pattern, offset=-1).mean().item()
src/interpretability/patching.py CHANGED
@@ -28,7 +28,6 @@ class ActivationPatcher:
28
 
29
  hook_name = f"blocks.{layer}.attn.hook_result"
30
 
31
- # Run the model with the hook
32
  with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
33
  patched_outputs = self.model(**clean_inputs)
34
 
 
28
 
29
  hook_name = f"blocks.{layer}.attn.hook_result"
30
 
 
31
  with self.model.transformer.hooks(fwd_hooks=[(hook_name, patch_hook)]):
32
  patched_outputs = self.model(**clean_inputs)
33
 
src/interpretability/sae_manager.py CHANGED
@@ -7,8 +7,7 @@ from jaxtyping import Float
7
 
8
  class SAEManager:
9
  """
10
- Research-grade manager for Sparse Autoencoders (SAEs) integrated with Decision Transformers.
11
- Handles training, decomposition into monosemantic latents, and mechanistic anomaly detection.
12
  """
13
  def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
14
  self.model = model
@@ -23,7 +22,7 @@ class SAEManager:
23
  expansion_factor: int = 8,
24
  ) -> StandardSAE:
25
  """
26
- Initializes an SAE for a specific hook point in the transformer.
27
  """
28
  cfg = StandardSAEConfig(
29
  d_in=d_model,
@@ -43,7 +42,7 @@ class SAEManager:
43
  epochs: int = 10,
44
  ):
45
  """
46
- Trains the SAE on collected trajectory activations.
47
  """
48
  if hook_point not in self.saes:
49
  self.setup_sae(hook_point, activations.shape[-1])
@@ -63,7 +62,6 @@ class SAEManager:
63
 
64
  optimizer.zero_grad()
65
 
66
- # Manual forward pass for training
67
  feature_acts = sae.encode(batch_acts)
68
  sae_out = sae.decode(feature_acts)
69
 
@@ -83,7 +81,7 @@ class SAEManager:
83
  activations: Float[torch.Tensor, "... d_model"]
84
  ) -> Float[torch.Tensor, "... d_sae"]:
85
  """
86
- Decomposes activations into monosemantic features.
87
  """
88
  if hook_point not in self.saes:
89
  raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
@@ -100,7 +98,7 @@ class SAEManager:
100
  activations: Float[torch.Tensor, "... d_model"]
101
  ) -> Float[torch.Tensor, "... d_model"]:
102
  """
103
- Reconstructs the original activations using the SAE.
104
  """
105
  if hook_point not in self.saes:
106
  raise ValueError(f"SAE for {hook_point} not found.")
@@ -118,8 +116,7 @@ class SAEManager:
118
  activations: Float[torch.Tensor, "... d_model"]
119
  ) -> Float[torch.Tensor, "..."]:
120
  """
121
- Calculates reconstruction error as a proxy for mechanistic anomaly detection.
122
- Formula: ||x - x_hat|| / ||x||
123
  """
124
  if hook_point not in self.saes:
125
  raise ValueError(f"SAE for {hook_point} not found.")
 
7
 
8
  class SAEManager:
9
  """
10
+ Manages SAEs for Decision Transformers: training, latent decomposition, and anomaly detection.
 
11
  """
12
  def __init__(self, model: nn.Module, sae_dir: str = "artifacts/saes"):
13
  self.model = model
 
22
  expansion_factor: int = 8,
23
  ) -> StandardSAE:
24
  """
25
+ Initializes an SAE for a specific hook point.
26
  """
27
  cfg = StandardSAEConfig(
28
  d_in=d_model,
 
42
  epochs: int = 10,
43
  ):
44
  """
45
+ Trains the SAE on trajectory activations.
46
  """
47
  if hook_point not in self.saes:
48
  self.setup_sae(hook_point, activations.shape[-1])
 
62
 
63
  optimizer.zero_grad()
64
 
 
65
  feature_acts = sae.encode(batch_acts)
66
  sae_out = sae.decode(feature_acts)
67
 
 
81
  activations: Float[torch.Tensor, "... d_model"]
82
  ) -> Float[torch.Tensor, "... d_sae"]:
83
  """
84
+ Decomposes activations into features.
85
  """
86
  if hook_point not in self.saes:
87
  raise ValueError(f"SAE for {hook_point} not found. Train or load it first.")
 
98
  activations: Float[torch.Tensor, "... d_model"]
99
  ) -> Float[torch.Tensor, "... d_model"]:
100
  """
101
+ Reconstructs original activations.
102
  """
103
  if hook_point not in self.saes:
104
  raise ValueError(f"SAE for {hook_point} not found.")
 
116
  activations: Float[torch.Tensor, "... d_model"]
117
  ) -> Float[torch.Tensor, "..."]:
118
  """
119
+ Reconstruction error for anomaly detection: ||x - x_hat|| / ||x||
 
120
  """
121
  if hook_point not in self.saes:
122
  raise ValueError(f"SAE for {hook_point} not found.")
src/models/hooked_dt.py CHANGED
@@ -54,51 +54,23 @@ class HookedDT(nn.Module):
54
  ):
55
  batch_size, seq_len, _ = states.shape
56
 
57
- # Embed tokens
58
  state_embeddings = self.embed_state(states)
59
  action_embeddings = self.embed_action(actions)
60
  returns_embeddings = self.embed_return(returns_to_go)
61
 
62
- # In DT, we interleave (R, S, A)
63
- # Sequence: (R1, S1, A1, R2, S2, A2, ...)
64
  stacked_inputs = torch.stack(
65
  (returns_embeddings, state_embeddings, action_embeddings), dim=2
66
  ).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
67
 
68
  stacked_inputs = self.embed_ln(stacked_inputs)
69
 
70
- # Add positional embeddings manually or via HookedTransformer
71
- # DT usually uses learned positional embeddings for timesteps
72
- # HookedTransformer usually handles this via its own embed_pos
73
- # We'll use the timestep info to get positional embeddings
74
-
75
- # For simplicity, let's assume we can use HookedTransformer's forward
76
- # but we need to handle the interleaved nature.
77
-
78
- # We pass the stacked_inputs directly to the transformer blocks
79
- # We use run_with_cache or standard forward based on whether we need the cache
80
- # For TransformerLens, we need to specify that we are passing embeddings
81
-
82
- # Note: HookedTransformer expects [batch, pos, d_model] if input is embeddings
83
- # We need to set use_local_embeddings=True or similar if we want to bypass default embeds
84
-
85
- # A better way is to use model.blocks directly or use the hook_embed to inject
86
-
87
  def embed_hook(value, hook):
88
  return stacked_inputs
89
 
90
- # We inject our interleaved embeddings into the 'hook_embed'
91
- # and pass a dummy tensor of the right shape to the transformer
92
  dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
93
 
94
- # We want the residual stream after the last block
95
- # HookedTransformer.run_with_cache returns (output, cache)
96
- # We can also use return_type="residual" or similar in some versions,
97
- # but let's just use the cache or the direct output if we set it up correctly.
98
-
99
- # In TransformerLens, the output of the forward pass is usually the logits.
100
- # We want the 'hook_resid_post' of the last block.
101
-
102
  last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
103
 
104
  with self.transformer.hooks(fwd_hooks=[("hook_embed", embed_hook)]):
@@ -108,15 +80,12 @@ class HookedDT(nn.Module):
108
  )
109
 
110
  transformer_outputs = cache[last_block_hook]
111
-
112
- # Reshape back to (batch, seq, 3, d_model)
113
  x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
114
 
115
- # Predict (A from S, S from A, R from S?)
116
- # Standard DT: Action is predicted from State token
117
- action_preds = self.predict_action(x[:, :, 1]) # predict next action from state
118
- return_preds = self.predict_return(x[:, :, 2]) # predict next return from action
119
- state_preds = self.predict_state(x[:, :, 2]) # predict next state from action
120
 
121
  return action_preds, state_preds, return_preds
122
 
@@ -125,11 +94,11 @@ class HookedDT(nn.Module):
125
  cfg = HookedTransformerConfig(
126
  n_layers=n_layers,
127
  d_model=d_model,
128
- n_ctx=300, # Max sequence length * 3
129
  d_head=d_model // n_heads,
130
  n_heads=n_heads,
131
- d_vocab=10, # Dummy value, we use custom embeddings
132
- act_fn="relu", # DT original uses ReLU or GeLU
133
  d_mlp=d_model * 4,
134
  normalization_type="LN",
135
  device="cuda" if torch.cuda.is_available() else "cpu"
 
54
  ):
55
  batch_size, seq_len, _ = states.shape
56
 
 
57
  state_embeddings = self.embed_state(states)
58
  action_embeddings = self.embed_action(actions)
59
  returns_embeddings = self.embed_return(returns_to_go)
60
 
61
+ # Interleave (R, S, A) sequence
 
62
  stacked_inputs = torch.stack(
63
  (returns_embeddings, state_embeddings, action_embeddings), dim=2
64
  ).reshape(batch_size, 3 * seq_len, self.cfg.d_model)
65
 
66
  stacked_inputs = self.embed_ln(stacked_inputs)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def embed_hook(value, hook):
69
  return stacked_inputs
70
 
71
+ # Inject interleaved embeddings into TransformerLens
 
72
  dummy_input = torch.zeros((batch_size, 3 * seq_len), dtype=torch.long, device=stacked_inputs.device)
73
 
 
 
 
 
 
 
 
 
74
  last_block_hook = f"blocks.{self.cfg.n_layers - 1}.hook_resid_post"
75
 
76
  with self.transformer.hooks(fwd_hooks=[("hook_embed", embed_hook)]):
 
80
  )
81
 
82
  transformer_outputs = cache[last_block_hook]
 
 
83
  x = transformer_outputs.reshape(batch_size, seq_len, 3, self.cfg.d_model)
84
 
85
+ # Action from state, return/state from action
86
+ action_preds = self.predict_action(x[:, :, 1])
87
+ return_preds = self.predict_return(x[:, :, 2])
88
+ state_preds = self.predict_state(x[:, :, 2])
 
89
 
90
  return action_preds, state_preds, return_preds
91
 
 
94
  cfg = HookedTransformerConfig(
95
  n_layers=n_layers,
96
  d_model=d_model,
97
+ n_ctx=300,
98
  d_head=d_model // n_heads,
99
  n_heads=n_heads,
100
+ d_vocab=10,
101
+ act_fn="relu",
102
  d_mlp=d_model * 4,
103
  normalization_type="LN",
104
  device="cuda" if torch.cuda.is_available() else "cpu"