singhanshuman commited on
Commit
f41070a
Β·
verified Β·
1 Parent(s): 86207d3

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py β€” Interactive Sim-OPRL demo (Gradio).
3
+
4
+ The professor clicks which of two CartPole trajectories she prefers.
5
+ Each click updates the Bradley-Terry reward model.
6
+ Every 5 clicks, the policy is retrained using REINFORCE on the learned reward.
7
+ The agent's performance (true CartPole reward) is plotted live.
8
+
9
+ Deploy: gradio app.py or python app.py
10
+ HuggingFace Spaces: push this repo; set app.py as the entrypoint.
11
+ """
12
+ import os
13
+ import pickle
14
+ import random
15
+ import tempfile
16
+ import numpy as np
17
+ import torch
18
+ import gymnasium as gym
19
+ import imageio
20
+ import matplotlib
21
+ matplotlib.use("Agg")
22
+ import matplotlib.pyplot as plt
23
+ import gradio as gr
24
+ from pathlib import Path
25
+
26
+ from simoprl.collect_data import collect_offline_dataset, load_dataset
27
+ from simoprl.dynamics_model import EnsembleDynamicsModel
28
+ from simoprl.reward_model import EnsembleRewardModel
29
+ from simoprl.preference_elicitation import SimOPRL
30
+ from simoprl.policy import PolicyNetwork, REINFORCETrainer, evaluate_policy
31
+
32
+ # ── Paths ─────────────────────────────────────────────────────────────────────
33
+ DATA_PATH = "data/offline_dataset.pkl"
34
+ DYN_MODEL_PATH = "models/dynamics_model.pt"
35
+ RESULTS_PATH = "results/experiment_results.pkl"
36
+
37
+ # ── Global mutable state (single-user demo) ───────────────────────────────────
38
+ class _State:
39
+ dynamics_model: EnsembleDynamicsModel = None
40
+ reward_model: EnsembleRewardModel = None
41
+ policy: PolicyNetwork = None
42
+ trainer: REINFORCETrainer = None
43
+ elicitor: SimOPRL = None
44
+ dataset: list = None
45
+ query_count: int = 0
46
+ return_history: list = [] # [(n_queries, mean_return)]
47
+ current_traj1: list = None
48
+ current_traj2: list = None
49
+ initialized: bool = False
50
+
51
+ S = _State()
52
+
53
+
54
+ # ── Setup ─────────────────────────────────────────────────────────────────────
55
+
56
+ def _setup():
57
+ """Train / load all components. Called once at startup."""
58
+ if S.initialized:
59
+ return
60
+
61
+ # 1. Dataset
62
+ if Path(DATA_PATH).exists():
63
+ S.dataset = load_dataset(DATA_PATH)
64
+ print(f"Dataset loaded: {len(S.dataset)} trajectories")
65
+ else:
66
+ print("Collecting offline dataset …")
67
+ S.dataset = collect_offline_dataset(n_trajectories=800, save_path=DATA_PATH)
68
+
69
+ # 2. Dynamics model (pre-trained; central to Sim-OPRL)
70
+ S.dynamics_model = EnsembleDynamicsModel(n_models=5)
71
+ if Path(DYN_MODEL_PATH).exists():
72
+ S.dynamics_model.load(DYN_MODEL_PATH)
73
+ else:
74
+ print("Training dynamics model (first run β€” this takes a few minutes) …")
75
+ S.dynamics_model.train(S.dataset, n_epochs=100)
76
+ S.dynamics_model.save(DYN_MODEL_PATH)
77
+
78
+ # 3. Reward model β€” starts blank; shaped entirely by the professor's clicks
79
+ S.reward_model = EnsembleRewardModel(n_models=3)
80
+
81
+ # 4. Policy β€” starts random; improves as reward model learns
82
+ S.policy = PolicyNetwork()
83
+ S.trainer = REINFORCETrainer(S.policy, S.reward_model, lr=1e-3)
84
+
85
+ # 5. Sim-OPRL elicitor
86
+ S.elicitor = SimOPRL(S.dataset, S.dynamics_model, horizon=50, n_simulated=40, lambda_=1.0)
87
+
88
+ S.initialized = True
89
+ print("Setup complete.")
90
+
91
+
92
+ # ── Trajectory simulation & rendering ────────────────────────────────────────
93
+
94
+ def _current_policy_fn(state: np.ndarray) -> int:
95
+ if S.query_count < 5:
96
+ return np.random.randint(2)
97
+ action, _ = S.policy.select_action(state)
98
+ return action
99
+
100
+
101
+ def _render_trajectory_to_gif(trajectory, path, fps=20) -> str:
102
+ """
103
+ Render a (state, action) trajectory to a GIF using CartPole's rgb_array renderer.
104
+ For simulated trajectories the env state is set at each step.
105
+ """
106
+ env = gym.make("CartPole-v1", render_mode="rgb_array")
107
+ env.reset()
108
+
109
+ frames = []
110
+ for state_arr, action in trajectory:
111
+ # Clip to renderable range (dynamics model may predict slightly OOB states)
112
+ clipped = np.array([
113
+ np.clip(state_arr[0], -4.8, 4.8),
114
+ np.clip(state_arr[1], -10.0, 10.0),
115
+ np.clip(state_arr[2], -0.5, 0.5),
116
+ np.clip(state_arr[3], -10.0, 10.0),
117
+ ], dtype=np.float64)
118
+ env.unwrapped.state = clipped
119
+ frames.append(env.render())
120
+
121
+ env.close()
122
+
123
+ duration = 1.0 / fps
124
+ imageio.mimwrite(path, frames, format="GIF", duration=duration, loop=0)
125
+ return path
126
+
127
+
128
+ def _generate_and_render_pair() -> tuple[str, str]:
129
+ """Ask Sim-OPRL for the next query pair and render both as GIFs."""
130
+ traj1, traj2 = S.elicitor.get_query_pair(S.reward_model, _current_policy_fn)
131
+ S.current_traj1 = traj1
132
+ S.current_traj2 = traj2
133
+
134
+ path1 = _render_trajectory_to_gif(traj1, "/tmp/simoprl_traj_A.gif")
135
+ path2 = _render_trajectory_to_gif(traj2, "/tmp/simoprl_traj_B.gif")
136
+ return path1, path2
137
+
138
+
139
+ # ── Plot ──────────────────────────────────────────────────────────────────────
140
+
141
+ def _make_return_plot() -> plt.Figure:
142
+ fig, ax = plt.subplots(figsize=(9, 4))
143
+ ax.set_facecolor("#f5f5f5")
144
+ fig.patch.set_facecolor("white")
145
+
146
+ ax.axhline(y=21, color="#aaa", linestyle=":", linewidth=1.2, label="Random policy (~21 steps)")
147
+ ax.axhline(y=500, color="#2ca02c", linestyle="--", linewidth=1, alpha=0.5, label="Max return (500)")
148
+
149
+ if S.return_history:
150
+ qs = [x[0] for x in S.return_history]
151
+ means = np.array([x[1] for x in S.return_history])
152
+ ax.plot(qs, means, "o-", color="#1f77b4", linewidth=2.5, markersize=7,
153
+ label="Sim-OPRL (your preferences)")
154
+ ax.fill_between(qs, means * 0.85, means * 1.15, alpha=0.15, color="#1f77b4")
155
+
156
+ ax.set_xlabel("Number of Preference Queries", fontsize=12)
157
+ ax.set_ylabel("Policy Return (True Reward)", fontsize=12)
158
+ ax.set_title("How your preferences shape the agent", fontsize=13, fontweight="bold")
159
+ ax.set_ylim(0, 530)
160
+ ax.legend(fontsize=10, framealpha=0.9)
161
+ ax.grid(True, alpha=0.3, linestyle="--")
162
+ plt.tight_layout()
163
+ return fig
164
+
165
+
166
+ def _make_comparison_plot() -> plt.Figure:
167
+ """Show pre-computed baseline comparison if results exist."""
168
+ if not Path(RESULTS_PATH).exists():
169
+ fig, ax = plt.subplots(figsize=(9, 3))
170
+ ax.text(0.5, 0.5, "Run python train.py to generate comparison figure",
171
+ ha="center", va="center", transform=ax.transAxes, fontsize=12, color="gray")
172
+ ax.axis("off")
173
+ return fig
174
+
175
+ with open(RESULTS_PATH, "rb") as f:
176
+ data = pickle.load(f)
177
+
178
+ results = data["results"]
179
+ checkpoints = sorted(data["checkpoints"])
180
+ colors = {"uniform": "#d62728", "uncertainty": "#ff7f0e", "simoprl": "#1f77b4"}
181
+ labels = {"uniform": "Uniform OPRL", "uncertainty": "Uncertainty OPRL", "simoprl": "Sim-OPRL (paper)"}
182
+
183
+ fig, ax = plt.subplots(figsize=(9, 4))
184
+ ax.set_facecolor("#f5f5f5")
185
+ for method in ["uniform", "uncertainty", "simoprl"]:
186
+ if method not in results:
187
+ continue
188
+ seed_results = results[method]
189
+ qs = checkpoints
190
+ means = np.array([np.mean([r.get(q, np.nan) for r in seed_results]) for q in qs])
191
+ stds = np.array([np.std([r.get(q, np.nan) for r in seed_results]) for q in qs])
192
+ ax.plot(qs, means, "-o", color=colors[method], linewidth=2 if method == "simoprl" else 1.5,
193
+ markersize=5, label=labels[method])
194
+ ax.fill_between(qs, means - stds, means + stds, alpha=0.12, color=colors[method])
195
+
196
+ ax.axhline(y=500, color="green", linestyle="--", linewidth=1, alpha=0.5)
197
+ ax.set_xlabel("Preference Queries", fontsize=11)
198
+ ax.set_ylabel("Policy Return", fontsize=11)
199
+ ax.set_title("Sim-OPRL vs baselines (oracle preferences, 5 seeds)", fontsize=12, fontweight="bold")
200
+ ax.legend(fontsize=10)
201
+ ax.grid(True, alpha=0.3)
202
+ plt.tight_layout()
203
+ return fig
204
+
205
+
206
+ # ── Gradio handlers ───────────────────────────────────────────────────────────
207
+
208
+ def on_load():
209
+ _setup()
210
+ gif1, gif2 = _generate_and_render_pair()
211
+ plot = _make_return_plot()
212
+ status = "Ready β€” click which trajectory keeps the pole balanced longer."
213
+ return gif1, gif2, plot, status, _make_comparison_plot()
214
+
215
+
216
+ def on_preference(preferred: str):
217
+ """Called when professor clicks 'Prefer A' or 'Prefer B'."""
218
+ if S.current_traj1 is None:
219
+ return on_load()
220
+
221
+ label = 0 if preferred == "A" else 1
222
+ S.reward_model.add_preference(S.current_traj1, S.current_traj2, label)
223
+ S.reward_model.update(n_epochs=15)
224
+ S.query_count += 1
225
+
226
+ status = f"Query {S.query_count}: you preferred {'A' if label == 0 else 'B'}."
227
+
228
+ # Retrain policy every 5 queries
229
+ if S.query_count % 5 == 0:
230
+ status += " Updating policy …"
231
+ S.trainer.train(n_episodes=40)
232
+ mean_ret, _ = evaluate_policy(S.policy, n_episodes=15)
233
+ S.return_history.append((S.query_count, mean_ret))
234
+ status += f" Policy return: {mean_ret:.1f}"
235
+
236
+ gif1, gif2 = _generate_and_render_pair()
237
+ return gif1, gif2, _make_return_plot(), status, _make_comparison_plot()
238
+
239
+
240
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
241
+
242
+ with gr.Blocks(title="Sim-OPRL Demo", theme=gr.themes.Soft()) as demo:
243
+
244
+ gr.Markdown("""
245
+ # Sim-OPRL: Preference Elicitation for Offline RL
246
+ ### Pace Β· SchΓΆlkopf Β· RΓ€tsch Β· Ramponi β€” ICLR 2025
247
+
248
+ Two CartPole trajectories are simulated by a learned **dynamics model**, chosen by the
249
+ **Sim-OPRL** acquisition strategy: high reward uncertainty (we learn the most here)
250
+ and low transition uncertainty (the dynamics model is reliable here).
251
+
252
+ **Click which run keeps the pole balanced longer.**
253
+ Your preferences directly train the reward model via the Bradley-Terry loss.
254
+ Every 5 clicks, the policy is re-optimised with REINFORCE on the learned reward.
255
+ """)
256
+
257
+ with gr.Row(equal_height=True):
258
+ with gr.Column():
259
+ vid_A = gr.Image(label="Trajectory A", type="filepath")
260
+ btn_A = gr.Button("β¬… Prefer A", variant="primary", size="lg")
261
+ with gr.Column():
262
+ vid_B = gr.Image(label="Trajectory B", type="filepath")
263
+ btn_B = gr.Button("Prefer B ➑", variant="primary", size="lg")
264
+
265
+ status_box = gr.Textbox(label="Status", interactive=False, lines=1)
266
+
267
+ with gr.Tabs():
268
+ with gr.TabItem("Live: Your Preferences β†’ Agent Return"):
269
+ live_plot = gr.Plot(label="Return vs Queries (updates every 5 clicks)")
270
+ with gr.TabItem("Baseline Comparison (from train.py)"):
271
+ comparison_plot = gr.Plot(label="Sim-OPRL vs Uniform OPRL vs Uncertainty OPRL")
272
+
273
+ gr.Markdown("""
274
+ ---
275
+ ### How Sim-OPRL works
276
+
277
+ | Step | What happens |
278
+ |------|--------------|
279
+ | 1 | Collect an unlabelled offline dataset (no rewards) |
280
+ | 2 | Train an **ensemble dynamics model** on the dataset |
281
+ | 3 | For each query: simulate trajectories, score by `reward_uncertainty βˆ’ Ξ» Β· transition_uncertainty` |
282
+ | 4 | Ask for a preference on the highest-scoring pair |
283
+ | 5 | Update the **Bradley-Terry reward model** with the preference |
284
+ | 6 | Re-optimise the policy with REINFORCE on the learned reward |
285
+
286
+ Sim-OPRL reaches higher returns with **fewer queries** than naΓ―ve baselines
287
+ by asking *informative* questions, not random ones.
288
+ """)
289
+
290
+ # Wire up
291
+ btn_A.click(
292
+ fn=lambda: on_preference("A"),
293
+ inputs=[],
294
+ outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
295
+ )
296
+ btn_B.click(
297
+ fn=lambda: on_preference("B"),
298
+ inputs=[],
299
+ outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
300
+ )
301
+ demo.load(
302
+ fn=on_load,
303
+ inputs=[],
304
+ outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
305
+ )
306
+
307
+
308
+ if __name__ == "__main__":
309
+ demo.launch(share=False)