"""Session-based REST router for browser-driven episodes.""" from __future__ import annotations import logging import threading import time import uuid from typing import Optional import numpy as np from fastapi import APIRouter, HTTPException from pydantic import BaseModel, ConfigDict, Field from physix.models import ( DEFAULT_MAX_TURNS, PhysiXAction, PhysiXObservation, ) from physix.server.environment import PhysiXEnvironment from physix.server.providers import ( LlmModelInfo, LlmModelsLister, LlmModelsResponse, LlmPolicyFactory, LlmStepRequest, default_ollama_models_lister, default_openai_compat_policy_factory, ) from physix.systems import list_supported_systems, list_systems from physix.systems.base import PhysicalSystem, TrajectoryData from physix.training.prompt import build_prompt, parse_completion from physix.verifier.parser import parse_equation from physix.verifier.simulator import simulate_hypothesis __all__ = [ "InteractiveSessionStore", "LlmModelInfo", "LlmModelsResponse", "LlmStepRequest", "LlmStepResponse", "build_interactive_router", ] _log = logging.getLogger(__name__) class InteractiveResetRequest(BaseModel): model_config = ConfigDict(extra="forbid") system_id: Optional[str] = Field( default=None, description="Force a specific system. None = sample at random.", ) seed: Optional[int] = None max_turns: int = Field(default=DEFAULT_MAX_TURNS, ge=1, le=32) class SystemDescriptor(BaseModel): model_config = ConfigDict(frozen=True) system_id: str state_variables: tuple[str, ...] class InteractiveStartResponse(BaseModel): session_id: str observation: PhysiXObservation system: SystemDescriptor max_turns: int class LlmStepResponse(BaseModel): observation: PhysiXObservation predicted_trajectory: list[dict[str, float]] = Field(default_factory=list) action: PhysiXAction raw_completion: str latency_s: float model: str class DirectStepRequest(BaseModel): """Submit an action directly, without going through an LLM. Used by the OpenEnv tab in the demo so visitors can drive the env by hand and see exactly what each call returns. """ model_config = ConfigDict(extra="forbid") action: PhysiXAction class DirectStepResponse(BaseModel): observation: PhysiXObservation predicted_trajectory: list[dict[str, float]] = Field(default_factory=list) class SessionSummary(BaseModel): session_id: str system_id: str turn: int max_turns: int converged: bool done: bool class _Session: __slots__ = ("env", "system_id", "max_turns", "lock") def __init__(self, env: PhysiXEnvironment, system_id: str, max_turns: int) -> None: self.env = env self.system_id = system_id self.max_turns = max_turns self.lock = threading.Lock() class InteractiveSessionStore: """Threadsafe in-memory session map.""" def __init__(self) -> None: self._sessions: dict[str, _Session] = {} self._lock = threading.Lock() def create( self, *, system_id: Optional[str], seed: Optional[int], max_turns: int, ) -> tuple[str, _Session, PhysiXObservation]: env = PhysiXEnvironment(seed=seed, max_turns=max_turns) observation = env.reset(seed=seed, system_id=system_id) session = _Session(env=env, system_id=env.state.system_id, max_turns=max_turns) session_id = uuid.uuid4().hex with self._lock: self._sessions[session_id] = session return session_id, session, observation def get(self, session_id: str) -> _Session: with self._lock: session = self._sessions.get(session_id) if session is None: raise HTTPException(status_code=404, detail="Unknown session_id.") return session def delete(self, session_id: str) -> None: with self._lock: self._sessions.pop(session_id, None) def __len__(self) -> int: with self._lock: return len(self._sessions) def build_interactive_router( store: Optional[InteractiveSessionStore] = None, *, policy_factory: LlmPolicyFactory = default_openai_compat_policy_factory, models_lister: LlmModelsLister = default_ollama_models_lister, ) -> APIRouter: sessions = store if store is not None else InteractiveSessionStore() router = APIRouter(prefix="/interactive", tags=["Interactive"]) @router.get("/models", response_model=LlmModelsResponse) def list_local_models() -> LlmModelsResponse: return models_lister() @router.get("/systems", response_model=list[SystemDescriptor]) def list_public_systems() -> list[SystemDescriptor]: from physix.systems import get_system out: list[SystemDescriptor] = [] for system_id in list_supported_systems(): system = get_system(system_id) out.append( SystemDescriptor( system_id=system.system_id, state_variables=system.state_variables, ) ) return out @router.post("/sessions", response_model=InteractiveStartResponse) def start_session(payload: InteractiveResetRequest) -> InteractiveStartResponse: from physix.systems import get_system if payload.system_id is not None and payload.system_id not in list_systems(): raise HTTPException( status_code=400, detail=f"Unknown system_id {payload.system_id!r}." ) chosen_system_id = payload.system_id if chosen_system_id is None: demo_ids = list_supported_systems() if demo_ids: rng = ( np.random.default_rng(payload.seed) if payload.seed is not None else np.random.default_rng() ) chosen_system_id = str(rng.choice(demo_ids)) session_id, session, observation = sessions.create( system_id=chosen_system_id, seed=payload.seed, max_turns=payload.max_turns, ) system = get_system(session.system_id) return InteractiveStartResponse( session_id=session_id, observation=observation, system=SystemDescriptor( system_id=system.system_id, state_variables=system.state_variables, ), max_turns=session.max_turns, ) @router.post( "/sessions/{session_id}/llm-step", response_model=LlmStepResponse ) def llm_step_session( session_id: str, payload: LlmStepRequest ) -> LlmStepResponse: session = sessions.get(session_id) with session.lock: _ensure_budget(session) current_obs = session.env.current_observation() if current_obs is None: raise HTTPException( status_code=500, detail="Session has no current observation." ) policy = policy_factory(payload) t0 = time.perf_counter() raw_completion = policy(build_prompt(current_obs)) latency_s = time.perf_counter() - t0 action = parse_completion(raw_completion) observation = session.env.step(action) predicted = _safe_predict(session.env, action) return LlmStepResponse( observation=observation, predicted_trajectory=predicted, action=action, raw_completion=raw_completion, latency_s=latency_s, model=payload.model, ) @router.post( "/sessions/{session_id}/step", response_model=DirectStepResponse ) def direct_step_session( session_id: str, payload: DirectStepRequest ) -> DirectStepResponse: """Apply a user-supplied action without invoking an LLM. Mirrors the standard OpenEnv ``POST /step`` semantics but bound to a per-browser session so subsequent calls share state. The OpenEnv tab in the demo uses this to let visitors drive the env by hand. """ session = sessions.get(session_id) with session.lock: _ensure_budget(session) observation = session.env.step(payload.action) predicted = _safe_predict(session.env, payload.action) return DirectStepResponse( observation=observation, predicted_trajectory=predicted ) @router.delete("/sessions/{session_id}", status_code=204) def end_session(session_id: str) -> None: sessions.delete(session_id) @router.get("/sessions/{session_id}", response_model=SessionSummary) def get_session(session_id: str) -> SessionSummary: session = sessions.get(session_id) return SessionSummary( session_id=session_id, system_id=session.system_id, turn=session.env.state.step_count, max_turns=session.max_turns, converged=session.env.state.converged, done=( session.env.state.converged or session.env.state.step_count >= session.max_turns ), ) return router def _ensure_budget(session: _Session) -> None: if session.env.state.step_count >= session.max_turns: raise HTTPException( status_code=409, detail="Episode budget already exhausted; start a new session.", ) def _safe_predict( env: PhysiXEnvironment, action: PhysiXAction ) -> list[dict[str, float]]: """Forward-simulate the user's hypothesis for the UI overlay. Returns ``[]`` on parse / simulation failure — the env's reward is authoritative; this is best-effort visualisation only. """ system: Optional[PhysicalSystem] = env.current_system trajectory: Optional[TrajectoryData] = env.current_trajectory if system is None or trajectory is None: return [] parameter_names = frozenset(action.params or {}) | frozenset(system.parameters) try: parsed = parse_equation( action.equation, state_variables=system.state_variables, parameter_names=parameter_names, ) except Exception as exc: # noqa: BLE001 _log.debug("predict parse failed: %s", exc) return [] merged = {**system.parameters, **(action.params or {})} try: predicted = simulate_hypothesis( parsed, state_variables=system.state_variables, parameters=merged, initial_conditions=trajectory.initial_conditions, timestamps=trajectory.timestamps, ) except Exception as exc: # noqa: BLE001 _log.debug("predict simulate failed: %s", exc) return [] samples: list[dict[str, float]] = [] for i, t in enumerate(trajectory.timestamps): sample: dict[str, float] = {"t": round(float(t), 5)} for var in system.state_variables: value = predicted[var][i] if not np.isfinite(value): return [] sample[var] = round(float(value), 5) samples.append(sample) return samples