from __future__ import annotations
import concurrent.futures
import os
from pathlib import Path
import time
from collections import OrderedDict
import logging
from server.road_router import get_router
from ev_grid_oracle.traffic import TrafficModel
import hashlib
try:
from openenv.core.env_server.http_server import create_app
except ImportError as e: # pragma: no cover
raise ImportError("openenv-core required. Install deps from pyproject.") from e
from typing import Any, Literal, cast
from uuid import uuid4
from pydantic import BaseModel, Field
from fastapi import Body, HTTPException, Query, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from ev_grid_oracle.city_graph import build_city_graph
import networkx as nx
from ev_grid_oracle.env import EVGridCore, _build_prompt
from ev_grid_oracle.models import (
ActionType,
EVRequest,
EVGridAction,
EVGridObservation,
GridDirective,
MultiAgentStepRequest,
NegotiationMessage,
)
from ev_grid_oracle.oracle_agent import OracleAgent
from ev_grid_oracle.policies import baseline_policy
from ev_grid_oracle.parsing import parse_simulation
from ev_grid_oracle.reward import split_role_rewards
from ev_grid_oracle.scenarios import ScenarioName
from ev_grid_oracle.world_model_verifier import rollout_deterministic_5ticks, score_prediction
from ev_grid_oracle.multi_agent import MultiAgentSession
from server.ev_grid_environment import EVGridEnvironment
from server.ev_grid_road_environment import EVGridRoadEnvironment
from ev_grid_oracle.road_models import RoadAction, RoadObservation
from server.role_metrics import compute_role_kpis, compute_role_reward_breakdown, summarize_action
from server.road_router import haversine_m
log = logging.getLogger("ev-grid-oracle")
if not log.handlers:
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
def _request_id(req: Request) -> str:
rid = (req.headers.get("x-request-id") or req.headers.get("x-amzn-trace-id") or "").strip()
return rid or uuid4().hex
def _oracle_skip_llm_env() -> bool:
return os.getenv("ORACLE_SKIP_LLM", "").strip() not in ("", "0", "false", "False")
_RATE_BUCKET: dict[str, list[float]] = {}
def _rate_limit(req: Request, *, key: str, limit: int, window_sec: int) -> None:
ip = (req.client.host if req.client else "unknown") + ":" + key
now = time.time()
xs = _RATE_BUCKET.get(ip, [])
xs = [t for t in xs if now - t < window_sec]
if len(xs) >= limit:
raise HTTPException(status_code=429, detail=f"Rate limit exceeded ({key}). Please wait and retry.")
xs.append(now)
_RATE_BUCKET[ip] = xs
def _demo_oracle_act_with_guard(
*,
st: Any,
core: EVGridCore,
oracle_lora_repo: str,
) -> tuple[EVGridAction, str, bool, bool, bool]:
"""
Run oracle policy with CPU-Space-safe guards.
Returns: action, oracle_text, oracle_llm_active, oracle_timed_out, oracle_skipped_env
"""
if _oracle_skip_llm_env():
a, t = OracleAgent(lora_repo_id=None).act_with_text(st, _build_prompt(st), core.city_graph)
return a, t, False, False, True
repo = (oracle_lora_repo or "").strip() or None
if not repo:
agent = OracleAgent(lora_repo_id=None)
action, text = agent.act_with_text(st, _build_prompt(st), core.city_graph)
return action, text, bool(agent.is_active), False, False
timeout = float(os.getenv("DEMO_ORACLE_INFERENCE_TIMEOUT_SEC", "90"))
# Reuse a single executor to avoid spawning threads repeatedly.
# Note: cancellation does not reliably stop model load once started, so we keep the timeout
# as a *response guard* only. The model cache in OracleAgent prevents repeated cold-loads.
global _ORACLE_EXEC
try:
_ORACLE_EXEC
except NameError:
_ORACLE_EXEC = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def run() -> tuple[EVGridAction, str, bool]:
agent = OracleAgent(lora_repo_id=repo)
action, text = agent.act_with_text(st, _build_prompt(st), core.city_graph)
return action, text, bool(agent.is_active)
fut = _ORACLE_EXEC.submit(run)
try:
action, text, active = fut.result(timeout=timeout)
return action, text, active, False, False
except concurrent.futures.TimeoutError:
return baseline_policy(st, core.city_graph), "[timeout] baseline fallback (oracle too slow)", False, True, False
app = create_app(EVGridEnvironment, EVGridAction, EVGridObservation, env_name="ev-grid-oracle", max_concurrent_envs=1)
# Mount a separate “real road graph” RL environment under /road/.
road_app = create_app(EVGridRoadEnvironment, RoadAction, RoadObservation, env_name="ev-grid-oracle-road", max_concurrent_envs=1)
app.mount("/road", road_app)
_WEB_DIST = (Path(__file__).resolve().parents[1] / "web" / "dist").resolve()
if _WEB_DIST.exists():
# Serve Phaser UI at /ui (built by Docker during Space build)
app.mount("/ui", StaticFiles(directory=str(_WEB_DIST), html=True), name="ui")
@app.get("/", response_class=HTMLResponse)
def root() -> str:
# HF Spaces loads / by default; redirect to the Phaser UI if present.
if _WEB_DIST.exists():
return """
"""
return """\
EV Grid Oracle (OpenEnv)
EV Grid Oracle — OpenEnv Environment
This Space hosts the FastAPI server for the OpenEnv environment + a small demo API.
Demo API (for the Phaser pixel-map client)
POST /demo/new
POST /demo/step
GET /demo/state
If the Phaser UI is built into this Space, it will be available at
/ui/.
"""
@app.get("/healthz")
def healthz(req: Request) -> dict[str, Any]:
"""
HF Spaces / cold-start friendly health endpoint.
Keep it fast and dependency-safe (no heavy routing work).
"""
rid = _request_id(req)
router_ok = True
try:
# Lazy import / init; should be cached if already loaded.
get_router()
except Exception:
router_ok = False
return {
"ok": True,
"request_id": rid,
"sim_version": _SIM_VERSION,
"web_ui": bool(_WEB_DIST.exists()),
"demo_sessions": len(_demo_sessions),
"router_ok": router_ok,
"schema_version": "traffic-v1",
}
# -----------------------------
# Demo API (Phaser frontend)
# -----------------------------
_DEMO_SESSION_TTL_SEC = int(os.getenv("DEMO_SESSION_TTL_SEC", "3600")) # 1h
_DEMO_MAX_SESSIONS = int(os.getenv("DEMO_MAX_SESSIONS", "64"))
# Ordered for deterministic eviction of oldest sessions.
_demo_sessions: "OrderedDict[str, tuple[float, EVGridCore]]" = OrderedDict()
_demo_graph = build_city_graph()
_SIM_VERSION = "2026-04-26.1"
def _osm_route_polyline(
*,
src_lat: float,
src_lng: float,
dst_lat: float,
dst_lng: float,
traffic: TrafficModel | None = None,
tick: int | None = None,
) -> tuple[list[list[float]], list[int]] | None:
try:
return get_router().route_polyline(
src_lat=src_lat, src_lng=src_lng, dst_lat=dst_lat, dst_lng=dst_lng, traffic=traffic, tick=tick
)
except Exception:
return None
def _graph_route_polyline(core: EVGridCore, *, src_station_id: str, dst_station_id: str) -> list[list[float]]:
"""
Return a render-friendly polyline (lat/lng pairs) along the station graph.
v0 fallback was a straight line; this produces a multi-point path so the UI reads like navigation.
"""
if src_station_id == dst_station_id:
n = core.city_graph.nodes[src_station_id]
return [[float(n["lat"]), float(n["lng"])]]
try:
path = cast(list[str], nx.shortest_path(core.city_graph, src_station_id, dst_station_id, weight="weight_minutes"))
except Exception:
# Fallback: direct
a = core.city_graph.nodes[src_station_id]
b = core.city_graph.nodes[dst_station_id]
return [[float(a["lat"]), float(a["lng"])], [float(b["lat"]), float(b["lng"])]]
out: list[list[float]] = []
for sid in path:
n = core.city_graph.nodes[sid]
out.append([float(n["lat"]), float(n["lng"])])
return out
def _spawn_road_point_away_from_stations(
*,
core: EVGridCore,
min_station_dist_m: float,
seed_key: str,
attempts: int = 80,
) -> tuple[float, float]:
"""
Pick a deterministic road-graph node location (lat,lng) that is not within
`min_station_dist_m` of any station. Deterministic for a given seed_key.
"""
router = get_router()
st = core._grid_state
if st is None:
raise ValueError("core not initialized")
stations = st.stations
if not stations:
raise ValueError("no stations")
h = hashlib.sha1(seed_key.encode("utf-8")).digest()
base = int.from_bytes(h[:4], "big")
n = len(router.nodes)
for k in range(attempts):
idx = (base + k * 9973) % max(1, n)
lat, lng = router.nodes[int(idx)]
ok = True
for s in stations:
if haversine_m(float(lat), float(lng), float(s.lat), float(s.lng)) < float(min_station_dist_m):
ok = False
break
if ok:
return float(lat), float(lng)
raise ValueError("could_not_find_spawn_point")
def _demo_session_gc(now: float | None = None) -> None:
t = float(now if now is not None else time.time())
# TTL eviction
expired: list[str] = []
for sid, (ts, _core) in _demo_sessions.items():
if t - float(ts) > float(_DEMO_SESSION_TTL_SEC):
expired.append(sid)
for sid in expired:
_demo_sessions.pop(sid, None)
# size eviction
while len(_demo_sessions) > int(_DEMO_MAX_SESSIONS):
_demo_sessions.popitem(last=False)
def _demo_session_get(session_id: str) -> EVGridCore | None:
_demo_session_gc()
row = _demo_sessions.get(session_id)
if row is None:
return None
ts, core = row
# touch (LRU-ish)
_demo_sessions.move_to_end(session_id, last=True)
_demo_sessions[session_id] = (time.time(), core)
return core
class DemoNewRequest(BaseModel):
seed: int = Field(123, ge=0, le=1_000_000)
scenario: ScenarioName = Field("baseline")
fleet_mode: str = Field("mixed", description="Fleet persona mix: mixed|taxi|corporate|delivery|private|emergency")
# -----------------------------
# Multi-agent demo API (Theme #1)
# -----------------------------
_MA_SESSION_TTL_SEC = int(os.getenv("MA_SESSION_TTL_SEC", "3600"))
_MA_MAX_SESSIONS = int(os.getenv("MA_MAX_SESSIONS", "64"))
_ma_sessions: "OrderedDict[str, tuple[float, MultiAgentSession]]" = OrderedDict()
def _ma_gc(now: float | None = None) -> None:
t = float(now if now is not None else time.time())
expired: list[str] = []
for sid, (ts, _sess) in _ma_sessions.items():
if t - float(ts) > float(_MA_SESSION_TTL_SEC):
expired.append(sid)
for sid in expired:
_ma_sessions.pop(sid, None)
while len(_ma_sessions) > int(_MA_MAX_SESSIONS):
_ma_sessions.popitem(last=False)
def _ma_get(session_id: str) -> MultiAgentSession | None:
_ma_gc()
row = _ma_sessions.get(session_id)
if row is None:
return None
_ts, sess = row
_ma_sessions.move_to_end(session_id, last=True)
_ma_sessions[session_id] = (time.time(), sess)
return sess
class MANewRequest(BaseModel):
seed: int = Field(123, ge=0, le=1_000_000)
scenario: ScenarioName = Field("baseline")
fleet_mode: str = Field("mixed", description="Fleet persona mix: mixed|taxi|corporate|delivery|private|emergency")
@app.post("/ma/new")
def ma_new(req: Request, payload: MANewRequest = Body(...)) -> dict[str, Any]:
_rate_limit(req, key="ma_new", limit=30, window_sec=60)
t0 = time.time()
rid = _request_id(req)
try:
_ma_gc()
sid = str(uuid4())
core = EVGridCore(city_graph=_demo_graph)
obs = core.reset(seed=payload.seed, scenario=cast(ScenarioName, payload.scenario), fleet_mode=cast(Any, payload.fleet_mode))
sess = MultiAgentSession(core=core)
_ma_sessions[sid] = (time.time(), sess)
log.info(
"ma_new",
extra={"rid": rid, "sid": sid, "seed": payload.seed, "scenario": str(core.scenario), "ms": int((time.time() - t0) * 1000)},
)
return {
"request_id": rid,
"session_id": sid,
"obs": _obs_to_jsonable(obs),
"station_nodes": _station_nodes(core),
"scenario": core.scenario,
"seed": payload.seed,
"sim_version": _SIM_VERSION,
"messages": [],
"grid_directive": GridDirective().model_dump(mode="json"),
}
except HTTPException:
raise
except Exception as e:
log.exception("ma_new_error", extra={"rid": rid, "ms": int((time.time() - t0) * 1000)})
raise HTTPException(status_code=500, detail=f"ma_new_error: {type(e).__name__}: {e}")
def _grid_policy(st) -> tuple[GridDirective, NegotiationMessage]:
# Deterministic grid-side directive: tighten budget during high/critical risk,
# and blacklist top-loaded stations to prevent local overload.
peak = getattr(st, "peak_risk", None)
max_grid = 0.92
if peak and str(peak.value) == "high":
max_grid = 0.88
if peak and str(peak.value) == "critical":
max_grid = 0.84
# blacklist top-2 load stations
stations = list(getattr(st, "stations", []) or [])
top = sorted(stations, key=lambda s: (s.occupied_slots / max(1, s.total_slots), s.queue_length), reverse=True)[:2]
bl = [s.station_id for s in top]
d = GridDirective(max_grid_load_pct=float(max_grid), station_blacklist=bl, price_mult=1.0)
msg = NegotiationMessage(role="grid", text=f"Keep grid<= {max_grid:.2f}. Avoid {', '.join(bl) if bl else 'none'}.")
return d, msg
class MAAutoStepRequest(BaseModel):
session_id: str
fleet_policy: Literal["baseline", "oracle"] = "baseline"
oracle_lora_repo: str = ""
@app.post("/ma/auto_step")
def ma_auto_step(req: Request, payload: MAAutoStepRequest = Body(...)) -> dict[str, Any]:
_rate_limit(req, key="ma_auto_step", limit=120, window_sec=60)
"""
Convenience endpoint for the demo UI: server computes both roles' actions/messages
while still using the explicit multi-agent protocol internally.
"""
sess = _ma_get(payload.session_id)
if sess is None:
raise HTTPException(status_code=404, detail="Unknown session_id")
st = sess.core._grid_state
if st is None:
raise HTTPException(status_code=400, detail="Session not initialized")
directive, grid_msg = _grid_policy(st)
if payload.fleet_policy == "baseline":
fleet_action = baseline_policy(st, sess.core.city_graph)
fleet_msg = NegotiationMessage(role="fleet", text="Routing using heuristic baseline under grid constraints.")
else:
action, _txt, active, timed_out, skipped = _demo_oracle_act_with_guard(st=st, core=sess.core, oracle_lora_repo=payload.oracle_lora_repo)
fleet_action = action
tag = "LLM" if active else "fallback"
fleet_msg = NegotiationMessage(role="fleet", text=f"Routing with oracle ({tag}).")
obs = sess.step(
grid_directive=directive,
fleet_action=fleet_action,
grid_message=grid_msg,
fleet_message=fleet_msg,
)
directive_ok = len(sess.last_violations) == 0
meaningful = True
rr = split_role_rewards(obs.reward_breakdown, grid_directive_ok=directive_ok, has_meaningful_messages=meaningful)
return {
"session_id": payload.session_id,
"obs": obs.model_dump(mode="json"),
"tick": sess.core.step_count,
"scenario": sess.core.scenario,
"grid_directive": directive.model_dump(mode="json"),
"fleet_action": fleet_action.model_dump(mode="json"),
"resolved_action": sess.last_resolved_action.model_dump(mode="json") if sess.last_resolved_action else fleet_action.model_dump(mode="json"),
"violations": list(sess.last_violations),
"messages": [m.model_dump(mode="json") for m in sess.messages[-50:]],
"role_rewards": rr,
}
@app.get("/ma/state")
def ma_state(req: Request, session_id: str = Query(...)) -> dict[str, Any]:
_rate_limit(req, key="ma_state", limit=120, window_sec=60)
t0 = time.time()
sess = _ma_get(session_id)
if sess is None:
log.info("ma_state_miss", extra={"sid": session_id, "ms": int((time.time() - t0) * 1000)})
raise HTTPException(status_code=404, detail="Unknown session_id")
core = sess.core
st = core._grid_state
if st is None:
obs = core.reset(seed=123, scenario=core.scenario)
else:
obs = EVGridObservation(
prompt=_build_prompt(st),
state=st,
done=False,
reward_breakdown={},
anti_cheat_flags=[],
anti_cheat_details={},
)
return {
"session_id": session_id,
"obs": _obs_to_jsonable(obs),
"station_nodes": _station_nodes(core),
"scenario": core.scenario,
"tick": core.step_count,
"messages": [m.model_dump(mode="json") for m in sess.messages[-50:]],
"grid_directive": sess.last_directive.model_dump(mode="json"),
"violations": list(sess.last_violations),
"resolved_action": sess.last_resolved_action.model_dump(mode="json") if sess.last_resolved_action else None,
}
@app.post("/ma/step")
def ma_step(req: Request, payload: MultiAgentStepRequest = Body(...)) -> dict[str, Any]:
_rate_limit(req, key="ma_step", limit=120, window_sec=60)
t0 = time.time()
rid = _request_id(req)
sess = _ma_get(payload.session_id)
if sess is None:
log.info("ma_step_miss", extra={"rid": rid, "sid": payload.session_id, "ms": int((time.time() - t0) * 1000)})
raise HTTPException(status_code=404, detail="Unknown session_id")
gm = payload.grid_message
fm = payload.fleet_message
if gm is not None and gm.role != "grid":
raise HTTPException(status_code=400, detail="grid_message.role must be 'grid'")
if fm is not None and fm.role != "fleet":
raise HTTPException(status_code=400, detail="fleet_message.role must be 'fleet'")
obs = sess.step(
grid_directive=payload.grid_directive,
fleet_action=payload.fleet_action,
grid_message=gm,
fleet_message=fm,
)
out_obs = obs.model_dump(mode="json")
directive_ok = len(sess.last_violations) == 0
meaningful = (gm is not None and gm.text.strip() != "") or (fm is not None and fm.text.strip() != "")
role_rewards = split_role_rewards(
out_obs.get("reward_breakdown", {}) if isinstance(out_obs, dict) else {},
grid_directive_ok=directive_ok,
has_meaningful_messages=meaningful,
)
log.info(
"ma_step",
extra={
"rid": rid,
"sid": payload.session_id,
"tick": int(sess.core.step_count),
"viol": ",".join(sess.last_violations),
"ms": int((time.time() - t0) * 1000),
},
)
return {
"request_id": rid,
"session_id": payload.session_id,
"obs": out_obs,
"tick": sess.core.step_count,
"scenario": sess.core.scenario,
"grid_directive": payload.grid_directive.model_dump(mode="json"),
"fleet_action": payload.fleet_action.model_dump(mode="json"),
"resolved_action": sess.last_resolved_action.model_dump(mode="json") if sess.last_resolved_action else payload.fleet_action.model_dump(mode="json"),
"violations": list(sess.last_violations),
"messages": [m.model_dump(mode="json") for m in sess.messages[-50:]],
"role_rewards": role_rewards,
}
def _obs_to_jsonable(obs: EVGridObservation) -> dict[str, Any]:
# Pydantic v2 BaseModel: use model_dump for JSONable dicts
return obs.model_dump()
def _station_nodes(core: EVGridCore) -> list[dict[str, Any]]:
st = core._grid_state
if st is None:
return []
return [
{
"station_id": s.station_id,
"name": s.neighborhood_name,
"slug": s.neighborhood_slug,
"lat": s.lat,
"lng": s.lng,
"total_slots": s.total_slots,
}
for s in st.stations
]
@app.post("/demo/new")
def demo_new(req: Request, payload: DemoNewRequest = Body(...)) -> dict[str, Any]:
_rate_limit(req, key="demo_new", limit=30, window_sec=60)
t0 = time.time()
rid = _request_id(req)
try:
_demo_session_gc()
session_id = str(uuid4())
core = EVGridCore(city_graph=_demo_graph)
obs = core.reset(seed=payload.seed, scenario=cast(ScenarioName, payload.scenario), fleet_mode=cast(Any, payload.fleet_mode))
_demo_sessions[session_id] = (time.time(), core)
from ev_grid_oracle.scenarios import scenario_schedule
log.info(
"demo_new",
extra={"rid": rid, "sid": session_id, "seed": payload.seed, "scenario": str(obs.state and core.scenario), "ms": int((time.time()-t0)*1000)},
)
return {
"request_id": rid,
"session_id": session_id,
"obs": _obs_to_jsonable(obs),
"station_nodes": _station_nodes(core),
"scenario": core.scenario,
"seed": payload.seed,
"sim_version": _SIM_VERSION,
"scenario_schedule": scenario_schedule(core.scenario),
}
except HTTPException:
raise
except Exception as e:
log.exception("demo_new_error", extra={"rid": rid, "ms": int((time.time() - t0) * 1000)})
raise HTTPException(status_code=500, detail=f"demo_new_error: {type(e).__name__}: {e}")
@app.get("/demo/state")
def demo_state(req: Request, session_id: str = Query(...)) -> dict[str, Any]:
_rate_limit(req, key="demo_state", limit=120, window_sec=60)
t0 = time.time()
rid = _request_id(req)
core = _demo_session_get(session_id)
if core is None:
log.info("demo_state_miss", extra={"rid": rid, "sid": session_id, "ms": int((time.time()-t0)*1000)})
raise HTTPException(status_code=404, detail="Unknown session_id")
st = core._grid_state
if st is None:
obs = core.reset(seed=123, scenario=core.scenario)
else:
obs = EVGridObservation(
prompt=_build_prompt(st),
state=st,
done=False,
reward_breakdown={},
anti_cheat_flags=[],
anti_cheat_details={},
)
from ev_grid_oracle.scenarios import scenario_schedule
log.info("demo_state", extra={"rid": rid, "sid": session_id, "tick": int(core.step_count), "ms": int((time.time()-t0)*1000)})
return {
"request_id": rid,
"session_id": session_id,
"obs": _obs_to_jsonable(obs),
"station_nodes": _station_nodes(core),
"scenario": core.scenario,
"sim_version": _SIM_VERSION,
"scenario_schedule": scenario_schedule(core.scenario),
}
class DemoSpawnVehicleRequest(BaseModel):
session_id: str
min_station_dist_m: float = Field(250.0, ge=50.0, le=3000.0)
battery_threshold_pct: float = Field(30.0, ge=1.0, le=80.0)
@app.post("/demo/spawn_vehicle")
def demo_spawn_vehicle(req: Request, payload: DemoSpawnVehicleRequest = Body(...)) -> dict[str, Any]:
"""
Spawn a new EV at a valid road location (away from stations) and immediately compute
an assignment + route event for the frontend.
"""
_rate_limit(req, key="demo_spawn_vehicle", limit=80, window_sec=60)
t0 = time.time()
rid = _request_id(req)
core = _demo_session_get(payload.session_id)
if core is None:
log.info("demo_spawn_vehicle_miss", extra={"rid": rid, "sid": payload.session_id, "ms": int((time.time()-t0)*1000)})
raise HTTPException(status_code=404, detail="Unknown session_id")
st = core._grid_state
if st is None:
raise HTTPException(status_code=400, detail="Session not initialized")
# deterministic id and spawn point
ev_id = f"SPAWN-{uuid4().hex[:10]}"
seed_key = f"{core._seed_for_bescom}|{core.scenario}|{core.step_count}|{len(st.pending_evs)}|{ev_id}"
try:
lat, lng = _spawn_road_point_away_from_stations(
core=core,
min_station_dist_m=float(payload.min_station_dist_m),
seed_key=seed_key,
)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
# Pick nearest station just to fill neighborhood fields (used by policies + prompt).
nearest = min(st.stations, key=lambda s: haversine_m(lat, lng, float(s.lat), float(s.lng)))
# Force low battery so it needs charging.
battery = min(float(payload.battery_threshold_pct) - 1.0, 25.0)
if battery < 2.0:
battery = float(payload.battery_threshold_pct) * 0.5
battery = max(1.0, battery)
spawned = EVRequest(
ev_id=ev_id,
battery_pct_0_100=round(float(battery), 1),
urgency=round(0.9 if battery < 15.0 else 0.7, 2),
persona="PrivateOwner",
price_sensitivity=0.35,
neighborhood_slug=str(nearest.neighborhood_slug),
neighborhood_name=str(nearest.neighborhood_name),
target_charge_pct_0_100=90.0,
max_wait_minutes=30,
)
st.pending_evs.insert(0, spawned)
# Assignment: re-use the baseline scoring (distance + wait + stress + price),
# and require capacity (avoid full stations). If none, respond gracefully.
candidates = [s for s in st.stations if int(s.occupied_slots) < int(s.total_slots)]
if not candidates:
return {
"request_id": rid,
"session_id": payload.session_id,
"spawned_ev": spawned.model_dump(mode="json"),
"assignment": None,
"event": {"type": "no_station", "reason": "all_stations_full"},
"ms": int((time.time() - t0) * 1000),
}
try:
action = baseline_policy(st, core.city_graph)
except Exception:
# last-resort: pick nearest non-full station
best = min(candidates, key=lambda s: haversine_m(lat, lng, float(s.lat), float(s.lng)))
action = EVGridAction(action_type=ActionType.route, ev_id=ev_id, station_id=str(best.station_id), defer_minutes=0)
assigned_station_id = getattr(action, "station_id", None)
dst = next((s for s in st.stations if assigned_station_id and s.station_id == assigned_station_id), None)
if action.action_type != ActionType.route or dst is None:
return {
"request_id": rid,
"session_id": payload.session_id,
"spawned_ev": spawned.model_dump(mode="json"),
"assignment": action.model_dump(mode="json"),
"event": {"type": "no_station", "reason": "policy_defer_or_invalid"},
"ms": int((time.time() - t0) * 1000),
}
traffic = TrafficModel(seed=int(core._seed_for_bescom), scenario=str(core.scenario))
routed = _osm_route_polyline(src_lat=lat, src_lng=lng, dst_lat=float(dst.lat), dst_lng=float(dst.lng), traffic=traffic, tick=int(core.step_count))
poly, seg_m_q = routed if routed is not None else ([], None)
event = {
"type": "route",
"ev_id": ev_id,
"from": {"station_id": "ROAD", "lat": lat, "lng": lng},
"to": {"station_id": dst.station_id, "lat": dst.lat, "lng": dst.lng},
"polyline": (poly or [[lat, lng], [float(dst.lat), float(dst.lng)]]),
"traffic_seg_m_q": seg_m_q,
"reroute_reason": "spawn",
}
log.info(
"demo_spawn_vehicle",
extra={"rid": rid, "sid": payload.session_id, "ev_id": ev_id, "to": str(dst.station_id), "ms": int((time.time() - t0) * 1000)},
)
return {
"request_id": rid,
"session_id": payload.session_id,
"spawned_ev": spawned.model_dump(mode="json"),
"assignment": action.model_dump(mode="json"),
"event": event,
"ms": int((time.time() - t0) * 1000),
}
@app.post("/demo/step")
def demo_step(
req: Request,
session_id: str = Body(...),
mode: Literal["baseline", "oracle"] = Body("baseline"),
oracle_lora_repo: str = Body("", embed=True),
forced_action: dict[str, Any] | None = Body(None),
) -> dict[str, Any]:
_rate_limit(req, key="demo_step", limit=120, window_sec=60)
t0 = time.time()
rid = _request_id(req)
core = _demo_session_get(session_id)
if core is None:
log.info("demo_step_miss", extra={"rid": rid, "sid": session_id, "ms": int((time.time()-t0)*1000)})
raise HTTPException(status_code=404, detail="Unknown session_id")
try:
st = core._grid_state
oracle_llm_active = False
oracle_text = ""
oracle_timed_out = False
oracle_skipped_env = False
dream_score = None
dream_breakdown: dict[str, float] = {}
dream_pred = None
dream_true = None
event: dict[str, Any] = {"type": "noop"}
forced = forced_action is not None
if forced_action is not None:
try:
action = EVGridAction.model_validate(forced_action)
except Exception as ve:
issues = ve.errors() if hasattr(ve, "errors") else [{"msg": str(ve)}]
# Pydantic v2 can include non-JSON-serializable objects under `ctx` (e.g., ValueError instances).
for it in issues:
if isinstance(it, dict) and "ctx" in it:
try:
ctx = it.get("ctx") or {}
if isinstance(ctx, dict):
it["ctx"] = {str(k): str(v) for k, v in ctx.items()}
else:
it["ctx"] = str(ctx)
except Exception:
it.pop("ctx", None)
raise HTTPException(status_code=422, detail={"error": "invalid_forced_action", "issues": issues})
oracle_llm_active = False
oracle_text = ""
dream_score = None
dream_breakdown = {}
dream_pred = None
dream_true = None
# Keep animation useful even when replaying stored actions.
if st is not None:
ev = next((e for e in st.pending_evs if e.ev_id == action.ev_id), st.pending_evs[0] if st.pending_evs else None)
src = next((x for x in st.stations if ev is not None and x.neighborhood_slug == ev.neighborhood_slug), None)
dst = next((x for x in st.stations if action.station_id and x.station_id == action.station_id), None)
if action.action_type == ActionType.route and ev is not None and src is not None and dst is not None:
traffic = TrafficModel(seed=int(core._seed_for_bescom), scenario=str(core.scenario))
routed = _osm_route_polyline(
src_lat=float(src.lat),
src_lng=float(src.lng),
dst_lat=float(dst.lat),
dst_lng=float(dst.lng),
traffic=traffic,
tick=int(core.step_count),
)
poly, seg_m_q = routed if routed is not None else ([], None)
event = {
"type": "route",
"ev_id": ev.ev_id,
"from": {"station_id": src.station_id, "lat": src.lat, "lng": src.lng},
"to": {"station_id": dst.station_id, "lat": dst.lat, "lng": dst.lng},
"polyline": (poly or _graph_route_polyline(core, src_station_id=src.station_id, dst_station_id=dst.station_id)),
"traffic_seg_m_q": seg_m_q,
"reroute_reason": "periodic" if (int(core.step_count) % 6 == 0) else None,
}
else:
event = {"type": "forced_action", "action_type": str(action.action_type.value)}
else:
event = {"type": "forced_action"}
elif st is None or not st.pending_evs:
action = EVGridAction(action_type=ActionType.load_shift, ev_id="EV-000", defer_minutes=0)
event = {"type": "idle"}
else:
if mode == "baseline":
action = baseline_policy(st, core.city_graph)
else:
action, oracle_text, oracle_llm_active, oracle_timed_out, oracle_skipped_env = _demo_oracle_act_with_guard(
st=st, core=core, oracle_lora_repo=oracle_lora_repo
)
# If oracle produced a block, score it against a deterministic T+5 rollout.
pred = parse_simulation(oracle_text) if oracle_text else None
if pred is not None:
ps = score_prediction(st, action, pred)
dream_score = ps.score_0_1
dream_breakdown = ps.breakdown
dream_pred = pred.model_dump(mode="json")
t5 = rollout_deterministic_5ticks(st, action)
# summarize true top3
top3 = sorted(
[
(
s.station_id,
s.occupied_slots / max(1, s.total_slots),
s.queue_length,
)
for s in t5.stations
],
key=lambda x: x[1],
reverse=True,
)[:3]
dream_true = {
"t5_grid_load_pct": float(t5.grid_load_pct),
"t5_renewable_pct": float(t5.renewable_pct),
"t5_top_stations": [
{"station_id": sid, "load_pct": float(load), "queue": int(q)} for sid, load, q in top3
],
}
# Render-friendly event for frontend animation (skip overwriting if replaying forced_action).
# v0: polyline path is station-to-station graph path (lat/lng pairs).
if not forced and st is not None and st.pending_evs:
ev = st.pending_evs[0]
src = next((x for x in st.stations if x.neighborhood_slug == ev.neighborhood_slug), None)
dst = next((x for x in st.stations if x.station_id == action.station_id), None)
if action.action_type == ActionType.route and src is not None and dst is not None:
traffic = TrafficModel(seed=int(core._seed_for_bescom), scenario=str(core.scenario))
routed = _osm_route_polyline(
src_lat=float(src.lat),
src_lng=float(src.lng),
dst_lat=float(dst.lat),
dst_lng=float(dst.lng),
traffic=traffic,
tick=int(core.step_count),
)
poly, seg_m_q = routed if routed is not None else ([], None)
event = {
"type": "route",
"ev_id": ev.ev_id,
"from": {"station_id": src.station_id, "lat": src.lat, "lng": src.lng},
"to": {"station_id": dst.station_id, "lat": dst.lat, "lng": dst.lng},
"polyline": (poly or _graph_route_polyline(core, src_station_id=src.station_id, dst_station_id=dst.station_id)),
"traffic_seg_m_q": seg_m_q,
"reroute_reason": "periodic" if (int(core.step_count) % 6 == 0) else None,
}
else:
event = {"type": action.action_type.value}
# Ensure the map always looks alive: if action isn't a route, emit a deterministic
# ambient trip for UI motion (does not affect env dynamics or rewards).
if (event.get("type") != "route") and st is not None and len(st.stations) >= 2:
tick_i = int(core.step_count)
seed_i = int(core._seed_for_bescom)
scen = str(core.scenario)
h = hashlib.sha1(f"{seed_i}|{scen}|{mode}|ambient|{tick_i}".encode("utf-8")).digest()
a_i = int.from_bytes(h[:2], "big") % len(st.stations)
b_i = int.from_bytes(h[2:4], "big") % len(st.stations)
if b_i == a_i:
b_i = (b_i + 1) % len(st.stations)
src2 = st.stations[a_i]
dst2 = st.stations[b_i]
traffic2 = TrafficModel(seed=seed_i, scenario=scen)
routed2 = _osm_route_polyline(
src_lat=float(src2.lat),
src_lng=float(src2.lng),
dst_lat=float(dst2.lat),
dst_lng=float(dst2.lng),
traffic=traffic2,
tick=tick_i,
)
poly2, seg2 = routed2 if routed2 is not None else ([], None)
event = {
"type": "route",
"ev_id": f"AMBIENT-{mode}-{a_i}-{b_i}",
"from": {"station_id": src2.station_id, "lat": src2.lat, "lng": src2.lng},
"to": {"station_id": dst2.station_id, "lat": dst2.lat, "lng": dst2.lng},
"polyline": (poly2 or _graph_route_polyline(core, src_station_id=src2.station_id, dst_station_id=dst2.station_id)),
"traffic_seg_m_q": seg2,
"reroute_reason": "ambient",
}
obs = core.step(action)
anti_flags = obs.anti_cheat_flags
anti_details = obs.anti_cheat_details
role_kpis = compute_role_kpis(obs)
role_reward_breakdown = compute_role_reward_breakdown(obs)
out = {
"request_id": rid,
"obs": _obs_to_jsonable(obs),
"event": event,
"scenario": core.scenario,
"scenario_events_at_tick": core.last_scenario_events,
"tick": core.step_count,
"tick_dt_s": float(core.step_minutes) * 60.0,
"schema_version": "traffic-v1",
"sim_version": _SIM_VERSION,
"anti_cheat_flags": anti_flags,
"anti_cheat_details": anti_details,
"role_kpis": role_kpis,
"role_reward_breakdown": role_reward_breakdown,
"mode": mode,
"oracle_lora_repo": (oracle_lora_repo or "").strip(),
"oracle_llm_active": oracle_llm_active,
"oracle_timed_out": oracle_timed_out,
"oracle_skipped_env": oracle_skipped_env,
"action": summarize_action(action),
"oracle_text": oracle_text,
"dream_score": dream_score,
"dream_breakdown": dream_breakdown,
"dream_pred": dream_pred,
"dream_true": dream_true,
"forced_action": forced_action is not None,
}
log.info(
"demo_step",
extra={
"rid": rid,
"sid": session_id,
"mode": mode,
"tick": int(core.step_count),
"oracle_active": bool(oracle_llm_active),
"oracle_timeout": bool(oracle_timed_out),
"oracle_skipped": bool(oracle_skipped_env),
"forced": bool(forced_action is not None),
"ms": int((time.time() - t0) * 1000),
},
)
return out
except HTTPException:
raise
except Exception as e:
log.exception("demo_step_error", extra={"rid": rid, "sid": session_id, "mode": mode, "ms": int((time.time() - t0) * 1000)})
raise HTTPException(status_code=500, detail=f"demo_step_error: {type(e).__name__}: {e}")
def main(host: str = "0.0.0.0", port: int = 8000):
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()