"""WebSocket client for the CricketCaptain environment.""" from typing import Any, Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult try: from .models import CricketAction, CricketObservation, CricketState except ImportError: from models import CricketAction, CricketObservation, CricketState class CricketCaptainEnv(EnvClient[CricketAction, CricketObservation, CricketState]): def _step_payload(self, action: CricketAction) -> Dict[str, Any]: return {"tool": action.tool, "arguments": action.arguments} def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CricketObservation]: obs_data = payload.get("observation", payload) observation = CricketObservation( game_context=obs_data.get("game_context", {}), declared_strategy=obs_data.get("declared_strategy", {}), bowling_strategy=obs_data.get("bowling_strategy", {}), field_setting=obs_data.get("field_setting", "Balanced"), strategic_phase=obs_data.get("strategic_phase", "pre_ball"), current_batter=obs_data.get("current_batter", {}), current_bowler=obs_data.get("current_bowler", {}), opponent_context=obs_data.get("opponent_context", {}), opponent_plan=obs_data.get("opponent_plan", {}), last_outcome=obs_data.get("last_outcome", {}), eval_pack_id=obs_data.get("eval_pack_id", "default"), game_state=obs_data.get("game_state", "batting"), available_tools=obs_data.get("available_tools", []), tool_history=obs_data.get("tool_history", []), last_ball_result=obs_data.get("last_ball_result", ""), prompt_text=obs_data.get("prompt_text", ""), target=obs_data.get("target"), innings_type=obs_data.get("innings_type", "first"), curriculum_stage=obs_data.get("curriculum_stage", 1), done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> CricketState: return CricketState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), game_state=payload.get("game_state", "batting"), phase=payload.get("phase", "powerplay"), total_score=payload.get("total_score", 0), wickets_lost=payload.get("wickets_lost", 0), over=payload.get("over", 0), ball=payload.get("ball", 0), target=payload.get("target"), match_result=payload.get("match_result", ""), reward_breakdown=payload.get("reward_breakdown", {}), innings_rewards=payload.get("innings_rewards", []), toss_winner=payload.get("toss_winner"), toss_decision=payload.get("toss_decision"), innings_type=payload.get("innings_type", "first"), coherence_scores=payload.get("coherence_scores", []), adaptation_scores=payload.get("adaptation_scores", []), opponent_awareness_scores=payload.get("opponent_awareness_scores", []), regret_scores=payload.get("regret_scores", []), tool_calls_made=payload.get("tool_calls_made", 0), analyze_calls=payload.get("analyze_calls", []), tool_history=payload.get("tool_history", []), transcript=payload.get("transcript", []), dls_par=payload.get("dls_par", 250.0), strategy_changes=payload.get("strategy_changes", 0), last_strategy_set_over=payload.get("last_strategy_set_over", -1), last_reflection=payload.get("last_reflection", ""), eval_pack_id=payload.get("eval_pack_id", "default"), opponent_mode=payload.get("opponent_mode", "heuristic"), strategic_phase=payload.get("strategic_phase", "pre_ball"), shot_plan=payload.get("shot_plan", {}), delivery_plan=payload.get("delivery_plan", {}), opponent_plan=payload.get("opponent_plan", {}), last_outcome=payload.get("last_outcome", {}), current_batter=payload.get("current_batter", {}), current_bowler=payload.get("current_bowler", {}), is_done=payload.get("is_done", False), curriculum_stage=payload.get("curriculum_stage", 1), max_overs=payload.get("max_overs", 50), match_plan=payload.get("match_plan", {}), plan_commitment_scores=payload.get("plan_commitment_scores", []), plan_staleness_penalties=payload.get("plan_staleness_penalties", []), plan_freshness_scores=payload.get("plan_freshness_scores", []), )