cricket-captain-llm / inference.py
pratinavseth's picture
sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified
"""
Baseline inference script for CricketCaptain-LLM.
Runs an LLM agent against all task difficulties (easy=T5, medium=T20, hard=ODI).
Emits [START], [STEP], [END] stdout lines per the OpenEnv benchmark spec.
Required environment variables (Round 1 spec):
API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1)
MODEL_NAME Model identifier (default: 'random' baseline)
HF_TOKEN HF API key (also accepted as API_KEY)
LOCAL_IMAGE_NAME Optional Docker image for from_docker_image() runs
Designed to run on vCPU=2, 8 GB RAM — uses HF Router by default so no local
model load is required.
Optional spectator-bus integration (`--publish-bus`) live-streams every
episode to the cockpit at `{bus-url}/custom`. This is purely additive: the
benchmark stdout markers and grader output are unaffected.
Usage:
# Random baseline (no API key)
python inference.py --model random --episodes 5 --task easy
# API model via HF Router (defaults to API_BASE_URL env var)
MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct HF_TOKEN=hf_... \
python inference.py --episodes 10 --task medium
# Trained checkpoint served via vLLM at a custom endpoint
python inference.py --model ./checkpoints/stage2_final --episodes 20 \
--api-base http://localhost:8080/v1
# Live-stream a real-LLM run to the cockpit UI (server on :8000)
python inference.py --model gpt-4o-mini --episodes 1 --max-overs 5 \
--publish-bus --bus-url http://127.0.0.1:8000
"""
import argparse
import asyncio
import datetime
import json
import os
import random
import statistics
import time
import uuid
from pathlib import Path
from typing import Any
try:
import openai
_OPENAI_AVAILABLE = True
except ImportError:
_OPENAI_AVAILABLE = False
try:
from client import CricketCaptainEnv
from models import CricketAction
except ImportError:
from cricket_captain.client import CricketCaptainEnv
from cricket_captain.models import CricketAction
try:
import httpx
_HTTPX_AVAILABLE = True
except ImportError:
_HTTPX_AVAILABLE = False
SYSTEM_PROMPT = """You are an expert cricket captain. You must manage the team through the Toss, Batting, and Bowling phases.
Your goal is to win the match. You receive a scorecard and must respond with a SINGLE valid JSON tool call from the available tools.
Batting Tools:
select_batter — Choose batter profile for the situation
set_strategy — Declare batting intent (aggression, rationale)
plan_shot — Plan target area, risk, and shot intent before execution
play_delivery — Choose a shot and advance the game
Bowling Tools:
choose_bowler — Choose bowler profile for the situation
set_bowling_strategy — Set bowler type, line, length, and delivery type
plan_delivery — Plan line, length, and variation before execution
set_field_setting — Set field preset (Aggressive, Balanced, Defensive)
bowl_delivery — Advance the game during bowling phase
Common Tools:
call_toss — Call heads/tails and make a decision (bat/bowl)
analyze_situation — Query match context
reflect_after_ball — Briefly update plan after the previous ball
Always respond with exactly one JSON object on a single line, no markdown."""
SHOT_AGGRESSION_ORDER = ["leave", "defensive", "single", "rotate", "boundary", "six"]
def _coerce_aggression(value: Any, default: float = 0.5) -> float:
if isinstance(value, (int, float)):
return max(0.0, min(1.0, float(value)))
text = str(value).strip().lower()
word_map = {
"very low": 0.15,
"low": 0.25,
"conservative": 0.25,
"defensive": 0.25,
"moderate": 0.5,
"medium": 0.5,
"balanced": 0.5,
"normal": 0.5,
"high": 0.75,
"aggressive": 0.75,
"very high": 0.9,
"attack": 0.8,
"attacking": 0.8,
}
try:
return max(0.0, min(1.0, float(text)))
except ValueError:
return word_map.get(text, default)
def _normalize_action_args(tool: str, args: dict[str, Any]) -> dict[str, Any]:
"""Normalize common LLM variants before sending to the server."""
normalized = dict(args)
if tool in ("set_strategy", "select_batter") and "aggression" in normalized:
normalized["aggression"] = _coerce_aggression(normalized["aggression"])
if tool == "plan_shot" and str(normalized.get("risk", "")).lower() == "moderate":
normalized["risk"] = "balanced"
if tool == "call_toss":
call = str(normalized.get("call", "heads")).lower()
decision = str(normalized.get("decision", "bat")).lower()
normalized["call"] = call if call in ("heads", "tails") else "heads"
normalized["decision"] = decision if decision in ("bat", "bowl") else "bat"
return normalized
class RandomAgent:
"""Baseline: random valid tool calls based on availability."""
def __call__(self, messages: list[dict]) -> str:
prompt = messages[-1]["content"] if messages else ""
# In a real scenario, we'd parse available_tools from the prompt/observation.
# Here we'll just check some keywords in the prompt to guess the phase.
if "TOSS" in prompt:
return json.dumps({
"tool": "call_toss",
"arguments": {"call": random.choice(["heads", "tails"]), "decision": random.choice(["bat", "bowl"])}
})
if "BOWLING" in prompt:
roll = random.random()
if roll < 0.12:
return json.dumps({
"tool": "choose_bowler",
"arguments": {
"name": random.choice(["Strike Pacer", "Control Spinner", "Death Specialist"]),
"bowler_type": random.choice(["pace", "spin"]),
"style": random.choice(["swing", "economy", "yorker"]),
"rationale": "Match bowler type to phase and batter style.",
},
})
if roll < 0.28:
return json.dumps({
"tool": "set_bowling_strategy",
"arguments": {
"bowler_type": random.choice(["pace", "spin"]),
"line": random.choice(["stumps", "outside off", "on pads"]),
"length": random.choice(["good length", "full", "short"]),
"delivery_type": "stock",
"rationale": "Random bowling strategy."
}
})
elif roll < 0.45:
return json.dumps({
"tool": "plan_delivery",
"arguments": {
"bowler_type": random.choice(["pace", "spin"]),
"line": random.choice(["stumps", "outside off", "wide"]),
"length": random.choice(["good length", "full", "short", "yorker"]),
"delivery_type": random.choice(["stock", "slower ball", "yorker", "bouncer"]),
"rationale": "Plan delivery against current batter and field.",
},
})
elif roll < 0.58:
return json.dumps({
"tool": "set_field_setting",
"arguments": {"setting": random.choice(["Aggressive", "Balanced", "Defensive"])}
})
elif roll < 0.65:
return json.dumps({
"tool": "reflect_after_ball",
"arguments": {"reflection": "Adjust based on the previous ball outcome and match pressure."},
})
else:
return json.dumps({"tool": "bowl_delivery", "arguments": {}})
# Default Batting logic
roll = random.random()
if roll < 0.1:
return json.dumps({
"tool": "select_batter",
"arguments": {
"name": random.choice(["Opener", "Anchor", "Finisher"]),
"style": random.choice(["balanced", "anchor", "aggressive"]),
"aggression": round(random.uniform(0.2, 0.8), 2),
"rationale": "Choose batter profile for phase, target, and wicket context.",
},
})
if roll < 0.25:
agg = round(random.uniform(0.1, 0.9), 2)
return json.dumps({
"tool": "set_strategy",
"arguments": {
"phase_intent": random.choice(["consolidate", "attack", "rotate"]),
"aggression": agg,
"rationale": "Random strategy selection.",
}
})
elif roll < 0.4:
return json.dumps({
"tool": "plan_shot",
"arguments": {
"shot_intent": random.choice(SHOT_AGGRESSION_ORDER),
"target_area": random.choice(["off-side gap", "leg-side gap", "straight", "boundary"]),
"risk": random.choice(["low", "balanced", "high"]),
"rationale": "Match shot plan to bowler, field, and required rate.",
},
})
elif roll < 0.5:
return json.dumps({
"tool": "analyze_situation",
"arguments": {"query_type": random.choice(["pitch_conditions", "match_situation"])},
})
elif roll < 0.58:
return json.dumps({
"tool": "reflect_after_ball",
"arguments": {"reflection": "Revise risk after the previous delivery and current target pressure."},
})
else:
return json.dumps({
"tool": "play_delivery",
"arguments": {
"shot_intent": random.choice(SHOT_AGGRESSION_ORDER),
"explanation": "Random shot selection.",
},
})
class OpenAIAgent:
def __init__(
self,
model: str,
api_base: str | None = None,
api_key: str | None = None,
timeout: float = 30.0,
temperature: float = 0.2,
):
if not _OPENAI_AVAILABLE:
raise ImportError("openai not installed. Run: pip install openai")
self._client = openai.OpenAI(
base_url=api_base,
api_key=api_key or "dummy",
timeout=timeout,
)
self._model = model
self._temperature = temperature
def __call__(self, messages: list[dict]) -> str:
resp = self._client.chat.completions.create(
model=self._model,
messages=messages,
temperature=self._temperature,
max_tokens=300,
)
return resp.choices[0].message.content.strip()
def _parse_action(raw: str) -> tuple[CricketAction | None, bool]:
raw = raw.strip()
if raw.startswith("```"):
lines = raw.split("\n")
raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw
try:
if not raw.startswith("{"):
start = raw.find("{")
if start >= 0:
raw = raw[start:]
data, _ = json.JSONDecoder().raw_decode(raw)
valid_tools = (
"set_strategy", "analyze_situation", "play_delivery",
"call_toss", "bowl_delivery", "set_bowling_strategy", "set_field_setting",
"choose_bowler", "select_batter", "plan_delivery", "plan_shot", "reflect_after_ball",
"set_match_plan", "update_match_plan",
)
if "tool" not in data and len(data) == 1:
maybe_tool, maybe_args = next(iter(data.items()))
if maybe_tool in valid_tools and isinstance(maybe_args, dict):
data = {"tool": maybe_tool, "arguments": maybe_args}
tool = data.get("tool", "")
if tool not in valid_tools:
return None, True
return CricketAction(tool=tool, arguments=_normalize_action_args(tool, data.get("arguments", {}))), False
except Exception:
return None, True
def _action_short(action: CricketAction) -> str:
"""Compact one-line action string for [STEP] markers."""
args = action.arguments or {}
primary = (
args.get("shot_intent")
or args.get("delivery_type")
or args.get("phase_intent")
or args.get("setting")
or args.get("call")
)
return f"{action.tool}({primary})" if primary else action.tool
async def run_episode(
env: CricketCaptainEnv,
agent,
task: str = "stage2_full",
max_steps: int = 600,
verbose: bool = False,
eval_pack_id: str = "default",
opponent_mode: str = "heuristic",
max_overs: int | None = None,
step_log=None, # callable(str) — called after every delivery for live log
benchmark_stdout: bool = True, # emit [START]/[STEP]/[END] markers
model_name: str = "random", # threaded into the [START] line
observer=None, # optional spectator.publisher.BusObserver
) -> dict[str, Any]:
# All reset params must go inside options={} — EnvClient.reset(**kwargs) only
# forwards recognised signature params (seed, options); bare kwargs are dropped.
result = await env.reset(options={
"task": task,
"random_start": False,
"eval_pack_id": eval_pack_id,
"opponent_mode": opponent_mode,
"max_overs": max_overs,
})
obs = result.observation
if observer is not None:
await observer.start(obs)
history: list[dict] = []
rewards: list[float] = []
parse_errors = 0
deliveries = 0
turn = 0
_AGENT_TIMEOUT = 45.0
if benchmark_stdout:
print(f"[START] task={task} env=cricket_captain model={model_name}", flush=True)
while not result.done and turn < max_steps:
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history[-10:] + [
{"role": "user", "content": obs.prompt_text}
]
try:
raw = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, agent, messages),
timeout=_AGENT_TIMEOUT,
)
except asyncio.TimeoutError:
raw = ""
action, err = _parse_action(raw)
if err:
parse_errors += 1
if "BOWLING" in obs.prompt_text:
action = CricketAction(tool="bowl_delivery", arguments={})
elif "TOSS" in obs.prompt_text:
action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
else:
action = CricketAction(tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"})
action_dict = {"tool": action.tool, "arguments": action.arguments}
pre = None
if observer is not None:
pre = await observer.before_step(
action_dict, obs, turn=turn, transcript_text=raw,
)
try:
result = await env.step(action)
except Exception as exc: # noqa: BLE001
if observer is not None:
await observer.emit_error(f"env.step failed: {exc}", turn=turn, tool=action.tool)
raise
obs = result.observation
r = result.reward or 0.0
rewards.append(r)
turn += 1
is_delivery = action.tool in ("play_delivery", "bowl_delivery")
if is_delivery:
deliveries += 1
if benchmark_stdout:
err_field = err.replace("\n", " ").replace(" ", "_")[:60] if err else "null"
print(
f"[STEP] step={turn} action={_action_short(action)} "
f"reward={r:.2f} done={'true' if result.done else 'false'} error={err_field}",
flush=True,
)
if observer is not None:
await observer.after_step(
action_dict, pre or {}, obs, r,
turn=turn,
fetch_state=env.state,
)
if is_delivery and step_log:
ctx = obs.game_context
step_log(
f" over={ctx.get('over', '?')}.{ctx.get('ball', '?')} "
f"score={ctx.get('score', '?')}/{ctx.get('wickets', '?')} "
f"tool={action.tool} r={r:.3f} | {obs.last_ball_result[:60]}"
)
elif verbose:
print(f" [{turn}] [{obs.game_state.upper()}] {raw[:60]} → r={r:.3f}")
history.append({"role": "assistant", "content": raw})
history.append({"role": "user", "content": obs.last_ball_result})
state = await env.state()
if observer is not None:
await observer.end(state)
match_result = getattr(state, "match_result", None) or ""
success = (match_result == "win")
# Score in [0, 1]: win=1.0, tie=0.5, loss=0.0. Matches Round 1 grader-output spec.
if match_result == "win":
score = 1.0
elif match_result == "tie":
score = 0.5
else:
score = 0.0
if benchmark_stdout:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={'true' if success else 'false'} steps={turn} "
f"score={score:.2f} rewards={rewards_str}",
flush=True,
)
coherence = state.coherence_scores
return {
"total_score": state.total_score,
"wickets_lost": state.wickets_lost,
"over": state.over,
"total_reward": sum(rewards),
"mean_coherence": statistics.mean(coherence) if coherence else 0.0,
"parse_error_rate": parse_errors / max(turn, 1),
"tool_calls": state.tool_calls_made,
"adaptation": statistics.mean(state.adaptation_scores) if state.adaptation_scores else 0.0,
"opponent_awareness": statistics.mean(state.opponent_awareness_scores) if state.opponent_awareness_scores else 0.0,
"regret": statistics.mean(state.regret_scores) if state.regret_scores else 0.0,
"deliveries": deliveries,
"game_state": state.game_state,
"target": state.target,
}
def _make_inference_run_folder(model: str, opponent_mode: str, max_overs: int | None) -> Path:
ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
model_short = model.split("/")[-1][:20] if model != "random" else "random"
overs_str = f"_{max_overs}ov" if max_overs else ""
opp_str = f"_{opponent_mode}"
folder_name = f"exp_{ts}_inference{overs_str}{opp_str}_{model_short}"
run_dir = Path(__file__).parent / "illustrations" / folder_name
run_dir.mkdir(parents=True, exist_ok=True)
return run_dir
async def evaluate(args):
agent: Any
if args.model == "random":
agent = RandomAgent()
print("Using RandomAgent baseline")
else:
agent = OpenAIAgent(args.model, api_base=args.api_base, api_key=args.api_key)
print(f"Using OpenAI-compatible agent: {args.model}")
run_dir = _make_inference_run_folder(args.model, args.opponent_mode, args.max_overs)
log_file = run_dir / "run_output.txt"
# Write header immediately so the file exists and is readable while running
header = "\n".join([
f"# Inference run: {run_dir.name}",
f"timestamp_utc: {datetime.datetime.utcnow().isoformat()}",
f"model: {args.model}",
f"api_base: {args.api_base}",
f"opponent_mode: {args.opponent_mode}",
f"max_overs: {args.max_overs}",
f"episodes: {args.episodes}",
f"task: {args.task}",
f"eval_pack_id: {args.eval_pack_id}",
"",
])
log_file.write_text(header)
def _log(msg: str):
print(msg)
with open(log_file, "a") as f:
f.write(msg + "\n")
results = []
bus_http: Any = None
if args.publish_bus:
if not _HTTPX_AVAILABLE:
raise RuntimeError("--publish-bus requires httpx. Run: pip install httpx")
bus_http = httpx.AsyncClient(timeout=15.0)
try:
async with CricketCaptainEnv(args.env_url) as env:
for ep in range(args.episodes):
_log(f"\n--- Episode {ep+1}/{args.episodes} ---")
observer = None
if args.publish_bus:
from spectator.publisher import BusObserver # lazy: only if asked
episode_id = (
f"{args.bus_episode_prefix}-{int(time.time())}-{uuid.uuid4().hex[:6]}"
)
observer = BusObserver(
bus_http,
args.bus_url,
episode_id,
mode=args.model,
task=args.task,
opponent_mode=args.opponent_mode,
eval_pack_id=args.eval_pack_id,
emit_transcript=args.bus_transcript,
)
_log(f" [bus] watch live at {args.bus_url}/custom (ep id: {episode_id})")
ep_result = await run_episode(
env,
agent,
task=args.task,
verbose=args.verbose,
eval_pack_id=args.eval_pack_id,
opponent_mode=args.opponent_mode,
max_overs=args.max_overs,
step_log=_log,
model_name=args.model,
observer=observer,
)
results.append(ep_result)
line = (
f"Episode {ep+1:>3}/{args.episodes} | "
f"Score: {ep_result['total_score']:>3}/{ep_result['wickets_lost']} "
f"({ep_result['over']} ov) | "
f"Reward: {ep_result['total_reward']:>6.3f} | "
f"Coherence: {ep_result['mean_coherence']:.3f} | "
f"Adapt: {ep_result['adaptation']:.3f} | "
f"ParseErr: {ep_result['parse_error_rate']:.1%}"
)
_log(line)
finally:
if bus_http is not None:
await bus_http.aclose()
_log("\n=== Summary ===")
summary_lines = []
for key in ["total_score", "wickets_lost", "total_reward", "mean_coherence", "parse_error_rate"]:
vals = [r[key] for r in results]
summary_lines.append(f" {key:20s}: mean={statistics.mean(vals):.3f} std={statistics.stdev(vals) if len(vals)>1 else 0:.3f}")
_log(summary_lines[-1])
# Write README (always at end — has final summary)
(run_dir / "README.md").write_text(
f"## Inference Run: {run_dir.name}\n\n"
f"**Date**: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n"
f"| Setting | Value |\n|---|---|\n"
f"| Model | `{args.model}` |\n"
f"| API base | `{args.api_base or 'N/A'}` |\n"
f"| Opponent mode | `{args.opponent_mode}` |\n"
f"| Max overs | {args.max_overs} |\n"
f"| Episodes | {args.episodes} |\n"
f"| Task | `{args.task}` |\n\n"
f"### Results\n\n```\n" + "\n".join(summary_lines) + "\n```\n\n"
f"See `run_output.txt` for full verbose episode log.\n"
)
print(f"\nRun saved → {run_dir}")
def main():
parser = argparse.ArgumentParser(description="CricketCaptain-LLM Baseline Inference")
parser.add_argument("--config", default=None, help="YAML config path (runner defaults).")
parser.add_argument("--model", default="random", help="'random' or OpenAI model name")
parser.add_argument("--episodes", type=int, default=5)
parser.add_argument("--task", default="medium",
choices=["easy", "medium", "hard", "stage2_full", "eval_50over"],
help="Task difficulty: easy=5-over, medium=T20, hard=50-over ODI. "
"stage2_full / eval_50over are legacy aliases.")
parser.add_argument("--max-overs", type=int, default=None,
help="Limit innings length for fast experiments (e.g. 5).")
parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
choices=["heuristic", "llm_live", "llm_cached", "cricsheet"])
# Round-1 spec env-var contract: API_BASE_URL / MODEL_NAME / HF_TOKEN.
# CLI flags --api-base / --api-key / --model override; otherwise we fall
# back to those env vars (with HF Router defaults so the script runs on
# vCPU=2, 8 GB RAM without local model loading).
parser.add_argument("--api-base", default=os.environ.get("API_BASE_URL"))
parser.add_argument("--api-key", default=os.environ.get("HF_TOKEN") or os.environ.get("API_KEY"))
parser.add_argument("--verbose", action="store_true")
# Spectator-bus integration. Off by default; turn on to live-stream this
# run to the cockpit at {bus-url}/custom.
parser.add_argument("--publish-bus", action="store_true",
help="Stream events to the spectator UI bus at /custom/publish.")
parser.add_argument("--bus-url", default=os.environ.get("CRICKET_UI_BASE_URL", "http://localhost:8000"),
help="HTTP base URL of the FastAPI server hosting /custom/publish.")
parser.add_argument("--bus-episode-prefix", default="inf",
help="Prefix used to name episode IDs when publishing.")
parser.add_argument("--bus-transcript", action="store_true",
help="Also emit raw LLM replies as `transcript` events (debug-only).")
args = parser.parse_args()
# If neither --model nor any config overrides it, prefer MODEL_NAME env var.
if args.model == "random" and os.environ.get("MODEL_NAME"):
args.model = os.environ["MODEL_NAME"]
# Default to HF Router when running with an OpenAI-compatible model and no
# explicit api-base set — keeps inference CPU-only as required.
if args.model != "random" and not args.api_base:
args.api_base = "https://router.huggingface.co/v1"
if args.config:
try:
from config_yaml import load_config, apply_runner_config_defaults
except ImportError:
from cricket_captain.config_yaml import load_config, apply_runner_config_defaults
defaults = apply_runner_config_defaults(load_config(args.config))
if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url:
args.env_url = defaults.env_url
if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id:
args.eval_pack_id = defaults.eval_pack_id
if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode:
args.opponent_mode = defaults.opponent_mode
if args.max_overs is None and defaults.max_overs is not None:
args.max_overs = int(defaults.max_overs)
if args.model == "random" and defaults.captain_model:
args.model = defaults.captain_model
if args.api_base is None and defaults.captain_api_base:
args.api_base = defaults.captain_api_base
if args.api_key is None and defaults.captain_api_key:
args.api_key = defaults.captain_api_key
asyncio.run(evaluate(args))
if __name__ == "__main__":
main()