Spaces:
Sleeping
Sleeping
File size: 4,630 Bytes
0390c67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """
Policy network and REINFORCE trainer.
The policy is trained using the *learned* reward model as a surrogate signal.
Rollouts are generated inside the *real* CartPole environment so that we don't
compound dynamics-model errors during policy optimisation.
Evaluation always uses the true environment reward.
"""
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
from pathlib import Path
STATE_DIM = 4
ACTION_DIM = 2
class PolicyNetwork(nn.Module):
def __init__(self, hidden_dim: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(STATE_DIM, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, ACTION_DIM),
)
def forward(self, state_t: torch.Tensor) -> torch.distributions.Categorical:
logits = self.net(state_t)
return torch.distributions.Categorical(logits=logits)
def select_action(self, state: np.ndarray) -> tuple[int, torch.Tensor]:
s = torch.from_numpy(state.astype(np.float32)).unsqueeze(0)
dist = self(s)
action = dist.sample()
return int(action.item()), dist.log_prob(action)
class REINFORCETrainer:
"""
REINFORCE (Williams 1992) using the learned reward model as reward signal.
Rollouts are collected in the real CartPole environment (not the dynamics
model) so we avoid compounding prediction errors. The learned reward model
replaces the true reward at training time β the algorithm never sees it.
"""
def __init__(
self,
policy: PolicyNetwork,
reward_model,
lr: float = 1e-3,
gamma: float = 0.99,
):
self.policy = policy
self.reward_model = reward_model
self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
self.gamma = gamma
def _collect_episode(self, env) -> tuple[list, list]:
"""Collect one episode using learned reward, return (log_probs, returns)."""
state, _ = env.reset()
log_probs, rewards = [], []
done = False
while not done:
action, log_prob = self.policy.select_action(state)
next_state, _, terminated, truncated, _ = env.step(action)
r = self.reward_model.step_reward(state, action) # learned reward
log_probs.append(log_prob)
rewards.append(r)
state = next_state
done = terminated or truncated
# Discounted returns
G, returns = 0.0, []
for r in reversed(rewards):
G = r + self.gamma * G
returns.insert(0, G)
return log_probs, returns
def train(self, n_episodes: int = 50) -> None:
env = gym.make("CartPole-v1")
self.policy.train()
for ep in range(n_episodes):
log_probs, returns = self._collect_episode(env)
returns_t = torch.tensor(returns, dtype=torch.float32)
returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)
loss = -torch.stack([lp * R for lp, R in zip(log_probs, returns_t)]).sum()
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
self.optimizer.step()
env.close()
self.policy.eval()
# ββ Persistence βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def save(self, path: str) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
torch.save(self.policy.state_dict(), path)
def load(self, path: str) -> None:
self.policy.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
def evaluate_policy(policy: PolicyNetwork, n_episodes: int = 20) -> tuple[float, float]:
"""
Evaluate using the TRUE CartPole environment reward.
This is only for measurement β the algorithm never calls this during training.
"""
env = gym.make("CartPole-v1")
policy.eval()
returns = []
for _ in range(n_episodes):
state, _ = env.reset()
total, done = 0.0, False
while not done:
with torch.no_grad():
action, _ = policy.select_action(state)
state, reward, terminated, truncated, _ = env.step(action)
total += reward
done = terminated or truncated
returns.append(total)
env.close()
return float(np.mean(returns)), float(np.std(returns))
|