singhanshuman commited on
Commit
e7d0ac5
Β·
verified Β·
1 Parent(s): 80c82ec

Upload simoprl/preference_elicitation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. simoprl/preference_elicitation.py +181 -0
simoprl/preference_elicitation.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Three preference elicitation strategies from the Sim-OPRL paper.
3
+
4
+ UniformOPRL β€” random pairs from offline dataset (naΓ―ve baseline)
5
+ UncertaintyOPRL β€” pairs where the reward model is most uncertain (active baseline)
6
+ SimOPRL β€” simulate new trajectories with the dynamics model;
7
+ optimistic on reward uncertainty, pessimistic on
8
+ transition uncertainty (the paper's contribution)
9
+ """
10
+ import random
11
+ import numpy as np
12
+ from .reward_model import EnsembleRewardModel
13
+ from .dynamics_model import EnsembleDynamicsModel
14
+
15
+ # CartPole termination thresholds (from gymnasium source)
16
+ _X_THRESH = 2.4
17
+ _THETA_THRESH = 12 * np.pi / 180 # 0.2094 rad
18
+
19
+
20
+ def _cartpole_quality(trajectory: list) -> int:
21
+ """
22
+ Physics-based quality for a CartPole trajectory (real or simulated).
23
+
24
+ Returns the number of steps the pole stays within CartPole's valid bounds.
25
+ This is the TRUE reward for CartPole (1 per surviving step), correctly
26
+ evaluated even for simulated trajectories whose length is always `horizon`.
27
+
28
+ Using len(traj) would give identical scores for all simulated trajectories
29
+ because they are all rolled out to the same fixed horizon β€” making oracle
30
+ labels meaningless. This function fixes that.
31
+ """
32
+ count = 0
33
+ for s, a in trajectory:
34
+ x, _, theta, _ = s
35
+ if abs(x) > _X_THRESH or abs(theta) > _THETA_THRESH:
36
+ break
37
+ count += 1
38
+ return count
39
+
40
+
41
+ def oracle_preference(traj1: list, traj2: list, stochastic: bool = False) -> int:
42
+ """
43
+ Simulated oracle using true CartPole physics to label preferences.
44
+
45
+ label = 0 β†’ traj1 preferred
46
+ label = 1 β†’ traj2 preferred
47
+ """
48
+ r1 = _cartpole_quality(traj1)
49
+ r2 = _cartpole_quality(traj2)
50
+
51
+ if r1 == r2:
52
+ return random.randint(0, 1)
53
+
54
+ if stochastic:
55
+ p = 1.0 / (1.0 + np.exp(-(r1 - r2)))
56
+ return 0 if np.random.random() < p else 1
57
+ return 0 if r1 > r2 else 1
58
+
59
+
60
+ # ─────────────────────────────────────────────────────────────────────────────
61
+
62
+ class UniformOPRL:
63
+ """Randomly sample trajectory pairs from the offline dataset."""
64
+
65
+ def __init__(self, dataset: list):
66
+ self.trajectories = [[(s, a) for s, a, ns in traj] for traj in dataset]
67
+
68
+ def get_query_pair(self, reward_model=None, policy_fn=None):
69
+ idx1, idx2 = random.sample(range(len(self.trajectories)), 2)
70
+ return self.trajectories[idx1], self.trajectories[idx2]
71
+
72
+
73
+ # ─────────────────────────────────────────────────────────────────────────────
74
+
75
+ class UncertaintyOPRL:
76
+ """
77
+ Sample from offline dataset; prioritise pairs with high reward-model
78
+ uncertainty (disagreement across ensemble members).
79
+ """
80
+
81
+ def __init__(self, dataset: list, n_candidates: int = 64):
82
+ self.trajectories = [[(s, a) for s, a, ns in traj] for traj in dataset]
83
+ self.n_candidates = n_candidates
84
+
85
+ def get_query_pair(self, reward_model, policy_fn=None):
86
+ candidates = random.sample(self.trajectories,
87
+ min(self.n_candidates, len(self.trajectories)))
88
+ uncs = np.array([reward_model.predict_return(t)[1] for t in candidates])
89
+ top2 = np.argsort(uncs)[-2:]
90
+ return candidates[top2[0]], candidates[top2[1]]
91
+
92
+
93
+ # ─────────────────────────────────────────────────────────────────────────────
94
+
95
+ class SimOPRL:
96
+ """
97
+ Core contribution of the paper.
98
+
99
+ Instead of querying the offline dataset directly, the agent:
100
+ 1. Samples starting states from the offline dataset (preferring upright pole
101
+ angles so simulated trajectories are long enough to differentiate).
102
+ 2. Simulates new trajectories using the learned dynamics model, using a
103
+ mix of the current policy and random actions for diversity.
104
+ 3. Scores each trajectory by:
105
+ score = reward_uncertainty βˆ’ Ξ» Β· transition_uncertainty
106
+
107
+ reward_uncertainty (optimistic) β†’ query where we learn the most
108
+ transition_uncertainty (pessimistic) β†’ avoid OOD regions
109
+
110
+ The pair with the highest score is the most informative query to label.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ dataset: list,
116
+ dynamics_model: EnsembleDynamicsModel,
117
+ horizon: int = 40,
118
+ n_simulated: int = 50,
119
+ lambda_: float = 0.5,
120
+ epsilon: float = 0.3, # exploration in simulated rollouts
121
+ ):
122
+ self.dynamics_model = dynamics_model
123
+ self.horizon = horizon
124
+ self.n_simulated = n_simulated
125
+ self.lambda_ = lambda_
126
+ self.epsilon = epsilon
127
+
128
+ # Prefer near-upright starting states: they produce longer, more
129
+ # informative trajectories. Using near-failure states means simulated
130
+ # trajectories all score 0 and the oracle can't distinguish them.
131
+ all_states = [s.copy() for traj in dataset for s, a, ns in traj]
132
+ upright = [s for s in all_states if abs(s[2]) < _THETA_THRESH * 0.7]
133
+ self.start_states = upright if len(upright) > 20 else all_states
134
+
135
+ def _simulate_trajectory(self, start_state, policy_fn):
136
+ """
137
+ Roll out one trajectory from start_state using the dynamics model.
138
+ Actions are chosen by policy_fn with epsilon-greedy exploration.
139
+
140
+ Returns (trajectory, avg_transition_uncertainty).
141
+ """
142
+ state = start_state.copy()
143
+ trajectory = []
144
+ total_trans_unc = 0.0
145
+
146
+ for _ in range(self.horizon):
147
+ # Epsilon-greedy: explore with random actions for diversity
148
+ if np.random.random() < self.epsilon:
149
+ action = np.random.randint(2)
150
+ else:
151
+ action = int(policy_fn(state))
152
+
153
+ next_state, trans_unc = self.dynamics_model.predict(state, action)
154
+ trajectory.append((state.copy(), action))
155
+ total_trans_unc += trans_unc
156
+ state = next_state
157
+
158
+ # Early stop if predicted state is clearly out of CartPole bounds
159
+ # (avoids accumulating dynamics errors past the point of no return)
160
+ if abs(state[0]) > _X_THRESH * 1.5 or abs(state[2]) > _THETA_THRESH * 2:
161
+ break
162
+
163
+ avg_trans_unc = total_trans_unc / max(len(trajectory), 1)
164
+ return trajectory, avg_trans_unc
165
+
166
+ def get_query_pair(self, reward_model, policy_fn):
167
+ """
168
+ Generate n_simulated candidate trajectories and return the best pair.
169
+ """
170
+ candidates = []
171
+ for _ in range(self.n_simulated):
172
+ start = random.choice(self.start_states)
173
+ traj, trans_unc = self._simulate_trajectory(start, policy_fn)
174
+ _, reward_unc = reward_model.predict_return(traj)
175
+
176
+ # Sim-OPRL acquisition score
177
+ score = reward_unc - self.lambda_ * trans_unc
178
+ candidates.append((traj, score))
179
+
180
+ candidates.sort(key=lambda x: x[1], reverse=True)
181
+ return candidates[0][0], candidates[1][0]