Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """FastAPI application for the LogiFlow-RL OpenEnv environment.""" | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Literal, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| try: | |
| from openenv.core.env_server.types import ( | |
| EnvironmentMetadata, | |
| HealthResponse, | |
| HealthStatus, | |
| ResetRequest, | |
| ResetResponse, | |
| SchemaResponse, | |
| StepRequest, | |
| StepResponse, | |
| ) | |
| from ..models import ( | |
| CrisisLogisticsAction, | |
| CrisisLogisticsObservation, | |
| CrisisLogisticsState, | |
| ) | |
| from .crisis_logistics_env_environment import ( | |
| CrisisLogisticsEnvironment, | |
| choose_resilient_action, | |
| ) | |
| except ImportError: | |
| from openenv.core.env_server.types import ( | |
| EnvironmentMetadata, | |
| HealthResponse, | |
| HealthStatus, | |
| ResetRequest, | |
| ResetResponse, | |
| SchemaResponse, | |
| StepRequest, | |
| StepResponse, | |
| ) | |
| from models import ( | |
| CrisisLogisticsAction, | |
| CrisisLogisticsObservation, | |
| CrisisLogisticsState, | |
| ) | |
| from server.crisis_logistics_env_environment import ( | |
| CrisisLogisticsEnvironment, | |
| choose_resilient_action, | |
| ) | |
| app = FastAPI( | |
| title="OpenEnv Environment HTTP API", | |
| version="1.1.0", | |
| description=( | |
| "HTTP API for interacting with the LogiFlow-RL environment through " | |
| "a standardized OpenEnv-style interface." | |
| ), | |
| ) | |
| env = CrisisLogisticsEnvironment() | |
| VISUALIZER_PATH = Path(__file__).resolve().parent.parent / "visualisation" / "logiflow_visualizer.html" | |
| class PolicyStepRequest(BaseModel): | |
| mode: Literal["heuristic", "llm"] = "heuristic" | |
| timeout_s: Optional[float] = None | |
| class PolicyStepResponse(BaseModel): | |
| observation: Dict[str, Any] | |
| reward: float | |
| done: bool | |
| policy_mode: Literal["heuristic", "llm"] | |
| action_source: Literal["heuristic", "llm"] | |
| action: Dict[str, Any] | |
| llm_model: Optional[str] = None | |
| llm_raw_output: Optional[str] = None | |
| class LlmStatusResponse(BaseModel): | |
| llm_ready: bool | |
| has_api_key: bool | |
| model_name: str | |
| api_base_url: str | |
| def _build_policy_prompt(observation: CrisisLogisticsObservation, title: str) -> str: | |
| return ( | |
| f"Task: {title}\n" | |
| f"Objective: {observation.objective}\n" | |
| f"Step: {observation.step_count + 1}/{observation.max_steps}\n" | |
| f"Visible nodes: {observation.visible_node_ids}\n" | |
| f"Observed node loads: {observation.observed_node_loads}\n" | |
| f"Node capacities: {observation.node_capacities}\n" | |
| f"Visible connectivity: {observation.visible_connectivity}\n" | |
| f"Active disruptions: {observation.active_disruptions}\n" | |
| f"In-transit shipments: {observation.in_transit_shipments[:8]}\n" | |
| f"Incoming shipment: source={observation.pending_source_node}, volume={observation.incoming_load}\n" | |
| f"Traffic event: {observation.event_label}\n" | |
| f"Dynamic pressure: {observation.dynamic_pressure}\n" | |
| f"Priority target: {observation.priority_target_name} (node {observation.priority_target_node})\n" | |
| "Return exactly one JSON object with keys: reasoning, source_node, dest_node, shipment_volume." | |
| ) | |
| def _extract_json_payload(text: str) -> Dict[str, Any]: | |
| decoder = json.JSONDecoder() | |
| candidates = [] | |
| for index, char in enumerate(text): | |
| if char != "{": | |
| continue | |
| try: | |
| payload, _ = decoder.raw_decode(text[index:]) | |
| except Exception: | |
| continue | |
| if isinstance(payload, dict): | |
| candidates.append(payload) | |
| if not candidates: | |
| return {} | |
| required = {"reasoning", "source_node", "dest_node", "shipment_volume"} | |
| for payload in reversed(candidates): | |
| if required.issubset(payload.keys()): | |
| return payload | |
| return candidates[-1] | |
| def _resolve_llm_action(observation: CrisisLogisticsObservation) -> tuple[CrisisLogisticsAction, str, str]: | |
| api_key = ( | |
| os.getenv("HF_TOKEN") | |
| or os.getenv("OPENAI_API_KEY") | |
| or os.getenv("API_KEY") | |
| ) | |
| api_key = api_key.strip() if isinstance(api_key, str) else api_key | |
| if not api_key: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="LLM mode needs HF_TOKEN or OPENAI_API_KEY set in Space secrets.", | |
| ) | |
| base_url = (os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1").strip() | |
| model_name = (os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct").strip() | |
| try: | |
| from openai import OpenAI | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=f"openai client import failed: {exc}") from exc | |
| prompt = _build_policy_prompt(observation, env.task.title) | |
| system_prompt = ( | |
| "You are a logistics routing policy for a crisis supply chain environment. " | |
| "Always return exactly one JSON object with keys: reasoning, source_node, dest_node, shipment_volume." | |
| ) | |
| try: | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to initialize OpenAI client: {type(exc).__name__}: {exc}", | |
| ) from exc | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| temperature=0.0, | |
| max_tokens=180, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail=f"LLM request failed: {exc}") from exc | |
| raw_text = (response.choices[0].message.content or "").strip() | |
| payload = _extract_json_payload(raw_text) | |
| if not payload: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"LLM output did not contain valid action JSON. Raw output: {raw_text[:600]}", | |
| ) | |
| try: | |
| action = CrisisLogisticsAction(**payload) | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"LLM output JSON could not be parsed as action: {exc}. Raw output: {raw_text[:600]}", | |
| ) from exc | |
| return action, model_name, raw_text | |
| def _read_visualizer_html() -> str: | |
| """Load the standalone visualizer HTML bundled with the project.""" | |
| if VISUALIZER_PATH.exists(): | |
| return VISUALIZER_PATH.read_text(encoding="utf-8") | |
| return """ | |
| <html> | |
| <head><title>LogiFlow-RL Visualizer Missing</title></head> | |
| <body style="font-family: Arial, sans-serif; margin: 40px;"> | |
| <h1>Visualizer file not found</h1> | |
| <p>Expected to find the visualizer at <code>/visualisation/logiflow_visualizer.html</code>.</p> | |
| </body> | |
| </html> | |
| """ | |
| async def root() -> HTMLResponse: | |
| return HTMLResponse( | |
| """ | |
| <html> | |
| <head><title>LogiFlow-RL</title></head> | |
| <body style="font-family: Arial, sans-serif; margin: 40px;"> | |
| <h1>LogiFlow-RL</h1> | |
| <p>OpenEnv benchmark for dynamic logistics routing.</p> | |
| <ul> | |
| <li><a href="/docs">API docs</a></li> | |
| <li><a href="/health">Health</a></li> | |
| <li><a href="/schema">Schema</a></li> | |
| <li><a href="/web">Live visualizer</a></li> | |
| </ul> | |
| </body> | |
| </html> | |
| """ | |
| ) | |
| async def web_landing() -> HTMLResponse: | |
| return HTMLResponse(_read_visualizer_html()) | |
| async def web_landing_slash() -> HTMLResponse: | |
| return HTMLResponse(_read_visualizer_html()) | |
| async def server_compat() -> RedirectResponse: | |
| """Compatibility route used by some deployment templates.""" | |
| return RedirectResponse(url="/web") | |
| async def visualizer() -> HTMLResponse: | |
| return HTMLResponse(_read_visualizer_html()) | |
| async def reset_environment(request: Optional[ResetRequest] = None) -> ResetResponse: | |
| request = request or ResetRequest() | |
| task_id = getattr(request, "task_id", None) or "easy" | |
| observation = env.reset(seed=request.seed, episode_id=request.episode_id, task_id=task_id) | |
| return ResetResponse( | |
| observation=observation.model_dump(), | |
| reward=float(observation.reward or 0.0), | |
| done=observation.done, | |
| ) | |
| async def step_environment(request: StepRequest) -> StepResponse: | |
| action = CrisisLogisticsAction(**request.action) | |
| observation = env.step(action, timeout_s=request.timeout_s) | |
| return StepResponse( | |
| observation=observation.model_dump(), | |
| reward=float(observation.reward or 0.0), | |
| done=observation.done, | |
| ) | |
| async def policy_step(request: PolicyStepRequest) -> PolicyStepResponse: | |
| """Execute one environment step using either heuristic or strict LLM policy mode.""" | |
| # Build current observation snapshot for policy selection. | |
| try: | |
| observation = env._get_observation("Policy evaluation snapshot.") | |
| if request.mode == "heuristic": | |
| action = choose_resilient_action(observation) | |
| policy_mode = "heuristic" | |
| action_source = "heuristic" | |
| llm_model = None | |
| llm_raw_output = None | |
| elif request.mode == "llm": | |
| action, llm_model, llm_raw_output = _resolve_llm_action(observation) | |
| policy_mode = "llm" | |
| action_source = "llm" | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unsupported policy mode: {request.mode}") | |
| next_observation = env.step(action, timeout_s=request.timeout_s) | |
| return PolicyStepResponse( | |
| observation=next_observation.model_dump(), | |
| reward=float(next_observation.reward or 0.0), | |
| done=next_observation.done, | |
| policy_mode=policy_mode, | |
| action_source=action_source, | |
| action=action.model_dump(exclude_none=True), | |
| llm_model=llm_model, | |
| llm_raw_output=llm_raw_output if request.mode == "llm" else None, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"policy_step unexpected error: {type(exc).__name__}: {exc}", | |
| ) from exc | |
| async def llm_status() -> LlmStatusResponse: | |
| api_key = ( | |
| os.getenv("HF_TOKEN") | |
| or os.getenv("OPENAI_API_KEY") | |
| or os.getenv("API_KEY") | |
| ) | |
| model_name = (os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct").strip() | |
| api_base_url = (os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1").strip() | |
| return LlmStatusResponse( | |
| llm_ready=bool(api_key), | |
| has_api_key=bool(api_key), | |
| model_name=model_name, | |
| api_base_url=api_base_url, | |
| ) | |
| async def get_state() -> CrisisLogisticsState: | |
| return env.state | |
| async def get_metadata() -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name="LogiFlow-RL", | |
| description=( | |
| "Adaptive logistics routing benchmark with dynamic pressure, priority demand, " | |
| "and multi-component verifiable rewards." | |
| ), | |
| version="1.1.0", | |
| ) | |
| async def get_schema() -> SchemaResponse: | |
| return SchemaResponse( | |
| action=CrisisLogisticsAction.model_json_schema(), | |
| observation=CrisisLogisticsObservation.model_json_schema(), | |
| state=CrisisLogisticsState.model_json_schema(), | |
| ) | |
| async def health() -> HealthResponse: | |
| return HealthResponse(status=HealthStatus.HEALTHY) | |
| def main() -> None: | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| if __name__ == "__main__": | |
| main() | |