singhanshuman commited on
Commit
0390c67
Β·
verified Β·
1 Parent(s): e7d0ac5

Upload simoprl/policy.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. simoprl/policy.py +134 -0
simoprl/policy.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Policy network and REINFORCE trainer.
3
+
4
+ The policy is trained using the *learned* reward model as a surrogate signal.
5
+ Rollouts are generated inside the *real* CartPole environment so that we don't
6
+ compound dynamics-model errors during policy optimisation.
7
+ Evaluation always uses the true environment reward.
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import gymnasium as gym
13
+ from pathlib import Path
14
+
15
+ STATE_DIM = 4
16
+ ACTION_DIM = 2
17
+
18
+
19
+ class PolicyNetwork(nn.Module):
20
+ def __init__(self, hidden_dim: int = 64):
21
+ super().__init__()
22
+ self.net = nn.Sequential(
23
+ nn.Linear(STATE_DIM, hidden_dim),
24
+ nn.Tanh(),
25
+ nn.Linear(hidden_dim, hidden_dim),
26
+ nn.Tanh(),
27
+ nn.Linear(hidden_dim, ACTION_DIM),
28
+ )
29
+
30
+ def forward(self, state_t: torch.Tensor) -> torch.distributions.Categorical:
31
+ logits = self.net(state_t)
32
+ return torch.distributions.Categorical(logits=logits)
33
+
34
+ def select_action(self, state: np.ndarray) -> tuple[int, torch.Tensor]:
35
+ s = torch.from_numpy(state.astype(np.float32)).unsqueeze(0)
36
+ dist = self(s)
37
+ action = dist.sample()
38
+ return int(action.item()), dist.log_prob(action)
39
+
40
+
41
+ class REINFORCETrainer:
42
+ """
43
+ REINFORCE (Williams 1992) using the learned reward model as reward signal.
44
+
45
+ Rollouts are collected in the real CartPole environment (not the dynamics
46
+ model) so we avoid compounding prediction errors. The learned reward model
47
+ replaces the true reward at training time β€” the algorithm never sees it.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ policy: PolicyNetwork,
53
+ reward_model,
54
+ lr: float = 1e-3,
55
+ gamma: float = 0.99,
56
+ ):
57
+ self.policy = policy
58
+ self.reward_model = reward_model
59
+ self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
60
+ self.gamma = gamma
61
+
62
+ def _collect_episode(self, env) -> tuple[list, list]:
63
+ """Collect one episode using learned reward, return (log_probs, returns)."""
64
+ state, _ = env.reset()
65
+ log_probs, rewards = [], []
66
+
67
+ done = False
68
+ while not done:
69
+ action, log_prob = self.policy.select_action(state)
70
+ next_state, _, terminated, truncated, _ = env.step(action)
71
+ r = self.reward_model.step_reward(state, action) # learned reward
72
+ log_probs.append(log_prob)
73
+ rewards.append(r)
74
+ state = next_state
75
+ done = terminated or truncated
76
+
77
+ # Discounted returns
78
+ G, returns = 0.0, []
79
+ for r in reversed(rewards):
80
+ G = r + self.gamma * G
81
+ returns.insert(0, G)
82
+
83
+ return log_probs, returns
84
+
85
+ def train(self, n_episodes: int = 50) -> None:
86
+ env = gym.make("CartPole-v1")
87
+ self.policy.train()
88
+
89
+ for ep in range(n_episodes):
90
+ log_probs, returns = self._collect_episode(env)
91
+ returns_t = torch.tensor(returns, dtype=torch.float32)
92
+ returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)
93
+
94
+ loss = -torch.stack([lp * R for lp, R in zip(log_probs, returns_t)]).sum()
95
+ self.optimizer.zero_grad()
96
+ loss.backward()
97
+ nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
98
+ self.optimizer.step()
99
+
100
+ env.close()
101
+ self.policy.eval()
102
+
103
+ # ── Persistence ───────────────────────────────────────────────────────
104
+
105
+ def save(self, path: str) -> None:
106
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
107
+ torch.save(self.policy.state_dict(), path)
108
+
109
+ def load(self, path: str) -> None:
110
+ self.policy.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
111
+
112
+
113
+ def evaluate_policy(policy: PolicyNetwork, n_episodes: int = 20) -> tuple[float, float]:
114
+ """
115
+ Evaluate using the TRUE CartPole environment reward.
116
+ This is only for measurement β€” the algorithm never calls this during training.
117
+ """
118
+ env = gym.make("CartPole-v1")
119
+ policy.eval()
120
+ returns = []
121
+
122
+ for _ in range(n_episodes):
123
+ state, _ = env.reset()
124
+ total, done = 0.0, False
125
+ while not done:
126
+ with torch.no_grad():
127
+ action, _ = policy.select_action(state)
128
+ state, reward, terminated, truncated, _ = env.step(action)
129
+ total += reward
130
+ done = terminated or truncated
131
+ returns.append(total)
132
+
133
+ env.close()
134
+ return float(np.mean(returns)), float(np.std(returns))