DT-Explorer / src /models /hooked_dt.py
sadhumitha-s's picture
refactor: implement centralized configuration, upgrade SAE training to multi-layer TopK, and optimize dashboard attribution UX
b7ddfc6
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig
from jaxtyping import Float, Int
from typing import Optional, Union, List
class HookedDT(nn.Module):
"""
Decision Transformer wrapped in TransformerLens logic.
Supports State, Action, and Reward-to-Go (RTG) tokens.
"""
def __init__(
self,
cfg: HookedTransformerConfig,
state_dim: int,
action_dim: int,
max_length: int = 30,
):
super().__init__()
self.cfg = cfg
self.state_dim = state_dim
self.action_dim = action_dim
self.max_length = max_length
# Core transformer blocks from TransformerLens
self.transformer = HookedTransformer(cfg)
# DT-specific embeddings
self.embed_return = nn.Linear(1, cfg.d_model)
self.embed_state = nn.Linear(state_dim, cfg.d_model)
self.embed_action = nn.Linear(action_dim, cfg.d_model)
self.embed_ln = nn.LayerNorm(cfg.d_model)
# Prediction heads
self.predict_action = nn.Sequential(nn.Linear(cfg.d_model, action_dim))
self.predict_return = nn.Sequential(nn.Linear(cfg.d_model, 1))
self.predict_state = nn.Sequential(nn.Linear(cfg.d_model, state_dim))
def get_embeddings(self, states, actions, returns_to_go):
"""Interleaves RTG, State, and Action embeddings."""
batch_size, seq_len, _ = states.shape
ret_emb = self.embed_return(returns_to_go)
state_emb = self.embed_state(states)
act_emb = self.embed_action(actions)
# Interleave: [R1, S1, A1, R2, S2, A2, ...]
stacked = torch.stack((ret_emb, state_emb, act_emb), dim=2)
stacked = stacked.reshape(batch_size, 3 * seq_len, self.cfg.d_model)
return self.embed_ln(stacked)
def forward(self, states, actions, returns_to_go, timesteps=None, return_cache=False):
"""Forward pass through DT."""
# Truncate to max_length to fit within transformer context
states = states[:, -self.max_length:]
actions = actions[:, -self.max_length:]
returns_to_go = returns_to_go[:, -self.max_length:]
embeddings = self.get_embeddings(states, actions, returns_to_go)
dummy_tokens = torch.zeros((embeddings.shape[0], embeddings.shape[1]),
dtype=torch.long, device=embeddings.device)
def inject_embeddings(value, hook):
return embeddings
# We need the residual stream post-processing from the last block
last_resid_hook = f"blocks.{self.cfg.n_layers-1}.hook_resid_post"
if return_cache:
with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
_, cache = self.transformer.run_with_cache(dummy_tokens)
last_resid = cache[last_resid_hook]
x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
action_preds = self.predict_action(x[:, :, 1]) # State token predicts action
return action_preds, cache
else:
with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
# run_with_cache is safer to ensure we can grab the specific hook output
_, cache = self.transformer.run_with_cache(dummy_tokens, names_filter=lambda n: n == last_resid_hook)
last_resid = cache[last_resid_hook]
x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
action_preds = self.predict_action(x[:, :, 1])
return action_preds
@classmethod
def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128, max_length=30):
cfg = HookedTransformerConfig(
n_layers=n_layers,
d_model=d_model,
n_ctx=3 * max_length,
d_head=d_model // n_heads,
n_heads=n_heads,
d_vocab=10, # Dummy vocab size
act_fn="relu",
d_mlp=d_model * 4,
normalization_type="LN",
use_attn_result=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
return cls(cfg, state_dim, action_dim, max_length=max_length)