sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified | """ | |
| Evaluation and visualisation for CricketCaptain-LLM. | |
| Produces: | |
| 1. Reward curves — r_cric, r_coherence, r_tools, r_format vs training steps | |
| 2. Coherence heatmap — episode × turn matrix (diagonal banding = learning) | |
| 3. Tool usage timeline — scatter: over number → tool called | |
| 4. Strategy text samples — early vs late training qualitative comparison | |
| 5. Before/after canonical — fixed state (Over 35, 180/3) response comparison | |
| Usage: | |
| export CRICKET_CAPTAIN_ENV_URL="ws://<reachable-host>/ws" | |
| python eval.py --checkpoint ./checkpoints/stage2_final --episodes 50 --task eval_50over | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| import os | |
| import statistics | |
| from pathlib import Path | |
| from typing import Any | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| import numpy as np | |
| _PLOT_AVAILABLE = True | |
| except ImportError: | |
| _PLOT_AVAILABLE = False | |
| try: | |
| from client import CricketCaptainEnv | |
| from models import CricketAction | |
| from inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent | |
| except ImportError: | |
| from cricket_captain.client import CricketCaptainEnv | |
| from cricket_captain.models import CricketAction | |
| from cricket_captain.inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent | |
| # ------------------------------------------------------------------ # | |
| # Data collection # | |
| # ------------------------------------------------------------------ # | |
| async def collect_eval_episodes( | |
| env_url: str, | |
| agent, | |
| n_episodes: int, | |
| task: str, | |
| eval_pack_id: str = "default", | |
| opponent_mode: str = "heuristic", | |
| max_overs: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Run n_episodes and return raw episode data for visualisation.""" | |
| episodes = [] | |
| async with CricketCaptainEnv(env_url) as env: | |
| for ep in range(n_episodes): | |
| # OpenEnv server routes reset params via `options`. | |
| result = await env.reset(options={ | |
| "task": task, | |
| "random_start": False, | |
| "eval_pack_id": eval_pack_id, | |
| "opponent_mode": opponent_mode, | |
| "max_overs": max_overs, | |
| }) | |
| obs = result.observation | |
| history = [] | |
| step_data = [] | |
| turn = 0 | |
| while not result.done and turn < 600: | |
| messages = [{"role": "user", "content": obs.prompt_text}] | |
| raw = agent(messages) | |
| action, err = _parse_action(raw) | |
| if err or action is None: | |
| if obs.game_state == "bowling": | |
| action = CricketAction(tool="bowl_delivery", arguments={}) | |
| elif obs.game_state == "toss": | |
| action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"}) | |
| else: | |
| action = CricketAction( | |
| tool="play_delivery", | |
| arguments={"shot_intent": "defensive", "explanation": "fallback"}, | |
| ) | |
| ctx = obs.game_context.copy() | |
| step_data.append({ | |
| "turn": turn, | |
| "over": ctx.get("over", 0), | |
| "tool": action.tool, | |
| "shot_intent": action.arguments.get("shot_intent", ""), | |
| "strategic_phase": obs.strategic_phase, | |
| "opponent_plan": obs.opponent_plan, | |
| "rationale": obs.declared_strategy.get("rationale", ""), | |
| "raw_response": raw, | |
| "reward": result.reward or 0.0, | |
| "parse_error": err, | |
| }) | |
| result = await env.step(action) | |
| obs = result.observation | |
| turn += 1 | |
| state = await env.state() | |
| episodes.append({ | |
| "episode": ep, | |
| "steps": step_data, | |
| "coherence_scores": state.coherence_scores, | |
| "adaptation_scores": state.adaptation_scores, | |
| "opponent_awareness_scores": state.opponent_awareness_scores, | |
| "regret_scores": state.regret_scores, | |
| "total_score": state.total_score, | |
| "wickets_lost": state.wickets_lost, | |
| "tool_calls": state.tool_calls_made, | |
| "transcript": state.transcript, | |
| }) | |
| print(f" Episode {ep+1}/{n_episodes} | " | |
| f"Score: {state.total_score}/{state.wickets_lost} | " | |
| f"Mean coherence: {statistics.mean(state.coherence_scores) if state.coherence_scores else 0:.3f}") | |
| return episodes | |
| # ------------------------------------------------------------------ # | |
| # Visualisations # | |
| # ------------------------------------------------------------------ # | |
| def plot_coherence_heatmap(episodes: list[dict], out_dir: Path): | |
| if not _PLOT_AVAILABLE: | |
| print("matplotlib not available — skipping coherence heatmap") | |
| return | |
| max_turns = max(len(ep["coherence_scores"]) for ep in episodes) if episodes else 1 | |
| matrix = [] | |
| for ep in episodes: | |
| row = ep["coherence_scores"][:max_turns] | |
| row += [float("nan")] * (max_turns - len(row)) | |
| matrix.append(row) | |
| arr = np.array(matrix, dtype=float) | |
| fig, ax = plt.subplots(figsize=(min(max_turns // 3 + 2, 20), max(len(episodes) // 4, 4))) | |
| im = ax.imshow(arr, aspect="auto", cmap="YlOrRd", vmin=0.0, vmax=1.0) | |
| plt.colorbar(im, ax=ax, label="Coherence score") | |
| ax.set_xlabel("Turn (delivery number)") | |
| ax.set_ylabel("Episode") | |
| ax.set_title("Coherence Heatmap\n(diagonal banding → strategic consistency across episode)") | |
| plt.tight_layout() | |
| path = out_dir / "coherence_heatmap.png" | |
| plt.savefig(path, dpi=150) | |
| plt.close() | |
| print(f"Saved: {path}") | |
| def plot_tool_usage_timeline(episodes: list[dict], out_dir: Path): | |
| if not _PLOT_AVAILABLE: | |
| print("matplotlib not available — skipping tool timeline") | |
| return | |
| tool_colors = { | |
| "set_strategy": "steelblue", | |
| "plan_shot": "royalblue", | |
| "choose_bowler": "seagreen", | |
| "plan_delivery": "mediumseagreen", | |
| "analyze_situation": "darkorange", | |
| "play_delivery": "lightgray", | |
| "bowl_delivery": "lightgray", | |
| "reflect_after_ball": "purple", | |
| } | |
| fig, ax = plt.subplots(figsize=(14, 5)) | |
| for ep_idx, ep in enumerate(episodes[:20]): # cap at 20 episodes for readability | |
| for step in ep["steps"]: | |
| tool = step["tool"] | |
| if tool in ("play_delivery", "bowl_delivery"): | |
| continue # too many — skip for clarity | |
| ax.scatter(step["over"], ep_idx, | |
| color=tool_colors.get(tool, "black"), | |
| marker="D" if tool == "set_strategy" else "o", | |
| s=60, alpha=0.8, zorder=3) | |
| from matplotlib.patches import Patch | |
| legend_elements = [ | |
| Patch(color="steelblue", label="set_strategy"), | |
| Patch(color="royalblue", label="plan_shot"), | |
| Patch(color="seagreen", label="choose_bowler"), | |
| Patch(color="mediumseagreen", label="plan_delivery"), | |
| Patch(color="darkorange", label="analyze_situation"), | |
| Patch(color="purple", label="reflect_after_ball"), | |
| ] | |
| ax.legend(handles=legend_elements, loc="upper right") | |
| ax.axvline(6, color="gray", linestyle="--", alpha=0.4, label="Phase transition") | |
| ax.axvline(16, color="gray", linestyle="--", alpha=0.4) | |
| ax.axvline(36, color="gray", linestyle="--", alpha=0.4) | |
| ax.set_xlabel("Over number") | |
| ax.set_ylabel("Episode") | |
| ax.set_title("Tool Usage Timeline\n(ideal: strategy + analysis cluster at phase transitions)") | |
| plt.tight_layout() | |
| path = out_dir / "tool_usage_timeline.png" | |
| plt.savefig(path, dpi=150) | |
| plt.close() | |
| print(f"Saved: {path}") | |
| def plot_reward_curve(log_file: str | None, out_dir: Path): | |
| """Plot reward curves from a JSONL training log.""" | |
| if not _PLOT_AVAILABLE: | |
| print("matplotlib not available — skipping reward curve") | |
| return | |
| if not log_file or not os.path.exists(log_file): | |
| print(f"No log file found at {log_file} — skipping reward curve") | |
| return | |
| steps, r_cric, r_coh, r_tools, r_fmt, composite = [], [], [], [], [], [] | |
| with open(log_file) as f: | |
| for line in f: | |
| try: | |
| d = json.loads(line) | |
| if "r_cric" in d: | |
| steps.append(d.get("step", len(steps))) | |
| r_cric.append(d["r_cric"]) | |
| r_coh.append(d["r_coherence"]) | |
| r_tools.append(d["r_tools"]) | |
| r_fmt.append(d["r_format"]) | |
| composite.append(d["composite"]) | |
| except (json.JSONDecodeError, KeyError): | |
| continue | |
| if not steps: | |
| print("No reward data in log file") | |
| return | |
| fig, axes = plt.subplots(2, 3, figsize=(16, 8)) | |
| for ax, (label, vals) in zip( | |
| axes.flat, | |
| [("r_cric", r_cric), ("r_coherence", r_coh), ("r_tools", r_tools), | |
| ("r_format", r_fmt), ("composite", composite)], | |
| ): | |
| ax.plot(steps, vals, linewidth=1.5) | |
| ax.set_title(label) | |
| ax.set_xlabel("Training step") | |
| ax.set_ylabel("Reward") | |
| ax.grid(True, alpha=0.3) | |
| axes.flat[-1].axis("off") | |
| plt.suptitle("CricketCaptain-LLM Training Reward Curves") | |
| plt.tight_layout() | |
| path = out_dir / "reward_curves.png" | |
| plt.savefig(path, dpi=150) | |
| plt.close() | |
| print(f"Saved: {path}") | |
| def print_strategy_samples(episodes: list[dict], label: str = ""): | |
| """Print sample strategy declarations for qualitative analysis.""" | |
| print(f"\n=== Strategy Samples {label} ===") | |
| seen = set() | |
| for ep in episodes[:10]: | |
| for step in ep["steps"]: | |
| r = step.get("rationale", "").strip() | |
| if r and r not in seen and step["tool"] == "set_strategy": | |
| seen.add(r) | |
| print(f" [{ep['episode']}:{step['over']}] {r}") | |
| if len(seen) >= 8: | |
| return | |
| def print_summary(episodes: list[dict]): | |
| all_coherence = [s for ep in episodes for s in ep["coherence_scores"]] | |
| all_adaptation = [s for ep in episodes for s in ep.get("adaptation_scores", [])] | |
| all_awareness = [s for ep in episodes for s in ep.get("opponent_awareness_scores", [])] | |
| all_regret = [s for ep in episodes for s in ep.get("regret_scores", [])] | |
| all_scores = [ep["total_score"] for ep in episodes] | |
| all_wickets = [ep["wickets_lost"] for ep in episodes] | |
| print("\n=== Evaluation Summary ===") | |
| print(f" Episodes: {len(episodes)}") | |
| print(f" Mean score: {statistics.mean(all_scores):.1f} ± {statistics.stdev(all_scores) if len(all_scores)>1 else 0:.1f}") | |
| print(f" Mean wickets: {statistics.mean(all_wickets):.1f}") | |
| print(f" Mean coherence: {statistics.mean(all_coherence) if all_coherence else 0:.4f}") | |
| print(f" Mean adaptation: {statistics.mean(all_adaptation) if all_adaptation else 0:.4f}") | |
| print(f" Mean opp aware: {statistics.mean(all_awareness) if all_awareness else 0:.4f}") | |
| print(f" Mean regret score: {statistics.mean(all_regret) if all_regret else 0:.4f}") | |
| print(f" Coherence stdev: {statistics.stdev(all_coherence) if len(all_coherence)>1 else 0:.4f}") | |
| # ------------------------------------------------------------------ # | |
| # CLI # | |
| # ------------------------------------------------------------------ # | |
| async def _run_eval(args): | |
| agent = RandomAgent() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"Collecting {args.episodes} evaluation episodes...") | |
| episodes = await collect_eval_episodes( | |
| args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs | |
| ) | |
| print_summary(episodes) | |
| print_strategy_samples(episodes, label=f"(task={args.task})") | |
| if _PLOT_AVAILABLE: | |
| plot_coherence_heatmap(episodes, out_dir) | |
| plot_tool_usage_timeline(episodes, out_dir) | |
| plot_reward_curve(args.log_file, out_dir) | |
| print(f"\nPlots saved to {out_dir}/") | |
| else: | |
| print("Install matplotlib+numpy for visualisations: pip install matplotlib numpy") | |
| # Dump raw data | |
| raw_path = out_dir / "eval_episodes.jsonl" | |
| with open(raw_path, "w") as f: | |
| for ep in episodes: | |
| f.write(json.dumps(ep) + "\n") | |
| print(f"Raw data saved to {raw_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="CricketCaptain-LLM Evaluation") | |
| parser.add_argument("--config", default=None, help="YAML config path (runner defaults).") | |
| parser.add_argument("--episodes", type=int, default=20) | |
| parser.add_argument("--task", default="medium", | |
| choices=["easy", "medium", "hard", "stage2_full", "eval_50over"]) | |
| parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000")) | |
| parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default")) | |
| parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"), | |
| choices=["heuristic", "llm_live", "llm_cached"]) | |
| parser.add_argument("--max-overs", type=int, default=None, | |
| help="Limit innings length for fast experiments (e.g. 5).") | |
| parser.add_argument("--out-dir", default="./eval_output") | |
| parser.add_argument("--log-file", default=None, | |
| help="Path to JSONL training log for reward curves") | |
| parser.add_argument("--checkpoint", default=None, | |
| help="Checkpoint path (unused by default — agent is random baseline)") | |
| args = parser.parse_args() | |
| if args.config: | |
| try: | |
| from config_yaml import load_config, apply_runner_config_defaults | |
| except ImportError: | |
| from cricket_captain.config_yaml import load_config, apply_runner_config_defaults | |
| defaults = apply_runner_config_defaults(load_config(args.config)) | |
| if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url: | |
| args.env_url = defaults.env_url | |
| if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id: | |
| args.eval_pack_id = defaults.eval_pack_id | |
| if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode: | |
| args.opponent_mode = defaults.opponent_mode | |
| if args.max_overs is None and defaults.max_overs is not None: | |
| args.max_overs = int(defaults.max_overs) | |
| asyncio.run(_run_eval(args)) | |
| if __name__ == "__main__": | |
| main() | |