Spaces:
Sleeping
Sleeping
File size: 12,936 Bytes
ddb6ffa de6fe49 47ee65f de6fe49 d0d998b de6fe49 9155bc6 de6fe49 ddb6ffa de6fe49 ddb6ffa de6fe49 ddb6ffa 47ee65f ddb6ffa 47ee65f de6fe49 a4b6177 de6fe49 a4b6177 de6fe49 a4b6177 de6fe49 a4b6177 de6fe49 47ee65f ddb6ffa 09f5177 47ee65f 09f5177 47ee65f 9155bc6 47ee65f 09f5177 ddb6ffa d0d998b ddb6ffa de6fe49 a4b6177 de6fe49 ddb6ffa 47ee65f ddb6ffa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""FastAPI application for the LogiFlow-RL OpenEnv environment."""
import json
import os
from pathlib import Path
from typing import Any, Dict, Literal, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, RedirectResponse
from pydantic import BaseModel
try:
from openenv.core.env_server.types import (
EnvironmentMetadata,
HealthResponse,
HealthStatus,
ResetRequest,
ResetResponse,
SchemaResponse,
StepRequest,
StepResponse,
)
from ..models import (
CrisisLogisticsAction,
CrisisLogisticsObservation,
CrisisLogisticsState,
)
from .crisis_logistics_env_environment import (
CrisisLogisticsEnvironment,
choose_resilient_action,
)
except ImportError:
from openenv.core.env_server.types import (
EnvironmentMetadata,
HealthResponse,
HealthStatus,
ResetRequest,
ResetResponse,
SchemaResponse,
StepRequest,
StepResponse,
)
from models import (
CrisisLogisticsAction,
CrisisLogisticsObservation,
CrisisLogisticsState,
)
from server.crisis_logistics_env_environment import (
CrisisLogisticsEnvironment,
choose_resilient_action,
)
app = FastAPI(
title="OpenEnv Environment HTTP API",
version="1.1.0",
description=(
"HTTP API for interacting with the LogiFlow-RL environment through "
"a standardized OpenEnv-style interface."
),
)
env = CrisisLogisticsEnvironment()
VISUALIZER_PATH = Path(__file__).resolve().parent.parent / "visualisation" / "logiflow_visualizer.html"
class PolicyStepRequest(BaseModel):
mode: Literal["heuristic", "llm"] = "heuristic"
timeout_s: Optional[float] = None
class PolicyStepResponse(BaseModel):
observation: Dict[str, Any]
reward: float
done: bool
policy_mode: Literal["heuristic", "llm"]
action_source: Literal["heuristic", "llm"]
action: Dict[str, Any]
llm_model: Optional[str] = None
llm_raw_output: Optional[str] = None
class LlmStatusResponse(BaseModel):
llm_ready: bool
has_api_key: bool
model_name: str
api_base_url: str
def _build_policy_prompt(observation: CrisisLogisticsObservation, title: str) -> str:
return (
f"Task: {title}\n"
f"Objective: {observation.objective}\n"
f"Step: {observation.step_count + 1}/{observation.max_steps}\n"
f"Visible nodes: {observation.visible_node_ids}\n"
f"Observed node loads: {observation.observed_node_loads}\n"
f"Node capacities: {observation.node_capacities}\n"
f"Visible connectivity: {observation.visible_connectivity}\n"
f"Active disruptions: {observation.active_disruptions}\n"
f"In-transit shipments: {observation.in_transit_shipments[:8]}\n"
f"Incoming shipment: source={observation.pending_source_node}, volume={observation.incoming_load}\n"
f"Traffic event: {observation.event_label}\n"
f"Dynamic pressure: {observation.dynamic_pressure}\n"
f"Priority target: {observation.priority_target_name} (node {observation.priority_target_node})\n"
"Return exactly one JSON object with keys: reasoning, source_node, dest_node, shipment_volume."
)
def _extract_json_payload(text: str) -> Dict[str, Any]:
decoder = json.JSONDecoder()
candidates = []
for index, char in enumerate(text):
if char != "{":
continue
try:
payload, _ = decoder.raw_decode(text[index:])
except Exception:
continue
if isinstance(payload, dict):
candidates.append(payload)
if not candidates:
return {}
required = {"reasoning", "source_node", "dest_node", "shipment_volume"}
for payload in reversed(candidates):
if required.issubset(payload.keys()):
return payload
return candidates[-1]
def _resolve_llm_action(observation: CrisisLogisticsObservation) -> tuple[CrisisLogisticsAction, str, str]:
api_key = (
os.getenv("HF_TOKEN")
or os.getenv("OPENAI_API_KEY")
or os.getenv("API_KEY")
)
api_key = api_key.strip() if isinstance(api_key, str) else api_key
if not api_key:
raise HTTPException(
status_code=503,
detail="LLM mode needs HF_TOKEN or OPENAI_API_KEY set in Space secrets.",
)
base_url = (os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1").strip()
model_name = (os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct").strip()
try:
from openai import OpenAI
except Exception as exc:
raise HTTPException(status_code=500, detail=f"openai client import failed: {exc}") from exc
prompt = _build_policy_prompt(observation, env.task.title)
system_prompt = (
"You are a logistics routing policy for a crisis supply chain environment. "
"Always return exactly one JSON object with keys: reasoning, source_node, dest_node, shipment_volume."
)
try:
client = OpenAI(api_key=api_key, base_url=base_url)
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"Failed to initialize OpenAI client: {type(exc).__name__}: {exc}",
) from exc
try:
response = client.chat.completions.create(
model=model_name,
temperature=0.0,
max_tokens=180,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"LLM request failed: {exc}") from exc
raw_text = (response.choices[0].message.content or "").strip()
payload = _extract_json_payload(raw_text)
if not payload:
raise HTTPException(
status_code=422,
detail=f"LLM output did not contain valid action JSON. Raw output: {raw_text[:600]}",
)
try:
action = CrisisLogisticsAction(**payload)
except Exception as exc:
raise HTTPException(
status_code=422,
detail=f"LLM output JSON could not be parsed as action: {exc}. Raw output: {raw_text[:600]}",
) from exc
return action, model_name, raw_text
def _read_visualizer_html() -> str:
"""Load the standalone visualizer HTML bundled with the project."""
if VISUALIZER_PATH.exists():
return VISUALIZER_PATH.read_text(encoding="utf-8")
return """
<html>
<head><title>LogiFlow-RL Visualizer Missing</title></head>
<body style="font-family: Arial, sans-serif; margin: 40px;">
<h1>Visualizer file not found</h1>
<p>Expected to find the visualizer at <code>/visualisation/logiflow_visualizer.html</code>.</p>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse, tags=["Environment Info"])
async def root() -> HTMLResponse:
return HTMLResponse(
"""
<html>
<head><title>LogiFlow-RL</title></head>
<body style="font-family: Arial, sans-serif; margin: 40px;">
<h1>LogiFlow-RL</h1>
<p>OpenEnv benchmark for dynamic logistics routing.</p>
<ul>
<li><a href="/docs">API docs</a></li>
<li><a href="/health">Health</a></li>
<li><a href="/schema">Schema</a></li>
<li><a href="/web">Live visualizer</a></li>
</ul>
</body>
</html>
"""
)
@app.get("/web", response_class=HTMLResponse, tags=["Environment Info"])
async def web_landing() -> HTMLResponse:
return HTMLResponse(_read_visualizer_html())
@app.get("/web/", response_class=HTMLResponse, tags=["Environment Info"])
async def web_landing_slash() -> HTMLResponse:
return HTMLResponse(_read_visualizer_html())
@app.get("/server", include_in_schema=False)
async def server_compat() -> RedirectResponse:
"""Compatibility route used by some deployment templates."""
return RedirectResponse(url="/web")
@app.get("/visualizer", response_class=HTMLResponse, tags=["Environment Info"])
async def visualizer() -> HTMLResponse:
return HTMLResponse(_read_visualizer_html())
@app.post("/reset", response_model=ResetResponse, tags=["Environment Control"])
async def reset_environment(request: Optional[ResetRequest] = None) -> ResetResponse:
request = request or ResetRequest()
task_id = getattr(request, "task_id", None) or "easy"
observation = env.reset(seed=request.seed, episode_id=request.episode_id, task_id=task_id)
return ResetResponse(
observation=observation.model_dump(),
reward=float(observation.reward or 0.0),
done=observation.done,
)
@app.post("/step", response_model=StepResponse, tags=["Environment Control"])
async def step_environment(request: StepRequest) -> StepResponse:
action = CrisisLogisticsAction(**request.action)
observation = env.step(action, timeout_s=request.timeout_s)
return StepResponse(
observation=observation.model_dump(),
reward=float(observation.reward or 0.0),
done=observation.done,
)
@app.post("/policy_step", response_model=PolicyStepResponse, tags=["Environment Control"])
async def policy_step(request: PolicyStepRequest) -> PolicyStepResponse:
"""Execute one environment step using either heuristic or strict LLM policy mode."""
# Build current observation snapshot for policy selection.
try:
observation = env._get_observation("Policy evaluation snapshot.")
if request.mode == "heuristic":
action = choose_resilient_action(observation)
policy_mode = "heuristic"
action_source = "heuristic"
llm_model = None
llm_raw_output = None
elif request.mode == "llm":
action, llm_model, llm_raw_output = _resolve_llm_action(observation)
policy_mode = "llm"
action_source = "llm"
else:
raise HTTPException(status_code=400, detail=f"Unsupported policy mode: {request.mode}")
next_observation = env.step(action, timeout_s=request.timeout_s)
return PolicyStepResponse(
observation=next_observation.model_dump(),
reward=float(next_observation.reward or 0.0),
done=next_observation.done,
policy_mode=policy_mode,
action_source=action_source,
action=action.model_dump(exclude_none=True),
llm_model=llm_model,
llm_raw_output=llm_raw_output if request.mode == "llm" else None,
)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"policy_step unexpected error: {type(exc).__name__}: {exc}",
) from exc
@app.get("/llm_status", response_model=LlmStatusResponse, tags=["Environment Info"])
async def llm_status() -> LlmStatusResponse:
api_key = (
os.getenv("HF_TOKEN")
or os.getenv("OPENAI_API_KEY")
or os.getenv("API_KEY")
)
model_name = (os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct").strip()
api_base_url = (os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1").strip()
return LlmStatusResponse(
llm_ready=bool(api_key),
has_api_key=bool(api_key),
model_name=model_name,
api_base_url=api_base_url,
)
@app.get("/state", response_model=CrisisLogisticsState, tags=["State Management"])
async def get_state() -> CrisisLogisticsState:
return env.state
@app.get("/metadata", response_model=EnvironmentMetadata, tags=["Environment Info"])
async def get_metadata() -> EnvironmentMetadata:
return EnvironmentMetadata(
name="LogiFlow-RL",
description=(
"Adaptive logistics routing benchmark with dynamic pressure, priority demand, "
"and multi-component verifiable rewards."
),
version="1.1.0",
)
@app.get("/schema", response_model=SchemaResponse, tags=["Schema"])
async def get_schema() -> SchemaResponse:
return SchemaResponse(
action=CrisisLogisticsAction.model_json_schema(),
observation=CrisisLogisticsObservation.model_json_schema(),
state=CrisisLogisticsState.model_json_schema(),
)
@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health() -> HealthResponse:
return HealthResponse(status=HealthStatus.HEALTHY)
def main() -> None:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|