Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Dict, Generic, Optional, TypeVar | |
| import httpx | |
| from fastapi import Body, FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, ConfigDict, Field | |
| class Action(BaseModel): | |
| model_config = ConfigDict( | |
| extra="forbid", | |
| validate_assignment=True, | |
| arbitrary_types_allowed=True, | |
| ) | |
| metadata: Dict[str, Any] = Field(default_factory=dict) | |
| class Observation(BaseModel): | |
| model_config = ConfigDict( | |
| extra="forbid", | |
| validate_assignment=True, | |
| arbitrary_types_allowed=True, | |
| ) | |
| done: bool = Field(default=False) | |
| reward: float | None = Field(default=None) | |
| metadata: Dict[str, Any] = Field(default_factory=dict) | |
| class State(BaseModel): | |
| model_config = ConfigDict( | |
| extra="allow", | |
| validate_assignment=True, | |
| arbitrary_types_allowed=True, | |
| ) | |
| episode_id: Optional[str] = Field(default=None) | |
| step_count: int = Field(default=0, ge=0) | |
| ActT = TypeVar("ActT", bound=BaseModel) | |
| ObsT = TypeVar("ObsT", bound=Observation) | |
| StateT = TypeVar("StateT", bound=State) | |
| class StepResult(Generic[ObsT]): | |
| observation: ObsT | |
| reward: float | None = None | |
| done: bool = False | |
| info: Dict[str, Any] = field(default_factory=dict) | |
| class Environment(ABC, Generic[ActT, ObsT, StateT]): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> ObsT: | |
| raise NotImplementedError | |
| def step( | |
| self, | |
| action: ActT, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> ObsT: | |
| raise NotImplementedError | |
| def state(self) -> StateT: | |
| raise NotImplementedError | |
| def close(self) -> None: | |
| return None | |
| class EnvClient(ABC, Generic[ActT, ObsT, StateT]): | |
| def __init__( | |
| self, | |
| base_url: str, | |
| *, | |
| timeout_s: float = 30.0, | |
| async_client: httpx.AsyncClient | None = None, | |
| ) -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self._owns_client = async_client is None | |
| self._client = async_client or httpx.AsyncClient( | |
| base_url=self.base_url, | |
| timeout=timeout_s, | |
| ) | |
| async def __aenter__(self) -> "EnvClient[ActT, ObsT, StateT]": | |
| return self | |
| async def __aexit__(self, exc_type, exc, tb) -> None: | |
| await self.aclose() | |
| async def aclose(self) -> None: | |
| if self._owns_client: | |
| await self._client.aclose() | |
| async def close(self) -> None: | |
| await self.aclose() | |
| async def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> StepResult[ObsT]: | |
| payload: Dict[str, Any] = {} | |
| if seed is not None: | |
| payload["seed"] = seed | |
| if episode_id is not None: | |
| payload["episode_id"] = episode_id | |
| payload.update(kwargs) | |
| response = await self._client.post("/reset", json=payload) | |
| response.raise_for_status() | |
| return self._parse_result(response.json()) | |
| async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: | |
| payload: Dict[str, Any] = {"action": self._step_payload(action)} | |
| payload.update(kwargs) | |
| response = await self._client.post("/step", json=payload) | |
| response.raise_for_status() | |
| return self._parse_result(response.json()) | |
| async def state(self) -> StateT: | |
| response = await self._client.get("/state") | |
| response.raise_for_status() | |
| return self._parse_state(response.json()) | |
| def _step_payload(self, action: ActT) -> Dict[str, Any]: | |
| raise NotImplementedError | |
| def _parse_result(self, data: Dict[str, Any]) -> StepResult[ObsT]: | |
| raise NotImplementedError | |
| def _parse_state(self, data: Dict[str, Any]) -> StateT: | |
| raise NotImplementedError | |
| def _serialize_observation(observation: Observation) -> Dict[str, Any]: | |
| return { | |
| "observation": observation.model_dump( | |
| exclude={"reward", "done", "metadata"}, | |
| mode="json", | |
| ), | |
| "reward": observation.reward, | |
| "done": observation.done, | |
| } | |
| def _root_page_html(env_name: str) -> str: | |
| return f"""<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>{env_name}</title> | |
| <style> | |
| :root {{ | |
| color-scheme: light; | |
| --bg: #f6f4ef; | |
| --panel: #fffdf8; | |
| --ink: #18222d; | |
| --muted: #556270; | |
| --accent: #0f766e; | |
| --accent-strong: #115e59; | |
| --line: #d7d1c7; | |
| --warn: #9a3412; | |
| --code: #f2efe8; | |
| }} | |
| * {{ box-sizing: border-box; }} | |
| body {{ | |
| margin: 0; | |
| font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; | |
| background: | |
| radial-gradient(circle at top right, rgba(15, 118, 110, 0.16), transparent 26%), | |
| linear-gradient(180deg, #faf8f2 0%, var(--bg) 100%); | |
| color: var(--ink); | |
| }} | |
| main {{ | |
| max-width: 1120px; | |
| margin: 0 auto; | |
| padding: 32px 20px 48px; | |
| }} | |
| .hero {{ | |
| display: grid; | |
| gap: 14px; | |
| margin-bottom: 24px; | |
| }} | |
| h1 {{ | |
| margin: 0; | |
| font-size: clamp(2rem, 4vw, 3.4rem); | |
| line-height: 1.05; | |
| }} | |
| p {{ | |
| margin: 0; | |
| color: var(--muted); | |
| max-width: 74ch; | |
| }} | |
| .grid {{ | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); | |
| gap: 16px; | |
| }} | |
| .panel {{ | |
| background: var(--panel); | |
| border: 1px solid var(--line); | |
| border-radius: 18px; | |
| padding: 18px; | |
| box-shadow: 0 10px 30px rgba(24, 34, 45, 0.06); | |
| }} | |
| .panel h2 {{ | |
| margin: 0 0 12px; | |
| font-size: 1.05rem; | |
| }} | |
| .row {{ | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| margin-bottom: 12px; | |
| }} | |
| label {{ | |
| display: grid; | |
| gap: 6px; | |
| font-size: 0.92rem; | |
| color: var(--muted); | |
| width: 100%; | |
| }} | |
| select, button, textarea {{ | |
| font: inherit; | |
| }} | |
| select, textarea {{ | |
| width: 100%; | |
| border: 1px solid var(--line); | |
| border-radius: 12px; | |
| background: white; | |
| padding: 12px 14px; | |
| color: var(--ink); | |
| }} | |
| textarea {{ | |
| min-height: 220px; | |
| resize: vertical; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; | |
| background: var(--code); | |
| }} | |
| button {{ | |
| border: 0; | |
| border-radius: 999px; | |
| padding: 11px 16px; | |
| background: var(--accent); | |
| color: white; | |
| cursor: pointer; | |
| font-weight: 600; | |
| }} | |
| button:hover {{ | |
| background: var(--accent-strong); | |
| }} | |
| .secondary {{ | |
| background: #dbe9e7; | |
| color: var(--accent-strong); | |
| }} | |
| .stats {{ | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); | |
| gap: 10px; | |
| }} | |
| .stat {{ | |
| border: 1px solid var(--line); | |
| border-radius: 14px; | |
| padding: 12px; | |
| background: rgba(255, 255, 255, 0.75); | |
| }} | |
| .stat strong {{ | |
| display: block; | |
| font-size: 1.15rem; | |
| margin-top: 4px; | |
| }} | |
| .muted {{ | |
| color: var(--muted); | |
| }} | |
| .warning {{ | |
| color: var(--warn); | |
| }} | |
| code {{ | |
| background: var(--code); | |
| padding: 2px 6px; | |
| border-radius: 8px; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <main> | |
| <section class="hero"> | |
| <h1>Enterprise Finance OpenEnv</h1> | |
| <p> | |
| Interactive console for testing <code>/reset</code>, <code>/step</code>, <code>/state</code>, | |
| and score outputs directly from the deployed environment. | |
| </p> | |
| <p class="muted"> | |
| Use the API endpoints for agent evaluation. This page is a manual playground for humans. | |
| </p> | |
| </section> | |
| <section class="grid"> | |
| <div class="panel"> | |
| <h2>Episode Controls</h2> | |
| <div class="row"> | |
| <label> | |
| Difficulty | |
| <select id="difficulty"> | |
| <option value="easy">easy</option> | |
| <option value="medium">medium</option> | |
| <option value="hard">hard</option> | |
| </select> | |
| </label> | |
| </div> | |
| <div class="row"> | |
| <button id="reset-btn" type="button">Reset Episode</button> | |
| <button id="state-btn" class="secondary" type="button">Refresh State</button> | |
| </div> | |
| <div class="stats"> | |
| <div class="stat"><span class="muted">Difficulty</span><strong id="stat-difficulty">-</strong></div> | |
| <div class="stat"><span class="muted">Step Count</span><strong id="stat-step-count">-</strong></div> | |
| <div class="stat"><span class="muted">Open Balance</span><strong id="stat-open-balance">-</strong></div> | |
| <div class="stat"><span class="muted">Score</span><strong id="stat-score">-</strong></div> | |
| <div class="stat"><span class="muted">Total Reward</span><strong id="stat-total-reward">-</strong></div> | |
| <div class="stat"><span class="muted">Unmatched</span><strong id="stat-unmatched">-</strong></div> | |
| </div> | |
| </div> | |
| <div class="panel"> | |
| <h2>Submit Action</h2> | |
| <p class="muted"> | |
| Paste a JSON action body matching the environment schema. | |
| </p> | |
| <label> | |
| Action JSON | |
| <textarea id="action-json">{{ | |
| "type": "query_subledger", | |
| "entity": "PARENT_US", | |
| "account_code": "IC_AR", | |
| "date_range": ["2026-01-01", "2026-01-31"] | |
| }}</textarea> | |
| </label> | |
| <div class="row"> | |
| <button id="step-btn" type="button">POST /step</button> | |
| </div> | |
| <p id="step-error" class="warning"></p> | |
| </div> | |
| </section> | |
| <section class="grid" style="margin-top: 16px;"> | |
| <div class="panel"> | |
| <h2>Latest API Response</h2> | |
| <label> | |
| Response JSON | |
| <textarea id="response-json" readonly></textarea> | |
| </label> | |
| </div> | |
| <div class="panel"> | |
| <h2>Current State</h2> | |
| <label> | |
| State JSON | |
| <textarea id="state-json" readonly></textarea> | |
| </label> | |
| </div> | |
| </section> | |
| </main> | |
| <script> | |
| const responseBox = document.getElementById("response-json"); | |
| const stateBox = document.getElementById("state-json"); | |
| const errorBox = document.getElementById("step-error"); | |
| function setJson(target, payload) {{ | |
| target.value = JSON.stringify(payload, null, 2); | |
| }} | |
| function setStat(id, value) {{ | |
| document.getElementById(id).textContent = value ?? "-"; | |
| }} | |
| function updateScoreboard(state) {{ | |
| setStat("stat-difficulty", state.difficulty); | |
| setStat("stat-step-count", state.step_count); | |
| setStat("stat-open-balance", state.open_abs_balance); | |
| setStat("stat-score", state.final_score ?? "in progress"); | |
| setStat("stat-total-reward", state.total_reward); | |
| setStat("stat-unmatched", state.unmatched_valid_count); | |
| setJson(stateBox, state); | |
| }} | |
| async function fetchJson(path, options = undefined) {{ | |
| const response = await fetch(path, options); | |
| const payload = await response.json(); | |
| if (!response.ok) {{ | |
| throw new Error(JSON.stringify(payload)); | |
| }} | |
| return payload; | |
| }} | |
| async function refreshState() {{ | |
| const state = await fetchJson("/state"); | |
| updateScoreboard(state); | |
| return state; | |
| }} | |
| async function resetEpisode() {{ | |
| errorBox.textContent = ""; | |
| const difficulty = document.getElementById("difficulty").value; | |
| const payload = await fetchJson("/reset", {{ | |
| method: "POST", | |
| headers: {{ "Content-Type": "application/json" }}, | |
| body: JSON.stringify({{ difficulty }}) | |
| }}); | |
| setJson(responseBox, payload); | |
| await refreshState(); | |
| }} | |
| async function postAction() {{ | |
| errorBox.textContent = ""; | |
| let action; | |
| try {{ | |
| action = JSON.parse(document.getElementById("action-json").value); | |
| }} catch (error) {{ | |
| errorBox.textContent = "Action JSON is invalid: " + error.message; | |
| return; | |
| }} | |
| try {{ | |
| const payload = await fetchJson("/step", {{ | |
| method: "POST", | |
| headers: {{ "Content-Type": "application/json" }}, | |
| body: JSON.stringify({{ action }}) | |
| }}); | |
| setJson(responseBox, payload); | |
| await refreshState(); | |
| }} catch (error) {{ | |
| errorBox.textContent = String(error.message || error); | |
| }} | |
| }} | |
| document.getElementById("reset-btn").addEventListener("click", resetEpisode); | |
| document.getElementById("state-btn").addEventListener("click", refreshState); | |
| document.getElementById("step-btn").addEventListener("click", postAction); | |
| refreshState().catch((error) => {{ | |
| errorBox.textContent = String(error.message || error); | |
| }}); | |
| </script> | |
| </body> | |
| </html>""" | |
| def create_app( | |
| env_factory: Callable[[], Environment[Any, Observation, State]] | |
| | type[Environment[Any, Observation, State]], | |
| action_cls: type[BaseModel], | |
| observation_cls: type[Observation], | |
| env_name: str = "environment", | |
| ) -> FastAPI: | |
| env = env_factory() if callable(env_factory) else env_factory | |
| async def lifespan(_: FastAPI): | |
| yield | |
| env.close() | |
| app = FastAPI(title=env_name, lifespan=lifespan) | |
| class ResetRequest(BaseModel): | |
| model_config = ConfigDict(extra="allow") | |
| seed: int | None = None | |
| episode_id: str | None = None | |
| class StepRequest(BaseModel): | |
| model_config = ConfigDict(extra="allow") | |
| action: Dict[str, Any] | |
| timeout_s: float | None = None | |
| async def home() -> HTMLResponse: | |
| return HTMLResponse(_root_page_html(env_name)) | |
| async def health() -> Dict[str, str]: | |
| return {"status": "healthy"} | |
| async def schema() -> Dict[str, Any]: | |
| state_cls = env.state.__class__ | |
| return { | |
| "action": action_cls.model_json_schema(), | |
| "observation": observation_cls.model_json_schema(), | |
| "state": state_cls.model_json_schema(), | |
| } | |
| async def state() -> Dict[str, Any]: | |
| return env.state.model_dump(mode="json") | |
| async def reset(request: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]: | |
| validated = ResetRequest.model_validate(request) | |
| payload = validated.model_dump(exclude_none=True) | |
| seed = payload.pop("seed", None) | |
| episode_id = payload.pop("episode_id", None) | |
| observation = env.reset(seed=seed, episode_id=episode_id, **payload) | |
| return _serialize_observation(observation) | |
| async def step(request: Dict[str, Any] = Body(...)) -> Dict[str, Any]: | |
| validated = StepRequest.model_validate(request) | |
| try: | |
| action = action_cls.model_validate(validated.action) | |
| except Exception as exc: # pragma: no cover - FastAPI will surface this in real path. | |
| raise HTTPException(status_code=422, detail=str(exc)) from exc | |
| payload = validated.model_dump(exclude_none=True) | |
| timeout_s = payload.pop("timeout_s", None) | |
| payload.pop("action", None) | |
| observation = env.step(action, timeout_s=timeout_s, **payload) | |
| return _serialize_observation(observation) | |
| return app | |