Spaces:
Running
Running
File size: 2,689 Bytes
8577352 4aa19e7 e2614dc 4aa19e7 8577352 b7ddfc6 8577352 e2614dc 8577352 b7ddfc6 e2614dc 8577352 e2614dc b7ddfc6 e2614dc b7ddfc6 e2614dc b7ddfc6 e2614dc 8577352 e2614dc b7ddfc6 e2614dc b7ddfc6 8577352 b7ddfc6 e2614dc 8577352 e2614dc 8577352 e2614dc 8577352 e2614dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
# Add project root to path for absolute imports
root_path = str(Path(__file__).resolve().parent.parent)
if root_path not in sys.path:
sys.path.append(root_path)
from src.models.hooked_dt import HookedDT
from src.data.harvester import PPOHarvester
from src.config import cfg
def train():
"""Main training loop for Decision Transformer."""
# Step 1: Collect data from expert PPO teacher
harvester = PPOHarvester(env_id=cfg.data.env_id, model_path="models/ppo_teacher.zip")
trajectories = harvester.collect_trajectories(num_episodes=cfg.data.num_episodes)
# Save trajectories for the dashboard to use later
harvester.save_trajectories(trajectories, "data/trajectories.pt")
state_dim = trajectories[0]["observations"].shape[1]
action_dim = cfg.model.action_dim
model = HookedDT.from_config(
state_dim=state_dim,
action_dim=action_dim,
n_layers=cfg.model.n_layers,
n_heads=cfg.model.n_heads,
d_model=cfg.model.d_model,
max_length=cfg.model.max_length
)
optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
criterion = nn.CrossEntropyLoss()
# Step 2: Train the DT
model.train()
for epoch in range(cfg.train.epochs):
total_loss = 0
for traj in tqdm(trajectories, desc=f"Epoch {epoch}"):
# Truncate to match model max_length
max_len = model.max_length
states = torch.from_numpy(traj["observations"]).float().unsqueeze(0)[:, -max_len:]
actions = torch.from_numpy(traj["actions"]).long()[-max_len:]
actions_one_hot = torch.nn.functional.one_hot(actions, num_classes=action_dim).float().unsqueeze(0)
returns = torch.from_numpy(traj["rewards"]).float().unsqueeze(0).unsqueeze(-1)[:, -max_len:]
# Predict actions based on State tokens
action_preds = model(states, actions_one_hot, returns)
# Cross entropy loss on predicted actions
loss = criterion(action_preds.view(-1, action_dim), actions.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch} Loss: {total_loss / len(trajectories)}")
# Step 3: Save the trained model
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/mini_dt.pt")
print("Model saved to models/mini_dt.pt")
if __name__ == "__main__":
train()
|