Spaces:
Sleeping
Sleeping
| """Session-based REST router for browser-driven episodes.""" | |
| from __future__ import annotations | |
| import logging | |
| import threading | |
| import time | |
| import uuid | |
| from typing import Optional | |
| import numpy as np | |
| from fastapi import APIRouter, HTTPException | |
| from pydantic import BaseModel, ConfigDict, Field | |
| from physix.models import ( | |
| DEFAULT_MAX_TURNS, | |
| PhysiXAction, | |
| PhysiXObservation, | |
| ) | |
| from physix.server.environment import PhysiXEnvironment | |
| from physix.server.providers import ( | |
| LlmModelInfo, | |
| LlmModelsLister, | |
| LlmModelsResponse, | |
| LlmPolicyFactory, | |
| LlmStepRequest, | |
| default_ollama_models_lister, | |
| default_openai_compat_policy_factory, | |
| ) | |
| from physix.systems import list_supported_systems, list_systems | |
| from physix.systems.base import PhysicalSystem, TrajectoryData | |
| from physix.training.prompt import build_prompt, parse_completion | |
| from physix.verifier.parser import parse_equation | |
| from physix.verifier.simulator import simulate_hypothesis | |
| __all__ = [ | |
| "InteractiveSessionStore", | |
| "LlmModelInfo", | |
| "LlmModelsResponse", | |
| "LlmStepRequest", | |
| "LlmStepResponse", | |
| "build_interactive_router", | |
| ] | |
| _log = logging.getLogger(__name__) | |
| class InteractiveResetRequest(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| system_id: Optional[str] = Field( | |
| default=None, | |
| description="Force a specific system. None = sample at random.", | |
| ) | |
| seed: Optional[int] = None | |
| max_turns: int = Field(default=DEFAULT_MAX_TURNS, ge=1, le=32) | |
| class SystemDescriptor(BaseModel): | |
| model_config = ConfigDict(frozen=True) | |
| system_id: str | |
| state_variables: tuple[str, ...] | |
| class InteractiveStartResponse(BaseModel): | |
| session_id: str | |
| observation: PhysiXObservation | |
| system: SystemDescriptor | |
| max_turns: int | |
| class LlmStepResponse(BaseModel): | |
| observation: PhysiXObservation | |
| predicted_trajectory: list[dict[str, float]] = Field(default_factory=list) | |
| action: PhysiXAction | |
| raw_completion: str | |
| latency_s: float | |
| model: str | |
| class DirectStepRequest(BaseModel): | |
| """Submit an action directly, without going through an LLM. | |
| Used by the OpenEnv tab in the demo so visitors can drive the env | |
| by hand and see exactly what each call returns. | |
| """ | |
| model_config = ConfigDict(extra="forbid") | |
| action: PhysiXAction | |
| class DirectStepResponse(BaseModel): | |
| observation: PhysiXObservation | |
| predicted_trajectory: list[dict[str, float]] = Field(default_factory=list) | |
| class SessionSummary(BaseModel): | |
| session_id: str | |
| system_id: str | |
| turn: int | |
| max_turns: int | |
| converged: bool | |
| done: bool | |
| class _Session: | |
| __slots__ = ("env", "system_id", "max_turns", "lock") | |
| def __init__(self, env: PhysiXEnvironment, system_id: str, max_turns: int) -> None: | |
| self.env = env | |
| self.system_id = system_id | |
| self.max_turns = max_turns | |
| self.lock = threading.Lock() | |
| class InteractiveSessionStore: | |
| """Threadsafe in-memory session map.""" | |
| def __init__(self) -> None: | |
| self._sessions: dict[str, _Session] = {} | |
| self._lock = threading.Lock() | |
| def create( | |
| self, | |
| *, | |
| system_id: Optional[str], | |
| seed: Optional[int], | |
| max_turns: int, | |
| ) -> tuple[str, _Session, PhysiXObservation]: | |
| env = PhysiXEnvironment(seed=seed, max_turns=max_turns) | |
| observation = env.reset(seed=seed, system_id=system_id) | |
| session = _Session(env=env, system_id=env.state.system_id, max_turns=max_turns) | |
| session_id = uuid.uuid4().hex | |
| with self._lock: | |
| self._sessions[session_id] = session | |
| return session_id, session, observation | |
| def get(self, session_id: str) -> _Session: | |
| with self._lock: | |
| session = self._sessions.get(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Unknown session_id.") | |
| return session | |
| def delete(self, session_id: str) -> None: | |
| with self._lock: | |
| self._sessions.pop(session_id, None) | |
| def __len__(self) -> int: | |
| with self._lock: | |
| return len(self._sessions) | |
| def build_interactive_router( | |
| store: Optional[InteractiveSessionStore] = None, | |
| *, | |
| policy_factory: LlmPolicyFactory = default_openai_compat_policy_factory, | |
| models_lister: LlmModelsLister = default_ollama_models_lister, | |
| ) -> APIRouter: | |
| sessions = store if store is not None else InteractiveSessionStore() | |
| router = APIRouter(prefix="/interactive", tags=["Interactive"]) | |
| def list_local_models() -> LlmModelsResponse: | |
| return models_lister() | |
| def list_public_systems() -> list[SystemDescriptor]: | |
| from physix.systems import get_system | |
| out: list[SystemDescriptor] = [] | |
| for system_id in list_supported_systems(): | |
| system = get_system(system_id) | |
| out.append( | |
| SystemDescriptor( | |
| system_id=system.system_id, | |
| state_variables=system.state_variables, | |
| ) | |
| ) | |
| return out | |
| def start_session(payload: InteractiveResetRequest) -> InteractiveStartResponse: | |
| from physix.systems import get_system | |
| if payload.system_id is not None and payload.system_id not in list_systems(): | |
| raise HTTPException( | |
| status_code=400, detail=f"Unknown system_id {payload.system_id!r}." | |
| ) | |
| chosen_system_id = payload.system_id | |
| if chosen_system_id is None: | |
| demo_ids = list_supported_systems() | |
| if demo_ids: | |
| rng = ( | |
| np.random.default_rng(payload.seed) | |
| if payload.seed is not None | |
| else np.random.default_rng() | |
| ) | |
| chosen_system_id = str(rng.choice(demo_ids)) | |
| session_id, session, observation = sessions.create( | |
| system_id=chosen_system_id, | |
| seed=payload.seed, | |
| max_turns=payload.max_turns, | |
| ) | |
| system = get_system(session.system_id) | |
| return InteractiveStartResponse( | |
| session_id=session_id, | |
| observation=observation, | |
| system=SystemDescriptor( | |
| system_id=system.system_id, | |
| state_variables=system.state_variables, | |
| ), | |
| max_turns=session.max_turns, | |
| ) | |
| def llm_step_session( | |
| session_id: str, payload: LlmStepRequest | |
| ) -> LlmStepResponse: | |
| session = sessions.get(session_id) | |
| with session.lock: | |
| _ensure_budget(session) | |
| current_obs = session.env.current_observation() | |
| if current_obs is None: | |
| raise HTTPException( | |
| status_code=500, detail="Session has no current observation." | |
| ) | |
| policy = policy_factory(payload) | |
| t0 = time.perf_counter() | |
| raw_completion = policy(build_prompt(current_obs)) | |
| latency_s = time.perf_counter() - t0 | |
| action = parse_completion(raw_completion) | |
| observation = session.env.step(action) | |
| predicted = _safe_predict(session.env, action) | |
| return LlmStepResponse( | |
| observation=observation, | |
| predicted_trajectory=predicted, | |
| action=action, | |
| raw_completion=raw_completion, | |
| latency_s=latency_s, | |
| model=payload.model, | |
| ) | |
| def direct_step_session( | |
| session_id: str, payload: DirectStepRequest | |
| ) -> DirectStepResponse: | |
| """Apply a user-supplied action without invoking an LLM. | |
| Mirrors the standard OpenEnv ``POST /step`` semantics but bound | |
| to a per-browser session so subsequent calls share state. The | |
| OpenEnv tab in the demo uses this to let visitors drive the env | |
| by hand. | |
| """ | |
| session = sessions.get(session_id) | |
| with session.lock: | |
| _ensure_budget(session) | |
| observation = session.env.step(payload.action) | |
| predicted = _safe_predict(session.env, payload.action) | |
| return DirectStepResponse( | |
| observation=observation, predicted_trajectory=predicted | |
| ) | |
| def end_session(session_id: str) -> None: | |
| sessions.delete(session_id) | |
| def get_session(session_id: str) -> SessionSummary: | |
| session = sessions.get(session_id) | |
| return SessionSummary( | |
| session_id=session_id, | |
| system_id=session.system_id, | |
| turn=session.env.state.step_count, | |
| max_turns=session.max_turns, | |
| converged=session.env.state.converged, | |
| done=( | |
| session.env.state.converged | |
| or session.env.state.step_count >= session.max_turns | |
| ), | |
| ) | |
| return router | |
| def _ensure_budget(session: _Session) -> None: | |
| if session.env.state.step_count >= session.max_turns: | |
| raise HTTPException( | |
| status_code=409, | |
| detail="Episode budget already exhausted; start a new session.", | |
| ) | |
| def _safe_predict( | |
| env: PhysiXEnvironment, action: PhysiXAction | |
| ) -> list[dict[str, float]]: | |
| """Forward-simulate the user's hypothesis for the UI overlay. | |
| Returns ``[]`` on parse / simulation failure — the env's reward is | |
| authoritative; this is best-effort visualisation only. | |
| """ | |
| system: Optional[PhysicalSystem] = env.current_system | |
| trajectory: Optional[TrajectoryData] = env.current_trajectory | |
| if system is None or trajectory is None: | |
| return [] | |
| parameter_names = frozenset(action.params or {}) | frozenset(system.parameters) | |
| try: | |
| parsed = parse_equation( | |
| action.equation, | |
| state_variables=system.state_variables, | |
| parameter_names=parameter_names, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.debug("predict parse failed: %s", exc) | |
| return [] | |
| merged = {**system.parameters, **(action.params or {})} | |
| try: | |
| predicted = simulate_hypothesis( | |
| parsed, | |
| state_variables=system.state_variables, | |
| parameters=merged, | |
| initial_conditions=trajectory.initial_conditions, | |
| timestamps=trajectory.timestamps, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.debug("predict simulate failed: %s", exc) | |
| return [] | |
| samples: list[dict[str, float]] = [] | |
| for i, t in enumerate(trajectory.timestamps): | |
| sample: dict[str, float] = {"t": round(float(t), 5)} | |
| for var in system.state_variables: | |
| value = predicted[var][i] | |
| if not np.isfinite(value): | |
| return [] | |
| sample[var] = round(float(value), 5) | |
| samples.append(sample) | |
| return samples | |