File size: 15,149 Bytes
6c96042 12fc392 6c96042 12fc392 4d4439f 6c96042 4d4439f 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 12fc392 6c96042 4d4439f 6c96042 d8f3415 6c96042 e70c305 12fc392 4d4439f 6c96042 d8f3415 6c96042 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 | """
Evaluation and visualisation for CricketCaptain-LLM.
Produces:
1. Reward curves — r_cric, r_coherence, r_tools, r_format vs training steps
2. Coherence heatmap — episode × turn matrix (diagonal banding = learning)
3. Tool usage timeline — scatter: over number → tool called
4. Strategy text samples — early vs late training qualitative comparison
5. Before/after canonical — fixed state (Over 35, 180/3) response comparison
Usage:
export CRICKET_CAPTAIN_ENV_URL="ws://<reachable-host>/ws"
python eval.py --checkpoint ./checkpoints/stage2_final --episodes 50 --task eval_50over
"""
import argparse
import asyncio
import json
import os
import statistics
from pathlib import Path
from typing import Any
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
_PLOT_AVAILABLE = True
except ImportError:
_PLOT_AVAILABLE = False
try:
from client import CricketCaptainEnv
from models import CricketAction
from inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent
except ImportError:
from cricket_captain.client import CricketCaptainEnv
from cricket_captain.models import CricketAction
from cricket_captain.inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent
# ------------------------------------------------------------------ #
# Data collection #
# ------------------------------------------------------------------ #
async def collect_eval_episodes(
env_url: str,
agent,
n_episodes: int,
task: str,
eval_pack_id: str = "default",
opponent_mode: str = "heuristic",
max_overs: int | None = None,
) -> list[dict[str, Any]]:
"""Run n_episodes and return raw episode data for visualisation."""
episodes = []
async with CricketCaptainEnv(env_url) as env:
for ep in range(n_episodes):
# OpenEnv server routes reset params via `options`.
result = await env.reset(options={
"task": task,
"random_start": False,
"eval_pack_id": eval_pack_id,
"opponent_mode": opponent_mode,
"max_overs": max_overs,
})
obs = result.observation
history = []
step_data = []
turn = 0
while not result.done and turn < 600:
messages = [{"role": "user", "content": obs.prompt_text}]
raw = agent(messages)
action, err = _parse_action(raw)
if err or action is None:
if obs.game_state == "bowling":
action = CricketAction(tool="bowl_delivery", arguments={})
elif obs.game_state == "toss":
action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
else:
action = CricketAction(
tool="play_delivery",
arguments={"shot_intent": "defensive", "explanation": "fallback"},
)
ctx = obs.game_context.copy()
step_data.append({
"turn": turn,
"over": ctx.get("over", 0),
"tool": action.tool,
"shot_intent": action.arguments.get("shot_intent", ""),
"strategic_phase": obs.strategic_phase,
"opponent_plan": obs.opponent_plan,
"rationale": obs.declared_strategy.get("rationale", ""),
"raw_response": raw,
"reward": result.reward or 0.0,
"parse_error": err,
})
result = await env.step(action)
obs = result.observation
turn += 1
state = await env.state()
episodes.append({
"episode": ep,
"steps": step_data,
"coherence_scores": state.coherence_scores,
"adaptation_scores": state.adaptation_scores,
"opponent_awareness_scores": state.opponent_awareness_scores,
"regret_scores": state.regret_scores,
"total_score": state.total_score,
"wickets_lost": state.wickets_lost,
"tool_calls": state.tool_calls_made,
"transcript": state.transcript,
})
print(f" Episode {ep+1}/{n_episodes} | "
f"Score: {state.total_score}/{state.wickets_lost} | "
f"Mean coherence: {statistics.mean(state.coherence_scores) if state.coherence_scores else 0:.3f}")
return episodes
# ------------------------------------------------------------------ #
# Visualisations #
# ------------------------------------------------------------------ #
def plot_coherence_heatmap(episodes: list[dict], out_dir: Path):
if not _PLOT_AVAILABLE:
print("matplotlib not available — skipping coherence heatmap")
return
max_turns = max(len(ep["coherence_scores"]) for ep in episodes) if episodes else 1
matrix = []
for ep in episodes:
row = ep["coherence_scores"][:max_turns]
row += [float("nan")] * (max_turns - len(row))
matrix.append(row)
arr = np.array(matrix, dtype=float)
fig, ax = plt.subplots(figsize=(min(max_turns // 3 + 2, 20), max(len(episodes) // 4, 4)))
im = ax.imshow(arr, aspect="auto", cmap="YlOrRd", vmin=0.0, vmax=1.0)
plt.colorbar(im, ax=ax, label="Coherence score")
ax.set_xlabel("Turn (delivery number)")
ax.set_ylabel("Episode")
ax.set_title("Coherence Heatmap\n(diagonal banding → strategic consistency across episode)")
plt.tight_layout()
path = out_dir / "coherence_heatmap.png"
plt.savefig(path, dpi=150)
plt.close()
print(f"Saved: {path}")
def plot_tool_usage_timeline(episodes: list[dict], out_dir: Path):
if not _PLOT_AVAILABLE:
print("matplotlib not available — skipping tool timeline")
return
tool_colors = {
"set_strategy": "steelblue",
"plan_shot": "royalblue",
"choose_bowler": "seagreen",
"plan_delivery": "mediumseagreen",
"analyze_situation": "darkorange",
"play_delivery": "lightgray",
"bowl_delivery": "lightgray",
"reflect_after_ball": "purple",
}
fig, ax = plt.subplots(figsize=(14, 5))
for ep_idx, ep in enumerate(episodes[:20]): # cap at 20 episodes for readability
for step in ep["steps"]:
tool = step["tool"]
if tool in ("play_delivery", "bowl_delivery"):
continue # too many — skip for clarity
ax.scatter(step["over"], ep_idx,
color=tool_colors.get(tool, "black"),
marker="D" if tool == "set_strategy" else "o",
s=60, alpha=0.8, zorder=3)
from matplotlib.patches import Patch
legend_elements = [
Patch(color="steelblue", label="set_strategy"),
Patch(color="royalblue", label="plan_shot"),
Patch(color="seagreen", label="choose_bowler"),
Patch(color="mediumseagreen", label="plan_delivery"),
Patch(color="darkorange", label="analyze_situation"),
Patch(color="purple", label="reflect_after_ball"),
]
ax.legend(handles=legend_elements, loc="upper right")
ax.axvline(6, color="gray", linestyle="--", alpha=0.4, label="Phase transition")
ax.axvline(16, color="gray", linestyle="--", alpha=0.4)
ax.axvline(36, color="gray", linestyle="--", alpha=0.4)
ax.set_xlabel("Over number")
ax.set_ylabel("Episode")
ax.set_title("Tool Usage Timeline\n(ideal: strategy + analysis cluster at phase transitions)")
plt.tight_layout()
path = out_dir / "tool_usage_timeline.png"
plt.savefig(path, dpi=150)
plt.close()
print(f"Saved: {path}")
def plot_reward_curve(log_file: str | None, out_dir: Path):
"""Plot reward curves from a JSONL training log."""
if not _PLOT_AVAILABLE:
print("matplotlib not available — skipping reward curve")
return
if not log_file or not os.path.exists(log_file):
print(f"No log file found at {log_file} — skipping reward curve")
return
steps, r_cric, r_coh, r_tools, r_fmt, composite = [], [], [], [], [], []
with open(log_file) as f:
for line in f:
try:
d = json.loads(line)
if "r_cric" in d:
steps.append(d.get("step", len(steps)))
r_cric.append(d["r_cric"])
r_coh.append(d["r_coherence"])
r_tools.append(d["r_tools"])
r_fmt.append(d["r_format"])
composite.append(d["composite"])
except (json.JSONDecodeError, KeyError):
continue
if not steps:
print("No reward data in log file")
return
fig, axes = plt.subplots(2, 3, figsize=(16, 8))
for ax, (label, vals) in zip(
axes.flat,
[("r_cric", r_cric), ("r_coherence", r_coh), ("r_tools", r_tools),
("r_format", r_fmt), ("composite", composite)],
):
ax.plot(steps, vals, linewidth=1.5)
ax.set_title(label)
ax.set_xlabel("Training step")
ax.set_ylabel("Reward")
ax.grid(True, alpha=0.3)
axes.flat[-1].axis("off")
plt.suptitle("CricketCaptain-LLM Training Reward Curves")
plt.tight_layout()
path = out_dir / "reward_curves.png"
plt.savefig(path, dpi=150)
plt.close()
print(f"Saved: {path}")
def print_strategy_samples(episodes: list[dict], label: str = ""):
"""Print sample strategy declarations for qualitative analysis."""
print(f"\n=== Strategy Samples {label} ===")
seen = set()
for ep in episodes[:10]:
for step in ep["steps"]:
r = step.get("rationale", "").strip()
if r and r not in seen and step["tool"] == "set_strategy":
seen.add(r)
print(f" [{ep['episode']}:{step['over']}] {r}")
if len(seen) >= 8:
return
def print_summary(episodes: list[dict]):
all_coherence = [s for ep in episodes for s in ep["coherence_scores"]]
all_adaptation = [s for ep in episodes for s in ep.get("adaptation_scores", [])]
all_awareness = [s for ep in episodes for s in ep.get("opponent_awareness_scores", [])]
all_regret = [s for ep in episodes for s in ep.get("regret_scores", [])]
all_scores = [ep["total_score"] for ep in episodes]
all_wickets = [ep["wickets_lost"] for ep in episodes]
print("\n=== Evaluation Summary ===")
print(f" Episodes: {len(episodes)}")
print(f" Mean score: {statistics.mean(all_scores):.1f} ± {statistics.stdev(all_scores) if len(all_scores)>1 else 0:.1f}")
print(f" Mean wickets: {statistics.mean(all_wickets):.1f}")
print(f" Mean coherence: {statistics.mean(all_coherence) if all_coherence else 0:.4f}")
print(f" Mean adaptation: {statistics.mean(all_adaptation) if all_adaptation else 0:.4f}")
print(f" Mean opp aware: {statistics.mean(all_awareness) if all_awareness else 0:.4f}")
print(f" Mean regret score: {statistics.mean(all_regret) if all_regret else 0:.4f}")
print(f" Coherence stdev: {statistics.stdev(all_coherence) if len(all_coherence)>1 else 0:.4f}")
# ------------------------------------------------------------------ #
# CLI #
# ------------------------------------------------------------------ #
async def _run_eval(args):
agent = RandomAgent()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Collecting {args.episodes} evaluation episodes...")
episodes = await collect_eval_episodes(
args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs
)
print_summary(episodes)
print_strategy_samples(episodes, label=f"(task={args.task})")
if _PLOT_AVAILABLE:
plot_coherence_heatmap(episodes, out_dir)
plot_tool_usage_timeline(episodes, out_dir)
plot_reward_curve(args.log_file, out_dir)
print(f"\nPlots saved to {out_dir}/")
else:
print("Install matplotlib+numpy for visualisations: pip install matplotlib numpy")
# Dump raw data
raw_path = out_dir / "eval_episodes.jsonl"
with open(raw_path, "w") as f:
for ep in episodes:
f.write(json.dumps(ep) + "\n")
print(f"Raw data saved to {raw_path}")
def main():
parser = argparse.ArgumentParser(description="CricketCaptain-LLM Evaluation")
parser.add_argument("--config", default=None, help="YAML config path (runner defaults).")
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--task", default="medium",
choices=["easy", "medium", "hard", "stage2_full", "eval_50over"])
parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
choices=["heuristic", "llm_live", "llm_cached"])
parser.add_argument("--max-overs", type=int, default=None,
help="Limit innings length for fast experiments (e.g. 5).")
parser.add_argument("--out-dir", default="./eval_output")
parser.add_argument("--log-file", default=None,
help="Path to JSONL training log for reward curves")
parser.add_argument("--checkpoint", default=None,
help="Checkpoint path (unused by default — agent is random baseline)")
args = parser.parse_args()
if args.config:
try:
from config_yaml import load_config, apply_runner_config_defaults
except ImportError:
from cricket_captain.config_yaml import load_config, apply_runner_config_defaults
defaults = apply_runner_config_defaults(load_config(args.config))
if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url:
args.env_url = defaults.env_url
if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id:
args.eval_pack_id = defaults.eval_pack_id
if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode:
args.opponent_mode = defaults.opponent_mode
if args.max_overs is None and defaults.max_overs is not None:
args.max_overs = int(defaults.max_overs)
asyncio.run(_run_eval(args))
if __name__ == "__main__":
main()
|