sim-oprl / app.py
singhanshuman's picture
Upload app.py with huggingface_hub
fbcd20f verified
"""
app.py β€” Interactive Sim-OPRL demo (Gradio).
The professor clicks which of two CartPole trajectories she prefers.
Each click updates the Bradley-Terry reward model.
Every 5 clicks, the policy is retrained using REINFORCE on the learned reward.
The agent's performance (true CartPole reward) is plotted live.
Deploy: gradio app.py or python app.py
HuggingFace Spaces: push this repo; set app.py as the entrypoint.
"""
import os
import pickle
import random
import tempfile
import numpy as np
import torch
import gymnasium as gym
import imageio
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
from pathlib import Path
from simoprl.collect_data import collect_offline_dataset, load_dataset
from simoprl.dynamics_model import EnsembleDynamicsModel
from simoprl.reward_model import EnsembleRewardModel
from simoprl.preference_elicitation import SimOPRL
from simoprl.policy import PolicyNetwork, REINFORCETrainer, evaluate_policy
# ── Paths ─────────────────────────────────────────────────────────────────────
DATA_PATH = "data/offline_dataset.pkl"
DYN_MODEL_PATH = "models/dynamics_model.pt"
RESULTS_PATH = "results/experiment_results.pkl"
# ── Global mutable state (single-user demo) ───────────────────────────────────
class _State:
dynamics_model: EnsembleDynamicsModel = None
reward_model: EnsembleRewardModel = None
policy: PolicyNetwork = None
trainer: REINFORCETrainer = None
elicitor: SimOPRL = None
dataset: list = None
query_count: int = 0
return_history: list = [] # [(n_queries, mean_return)]
current_traj1: list = None
current_traj2: list = None
initialized: bool = False
S = _State()
# ── Setup ─────────────────────────────────────────────────────────────────────
def _setup():
"""Train / load all components. Called once at startup."""
if S.initialized:
return
# 1. Dataset
if Path(DATA_PATH).exists():
S.dataset = load_dataset(DATA_PATH)
print(f"Dataset loaded: {len(S.dataset)} trajectories")
else:
print("Collecting offline dataset …")
S.dataset = collect_offline_dataset(n_trajectories=800, save_path=DATA_PATH)
# 2. Dynamics model (pre-trained; central to Sim-OPRL)
S.dynamics_model = EnsembleDynamicsModel(n_models=5)
if Path(DYN_MODEL_PATH).exists():
S.dynamics_model.load(DYN_MODEL_PATH)
else:
print("Training dynamics model (first run β€” this takes a few minutes) …")
S.dynamics_model.train(S.dataset, n_epochs=100)
S.dynamics_model.save(DYN_MODEL_PATH)
# 3. Reward model β€” starts blank; shaped entirely by the professor's clicks
S.reward_model = EnsembleRewardModel(n_models=3)
# 4. Policy β€” starts random; improves as reward model learns
S.policy = PolicyNetwork()
S.trainer = REINFORCETrainer(S.policy, S.reward_model, lr=1e-3)
# 5. Sim-OPRL elicitor
S.elicitor = SimOPRL(S.dataset, S.dynamics_model, horizon=50, n_simulated=40, lambda_=1.0)
S.initialized = True
print("Setup complete.")
# ── Trajectory simulation & rendering ────────────────────────────────────────
def _current_policy_fn(state: np.ndarray) -> int:
if S.query_count < 5:
return np.random.randint(2)
action, _ = S.policy.select_action(state)
return action
def _render_trajectory_to_gif(trajectory, path, fps=20) -> str:
"""
Render a (state, action) trajectory to a GIF using CartPole's rgb_array renderer.
For simulated trajectories the env state is set at each step.
"""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env.reset()
frames = []
for state_arr, action in trajectory:
# Clip to renderable range (dynamics model may predict slightly OOB states)
clipped = np.array([
np.clip(state_arr[0], -4.8, 4.8),
np.clip(state_arr[1], -10.0, 10.0),
np.clip(state_arr[2], -0.5, 0.5),
np.clip(state_arr[3], -10.0, 10.0),
], dtype=np.float64)
env.unwrapped.state = clipped
frames.append(env.render())
env.close()
duration = 1.0 / fps
imageio.mimwrite(path, frames, format="GIF", duration=duration, loop=0)
return path
def _generate_and_render_pair() -> tuple[str, str]:
"""Ask Sim-OPRL for the next query pair and render both as GIFs."""
traj1, traj2 = S.elicitor.get_query_pair(S.reward_model, _current_policy_fn)
S.current_traj1 = traj1
S.current_traj2 = traj2
path1 = _render_trajectory_to_gif(traj1, "/tmp/simoprl_traj_A.gif")
path2 = _render_trajectory_to_gif(traj2, "/tmp/simoprl_traj_B.gif")
return path1, path2
# ── Plot ──────────────────────────────────────────────────────────────────────
def _make_return_plot() -> plt.Figure:
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_facecolor("#f5f5f5")
fig.patch.set_facecolor("white")
ax.axhline(y=21, color="#aaa", linestyle=":", linewidth=1.2, label="Random policy (~21 steps)")
ax.axhline(y=500, color="#2ca02c", linestyle="--", linewidth=1, alpha=0.5, label="Max return (500)")
if S.return_history:
qs = [x[0] for x in S.return_history]
means = np.array([x[1] for x in S.return_history])
ax.plot(qs, means, "o-", color="#1f77b4", linewidth=2.5, markersize=7,
label="Sim-OPRL (your preferences)")
ax.fill_between(qs, means * 0.85, means * 1.15, alpha=0.15, color="#1f77b4")
ax.set_xlabel("Number of Preference Queries", fontsize=12)
ax.set_ylabel("Policy Return (True Reward)", fontsize=12)
ax.set_title("How your preferences shape the agent", fontsize=13, fontweight="bold")
ax.set_ylim(0, 530)
ax.legend(fontsize=10, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle="--")
plt.tight_layout()
return fig
def _make_comparison_plot() -> plt.Figure:
"""Show pre-computed baseline comparison if results exist."""
if not Path(RESULTS_PATH).exists():
fig, ax = plt.subplots(figsize=(9, 3))
ax.text(0.5, 0.5, "Run python train.py to generate comparison figure",
ha="center", va="center", transform=ax.transAxes, fontsize=12, color="gray")
ax.axis("off")
return fig
with open(RESULTS_PATH, "rb") as f:
data = pickle.load(f)
results = data["results"]
checkpoints = sorted(data["checkpoints"])
colors = {"uniform": "#d62728", "uncertainty": "#ff7f0e", "simoprl": "#1f77b4"}
labels = {"uniform": "Uniform OPRL", "uncertainty": "Uncertainty OPRL", "simoprl": "Sim-OPRL (paper)"}
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_facecolor("#f5f5f5")
for method in ["uniform", "uncertainty", "simoprl"]:
if method not in results:
continue
seed_results = results[method]
qs = checkpoints
means = np.array([np.mean([r.get(q, np.nan) for r in seed_results]) for q in qs])
stds = np.array([np.std([r.get(q, np.nan) for r in seed_results]) for q in qs])
ax.plot(qs, means, "-o", color=colors[method], linewidth=2 if method == "simoprl" else 1.5,
markersize=5, label=labels[method])
ax.fill_between(qs, means - stds, means + stds, alpha=0.12, color=colors[method])
ax.axhline(y=500, color="green", linestyle="--", linewidth=1, alpha=0.5)
ax.set_xlabel("Preference Queries", fontsize=11)
ax.set_ylabel("Policy Return", fontsize=11)
ax.set_title("Sim-OPRL vs baselines (oracle preferences, 5 seeds)", fontsize=12, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig
# ── Gradio handlers ───────────────────────────────────────────────────────────
def on_load():
_setup()
gif1, gif2 = _generate_and_render_pair()
plot = _make_return_plot()
status = "Ready β€” click which trajectory keeps the pole balanced longer."
return gif1, gif2, plot, status, _make_comparison_plot()
def on_preference(preferred: str):
"""Called when professor clicks 'Prefer A' or 'Prefer B'."""
if S.current_traj1 is None:
return on_load()
label = 0 if preferred == "A" else 1
S.reward_model.add_preference(S.current_traj1, S.current_traj2, label)
S.reward_model.update(n_epochs=15)
S.query_count += 1
status = f"Query {S.query_count}: you preferred {'A' if label == 0 else 'B'}."
# Retrain policy every 5 queries
if S.query_count % 5 == 0:
status += " Updating policy …"
S.trainer.train(n_episodes=40)
mean_ret, _ = evaluate_policy(S.policy, n_episodes=15)
S.return_history.append((S.query_count, mean_ret))
status += f" Policy return: {mean_ret:.1f}"
gif1, gif2 = _generate_and_render_pair()
return gif1, gif2, _make_return_plot(), status, _make_comparison_plot()
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="Sim-OPRL Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Sim-OPRL: Preference Elicitation for Offline RL
### Pace Β· SchΓΆlkopf Β· RΓ€tsch Β· Ramponi β€” ICLR 2025
Two CartPole trajectories are simulated by a learned **dynamics model**, chosen by the
**Sim-OPRL** acquisition strategy: high reward uncertainty (we learn the most here)
and low transition uncertainty (the dynamics model is reliable here).
**Click which run keeps the pole balanced longer.**
Your preferences directly train the reward model via the Bradley-Terry loss.
Every 5 clicks, the policy is re-optimised with REINFORCE on the learned reward.
""")
with gr.Row(equal_height=True):
with gr.Column():
vid_A = gr.Image(label="Trajectory A", type="filepath")
btn_A = gr.Button("β¬… Prefer A", variant="primary", size="lg")
with gr.Column():
vid_B = gr.Image(label="Trajectory B", type="filepath")
btn_B = gr.Button("Prefer B ➑", variant="primary", size="lg")
status_box = gr.Textbox(label="Status", interactive=False, lines=1)
with gr.Tabs():
with gr.Tab("Live: Your Preferences β†’ Agent Return"):
live_plot = gr.Plot(label="Return vs Queries (updates every 5 clicks)")
with gr.Tab("Baseline Comparison (from train.py)"):
comparison_plot = gr.Plot(label="Sim-OPRL vs Uniform OPRL vs Uncertainty OPRL")
gr.Markdown("""
---
### How Sim-OPRL works
| Step | What happens |
|------|--------------|
| 1 | Collect an unlabelled offline dataset (no rewards) |
| 2 | Train an **ensemble dynamics model** on the dataset |
| 3 | For each query: simulate trajectories, score by `reward_uncertainty βˆ’ Ξ» Β· transition_uncertainty` |
| 4 | Ask for a preference on the highest-scoring pair |
| 5 | Update the **Bradley-Terry reward model** with the preference |
| 6 | Re-optimise the policy with REINFORCE on the learned reward |
Sim-OPRL reaches higher returns with **fewer queries** than naΓ―ve baselines
by asking *informative* questions, not random ones.
""")
# Wire up
btn_A.click(
fn=lambda: on_preference("A"),
inputs=[],
outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
)
btn_B.click(
fn=lambda: on_preference("B"),
inputs=[],
outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
)
demo.load(
fn=on_load,
inputs=[],
outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
)
if __name__ == "__main__":
demo.launch(share=False)