cricket-captain-llm / client.py
pratinavseth's picture
feat: all-12-tool UI, auto-play, timestamped run folders, model fix
86a4911
"""WebSocket client for the CricketCaptain environment."""
from typing import Any, Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
try:
from .models import CricketAction, CricketObservation, CricketState
except ImportError:
from models import CricketAction, CricketObservation, CricketState
class CricketCaptainEnv(EnvClient[CricketAction, CricketObservation, CricketState]):
def _step_payload(self, action: CricketAction) -> Dict[str, Any]:
return {"tool": action.tool, "arguments": action.arguments}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CricketObservation]:
obs_data = payload.get("observation", payload)
observation = CricketObservation(
game_context=obs_data.get("game_context", {}),
declared_strategy=obs_data.get("declared_strategy", {}),
bowling_strategy=obs_data.get("bowling_strategy", {}),
field_setting=obs_data.get("field_setting", "Balanced"),
strategic_phase=obs_data.get("strategic_phase", "pre_ball"),
current_batter=obs_data.get("current_batter", {}),
current_bowler=obs_data.get("current_bowler", {}),
opponent_context=obs_data.get("opponent_context", {}),
opponent_plan=obs_data.get("opponent_plan", {}),
last_outcome=obs_data.get("last_outcome", {}),
eval_pack_id=obs_data.get("eval_pack_id", "default"),
game_state=obs_data.get("game_state", "batting"),
available_tools=obs_data.get("available_tools", []),
tool_history=obs_data.get("tool_history", []),
last_ball_result=obs_data.get("last_ball_result", ""),
prompt_text=obs_data.get("prompt_text", ""),
target=obs_data.get("target"),
innings_type=obs_data.get("innings_type", "first"),
curriculum_stage=obs_data.get("curriculum_stage", 1),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> CricketState:
return CricketState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
game_state=payload.get("game_state", "batting"),
phase=payload.get("phase", "powerplay"),
total_score=payload.get("total_score", 0),
wickets_lost=payload.get("wickets_lost", 0),
over=payload.get("over", 0),
ball=payload.get("ball", 0),
target=payload.get("target"),
match_result=payload.get("match_result", ""),
reward_breakdown=payload.get("reward_breakdown", {}),
innings_rewards=payload.get("innings_rewards", []),
toss_winner=payload.get("toss_winner"),
toss_decision=payload.get("toss_decision"),
innings_type=payload.get("innings_type", "first"),
coherence_scores=payload.get("coherence_scores", []),
adaptation_scores=payload.get("adaptation_scores", []),
opponent_awareness_scores=payload.get("opponent_awareness_scores", []),
regret_scores=payload.get("regret_scores", []),
tool_calls_made=payload.get("tool_calls_made", 0),
analyze_calls=payload.get("analyze_calls", []),
tool_history=payload.get("tool_history", []),
transcript=payload.get("transcript", []),
dls_par=payload.get("dls_par", 250.0),
strategy_changes=payload.get("strategy_changes", 0),
last_strategy_set_over=payload.get("last_strategy_set_over", -1),
last_reflection=payload.get("last_reflection", ""),
eval_pack_id=payload.get("eval_pack_id", "default"),
opponent_mode=payload.get("opponent_mode", "heuristic"),
strategic_phase=payload.get("strategic_phase", "pre_ball"),
shot_plan=payload.get("shot_plan", {}),
delivery_plan=payload.get("delivery_plan", {}),
opponent_plan=payload.get("opponent_plan", {}),
last_outcome=payload.get("last_outcome", {}),
current_batter=payload.get("current_batter", {}),
current_bowler=payload.get("current_bowler", {}),
is_done=payload.get("is_done", False),
curriculum_stage=payload.get("curriculum_stage", 1),
max_overs=payload.get("max_overs", 50),
match_plan=payload.get("match_plan", {}),
plan_commitment_scores=payload.get("plan_commitment_scores", []),
plan_staleness_penalties=payload.get("plan_staleness_penalties", []),
plan_freshness_scores=payload.get("plan_freshness_scores", []),
)