from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
import httpx
from fastapi import Body, FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, ConfigDict, Field
class Action(BaseModel):
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
arbitrary_types_allowed=True,
)
metadata: Dict[str, Any] = Field(default_factory=dict)
class Observation(BaseModel):
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
arbitrary_types_allowed=True,
)
done: bool = Field(default=False)
reward: float | None = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict)
class State(BaseModel):
model_config = ConfigDict(
extra="allow",
validate_assignment=True,
arbitrary_types_allowed=True,
)
episode_id: Optional[str] = Field(default=None)
step_count: int = Field(default=0, ge=0)
ActT = TypeVar("ActT", bound=BaseModel)
ObsT = TypeVar("ObsT", bound=Observation)
StateT = TypeVar("StateT", bound=State)
@dataclass
class StepResult(Generic[ObsT]):
observation: ObsT
reward: float | None = None
done: bool = False
info: Dict[str, Any] = field(default_factory=dict)
class Environment(ABC, Generic[ActT, ObsT, StateT]):
def __init__(self) -> None:
super().__init__()
@abstractmethod
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> ObsT:
raise NotImplementedError
@abstractmethod
def step(
self,
action: ActT,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> ObsT:
raise NotImplementedError
@property
@abstractmethod
def state(self) -> StateT:
raise NotImplementedError
def close(self) -> None:
return None
class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
def __init__(
self,
base_url: str,
*,
timeout_s: float = 30.0,
async_client: httpx.AsyncClient | None = None,
) -> None:
self.base_url = base_url.rstrip("/")
self._owns_client = async_client is None
self._client = async_client or httpx.AsyncClient(
base_url=self.base_url,
timeout=timeout_s,
)
async def __aenter__(self) -> "EnvClient[ActT, ObsT, StateT]":
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.aclose()
async def aclose(self) -> None:
if self._owns_client:
await self._client.aclose()
async def close(self) -> None:
await self.aclose()
async def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> StepResult[ObsT]:
payload: Dict[str, Any] = {}
if seed is not None:
payload["seed"] = seed
if episode_id is not None:
payload["episode_id"] = episode_id
payload.update(kwargs)
response = await self._client.post("/reset", json=payload)
response.raise_for_status()
return self._parse_result(response.json())
async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
payload: Dict[str, Any] = {"action": self._step_payload(action)}
payload.update(kwargs)
response = await self._client.post("/step", json=payload)
response.raise_for_status()
return self._parse_result(response.json())
async def state(self) -> StateT:
response = await self._client.get("/state")
response.raise_for_status()
return self._parse_state(response.json())
@abstractmethod
def _step_payload(self, action: ActT) -> Dict[str, Any]:
raise NotImplementedError
@abstractmethod
def _parse_result(self, data: Dict[str, Any]) -> StepResult[ObsT]:
raise NotImplementedError
@abstractmethod
def _parse_state(self, data: Dict[str, Any]) -> StateT:
raise NotImplementedError
def _serialize_observation(observation: Observation) -> Dict[str, Any]:
return {
"observation": observation.model_dump(
exclude={"reward", "done", "metadata"},
mode="json",
),
"reward": observation.reward,
"done": observation.done,
}
def _root_page_html(env_name: str) -> str:
return f"""
{env_name}
Enterprise Finance OpenEnv
Interactive console for testing /reset, /step, /state,
and score outputs directly from the deployed environment.
Use the API endpoints for agent evaluation. This page is a manual playground for humans.
Episode Controls
Difficulty-
Step Count-
Open Balance-
Score-
Total Reward-
Unmatched-
Submit Action
Paste a JSON action body matching the environment schema.
Latest API Response
Current State
"""
def create_app(
env_factory: Callable[[], Environment[Any, Observation, State]]
| type[Environment[Any, Observation, State]],
action_cls: type[BaseModel],
observation_cls: type[Observation],
env_name: str = "environment",
) -> FastAPI:
env = env_factory() if callable(env_factory) else env_factory
@asynccontextmanager
async def lifespan(_: FastAPI):
yield
env.close()
app = FastAPI(title=env_name, lifespan=lifespan)
class ResetRequest(BaseModel):
model_config = ConfigDict(extra="allow")
seed: int | None = None
episode_id: str | None = None
class StepRequest(BaseModel):
model_config = ConfigDict(extra="allow")
action: Dict[str, Any]
timeout_s: float | None = None
@app.get("/", response_class=HTMLResponse)
async def home() -> HTMLResponse:
return HTMLResponse(_root_page_html(env_name))
@app.get("/health")
async def health() -> Dict[str, str]:
return {"status": "healthy"}
@app.get("/schema")
async def schema() -> Dict[str, Any]:
state_cls = env.state.__class__
return {
"action": action_cls.model_json_schema(),
"observation": observation_cls.model_json_schema(),
"state": state_cls.model_json_schema(),
}
@app.get("/state")
async def state() -> Dict[str, Any]:
return env.state.model_dump(mode="json")
@app.post("/reset")
async def reset(request: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
validated = ResetRequest.model_validate(request)
payload = validated.model_dump(exclude_none=True)
seed = payload.pop("seed", None)
episode_id = payload.pop("episode_id", None)
observation = env.reset(seed=seed, episode_id=episode_id, **payload)
return _serialize_observation(observation)
@app.post("/step")
async def step(request: Dict[str, Any] = Body(...)) -> Dict[str, Any]:
validated = StepRequest.model_validate(request)
try:
action = action_cls.model_validate(validated.action)
except Exception as exc: # pragma: no cover - FastAPI will surface this in real path.
raise HTTPException(status_code=422, detail=str(exc)) from exc
payload = validated.model_dump(exclude_none=True)
timeout_s = payload.pop("timeout_s", None)
payload.pop("action", None)
observation = env.step(action, timeout_s=timeout_s, **payload)
return _serialize_observation(observation)
return app