"""Subtext Arena client.""" from __future__ import annotations from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from .models import SubtextArenaAction, SubtextArenaObservation class SubtextArenaEnv( EnvClient[SubtextArenaAction, SubtextArenaObservation, State] ): """Client for Subtext Arena. Example: >>> with SubtextArenaEnv(base_url="http://localhost:8000") as env: ... obs = env.reset().observation ... # Inspect transcript first ... obs = env.step(SubtextArenaAction(tool="get_transcript")).observation ... # Then check prosody on the full clip ... obs = env.step(SubtextArenaAction( ... tool="get_prosody_features", ... tool_args={}, ... )).observation ... # Submit a belief ... result = env.step(SubtextArenaAction( ... tool="submit_belief", ... tool_args={"label": "sarcastic", "confidence": 0.85}, ... )) ... print("done:", result.done, "reward:", result.reward) """ def _step_payload(self, action: SubtextArenaAction) -> Dict: return { "tool": action.tool, "tool_args": action.tool_args or {}, } def _parse_result(self, payload: Dict) -> StepResult[SubtextArenaObservation]: obs_data = payload.get("observation", {}) or {} observation = SubtextArenaObservation( clip_id=obs_data.get("clip_id", ""), speaker=obs_data.get("speaker", ""), duration_s=float(obs_data.get("duration_s", 0.0) or 0.0), is_pivot=bool(obs_data.get("is_pivot", False)), tool_used=obs_data.get("tool_used", ""), tool_output=obs_data.get("tool_output", ""), step=int(obs_data.get("step", 0) or 0), max_steps=int(obs_data.get("max_steps", 6) or 6), audio_calls_so_far=int(obs_data.get("audio_calls_so_far", 0) or 0), error=obs_data.get("error"), done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}) or {}, ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> State: return State( episode_id=payload.get("episode_id"), step_count=int(payload.get("step_count", 0) or 0), )