Prasham1710's picture
first commit
4ccc966
from __future__ import annotations
from typing import Any
try:
from openenv.core.client_types import StepResult
except ImportError:
from enterprise_finance_env._compat import StepResult
from enterprise_finance_env._compat import EnvClient
from enterprise_finance_env.models import (
ActionLike,
EnterpriseFinanceActionPayload,
EnterpriseFinanceObservation,
EnterpriseFinanceState,
)
class EnterpriseFinanceClient(
EnvClient[
EnterpriseFinanceActionPayload,
EnterpriseFinanceObservation,
EnterpriseFinanceState,
]
):
def __init__(self, base_url: str = "http://localhost:7860", **kwargs: Any) -> None:
super().__init__(base_url=base_url, **kwargs)
async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]:
payload = (
action
if isinstance(action, EnterpriseFinanceActionPayload)
else EnterpriseFinanceActionPayload(root=action)
)
return await super().step(payload, **kwargs)
def _step_payload(self, action: EnterpriseFinanceActionPayload) -> dict[str, Any]:
payload = action.model_dump(mode="json")
if not isinstance(payload, dict):
raise TypeError("Expected action payload to serialize to a dictionary")
return payload
def _parse_result(self, data: dict[str, Any]) -> StepResult[EnterpriseFinanceObservation]:
observation = EnterpriseFinanceObservation(**data["observation"])
result = StepResult(
observation=observation,
reward=data.get("reward"),
done=data.get("done", False),
)
if hasattr(result, "info"):
setattr(result, "info", data.get("info", {}))
return result
def _parse_state(self, data: dict[str, Any]) -> EnterpriseFinanceState:
return EnterpriseFinanceState(**data)