""" Evaluation and visualisation for CricketCaptain-LLM. Produces: 1. Reward curves — r_cric, r_coherence, r_tools, r_format vs training steps 2. Coherence heatmap — episode × turn matrix (diagonal banding = learning) 3. Tool usage timeline — scatter: over number → tool called 4. Strategy text samples — early vs late training qualitative comparison 5. Before/after canonical — fixed state (Over 35, 180/3) response comparison Usage: export CRICKET_CAPTAIN_ENV_URL="ws:///ws" python eval.py --checkpoint ./checkpoints/stage2_final --episodes 50 --task eval_50over """ import argparse import asyncio import json import os import statistics from pathlib import Path from typing import Any try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.colors as mcolors import numpy as np _PLOT_AVAILABLE = True except ImportError: _PLOT_AVAILABLE = False try: from client import CricketCaptainEnv from models import CricketAction from inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent except ImportError: from cricket_captain.client import CricketCaptainEnv from cricket_captain.models import CricketAction from cricket_captain.inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent # ------------------------------------------------------------------ # # Data collection # # ------------------------------------------------------------------ # async def collect_eval_episodes( env_url: str, agent, n_episodes: int, task: str, eval_pack_id: str = "default", opponent_mode: str = "heuristic", max_overs: int | None = None, ) -> list[dict[str, Any]]: """Run n_episodes and return raw episode data for visualisation.""" episodes = [] async with CricketCaptainEnv(env_url) as env: for ep in range(n_episodes): # OpenEnv server routes reset params via `options`. result = await env.reset(options={ "task": task, "random_start": False, "eval_pack_id": eval_pack_id, "opponent_mode": opponent_mode, "max_overs": max_overs, }) obs = result.observation history = [] step_data = [] turn = 0 while not result.done and turn < 600: messages = [{"role": "user", "content": obs.prompt_text}] raw = agent(messages) action, err = _parse_action(raw) if err or action is None: if obs.game_state == "bowling": action = CricketAction(tool="bowl_delivery", arguments={}) elif obs.game_state == "toss": action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"}) else: action = CricketAction( tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"}, ) ctx = obs.game_context.copy() step_data.append({ "turn": turn, "over": ctx.get("over", 0), "tool": action.tool, "shot_intent": action.arguments.get("shot_intent", ""), "strategic_phase": obs.strategic_phase, "opponent_plan": obs.opponent_plan, "rationale": obs.declared_strategy.get("rationale", ""), "raw_response": raw, "reward": result.reward or 0.0, "parse_error": err, }) result = await env.step(action) obs = result.observation turn += 1 state = await env.state() episodes.append({ "episode": ep, "steps": step_data, "coherence_scores": state.coherence_scores, "adaptation_scores": state.adaptation_scores, "opponent_awareness_scores": state.opponent_awareness_scores, "regret_scores": state.regret_scores, "total_score": state.total_score, "wickets_lost": state.wickets_lost, "tool_calls": state.tool_calls_made, "transcript": state.transcript, }) print(f" Episode {ep+1}/{n_episodes} | " f"Score: {state.total_score}/{state.wickets_lost} | " f"Mean coherence: {statistics.mean(state.coherence_scores) if state.coherence_scores else 0:.3f}") return episodes # ------------------------------------------------------------------ # # Visualisations # # ------------------------------------------------------------------ # def plot_coherence_heatmap(episodes: list[dict], out_dir: Path): if not _PLOT_AVAILABLE: print("matplotlib not available — skipping coherence heatmap") return max_turns = max(len(ep["coherence_scores"]) for ep in episodes) if episodes else 1 matrix = [] for ep in episodes: row = ep["coherence_scores"][:max_turns] row += [float("nan")] * (max_turns - len(row)) matrix.append(row) arr = np.array(matrix, dtype=float) fig, ax = plt.subplots(figsize=(min(max_turns // 3 + 2, 20), max(len(episodes) // 4, 4))) im = ax.imshow(arr, aspect="auto", cmap="YlOrRd", vmin=0.0, vmax=1.0) plt.colorbar(im, ax=ax, label="Coherence score") ax.set_xlabel("Turn (delivery number)") ax.set_ylabel("Episode") ax.set_title("Coherence Heatmap\n(diagonal banding → strategic consistency across episode)") plt.tight_layout() path = out_dir / "coherence_heatmap.png" plt.savefig(path, dpi=150) plt.close() print(f"Saved: {path}") def plot_tool_usage_timeline(episodes: list[dict], out_dir: Path): if not _PLOT_AVAILABLE: print("matplotlib not available — skipping tool timeline") return tool_colors = { "set_strategy": "steelblue", "plan_shot": "royalblue", "choose_bowler": "seagreen", "plan_delivery": "mediumseagreen", "analyze_situation": "darkorange", "play_delivery": "lightgray", "bowl_delivery": "lightgray", "reflect_after_ball": "purple", } fig, ax = plt.subplots(figsize=(14, 5)) for ep_idx, ep in enumerate(episodes[:20]): # cap at 20 episodes for readability for step in ep["steps"]: tool = step["tool"] if tool in ("play_delivery", "bowl_delivery"): continue # too many — skip for clarity ax.scatter(step["over"], ep_idx, color=tool_colors.get(tool, "black"), marker="D" if tool == "set_strategy" else "o", s=60, alpha=0.8, zorder=3) from matplotlib.patches import Patch legend_elements = [ Patch(color="steelblue", label="set_strategy"), Patch(color="royalblue", label="plan_shot"), Patch(color="seagreen", label="choose_bowler"), Patch(color="mediumseagreen", label="plan_delivery"), Patch(color="darkorange", label="analyze_situation"), Patch(color="purple", label="reflect_after_ball"), ] ax.legend(handles=legend_elements, loc="upper right") ax.axvline(6, color="gray", linestyle="--", alpha=0.4, label="Phase transition") ax.axvline(16, color="gray", linestyle="--", alpha=0.4) ax.axvline(36, color="gray", linestyle="--", alpha=0.4) ax.set_xlabel("Over number") ax.set_ylabel("Episode") ax.set_title("Tool Usage Timeline\n(ideal: strategy + analysis cluster at phase transitions)") plt.tight_layout() path = out_dir / "tool_usage_timeline.png" plt.savefig(path, dpi=150) plt.close() print(f"Saved: {path}") def plot_reward_curve(log_file: str | None, out_dir: Path): """Plot reward curves from a JSONL training log.""" if not _PLOT_AVAILABLE: print("matplotlib not available — skipping reward curve") return if not log_file or not os.path.exists(log_file): print(f"No log file found at {log_file} — skipping reward curve") return steps, r_cric, r_coh, r_tools, r_fmt, composite = [], [], [], [], [], [] with open(log_file) as f: for line in f: try: d = json.loads(line) if "r_cric" in d: steps.append(d.get("step", len(steps))) r_cric.append(d["r_cric"]) r_coh.append(d["r_coherence"]) r_tools.append(d["r_tools"]) r_fmt.append(d["r_format"]) composite.append(d["composite"]) except (json.JSONDecodeError, KeyError): continue if not steps: print("No reward data in log file") return fig, axes = plt.subplots(2, 3, figsize=(16, 8)) for ax, (label, vals) in zip( axes.flat, [("r_cric", r_cric), ("r_coherence", r_coh), ("r_tools", r_tools), ("r_format", r_fmt), ("composite", composite)], ): ax.plot(steps, vals, linewidth=1.5) ax.set_title(label) ax.set_xlabel("Training step") ax.set_ylabel("Reward") ax.grid(True, alpha=0.3) axes.flat[-1].axis("off") plt.suptitle("CricketCaptain-LLM Training Reward Curves") plt.tight_layout() path = out_dir / "reward_curves.png" plt.savefig(path, dpi=150) plt.close() print(f"Saved: {path}") def print_strategy_samples(episodes: list[dict], label: str = ""): """Print sample strategy declarations for qualitative analysis.""" print(f"\n=== Strategy Samples {label} ===") seen = set() for ep in episodes[:10]: for step in ep["steps"]: r = step.get("rationale", "").strip() if r and r not in seen and step["tool"] == "set_strategy": seen.add(r) print(f" [{ep['episode']}:{step['over']}] {r}") if len(seen) >= 8: return def print_summary(episodes: list[dict]): all_coherence = [s for ep in episodes for s in ep["coherence_scores"]] all_adaptation = [s for ep in episodes for s in ep.get("adaptation_scores", [])] all_awareness = [s for ep in episodes for s in ep.get("opponent_awareness_scores", [])] all_regret = [s for ep in episodes for s in ep.get("regret_scores", [])] all_scores = [ep["total_score"] for ep in episodes] all_wickets = [ep["wickets_lost"] for ep in episodes] print("\n=== Evaluation Summary ===") print(f" Episodes: {len(episodes)}") print(f" Mean score: {statistics.mean(all_scores):.1f} ± {statistics.stdev(all_scores) if len(all_scores)>1 else 0:.1f}") print(f" Mean wickets: {statistics.mean(all_wickets):.1f}") print(f" Mean coherence: {statistics.mean(all_coherence) if all_coherence else 0:.4f}") print(f" Mean adaptation: {statistics.mean(all_adaptation) if all_adaptation else 0:.4f}") print(f" Mean opp aware: {statistics.mean(all_awareness) if all_awareness else 0:.4f}") print(f" Mean regret score: {statistics.mean(all_regret) if all_regret else 0:.4f}") print(f" Coherence stdev: {statistics.stdev(all_coherence) if len(all_coherence)>1 else 0:.4f}") # ------------------------------------------------------------------ # # CLI # # ------------------------------------------------------------------ # async def _run_eval(args): agent = RandomAgent() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) print(f"Collecting {args.episodes} evaluation episodes...") episodes = await collect_eval_episodes( args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs ) print_summary(episodes) print_strategy_samples(episodes, label=f"(task={args.task})") if _PLOT_AVAILABLE: plot_coherence_heatmap(episodes, out_dir) plot_tool_usage_timeline(episodes, out_dir) plot_reward_curve(args.log_file, out_dir) print(f"\nPlots saved to {out_dir}/") else: print("Install matplotlib+numpy for visualisations: pip install matplotlib numpy") # Dump raw data raw_path = out_dir / "eval_episodes.jsonl" with open(raw_path, "w") as f: for ep in episodes: f.write(json.dumps(ep) + "\n") print(f"Raw data saved to {raw_path}") def main(): parser = argparse.ArgumentParser(description="CricketCaptain-LLM Evaluation") parser.add_argument("--config", default=None, help="YAML config path (runner defaults).") parser.add_argument("--episodes", type=int, default=20) parser.add_argument("--task", default="medium", choices=["easy", "medium", "hard", "stage2_full", "eval_50over"]) parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000")) parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default")) parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"), choices=["heuristic", "llm_live", "llm_cached"]) parser.add_argument("--max-overs", type=int, default=None, help="Limit innings length for fast experiments (e.g. 5).") parser.add_argument("--out-dir", default="./eval_output") parser.add_argument("--log-file", default=None, help="Path to JSONL training log for reward curves") parser.add_argument("--checkpoint", default=None, help="Checkpoint path (unused by default — agent is random baseline)") args = parser.parse_args() if args.config: try: from config_yaml import load_config, apply_runner_config_defaults except ImportError: from cricket_captain.config_yaml import load_config, apply_runner_config_defaults defaults = apply_runner_config_defaults(load_config(args.config)) if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url: args.env_url = defaults.env_url if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id: args.eval_pack_id = defaults.eval_pack_id if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode: args.opponent_mode = defaults.opponent_mode if args.max_overs is None and defaults.max_overs is not None: args.max_overs = int(defaults.max_overs) asyncio.run(_run_eval(args)) if __name__ == "__main__": main()