ronitraj commited on
Commit
8ec915d
·
1 Parent(s): f56abd2

Fix submission runner and sessionized API/UI flow

Browse files
inference.py CHANGED
@@ -138,12 +138,6 @@ def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> N
138
 
139
 
140
  def _run_task(task_id: str, client: OpenAI | None) -> bool:
141
- env = LLMServeEnvironment(seed=DEFAULT_SEED, mode="sim")
142
- grader = GraderEngine()
143
- fallback_agent = _create_fallback_agent(task_id)
144
- if hasattr(fallback_agent, "reset"):
145
- fallback_agent.reset()
146
-
147
  model_label = MODEL_NAME if client is not None else "heuristic"
148
  _log_start(task=task_id, env_name=ENV_NAME, model=model_label)
149
 
@@ -151,10 +145,18 @@ def _run_task(task_id: str, client: OpenAI | None) -> bool:
151
  steps_taken = 0
152
  score = 0.0
153
  success = False
154
- observation = None
155
  previous_action: dict[str, Any] | None = None
 
 
 
156
 
157
  try:
 
 
 
 
 
 
158
  observation = env.reset(seed=DEFAULT_SEED, task_id=task_id)
159
  task_cfg = env.task_config or {}
160
  configured_max_steps = int(task_cfg.get("max_steps", MAX_STEPS))
@@ -187,7 +189,7 @@ def _run_task(task_id: str, client: OpenAI | None) -> bool:
187
  _log_step(step=step_idx, action=action_json, reward=0.0, done=True, error=_sanitize_error(exc))
188
  break
189
 
190
- grade = grader.grade(env.export_episode_log())
191
  score = float(grade.get("score", 0.0))
192
  score = max(0.0, min(1.0, score))
193
  success = score > 0.0
@@ -204,12 +206,22 @@ def _run_task(task_id: str, client: OpenAI | None) -> bool:
204
 
205
 
206
  def main() -> int:
207
- client = _create_client()
208
- all_success = True
 
 
 
209
  for task_id in TASKS:
210
- ok = _run_task(task_id=task_id, client=client)
211
- all_success = all_success and ok
212
- return 0 if all_success else 1
 
 
 
 
 
 
 
213
 
214
 
215
  if __name__ == "__main__":
 
138
 
139
 
140
  def _run_task(task_id: str, client: OpenAI | None) -> bool:
 
 
 
 
 
 
141
  model_label = MODEL_NAME if client is not None else "heuristic"
142
  _log_start(task=task_id, env_name=ENV_NAME, model=model_label)
143
 
 
145
  steps_taken = 0
146
  score = 0.0
147
  success = False
 
148
  previous_action: dict[str, Any] | None = None
149
+ env: LLMServeEnvironment | None = None
150
+ grader: GraderEngine | None = None
151
+ fallback_agent: Any = None
152
 
153
  try:
154
+ env = LLMServeEnvironment(seed=DEFAULT_SEED, mode="sim")
155
+ grader = GraderEngine()
156
+ fallback_agent = _create_fallback_agent(task_id)
157
+ if hasattr(fallback_agent, "reset"):
158
+ fallback_agent.reset()
159
+
160
  observation = env.reset(seed=DEFAULT_SEED, task_id=task_id)
161
  task_cfg = env.task_config or {}
162
  configured_max_steps = int(task_cfg.get("max_steps", MAX_STEPS))
 
189
  _log_step(step=step_idx, action=action_json, reward=0.0, done=True, error=_sanitize_error(exc))
190
  break
191
 
192
+ grade = grader.grade(env.export_episode_log()) if grader is not None else {"score": 0.0}
193
  score = float(grade.get("score", 0.0))
194
  score = max(0.0, min(1.0, score))
195
  success = score > 0.0
 
206
 
207
 
208
  def main() -> int:
209
+ try:
210
+ client = _create_client()
211
+ except Exception:
212
+ client = None
213
+
214
  for task_id in TASKS:
215
+ try:
216
+ _run_task(task_id=task_id, client=client)
217
+ except Exception as exc:
218
+ _log_start(task=task_id, env_name=ENV_NAME, model=MODEL_NAME if client is not None else "heuristic")
219
+ _log_step(step=1, action="{}", reward=0.0, done=True, error=_sanitize_error(exc))
220
+ _log_end(success=False, steps=1, score=0.0, rewards=[0.0])
221
+
222
+ # The validator treats non-zero exits as infrastructure failures, so we always
223
+ # return 0 after emitting structured episode logs for every task.
224
+ return 0
225
 
226
 
227
  if __name__ == "__main__":
llmserve_env/client.py CHANGED
@@ -10,6 +10,7 @@ from llmserve_env.models import EpisodeLog, ServeAction, ServeObservation, Serve
10
  class LLMServeEnv:
11
  def __init__(self, base_url: str) -> None:
12
  self.base_url = base_url.rstrip("/")
 
13
 
14
  @classmethod
15
  def from_url(cls, base_url: str) -> "LLMServeEnv":
@@ -21,16 +22,21 @@ class LLMServeEnv:
21
 
22
  def reset(self, task_id: str, seed: int | None = None) -> ServeObservation:
23
  payload = self._post("/reset", {"task_id": task_id, "seed": seed})
 
24
  return self._parse_observation_payload(payload)
25
 
26
  def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
 
 
27
  action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
28
- payload = self._post("/step", {"action": action_payload})
29
  observation = self._parse_observation_payload(payload)
30
  return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
31
 
32
  def state(self) -> ServeState:
33
- payload = self._get("/state")
 
 
34
  return ServeState.model_validate(payload)
35
 
36
  def tasks(self) -> dict[str, Any]:
@@ -67,4 +73,3 @@ class LLMServeEnv:
67
  req = request.Request(f"{self.base_url}{path}", data=body, headers=headers, method="POST")
68
  with request.urlopen(req) as response:
69
  return json.loads(response.read().decode("utf-8"))
70
-
 
10
  class LLMServeEnv:
11
  def __init__(self, base_url: str) -> None:
12
  self.base_url = base_url.rstrip("/")
13
+ self.session_id: str | None = None
14
 
15
  @classmethod
16
  def from_url(cls, base_url: str) -> "LLMServeEnv":
 
22
 
23
  def reset(self, task_id: str, seed: int | None = None) -> ServeObservation:
24
  payload = self._post("/reset", {"task_id": task_id, "seed": seed})
25
+ self.session_id = payload.get("session_id")
26
  return self._parse_observation_payload(payload)
27
 
28
  def step(self, action: dict[str, Any] | ServeAction) -> tuple[ServeObservation, float, bool, dict[str, Any]]:
29
+ if self.session_id is None:
30
+ raise RuntimeError("reset() must be called before step() so the client has a session_id.")
31
  action_payload = action.model_dump(mode="json") if isinstance(action, ServeAction) else action
32
+ payload = self._post("/step", {"action": action_payload, "session_id": self.session_id})
33
  observation = self._parse_observation_payload(payload)
34
  return observation, float(payload["reward"]), bool(payload["done"]), observation.metadata
35
 
36
  def state(self) -> ServeState:
37
+ if self.session_id is None:
38
+ raise RuntimeError("reset() must be called before state() so the client has a session_id.")
39
+ payload = self._get(f"/state?session_id={self.session_id}")
40
  return ServeState.model_validate(payload)
41
 
42
  def tasks(self) -> dict[str, Any]:
 
73
  req = request.Request(f"{self.base_url}{path}", data=body, headers=headers, method="POST")
74
  with request.urlopen(req) as response:
75
  return json.loads(response.read().decode("utf-8"))
 
server/app.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import os
4
  from pathlib import Path
5
 
6
- from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import RedirectResponse
8
  from openenv.core import create_fastapi_app
9
  from dotenv import load_dotenv
@@ -13,7 +13,8 @@ from llmserve_env.task_catalog import get_action_schema, get_task_catalog
13
  from server.baseline_inference import create_local_runner, run_baseline_suite
14
  from server.grader import GraderEngine
15
  from server.llmserve_environment import LLMServeEnvironment
16
- from server.schemas import GraderRequest
 
17
  from server.web_ui import create_web_app
18
 
19
 
@@ -29,6 +30,7 @@ def _build_shared_env() -> LLMServeEnvironment:
29
 
30
  shared_env = _build_shared_env()
31
  grader = GraderEngine()
 
32
 
33
 
34
  def get_env() -> LLMServeEnvironment:
@@ -36,6 +38,14 @@ def get_env() -> LLMServeEnvironment:
36
 
37
 
38
  def _register_extra_routes(app: FastAPI) -> FastAPI:
 
 
 
 
 
 
 
 
39
  @app.get("/")
40
  def root() -> RedirectResponse:
41
  return RedirectResponse(url="/web", status_code=307)
@@ -50,8 +60,43 @@ def _register_extra_routes(app: FastAPI) -> FastAPI:
50
  "mode": shared_env.backend.mode,
51
  "backend": shared_env.backend.describe(),
52
  "seed": shared_env.seed,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  }
54
 
 
 
 
 
 
55
  @app.post("/grader")
56
  def grade(payload: GraderRequest | None = None) -> dict[str, object]:
57
  if payload and payload.episode_log is not None:
@@ -98,14 +143,13 @@ def _register_extra_routes(app: FastAPI) -> FastAPI:
98
 
99
 
100
  def create_application(enable_web: bool = True) -> FastAPI:
 
 
 
 
 
101
  if enable_web:
102
- app = create_web_app(shared_env)
103
- else:
104
- app = create_fastapi_app(
105
- get_env,
106
- ServeAction,
107
- ServeObservation,
108
- )
109
  return _register_extra_routes(app)
110
 
111
 
 
3
  import os
4
  from pathlib import Path
5
 
6
+ from fastapi import FastAPI, HTTPException, Query
7
  from fastapi.responses import RedirectResponse
8
  from openenv.core import create_fastapi_app
9
  from dotenv import load_dotenv
 
13
  from server.baseline_inference import create_local_runner, run_baseline_suite
14
  from server.grader import GraderEngine
15
  from server.llmserve_environment import LLMServeEnvironment
16
+ from server.schemas import GraderRequest, ResetRequest, StepRequest
17
+ from server.session_manager import SessionManager
18
  from server.web_ui import create_web_app
19
 
20
 
 
30
 
31
  shared_env = _build_shared_env()
32
  grader = GraderEngine()
33
+ session_manager = SessionManager()
34
 
35
 
36
  def get_env() -> LLMServeEnvironment:
 
38
 
39
 
40
  def _register_extra_routes(app: FastAPI) -> FastAPI:
41
+ def _resolve_env(session_id: str | None) -> LLMServeEnvironment:
42
+ if not session_id:
43
+ return shared_env
44
+ try:
45
+ return session_manager.get(session_id)
46
+ except KeyError as exc:
47
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
48
+
49
  @app.get("/")
50
  def root() -> RedirectResponse:
51
  return RedirectResponse(url="/web", status_code=307)
 
60
  "mode": shared_env.backend.mode,
61
  "backend": shared_env.backend.describe(),
62
  "seed": shared_env.seed,
63
+ "active_sessions": session_manager.count(),
64
+ }
65
+
66
+ @app.post("/reset")
67
+ def reset(payload: ResetRequest) -> dict[str, object]:
68
+ session_id, env = session_manager.create(
69
+ task_id=payload.task_id,
70
+ seed=payload.seed,
71
+ episode_id=payload.episode_id,
72
+ )
73
+ observation = env.observations[-1]
74
+
75
+ return {
76
+ "session_id": session_id,
77
+ "observation": observation.model_dump(mode="json"),
78
+ "reward": observation.reward,
79
+ "done": observation.done,
80
+ "metadata": observation.metadata,
81
+ }
82
+
83
+ @app.post("/step")
84
+ def step(payload: StepRequest) -> dict[str, object]:
85
+ env = _resolve_env(payload.session_id)
86
+ observation = env.step(payload.action)
87
+ return {
88
+ "session_id": payload.session_id or env.state.episode_id,
89
+ "observation": observation.model_dump(mode="json"),
90
+ "reward": observation.reward,
91
+ "done": observation.done,
92
+ "metadata": observation.metadata,
93
  }
94
 
95
+ @app.get("/state")
96
+ def state(session_id: str | None = Query(default=None)) -> dict[str, object]:
97
+ env = _resolve_env(session_id)
98
+ return env.state.model_dump(mode="json")
99
+
100
  @app.post("/grader")
101
  def grade(payload: GraderRequest | None = None) -> dict[str, object]:
102
  if payload and payload.episode_log is not None:
 
143
 
144
 
145
  def create_application(enable_web: bool = True) -> FastAPI:
146
+ app = create_fastapi_app(
147
+ get_env,
148
+ ServeAction,
149
+ ServeObservation,
150
+ )
151
  if enable_web:
152
+ app = create_web_app(app, session_manager, shared_env)
 
 
 
 
 
 
153
  return _register_extra_routes(app)
154
 
155
 
server/replay_assets.py CHANGED
@@ -8,13 +8,44 @@ import pandas as pd
8
 
9
  ROOT_DIR = Path(__file__).resolve().parents[1]
10
  DATA_DIR = ROOT_DIR / "data"
 
11
 
12
 
13
- def resolve_data_path(relative_path: str) -> Path:
14
  path = Path(relative_path)
15
  if path.is_absolute():
16
- return path
17
- return DATA_DIR / path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  @lru_cache(maxsize=None)
 
8
 
9
  ROOT_DIR = Path(__file__).resolve().parents[1]
10
  DATA_DIR = ROOT_DIR / "data"
11
+ SERVER_DATA_DIR = ROOT_DIR / "server" / "data"
12
 
13
 
14
+ def _candidate_paths(relative_path: str) -> list[Path]:
15
  path = Path(relative_path)
16
  if path.is_absolute():
17
+ return [path]
18
+
19
+ candidates = [
20
+ DATA_DIR / path,
21
+ SERVER_DATA_DIR / path,
22
+ ]
23
+
24
+ if path.name == "latency_table.parquet":
25
+ serving_profile = path.with_name("serving_profile_table.parquet")
26
+ candidates.extend(
27
+ [
28
+ DATA_DIR / serving_profile,
29
+ SERVER_DATA_DIR / serving_profile,
30
+ ]
31
+ )
32
+
33
+ seen: set[Path] = set()
34
+ deduped: list[Path] = []
35
+ for candidate in candidates:
36
+ resolved = candidate.resolve()
37
+ if resolved not in seen:
38
+ seen.add(resolved)
39
+ deduped.append(candidate)
40
+ return deduped
41
+
42
+
43
+ def resolve_data_path(relative_path: str) -> Path:
44
+ for candidate in _candidate_paths(relative_path):
45
+ if candidate.exists():
46
+ return candidate
47
+ searched = ", ".join(str(candidate) for candidate in _candidate_paths(relative_path))
48
+ raise FileNotFoundError(f"Could not locate required data asset '{relative_path}'. Searched: {searched}")
49
 
50
 
51
  @lru_cache(maxsize=None)
server/schemas.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  from pydantic import BaseModel, ConfigDict
4
 
5
- from llmserve_env.models import EpisodeLog
6
 
7
 
8
  class GraderRequest(BaseModel):
@@ -11,3 +11,18 @@ class GraderRequest(BaseModel):
11
  task_id: str | None = None
12
  episode_log: EpisodeLog | None = None
13
  actions_taken: int | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from pydantic import BaseModel, ConfigDict
4
 
5
+ from llmserve_env.models import EpisodeLog, ServeAction
6
 
7
 
8
  class GraderRequest(BaseModel):
 
11
  task_id: str | None = None
12
  episode_log: EpisodeLog | None = None
13
  actions_taken: int | None = None
14
+
15
+
16
+ class ResetRequest(BaseModel):
17
+ model_config = ConfigDict(extra="forbid")
18
+
19
+ task_id: str = "static_workload"
20
+ seed: int | None = None
21
+ episode_id: str | None = None
22
+
23
+
24
+ class StepRequest(BaseModel):
25
+ model_config = ConfigDict(extra="forbid")
26
+
27
+ action: ServeAction
28
+ session_id: str | None = None
server/session_manager.py CHANGED
@@ -16,9 +16,14 @@ class SessionManager:
16
  self._sessions: OrderedDict[str, LLMServeEnvironment] = OrderedDict()
17
  self._max_sessions = max_sessions
18
 
19
- def create(self, task_id: str, seed: int | None = None) -> tuple[str, LLMServeEnvironment]:
 
 
 
 
 
20
  env = LLMServeEnvironment(seed=seed or 42)
21
- env.reset(task_id=task_id, seed=seed)
22
  session_id = env.state.episode_id
23
 
24
  with self._lock:
 
16
  self._sessions: OrderedDict[str, LLMServeEnvironment] = OrderedDict()
17
  self._max_sessions = max_sessions
18
 
19
+ def create(
20
+ self,
21
+ task_id: str,
22
+ seed: int | None = None,
23
+ episode_id: str | None = None,
24
+ ) -> tuple[str, LLMServeEnvironment]:
25
  env = LLMServeEnvironment(seed=seed or 42)
26
+ env.reset(task_id=task_id, seed=seed, episode_id=episode_id)
27
  session_id = env.state.episode_id
28
 
29
  with self._lock:
server/web_ui.py CHANGED
@@ -6,47 +6,39 @@ from typing import Any
6
  import gradio as gr
7
  import pandas as pd
8
  from fastapi import FastAPI
9
- from openenv.core import create_fastapi_app
10
 
11
  from llmserve_env.models import QuantizationTier, ServeAction, ServeObservation
12
  from llmserve_env.task_catalog import get_task_catalog
13
  from server.llmserve_environment import LLMServeEnvironment
 
14
 
15
 
16
- def create_web_app(env: LLMServeEnvironment) -> FastAPI:
17
- app = create_fastapi_app(lambda: env, ServeAction, ServeObservation)
18
- blocks = build_web_ui(env)
19
  return gr.mount_gradio_app(app, blocks, path="/web")
20
 
21
 
22
- def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
23
  task_ids = [task["id"] for task in get_task_catalog()]
24
 
25
- def _state_json() -> str:
26
- return json.dumps(env.state.model_dump(mode="json"), indent=2)
27
-
28
- def _session_json() -> str:
29
- backend = env.backend.describe()
30
- payload = {
31
- "active_task_id": env.state.task_id,
32
- "episode_id": env.state.episode_id,
33
- "step_count": env.state.step_count,
34
- "mode": backend.get("mode", env.backend.mode),
35
- "backend": backend,
36
- "done": env.state.done,
37
- }
38
- return json.dumps(payload, indent=2)
39
-
40
- def _response_json(observation: ServeObservation) -> str:
41
- payload = {
42
- "observation": observation.model_dump(mode="json"),
43
- "reward": observation.reward,
44
- "done": observation.done,
45
- "metadata": observation.metadata,
46
- }
47
- return json.dumps(payload, indent=2)
48
 
49
- def _history_frame() -> pd.DataFrame:
 
50
  rows = [
51
  {
52
  "step_index": observation.step_index,
@@ -55,7 +47,7 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
55
  "slo_compliance_rate": observation.slo_compliance_rate,
56
  "throughput_tps": observation.throughput_tps,
57
  }
58
- for observation in env.observations
59
  ]
60
  if not rows:
61
  rows = [
@@ -69,34 +61,84 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
69
  ]
70
  return pd.DataFrame(rows)
71
 
72
- def _ui_payload(observation: ServeObservation, status_message: str) -> tuple[str, str, str, str, pd.DataFrame]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return (
74
  status_message,
75
- _session_json(),
76
  _response_json(observation),
77
- _state_json(),
78
- _history_frame(),
 
79
  )
80
 
81
- def reset_env(task_id: str, seed: int) -> tuple[str, str, str, str, pd.DataFrame]:
82
  try:
83
- observation = env.reset(task_id=task_id, seed=int(seed))
 
 
 
84
  return _ui_payload(
85
  observation,
86
  f"Environment reset for task `{task_id}`. Active episode now uses `{env.state.task_id}`.",
 
 
87
  )
88
  except Exception as exc:
89
- return (f"Error: {exc}", _session_json(), "", _state_json(), _history_frame())
90
 
91
  def step_env(
 
92
  batch_cap: int,
93
  kv_budget_fraction: float,
94
  speculation_depth: int,
95
  quantization_tier: str,
96
  prefill_decode_split: bool,
97
  priority_routing: bool,
98
- ) -> tuple[str, str, str, str, pd.DataFrame]:
99
  try:
 
 
 
100
  action = ServeAction(
101
  batch_cap=int(batch_cap),
102
  kv_budget_fraction=float(kv_budget_fraction),
@@ -109,15 +151,28 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
109
  return _ui_payload(
110
  observation,
111
  f"Step complete for active task `{env.state.task_id}` in `{env.backend.mode}` mode.",
 
 
112
  )
113
  except Exception as exc:
114
- return (f"Error: {exc}", _session_json(), "", _state_json(), _history_frame())
 
 
 
 
 
 
 
 
115
 
116
- def get_state() -> tuple[str, pd.DataFrame]:
117
  try:
118
- return _state_json(), _history_frame()
 
 
 
119
  except Exception as exc:
120
- return f"Error: {exc}", _history_frame()
121
 
122
  with gr.Blocks(title="LLMServeEnv") as demo:
123
  gr.Markdown(
@@ -125,10 +180,13 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
125
  # LLMServeEnv
126
 
127
  Reset an episode, then control the serving policy with bounded inputs only.
 
128
  Numeric controls use sliders, categorical controls use fixed choices.
129
  """
130
  )
131
 
 
 
132
  with gr.Row():
133
  with gr.Column(scale=1):
134
  task_id = gr.Dropdown(
@@ -166,7 +224,7 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
166
 
167
  with gr.Column(scale=2):
168
  response_json = gr.Code(label="Observation / Step Response", language="json", interactive=False)
169
- state_json = gr.Code(label="Current State", language="json", interactive=False)
170
  history_table = gr.Dataframe(
171
  value=_history_frame(),
172
  headers=["step_index", "reward", "p99_ttft_ms", "slo_compliance_rate", "throughput_tps"],
@@ -176,17 +234,18 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
176
 
177
  reset_btn.click(
178
  fn=reset_env,
179
- inputs=[task_id, seed],
180
- outputs=[status, session_json, response_json, state_json, history_table],
181
  )
182
  task_id.change(
183
  fn=reset_env,
184
- inputs=[task_id, seed],
185
- outputs=[status, session_json, response_json, state_json, history_table],
186
  )
187
  step_btn.click(
188
  fn=step_env,
189
  inputs=[
 
190
  batch_cap,
191
  kv_budget_fraction,
192
  speculation_depth,
@@ -194,11 +253,12 @@ def build_web_ui(env: LLMServeEnvironment) -> gr.Blocks:
194
  prefill_decode_split,
195
  priority_routing,
196
  ],
197
- outputs=[status, session_json, response_json, state_json, history_table],
198
  )
199
  state_btn.click(
200
  fn=get_state,
201
- outputs=[state_json, history_table],
 
202
  )
203
 
204
  return demo
 
6
  import gradio as gr
7
  import pandas as pd
8
  from fastapi import FastAPI
 
9
 
10
  from llmserve_env.models import QuantizationTier, ServeAction, ServeObservation
11
  from llmserve_env.task_catalog import get_task_catalog
12
  from server.llmserve_environment import LLMServeEnvironment
13
+ from server.session_manager import SessionManager
14
 
15
 
16
+ def create_web_app(app: FastAPI, session_manager: SessionManager, fallback_env: LLMServeEnvironment) -> FastAPI:
17
+ blocks = build_web_ui(session_manager, fallback_env)
 
18
  return gr.mount_gradio_app(app, blocks, path="/web")
19
 
20
 
21
+ def build_web_ui(session_manager: SessionManager, fallback_env: LLMServeEnvironment) -> gr.Blocks:
22
  task_ids = [task["id"] for task in get_task_catalog()]
23
 
24
+ def _empty_state_json() -> str:
25
+ return json.dumps(
26
+ {
27
+ "episode_id": "",
28
+ "step_count": 0,
29
+ "task_id": "uninitialized",
30
+ "total_requests_served": 0,
31
+ "total_slo_violations": 0,
32
+ "cumulative_reward": 0.0,
33
+ "elapsed_simulated_time_s": 0.0,
34
+ "workload_phase": "warmup",
35
+ "done": False,
36
+ },
37
+ indent=2,
38
+ )
 
 
 
 
 
 
 
 
39
 
40
+ def _history_frame(env: LLMServeEnvironment | None = None) -> pd.DataFrame:
41
+ active_env = env or fallback_env
42
  rows = [
43
  {
44
  "step_index": observation.step_index,
 
47
  "slo_compliance_rate": observation.slo_compliance_rate,
48
  "throughput_tps": observation.throughput_tps,
49
  }
50
+ for observation in active_env.observations
51
  ]
52
  if not rows:
53
  rows = [
 
61
  ]
62
  return pd.DataFrame(rows)
63
 
64
+ def _session_json(env: LLMServeEnvironment | None = None) -> str:
65
+ active_env = env or fallback_env
66
+ backend = active_env.backend.describe()
67
+ payload = {
68
+ "active_task_id": active_env.state.task_id,
69
+ "episode_id": active_env.state.episode_id,
70
+ "step_count": active_env.state.step_count,
71
+ "mode": backend.get("mode", active_env.backend.mode),
72
+ "backend": backend,
73
+ "done": active_env.state.done,
74
+ }
75
+ return json.dumps(payload, indent=2)
76
+
77
+ def _response_json(observation: ServeObservation) -> str:
78
+ payload = {
79
+ "observation": observation.model_dump(mode="json"),
80
+ "reward": observation.reward,
81
+ "done": observation.done,
82
+ "metadata": observation.metadata,
83
+ }
84
+ return json.dumps(payload, indent=2)
85
+
86
+ def _state_json(env: LLMServeEnvironment | None = None) -> str:
87
+ if env is None:
88
+ return _empty_state_json()
89
+ return json.dumps(env.state.model_dump(mode="json"), indent=2)
90
+
91
+ def _get_env(session_id: str | None) -> LLMServeEnvironment | None:
92
+ if not session_id:
93
+ return None
94
+ try:
95
+ return session_manager.get(session_id)
96
+ except KeyError:
97
+ return None
98
+
99
+ def _ui_payload(
100
+ observation: ServeObservation,
101
+ status_message: str,
102
+ session_id: str,
103
+ env: LLMServeEnvironment,
104
+ ) -> tuple[str, str, str, str, pd.DataFrame, str]:
105
  return (
106
  status_message,
107
+ _session_json(env),
108
  _response_json(observation),
109
+ _state_json(env),
110
+ _history_frame(env),
111
+ session_id,
112
  )
113
 
114
+ def reset_env(current_session_id: str | None, task_id: str, seed: int) -> tuple[str, str, str, str, pd.DataFrame, str]:
115
  try:
116
+ if current_session_id:
117
+ session_manager.remove(current_session_id)
118
+ session_id, env = session_manager.create(task_id=task_id, seed=int(seed))
119
+ observation = env.observations[-1]
120
  return _ui_payload(
121
  observation,
122
  f"Environment reset for task `{task_id}`. Active episode now uses `{env.state.task_id}`.",
123
+ session_id,
124
+ env,
125
  )
126
  except Exception as exc:
127
+ return (f"Error: {exc}", _session_json(), "", _state_json(), _history_frame(), current_session_id or "")
128
 
129
  def step_env(
130
+ session_id: str | None,
131
  batch_cap: int,
132
  kv_budget_fraction: float,
133
  speculation_depth: int,
134
  quantization_tier: str,
135
  prefill_decode_split: bool,
136
  priority_routing: bool,
137
+ ) -> tuple[str, str, str, str, pd.DataFrame, str]:
138
  try:
139
+ env = _get_env(session_id)
140
+ if env is None:
141
+ raise RuntimeError("No active session found. Click Reset before stepping.")
142
  action = ServeAction(
143
  batch_cap=int(batch_cap),
144
  kv_budget_fraction=float(kv_budget_fraction),
 
151
  return _ui_payload(
152
  observation,
153
  f"Step complete for active task `{env.state.task_id}` in `{env.backend.mode}` mode.",
154
+ session_id or env.state.episode_id,
155
+ env,
156
  )
157
  except Exception as exc:
158
+ active_env = _get_env(session_id)
159
+ return (
160
+ f"Error: {exc}",
161
+ _session_json(active_env),
162
+ "",
163
+ _state_json(active_env),
164
+ _history_frame(active_env),
165
+ session_id or "",
166
+ )
167
 
168
+ def get_state(session_id: str | None) -> tuple[str, pd.DataFrame, str]:
169
  try:
170
+ env = _get_env(session_id)
171
+ if env is None:
172
+ raise RuntimeError("No active session found. Click Reset to start an episode.")
173
+ return _state_json(env), _history_frame(env), session_id or ""
174
  except Exception as exc:
175
+ return f"Error: {exc}", _history_frame(), session_id or ""
176
 
177
  with gr.Blocks(title="LLMServeEnv") as demo:
178
  gr.Markdown(
 
180
  # LLMServeEnv
181
 
182
  Reset an episode, then control the serving policy with bounded inputs only.
183
+ The web UI now keeps a dedicated backend session per browser tab so repeated Step clicks continue the same episode reliably in Docker.
184
  Numeric controls use sliders, categorical controls use fixed choices.
185
  """
186
  )
187
 
188
+ session_id_state = gr.State(value="")
189
+
190
  with gr.Row():
191
  with gr.Column(scale=1):
192
  task_id = gr.Dropdown(
 
224
 
225
  with gr.Column(scale=2):
226
  response_json = gr.Code(label="Observation / Step Response", language="json", interactive=False)
227
+ state_json = gr.Code(label="Current State", language="json", value=_empty_state_json(), interactive=False)
228
  history_table = gr.Dataframe(
229
  value=_history_frame(),
230
  headers=["step_index", "reward", "p99_ttft_ms", "slo_compliance_rate", "throughput_tps"],
 
234
 
235
  reset_btn.click(
236
  fn=reset_env,
237
+ inputs=[session_id_state, task_id, seed],
238
+ outputs=[status, session_json, response_json, state_json, history_table, session_id_state],
239
  )
240
  task_id.change(
241
  fn=reset_env,
242
+ inputs=[session_id_state, task_id, seed],
243
+ outputs=[status, session_json, response_json, state_json, history_table, session_id_state],
244
  )
245
  step_btn.click(
246
  fn=step_env,
247
  inputs=[
248
+ session_id_state,
249
  batch_cap,
250
  kv_budget_fraction,
251
  speculation_depth,
 
253
  prefill_decode_split,
254
  priority_routing,
255
  ],
256
+ outputs=[status, session_json, response_json, state_json, history_table, session_id_state],
257
  )
258
  state_btn.click(
259
  fn=get_state,
260
+ inputs=[session_id_state],
261
+ outputs=[state_json, history_table, session_id_state],
262
  )
263
 
264
  return demo
tests/test_api.py CHANGED
@@ -5,6 +5,7 @@ import pytest
5
  from fastapi import HTTPException
6
 
7
  from server.app import create_application, shared_env
 
8
 
9
 
10
  def _route_map():
@@ -55,3 +56,32 @@ def test_baseline_endpoint_direct() -> None:
55
  def test_demo_redirects_to_web() -> None:
56
  response = _call(_route_map()["/demo"])
57
  assert response.headers["location"] == "/web"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from fastapi import HTTPException
6
 
7
  from server.app import create_application, shared_env
8
+ from server.schemas import ResetRequest, StepRequest
9
 
10
 
11
  def _route_map():
 
56
  def test_demo_redirects_to_web() -> None:
57
  response = _call(_route_map()["/demo"])
58
  assert response.headers["location"] == "/web"
59
+
60
+
61
+ def test_http_session_advances_across_multiple_steps() -> None:
62
+ routes = _route_map()
63
+ reset_endpoint = routes["/reset"]
64
+ step_endpoint = routes["/step"]
65
+ state_endpoint = routes["/state"]
66
+
67
+ reset_payload = _call(reset_endpoint, ResetRequest(task_id="bursty_workload", seed=42))
68
+ session_id = reset_payload["session_id"]
69
+ assert session_id
70
+ assert reset_payload["observation"]["step_index"] == 0
71
+
72
+ action = {
73
+ "batch_cap": 32,
74
+ "kv_budget_fraction": 1.0,
75
+ "speculation_depth": 0,
76
+ "quantization_tier": "FP16",
77
+ "prefill_decode_split": False,
78
+ "priority_routing": False,
79
+ }
80
+ first_payload = _call(step_endpoint, StepRequest(session_id=session_id, action=action))
81
+ second_payload = _call(step_endpoint, StepRequest(session_id=session_id, action=action))
82
+ assert first_payload["observation"]["step_index"] == 1
83
+ assert second_payload["observation"]["step_index"] == 2
84
+ assert first_payload["reward"] != second_payload["reward"]
85
+
86
+ state_payload = _call(state_endpoint, session_id=session_id)
87
+ assert state_payload["step_count"] == 2
tests/test_inference.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import inference
6
+ from server import replay_assets
7
+
8
+
9
+ def test_resolve_data_path_finds_lookup_table() -> None:
10
+ path = replay_assets.resolve_data_path("lookup_tables/latency_table.parquet")
11
+ assert path.exists()
12
+ assert path.name in {"latency_table.parquet", "serving_profile_table.parquet"}
13
+
14
+
15
+ def test_main_returns_zero_when_env_init_fails(monkeypatch, capsys) -> None:
16
+ class BrokenEnv:
17
+ def __init__(self, *args, **kwargs) -> None:
18
+ raise RuntimeError("simulator bootstrap failed")
19
+
20
+ monkeypatch.setattr(inference, "LLMServeEnvironment", BrokenEnv)
21
+ monkeypatch.setattr(inference, "_create_client", lambda: None)
22
+
23
+ rc = inference.main()
24
+ output = capsys.readouterr().out
25
+
26
+ assert rc == 0
27
+ assert output.count("[START]") == len(inference.TASKS)
28
+ assert output.count("[END]") == len(inference.TASKS)
29
+ assert "simulator bootstrap failed" in output