File size: 4,327 Bytes
e2614dc
 
 
 
 
 
 
 
8577352
e2614dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8577352
e2614dc
 
8577352
e2614dc
 
 
 
 
 
8577352
 
 
e2614dc
8577352
 
e2614dc
 
8577352
 
 
e2614dc
8577352
 
 
 
e2614dc
8577352
 
4aa19e7
 
 
 
 
8577352
 
 
e2614dc
8577352
 
e2614dc
8577352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2614dc
 
b7ddfc6
e2614dc
 
 
b7ddfc6
e2614dc
 
8577352
731ae64
e2614dc
 
11dbbc6
e2614dc
 
b7ddfc6
8577352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig
from jaxtyping import Float, Int
from typing import Optional, Union, List

class HookedDT(nn.Module):
    """
    Decision Transformer wrapped in TransformerLens logic.
    Supports State, Action, and Reward-to-Go (RTG) tokens.
    """
    def __init__(
        self,
        cfg: HookedTransformerConfig,
        state_dim: int,
        action_dim: int,
        max_length: int = 30,
    ):
        super().__init__()
        self.cfg = cfg
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_length = max_length

        # Core transformer blocks from TransformerLens
        self.transformer = HookedTransformer(cfg)

        # DT-specific embeddings
        self.embed_return = nn.Linear(1, cfg.d_model)
        self.embed_state = nn.Linear(state_dim, cfg.d_model)
        self.embed_action = nn.Linear(action_dim, cfg.d_model)
        self.embed_ln = nn.LayerNorm(cfg.d_model)

        # Prediction heads
        self.predict_action = nn.Sequential(nn.Linear(cfg.d_model, action_dim))
        self.predict_return = nn.Sequential(nn.Linear(cfg.d_model, 1))
        self.predict_state = nn.Sequential(nn.Linear(cfg.d_model, state_dim))

    def get_embeddings(self, states, actions, returns_to_go):
        """Interleaves RTG, State, and Action embeddings."""
        batch_size, seq_len, _ = states.shape
        
        ret_emb = self.embed_return(returns_to_go)
        state_emb = self.embed_state(states)
        act_emb = self.embed_action(actions)
        
        # Interleave: [R1, S1, A1, R2, S2, A2, ...]
        stacked = torch.stack((ret_emb, state_emb, act_emb), dim=2)
        stacked = stacked.reshape(batch_size, 3 * seq_len, self.cfg.d_model)
        return self.embed_ln(stacked)

    def forward(self, states, actions, returns_to_go, timesteps=None, return_cache=False):
        """Forward pass through DT."""
        # Truncate to max_length to fit within transformer context
        states = states[:, -self.max_length:]
        actions = actions[:, -self.max_length:]
        returns_to_go = returns_to_go[:, -self.max_length:]

        embeddings = self.get_embeddings(states, actions, returns_to_go)
        dummy_tokens = torch.zeros((embeddings.shape[0], embeddings.shape[1]), 
                                 dtype=torch.long, device=embeddings.device)
        
        def inject_embeddings(value, hook):
            return embeddings

        # We need the residual stream post-processing from the last block
        last_resid_hook = f"blocks.{self.cfg.n_layers-1}.hook_resid_post"
        
        if return_cache:
            with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
                _, cache = self.transformer.run_with_cache(dummy_tokens)
            
            last_resid = cache[last_resid_hook]
            x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
            action_preds = self.predict_action(x[:, :, 1]) # State token predicts action
            return action_preds, cache
        else:
            with self.transformer.hooks(fwd_hooks=[("hook_embed", inject_embeddings)]):
                # run_with_cache is safer to ensure we can grab the specific hook output
                _, cache = self.transformer.run_with_cache(dummy_tokens, names_filter=lambda n: n == last_resid_hook)
            
            last_resid = cache[last_resid_hook]
            x = last_resid.reshape(states.shape[0], states.shape[1], 3, self.cfg.d_model)
            action_preds = self.predict_action(x[:, :, 1])
            return action_preds

    @classmethod
    def from_config(cls, state_dim, action_dim, n_layers=2, n_heads=4, d_model=128, max_length=30):
        cfg = HookedTransformerConfig(
            n_layers=n_layers,
            d_model=d_model,
            n_ctx=3 * max_length, 
            d_head=d_model // n_heads,
            n_heads=n_heads,
            d_vocab=10, # Dummy vocab size
            act_fn="relu", 
            d_mlp=d_model * 4,
            normalization_type="LN",
            use_attn_result=True,
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
        return cls(cfg, state_dim, action_dim, max_length=max_length)