""" Baseline inference script for CricketCaptain-LLM. Runs an LLM agent against all task difficulties (easy=T5, medium=T20, hard=ODI). Emits [START], [STEP], [END] stdout lines per the OpenEnv benchmark spec. Required environment variables (Round 1 spec): API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1) MODEL_NAME Model identifier (default: 'random' baseline) HF_TOKEN HF API key (also accepted as API_KEY) LOCAL_IMAGE_NAME Optional Docker image for from_docker_image() runs Designed to run on vCPU=2, 8 GB RAM — uses HF Router by default so no local model load is required. Optional spectator-bus integration (`--publish-bus`) live-streams every episode to the cockpit at `{bus-url}/custom`. This is purely additive: the benchmark stdout markers and grader output are unaffected. Usage: # Random baseline (no API key) python inference.py --model random --episodes 5 --task easy # API model via HF Router (defaults to API_BASE_URL env var) MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct HF_TOKEN=hf_... \ python inference.py --episodes 10 --task medium # Trained checkpoint served via vLLM at a custom endpoint python inference.py --model ./checkpoints/stage2_final --episodes 20 \ --api-base http://localhost:8080/v1 # Live-stream a real-LLM run to the cockpit UI (server on :8000) python inference.py --model gpt-4o-mini --episodes 1 --max-overs 5 \ --publish-bus --bus-url http://127.0.0.1:8000 """ import argparse import asyncio import datetime import json import os import random import statistics import time import uuid from pathlib import Path from typing import Any try: import openai _OPENAI_AVAILABLE = True except ImportError: _OPENAI_AVAILABLE = False try: from client import CricketCaptainEnv from models import CricketAction except ImportError: from cricket_captain.client import CricketCaptainEnv from cricket_captain.models import CricketAction try: import httpx _HTTPX_AVAILABLE = True except ImportError: _HTTPX_AVAILABLE = False SYSTEM_PROMPT = """You are an expert cricket captain. You must manage the team through the Toss, Batting, and Bowling phases. Your goal is to win the match. You receive a scorecard and must respond with a SINGLE valid JSON tool call from the available tools. Batting Tools: select_batter — Choose batter profile for the situation set_strategy — Declare batting intent (aggression, rationale) plan_shot — Plan target area, risk, and shot intent before execution play_delivery — Choose a shot and advance the game Bowling Tools: choose_bowler — Choose bowler profile for the situation set_bowling_strategy — Set bowler type, line, length, and delivery type plan_delivery — Plan line, length, and variation before execution set_field_setting — Set field preset (Aggressive, Balanced, Defensive) bowl_delivery — Advance the game during bowling phase Common Tools: call_toss — Call heads/tails and make a decision (bat/bowl) analyze_situation — Query match context reflect_after_ball — Briefly update plan after the previous ball Always respond with exactly one JSON object on a single line, no markdown.""" SHOT_AGGRESSION_ORDER = ["leave", "defensive", "single", "rotate", "boundary", "six"] def _coerce_aggression(value: Any, default: float = 0.5) -> float: if isinstance(value, (int, float)): return max(0.0, min(1.0, float(value))) text = str(value).strip().lower() word_map = { "very low": 0.15, "low": 0.25, "conservative": 0.25, "defensive": 0.25, "moderate": 0.5, "medium": 0.5, "balanced": 0.5, "normal": 0.5, "high": 0.75, "aggressive": 0.75, "very high": 0.9, "attack": 0.8, "attacking": 0.8, } try: return max(0.0, min(1.0, float(text))) except ValueError: return word_map.get(text, default) def _normalize_action_args(tool: str, args: dict[str, Any]) -> dict[str, Any]: """Normalize common LLM variants before sending to the server.""" normalized = dict(args) if tool in ("set_strategy", "select_batter") and "aggression" in normalized: normalized["aggression"] = _coerce_aggression(normalized["aggression"]) if tool == "plan_shot" and str(normalized.get("risk", "")).lower() == "moderate": normalized["risk"] = "balanced" if tool == "call_toss": call = str(normalized.get("call", "heads")).lower() decision = str(normalized.get("decision", "bat")).lower() normalized["call"] = call if call in ("heads", "tails") else "heads" normalized["decision"] = decision if decision in ("bat", "bowl") else "bat" return normalized class RandomAgent: """Baseline: random valid tool calls based on availability.""" def __call__(self, messages: list[dict]) -> str: prompt = messages[-1]["content"] if messages else "" # In a real scenario, we'd parse available_tools from the prompt/observation. # Here we'll just check some keywords in the prompt to guess the phase. if "TOSS" in prompt: return json.dumps({ "tool": "call_toss", "arguments": {"call": random.choice(["heads", "tails"]), "decision": random.choice(["bat", "bowl"])} }) if "BOWLING" in prompt: roll = random.random() if roll < 0.12: return json.dumps({ "tool": "choose_bowler", "arguments": { "name": random.choice(["Strike Pacer", "Control Spinner", "Death Specialist"]), "bowler_type": random.choice(["pace", "spin"]), "style": random.choice(["swing", "economy", "yorker"]), "rationale": "Match bowler type to phase and batter style.", }, }) if roll < 0.28: return json.dumps({ "tool": "set_bowling_strategy", "arguments": { "bowler_type": random.choice(["pace", "spin"]), "line": random.choice(["stumps", "outside off", "on pads"]), "length": random.choice(["good length", "full", "short"]), "delivery_type": "stock", "rationale": "Random bowling strategy." } }) elif roll < 0.45: return json.dumps({ "tool": "plan_delivery", "arguments": { "bowler_type": random.choice(["pace", "spin"]), "line": random.choice(["stumps", "outside off", "wide"]), "length": random.choice(["good length", "full", "short", "yorker"]), "delivery_type": random.choice(["stock", "slower ball", "yorker", "bouncer"]), "rationale": "Plan delivery against current batter and field.", }, }) elif roll < 0.58: return json.dumps({ "tool": "set_field_setting", "arguments": {"setting": random.choice(["Aggressive", "Balanced", "Defensive"])} }) elif roll < 0.65: return json.dumps({ "tool": "reflect_after_ball", "arguments": {"reflection": "Adjust based on the previous ball outcome and match pressure."}, }) else: return json.dumps({"tool": "bowl_delivery", "arguments": {}}) # Default Batting logic roll = random.random() if roll < 0.1: return json.dumps({ "tool": "select_batter", "arguments": { "name": random.choice(["Opener", "Anchor", "Finisher"]), "style": random.choice(["balanced", "anchor", "aggressive"]), "aggression": round(random.uniform(0.2, 0.8), 2), "rationale": "Choose batter profile for phase, target, and wicket context.", }, }) if roll < 0.25: agg = round(random.uniform(0.1, 0.9), 2) return json.dumps({ "tool": "set_strategy", "arguments": { "phase_intent": random.choice(["consolidate", "attack", "rotate"]), "aggression": agg, "rationale": "Random strategy selection.", } }) elif roll < 0.4: return json.dumps({ "tool": "plan_shot", "arguments": { "shot_intent": random.choice(SHOT_AGGRESSION_ORDER), "target_area": random.choice(["off-side gap", "leg-side gap", "straight", "boundary"]), "risk": random.choice(["low", "balanced", "high"]), "rationale": "Match shot plan to bowler, field, and required rate.", }, }) elif roll < 0.5: return json.dumps({ "tool": "analyze_situation", "arguments": {"query_type": random.choice(["pitch_conditions", "match_situation"])}, }) elif roll < 0.58: return json.dumps({ "tool": "reflect_after_ball", "arguments": {"reflection": "Revise risk after the previous delivery and current target pressure."}, }) else: return json.dumps({ "tool": "play_delivery", "arguments": { "shot_intent": random.choice(SHOT_AGGRESSION_ORDER), "explanation": "Random shot selection.", }, }) class OpenAIAgent: def __init__( self, model: str, api_base: str | None = None, api_key: str | None = None, timeout: float = 30.0, temperature: float = 0.2, ): if not _OPENAI_AVAILABLE: raise ImportError("openai not installed. Run: pip install openai") self._client = openai.OpenAI( base_url=api_base, api_key=api_key or "dummy", timeout=timeout, ) self._model = model self._temperature = temperature def __call__(self, messages: list[dict]) -> str: resp = self._client.chat.completions.create( model=self._model, messages=messages, temperature=self._temperature, max_tokens=300, ) return resp.choices[0].message.content.strip() def _parse_action(raw: str) -> tuple[CricketAction | None, bool]: raw = raw.strip() if raw.startswith("```"): lines = raw.split("\n") raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw try: if not raw.startswith("{"): start = raw.find("{") if start >= 0: raw = raw[start:] data, _ = json.JSONDecoder().raw_decode(raw) valid_tools = ( "set_strategy", "analyze_situation", "play_delivery", "call_toss", "bowl_delivery", "set_bowling_strategy", "set_field_setting", "choose_bowler", "select_batter", "plan_delivery", "plan_shot", "reflect_after_ball", "set_match_plan", "update_match_plan", ) if "tool" not in data and len(data) == 1: maybe_tool, maybe_args = next(iter(data.items())) if maybe_tool in valid_tools and isinstance(maybe_args, dict): data = {"tool": maybe_tool, "arguments": maybe_args} tool = data.get("tool", "") if tool not in valid_tools: return None, True return CricketAction(tool=tool, arguments=_normalize_action_args(tool, data.get("arguments", {}))), False except Exception: return None, True def _action_short(action: CricketAction) -> str: """Compact one-line action string for [STEP] markers.""" args = action.arguments or {} primary = ( args.get("shot_intent") or args.get("delivery_type") or args.get("phase_intent") or args.get("setting") or args.get("call") ) return f"{action.tool}({primary})" if primary else action.tool async def run_episode( env: CricketCaptainEnv, agent, task: str = "stage2_full", max_steps: int = 600, verbose: bool = False, eval_pack_id: str = "default", opponent_mode: str = "heuristic", max_overs: int | None = None, step_log=None, # callable(str) — called after every delivery for live log benchmark_stdout: bool = True, # emit [START]/[STEP]/[END] markers model_name: str = "random", # threaded into the [START] line observer=None, # optional spectator.publisher.BusObserver ) -> dict[str, Any]: # All reset params must go inside options={} — EnvClient.reset(**kwargs) only # forwards recognised signature params (seed, options); bare kwargs are dropped. result = await env.reset(options={ "task": task, "random_start": False, "eval_pack_id": eval_pack_id, "opponent_mode": opponent_mode, "max_overs": max_overs, }) obs = result.observation if observer is not None: await observer.start(obs) history: list[dict] = [] rewards: list[float] = [] parse_errors = 0 deliveries = 0 turn = 0 _AGENT_TIMEOUT = 45.0 if benchmark_stdout: print(f"[START] task={task} env=cricket_captain model={model_name}", flush=True) while not result.done and turn < max_steps: messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history[-10:] + [ {"role": "user", "content": obs.prompt_text} ] try: raw = await asyncio.wait_for( asyncio.get_event_loop().run_in_executor(None, agent, messages), timeout=_AGENT_TIMEOUT, ) except asyncio.TimeoutError: raw = "" action, err = _parse_action(raw) if err: parse_errors += 1 if "BOWLING" in obs.prompt_text: action = CricketAction(tool="bowl_delivery", arguments={}) elif "TOSS" in obs.prompt_text: action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"}) else: action = CricketAction(tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"}) action_dict = {"tool": action.tool, "arguments": action.arguments} pre = None if observer is not None: pre = await observer.before_step( action_dict, obs, turn=turn, transcript_text=raw, ) try: result = await env.step(action) except Exception as exc: # noqa: BLE001 if observer is not None: await observer.emit_error(f"env.step failed: {exc}", turn=turn, tool=action.tool) raise obs = result.observation r = result.reward or 0.0 rewards.append(r) turn += 1 is_delivery = action.tool in ("play_delivery", "bowl_delivery") if is_delivery: deliveries += 1 if benchmark_stdout: err_field = err.replace("\n", " ").replace(" ", "_")[:60] if err else "null" print( f"[STEP] step={turn} action={_action_short(action)} " f"reward={r:.2f} done={'true' if result.done else 'false'} error={err_field}", flush=True, ) if observer is not None: await observer.after_step( action_dict, pre or {}, obs, r, turn=turn, fetch_state=env.state, ) if is_delivery and step_log: ctx = obs.game_context step_log( f" over={ctx.get('over', '?')}.{ctx.get('ball', '?')} " f"score={ctx.get('score', '?')}/{ctx.get('wickets', '?')} " f"tool={action.tool} r={r:.3f} | {obs.last_ball_result[:60]}" ) elif verbose: print(f" [{turn}] [{obs.game_state.upper()}] {raw[:60]} → r={r:.3f}") history.append({"role": "assistant", "content": raw}) history.append({"role": "user", "content": obs.last_ball_result}) state = await env.state() if observer is not None: await observer.end(state) match_result = getattr(state, "match_result", None) or "" success = (match_result == "win") # Score in [0, 1]: win=1.0, tie=0.5, loss=0.0. Matches Round 1 grader-output spec. if match_result == "win": score = 1.0 elif match_result == "tie": score = 0.5 else: score = 0.0 if benchmark_stdout: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={'true' if success else 'false'} steps={turn} " f"score={score:.2f} rewards={rewards_str}", flush=True, ) coherence = state.coherence_scores return { "total_score": state.total_score, "wickets_lost": state.wickets_lost, "over": state.over, "total_reward": sum(rewards), "mean_coherence": statistics.mean(coherence) if coherence else 0.0, "parse_error_rate": parse_errors / max(turn, 1), "tool_calls": state.tool_calls_made, "adaptation": statistics.mean(state.adaptation_scores) if state.adaptation_scores else 0.0, "opponent_awareness": statistics.mean(state.opponent_awareness_scores) if state.opponent_awareness_scores else 0.0, "regret": statistics.mean(state.regret_scores) if state.regret_scores else 0.0, "deliveries": deliveries, "game_state": state.game_state, "target": state.target, } def _make_inference_run_folder(model: str, opponent_mode: str, max_overs: int | None) -> Path: ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") model_short = model.split("/")[-1][:20] if model != "random" else "random" overs_str = f"_{max_overs}ov" if max_overs else "" opp_str = f"_{opponent_mode}" folder_name = f"exp_{ts}_inference{overs_str}{opp_str}_{model_short}" run_dir = Path(__file__).parent / "illustrations" / folder_name run_dir.mkdir(parents=True, exist_ok=True) return run_dir async def evaluate(args): agent: Any if args.model == "random": agent = RandomAgent() print("Using RandomAgent baseline") else: agent = OpenAIAgent(args.model, api_base=args.api_base, api_key=args.api_key) print(f"Using OpenAI-compatible agent: {args.model}") run_dir = _make_inference_run_folder(args.model, args.opponent_mode, args.max_overs) log_file = run_dir / "run_output.txt" # Write header immediately so the file exists and is readable while running header = "\n".join([ f"# Inference run: {run_dir.name}", f"timestamp_utc: {datetime.datetime.utcnow().isoformat()}", f"model: {args.model}", f"api_base: {args.api_base}", f"opponent_mode: {args.opponent_mode}", f"max_overs: {args.max_overs}", f"episodes: {args.episodes}", f"task: {args.task}", f"eval_pack_id: {args.eval_pack_id}", "", ]) log_file.write_text(header) def _log(msg: str): print(msg) with open(log_file, "a") as f: f.write(msg + "\n") results = [] bus_http: Any = None if args.publish_bus: if not _HTTPX_AVAILABLE: raise RuntimeError("--publish-bus requires httpx. Run: pip install httpx") bus_http = httpx.AsyncClient(timeout=15.0) try: async with CricketCaptainEnv(args.env_url) as env: for ep in range(args.episodes): _log(f"\n--- Episode {ep+1}/{args.episodes} ---") observer = None if args.publish_bus: from spectator.publisher import BusObserver # lazy: only if asked episode_id = ( f"{args.bus_episode_prefix}-{int(time.time())}-{uuid.uuid4().hex[:6]}" ) observer = BusObserver( bus_http, args.bus_url, episode_id, mode=args.model, task=args.task, opponent_mode=args.opponent_mode, eval_pack_id=args.eval_pack_id, emit_transcript=args.bus_transcript, ) _log(f" [bus] watch live at {args.bus_url}/custom (ep id: {episode_id})") ep_result = await run_episode( env, agent, task=args.task, verbose=args.verbose, eval_pack_id=args.eval_pack_id, opponent_mode=args.opponent_mode, max_overs=args.max_overs, step_log=_log, model_name=args.model, observer=observer, ) results.append(ep_result) line = ( f"Episode {ep+1:>3}/{args.episodes} | " f"Score: {ep_result['total_score']:>3}/{ep_result['wickets_lost']} " f"({ep_result['over']} ov) | " f"Reward: {ep_result['total_reward']:>6.3f} | " f"Coherence: {ep_result['mean_coherence']:.3f} | " f"Adapt: {ep_result['adaptation']:.3f} | " f"ParseErr: {ep_result['parse_error_rate']:.1%}" ) _log(line) finally: if bus_http is not None: await bus_http.aclose() _log("\n=== Summary ===") summary_lines = [] for key in ["total_score", "wickets_lost", "total_reward", "mean_coherence", "parse_error_rate"]: vals = [r[key] for r in results] summary_lines.append(f" {key:20s}: mean={statistics.mean(vals):.3f} std={statistics.stdev(vals) if len(vals)>1 else 0:.3f}") _log(summary_lines[-1]) # Write README (always at end — has final summary) (run_dir / "README.md").write_text( f"## Inference Run: {run_dir.name}\n\n" f"**Date**: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n" f"| Setting | Value |\n|---|---|\n" f"| Model | `{args.model}` |\n" f"| API base | `{args.api_base or 'N/A'}` |\n" f"| Opponent mode | `{args.opponent_mode}` |\n" f"| Max overs | {args.max_overs} |\n" f"| Episodes | {args.episodes} |\n" f"| Task | `{args.task}` |\n\n" f"### Results\n\n```\n" + "\n".join(summary_lines) + "\n```\n\n" f"See `run_output.txt` for full verbose episode log.\n" ) print(f"\nRun saved → {run_dir}") def main(): parser = argparse.ArgumentParser(description="CricketCaptain-LLM Baseline Inference") parser.add_argument("--config", default=None, help="YAML config path (runner defaults).") parser.add_argument("--model", default="random", help="'random' or OpenAI model name") parser.add_argument("--episodes", type=int, default=5) parser.add_argument("--task", default="medium", choices=["easy", "medium", "hard", "stage2_full", "eval_50over"], help="Task difficulty: easy=5-over, medium=T20, hard=50-over ODI. " "stage2_full / eval_50over are legacy aliases.") parser.add_argument("--max-overs", type=int, default=None, help="Limit innings length for fast experiments (e.g. 5).") parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000")) parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default")) parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"), choices=["heuristic", "llm_live", "llm_cached", "cricsheet"]) # Round-1 spec env-var contract: API_BASE_URL / MODEL_NAME / HF_TOKEN. # CLI flags --api-base / --api-key / --model override; otherwise we fall # back to those env vars (with HF Router defaults so the script runs on # vCPU=2, 8 GB RAM without local model loading). parser.add_argument("--api-base", default=os.environ.get("API_BASE_URL")) parser.add_argument("--api-key", default=os.environ.get("HF_TOKEN") or os.environ.get("API_KEY")) parser.add_argument("--verbose", action="store_true") # Spectator-bus integration. Off by default; turn on to live-stream this # run to the cockpit at {bus-url}/custom. parser.add_argument("--publish-bus", action="store_true", help="Stream events to the spectator UI bus at /custom/publish.") parser.add_argument("--bus-url", default=os.environ.get("CRICKET_UI_BASE_URL", "http://localhost:8000"), help="HTTP base URL of the FastAPI server hosting /custom/publish.") parser.add_argument("--bus-episode-prefix", default="inf", help="Prefix used to name episode IDs when publishing.") parser.add_argument("--bus-transcript", action="store_true", help="Also emit raw LLM replies as `transcript` events (debug-only).") args = parser.parse_args() # If neither --model nor any config overrides it, prefer MODEL_NAME env var. if args.model == "random" and os.environ.get("MODEL_NAME"): args.model = os.environ["MODEL_NAME"] # Default to HF Router when running with an OpenAI-compatible model and no # explicit api-base set — keeps inference CPU-only as required. if args.model != "random" and not args.api_base: args.api_base = "https://router.huggingface.co/v1" if args.config: try: from config_yaml import load_config, apply_runner_config_defaults except ImportError: from cricket_captain.config_yaml import load_config, apply_runner_config_defaults defaults = apply_runner_config_defaults(load_config(args.config)) if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url: args.env_url = defaults.env_url if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id: args.eval_pack_id = defaults.eval_pack_id if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode: args.opponent_mode = defaults.opponent_mode if args.max_overs is None and defaults.max_overs is not None: args.max_overs = int(defaults.max_overs) if args.model == "random" and defaults.captain_model: args.model = defaults.captain_model if args.api_base is None and defaults.captain_api_base: args.api_base = defaults.captain_api_base if args.api_key is None and defaults.captain_api_key: args.api_key = defaults.captain_api_key asyncio.run(evaluate(args)) if __name__ == "__main__": main()