Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any | |
| try: | |
| from openenv.core.client_types import StepResult | |
| except ImportError: | |
| from enterprise_finance_env._compat import StepResult | |
| from enterprise_finance_env._compat import EnvClient | |
| from enterprise_finance_env.models import ( | |
| ActionLike, | |
| EnterpriseFinanceActionPayload, | |
| EnterpriseFinanceObservation, | |
| EnterpriseFinanceState, | |
| ) | |
| class EnterpriseFinanceClient( | |
| EnvClient[ | |
| EnterpriseFinanceActionPayload, | |
| EnterpriseFinanceObservation, | |
| EnterpriseFinanceState, | |
| ] | |
| ): | |
| def __init__(self, base_url: str = "http://localhost:7860", **kwargs: Any) -> None: | |
| super().__init__(base_url=base_url, **kwargs) | |
| async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]: | |
| payload = ( | |
| action | |
| if isinstance(action, EnterpriseFinanceActionPayload) | |
| else EnterpriseFinanceActionPayload(root=action) | |
| ) | |
| return await super().step(payload, **kwargs) | |
| def _step_payload(self, action: EnterpriseFinanceActionPayload) -> dict[str, Any]: | |
| payload = action.model_dump(mode="json") | |
| if not isinstance(payload, dict): | |
| raise TypeError("Expected action payload to serialize to a dictionary") | |
| return payload | |
| def _parse_result(self, data: dict[str, Any]) -> StepResult[EnterpriseFinanceObservation]: | |
| observation = EnterpriseFinanceObservation(**data["observation"]) | |
| result = StepResult( | |
| observation=observation, | |
| reward=data.get("reward"), | |
| done=data.get("done", False), | |
| ) | |
| if hasattr(result, "info"): | |
| setattr(result, "info", data.get("info", {})) | |
| return result | |
| def _parse_state(self, data: dict[str, Any]) -> EnterpriseFinanceState: | |
| return EnterpriseFinanceState(**data) | |