from __future__ import annotations import json from typing import Any import gradio as gr import pandas as pd from fastapi import FastAPI from llmserve_env.models import QuantizationTier, ServeAction, ServeObservation from llmserve_env.task_catalog import get_task_catalog from server.llmserve_environment import LLMServeEnvironment from server.session_manager import SessionManager def create_web_app(app: FastAPI, session_manager: SessionManager, fallback_env: LLMServeEnvironment) -> FastAPI: blocks = build_web_ui(session_manager, fallback_env) return gr.mount_gradio_app(app, blocks, path="/web") def build_web_ui(session_manager: SessionManager, fallback_env: LLMServeEnvironment) -> gr.Blocks: task_ids = [task["id"] for task in get_task_catalog()] def _empty_state_json() -> str: return json.dumps( { "episode_id": "", "step_count": 0, "task_id": "uninitialized", "total_requests_served": 0, "total_slo_violations": 0, "cumulative_reward": 0.0, "elapsed_simulated_time_s": 0.0, "workload_phase": "warmup", "done": False, }, indent=2, ) def _history_frame(env: LLMServeEnvironment | None = None) -> pd.DataFrame: active_env = env or fallback_env rows = [ { "step_index": observation.step_index, "reward": observation.reward, "p99_ttft_ms": observation.p99_ttft_ms, "slo_compliance_rate": observation.slo_compliance_rate, "throughput_tps": observation.throughput_tps, } for observation in active_env.observations ] if not rows: rows = [ { "step_index": 0, "reward": 0.0, "p99_ttft_ms": 0.0, "slo_compliance_rate": 1.0, "throughput_tps": 0.0, } ] return pd.DataFrame(rows) def _session_json(env: LLMServeEnvironment | None = None) -> str: active_env = env or fallback_env backend = active_env.backend.describe() payload = { "active_task_id": active_env.state.task_id, "episode_id": active_env.state.episode_id, "step_count": active_env.state.step_count, "mode": backend.get("mode", active_env.backend.mode), "backend": backend, "done": active_env.state.done, } return json.dumps(payload, indent=2) def _response_json(observation: ServeObservation) -> str: payload = { "observation": observation.model_dump(mode="json"), "reward": observation.reward, "done": observation.done, "metadata": observation.metadata, } return json.dumps(payload, indent=2) def _state_json(env: LLMServeEnvironment | None = None) -> str: if env is None: return _empty_state_json() return json.dumps(env.state.model_dump(mode="json"), indent=2) def _get_env(session_id: str | None) -> LLMServeEnvironment | None: if not session_id: return None try: return session_manager.get(session_id) except KeyError: return None def _ui_payload( observation: ServeObservation, status_message: str, session_id: str, env: LLMServeEnvironment, ) -> tuple[str, str, str, str, pd.DataFrame, str]: return ( status_message, _session_json(env), _response_json(observation), _state_json(env), _history_frame(env), session_id, ) def reset_env(current_session_id: str | None, task_id: str, seed: int) -> tuple[str, str, str, str, pd.DataFrame, str]: try: if current_session_id: session_manager.remove(current_session_id) session_id, env = session_manager.create(task_id=task_id, seed=int(seed)) observation = env.observations[-1] return _ui_payload( observation, f"Environment reset for task `{task_id}`. Active episode now uses `{env.state.task_id}`.", session_id, env, ) except Exception as exc: return (f"Error: {exc}", _session_json(), "", _state_json(), _history_frame(), current_session_id or "") def step_env( session_id: str | None, batch_cap: int, kv_budget_fraction: float, speculation_depth: int, quantization_tier: str, prefill_decode_split: bool, priority_routing: bool, ) -> tuple[str, str, str, str, pd.DataFrame, str]: try: env = _get_env(session_id) if env is None: raise RuntimeError("No active session found. Click Reset before stepping.") action = ServeAction( batch_cap=int(batch_cap), kv_budget_fraction=float(kv_budget_fraction), speculation_depth=int(speculation_depth), quantization_tier=quantization_tier, prefill_decode_split=bool(prefill_decode_split), priority_routing=bool(priority_routing), ) observation = env.step(action) return _ui_payload( observation, f"Step complete for active task `{env.state.task_id}` in `{env.backend.mode}` mode.", session_id or env.state.episode_id, env, ) except Exception as exc: active_env = _get_env(session_id) return ( f"Error: {exc}", _session_json(active_env), "", _state_json(active_env), _history_frame(active_env), session_id or "", ) def get_state(session_id: str | None) -> tuple[str, pd.DataFrame, str]: try: env = _get_env(session_id) if env is None: raise RuntimeError("No active session found. Click Reset to start an episode.") return _state_json(env), _history_frame(env), session_id or "" except Exception as exc: return f"Error: {exc}", _history_frame(), session_id or "" with gr.Blocks(title="LLMServeEnv") as demo: gr.Markdown( """ # LLMServeEnv Reset an episode, then control the serving policy with bounded inputs only. The web UI now keeps a dedicated backend session per browser tab so repeated Step clicks continue the same episode reliably in Docker. Numeric controls use sliders, categorical controls use fixed choices. """ ) session_id_state = gr.State(value="") with gr.Row(): with gr.Column(scale=1): task_id = gr.Dropdown( choices=task_ids, value=task_ids[0], allow_custom_value=False, label="Task", ) seed = gr.Slider(0, 1000, value=42, step=1, label="Seed") reset_btn = gr.Button("Reset", variant="secondary") gr.Markdown("## Action Controls") batch_cap = gr.Slider(1, 512, value=32, step=1, label="Batch Cap") kv_budget_fraction = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="KV Budget Fraction") speculation_depth = gr.Slider(0, 8, value=0, step=1, label="Speculation Depth") quantization_tier = gr.Radio( choices=[tier.value for tier in QuantizationTier], value=QuantizationTier.FP16.value, label="Quantization Tier", ) prefill_decode_split = gr.Checkbox(value=False, label="Prefill Decode Split") priority_routing = gr.Checkbox(value=False, label="Priority Routing") with gr.Row(): step_btn = gr.Button("Step", variant="primary") state_btn = gr.Button("Get state", variant="secondary") status = gr.Textbox(label="Status", interactive=False) session_json = gr.Code( label="Active Session", language="json", value=_session_json(), interactive=False, ) with gr.Column(scale=2): response_json = gr.Code(label="Observation / Step Response", language="json", interactive=False) state_json = gr.Code(label="Current State", language="json", value=_empty_state_json(), interactive=False) history_table = gr.Dataframe( value=_history_frame(), headers=["step_index", "reward", "p99_ttft_ms", "slo_compliance_rate", "throughput_tps"], label="Episode Metrics History", interactive=False, ) reset_btn.click( fn=reset_env, inputs=[session_id_state, task_id, seed], outputs=[status, session_json, response_json, state_json, history_table, session_id_state], ) task_id.change( fn=reset_env, inputs=[session_id_state, task_id, seed], outputs=[status, session_json, response_json, state_json, history_table, session_id_state], ) step_btn.click( fn=step_env, inputs=[ session_id_state, batch_cap, kv_budget_fraction, speculation_depth, quantization_tier, prefill_decode_split, priority_routing, ], outputs=[status, session_json, response_json, state_json, history_table, session_id_state], ) state_btn.click( fn=get_state, inputs=[session_id_state], outputs=[state_json, history_table, session_id_state], ) return demo