File size: 15,149 Bytes
6c96042
 
 
 
 
 
 
 
 
 
 
12fc392
 
6c96042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12fc392
 
4d4439f
6c96042
 
 
 
 
4d4439f
 
 
 
 
 
 
 
6c96042
 
 
 
 
12fc392
6c96042
 
 
 
12fc392
 
 
 
 
 
 
 
 
6c96042
 
 
 
 
 
 
12fc392
 
6c96042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12fc392
 
 
6c96042
 
 
12fc392
6c96042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12fc392
 
 
6c96042
 
12fc392
 
6c96042
 
 
 
 
12fc392
6c96042
 
 
 
 
 
 
 
 
12fc392
 
 
6c96042
12fc392
6c96042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12fc392
 
 
6c96042
 
 
 
 
 
 
 
12fc392
 
 
6c96042
 
 
 
 
 
 
 
 
 
 
 
 
 
4d4439f
6c96042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8f3415
6c96042
e70c305
 
12fc392
 
 
 
4d4439f
 
6c96042
 
 
 
 
 
d8f3415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c96042
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""
Evaluation and visualisation for CricketCaptain-LLM.

Produces:
  1. Reward curves          — r_cric, r_coherence, r_tools, r_format vs training steps
  2. Coherence heatmap      — episode × turn matrix (diagonal banding = learning)
  3. Tool usage timeline    — scatter: over number → tool called
  4. Strategy text samples  — early vs late training qualitative comparison
  5. Before/after canonical — fixed state (Over 35, 180/3) response comparison

Usage:
  export CRICKET_CAPTAIN_ENV_URL="ws://<reachable-host>/ws"
  python eval.py --checkpoint ./checkpoints/stage2_final --episodes 50 --task eval_50over
"""

import argparse
import asyncio
import json
import os
import statistics
from pathlib import Path
from typing import Any

try:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors
    import numpy as np
    _PLOT_AVAILABLE = True
except ImportError:
    _PLOT_AVAILABLE = False

try:
    from client import CricketCaptainEnv
    from models import CricketAction
    from inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent
except ImportError:
    from cricket_captain.client import CricketCaptainEnv
    from cricket_captain.models import CricketAction
    from cricket_captain.inference import _parse_action, SHOT_AGGRESSION_ORDER, RandomAgent


# ------------------------------------------------------------------ #
# Data collection                                                      #
# ------------------------------------------------------------------ #

async def collect_eval_episodes(
    env_url: str,
    agent,
    n_episodes: int,
    task: str,
    eval_pack_id: str = "default",
    opponent_mode: str = "heuristic",
    max_overs: int | None = None,
) -> list[dict[str, Any]]:
    """Run n_episodes and return raw episode data for visualisation."""
    episodes = []
    async with CricketCaptainEnv(env_url) as env:
        for ep in range(n_episodes):
            # OpenEnv server routes reset params via `options`.
            result = await env.reset(options={
                "task": task,
                "random_start": False,
                "eval_pack_id": eval_pack_id,
                "opponent_mode": opponent_mode,
                "max_overs": max_overs,
            })
            obs = result.observation
            history = []
            step_data = []
            turn = 0

            while not result.done and turn < 600:
                messages = [{"role": "user", "content": obs.prompt_text}]
                raw = agent(messages)
                action, err = _parse_action(raw)
                if err or action is None:
                    if obs.game_state == "bowling":
                        action = CricketAction(tool="bowl_delivery", arguments={})
                    elif obs.game_state == "toss":
                        action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
                    else:
                        action = CricketAction(
                            tool="play_delivery",
                            arguments={"shot_intent": "defensive", "explanation": "fallback"},
                        )

                ctx = obs.game_context.copy()
                step_data.append({
                    "turn": turn,
                    "over": ctx.get("over", 0),
                    "tool": action.tool,
                    "shot_intent": action.arguments.get("shot_intent", ""),
                    "strategic_phase": obs.strategic_phase,
                    "opponent_plan": obs.opponent_plan,
                    "rationale": obs.declared_strategy.get("rationale", ""),
                    "raw_response": raw,
                    "reward": result.reward or 0.0,
                    "parse_error": err,
                })

                result = await env.step(action)
                obs = result.observation
                turn += 1

            state = await env.state()
            episodes.append({
                "episode": ep,
                "steps": step_data,
                "coherence_scores": state.coherence_scores,
                "adaptation_scores": state.adaptation_scores,
                "opponent_awareness_scores": state.opponent_awareness_scores,
                "regret_scores": state.regret_scores,
                "total_score": state.total_score,
                "wickets_lost": state.wickets_lost,
                "tool_calls": state.tool_calls_made,
                "transcript": state.transcript,
            })
            print(f"  Episode {ep+1}/{n_episodes} | "
                  f"Score: {state.total_score}/{state.wickets_lost} | "
                  f"Mean coherence: {statistics.mean(state.coherence_scores) if state.coherence_scores else 0:.3f}")
    return episodes


# ------------------------------------------------------------------ #
# Visualisations                                                       #
# ------------------------------------------------------------------ #

def plot_coherence_heatmap(episodes: list[dict], out_dir: Path):
    if not _PLOT_AVAILABLE:
        print("matplotlib not available — skipping coherence heatmap")
        return

    max_turns = max(len(ep["coherence_scores"]) for ep in episodes) if episodes else 1
    matrix = []
    for ep in episodes:
        row = ep["coherence_scores"][:max_turns]
        row += [float("nan")] * (max_turns - len(row))
        matrix.append(row)

    arr = np.array(matrix, dtype=float)
    fig, ax = plt.subplots(figsize=(min(max_turns // 3 + 2, 20), max(len(episodes) // 4, 4)))
    im = ax.imshow(arr, aspect="auto", cmap="YlOrRd", vmin=0.0, vmax=1.0)
    plt.colorbar(im, ax=ax, label="Coherence score")
    ax.set_xlabel("Turn (delivery number)")
    ax.set_ylabel("Episode")
    ax.set_title("Coherence Heatmap\n(diagonal banding → strategic consistency across episode)")
    plt.tight_layout()
    path = out_dir / "coherence_heatmap.png"
    plt.savefig(path, dpi=150)
    plt.close()
    print(f"Saved: {path}")


def plot_tool_usage_timeline(episodes: list[dict], out_dir: Path):
    if not _PLOT_AVAILABLE:
        print("matplotlib not available — skipping tool timeline")
        return

    tool_colors = {
        "set_strategy": "steelblue",
        "plan_shot": "royalblue",
        "choose_bowler": "seagreen",
        "plan_delivery": "mediumseagreen",
        "analyze_situation": "darkorange",
        "play_delivery": "lightgray",
        "bowl_delivery": "lightgray",
        "reflect_after_ball": "purple",
    }
    fig, ax = plt.subplots(figsize=(14, 5))
    for ep_idx, ep in enumerate(episodes[:20]):  # cap at 20 episodes for readability
        for step in ep["steps"]:
            tool = step["tool"]
            if tool in ("play_delivery", "bowl_delivery"):
                continue  # too many — skip for clarity
            ax.scatter(step["over"], ep_idx,
                       color=tool_colors.get(tool, "black"),
                       marker="D" if tool == "set_strategy" else "o",
                       s=60, alpha=0.8, zorder=3)

    from matplotlib.patches import Patch
    legend_elements = [
        Patch(color="steelblue", label="set_strategy"),
        Patch(color="royalblue", label="plan_shot"),
        Patch(color="seagreen", label="choose_bowler"),
        Patch(color="mediumseagreen", label="plan_delivery"),
        Patch(color="darkorange", label="analyze_situation"),
        Patch(color="purple", label="reflect_after_ball"),
    ]
    ax.legend(handles=legend_elements, loc="upper right")
    ax.axvline(6, color="gray", linestyle="--", alpha=0.4, label="Phase transition")
    ax.axvline(16, color="gray", linestyle="--", alpha=0.4)
    ax.axvline(36, color="gray", linestyle="--", alpha=0.4)
    ax.set_xlabel("Over number")
    ax.set_ylabel("Episode")
    ax.set_title("Tool Usage Timeline\n(ideal: strategy + analysis cluster at phase transitions)")
    plt.tight_layout()
    path = out_dir / "tool_usage_timeline.png"
    plt.savefig(path, dpi=150)
    plt.close()
    print(f"Saved: {path}")


def plot_reward_curve(log_file: str | None, out_dir: Path):
    """Plot reward curves from a JSONL training log."""
    if not _PLOT_AVAILABLE:
        print("matplotlib not available — skipping reward curve")
        return
    if not log_file or not os.path.exists(log_file):
        print(f"No log file found at {log_file} — skipping reward curve")
        return

    steps, r_cric, r_coh, r_tools, r_fmt, composite = [], [], [], [], [], []
    with open(log_file) as f:
        for line in f:
            try:
                d = json.loads(line)
                if "r_cric" in d:
                    steps.append(d.get("step", len(steps)))
                    r_cric.append(d["r_cric"])
                    r_coh.append(d["r_coherence"])
                    r_tools.append(d["r_tools"])
                    r_fmt.append(d["r_format"])
                    composite.append(d["composite"])
            except (json.JSONDecodeError, KeyError):
                continue

    if not steps:
        print("No reward data in log file")
        return

    fig, axes = plt.subplots(2, 3, figsize=(16, 8))
    for ax, (label, vals) in zip(
        axes.flat,
        [("r_cric", r_cric), ("r_coherence", r_coh), ("r_tools", r_tools),
         ("r_format", r_fmt), ("composite", composite)],
    ):
        ax.plot(steps, vals, linewidth=1.5)
        ax.set_title(label)
        ax.set_xlabel("Training step")
        ax.set_ylabel("Reward")
        ax.grid(True, alpha=0.3)
    axes.flat[-1].axis("off")
    plt.suptitle("CricketCaptain-LLM Training Reward Curves")
    plt.tight_layout()
    path = out_dir / "reward_curves.png"
    plt.savefig(path, dpi=150)
    plt.close()
    print(f"Saved: {path}")


def print_strategy_samples(episodes: list[dict], label: str = ""):
    """Print sample strategy declarations for qualitative analysis."""
    print(f"\n=== Strategy Samples {label} ===")
    seen = set()
    for ep in episodes[:10]:
        for step in ep["steps"]:
            r = step.get("rationale", "").strip()
            if r and r not in seen and step["tool"] == "set_strategy":
                seen.add(r)
                print(f"  [{ep['episode']}:{step['over']}] {r}")
                if len(seen) >= 8:
                    return


def print_summary(episodes: list[dict]):
    all_coherence = [s for ep in episodes for s in ep["coherence_scores"]]
    all_adaptation = [s for ep in episodes for s in ep.get("adaptation_scores", [])]
    all_awareness = [s for ep in episodes for s in ep.get("opponent_awareness_scores", [])]
    all_regret = [s for ep in episodes for s in ep.get("regret_scores", [])]
    all_scores = [ep["total_score"] for ep in episodes]
    all_wickets = [ep["wickets_lost"] for ep in episodes]

    print("\n=== Evaluation Summary ===")
    print(f"  Episodes:          {len(episodes)}")
    print(f"  Mean score:        {statistics.mean(all_scores):.1f} ± {statistics.stdev(all_scores) if len(all_scores)>1 else 0:.1f}")
    print(f"  Mean wickets:      {statistics.mean(all_wickets):.1f}")
    print(f"  Mean coherence:    {statistics.mean(all_coherence) if all_coherence else 0:.4f}")
    print(f"  Mean adaptation:   {statistics.mean(all_adaptation) if all_adaptation else 0:.4f}")
    print(f"  Mean opp aware:    {statistics.mean(all_awareness) if all_awareness else 0:.4f}")
    print(f"  Mean regret score: {statistics.mean(all_regret) if all_regret else 0:.4f}")
    print(f"  Coherence stdev:   {statistics.stdev(all_coherence) if len(all_coherence)>1 else 0:.4f}")


# ------------------------------------------------------------------ #
# CLI                                                                  #
# ------------------------------------------------------------------ #

async def _run_eval(args):
    agent = RandomAgent()
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Collecting {args.episodes} evaluation episodes...")
    episodes = await collect_eval_episodes(
        args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs
    )

    print_summary(episodes)
    print_strategy_samples(episodes, label=f"(task={args.task})")

    if _PLOT_AVAILABLE:
        plot_coherence_heatmap(episodes, out_dir)
        plot_tool_usage_timeline(episodes, out_dir)
        plot_reward_curve(args.log_file, out_dir)
        print(f"\nPlots saved to {out_dir}/")
    else:
        print("Install matplotlib+numpy for visualisations: pip install matplotlib numpy")

    # Dump raw data
    raw_path = out_dir / "eval_episodes.jsonl"
    with open(raw_path, "w") as f:
        for ep in episodes:
            f.write(json.dumps(ep) + "\n")
    print(f"Raw data saved to {raw_path}")


def main():
    parser = argparse.ArgumentParser(description="CricketCaptain-LLM Evaluation")
    parser.add_argument("--config", default=None, help="YAML config path (runner defaults).")
    parser.add_argument("--episodes", type=int, default=20)
    parser.add_argument("--task", default="medium",
                        choices=["easy", "medium", "hard", "stage2_full", "eval_50over"])
    parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
    parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
    parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
                        choices=["heuristic", "llm_live", "llm_cached"])
    parser.add_argument("--max-overs", type=int, default=None,
                        help="Limit innings length for fast experiments (e.g. 5).")
    parser.add_argument("--out-dir", default="./eval_output")
    parser.add_argument("--log-file", default=None,
                        help="Path to JSONL training log for reward curves")
    parser.add_argument("--checkpoint", default=None,
                        help="Checkpoint path (unused by default — agent is random baseline)")
    args = parser.parse_args()

    if args.config:
        try:
            from config_yaml import load_config, apply_runner_config_defaults
        except ImportError:
            from cricket_captain.config_yaml import load_config, apply_runner_config_defaults
        defaults = apply_runner_config_defaults(load_config(args.config))
        if args.env_url == os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000") and defaults.env_url:
            args.env_url = defaults.env_url
        if args.eval_pack_id == os.environ.get("CRICKET_EVAL_PACK_ID", "default") and defaults.eval_pack_id:
            args.eval_pack_id = defaults.eval_pack_id
        if args.opponent_mode == os.environ.get("CRICKET_OPPONENT_MODE", "heuristic") and defaults.opponent_mode:
            args.opponent_mode = defaults.opponent_mode
        if args.max_overs is None and defaults.max_overs is not None:
            args.max_overs = int(defaults.max_overs)

    asyncio.run(_run_eval(args))


if __name__ == "__main__":
    main()