File size: 4,630 Bytes
0390c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Policy network and REINFORCE trainer.

The policy is trained using the *learned* reward model as a surrogate signal.
Rollouts are generated inside the *real* CartPole environment so that we don't
compound dynamics-model errors during policy optimisation.
Evaluation always uses the true environment reward.
"""
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
from pathlib import Path

STATE_DIM = 4
ACTION_DIM = 2


class PolicyNetwork(nn.Module):
    def __init__(self, hidden_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(STATE_DIM, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, ACTION_DIM),
        )

    def forward(self, state_t: torch.Tensor) -> torch.distributions.Categorical:
        logits = self.net(state_t)
        return torch.distributions.Categorical(logits=logits)

    def select_action(self, state: np.ndarray) -> tuple[int, torch.Tensor]:
        s = torch.from_numpy(state.astype(np.float32)).unsqueeze(0)
        dist = self(s)
        action = dist.sample()
        return int(action.item()), dist.log_prob(action)


class REINFORCETrainer:
    """
    REINFORCE (Williams 1992) using the learned reward model as reward signal.

    Rollouts are collected in the real CartPole environment (not the dynamics
    model) so we avoid compounding prediction errors. The learned reward model
    replaces the true reward at training time β€” the algorithm never sees it.
    """

    def __init__(
        self,
        policy: PolicyNetwork,
        reward_model,
        lr: float = 1e-3,
        gamma: float = 0.99,
    ):
        self.policy = policy
        self.reward_model = reward_model
        self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
        self.gamma = gamma

    def _collect_episode(self, env) -> tuple[list, list]:
        """Collect one episode using learned reward, return (log_probs, returns)."""
        state, _ = env.reset()
        log_probs, rewards = [], []

        done = False
        while not done:
            action, log_prob = self.policy.select_action(state)
            next_state, _, terminated, truncated, _ = env.step(action)
            r = self.reward_model.step_reward(state, action)   # learned reward
            log_probs.append(log_prob)
            rewards.append(r)
            state = next_state
            done = terminated or truncated

        # Discounted returns
        G, returns = 0.0, []
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)

        return log_probs, returns

    def train(self, n_episodes: int = 50) -> None:
        env = gym.make("CartPole-v1")
        self.policy.train()

        for ep in range(n_episodes):
            log_probs, returns = self._collect_episode(env)
            returns_t = torch.tensor(returns, dtype=torch.float32)
            returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)

            loss = -torch.stack([lp * R for lp, R in zip(log_probs, returns_t)]).sum()
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
            self.optimizer.step()

        env.close()
        self.policy.eval()

    # ── Persistence ───────────────────────────────────────────────────────

    def save(self, path: str) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        torch.save(self.policy.state_dict(), path)

    def load(self, path: str) -> None:
        self.policy.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))


def evaluate_policy(policy: PolicyNetwork, n_episodes: int = 20) -> tuple[float, float]:
    """
    Evaluate using the TRUE CartPole environment reward.
    This is only for measurement β€” the algorithm never calls this during training.
    """
    env = gym.make("CartPole-v1")
    policy.eval()
    returns = []

    for _ in range(n_episodes):
        state, _ = env.reset()
        total, done = 0.0, False
        while not done:
            with torch.no_grad():
                action, _ = policy.select_action(state)
            state, reward, terminated, truncated, _ = env.step(action)
            total += reward
            done = terminated or truncated
        returns.append(total)

    env.close()
    return float(np.mean(returns)), float(np.std(returns))