sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified | """ | |
| Baseline inference script for CricketCaptain-LLM. | |
| Runs an LLM agent against all task difficulties (easy=T5, medium=T20, hard=ODI). | |
| Emits [START], [STEP], [END] stdout lines per the OpenEnv benchmark spec. | |
| Required environment variables (Round 1 spec): | |
| API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1) | |
| MODEL_NAME Model identifier (default: 'random' baseline) | |
| HF_TOKEN HF API key (also accepted as API_KEY) | |
| LOCAL_IMAGE_NAME Optional Docker image for from_docker_image() runs | |
| Designed to run on vCPU=2, 8 GB RAM — uses HF Router by default so no local | |
| model load is required. | |
| Optional spectator-bus integration (`--publish-bus`) live-streams every | |
| episode to the cockpit at `{bus-url}/custom`. This is purely additive: the | |
| benchmark stdout markers and grader output are unaffected. | |
| Usage: | |
| # Random baseline (no API key) | |
| python inference.py --model random --episodes 5 --task easy | |
| # API model via HF Router (defaults to API_BASE_URL env var) | |
| MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct HF_TOKEN=hf_... \ | |
| python inference.py --episodes 10 --task medium | |
| # Trained checkpoint served via vLLM at a custom endpoint | |
| python inference.py --model ./checkpoints/stage2_final --episodes 20 \ | |
| --api-base http://localhost:8080/v1 | |
| # Live-stream a real-LLM run to the cockpit UI (server on :8000) | |
| python inference.py --model gpt-4o-mini --episodes 1 --max-overs 5 \ | |
| --publish-bus --bus-url http://127.0.0.1:8000 | |
| """ | |
| import argparse | |
| import asyncio | |
| import datetime | |
| import json | |
| import os | |
| import random | |
| import statistics | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| try: | |
| import openai | |
| _OPENAI_AVAILABLE = True | |
| except ImportError: | |
| _OPENAI_AVAILABLE = False | |
| try: | |
| from client import CricketCaptainEnv | |
| from models import CricketAction | |
| except ImportError: | |
| from cricket_captain.client import CricketCaptainEnv | |
| from cricket_captain.models import CricketAction | |
| try: | |
| import httpx | |
| _HTTPX_AVAILABLE = True | |
| except ImportError: | |
| _HTTPX_AVAILABLE = False | |
| SYSTEM_PROMPT = """You are an expert cricket captain. You must manage the team through the Toss, Batting, and Bowling phases. | |
| Your goal is to win the match. You receive a scorecard and must respond with a SINGLE valid JSON tool call from the available tools. | |
| Batting Tools: | |
| select_batter — Choose batter profile for the situation | |
| set_strategy — Declare batting intent (aggression, rationale) | |
| plan_shot — Plan target area, risk, and shot intent before execution | |
| play_delivery — Choose a shot and advance the game | |
| Bowling Tools: | |
| choose_bowler — Choose bowler profile for the situation | |
| set_bowling_strategy — Set bowler type, line, length, and delivery type | |
| plan_delivery — Plan line, length, and variation before execution | |
| set_field_setting — Set field preset (Aggressive, Balanced, Defensive) | |
| bowl_delivery — Advance the game during bowling phase | |
| Common Tools: | |
| call_toss — Call heads/tails and make a decision (bat/bowl) | |
| analyze_situation — Query match context | |
| reflect_after_ball — Briefly update plan after the previous ball | |
| Always respond with exactly one JSON object on a single line, no markdown.""" | |
| SHOT_AGGRESSION_ORDER = ["leave", "defensive", "single", "rotate", "boundary", "six"] | |
| def _coerce_aggression(value: Any, default: float = 0.5) -> float: | |
| if isinstance(value, (int, float)): | |
| return max(0.0, min(1.0, float(value))) | |
| text = str(value).strip().lower() | |
| word_map = { | |
| "very low": 0.15, | |
| "low": 0.25, | |
| "conservative": 0.25, | |
| "defensive": 0.25, | |
| "moderate": 0.5, | |
| "medium": 0.5, | |
| "balanced": 0.5, | |
| "normal": 0.5, | |
| "high": 0.75, | |
| "aggressive": 0.75, | |
| "very high": 0.9, | |
| "attack": 0.8, | |
| "attacking": 0.8, | |
| } | |
| try: | |
| return max(0.0, min(1.0, float(text))) | |
| except ValueError: | |
| return word_map.get(text, default) | |
| def _normalize_action_args(tool: str, args: dict[str, Any]) -> dict[str, Any]: | |
| """Normalize common LLM variants before sending to the server.""" | |
| normalized = dict(args) | |
| if tool in ("set_strategy", "select_batter") and "aggression" in normalized: | |
| normalized["aggression"] = _coerce_aggression(normalized["aggression"]) | |
| if tool == "plan_shot" and str(normalized.get("risk", "")).lower() == "moderate": | |
| normalized["risk"] = "balanced" | |
| if tool == "call_toss": | |
| call = str(normalized.get("call", "heads")).lower() | |
| decision = str(normalized.get("decision", "bat")).lower() | |
| normalized["call"] = call if call in ("heads", "tails") else "heads" | |
| normalized["decision"] = decision if decision in ("bat", "bowl") else "bat" | |
| return normalized | |
| class RandomAgent: | |
| """Baseline: random valid tool calls based on availability.""" | |
| def __call__(self, messages: list[dict]) -> str: | |
| prompt = messages[-1]["content"] if messages else "" | |
| # In a real scenario, we'd parse available_tools from the prompt/observation. | |
| # Here we'll just check some keywords in the prompt to guess the phase. | |
| if "TOSS" in prompt: | |
| return json.dumps({ | |
| "tool": "call_toss", | |
| "arguments": {"call": random.choice(["heads", "tails"]), "decision": random.choice(["bat", "bowl"])} | |
| }) | |
| if "BOWLING" in prompt: | |
| roll = random.random() | |
| if roll < 0.12: | |
| return json.dumps({ | |
| "tool": "choose_bowler", | |
| "arguments": { | |
| "name": random.choice(["Strike Pacer", "Control Spinner", "Death Specialist"]), | |
| "bowler_type": random.choice(["pace", "spin"]), | |
| "style": random.choice(["swing", "economy", "yorker"]), | |
| "rationale": "Match bowler type to phase and batter style.", | |
| }, | |
| }) | |
| if roll < 0.28: | |
| return json.dumps({ | |
| "tool": "set_bowling_strategy", | |
| "arguments": { | |
| "bowler_type": random.choice(["pace", "spin"]), | |
| "line": random.choice(["stumps", "outside off", "on pads"]), | |
| "length": random.choice(["good length", "full", "short"]), | |
| "delivery_type": "stock", | |
| "rationale": "Random bowling strategy." | |
| } | |
| }) | |
| elif roll < 0.45: | |
| return json.dumps({ | |
| "tool": "plan_delivery", | |
| "arguments": { | |
| "bowler_type": random.choice(["pace", "spin"]), | |
| "line": random.choice(["stumps", "outside off", "wide"]), | |
| "length": random.choice(["good length", "full", "short", "yorker"]), | |
| "delivery_type": random.choice(["stock", "slower ball", "yorker", "bouncer"]), | |
| "rationale": "Plan delivery against current batter and field.", | |
| }, | |
| }) | |
| elif roll < 0.58: | |
| return json.dumps({ | |
| "tool": "set_field_setting", | |
| "arguments": {"setting": random.choice(["Aggressive", "Balanced", "Defensive"])} | |
| }) | |
| elif roll < 0.65: | |
| return json.dumps({ | |
| "tool": "reflect_after_ball", | |
| "arguments": {"reflection": "Adjust based on the previous ball outcome and match pressure."}, | |
| }) | |
| else: | |
| return json.dumps({"tool": "bowl_delivery", "arguments": {}}) | |
| # Default Batting logic | |
| roll = random.random() | |
| if roll < 0.1: | |
| return json.dumps({ | |
| "tool": "select_batter", | |
| "arguments": { | |
| "name": random.choice(["Opener", "Anchor", "Finisher"]), | |
| "style": random.choice(["balanced", "anchor", "aggressive"]), | |
| "aggression": round(random.uniform(0.2, 0.8), 2), | |
| "rationale": "Choose batter profile for phase, target, and wicket context.", | |
| }, | |
| }) | |
| if roll < 0.25: | |
| agg = round(random.uniform(0.1, 0.9), 2) | |
| return json.dumps({ | |
| "tool": "set_strategy", | |
| "arguments": { | |
| "phase_intent": random.choice(["consolidate", "attack", "rotate"]), | |
| "aggression": agg, | |
| "rationale": "Random strategy selection.", | |
| } | |
| }) | |
| elif roll < 0.4: | |
| return json.dumps({ | |
| "tool": "plan_shot", | |
| "arguments": { | |
| "shot_intent": random.choice(SHOT_AGGRESSION_ORDER), | |
| "target_area": random.choice(["off-side gap", "leg-side gap", "straight", "boundary"]), | |
| "risk": random.choice(["low", "balanced", "high"]), | |
| "rationale": "Match shot plan to bowler, field, and required rate.", | |
| }, | |
| }) | |
| elif roll < 0.5: | |
| return json.dumps({ | |
| "tool": "analyze_situation", | |
| "arguments": {"query_type": random.choice(["pitch_conditions", "match_situation"])}, | |
| }) | |
| elif roll < 0.58: | |
| return json.dumps({ | |
| "tool": "reflect_after_ball", | |
| "arguments": {"reflection": "Revise risk after the previous delivery and current target pressure."}, | |
| }) | |
| else: | |
| return json.dumps({ | |
| "tool": "play_delivery", | |
| "arguments": { | |
| "shot_intent": random.choice(SHOT_AGGRESSION_ORDER), | |
| "explanation": "Random shot selection.", | |
| }, | |
| }) | |
| class OpenAIAgent: | |
| def __init__( | |
| self, | |
| model: str, | |
| api_base: str | None = None, | |
| api_key: str | None = None, | |
| timeout: float = 30.0, | |
| temperature: float = 0.2, | |
| ): | |
| if not _OPENAI_AVAILABLE: | |
| raise ImportError("openai not installed. Run: pip install openai") | |
| self._client = openai.OpenAI( | |
| base_url=api_base, | |
| api_key=api_key or "dummy", | |
| timeout=timeout, | |
| ) | |
| self._model = model | |
| self._temperature = temperature | |
| def __call__(self, messages: list[dict]) -> str: | |
| resp = self._client.chat.completions.create( | |
| model=self._model, | |
| messages=messages, | |
| temperature=self._temperature, | |
| max_tokens=300, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| def _parse_action(raw: str) -> tuple[CricketAction | None, bool]: | |
| raw = raw.strip() | |
| if raw.startswith("```"): | |
| lines = raw.split("\n") | |
| raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw | |
| try: | |
| if not raw.startswith("{"): | |
| start = raw.find("{") | |
| if start >= 0: | |
| raw = raw[start:] | |
| data, _ = json.JSONDecoder().raw_decode(raw) | |
| valid_tools = ( | |
| "set_strategy", "analyze_situation", "play_delivery", | |
| "call_toss", "bowl_delivery", "set_bowling_strategy", "set_field_setting", | |
| "choose_bowler", "select_batter", "plan_delivery", "plan_shot", "reflect_after_ball", | |
| "set_match_plan", "update_match_plan", | |
| ) | |
| if "tool" not in data and len(data) == 1: | |
| maybe_tool, maybe_args = next(iter(data.items())) | |
| if maybe_tool in valid_tools and isinstance(maybe_args, dict): | |
| data = {"tool": maybe_tool, "arguments": maybe_args} | |
| tool = data.get("tool", "") | |
| if tool not in valid_tools: | |
| return None, True | |
| return CricketAction(tool=tool, arguments=_normalize_action_args(tool, data.get("arguments", {}))), False | |
| except Exception: | |
| return None, True | |
| def _action_short(action: CricketAction) -> str: | |
| """Compact one-line action string for [STEP] markers.""" | |
| args = action.arguments or {} | |
| primary = ( | |
| args.get("shot_intent") | |
| or args.get("delivery_type") | |
| or args.get("phase_intent") | |
| or args.get("setting") | |
| or args.get("call") | |
| ) | |
| return f"{action.tool}({primary})" if primary else action.tool | |
| async def run_episode( | |
| env: CricketCaptainEnv, | |
| agent, | |
| task: str = "stage2_full", | |
| max_steps: int = 600, | |
| verbose: bool = False, | |
| eval_pack_id: str = "default", | |
| opponent_mode: str = "heuristic", | |
| max_overs: int | None = None, | |
| step_log=None, # callable(str) — called after every delivery for live log | |
| benchmark_stdout: bool = True, # emit [START]/[STEP]/[END] markers | |
| model_name: str = "random", # threaded into the [START] line | |
| observer=None, # optional spectator.publisher.BusObserver | |
| ) -> dict[str, Any]: | |
| # All reset params must go inside options={} — EnvClient.reset(**kwargs) only | |
| # forwards recognised signature params (seed, options); bare kwargs are dropped. | |
| result = await env.reset(options={ | |
| "task": task, | |
| "random_start": False, | |
| "eval_pack_id": eval_pack_id, | |
| "opponent_mode": opponent_mode, | |
| "max_overs": max_overs, | |
| }) | |
| obs = result.observation | |
| if observer is not None: | |
| await observer.start(obs) | |
| history: list[dict] = [] | |
| rewards: list[float] = [] | |
| parse_errors = 0 | |
| deliveries = 0 | |
| turn = 0 | |
| _AGENT_TIMEOUT = 45.0 | |
| if benchmark_stdout: | |
| print(f"[START] task={task} env=cricket_captain model={model_name}", flush=True) | |
| while not result.done and turn < max_steps: | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history[-10:] + [ | |
| {"role": "user", "content": obs.prompt_text} | |
| ] | |
| try: | |
| raw = await asyncio.wait_for( | |
| asyncio.get_event_loop().run_in_executor(None, agent, messages), | |
| timeout=_AGENT_TIMEOUT, | |
| ) | |
| except asyncio.TimeoutError: | |
| raw = "" | |
| action, err = _parse_action(raw) | |
| if err: | |
| parse_errors += 1 | |
| if "BOWLING" in obs.prompt_text: | |
| action = CricketAction(tool="bowl_delivery", arguments={}) | |
| elif "TOSS" in obs.prompt_text: | |
| action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"}) | |
| else: | |
| action = CricketAction(tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"}) | |
| action_dict = {"tool": action.tool, "arguments": action.arguments} | |
| pre = None | |
| if observer is not None: | |
| pre = await observer.before_step( | |
| action_dict, obs, turn=turn, transcript_text=raw, | |
| ) | |
| try: | |
| result = await env.step(action) | |
| except Exception as exc: # noqa: BLE001 | |
| if observer is not None: | |
| await observer.emit_error(f"env.step failed: {exc}", turn=turn, tool=action.tool) | |
| raise | |
| obs = result.observation | |
| r = result.reward or 0.0 | |
| rewards.append(r) | |
| turn += 1 | |
| is_delivery = action.tool in ("play_delivery", "bowl_delivery") | |
| if is_delivery: | |
| deliveries += 1 | |
| if benchmark_stdout: | |
| err_field = err.replace("\n", " ").replace(" ", "_")[:60] if err else "null" | |
| print( | |
| f"[STEP] step={turn} action={_action_short(action)} " | |
| f"reward={r:.2f} done={'true' if result.done else 'false'} error={err_field}", | |
| flush=True, | |
| ) | |
| if observer is not None: | |
| await observer.after_step( | |
| action_dict, pre or {}, obs, r, | |
| turn=turn, | |
| fetch_state=env.state, | |
| ) | |
| if is_delivery and step_log: | |
| ctx = obs.game_context | |
| step_log( | |
| f" over={ctx.get('over', '?')}.{ctx.get('ball', '?')} " | |
| f"score={ctx.get('score', '?')}/{ctx.get('wickets', '?')} " | |
| f"tool={action.tool} r={r:.3f} | {obs.last_ball_result[:60]}" | |
| ) | |
| elif verbose: | |
| print(f" [{turn}] [{obs.game_state.upper()}] {raw[:60]} → r={r:.3f}") | |
| history.append({"role": "assistant", "content": raw}) | |
| history.append({"role": "user", "content": obs.last_ball_result}) | |
| state = await env.state() | |
| if observer is not None: | |
| await observer.end(state) | |
| match_result = getattr(state, "match_result", None) or "" | |
| success = (match_result == "win") | |
| # Score in [0, 1]: win=1.0, tie=0.5, loss=0.0. Matches Round 1 grader-output spec. | |
| if match_result == "win": | |
| score = 1.0 | |
| elif match_result == "tie": | |
| score = 0.5 | |
| else: | |
| score = 0.0 | |
| if benchmark_stdout: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={'true' if success else 'false'} steps={turn} " | |
| f"score={score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| coherence = state.coherence_scores | |
| return { | |
| "total_score": state.total_score, | |
| "wickets_lost": state.wickets_lost, | |
| "over": state.over, | |
| "total_reward": sum(rewards), | |
| "mean_coherence": statistics.mean(coherence) if coherence else 0.0, | |
| "parse_error_rate": parse_errors / max(turn, 1), | |
| "tool_calls": state.tool_calls_made, | |
| "adaptation": statistics.mean(state.adaptation_scores) if state.adaptation_scores else 0.0, | |
| "opponent_awareness": statistics.mean(state.opponent_awareness_scores) if state.opponent_awareness_scores else 0.0, | |
| "regret": statistics.mean(state.regret_scores) if state.regret_scores else 0.0, | |
| "deliveries": deliveries, | |
| "game_state": state.game_state, | |
| "target": state.target, | |
| } | |
| def _make_inference_run_folder(model: str, opponent_mode: str, max_overs: int | None) -> Path: | |
| ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") | |
| model_short = model.split("/")[-1][:20] if model != "random" else "random" | |
| overs_str = f"_{max_overs}ov" if max_overs else "" | |
| opp_str = f"_{opponent_mode}" | |
| folder_name = f"exp_{ts}_inference{overs_str}{opp_str}_{model_short}" | |
| run_dir = Path(__file__).parent / "illustrations" / folder_name | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| return run_dir | |
| async def evaluate(args): | |
| agent: Any | |
| if args.model == "random": | |
| agent = RandomAgent() | |
| print("Using RandomAgent baseline") | |
| else: | |
| agent = OpenAIAgent(args.model, api_base=args.api_base, api_key=args.api_key) | |
| print(f"Using OpenAI-compatible agent: {args.model}") | |
| run_dir = _make_inference_run_folder(args.model, args.opponent_mode, args.max_overs) | |
| log_file = run_dir / "run_output.txt" | |
| # Write header immediately so the file exists and is readable while running | |
| header = "\n".join([ | |
| f"# Inference run: {run_dir.name}", | |
| f"timestamp_utc: {datetime.datetime.utcnow().isoformat()}", | |
| f"model: {args.model}", | |
| f"api_base: {args.api_base}", | |
| f"opponent_mode: {args.opponent_mode}", | |
| f"max_overs: {args.max_overs}", | |
| f"episodes: {args.episodes}", | |
| f"task: {args.task}", | |
| f"eval_pack_id: {args.eval_pack_id}", | |
| "", | |
| ]) | |
| log_file.write_text(header) | |
| def _log(msg: str): | |
| print(msg) | |
| with open(log_file, "a") as f: | |
| f.write(msg + "\n") | |
| results = [] | |
| bus_http: Any = None | |
| if args.publish_bus: | |
| if not _HTTPX_AVAILABLE: | |
| raise RuntimeError("--publish-bus requires httpx. Run: pip install httpx") | |
| bus_http = httpx.AsyncClient(timeout=15.0) | |
| try: | |
| async with CricketCaptainEnv(args.env_url) as env: | |
| for ep in range(args.episodes): | |
| _log(f"\n--- Episode {ep+1}/{args.episodes} ---") | |
| observer = None | |
| if args.publish_bus: | |
| from spectator.publisher import BusObserver # lazy: only if asked | |
| episode_id = ( | |
| f"{args.bus_episode_prefix}-{int(time.time())}-{uuid.uuid4().hex[:6]}" | |
| ) | |
| observer = BusObserver( | |
| bus_http, | |
| args.bus_url, | |
| episode_id, | |
| mode=args.model, | |
| task=args.task, | |
| opponent_mode=args.opponent_mode, | |
| eval_pack_id=args.eval_pack_id, | |
| emit_transcript=args.bus_transcript, | |
| ) | |
| _log(f" [bus] watch live at {args.bus_url}/custom (ep id: {episode_id})") | |
| ep_result = await run_episode( | |
| env, | |
| agent, | |
| task=args.task, | |
| verbose=args.verbose, | |
| eval_pack_id=args.eval_pack_id, | |
| opponent_mode=args.opponent_mode, | |
| max_overs=args.max_overs, | |
| step_log=_log, | |
| model_name=args.model, | |
| observer=observer, | |
| ) | |
| results.append(ep_result) | |
| line = ( | |
| f"Episode {ep+1:>3}/{args.episodes} | " | |
| f"Score: {ep_result['total_score']:>3}/{ep_result['wickets_lost']} " | |
| f"({ep_result['over']} ov) | " | |
| f"Reward: {ep_result['total_reward']:>6.3f} | " | |
| f"Coherence: {ep_result['mean_coherence']:.3f} | " | |
| f"Adapt: {ep_result['adaptation']:.3f} | " | |
| f"ParseErr: {ep_result['parse_error_rate']:.1%}" | |
| ) | |
| _log(line) | |
| finally: | |
| if bus_http is not None: | |
| await bus_http.aclose() | |
| _log("\n=== Summary ===") | |
| summary_lines = [] | |
| for key in ["total_score", "wickets_lost", "total_reward", "mean_coherence", "parse_error_rate"]: | |
| vals = [r[key] for r in results] | |
| summary_lines.append(f" {key:20s}: mean={statistics.mean(vals):.3f} std={statistics.stdev(vals) if len(vals)>1 else 0:.3f}") | |
| _log(summary_lines[-1]) | |
| # Write README (always at end — has final summary) | |
| (run_dir / "README.md").write_text( | |
| f"## Inference Run: {run_dir.name}\n\n" | |
| f"**Date**: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n" | |
| f"| Setting | Value |\n|---|---|\n" | |
| f"| Model | `{args.model}` |\n" | |
| f"| API base | `{args.api_base or 'N/A'}` |\n" | |
| f"| Opponent mode | `{args.opponent_mode}` |\n" | |
| f"| Max overs | {args.max_overs} |\n" | |
| f"| Episodes | {args.episodes} |\n" | |
| f"| Task | `{args.task}` |\n\n" | |
| f"### Results\n\n```\n" + "\n".join(summary_lines) + "\n```\n\n" | |
| f"See `run_output.txt` for full verbose episode log.\n" | |
| ) | |
| print(f"\nRun saved → {run_dir}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="CricketCaptain-LLM Baseline Inference") | |
| parser.add_argument("--config", default=None, help="YAML config path (runner defaults).") | |
| parser.add_argument("--model", default="random", help="'random' or OpenAI model name") | |
| parser.add_argument("--episodes", type=int, default=5) | |
| parser.add_argument("--task", default="medium", | |
| choices=["easy", "medium", "hard", "stage2_full", "eval_50over"], | |
| help="Task difficulty: easy=5-over, medium=T20, hard=50-over ODI. " | |
| "stage2_full / eval_50over are legacy aliases.") | |
| parser.add_argument("--max-overs", type=int, default=None, | |
| help="Limit innings length for fast experiments (e.g. 5).") | |
| parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000")) | |
| parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default")) | |
| parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"), | |
| choices=["heuristic", "llm_live", "llm_cached", "cricsheet"]) | |
| # Round-1 spec env-var contract: API_BASE_URL / MODEL_NAME / HF_TOKEN. | |
| # CLI flags --api-base / --api-key / --model override; otherwise we fall | |
| # back to those env vars (with HF Router defaults so the script runs on | |
| # vCPU=2, 8 GB RAM without local model loading). | |
| parser.add_argument("--api-base", default=os.environ.get("API_BASE_URL")) | |
| parser.add_argument("--api-key", default=os.environ.get("HF_TOKEN") or os.environ.get("API_KEY")) | |
| parser.add_argument("--verbose", action="store_true") | |
| # Spectator-bus integration. Off by default; turn on to live-stream this | |
| # run to the cockpit at {bus-url}/custom. | |
| parser.add_argument("--publish-bus", action="store_true", | |
| help="Stream events to the spectator UI bus at /custom/publish.") | |
| parser.add_argument("--bus-url", default=os.environ.get("CRICKET_UI_BASE_URL", "http://localhost:8000"), | |
| help="HTTP base URL of the FastAPI server hosting /custom/publish.") | |
| parser.add_argument("--bus-episode-prefix", default="inf", | |
| help="Prefix used to name episode IDs when publishing.") | |
| parser.add_argument("--bus-transcript", action="store_true", | |
| help="Also emit raw LLM replies as `transcript` events (debug-only).") | |
| args = parser.parse_args() | |
| # If neither --model nor any config overrides it, prefer MODEL_NAME env var. | |
| if args.model == "random" and os.environ.get("MODEL_NAME"): | |
| args.model = os.environ["MODEL_NAME"] | |
| # Default to HF Router when running with an OpenAI-compatible model and no | |
| # explicit api-base set — keeps inference CPU-only as required. | |
| if args.model != "random" and not args.api_base: | |
| args.api_base = "https://router.huggingface.co/v1" | |
| if args.config: | |
| try: | |
| from config_yaml import load_config, apply_runner_config_defaults | |
| except ImportError: | |
| from cricket_captain.config_yaml import load_config, apply_runner_config_defaults | |
| defaults = apply_runner_config_defaults(load_config(args.config)) | |
| if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url: | |
| args.env_url = defaults.env_url | |
| if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id: | |
| args.eval_pack_id = defaults.eval_pack_id | |
| if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode: | |
| args.opponent_mode = defaults.opponent_mode | |
| if args.max_overs is None and defaults.max_overs is not None: | |
| args.max_overs = int(defaults.max_overs) | |
| if args.model == "random" and defaults.captain_model: | |
| args.model = defaults.captain_model | |
| if args.api_base is None and defaults.captain_api_base: | |
| args.api_base = defaults.captain_api_base | |
| if args.api_key is None and defaults.captain_api_key: | |
| args.api_key = defaults.captain_api_key | |
| asyncio.run(evaluate(args)) | |
| if __name__ == "__main__": | |
| main() | |