singhanshuman commited on
Commit
80c82ec
Β·
verified Β·
1 Parent(s): 248472d

Upload simoprl/reward_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. simoprl/reward_model.py +140 -0
simoprl/reward_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ STATE_DIM = 4
7
+ ACTION_DIM = 2
8
+ STEP_INPUT_DIM = STATE_DIM + ACTION_DIM # 6
9
+
10
+
11
+ class _StepRewardNet(nn.Module):
12
+ """
13
+ Predicts a per-step reward scalar from (state, action).
14
+ Trajectory reward = Ξ£_t r(s_t, a_t).
15
+ """
16
+
17
+ def __init__(self, hidden_dim: int = 256):
18
+ super().__init__()
19
+ self.net = nn.Sequential(
20
+ nn.Linear(STEP_INPUT_DIM, hidden_dim),
21
+ nn.SiLU(),
22
+ nn.Linear(hidden_dim, hidden_dim),
23
+ nn.SiLU(),
24
+ nn.Linear(hidden_dim, 1),
25
+ )
26
+
27
+ def forward(self, sa: torch.Tensor) -> torch.Tensor:
28
+ """sa: (batch, 6) β†’ (batch,)"""
29
+ return self.net(sa).squeeze(-1)
30
+
31
+ def trajectory_return(self, trajectory: list) -> torch.Tensor:
32
+ """
33
+ trajectory: list of (state: np.ndarray, action: int)
34
+ Returns scalar tensor (summed per-step reward).
35
+ """
36
+ sa_pairs = []
37
+ for s, a in trajectory:
38
+ action_oh = np.eye(ACTION_DIM, dtype=np.float32)[int(a)]
39
+ sa_pairs.append(np.concatenate([s.astype(np.float32), action_oh]))
40
+ sa_t = torch.from_numpy(np.stack(sa_pairs)) # (T, 6)
41
+ return self.forward(sa_t).sum() # scalar
42
+
43
+
44
+ class EnsembleRewardModel:
45
+ """
46
+ Ensemble of step-reward networks trained with the Bradley-Terry loss on
47
+ human (or oracle) trajectory preferences.
48
+
49
+ Bradley-Terry: P(Ο„1 ≻ Ο„2) = Οƒ(R(Ο„1) βˆ’ R(Ο„2))
50
+ Loss = βˆ’log Οƒ(R(preferred) βˆ’ R(rejected))
51
+
52
+ Uncertainty = std of ensemble return predictions β€” high β†’ reward model
53
+ is uncertain β†’ this pair is informative to query.
54
+ """
55
+
56
+ def __init__(self, n_models: int = 3, hidden_dim: int = 256, lr: float = 3e-4):
57
+ self.n_models = n_models
58
+ self.models = [_StepRewardNet(hidden_dim) for _ in range(n_models)]
59
+ self.optimizers = [torch.optim.Adam(m.parameters(), lr=lr, weight_decay=1e-4)
60
+ for m in self.models]
61
+ self.preference_buffer: list[tuple] = [] # (traj1, traj2, label)
62
+
63
+ # ── Preference buffer ──────────────────────────────────────────────────
64
+
65
+ def add_preference(self, traj1: list, traj2: list, label: int) -> None:
66
+ """
67
+ label = 0 β†’ traj1 preferred
68
+ label = 1 β†’ traj2 preferred
69
+ """
70
+ self.preference_buffer.append((traj1, traj2, int(label)))
71
+
72
+ # ── Training ──────────────────────────────────────────────────────────
73
+
74
+ def update(self, n_epochs: int = 20) -> float:
75
+ """Re-train all ensemble members on current preference buffer."""
76
+ if len(self.preference_buffer) < 2:
77
+ return float("nan")
78
+
79
+ total_loss = 0.0
80
+ for model, opt in zip(self.models, self.optimizers):
81
+ for _ in range(n_epochs):
82
+ perm = np.random.permutation(len(self.preference_buffer))
83
+ epoch_loss = 0.0
84
+ for idx in perm:
85
+ traj1, traj2, label = self.preference_buffer[idx]
86
+ r1 = model.trajectory_return(traj1)
87
+ r2 = model.trajectory_return(traj2)
88
+
89
+ # Bradley-Terry loss
90
+ if label == 0:
91
+ loss = -torch.log(torch.sigmoid(r1 - r2) + 1e-8)
92
+ else:
93
+ loss = -torch.log(torch.sigmoid(r2 - r1) + 1e-8)
94
+
95
+ opt.zero_grad()
96
+ loss.backward()
97
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
98
+ opt.step()
99
+ epoch_loss += loss.item()
100
+ total_loss += epoch_loss
101
+
102
+ return total_loss / (self.n_models * n_epochs * len(self.preference_buffer))
103
+
104
+ # ── Inference ─────────────────────────────────────────────────────────
105
+
106
+ def predict_return(self, trajectory: list[tuple]) -> tuple[float, float]:
107
+ """
108
+ Returns (mean_return, reward_uncertainty).
109
+ reward_uncertainty = std of ensemble predictions.
110
+ """
111
+ returns = []
112
+ with torch.no_grad():
113
+ for model in self.models:
114
+ r = model.trajectory_return(trajectory)
115
+ returns.append(r.item())
116
+ return float(np.mean(returns)), float(np.std(returns))
117
+
118
+ def step_reward(self, state: np.ndarray, action: int) -> float:
119
+ """Mean per-step reward across ensemble (used during policy training)."""
120
+ action_oh = np.eye(ACTION_DIM, dtype=np.float32)[int(action)]
121
+ sa = torch.from_numpy(np.concatenate([state.astype(np.float32), action_oh])).unsqueeze(0)
122
+ with torch.no_grad():
123
+ rewards = [m(sa).item() for m in self.models]
124
+ return float(np.mean(rewards))
125
+
126
+ # ── Persistence ───────────────────────────────────────────────────────
127
+
128
+ def save(self, path: str) -> None:
129
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
130
+ payload = {
131
+ "state_dicts": [m.state_dict() for m in self.models],
132
+ "preference_buffer": self.preference_buffer,
133
+ }
134
+ torch.save(payload, path)
135
+
136
+ def load(self, path: str) -> None:
137
+ payload = torch.load(path, map_location="cpu", weights_only=False)
138
+ for model, sd in zip(self.models, payload["state_dicts"]):
139
+ model.load_state_dict(sd)
140
+ self.preference_buffer = payload.get("preference_buffer", [])