from __future__ import annotations from typing import Any try: from openenv.core.client_types import StepResult except ImportError: from enterprise_finance_env._compat import StepResult from enterprise_finance_env._compat import EnvClient from enterprise_finance_env.models import ( ActionLike, EnterpriseFinanceActionPayload, EnterpriseFinanceObservation, EnterpriseFinanceState, ) class EnterpriseFinanceClient( EnvClient[ EnterpriseFinanceActionPayload, EnterpriseFinanceObservation, EnterpriseFinanceState, ] ): def __init__(self, base_url: str = "http://localhost:7860", **kwargs: Any) -> None: super().__init__(base_url=base_url, **kwargs) async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]: payload = ( action if isinstance(action, EnterpriseFinanceActionPayload) else EnterpriseFinanceActionPayload(root=action) ) return await super().step(payload, **kwargs) def _step_payload(self, action: EnterpriseFinanceActionPayload) -> dict[str, Any]: payload = action.model_dump(mode="json") if not isinstance(payload, dict): raise TypeError("Expected action payload to serialize to a dictionary") return payload def _parse_result(self, data: dict[str, Any]) -> StepResult[EnterpriseFinanceObservation]: observation = EnterpriseFinanceObservation(**data["observation"]) result = StepResult( observation=observation, reward=data.get("reward"), done=data.get("done", False), ) if hasattr(result, "info"): setattr(result, "info", data.get("info", {})) return result def _parse_state(self, data: dict[str, Any]) -> EnterpriseFinanceState: return EnterpriseFinanceState(**data)