File size: 12,447 Bytes
f41070a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcd20f
f41070a
fbcd20f
f41070a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
app.py β€” Interactive Sim-OPRL demo (Gradio).

The professor clicks which of two CartPole trajectories she prefers.
Each click updates the Bradley-Terry reward model.
Every 5 clicks, the policy is retrained using REINFORCE on the learned reward.
The agent's performance (true CartPole reward) is plotted live.

Deploy: gradio app.py   or   python app.py
HuggingFace Spaces: push this repo; set app.py as the entrypoint.
"""
import os
import pickle
import random
import tempfile
import numpy as np
import torch
import gymnasium as gym
import imageio
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
from pathlib import Path

from simoprl.collect_data import collect_offline_dataset, load_dataset
from simoprl.dynamics_model import EnsembleDynamicsModel
from simoprl.reward_model import EnsembleRewardModel
from simoprl.preference_elicitation import SimOPRL
from simoprl.policy import PolicyNetwork, REINFORCETrainer, evaluate_policy

# ── Paths ─────────────────────────────────────────────────────────────────────
DATA_PATH = "data/offline_dataset.pkl"
DYN_MODEL_PATH = "models/dynamics_model.pt"
RESULTS_PATH = "results/experiment_results.pkl"

# ── Global mutable state (single-user demo) ───────────────────────────────────
class _State:
    dynamics_model: EnsembleDynamicsModel = None
    reward_model: EnsembleRewardModel = None
    policy: PolicyNetwork = None
    trainer: REINFORCETrainer = None
    elicitor: SimOPRL = None
    dataset: list = None
    query_count: int = 0
    return_history: list = []           # [(n_queries, mean_return)]
    current_traj1: list = None
    current_traj2: list = None
    initialized: bool = False

S = _State()


# ── Setup ─────────────────────────────────────────────────────────────────────

def _setup():
    """Train / load all components. Called once at startup."""
    if S.initialized:
        return

    # 1. Dataset
    if Path(DATA_PATH).exists():
        S.dataset = load_dataset(DATA_PATH)
        print(f"Dataset loaded: {len(S.dataset)} trajectories")
    else:
        print("Collecting offline dataset …")
        S.dataset = collect_offline_dataset(n_trajectories=800, save_path=DATA_PATH)

    # 2. Dynamics model (pre-trained; central to Sim-OPRL)
    S.dynamics_model = EnsembleDynamicsModel(n_models=5)
    if Path(DYN_MODEL_PATH).exists():
        S.dynamics_model.load(DYN_MODEL_PATH)
    else:
        print("Training dynamics model (first run β€” this takes a few minutes) …")
        S.dynamics_model.train(S.dataset, n_epochs=100)
        S.dynamics_model.save(DYN_MODEL_PATH)

    # 3. Reward model β€” starts blank; shaped entirely by the professor's clicks
    S.reward_model = EnsembleRewardModel(n_models=3)

    # 4. Policy β€” starts random; improves as reward model learns
    S.policy = PolicyNetwork()
    S.trainer = REINFORCETrainer(S.policy, S.reward_model, lr=1e-3)

    # 5. Sim-OPRL elicitor
    S.elicitor = SimOPRL(S.dataset, S.dynamics_model, horizon=50, n_simulated=40, lambda_=1.0)

    S.initialized = True
    print("Setup complete.")


# ── Trajectory simulation & rendering ────────────────────────────────────────

def _current_policy_fn(state: np.ndarray) -> int:
    if S.query_count < 5:
        return np.random.randint(2)
    action, _ = S.policy.select_action(state)
    return action


def _render_trajectory_to_gif(trajectory, path, fps=20) -> str:
    """
    Render a (state, action) trajectory to a GIF using CartPole's rgb_array renderer.
    For simulated trajectories the env state is set at each step.
    """
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    env.reset()

    frames = []
    for state_arr, action in trajectory:
        # Clip to renderable range (dynamics model may predict slightly OOB states)
        clipped = np.array([
            np.clip(state_arr[0], -4.8, 4.8),
            np.clip(state_arr[1], -10.0, 10.0),
            np.clip(state_arr[2], -0.5, 0.5),
            np.clip(state_arr[3], -10.0, 10.0),
        ], dtype=np.float64)
        env.unwrapped.state = clipped
        frames.append(env.render())

    env.close()

    duration = 1.0 / fps
    imageio.mimwrite(path, frames, format="GIF", duration=duration, loop=0)
    return path


def _generate_and_render_pair() -> tuple[str, str]:
    """Ask Sim-OPRL for the next query pair and render both as GIFs."""
    traj1, traj2 = S.elicitor.get_query_pair(S.reward_model, _current_policy_fn)
    S.current_traj1 = traj1
    S.current_traj2 = traj2

    path1 = _render_trajectory_to_gif(traj1, "/tmp/simoprl_traj_A.gif")
    path2 = _render_trajectory_to_gif(traj2, "/tmp/simoprl_traj_B.gif")
    return path1, path2


# ── Plot ──────────────────────────────────────────────────────────────────────

def _make_return_plot() -> plt.Figure:
    fig, ax = plt.subplots(figsize=(9, 4))
    ax.set_facecolor("#f5f5f5")
    fig.patch.set_facecolor("white")

    ax.axhline(y=21, color="#aaa", linestyle=":", linewidth=1.2, label="Random policy (~21 steps)")
    ax.axhline(y=500, color="#2ca02c", linestyle="--", linewidth=1, alpha=0.5, label="Max return (500)")

    if S.return_history:
        qs = [x[0] for x in S.return_history]
        means = np.array([x[1] for x in S.return_history])
        ax.plot(qs, means, "o-", color="#1f77b4", linewidth=2.5, markersize=7,
                label="Sim-OPRL (your preferences)")
        ax.fill_between(qs, means * 0.85, means * 1.15, alpha=0.15, color="#1f77b4")

    ax.set_xlabel("Number of Preference Queries", fontsize=12)
    ax.set_ylabel("Policy Return (True Reward)", fontsize=12)
    ax.set_title("How your preferences shape the agent", fontsize=13, fontweight="bold")
    ax.set_ylim(0, 530)
    ax.legend(fontsize=10, framealpha=0.9)
    ax.grid(True, alpha=0.3, linestyle="--")
    plt.tight_layout()
    return fig


def _make_comparison_plot() -> plt.Figure:
    """Show pre-computed baseline comparison if results exist."""
    if not Path(RESULTS_PATH).exists():
        fig, ax = plt.subplots(figsize=(9, 3))
        ax.text(0.5, 0.5, "Run python train.py to generate comparison figure",
                ha="center", va="center", transform=ax.transAxes, fontsize=12, color="gray")
        ax.axis("off")
        return fig

    with open(RESULTS_PATH, "rb") as f:
        data = pickle.load(f)

    results = data["results"]
    checkpoints = sorted(data["checkpoints"])
    colors = {"uniform": "#d62728", "uncertainty": "#ff7f0e", "simoprl": "#1f77b4"}
    labels = {"uniform": "Uniform OPRL", "uncertainty": "Uncertainty OPRL", "simoprl": "Sim-OPRL (paper)"}

    fig, ax = plt.subplots(figsize=(9, 4))
    ax.set_facecolor("#f5f5f5")
    for method in ["uniform", "uncertainty", "simoprl"]:
        if method not in results:
            continue
        seed_results = results[method]
        qs = checkpoints
        means = np.array([np.mean([r.get(q, np.nan) for r in seed_results]) for q in qs])
        stds = np.array([np.std([r.get(q, np.nan) for r in seed_results]) for q in qs])
        ax.plot(qs, means, "-o", color=colors[method], linewidth=2 if method == "simoprl" else 1.5,
                markersize=5, label=labels[method])
        ax.fill_between(qs, means - stds, means + stds, alpha=0.12, color=colors[method])

    ax.axhline(y=500, color="green", linestyle="--", linewidth=1, alpha=0.5)
    ax.set_xlabel("Preference Queries", fontsize=11)
    ax.set_ylabel("Policy Return", fontsize=11)
    ax.set_title("Sim-OPRL vs baselines (oracle preferences, 5 seeds)", fontsize=12, fontweight="bold")
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    return fig


# ── Gradio handlers ───────────────────────────────────────────────────────────

def on_load():
    _setup()
    gif1, gif2 = _generate_and_render_pair()
    plot = _make_return_plot()
    status = "Ready β€” click which trajectory keeps the pole balanced longer."
    return gif1, gif2, plot, status, _make_comparison_plot()


def on_preference(preferred: str):
    """Called when professor clicks 'Prefer A' or 'Prefer B'."""
    if S.current_traj1 is None:
        return on_load()

    label = 0 if preferred == "A" else 1
    S.reward_model.add_preference(S.current_traj1, S.current_traj2, label)
    S.reward_model.update(n_epochs=15)
    S.query_count += 1

    status = f"Query {S.query_count}: you preferred {'A' if label == 0 else 'B'}."

    # Retrain policy every 5 queries
    if S.query_count % 5 == 0:
        status += " Updating policy …"
        S.trainer.train(n_episodes=40)
        mean_ret, _ = evaluate_policy(S.policy, n_episodes=15)
        S.return_history.append((S.query_count, mean_ret))
        status += f" Policy return: {mean_ret:.1f}"

    gif1, gif2 = _generate_and_render_pair()
    return gif1, gif2, _make_return_plot(), status, _make_comparison_plot()


# ── Gradio UI ─────────────────────────────────────────────────────────────────

with gr.Blocks(title="Sim-OPRL Demo", theme=gr.themes.Soft()) as demo:

    gr.Markdown("""
    # Sim-OPRL: Preference Elicitation for Offline RL
    ### Pace Β· SchΓΆlkopf Β· RΓ€tsch Β· Ramponi β€” ICLR 2025

    Two CartPole trajectories are simulated by a learned **dynamics model**, chosen by the
    **Sim-OPRL** acquisition strategy: high reward uncertainty (we learn the most here)
    and low transition uncertainty (the dynamics model is reliable here).

    **Click which run keeps the pole balanced longer.**
    Your preferences directly train the reward model via the Bradley-Terry loss.
    Every 5 clicks, the policy is re-optimised with REINFORCE on the learned reward.
    """)

    with gr.Row(equal_height=True):
        with gr.Column():
            vid_A = gr.Image(label="Trajectory A", type="filepath")
            btn_A = gr.Button("β¬…  Prefer A", variant="primary", size="lg")
        with gr.Column():
            vid_B = gr.Image(label="Trajectory B", type="filepath")
            btn_B = gr.Button("Prefer B  ➑", variant="primary", size="lg")

    status_box = gr.Textbox(label="Status", interactive=False, lines=1)

    with gr.Tabs():
        with gr.Tab("Live: Your Preferences β†’ Agent Return"):
            live_plot = gr.Plot(label="Return vs Queries (updates every 5 clicks)")
        with gr.Tab("Baseline Comparison (from train.py)"):
            comparison_plot = gr.Plot(label="Sim-OPRL vs Uniform OPRL vs Uncertainty OPRL")

    gr.Markdown("""
    ---
    ### How Sim-OPRL works

    | Step | What happens |
    |------|--------------|
    | 1 | Collect an unlabelled offline dataset (no rewards) |
    | 2 | Train an **ensemble dynamics model** on the dataset |
    | 3 | For each query: simulate trajectories, score by `reward_uncertainty βˆ’ Ξ» Β· transition_uncertainty` |
    | 4 | Ask for a preference on the highest-scoring pair |
    | 5 | Update the **Bradley-Terry reward model** with the preference |
    | 6 | Re-optimise the policy with REINFORCE on the learned reward |

    Sim-OPRL reaches higher returns with **fewer queries** than naΓ―ve baselines
    by asking *informative* questions, not random ones.
    """)

    # Wire up
    btn_A.click(
        fn=lambda: on_preference("A"),
        inputs=[],
        outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
    )
    btn_B.click(
        fn=lambda: on_preference("B"),
        inputs=[],
        outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
    )
    demo.load(
        fn=on_load,
        inputs=[],
        outputs=[vid_A, vid_B, live_plot, status_box, comparison_plot],
    )


if __name__ == "__main__":
    demo.launch(share=False)