""" Budget Router — Gradio Visualization Dashboard Run: python app_gradio.py (launches on http://localhost:7860) """ from __future__ import annotations import math import time from typing import Dict, Optional, Tuple import gradio as gr from budget_router.environment import BudgetRouterEnv from budget_router.models import Action, ActionType from budget_router.tasks import TASK_PRESETS from gradio_ui.config import MAX_STEPS as _MAX_STEPS, POLICY_CHOICES, SCENARIOS from gradio_ui.policies import get_policy_runner from gradio_ui.renderers import ( _kpi_grid, render_incident_timeline, render_side_panel, render_grader_plot, MISSION_SCORE_HELP, MISSION_SCORE_LABEL, _GRADER_PENDING, _PROVIDER_EMPTY, render_history_table_compare, ) from gradio_ui.state import fresh_side_state, _observation_to_dict, record_step from gradio_ui.theme import LIGHT_CSS, THEME MAX_STEPS = _MAX_STEPS # Compatibility: preserve module-level MAX_STEPS for callers. # ─── UI Build ───────────────────────────────────────────────────────────────── def build_app() -> gr.Blocks: def _normalize_seed(seed: object, default: int = 42) -> int: if seed is None: return default try: val = float(seed) # type: ignore[arg-type] except Exception: return default if math.isnan(val) or math.isinf(val): return default try: return int(val) except Exception: return default with gr.Blocks(title="Budget Router — Policy Comparison", theme=THEME, css=LIGHT_CSS) as demo: left_state = gr.State(fresh_side_state()) right_state = gr.State(fresh_side_state()) run_state = gr.State({"running": False, "scenario": "easy", "seed": 42, "step": 0}) gr.Markdown( "# Budget Router — Policy Comparison\n" "_Select 2 policies · start episode · step or finish episode · compare outcomes_" ) with gr.Row(): with gr.Column(scale=1): left_title = gr.Markdown("## Policy A") left_policy = gr.Dropdown(choices=POLICY_CHOICES, value=None, label="Select policy") left_status = gr.Textbox(label="Status", interactive=False, lines=2) left_providers = gr.HTML(_PROVIDER_EMPTY()) left_budget = gr.HTML("") left_kpis = gr.HTML( _kpi_grid( [ ("Step", "—"), ("Last action", "—"), ("Latency (ms)", "—"), ("Budget remaining", "—"), ("Reward", "—"), ("Adaptation", "—"), ] ) ) left_badges = gr.HTML("") left_summary = gr.HTML( _kpi_grid( [ ("Failed %", "—"), ("SLA breach %", "—"), ("Avg latency (ms)", "—"), ] ) ) with gr.Column(scale=1): right_title = gr.Markdown("## Policy B") right_policy = gr.Dropdown(choices=POLICY_CHOICES, value=None, label="Select policy") right_status = gr.Textbox(label="Status", interactive=False, lines=2) right_providers = gr.HTML(_PROVIDER_EMPTY()) right_budget = gr.HTML("") right_kpis = gr.HTML( _kpi_grid( [ ("Step", "—"), ("Last action", "—"), ("Latency (ms)", "—"), ("Budget remaining", "—"), ("Reward", "—"), ("Adaptation", "—"), ] ) ) right_badges = gr.HTML("") right_summary = gr.HTML( _kpi_grid( [ ("Failed %", "—"), ("SLA breach %", "—"), ("Avg latency (ms)", "—"), ] ) ) with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Episode Controls") scenario_sel = gr.Radio(SCENARIOS, value="easy", label="Scenario") seed_inp = gr.Number(value=42, label="Seed", precision=0) start_btn = gr.Button("▶ Start Episode", variant="primary", interactive=False) with gr.Row(): step_btn = gr.Button("→ Step", variant="secondary", interactive=False) fast_btn = gr.Button("⚡ Fast-forward", interactive=False) finish_btn = gr.Button("⏩ Finish Episode", interactive=False) gr.Markdown(f"### {MISSION_SCORE_LABEL} (comparison)\n_{MISSION_SCORE_HELP}_") grader_plot = gr.Plot() with gr.Row(elem_classes=["episode-history-row"]): with gr.Column(scale=1): left_history_title = gr.Markdown("### Step History — Policy A") left_history_tbl = gr.HTML(render_history_table_compare([]), elem_classes=["episode-history-table"]) with gr.Column(scale=1): right_history_title = gr.Markdown("### Step History — Policy B") right_history_tbl = gr.HTML(render_history_table_compare([]), elem_classes=["episode-history-table"]) with gr.Row(): with gr.Column(scale=1): left_grade_title = gr.Markdown(f"### {MISSION_SCORE_LABEL} — Policy A") left_grade = gr.HTML(_GRADER_PENDING()) with gr.Column(scale=1): right_grade_title = gr.Markdown(f"### {MISSION_SCORE_LABEL} — Policy B") right_grade = gr.HTML(_GRADER_PENDING()) gr.Markdown("### Incident Timeline") incidents_html = gr.HTML(render_incident_timeline("easy")) def _render_side(side: Dict, run: Dict, scenario_name: str) -> Tuple[str, str, str, str, str, str, str, str]: return render_side_panel(side, run, scenario_name) def _render_all(ls: Dict, rs: Dict, run: Dict) -> tuple: scenario_name = str(run.get("scenario", "easy") or "easy") l_out = _render_side(ls, run, scenario_name) r_out = _render_side(rs, run, scenario_name) plot = render_grader_plot( ls.get("history", []) or [], rs.get("history", []) or [], left_name=str(ls.get("policy_name") or ""), right_name=str(rs.get("policy_name") or ""), ) incidents = render_incident_timeline(scenario_name) running = bool(run.get("running", False)) btn_update = gr.update(interactive=running) config_update = gr.update(interactive=(not running)) return ( ls, rs, run, l_out[0], l_out[1], l_out[2], l_out[3], l_out[4], l_out[5], r_out[0], r_out[1], r_out[2], r_out[3], r_out[4], r_out[5], l_out[6], r_out[6], l_out[7], r_out[7], plot, incidents, config_update, config_update, config_update, config_update, config_update, btn_update, btn_update, btn_update, ) OUTPUTS = [ left_state, right_state, run_state, left_status, left_providers, left_budget, left_kpis, left_badges, left_summary, right_status, right_providers, right_budget, right_kpis, right_badges, right_summary, left_history_tbl, right_history_tbl, left_grade, right_grade, grader_plot, incidents_html, left_policy, right_policy, scenario_sel, seed_inp, start_btn, step_btn, fast_btn, finish_btn, ] GRADER_PLOT_IDX = OUTPUTS.index(grader_plot) def _update_start_enabled(p1: Optional[str], p2: Optional[str], run: Dict): left_name = str(p1 or "Policy A") right_name = str(p2 or "Policy B") running = bool((run or {}).get("running", False)) ok = (bool(p1) and bool(p2)) and (not running) return ( gr.update(interactive=ok), f"## {left_name}", f"## {right_name}", f"### Step History — {left_name}", f"### Step History — {right_name}", f"### {MISSION_SCORE_LABEL} — {left_name}", f"### {MISSION_SCORE_LABEL} — {right_name}", ) left_policy.change( _update_start_enabled, inputs=[left_policy, right_policy, run_state], outputs=[start_btn, left_title, right_title, left_history_title, right_history_title, left_grade_title, right_grade_title], ) right_policy.change( _update_start_enabled, inputs=[left_policy, right_policy, run_state], outputs=[start_btn, left_title, right_title, left_history_title, right_history_title, left_grade_title, right_grade_title], ) scenario_sel.change(lambda s: render_incident_timeline(s), inputs=[scenario_sel], outputs=[incidents_html]) def do_start(p1: str, p2: str, scenario: str, seed: Optional[float], _ls: Dict, _rs: Dict, _run: Dict): ls = fresh_side_state() rs = fresh_side_state() seed_int = _normalize_seed(seed, default=42) if not p1 or not p2: run = {"running": False, "scenario": scenario, "seed": seed_int, "step": 0} ls["status"] = "Select both policies to start." rs["status"] = "Select both policies to start." return _render_all(ls, rs, run) runner_l, err_l = get_policy_runner(p1) runner_r, err_r = get_policy_runner(p2) if err_l or err_r or runner_l is None or runner_r is None: ls["status"] = f"❌ {err_l}" if err_l else "" rs["status"] = f"❌ {err_r}" if err_r else "" run = {"running": False, "scenario": scenario, "seed": seed_int, "step": 0} return _render_all(ls, rs, run) env_l = BudgetRouterEnv() env_r = BudgetRouterEnv() obs_l = env_l.reset(seed=seed_int, scenario=scenario) obs_r = env_r.reset(seed=seed_int, scenario=scenario) try: runner_l.reset(scenario) except Exception: pass try: runner_r.reset(scenario) except Exception: pass ls.update( { "env": env_l, "policy_name": p1, "policy_runner": runner_l, "obs": _observation_to_dict(obs_l), "status": f"✅ Running · {p1}", } ) rs.update( { "env": env_r, "policy_name": p2, "policy_runner": runner_r, "obs": _observation_to_dict(obs_r), "status": f"✅ Running · {p2}", } ) run = {"running": True, "scenario": scenario, "seed": seed_int, "step": 0} return _render_all(ls, rs, run) def _apply_local_step(side: Dict, scenario_name: str, global_step: int) -> Dict: if side.get("done"): return side env = side.get("env") runner = side.get("policy_runner") if env is None or runner is None: side["done"] = True side["status"] = "❌ Not initialized" return side try: action_str = runner.choose_action(side.get("obs", {}) or {}) except Exception as exc: side["done"] = True side["status"] = f"❌ Policy error: {exc}" return side pre_obs = dict(side.get("obs", {}) or {}) obs_obj = env.step(Action(action_type=ActionType(action_str))) obs = _observation_to_dict(obs_obj) reward = float(obs.get("reward", 0.0) or 0.0) meta = dict(obs.get("metadata", {}) or {}) done = bool(obs.get("done", False)) side["history"].append(record_step(global_step, action_str, obs, reward, meta, health_obs=pre_obs)) side["obs"] = obs side["cumulative_reward"] = float(side.get("cumulative_reward", 0.0) or 0.0) + reward side["done"] = done side["status"] = "✅ Done" if done else str(side.get("status", "")) return side def do_step(ls: Dict, rs: Dict, run: Dict): if not bool(run.get("running", False)): return _render_all(ls, rs, run) if int(run.get("step", 0) or 0) >= MAX_STEPS: run["running"] = False return _render_all(ls, rs, run) next_step = int(run.get("step", 0) or 0) + 1 scenario = str(run.get("scenario", "easy") or "easy") ls = _apply_local_step(ls, scenario, next_step) rs = _apply_local_step(rs, scenario, next_step) run["step"] = next_step if next_step >= MAX_STEPS or (ls.get("done") and rs.get("done")): run["running"] = False return _render_all(ls, rs, run) def _stream_to_end(ls: Dict, rs: Dict, run: Dict): if not bool(run.get("running", False)): yield _render_all(ls, rs, run) return frozen = _render_all(ls, rs, run) frozen_grader_plot = frozen[GRADER_PLOT_IDX] while bool(run.get("running", False)) and int(run.get("step", 0) or 0) < MAX_STEPS: out = do_step(ls, rs, run) ls, rs, run = out[0], out[1], out[2] out_list = list(out) out_list[GRADER_PLOT_IDX] = frozen_grader_plot yield tuple(out_list) time.sleep(0.12) if not bool(run.get("running", False)): break yield _render_all(ls, rs, run) def do_fast_forward(ls: Dict, rs: Dict, run: Dict): yield from _stream_to_end(ls, rs, run) def do_finish(ls: Dict, rs: Dict, run: Dict): yield from _stream_to_end(ls, rs, run) start_btn.click(do_start, inputs=[left_policy, right_policy, scenario_sel, seed_inp, left_state, right_state, run_state], outputs=OUTPUTS) step_btn.click(do_step, inputs=[left_state, right_state, run_state], outputs=OUTPUTS) fast_btn.click(do_fast_forward, inputs=[left_state, right_state, run_state], outputs=OUTPUTS) finish_btn.click(do_finish, inputs=[left_state, right_state, run_state], outputs=OUTPUTS) return demo if __name__ == "__main__": app = build_app() app.queue() app.launch(server_port=7860)