Prasham1710's picture
added ui
9fd90c1
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
import httpx
from fastapi import Body, FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, ConfigDict, Field
class Action(BaseModel):
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
arbitrary_types_allowed=True,
)
metadata: Dict[str, Any] = Field(default_factory=dict)
class Observation(BaseModel):
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
arbitrary_types_allowed=True,
)
done: bool = Field(default=False)
reward: float | None = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict)
class State(BaseModel):
model_config = ConfigDict(
extra="allow",
validate_assignment=True,
arbitrary_types_allowed=True,
)
episode_id: Optional[str] = Field(default=None)
step_count: int = Field(default=0, ge=0)
ActT = TypeVar("ActT", bound=BaseModel)
ObsT = TypeVar("ObsT", bound=Observation)
StateT = TypeVar("StateT", bound=State)
@dataclass
class StepResult(Generic[ObsT]):
observation: ObsT
reward: float | None = None
done: bool = False
info: Dict[str, Any] = field(default_factory=dict)
class Environment(ABC, Generic[ActT, ObsT, StateT]):
def __init__(self) -> None:
super().__init__()
@abstractmethod
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> ObsT:
raise NotImplementedError
@abstractmethod
def step(
self,
action: ActT,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> ObsT:
raise NotImplementedError
@property
@abstractmethod
def state(self) -> StateT:
raise NotImplementedError
def close(self) -> None:
return None
class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
def __init__(
self,
base_url: str,
*,
timeout_s: float = 30.0,
async_client: httpx.AsyncClient | None = None,
) -> None:
self.base_url = base_url.rstrip("/")
self._owns_client = async_client is None
self._client = async_client or httpx.AsyncClient(
base_url=self.base_url,
timeout=timeout_s,
)
async def __aenter__(self) -> "EnvClient[ActT, ObsT, StateT]":
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.aclose()
async def aclose(self) -> None:
if self._owns_client:
await self._client.aclose()
async def close(self) -> None:
await self.aclose()
async def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> StepResult[ObsT]:
payload: Dict[str, Any] = {}
if seed is not None:
payload["seed"] = seed
if episode_id is not None:
payload["episode_id"] = episode_id
payload.update(kwargs)
response = await self._client.post("/reset", json=payload)
response.raise_for_status()
return self._parse_result(response.json())
async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
payload: Dict[str, Any] = {"action": self._step_payload(action)}
payload.update(kwargs)
response = await self._client.post("/step", json=payload)
response.raise_for_status()
return self._parse_result(response.json())
async def state(self) -> StateT:
response = await self._client.get("/state")
response.raise_for_status()
return self._parse_state(response.json())
@abstractmethod
def _step_payload(self, action: ActT) -> Dict[str, Any]:
raise NotImplementedError
@abstractmethod
def _parse_result(self, data: Dict[str, Any]) -> StepResult[ObsT]:
raise NotImplementedError
@abstractmethod
def _parse_state(self, data: Dict[str, Any]) -> StateT:
raise NotImplementedError
def _serialize_observation(observation: Observation) -> Dict[str, Any]:
return {
"observation": observation.model_dump(
exclude={"reward", "done", "metadata"},
mode="json",
),
"reward": observation.reward,
"done": observation.done,
}
def _root_page_html(env_name: str) -> str:
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>{env_name}</title>
<style>
:root {{
color-scheme: light;
--bg: #f6f4ef;
--panel: #fffdf8;
--ink: #18222d;
--muted: #556270;
--accent: #0f766e;
--accent-strong: #115e59;
--line: #d7d1c7;
--warn: #9a3412;
--code: #f2efe8;
}}
* {{ box-sizing: border-box; }}
body {{
margin: 0;
font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
background:
radial-gradient(circle at top right, rgba(15, 118, 110, 0.16), transparent 26%),
linear-gradient(180deg, #faf8f2 0%, var(--bg) 100%);
color: var(--ink);
}}
main {{
max-width: 1120px;
margin: 0 auto;
padding: 32px 20px 48px;
}}
.hero {{
display: grid;
gap: 14px;
margin-bottom: 24px;
}}
h1 {{
margin: 0;
font-size: clamp(2rem, 4vw, 3.4rem);
line-height: 1.05;
}}
p {{
margin: 0;
color: var(--muted);
max-width: 74ch;
}}
.grid {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 16px;
}}
.panel {{
background: var(--panel);
border: 1px solid var(--line);
border-radius: 18px;
padding: 18px;
box-shadow: 0 10px 30px rgba(24, 34, 45, 0.06);
}}
.panel h2 {{
margin: 0 0 12px;
font-size: 1.05rem;
}}
.row {{
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-bottom: 12px;
}}
label {{
display: grid;
gap: 6px;
font-size: 0.92rem;
color: var(--muted);
width: 100%;
}}
select, button, textarea {{
font: inherit;
}}
select, textarea {{
width: 100%;
border: 1px solid var(--line);
border-radius: 12px;
background: white;
padding: 12px 14px;
color: var(--ink);
}}
textarea {{
min-height: 220px;
resize: vertical;
font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
background: var(--code);
}}
button {{
border: 0;
border-radius: 999px;
padding: 11px 16px;
background: var(--accent);
color: white;
cursor: pointer;
font-weight: 600;
}}
button:hover {{
background: var(--accent-strong);
}}
.secondary {{
background: #dbe9e7;
color: var(--accent-strong);
}}
.stats {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(140px, 1fr));
gap: 10px;
}}
.stat {{
border: 1px solid var(--line);
border-radius: 14px;
padding: 12px;
background: rgba(255, 255, 255, 0.75);
}}
.stat strong {{
display: block;
font-size: 1.15rem;
margin-top: 4px;
}}
.muted {{
color: var(--muted);
}}
.warning {{
color: var(--warn);
}}
code {{
background: var(--code);
padding: 2px 6px;
border-radius: 8px;
}}
</style>
</head>
<body>
<main>
<section class="hero">
<h1>Enterprise Finance OpenEnv</h1>
<p>
Interactive console for testing <code>/reset</code>, <code>/step</code>, <code>/state</code>,
and score outputs directly from the deployed environment.
</p>
<p class="muted">
Use the API endpoints for agent evaluation. This page is a manual playground for humans.
</p>
</section>
<section class="grid">
<div class="panel">
<h2>Episode Controls</h2>
<div class="row">
<label>
Difficulty
<select id="difficulty">
<option value="easy">easy</option>
<option value="medium">medium</option>
<option value="hard">hard</option>
</select>
</label>
</div>
<div class="row">
<button id="reset-btn" type="button">Reset Episode</button>
<button id="state-btn" class="secondary" type="button">Refresh State</button>
</div>
<div class="stats">
<div class="stat"><span class="muted">Difficulty</span><strong id="stat-difficulty">-</strong></div>
<div class="stat"><span class="muted">Step Count</span><strong id="stat-step-count">-</strong></div>
<div class="stat"><span class="muted">Open Balance</span><strong id="stat-open-balance">-</strong></div>
<div class="stat"><span class="muted">Score</span><strong id="stat-score">-</strong></div>
<div class="stat"><span class="muted">Total Reward</span><strong id="stat-total-reward">-</strong></div>
<div class="stat"><span class="muted">Unmatched</span><strong id="stat-unmatched">-</strong></div>
</div>
</div>
<div class="panel">
<h2>Submit Action</h2>
<p class="muted">
Paste a JSON action body matching the environment schema.
</p>
<label>
Action JSON
<textarea id="action-json">{{
"type": "query_subledger",
"entity": "PARENT_US",
"account_code": "IC_AR",
"date_range": ["2026-01-01", "2026-01-31"]
}}</textarea>
</label>
<div class="row">
<button id="step-btn" type="button">POST /step</button>
</div>
<p id="step-error" class="warning"></p>
</div>
</section>
<section class="grid" style="margin-top: 16px;">
<div class="panel">
<h2>Latest API Response</h2>
<label>
Response JSON
<textarea id="response-json" readonly></textarea>
</label>
</div>
<div class="panel">
<h2>Current State</h2>
<label>
State JSON
<textarea id="state-json" readonly></textarea>
</label>
</div>
</section>
</main>
<script>
const responseBox = document.getElementById("response-json");
const stateBox = document.getElementById("state-json");
const errorBox = document.getElementById("step-error");
function setJson(target, payload) {{
target.value = JSON.stringify(payload, null, 2);
}}
function setStat(id, value) {{
document.getElementById(id).textContent = value ?? "-";
}}
function updateScoreboard(state) {{
setStat("stat-difficulty", state.difficulty);
setStat("stat-step-count", state.step_count);
setStat("stat-open-balance", state.open_abs_balance);
setStat("stat-score", state.final_score ?? "in progress");
setStat("stat-total-reward", state.total_reward);
setStat("stat-unmatched", state.unmatched_valid_count);
setJson(stateBox, state);
}}
async function fetchJson(path, options = undefined) {{
const response = await fetch(path, options);
const payload = await response.json();
if (!response.ok) {{
throw new Error(JSON.stringify(payload));
}}
return payload;
}}
async function refreshState() {{
const state = await fetchJson("/state");
updateScoreboard(state);
return state;
}}
async function resetEpisode() {{
errorBox.textContent = "";
const difficulty = document.getElementById("difficulty").value;
const payload = await fetchJson("/reset", {{
method: "POST",
headers: {{ "Content-Type": "application/json" }},
body: JSON.stringify({{ difficulty }})
}});
setJson(responseBox, payload);
await refreshState();
}}
async function postAction() {{
errorBox.textContent = "";
let action;
try {{
action = JSON.parse(document.getElementById("action-json").value);
}} catch (error) {{
errorBox.textContent = "Action JSON is invalid: " + error.message;
return;
}}
try {{
const payload = await fetchJson("/step", {{
method: "POST",
headers: {{ "Content-Type": "application/json" }},
body: JSON.stringify({{ action }})
}});
setJson(responseBox, payload);
await refreshState();
}} catch (error) {{
errorBox.textContent = String(error.message || error);
}}
}}
document.getElementById("reset-btn").addEventListener("click", resetEpisode);
document.getElementById("state-btn").addEventListener("click", refreshState);
document.getElementById("step-btn").addEventListener("click", postAction);
refreshState().catch((error) => {{
errorBox.textContent = String(error.message || error);
}});
</script>
</body>
</html>"""
def create_app(
env_factory: Callable[[], Environment[Any, Observation, State]]
| type[Environment[Any, Observation, State]],
action_cls: type[BaseModel],
observation_cls: type[Observation],
env_name: str = "environment",
) -> FastAPI:
env = env_factory() if callable(env_factory) else env_factory
@asynccontextmanager
async def lifespan(_: FastAPI):
yield
env.close()
app = FastAPI(title=env_name, lifespan=lifespan)
class ResetRequest(BaseModel):
model_config = ConfigDict(extra="allow")
seed: int | None = None
episode_id: str | None = None
class StepRequest(BaseModel):
model_config = ConfigDict(extra="allow")
action: Dict[str, Any]
timeout_s: float | None = None
@app.get("/", response_class=HTMLResponse)
async def home() -> HTMLResponse:
return HTMLResponse(_root_page_html(env_name))
@app.get("/health")
async def health() -> Dict[str, str]:
return {"status": "healthy"}
@app.get("/schema")
async def schema() -> Dict[str, Any]:
state_cls = env.state.__class__
return {
"action": action_cls.model_json_schema(),
"observation": observation_cls.model_json_schema(),
"state": state_cls.model_json_schema(),
}
@app.get("/state")
async def state() -> Dict[str, Any]:
return env.state.model_dump(mode="json")
@app.post("/reset")
async def reset(request: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
validated = ResetRequest.model_validate(request)
payload = validated.model_dump(exclude_none=True)
seed = payload.pop("seed", None)
episode_id = payload.pop("episode_id", None)
observation = env.reset(seed=seed, episode_id=episode_id, **payload)
return _serialize_observation(observation)
@app.post("/step")
async def step(request: Dict[str, Any] = Body(...)) -> Dict[str, Any]:
validated = StepRequest.model_validate(request)
try:
action = action_cls.model_validate(validated.action)
except Exception as exc: # pragma: no cover - FastAPI will surface this in real path.
raise HTTPException(status_code=422, detail=str(exc)) from exc
payload = validated.model_dump(exclude_none=True)
timeout_s = payload.pop("timeout_s", None)
payload.pop("action", None)
observation = env.step(action, timeout_s=timeout_s, **payload)
return _serialize_observation(observation)
return app