Spaces:
Sleeping
Sleeping
File size: 12,447 Bytes
f41070a fbcd20f f41070a fbcd20f f41070a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 | """
app.py β Interactive Sim-OPRL demo (Gradio).
The professor clicks which of two CartPole trajectories she prefers.
Each click updates the Bradley-Terry reward model.
Every 5 clicks, the policy is retrained using REINFORCE on the learned reward.
The agent's performance (true CartPole reward) is plotted live.
Deploy: gradio app.py or python app.py
HuggingFace Spaces: push this repo; set app.py as the entrypoint.
"""
import os
import pickle
import random
import tempfile
import numpy as np
import torch
import gymnasium as gym
import imageio
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
from pathlib import Path
from simoprl.collect_data import collect_offline_dataset, load_dataset
from simoprl.dynamics_model import EnsembleDynamicsModel
from simoprl.reward_model import EnsembleRewardModel
from simoprl.preference_elicitation import SimOPRL
from simoprl.policy import PolicyNetwork, REINFORCETrainer, evaluate_policy
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DATA_PATH = "data/offline_dataset.pkl"
DYN_MODEL_PATH = "models/dynamics_model.pt"
RESULTS_PATH = "results/experiment_results.pkl"
# ββ Global mutable state (single-user demo) βββββββββββββββββββββββββββββββββββ
class _State:
dynamics_model: EnsembleDynamicsModel = None
reward_model: EnsembleRewardModel = None
policy: PolicyNetwork = None
trainer: REINFORCETrainer = None
elicitor: SimOPRL = None
dataset: list = None
query_count: int = 0
return_history: list = [] # [(n_queries, mean_return)]
current_traj1: list = None
current_traj2: list = None
initialized: bool = False
S = _State()
# ββ Setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _setup():
"""Train / load all components. Called once at startup."""
if S.initialized:
return
# 1. Dataset
if Path(DATA_PATH).exists():
S.dataset = load_dataset(DATA_PATH)
print(f"Dataset loaded: {len(S.dataset)} trajectories")
else:
print("Collecting offline dataset β¦")
S.dataset = collect_offline_dataset(n_trajectories=800, save_path=DATA_PATH)
# 2. Dynamics model (pre-trained; central to Sim-OPRL)
S.dynamics_model = EnsembleDynamicsModel(n_models=5)
if Path(DYN_MODEL_PATH).exists():
S.dynamics_model.load(DYN_MODEL_PATH)
else:
print("Training dynamics model (first run β this takes a few minutes) β¦")
S.dynamics_model.train(S.dataset, n_epochs=100)
S.dynamics_model.save(DYN_MODEL_PATH)
# 3. Reward model β starts blank; shaped entirely by the professor's clicks
S.reward_model = EnsembleRewardModel(n_models=3)
# 4. Policy β starts random; improves as reward model learns
S.policy = PolicyNetwork()
S.trainer = REINFORCETrainer(S.policy, S.reward_model, lr=1e-3)
# 5. Sim-OPRL elicitor
S.elicitor = SimOPRL(S.dataset, S.dynamics_model, horizon=50, n_simulated=40, lambda_=1.0)
S.initialized = True
print("Setup complete.")
# ββ Trajectory simulation & rendering ββββββββββββββββββββββββββββββββββββββββ
def _current_policy_fn(state: np.ndarray) -> int:
if S.query_count < 5:
return np.random.randint(2)
action, _ = S.policy.select_action(state)
return action
def _render_trajectory_to_gif(trajectory, path, fps=20) -> str:
"""
Render a (state, action) trajectory to a GIF using CartPole's rgb_array renderer.
For simulated trajectories the env state is set at each step.
"""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env.reset()
frames = []
for state_arr, action in trajectory:
# Clip to renderable range (dynamics model may predict slightly OOB states)
clipped = np.array([
np.clip(state_arr[0], -4.8, 4.8),
np.clip(state_arr[1], -10.0, 10.0),
np.clip(state_arr[2], -0.5, 0.5),
np.clip(state_arr[3], -10.0, 10.0),
], dtype=np.float64)
env.unwrapped.state = clipped
frames.append(env.render())
env.close()
duration = 1.0 / fps
imageio.mimwrite(path, frames, format="GIF", duration=duration, loop=0)
return path
def _generate_and_render_pair() -> tuple[str, str]:
"""Ask Sim-OPRL for the next query pair and render both as GIFs."""
traj1, traj2 = S.elicitor.get_query_pair(S.reward_model, _current_policy_fn)
S.current_traj1 = traj1
S.current_traj2 = traj2
path1 = _render_trajectory_to_gif(traj1, "/tmp/simoprl_traj_A.gif")
path2 = _render_trajectory_to_gif(traj2, "/tmp/simoprl_traj_B.gif")
return path1, path2
# ββ Plot ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _make_return_plot() -> plt.Figure:
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_facecolor("#f5f5f5")
fig.patch.set_facecolor("white")
ax.axhline(y=21, color="#aaa", linestyle=":", linewidth=1.2, label="Random policy (~21 steps)")
ax.axhline(y=500, color="#2ca02c", linestyle="--", linewidth=1, alpha=0.5, label="Max return (500)")
if S.return_history:
qs = [x[0] for x in S.return_history]
means = np.array([x[1] for x in S.return_history])
ax.plot(qs, means, "o-", color="#1f77b4", linewidth=2.5, markersize=7,
label="Sim-OPRL (your preferences)")
ax.fill_between(qs, means * 0.85, means * 1.15, alpha=0.15, color="#1f77b4")
ax.set_xlabel("Number of Preference Queries", fontsize=12)
ax.set_ylabel("Policy Return (True Reward)", fontsize=12)
ax.set_title("How your preferences shape the agent", fontsize=13, fontweight="bold")
ax.set_ylim(0, 530)
ax.legend(fontsize=10, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle="--")
plt.tight_layout()
return fig
def _make_comparison_plot() -> plt.Figure:
"""Show pre-computed baseline comparison if results exist."""
if not Path(RESULTS_PATH).exists():
fig, ax = plt.subplots(figsize=(9, 3))
ax.text(0.5, 0.5, "Run python train.py to generate comparison figure",
ha="center", va="center", transform=ax.transAxes, fontsize=12, color="gray")
ax.axis("off")
return fig
with open(RESULTS_PATH, "rb") as f:
data = pickle.load(f)
results = data["results"]
checkpoints = sorted(data["checkpoints"])
colors = {"uniform": "#d62728", "uncertainty": "#ff7f0e", "simoprl": "#1f77b4"}
labels = {"uniform": "Uniform OPRL", "uncertainty": "Uncertainty OPRL", "simoprl": "Sim-OPRL (paper)"}
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_facecolor("#f5f5f5")
for method in ["uniform", "uncertainty", "simoprl"]:
if method not in results:
continue
seed_results = results[method]
qs = checkpoints
means = np.array([np.mean([r.get(q, np.nan) for r in seed_results]) for q in qs])
stds = np.array([np.std([r.get(q, np.nan) for r in seed_results]) for q in qs])
ax.plot(qs, means, "-o", color=colors[method], linewidth=2 if method == "simoprl" else 1.5,
markersize=5, label=labels[method])
ax.fill_between(qs, means - stds, means + stds, alpha=0.12, color=colors[method])
ax.axhline(y=500, color="green", linestyle="--", linewidth=1, alpha=0.5)
ax.set_xlabel("Preference Queries", fontsize=11)
ax.set_ylabel("Policy Return", fontsize=11)
ax.set_title("Sim-OPRL vs baselines (oracle preferences, 5 seeds)", fontsize=12, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig
# ββ Gradio handlers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def on_load():
_setup()
gif1, gif2 = _generate_and_render_pair()
plot = _make_return_plot()
status = "Ready β click which trajectory keeps the pole balanced longer."
return gif1, gif2, plot, status, _make_comparison_plot()
def on_preference(preferred: str):
"""Called when professor clicks 'Prefer A' or 'Prefer B'."""
if S.current_traj1 is None:
return on_load()
label = 0 if preferred == "A" else 1
S.reward_model.add_preference(S.current_traj1, S.current_traj2, label)
S.reward_model.update(n_epochs=15)
S.query_count += 1
status = f"Query {S.query_count}: you preferred {'A' if label == 0 else 'B'}."
# Retrain policy every 5 queries
if S.query_count % 5 == 0:
status += " Updating policy β¦"
S.trainer.train(n_episodes=40)
mean_ret, _ = evaluate_policy(S.policy, n_episodes=15)
S.return_history.append((S.query_count, mean_ret))
status += f" Policy return: {mean_ret:.1f}"
gif1, gif2 = _generate_and_render_pair()
return gif1, gif2, _make_return_plot(), status, _make_comparison_plot()
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Sim-OPRL Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Sim-OPRL: Preference Elicitation for Offline RL
### Pace Β· SchΓΆlkopf Β· RΓ€tsch Β· Ramponi β ICLR 2025
Two CartPole trajectories are simulated by a learned **dynamics model**, chosen by the
**Sim-OPRL** acquisition strategy: high reward uncertainty (we learn the most here)
and low transition uncertainty (the dynamics model is reliable here).
**Click which run keeps the pole balanced longer.**
Your preferences directly train the reward model via the Bradley-Terry loss.
Every 5 clicks, the policy is re-optimised with REINFORCE on the learned reward.
""")
with gr.Row(equal_height=True):
with gr.Column():
vid_A = gr.Image(label="Trajectory A", type="filepath")
btn_A = gr.Button("β¬
Prefer A", variant="primary", size="lg")
with gr.Column():
vid_B = gr.Image(label="Trajectory B", type="filepath")
btn_B = gr.Button("Prefer B β‘", variant="primary", size="lg")
status_box = gr.Textbox(label="Status", interactive=False, lines=1)
with gr.Tabs():
with gr.Tab("Live: Your Preferences β Agent Return"):
live_plot = gr.Plot(label="Return vs Queries (updates every 5 clicks)")
with gr.Tab("Baseline Comparison (from train.py)"):
comparison_plot = gr.Plot(label="Sim-OPRL vs Uniform OPRL vs Uncertainty OPRL")
gr.Markdown("""
---
### How Sim-OPRL works
| Step | What happens |
|------|--------------|
| 1 | Collect an unlabelled offline dataset (no rewards) |
| 2 | Train an **ensemble dynamics model** on the dataset |
| 3 | For each query: simulate trajectories, score by `reward_uncertainty β Ξ» Β· transition_uncertainty` |
| 4 | Ask for a preference on the highest-scoring pair |
| 5 | Update the **Bradley-Terry reward model** with the preference |
| 6 | Re-optimise the policy with REINFORCE on the learned reward |
Sim-OPRL reaches higher returns with **fewer queries** than naΓ―ve baselines
by asking *informative* questions, not random ones.
""")
# Wire up
btn_A.click(
fn=lambda: on_preference("A"),
inputs=[],
outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
)
btn_B.click(
fn=lambda: on_preference("B"),
inputs=[],
outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
)
demo.load(
fn=on_load,
inputs=[],
outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
)
if __name__ == "__main__":
demo.launch(share=False)
|