Spaces:
Sleeping
Sleeping
File size: 1,907 Bytes
4ccc966 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | from __future__ import annotations
from typing import Any
try:
from openenv.core.client_types import StepResult
except ImportError:
from enterprise_finance_env._compat import StepResult
from enterprise_finance_env._compat import EnvClient
from enterprise_finance_env.models import (
ActionLike,
EnterpriseFinanceActionPayload,
EnterpriseFinanceObservation,
EnterpriseFinanceState,
)
class EnterpriseFinanceClient(
EnvClient[
EnterpriseFinanceActionPayload,
EnterpriseFinanceObservation,
EnterpriseFinanceState,
]
):
def __init__(self, base_url: str = "http://localhost:7860", **kwargs: Any) -> None:
super().__init__(base_url=base_url, **kwargs)
async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]:
payload = (
action
if isinstance(action, EnterpriseFinanceActionPayload)
else EnterpriseFinanceActionPayload(root=action)
)
return await super().step(payload, **kwargs)
def _step_payload(self, action: EnterpriseFinanceActionPayload) -> dict[str, Any]:
payload = action.model_dump(mode="json")
if not isinstance(payload, dict):
raise TypeError("Expected action payload to serialize to a dictionary")
return payload
def _parse_result(self, data: dict[str, Any]) -> StepResult[EnterpriseFinanceObservation]:
observation = EnterpriseFinanceObservation(**data["observation"])
result = StepResult(
observation=observation,
reward=data.get("reward"),
done=data.get("done", False),
)
if hasattr(result, "info"):
setattr(result, "info", data.get("info", {}))
return result
def _parse_state(self, data: dict[str, Any]) -> EnterpriseFinanceState:
return EnterpriseFinanceState(**data)
|