sim-oprl / simoprl /policy.py
singhanshuman's picture
Upload simoprl/policy.py with huggingface_hub
0390c67 verified
"""
Policy network and REINFORCE trainer.
The policy is trained using the *learned* reward model as a surrogate signal.
Rollouts are generated inside the *real* CartPole environment so that we don't
compound dynamics-model errors during policy optimisation.
Evaluation always uses the true environment reward.
"""
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
from pathlib import Path
STATE_DIM = 4
ACTION_DIM = 2
class PolicyNetwork(nn.Module):
def __init__(self, hidden_dim: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(STATE_DIM, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, ACTION_DIM),
)
def forward(self, state_t: torch.Tensor) -> torch.distributions.Categorical:
logits = self.net(state_t)
return torch.distributions.Categorical(logits=logits)
def select_action(self, state: np.ndarray) -> tuple[int, torch.Tensor]:
s = torch.from_numpy(state.astype(np.float32)).unsqueeze(0)
dist = self(s)
action = dist.sample()
return int(action.item()), dist.log_prob(action)
class REINFORCETrainer:
"""
REINFORCE (Williams 1992) using the learned reward model as reward signal.
Rollouts are collected in the real CartPole environment (not the dynamics
model) so we avoid compounding prediction errors. The learned reward model
replaces the true reward at training time β€” the algorithm never sees it.
"""
def __init__(
self,
policy: PolicyNetwork,
reward_model,
lr: float = 1e-3,
gamma: float = 0.99,
):
self.policy = policy
self.reward_model = reward_model
self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
self.gamma = gamma
def _collect_episode(self, env) -> tuple[list, list]:
"""Collect one episode using learned reward, return (log_probs, returns)."""
state, _ = env.reset()
log_probs, rewards = [], []
done = False
while not done:
action, log_prob = self.policy.select_action(state)
next_state, _, terminated, truncated, _ = env.step(action)
r = self.reward_model.step_reward(state, action) # learned reward
log_probs.append(log_prob)
rewards.append(r)
state = next_state
done = terminated or truncated
# Discounted returns
G, returns = 0.0, []
for r in reversed(rewards):
G = r + self.gamma * G
returns.insert(0, G)
return log_probs, returns
def train(self, n_episodes: int = 50) -> None:
env = gym.make("CartPole-v1")
self.policy.train()
for ep in range(n_episodes):
log_probs, returns = self._collect_episode(env)
returns_t = torch.tensor(returns, dtype=torch.float32)
returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)
loss = -torch.stack([lp * R for lp, R in zip(log_probs, returns_t)]).sum()
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
self.optimizer.step()
env.close()
self.policy.eval()
# ── Persistence ───────────────────────────────────────────────────────
def save(self, path: str) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
torch.save(self.policy.state_dict(), path)
def load(self, path: str) -> None:
self.policy.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
def evaluate_policy(policy: PolicyNetwork, n_episodes: int = 20) -> tuple[float, float]:
"""
Evaluate using the TRUE CartPole environment reward.
This is only for measurement β€” the algorithm never calls this during training.
"""
env = gym.make("CartPole-v1")
policy.eval()
returns = []
for _ in range(n_episodes):
state, _ = env.reset()
total, done = 0.0, False
while not done:
with torch.no_grad():
action, _ = policy.select_action(state)
state, reward, terminated, truncated, _ = env.step(action)
total += reward
done = terminated or truncated
returns.append(total)
env.close()
return float(np.mean(returns)), float(np.std(returns))