File size: 10,477 Bytes
4fbc241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | """Lightweight PPO implementation for InferenceGym.
No external RL library dependency — just PyTorch.
Supports mixed action spaces via the PolicyNetwork heads.
Designed to train on CPU in <10 minutes for Task 1.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from rl.env_wrapper import GymEnvWrapper
from rl.policy_network import PolicyNetwork, action_dict_to_tensors, batch_action_tensors
@dataclass
class RolloutBuffer:
"""Stores one rollout of experience for PPO update."""
observations: list[np.ndarray] = field(default_factory=list)
actions: list[dict[str, Any]] = field(default_factory=list)
log_probs: list[torch.Tensor] = field(default_factory=list)
rewards: list[float] = field(default_factory=list)
dones: list[bool] = field(default_factory=list)
values: list[float] = field(default_factory=list)
def clear(self) -> None:
self.observations.clear()
self.actions.clear()
self.log_probs.clear()
self.rewards.clear()
self.dones.clear()
self.values.clear()
def __len__(self) -> int:
return len(self.rewards)
class PPOTrainer:
"""Proximal Policy Optimisation trainer."""
def __init__(
self,
env: GymEnvWrapper,
policy: PolicyNetwork,
*,
lr: float = 3e-4,
gamma: float = 0.99,
lam: float = 0.95,
clip_eps: float = 0.2,
entropy_coef: float = 0.01,
value_coef: float = 0.5,
max_grad_norm: float = 0.5,
rollout_length: int = 512,
ppo_epochs: int = 4,
minibatch_size: int = 64,
) -> None:
self.env = env
self.policy = policy
self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
self.gamma = gamma
self.lam = lam
self.clip_eps = clip_eps
self.entropy_coef = entropy_coef
self.value_coef = value_coef
self.max_grad_norm = max_grad_norm
self.rollout_length = rollout_length
self.ppo_epochs = ppo_epochs
self.minibatch_size = minibatch_size
# State
self._obs: np.ndarray | None = None
self._total_steps = 0
self._episodes_done = 0
self._episode_reward = 0.0
def collect_rollout(self, buffer: RolloutBuffer) -> dict[str, float]:
"""Run self.rollout_length steps in the environment, filling the buffer."""
buffer.clear()
self.policy.eval()
episode_rewards: list[float] = []
if self._obs is None:
self._obs = self.env.reset()
self._episode_reward = 0.0
with torch.no_grad():
for _ in range(self.rollout_length):
obs_t = torch.from_numpy(self._obs).unsqueeze(0)
sample = self.policy.sample_action(obs_t)
_, value = self.policy.get_distributions(obs_t)
next_obs, reward, done, info = self.env.step(sample.action_dict)
buffer.observations.append(self._obs.copy())
buffer.actions.append(sample.action_dict)
buffer.log_probs.append(sample.log_prob.squeeze())
buffer.rewards.append(reward)
buffer.dones.append(done)
buffer.values.append(value.item())
self._obs = next_obs
self._total_steps += 1
self._episode_reward += reward
if done:
episode_rewards.append(self._episode_reward)
self._episodes_done += 1
self._obs = self.env.reset()
self._episode_reward = 0.0
# Bootstrap value for incomplete episode
with torch.no_grad():
obs_t = torch.from_numpy(self._obs).unsqueeze(0)
_, last_value = self.policy.get_distributions(obs_t)
last_value = last_value.item()
stats = {
"mean_reward": float(np.mean(episode_rewards)) if episode_rewards else 0.0,
"episodes": len(episode_rewards),
"total_steps": self._total_steps,
}
# Compute GAE
self._compute_gae(buffer, last_value)
return stats
def _compute_gae(self, buffer: RolloutBuffer, last_value: float) -> None:
"""Compute generalized advantage estimates in-place."""
n = len(buffer)
advantages = np.zeros(n, dtype=np.float32)
returns = np.zeros(n, dtype=np.float32)
gae = 0.0
for t in reversed(range(n)):
if t == n - 1:
next_value = last_value
next_done = False
else:
next_value = buffer.values[t + 1]
next_done = buffer.dones[t + 1]
mask = 0.0 if buffer.dones[t] else 1.0
delta = buffer.rewards[t] + self.gamma * next_value * mask - buffer.values[t]
gae = delta + self.gamma * self.lam * mask * gae
advantages[t] = gae
returns[t] = gae + buffer.values[t]
# Store as attributes for update
buffer._advantages = advantages # type: ignore[attr-defined]
buffer._returns = returns # type: ignore[attr-defined]
def update(self, buffer: RolloutBuffer) -> dict[str, float]:
"""Run PPO update on the collected rollout buffer."""
self.policy.train()
n = len(buffer)
# Prepare tensors
obs_batch = torch.from_numpy(np.stack(buffer.observations))
old_log_probs = torch.stack(buffer.log_probs).detach()
action_tensors = batch_action_tensors(
[action_dict_to_tensors(a) for a in buffer.actions]
)
advantages = torch.from_numpy(buffer._advantages) # type: ignore[attr-defined]
returns = torch.from_numpy(buffer._returns) # type: ignore[attr-defined]
# Normalise advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
total_pg_loss = 0.0
total_vf_loss = 0.0
total_entropy = 0.0
num_updates = 0
for _ in range(self.ppo_epochs):
# Create random minibatch indices
indices = np.random.permutation(n)
for start in range(0, n, self.minibatch_size):
end = min(start + self.minibatch_size, n)
idx = indices[start:end]
idx_t = torch.from_numpy(idx).long()
mb_obs = obs_batch[idx_t]
mb_old_log_probs = old_log_probs[idx_t]
mb_advantages = advantages[idx_t]
mb_returns = returns[idx_t]
mb_actions = {k: v[idx_t] for k, v in action_tensors.items()}
new_log_probs, entropy, values = self.policy.evaluate_actions(mb_obs, mb_actions)
# PPO clipped objective
ratio = torch.exp(new_log_probs - mb_old_log_probs)
surr1 = ratio * mb_advantages
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * mb_advantages
pg_loss = -torch.min(surr1, surr2).mean()
# Value loss
vf_loss = nn.functional.mse_loss(values, mb_returns)
# Entropy bonus
entropy_loss = -entropy.mean()
loss = pg_loss + self.value_coef * vf_loss + self.entropy_coef * entropy_loss
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
total_pg_loss += pg_loss.item()
total_vf_loss += vf_loss.item()
total_entropy += entropy.mean().item()
num_updates += 1
return {
"pg_loss": total_pg_loss / max(num_updates, 1),
"vf_loss": total_vf_loss / max(num_updates, 1),
"entropy": total_entropy / max(num_updates, 1),
}
def train(
self,
total_steps: int,
log_interval: int = 2000,
checkpoint_interval: int = 10000,
checkpoint_path: str | None = None,
) -> list[dict[str, float]]:
"""Main training loop. Returns history of stats per rollout."""
history: list[dict[str, float]] = []
buffer = RolloutBuffer()
start_time = time.time()
last_log_step = 0
while self._total_steps < total_steps:
rollout_stats = self.collect_rollout(buffer)
update_stats = self.update(buffer)
combined = {**rollout_stats, **update_stats}
history.append(combined)
# Log progress
if self._total_steps - last_log_step >= log_interval:
elapsed = time.time() - start_time
sps = self._total_steps / max(elapsed, 1.0)
print(
f"[TRAIN] steps={self._total_steps:>7d}/{total_steps} "
f"episodes={self._episodes_done:>4d} "
f"mean_reward={combined['mean_reward']:>7.3f} "
f"pg_loss={combined['pg_loss']:.4f} "
f"entropy={combined['entropy']:.2f} "
f"sps={sps:.0f}"
)
last_log_step = self._total_steps
# Checkpoint
if checkpoint_path and self._total_steps % checkpoint_interval < self.rollout_length:
self.save(checkpoint_path.replace(".pt", f"_step{self._total_steps}.pt"))
elapsed = time.time() - start_time
print(f"[TRAIN] Done. Total steps: {self._total_steps}, Time: {elapsed:.1f}s")
return history
def save(self, path: str) -> None:
"""Save policy weights and normalizer state."""
state = {"policy": self.policy.state_dict()}
if self.env.normalizer is not None:
state["normalizer"] = self.env.normalizer.state_dict()
torch.save(state, path)
print(f"[SAVE] Weights saved to {path}")
def load(self, path: str) -> None:
"""Load policy weights and normalizer state."""
state = torch.load(path, map_location="cpu", weights_only=False)
self.policy.load_state_dict(state["policy"])
if "normalizer" in state and self.env.normalizer is not None:
self.env.normalizer.load_state_dict(state["normalizer"])
print(f"[LOAD] Weights loaded from {path}")
|