pratinavseth's picture
sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified
"""
Evaluation and visualisation for CricketCaptain-LLM.
Produces:
1. Reward curves — r_cric, r_coherence, r_tools, r_format vs training steps
2. Coherence heatmap — episode × turn matrix (diagonal banding = learning)
3. Tool usage timeline — scatter: over number → tool called
4. Strategy text samples — early vs late training qualitative comparison
5. Before/after canonical — fixed state (Over 35, 180/3) response comparison
Usage:
export CRICKET_CAPTAIN_ENV_URL="ws://<reachable-host>/ws"
python eval.py --checkpoint ./checkpoints/stage2_final --episodes 50 --task eval_50over
"""
import argparse
import asyncio
import json
import os
import statistics
from pathlib import Path
from typing import Any
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
_PLOT_AVAILABLE = True
except ImportError:
_PLOT_AVAILABLE = False
try:
from client import CricketCaptainEnv
from models import CricketAction
from inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent
except ImportError:
from cricket_captain.client import CricketCaptainEnv
from cricket_captain.models import CricketAction
from cricket_captain.inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent
# ------------------------------------------------------------------ #
# Data collection #
# ------------------------------------------------------------------ #
async def collect_eval_episodes(
env_url: str,
agent,
n_episodes: int,
task: str,
eval_pack_id: str = "default",
opponent_mode: str = "heuristic",
max_overs: int | None = None,
) -> list[dict[str, Any]]:
"""Run n_episodes and return raw episode data for visualisation."""
episodes = []
async with CricketCaptainEnv(env_url) as env:
for ep in range(n_episodes):
# OpenEnv server routes reset params via `options`.
result = await env.reset(options={
"task": task,
"random_start": False,
"eval_pack_id": eval_pack_id,
"opponent_mode": opponent_mode,
"max_overs": max_overs,
})
obs = result.observation
history = []
step_data = []
turn = 0
while not result.done and turn < 600:
messages = [{"role": "user", "content": obs.prompt_text}]
raw = agent(messages)
action, err = _parse_action(raw)
if err or action is None:
if obs.game_state == "bowling":
action = CricketAction(tool="bowl_delivery", arguments={})
elif obs.game_state == "toss":
action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
else:
action = CricketAction(
tool="play_delivery",
arguments={"shot_intent": "defensive", "explanation": "fallback"},
)
ctx = obs.game_context.copy()
step_data.append({
"turn": turn,
"over": ctx.get("over", 0),
"tool": action.tool,
"shot_intent": action.arguments.get("shot_intent", ""),
"strategic_phase": obs.strategic_phase,
"opponent_plan": obs.opponent_plan,
"rationale": obs.declared_strategy.get("rationale", ""),
"raw_response": raw,
"reward": result.reward or 0.0,
"parse_error": err,
})
result = await env.step(action)
obs = result.observation
turn += 1
state = await env.state()
episodes.append({
"episode": ep,
"steps": step_data,
"coherence_scores": state.coherence_scores,
"adaptation_scores": state.adaptation_scores,
"opponent_awareness_scores": state.opponent_awareness_scores,
"regret_scores": state.regret_scores,
"total_score": state.total_score,
"wickets_lost": state.wickets_lost,
"tool_calls": state.tool_calls_made,
"transcript": state.transcript,
})
print(f" Episode {ep+1}/{n_episodes} | "
f"Score: {state.total_score}/{state.wickets_lost} | "
f"Mean coherence: {statistics.mean(state.coherence_scores) if state.coherence_scores else 0:.3f}")
return episodes
# ------------------------------------------------------------------ #
# Visualisations #
# ------------------------------------------------------------------ #
def plot_coherence_heatmap(episodes: list[dict], out_dir: Path):
if not _PLOT_AVAILABLE:
print("matplotlib not available — skipping coherence heatmap")
return
max_turns = max(len(ep["coherence_scores"]) for ep in episodes) if episodes else 1
matrix = []
for ep in episodes:
row = ep["coherence_scores"][:max_turns]
row += [float("nan")] * (max_turns - len(row))
matrix.append(row)
arr = np.array(matrix, dtype=float)
fig, ax = plt.subplots(figsize=(min(max_turns // 3 + 2, 20), max(len(episodes) // 4, 4)))
im = ax.imshow(arr, aspect="auto", cmap="YlOrRd", vmin=0.0, vmax=1.0)
plt.colorbar(im, ax=ax, label="Coherence score")
ax.set_xlabel("Turn (delivery number)")
ax.set_ylabel("Episode")
ax.set_title("Coherence Heatmap\n(diagonal banding → strategic consistency across episode)")
plt.tight_layout()
path = out_dir / "coherence_heatmap.png"
plt.savefig(path, dpi=150)
plt.close()
print(f"Saved: {path}")
def plot_tool_usage_timeline(episodes: list[dict], out_dir: Path):
if not _PLOT_AVAILABLE:
print("matplotlib not available — skipping tool timeline")
return
tool_colors = {
"set_strategy": "steelblue",
"plan_shot": "royalblue",
"choose_bowler": "seagreen",
"plan_delivery": "mediumseagreen",
"analyze_situation": "darkorange",
"play_delivery": "lightgray",
"bowl_delivery": "lightgray",
"reflect_after_ball": "purple",
}
fig, ax = plt.subplots(figsize=(14, 5))
for ep_idx, ep in enumerate(episodes[:20]): # cap at 20 episodes for readability
for step in ep["steps"]:
tool = step["tool"]
if tool in ("play_delivery", "bowl_delivery"):
continue # too many — skip for clarity
ax.scatter(step["over"], ep_idx,
color=tool_colors.get(tool, "black"),
marker="D" if tool == "set_strategy" else "o",
s=60, alpha=0.8, zorder=3)
from matplotlib.patches import Patch
legend_elements = [
Patch(color="steelblue", label="set_strategy"),
Patch(color="royalblue", label="plan_shot"),
Patch(color="seagreen", label="choose_bowler"),
Patch(color="mediumseagreen", label="plan_delivery"),
Patch(color="darkorange", label="analyze_situation"),
Patch(color="purple", label="reflect_after_ball"),
]
ax.legend(handles=legend_elements, loc="upper right")
ax.axvline(6, color="gray", linestyle="--", alpha=0.4, label="Phase transition")
ax.axvline(16, color="gray", linestyle="--", alpha=0.4)
ax.axvline(36, color="gray", linestyle="--", alpha=0.4)
ax.set_xlabel("Over number")
ax.set_ylabel("Episode")
ax.set_title("Tool Usage Timeline\n(ideal: strategy + analysis cluster at phase transitions)")
plt.tight_layout()
path = out_dir / "tool_usage_timeline.png"
plt.savefig(path, dpi=150)
plt.close()
print(f"Saved: {path}")
def plot_reward_curve(log_file: str | None, out_dir: Path):
"""Plot reward curves from a JSONL training log."""
if not _PLOT_AVAILABLE:
print("matplotlib not available — skipping reward curve")
return
if not log_file or not os.path.exists(log_file):
print(f"No log file found at {log_file} — skipping reward curve")
return
steps, r_cric, r_coh, r_tools, r_fmt, composite = [], [], [], [], [], []
with open(log_file) as f:
for line in f:
try:
d = json.loads(line)
if "r_cric" in d:
steps.append(d.get("step", len(steps)))
r_cric.append(d["r_cric"])
r_coh.append(d["r_coherence"])
r_tools.append(d["r_tools"])
r_fmt.append(d["r_format"])
composite.append(d["composite"])
except (json.JSONDecodeError, KeyError):
continue
if not steps:
print("No reward data in log file")
return
fig, axes = plt.subplots(2, 3, figsize=(16, 8))
for ax, (label, vals) in zip(
axes.flat,
[("r_cric", r_cric), ("r_coherence", r_coh), ("r_tools", r_tools),
("r_format", r_fmt), ("composite", composite)],
):
ax.plot(steps, vals, linewidth=1.5)
ax.set_title(label)
ax.set_xlabel("Training step")
ax.set_ylabel("Reward")
ax.grid(True, alpha=0.3)
axes.flat[-1].axis("off")
plt.suptitle("CricketCaptain-LLM Training Reward Curves")
plt.tight_layout()
path = out_dir / "reward_curves.png"
plt.savefig(path, dpi=150)
plt.close()
print(f"Saved: {path}")
def print_strategy_samples(episodes: list[dict], label: str = ""):
"""Print sample strategy declarations for qualitative analysis."""
print(f"\n=== Strategy Samples {label} ===")
seen = set()
for ep in episodes[:10]:
for step in ep["steps"]:
r = step.get("rationale", "").strip()
if r and r not in seen and step["tool"] == "set_strategy":
seen.add(r)
print(f" [{ep['episode']}:{step['over']}] {r}")
if len(seen) >= 8:
return
def print_summary(episodes: list[dict]):
all_coherence = [s for ep in episodes for s in ep["coherence_scores"]]
all_adaptation = [s for ep in episodes for s in ep.get("adaptation_scores", [])]
all_awareness = [s for ep in episodes for s in ep.get("opponent_awareness_scores", [])]
all_regret = [s for ep in episodes for s in ep.get("regret_scores", [])]
all_scores = [ep["total_score"] for ep in episodes]
all_wickets = [ep["wickets_lost"] for ep in episodes]
print("\n=== Evaluation Summary ===")
print(f" Episodes: {len(episodes)}")
print(f" Mean score: {statistics.mean(all_scores):.1f} ± {statistics.stdev(all_scores) if len(all_scores)>1 else 0:.1f}")
print(f" Mean wickets: {statistics.mean(all_wickets):.1f}")
print(f" Mean coherence: {statistics.mean(all_coherence) if all_coherence else 0:.4f}")
print(f" Mean adaptation: {statistics.mean(all_adaptation) if all_adaptation else 0:.4f}")
print(f" Mean opp aware: {statistics.mean(all_awareness) if all_awareness else 0:.4f}")
print(f" Mean regret score: {statistics.mean(all_regret) if all_regret else 0:.4f}")
print(f" Coherence stdev: {statistics.stdev(all_coherence) if len(all_coherence)>1 else 0:.4f}")
# ------------------------------------------------------------------ #
# CLI #
# ------------------------------------------------------------------ #
async def _run_eval(args):
agent = RandomAgent()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Collecting {args.episodes} evaluation episodes...")
episodes = await collect_eval_episodes(
args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs
)
print_summary(episodes)
print_strategy_samples(episodes, label=f"(task={args.task})")
if _PLOT_AVAILABLE:
plot_coherence_heatmap(episodes, out_dir)
plot_tool_usage_timeline(episodes, out_dir)
plot_reward_curve(args.log_file, out_dir)
print(f"\nPlots saved to {out_dir}/")
else:
print("Install matplotlib+numpy for visualisations: pip install matplotlib numpy")
# Dump raw data
raw_path = out_dir / "eval_episodes.jsonl"
with open(raw_path, "w") as f:
for ep in episodes:
f.write(json.dumps(ep) + "\n")
print(f"Raw data saved to {raw_path}")
def main():
parser = argparse.ArgumentParser(description="CricketCaptain-LLM Evaluation")
parser.add_argument("--config", default=None, help="YAML config path (runner defaults).")
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--task", default="medium",
choices=["easy", "medium", "hard", "stage2_full", "eval_50over"])
parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
choices=["heuristic", "llm_live", "llm_cached"])
parser.add_argument("--max-overs", type=int, default=None,
help="Limit innings length for fast experiments (e.g. 5).")
parser.add_argument("--out-dir", default="./eval_output")
parser.add_argument("--log-file", default=None,
help="Path to JSONL training log for reward curves")
parser.add_argument("--checkpoint", default=None,
help="Checkpoint path (unused by default — agent is random baseline)")
args = parser.parse_args()
if args.config:
try:
from config_yaml import load_config, apply_runner_config_defaults
except ImportError:
from cricket_captain.config_yaml import load_config, apply_runner_config_defaults
defaults = apply_runner_config_defaults(load_config(args.config))
if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url:
args.env_url = defaults.env_url
if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id:
args.eval_pack_id = defaults.eval_pack_id
if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode:
args.opponent_mode = defaults.opponent_mode
if args.max_overs is None and defaults.max_overs is not None:
args.max_overs = int(defaults.max_overs)
asyncio.run(_run_eval(args))
if __name__ == "__main__":
main()