| """WebSocket client for the CricketCaptain environment.""" |
|
|
| from typing import Any, Dict |
|
|
| from openenv.core import EnvClient |
| from openenv.core.client_types import StepResult |
|
|
| try: |
| from .models import CricketAction, CricketObservation, CricketState |
| except ImportError: |
| from models import CricketAction, CricketObservation, CricketState |
|
|
|
|
| class CricketCaptainEnv(EnvClient[CricketAction, CricketObservation, CricketState]): |
| def _step_payload(self, action: CricketAction) -> Dict[str, Any]: |
| return {"tool": action.tool, "arguments": action.arguments} |
|
|
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CricketObservation]: |
| obs_data = payload.get("observation", payload) |
| observation = CricketObservation( |
| game_context=obs_data.get("game_context", {}), |
| declared_strategy=obs_data.get("declared_strategy", {}), |
| bowling_strategy=obs_data.get("bowling_strategy", {}), |
| field_setting=obs_data.get("field_setting", "Balanced"), |
| strategic_phase=obs_data.get("strategic_phase", "pre_ball"), |
| current_batter=obs_data.get("current_batter", {}), |
| current_bowler=obs_data.get("current_bowler", {}), |
| opponent_context=obs_data.get("opponent_context", {}), |
| opponent_plan=obs_data.get("opponent_plan", {}), |
| last_outcome=obs_data.get("last_outcome", {}), |
| eval_pack_id=obs_data.get("eval_pack_id", "default"), |
| game_state=obs_data.get("game_state", "batting"), |
| available_tools=obs_data.get("available_tools", []), |
| tool_history=obs_data.get("tool_history", []), |
| last_ball_result=obs_data.get("last_ball_result", ""), |
| prompt_text=obs_data.get("prompt_text", ""), |
| target=obs_data.get("target"), |
| innings_type=obs_data.get("innings_type", "first"), |
| curriculum_stage=obs_data.get("curriculum_stage", 1), |
| done=payload.get("done", False), |
| reward=payload.get("reward"), |
| metadata=obs_data.get("metadata", {}), |
| ) |
| return StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
|
|
| def _parse_state(self, payload: Dict[str, Any]) -> CricketState: |
| return CricketState( |
| episode_id=payload.get("episode_id"), |
| step_count=payload.get("step_count", 0), |
| game_state=payload.get("game_state", "batting"), |
| phase=payload.get("phase", "powerplay"), |
| total_score=payload.get("total_score", 0), |
| wickets_lost=payload.get("wickets_lost", 0), |
| over=payload.get("over", 0), |
| ball=payload.get("ball", 0), |
| target=payload.get("target"), |
| match_result=payload.get("match_result", ""), |
| reward_breakdown=payload.get("reward_breakdown", {}), |
| innings_rewards=payload.get("innings_rewards", []), |
| toss_winner=payload.get("toss_winner"), |
| toss_decision=payload.get("toss_decision"), |
| innings_type=payload.get("innings_type", "first"), |
| coherence_scores=payload.get("coherence_scores", []), |
| adaptation_scores=payload.get("adaptation_scores", []), |
| opponent_awareness_scores=payload.get("opponent_awareness_scores", []), |
| regret_scores=payload.get("regret_scores", []), |
| tool_calls_made=payload.get("tool_calls_made", 0), |
| analyze_calls=payload.get("analyze_calls", []), |
| tool_history=payload.get("tool_history", []), |
| transcript=payload.get("transcript", []), |
| dls_par=payload.get("dls_par", 250.0), |
| strategy_changes=payload.get("strategy_changes", 0), |
| last_strategy_set_over=payload.get("last_strategy_set_over", -1), |
| last_reflection=payload.get("last_reflection", ""), |
| eval_pack_id=payload.get("eval_pack_id", "default"), |
| opponent_mode=payload.get("opponent_mode", "heuristic"), |
| strategic_phase=payload.get("strategic_phase", "pre_ball"), |
| shot_plan=payload.get("shot_plan", {}), |
| delivery_plan=payload.get("delivery_plan", {}), |
| opponent_plan=payload.get("opponent_plan", {}), |
| last_outcome=payload.get("last_outcome", {}), |
| current_batter=payload.get("current_batter", {}), |
| current_bowler=payload.get("current_bowler", {}), |
| is_done=payload.get("is_done", False), |
| curriculum_stage=payload.get("curriculum_stage", 1), |
| max_overs=payload.get("max_overs", 50), |
| match_plan=payload.get("match_plan", {}), |
| plan_commitment_scores=payload.get("plan_commitment_scores", []), |
| plan_staleness_penalties=payload.get("plan_staleness_penalties", []), |
| plan_freshness_scores=payload.get("plan_freshness_scores", []), |
| ) |
|
|