test-rl-hackathon-budget / app_gradio.py
Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
"""
Budget Router — Gradio Visualization Dashboard
Run: python app_gradio.py (launches on http://localhost:7860)
"""
from __future__ import annotations
import math
import time
from typing import Dict, Optional, Tuple
import gradio as gr
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType
from budget_router.tasks import TASK_PRESETS
from gradio_ui.config import MAX_STEPS as _MAX_STEPS, POLICY_CHOICES, SCENARIOS
from gradio_ui.policies import get_policy_runner
from gradio_ui.renderers import (
_kpi_grid,
render_incident_timeline,
render_side_panel,
render_grader_plot,
MISSION_SCORE_HELP,
MISSION_SCORE_LABEL,
_GRADER_PENDING,
_PROVIDER_EMPTY,
render_history_table_compare,
)
from gradio_ui.state import fresh_side_state, _observation_to_dict, record_step
from gradio_ui.theme import LIGHT_CSS, THEME
MAX_STEPS = _MAX_STEPS
# Compatibility: preserve module-level MAX_STEPS for callers.
# ─── UI Build ─────────────────────────────────────────────────────────────────
def build_app() -> gr.Blocks:
def _normalize_seed(seed: object, default: int = 42) -> int:
if seed is None:
return default
try:
val = float(seed) # type: ignore[arg-type]
except Exception:
return default
if math.isnan(val) or math.isinf(val):
return default
try:
return int(val)
except Exception:
return default
with gr.Blocks(title="Budget Router — Policy Comparison", theme=THEME, css=LIGHT_CSS) as demo:
left_state = gr.State(fresh_side_state())
right_state = gr.State(fresh_side_state())
run_state = gr.State({"running": False, "scenario": "easy", "seed": 42, "step": 0})
gr.Markdown(
"# Budget Router — Policy Comparison\n"
"_Select 2 policies · start episode · step or finish episode · compare outcomes_"
)
with gr.Row():
with gr.Column(scale=1):
left_title = gr.Markdown("## Policy A")
left_policy = gr.Dropdown(choices=POLICY_CHOICES, value=None, label="Select policy")
left_status = gr.Textbox(label="Status", interactive=False, lines=2)
left_providers = gr.HTML(_PROVIDER_EMPTY())
left_budget = gr.HTML("")
left_kpis = gr.HTML(
_kpi_grid(
[
("Step", "—"),
("Last action", "—"),
("Latency (ms)", "—"),
("Budget remaining", "—"),
("Reward", "—"),
("Adaptation", "—"),
]
)
)
left_badges = gr.HTML("")
left_summary = gr.HTML(
_kpi_grid(
[
("Failed %", "—"),
("SLA breach %", "—"),
("Avg latency (ms)", "—"),
]
)
)
with gr.Column(scale=1):
right_title = gr.Markdown("## Policy B")
right_policy = gr.Dropdown(choices=POLICY_CHOICES, value=None, label="Select policy")
right_status = gr.Textbox(label="Status", interactive=False, lines=2)
right_providers = gr.HTML(_PROVIDER_EMPTY())
right_budget = gr.HTML("")
right_kpis = gr.HTML(
_kpi_grid(
[
("Step", "—"),
("Last action", "—"),
("Latency (ms)", "—"),
("Budget remaining", "—"),
("Reward", "—"),
("Adaptation", "—"),
]
)
)
right_badges = gr.HTML("")
right_summary = gr.HTML(
_kpi_grid(
[
("Failed %", "—"),
("SLA breach %", "—"),
("Avg latency (ms)", "—"),
]
)
)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Episode Controls")
scenario_sel = gr.Radio(SCENARIOS, value="easy", label="Scenario")
seed_inp = gr.Number(value=42, label="Seed", precision=0)
start_btn = gr.Button("▶ Start Episode", variant="primary", interactive=False)
with gr.Row():
step_btn = gr.Button("→ Step", variant="secondary", interactive=False)
fast_btn = gr.Button("⚡ Fast-forward", interactive=False)
finish_btn = gr.Button("⏩ Finish Episode", interactive=False)
gr.Markdown(f"### {MISSION_SCORE_LABEL} (comparison)\n_{MISSION_SCORE_HELP}_")
grader_plot = gr.Plot()
with gr.Row(elem_classes=["episode-history-row"]):
with gr.Column(scale=1):
left_history_title = gr.Markdown("### Step History — Policy A")
left_history_tbl = gr.HTML(render_history_table_compare([]), elem_classes=["episode-history-table"])
with gr.Column(scale=1):
right_history_title = gr.Markdown("### Step History — Policy B")
right_history_tbl = gr.HTML(render_history_table_compare([]), elem_classes=["episode-history-table"])
with gr.Row():
with gr.Column(scale=1):
left_grade_title = gr.Markdown(f"### {MISSION_SCORE_LABEL} — Policy A")
left_grade = gr.HTML(_GRADER_PENDING())
with gr.Column(scale=1):
right_grade_title = gr.Markdown(f"### {MISSION_SCORE_LABEL} — Policy B")
right_grade = gr.HTML(_GRADER_PENDING())
gr.Markdown("### Incident Timeline")
incidents_html = gr.HTML(render_incident_timeline("easy"))
def _render_side(side: Dict, run: Dict, scenario_name: str) -> Tuple[str, str, str, str, str, str, str, str]:
return render_side_panel(side, run, scenario_name)
def _render_all(ls: Dict, rs: Dict, run: Dict) -> tuple:
scenario_name = str(run.get("scenario", "easy") or "easy")
l_out = _render_side(ls, run, scenario_name)
r_out = _render_side(rs, run, scenario_name)
plot = render_grader_plot(
ls.get("history", []) or [],
rs.get("history", []) or [],
left_name=str(ls.get("policy_name") or ""),
right_name=str(rs.get("policy_name") or ""),
)
incidents = render_incident_timeline(scenario_name)
running = bool(run.get("running", False))
btn_update = gr.update(interactive=running)
config_update = gr.update(interactive=(not running))
return (
ls,
rs,
run,
l_out[0],
l_out[1],
l_out[2],
l_out[3],
l_out[4],
l_out[5],
r_out[0],
r_out[1],
r_out[2],
r_out[3],
r_out[4],
r_out[5],
l_out[6],
r_out[6],
l_out[7],
r_out[7],
plot,
incidents,
config_update,
config_update,
config_update,
config_update,
config_update,
btn_update,
btn_update,
btn_update,
)
OUTPUTS = [
left_state,
right_state,
run_state,
left_status,
left_providers,
left_budget,
left_kpis,
left_badges,
left_summary,
right_status,
right_providers,
right_budget,
right_kpis,
right_badges,
right_summary,
left_history_tbl,
right_history_tbl,
left_grade,
right_grade,
grader_plot,
incidents_html,
left_policy,
right_policy,
scenario_sel,
seed_inp,
start_btn,
step_btn,
fast_btn,
finish_btn,
]
GRADER_PLOT_IDX = OUTPUTS.index(grader_plot)
def _update_start_enabled(p1: Optional[str], p2: Optional[str], run: Dict):
left_name = str(p1 or "Policy A")
right_name = str(p2 or "Policy B")
running = bool((run or {}).get("running", False))
ok = (bool(p1) and bool(p2)) and (not running)
return (
gr.update(interactive=ok),
f"## {left_name}",
f"## {right_name}",
f"### Step History — {left_name}",
f"### Step History — {right_name}",
f"### {MISSION_SCORE_LABEL}{left_name}",
f"### {MISSION_SCORE_LABEL}{right_name}",
)
left_policy.change(
_update_start_enabled,
inputs=[left_policy, right_policy, run_state],
outputs=[start_btn, left_title, right_title, left_history_title, right_history_title, left_grade_title, right_grade_title],
)
right_policy.change(
_update_start_enabled,
inputs=[left_policy, right_policy, run_state],
outputs=[start_btn, left_title, right_title, left_history_title, right_history_title, left_grade_title, right_grade_title],
)
scenario_sel.change(lambda s: render_incident_timeline(s), inputs=[scenario_sel], outputs=[incidents_html])
def do_start(p1: str, p2: str, scenario: str, seed: Optional[float], _ls: Dict, _rs: Dict, _run: Dict):
ls = fresh_side_state()
rs = fresh_side_state()
seed_int = _normalize_seed(seed, default=42)
if not p1 or not p2:
run = {"running": False, "scenario": scenario, "seed": seed_int, "step": 0}
ls["status"] = "Select both policies to start."
rs["status"] = "Select both policies to start."
return _render_all(ls, rs, run)
runner_l, err_l = get_policy_runner(p1)
runner_r, err_r = get_policy_runner(p2)
if err_l or err_r or runner_l is None or runner_r is None:
ls["status"] = f"❌ {err_l}" if err_l else ""
rs["status"] = f"❌ {err_r}" if err_r else ""
run = {"running": False, "scenario": scenario, "seed": seed_int, "step": 0}
return _render_all(ls, rs, run)
env_l = BudgetRouterEnv()
env_r = BudgetRouterEnv()
obs_l = env_l.reset(seed=seed_int, scenario=scenario)
obs_r = env_r.reset(seed=seed_int, scenario=scenario)
try:
runner_l.reset(scenario)
except Exception:
pass
try:
runner_r.reset(scenario)
except Exception:
pass
ls.update(
{
"env": env_l,
"policy_name": p1,
"policy_runner": runner_l,
"obs": _observation_to_dict(obs_l),
"status": f"✅ Running · {p1}",
}
)
rs.update(
{
"env": env_r,
"policy_name": p2,
"policy_runner": runner_r,
"obs": _observation_to_dict(obs_r),
"status": f"✅ Running · {p2}",
}
)
run = {"running": True, "scenario": scenario, "seed": seed_int, "step": 0}
return _render_all(ls, rs, run)
def _apply_local_step(side: Dict, scenario_name: str, global_step: int) -> Dict:
if side.get("done"):
return side
env = side.get("env")
runner = side.get("policy_runner")
if env is None or runner is None:
side["done"] = True
side["status"] = "❌ Not initialized"
return side
try:
action_str = runner.choose_action(side.get("obs", {}) or {})
except Exception as exc:
side["done"] = True
side["status"] = f"❌ Policy error: {exc}"
return side
pre_obs = dict(side.get("obs", {}) or {})
obs_obj = env.step(Action(action_type=ActionType(action_str)))
obs = _observation_to_dict(obs_obj)
reward = float(obs.get("reward", 0.0) or 0.0)
meta = dict(obs.get("metadata", {}) or {})
done = bool(obs.get("done", False))
side["history"].append(record_step(global_step, action_str, obs, reward, meta, health_obs=pre_obs))
side["obs"] = obs
side["cumulative_reward"] = float(side.get("cumulative_reward", 0.0) or 0.0) + reward
side["done"] = done
side["status"] = "✅ Done" if done else str(side.get("status", ""))
return side
def do_step(ls: Dict, rs: Dict, run: Dict):
if not bool(run.get("running", False)):
return _render_all(ls, rs, run)
if int(run.get("step", 0) or 0) >= MAX_STEPS:
run["running"] = False
return _render_all(ls, rs, run)
next_step = int(run.get("step", 0) or 0) + 1
scenario = str(run.get("scenario", "easy") or "easy")
ls = _apply_local_step(ls, scenario, next_step)
rs = _apply_local_step(rs, scenario, next_step)
run["step"] = next_step
if next_step >= MAX_STEPS or (ls.get("done") and rs.get("done")):
run["running"] = False
return _render_all(ls, rs, run)
def _stream_to_end(ls: Dict, rs: Dict, run: Dict):
if not bool(run.get("running", False)):
yield _render_all(ls, rs, run)
return
frozen = _render_all(ls, rs, run)
frozen_grader_plot = frozen[GRADER_PLOT_IDX]
while bool(run.get("running", False)) and int(run.get("step", 0) or 0) < MAX_STEPS:
out = do_step(ls, rs, run)
ls, rs, run = out[0], out[1], out[2]
out_list = list(out)
out_list[GRADER_PLOT_IDX] = frozen_grader_plot
yield tuple(out_list)
time.sleep(0.12)
if not bool(run.get("running", False)):
break
yield _render_all(ls, rs, run)
def do_fast_forward(ls: Dict, rs: Dict, run: Dict):
yield from _stream_to_end(ls, rs, run)
def do_finish(ls: Dict, rs: Dict, run: Dict):
yield from _stream_to_end(ls, rs, run)
start_btn.click(do_start, inputs=[left_policy, right_policy, scenario_sel, seed_inp, left_state, right_state, run_state], outputs=OUTPUTS)
step_btn.click(do_step, inputs=[left_state, right_state, run_state], outputs=OUTPUTS)
fast_btn.click(do_fast_forward, inputs=[left_state, right_state, run_state], outputs=OUTPUTS)
finish_btn.click(do_finish, inputs=[left_state, right_state, run_state], outputs=OUTPUTS)
return demo
if __name__ == "__main__":
app = build_app()
app.queue()
app.launch(server_port=7860)