File size: 2,689 Bytes
8577352
4aa19e7
 
e2614dc
 
 
 
 
 
4aa19e7
 
 
 
 
8577352
 
b7ddfc6
8577352
e2614dc
8577352
 
b7ddfc6
 
e2614dc
8577352
 
 
e2614dc
b7ddfc6
e2614dc
 
 
 
b7ddfc6
 
 
 
e2614dc
 
b7ddfc6
e2614dc
 
8577352
e2614dc
b7ddfc6
e2614dc
 
b7ddfc6
 
 
 
8577352
b7ddfc6
e2614dc
8577352
 
e2614dc
8577352
e2614dc
 
 
 
 
 
 
 
 
 
8577352
 
e2614dc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

# Add project root to path for absolute imports
root_path = str(Path(__file__).resolve().parent.parent)
if root_path not in sys.path:
    sys.path.append(root_path)

from src.models.hooked_dt import HookedDT
from src.data.harvester import PPOHarvester
from src.config import cfg

def train():
    """Main training loop for Decision Transformer."""
    # Step 1: Collect data from expert PPO teacher
    harvester = PPOHarvester(env_id=cfg.data.env_id, model_path="models/ppo_teacher.zip")
    trajectories = harvester.collect_trajectories(num_episodes=cfg.data.num_episodes)
    
    # Save trajectories for the dashboard to use later
    harvester.save_trajectories(trajectories, "data/trajectories.pt")
    
    state_dim = trajectories[0]["observations"].shape[1]
    action_dim = cfg.model.action_dim
    
    model = HookedDT.from_config(
        state_dim=state_dim,
        action_dim=action_dim,
        n_layers=cfg.model.n_layers,
        n_heads=cfg.model.n_heads,
        d_model=cfg.model.d_model,
        max_length=cfg.model.max_length
    )
    
    optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
    criterion = nn.CrossEntropyLoss()

    # Step 2: Train the DT
    model.train()
    for epoch in range(cfg.train.epochs):
        total_loss = 0
        for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
            # Truncate to match model max_length
            max_len = model.max_length
            states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)[:, -max_len:]
            actions = torch.from_numpy(traj["actions"]).long()[-max_len:]
            actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
            returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)[:, -max_len:]

            # Predict actions based on State tokens
            action_preds = model(states, actions_one_hot, returns)
            
            # Cross entropy loss on predicted actions
            loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch} Loss: {total_loss / len(trajectories)}")

    # Step 3: Save the trained model
    os.makedirs("models", exist_ok=True)
    torch.save(model.state_dict(), "models/mini_dt.pt")
    print("Model saved to models/mini_dt.pt")

if __name__ == "__main__":
    train()