Spaces:
Running
Running
File size: 4,327 Bytes
e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 4aa19e7 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc b7ddfc6 e2614dc b7ddfc6 e2614dc 8577352 731ae64 e2614dc 11dbbc6 e2614dc b7ddfc6 8577352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig
from jaxtyping import Float, Int
from typing import Optional, Union, List
class HookedDT(nn.Module):
"""
Decision Transformer wrapped in TransformerLens logic.
Supports State, Action, and Reward-to-Go (RTG) tokens.
"""
def __init__(
self,
cfg: HookedTransformerConfig,
state_dim: int,
action_dim: int,
max_length: int = 30,
):
super().__init__()
self.cfg = cfg
self.state_dim = state_dim
self.action_dim = action_dim
self.max_length = max_length
# Core transformer blocks from TransformerLens
self.transformer = HookedTransformer(cfg)
# DT-specific embeddings
self.embed_return = nn.Linear(1, cfg.d_model)
self.embed_state = nn.Linear(state_dim, cfg.d_model)
self.embed_action = nn.Linear(action_dim, cfg.d_model)
self.embed_ln = nn.LayerNorm(cfg.d_model)
# Prediction heads
self.predict_action = nn.Sequential(nn.Linear(cfg.d_model, action_dim))
self.predict_return = nn.Sequential(nn.Linear(cfg.d_model, 1))
self.predict_state = nn.Sequential(nn.Linear(cfg.d_model, state_dim))
def get_embeddings(self, states, actions, returns_to_go):
"""Interleaves RTG, State, and Action embeddings."""
batch_size, seq_len, _ = states.shape
ret_emb = self.embed_return(returns_to_go)
state_emb = self.embed_state(states)
act_emb = self.embed_action(actions)
# Interleave: [R1, S1, A1, R2, S2, A2, ...]
stacked = torch.stack((ret_emb, state_emb, act_emb), dim=2)
stacked = stacked.reshape(batch_size, 3 * seq_len, self.cfg.d_model)
return self.embed_ln(stacked)
def forward(self, states, actions, returns_to_go, timesteps=None, return_cache=False):
"""Forward pass through DT."""
# Truncate to max_length to fit within transformer context
states = states[:, -self.max_length:]
actions = actions[:, -self.max_length:]
returns_to_go = returns_to_go[:, -self.max_length:]
embeddings = self.get_embeddings(states, actions, returns_to_go)
dummy_tokens = torch.zeros((embeddings.shape[0], embeddings.shape[1]),
dtype=torch.long, device=embeddings.device)
def inject_embeddings(value, hook):
return embeddings
# We need the residual stream post-processing from the last block
last_resid_hook = f"blocks.{self.cfg.n_layers-1}.hook_resid_post"
if return_cache:
with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
_, cache = self.transformer.run_with_cache(dummy_tokens)
last_resid = cache[last_resid_hook]
x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
action_preds = self.predict_action(x[:, :, 1]) # State token predicts action
return action_preds, cache
else:
with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
# run_with_cache is safer to ensure we can grab the specific hook output
_, cache = self.transformer.run_with_cache(dummy_tokens, names_filter=lambda n: n == last_resid_hook)
last_resid = cache[last_resid_hook]
x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
action_preds = self.predict_action(x[:, :, 1])
return action_preds
@classmethod
def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128, max_length=30):
cfg = HookedTransformerConfig(
n_layers=n_layers,
d_model=d_model,
n_ctx=3 * max_length,
d_head=d_model // n_heads,
n_heads=n_heads,
d_vocab=10, # Dummy vocab size
act_fn="relu",
d_mlp=d_model * 4,
normalization_type="LN",
use_attn_result=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
return cls(cfg, state_dim, action_dim, max_length=max_length)
|