from __future__ import annotations from abc import ABC, abstractmethod from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any, Callable, Dict, Generic, Optional, TypeVar import httpx from fastapi import Body, FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel, ConfigDict, Field class Action(BaseModel): model_config = ConfigDict( extra="forbid", validate_assignment=True, arbitrary_types_allowed=True, ) metadata: Dict[str, Any] = Field(default_factory=dict) class Observation(BaseModel): model_config = ConfigDict( extra="forbid", validate_assignment=True, arbitrary_types_allowed=True, ) done: bool = Field(default=False) reward: float | None = Field(default=None) metadata: Dict[str, Any] = Field(default_factory=dict) class State(BaseModel): model_config = ConfigDict( extra="allow", validate_assignment=True, arbitrary_types_allowed=True, ) episode_id: Optional[str] = Field(default=None) step_count: int = Field(default=0, ge=0) ActT = TypeVar("ActT", bound=BaseModel) ObsT = TypeVar("ObsT", bound=Observation) StateT = TypeVar("StateT", bound=State) @dataclass class StepResult(Generic[ObsT]): observation: ObsT reward: float | None = None done: bool = False info: Dict[str, Any] = field(default_factory=dict) class Environment(ABC, Generic[ActT, ObsT, StateT]): def __init__(self) -> None: super().__init__() @abstractmethod def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> ObsT: raise NotImplementedError @abstractmethod def step( self, action: ActT, timeout_s: Optional[float] = None, **kwargs: Any, ) -> ObsT: raise NotImplementedError @property @abstractmethod def state(self) -> StateT: raise NotImplementedError def close(self) -> None: return None class EnvClient(ABC, Generic[ActT, ObsT, StateT]): def __init__( self, base_url: str, *, timeout_s: float = 30.0, async_client: httpx.AsyncClient | None = None, ) -> None: self.base_url = base_url.rstrip("/") self._owns_client = async_client is None self._client = async_client or httpx.AsyncClient( base_url=self.base_url, timeout=timeout_s, ) async def __aenter__(self) -> "EnvClient[ActT, ObsT, StateT]": return self async def __aexit__(self, exc_type, exc, tb) -> None: await self.aclose() async def aclose(self) -> None: if self._owns_client: await self._client.aclose() async def close(self) -> None: await self.aclose() async def reset( self, seed: int | None = None, episode_id: str | None = None, **kwargs: Any, ) -> StepResult[ObsT]: payload: Dict[str, Any] = {} if seed is not None: payload["seed"] = seed if episode_id is not None: payload["episode_id"] = episode_id payload.update(kwargs) response = await self._client.post("/reset", json=payload) response.raise_for_status() return self._parse_result(response.json()) async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: payload: Dict[str, Any] = {"action": self._step_payload(action)} payload.update(kwargs) response = await self._client.post("/step", json=payload) response.raise_for_status() return self._parse_result(response.json()) async def state(self) -> StateT: response = await self._client.get("/state") response.raise_for_status() return self._parse_state(response.json()) @abstractmethod def _step_payload(self, action: ActT) -> Dict[str, Any]: raise NotImplementedError @abstractmethod def _parse_result(self, data: Dict[str, Any]) -> StepResult[ObsT]: raise NotImplementedError @abstractmethod def _parse_state(self, data: Dict[str, Any]) -> StateT: raise NotImplementedError def _serialize_observation(observation: Observation) -> Dict[str, Any]: return { "observation": observation.model_dump( exclude={"reward", "done", "metadata"}, mode="json", ), "reward": observation.reward, "done": observation.done, } def _root_page_html(env_name: str) -> str: return f""" {env_name}

Enterprise Finance OpenEnv

Interactive console for testing /reset, /step, /state, and score outputs directly from the deployed environment.

Use the API endpoints for agent evaluation. This page is a manual playground for humans.

Episode Controls

Difficulty-
Step Count-
Open Balance-
Score-
Total Reward-
Unmatched-

Submit Action

Paste a JSON action body matching the environment schema.

Latest API Response

Current State

""" def create_app( env_factory: Callable[[], Environment[Any, Observation, State]] | type[Environment[Any, Observation, State]], action_cls: type[BaseModel], observation_cls: type[Observation], env_name: str = "environment", ) -> FastAPI: env = env_factory() if callable(env_factory) else env_factory @asynccontextmanager async def lifespan(_: FastAPI): yield env.close() app = FastAPI(title=env_name, lifespan=lifespan) class ResetRequest(BaseModel): model_config = ConfigDict(extra="allow") seed: int | None = None episode_id: str | None = None class StepRequest(BaseModel): model_config = ConfigDict(extra="allow") action: Dict[str, Any] timeout_s: float | None = None @app.get("/", response_class=HTMLResponse) async def home() -> HTMLResponse: return HTMLResponse(_root_page_html(env_name)) @app.get("/health") async def health() -> Dict[str, str]: return {"status": "healthy"} @app.get("/schema") async def schema() -> Dict[str, Any]: state_cls = env.state.__class__ return { "action": action_cls.model_json_schema(), "observation": observation_cls.model_json_schema(), "state": state_cls.model_json_schema(), } @app.get("/state") async def state() -> Dict[str, Any]: return env.state.model_dump(mode="json") @app.post("/reset") async def reset(request: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]: validated = ResetRequest.model_validate(request) payload = validated.model_dump(exclude_none=True) seed = payload.pop("seed", None) episode_id = payload.pop("episode_id", None) observation = env.reset(seed=seed, episode_id=episode_id, **payload) return _serialize_observation(observation) @app.post("/step") async def step(request: Dict[str, Any] = Body(...)) -> Dict[str, Any]: validated = StepRequest.model_validate(request) try: action = action_cls.model_validate(validated.action) except Exception as exc: # pragma: no cover - FastAPI will surface this in real path. raise HTTPException(status_code=422, detail=str(exc)) from exc payload = validated.model_dump(exclude_none=True) timeout_s = payload.pop("timeout_s", None) payload.pop("action", None) observation = env.step(action, timeout_s=timeout_s, **payload) return _serialize_observation(observation) return app