sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified | """ | |
| compare_eval.py — Baseline vs trained head-to-head evaluation. | |
| Plays N full cricket matches with the BASELINE model (untrained Qwen3-4B-Instruct-2507) | |
| and the TRAINED model (same base + LoRA adapter from a training checkpoint), then | |
| dumps a comparison table: | |
| win_rate, mean_agent_score, mean_opp_score, mean_wickets, match_completion_rate, | |
| mean_tool_calls_per_episode, validity_rate, plus a few illustrative transcripts. | |
| Why this is the right eval for our setup | |
| ---------------------------------------- | |
| Training caps rollouts at the warmup/main token budgets (16k / 24k), which means | |
| warmup rollouts run short formats and main rollouts run 5-over. At EVAL time we | |
| lift the cap — the model gets unlimited context and can play full T20s. This is | |
| the same pattern coding-agent RL papers use: train on partial windows, eval on | |
| full task completion. The trained policy generalizes because it learned good | |
| per-state decisions, not a specific trajectory length. | |
| Usage | |
| ----- | |
| # Baseline (untrained Qwen3-4B-Instruct-2507 base) | |
| python compare_eval.py \\ | |
| --model Qwen/Qwen3-4B-Instruct-2507 \\ | |
| --label baseline \\ | |
| --episodes 20 --max-overs 5 \\ | |
| --output eval_results/baseline.json | |
| # Trained (warmup + main checkpoint) | |
| python compare_eval.py \\ | |
| --model Qwen/Qwen3-4B-Instruct-2507 \\ | |
| --adapter ./checkpoints/stage2_final \\ | |
| --label trained \\ | |
| --episodes 20 --max-overs 5 \\ | |
| --output eval_results/trained.json | |
| # Side-by-side comparison | |
| python compare_eval.py --compare eval_results/baseline.json eval_results/trained.json | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from collections import Counter | |
| from pathlib import Path | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from server.cricket_environment import CricketEnvironment | |
| from models import CricketAction | |
| import train as train_module # reuse SYSTEM_PROMPT and _parse_completion | |
| # ---------------------------------------------------------------------------- | |
| # Model loading | |
| # ---------------------------------------------------------------------------- | |
| def load_model_for_eval(model_name: str, adapter_path: str | None = None): | |
| """Load base model in bf16; optionally apply a LoRA adapter on top.""" | |
| print(f"Loading base model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| if adapter_path: | |
| print(f"Loading LoRA adapter: {adapter_path}") | |
| model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False) | |
| model.eval() | |
| return model, tokenizer | |
| # ---------------------------------------------------------------------------- | |
| # Single-episode rollout (no token cap — let matches actually complete) | |
| # ---------------------------------------------------------------------------- | |
| def play_one_episode( | |
| *, | |
| model, | |
| tokenizer, | |
| max_overs: int, | |
| opponent_mode: str, | |
| agent_team: str, | |
| eval_pack_id: str, | |
| seed: int, | |
| max_tool_calls: int = 800, | |
| max_completion_per_turn: int = 256, # per-turn (NOT per-rollout) — eval is turn-by-turn | |
| temperature: float = 0.3, # deterministic-ish at eval | |
| verbose: bool = False, | |
| ) -> dict: | |
| """Run one full match. Returns per-episode stats.""" | |
| env = CricketEnvironment() | |
| obs = env.reset(seed=seed, options={ | |
| "task": "stage2_full", | |
| "random_start": False, | |
| "max_overs": max_overs, | |
| "eval_pack_id": eval_pack_id, | |
| "opponent_mode": opponent_mode, | |
| "agent_team": agent_team, | |
| }) | |
| # Build the message log progressively. Each turn appends model output + tool response. | |
| system_prompt = train_module.SYSTEM_PROMPT | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": obs.prompt_text}, | |
| ] | |
| tool_calls_made = 0 | |
| tool_breakdown: Counter = Counter() | |
| parse_failures = 0 | |
| illegal_tool_attempts = 0 | |
| start_t = time.time() | |
| while not obs.done and tool_calls_made < max_tool_calls: | |
| # Render chat using model's tool template | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| except Exception as e: | |
| print(f" apply_chat_template error: {e}") | |
| break | |
| with torch.no_grad(): | |
| out = model.generate( | |
| inputs, | |
| max_new_tokens=max_completion_per_turn, | |
| do_sample=(temperature > 0), | |
| temperature=max(temperature, 1e-5), | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| gen_ids = out[0, inputs.shape[1]:] | |
| completion = tokenizer.decode(gen_ids, skip_special_tokens=False) | |
| # Parse the tool call | |
| parsed = train_module._parse_completion(completion) | |
| if parsed is None: | |
| parse_failures += 1 | |
| if verbose: | |
| print(f" PARSE FAIL: {completion[:200]}...") | |
| messages.append({"role": "assistant", "content": completion}) | |
| messages.append({"role": "user", "content": "Your previous output was not parseable. Please emit exactly one tool call."}) | |
| continue | |
| tool_name = parsed.get("tool", "") | |
| tool_args = parsed.get("arguments", {}) or {} | |
| tool_breakdown[tool_name] += 1 | |
| # Apply to env | |
| try: | |
| obs = env.step(CricketAction(tool=tool_name, arguments=tool_args)) | |
| tool_calls_made += 1 | |
| except Exception as e: | |
| illegal_tool_attempts += 1 | |
| if verbose: | |
| print(f" ILLEGAL TOOL: {tool_name} → {e}") | |
| messages.append({"role": "assistant", "content": completion}) | |
| messages.append({"role": "user", "content": f"Tool error: {e}. Try a different tool."}) | |
| continue | |
| messages.append({"role": "assistant", "content": completion}) | |
| messages.append({"role": "user", "content": obs.prompt_text}) | |
| elapsed = time.time() - start_t | |
| state = env.state | |
| breakdown = state.reward_breakdown or {} | |
| # Determine match result | |
| is_complete = bool(obs.done) | |
| agent_score = int(state.total_score or 0) | |
| opp_score = int(state.first_innings_score or 0) if state.innings_type == "second" else None | |
| target = state.target | |
| won = None | |
| if is_complete: | |
| # Crude win check; env's match_result string is the canonical source | |
| result_str = (state.match_result or "").lower() | |
| if "won" in result_str and "agent" in result_str: | |
| won = True | |
| elif "lost" in result_str or "won" in result_str: | |
| won = False | |
| else: | |
| won = None | |
| return { | |
| "seed": seed, | |
| "max_overs": max_overs, | |
| "opponent_mode": opponent_mode, | |
| "tool_calls_made": tool_calls_made, | |
| "match_complete": is_complete, | |
| "won": won, | |
| "agent_score": agent_score, | |
| "opponent_first_innings_score": opp_score, | |
| "target": target, | |
| "wickets_lost": int(state.wickets_lost or 0), | |
| "match_result": state.match_result or "", | |
| "tool_breakdown": dict(tool_breakdown), | |
| "parse_failures": parse_failures, | |
| "illegal_tool_attempts": illegal_tool_attempts, | |
| "validity_rate": round(1.0 - (parse_failures + illegal_tool_attempts) / max(tool_calls_made + parse_failures + illegal_tool_attempts, 1), 4), | |
| "reward_breakdown": dict(breakdown), | |
| "elapsed_seconds": round(elapsed, 1), | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # Run N episodes | |
| # ---------------------------------------------------------------------------- | |
| def run_n_episodes( | |
| *, model, tokenizer, episodes: int, max_overs: int, opponent_mode: str, | |
| agent_team: str, eval_pack_id: str, seed_base: int, max_tool_calls: int, | |
| max_completion_per_turn: int, temperature: float, verbose: bool, | |
| ) -> dict: | |
| results = [] | |
| for i in range(episodes): | |
| seed = seed_base + i | |
| print(f" [{i+1}/{episodes}] seed={seed} …", end="", flush=True) | |
| try: | |
| res = play_one_episode( | |
| model=model, tokenizer=tokenizer, | |
| max_overs=max_overs, opponent_mode=opponent_mode, | |
| agent_team=agent_team, eval_pack_id=eval_pack_id, seed=seed, | |
| max_tool_calls=max_tool_calls, | |
| max_completion_per_turn=max_completion_per_turn, | |
| temperature=temperature, verbose=verbose, | |
| ) | |
| print(f" {res['tool_calls_made']} tool calls, " | |
| f"{'COMPLETE' if res['match_complete'] else 'truncated'}, " | |
| f"score {res['agent_score']}/{res['wickets_lost']}, " | |
| f"{res['elapsed_seconds']}s") | |
| results.append(res) | |
| except Exception as e: | |
| print(f" FAILED: {e}") | |
| results.append({"seed": seed, "error": str(e)}) | |
| # Aggregate | |
| valid = [r for r in results if "error" not in r] | |
| n = len(valid) | |
| if n == 0: | |
| return {"results": results, "summary": {"n": 0, "error": "all episodes failed"}} | |
| completed = [r for r in valid if r["match_complete"]] | |
| won = [r for r in completed if r.get("won") is True] | |
| summary = { | |
| "n_episodes": n, | |
| "match_completion_rate": round(len(completed) / n, 4), | |
| "win_rate_among_completed": round(len(won) / max(len(completed), 1), 4), | |
| "win_rate_overall": round(len(won) / n, 4), | |
| "mean_agent_score": round(sum(r["agent_score"] for r in valid) / n, 2), | |
| "mean_wickets_lost": round(sum(r["wickets_lost"] for r in valid) / n, 2), | |
| "mean_tool_calls": round(sum(r["tool_calls_made"] for r in valid) / n, 1), | |
| "mean_validity_rate": round(sum(r["validity_rate"] for r in valid) / n, 4), | |
| "mean_composite_reward": round(sum(r["reward_breakdown"].get("composite", 0.0) for r in valid) / n, 4), | |
| "mean_r_result": round(sum(r["reward_breakdown"].get("r_result", 0.0) for r in valid) / n, 4), | |
| "mean_r_cricket": round(sum(r["reward_breakdown"].get("r_cricket", 0.0) for r in valid) / n, 4), | |
| "mean_r_behavior": round(sum(r["reward_breakdown"].get("r_behavior", 0.0) for r in valid) / n, 4), | |
| "mean_r_validity": round(sum(r["reward_breakdown"].get("r_validity", 0.0) for r in valid) / n, 4), | |
| "tool_freq": {}, | |
| } | |
| # Aggregate tool frequencies | |
| all_tools: Counter = Counter() | |
| for r in valid: | |
| for t, c in (r.get("tool_breakdown") or {}).items(): | |
| all_tools[t] += c | |
| total = sum(all_tools.values()) or 1 | |
| summary["tool_freq"] = {t: round(c / total, 3) for t, c in all_tools.most_common()} | |
| return {"results": results, "summary": summary} | |
| # ---------------------------------------------------------------------------- | |
| # Comparison printer | |
| # ---------------------------------------------------------------------------- | |
| def print_comparison(baseline_path: str, trained_path: str): | |
| with open(baseline_path) as f: | |
| b = json.load(f) | |
| with open(trained_path) as f: | |
| t = json.load(f) | |
| bs = b["summary"] | |
| ts = t["summary"] | |
| def row(label, key, fmt="{:.4f}"): | |
| bv = bs.get(key) | |
| tv = ts.get(key) | |
| b_str = fmt.format(bv) if bv is not None else "-" | |
| t_str = fmt.format(tv) if tv is not None else "-" | |
| delta = "" | |
| if isinstance(bv, (int, float)) and isinstance(tv, (int, float)): | |
| d = tv - bv | |
| delta = f" ({'+' if d >= 0 else ''}{d:.3f})" | |
| print(f" {label:<32} {b_str:>12} {t_str:>12}{delta}") | |
| print(f"\n{'='*80}") | |
| print(f"BASELINE vs TRAINED — {bs['n_episodes']} episodes each") | |
| print(f" baseline label: {b.get('label')} | trained label: {t.get('label')}") | |
| print(f"{'='*80}") | |
| print(f" {'metric':<32} {'baseline':>12} {'trained':>12}") | |
| print(f" {'-'*32} {'-'*12} {'-'*12}") | |
| row("match_completion_rate", "match_completion_rate") | |
| row("win_rate_overall", "win_rate_overall") | |
| row("win_rate_among_completed", "win_rate_among_completed") | |
| row("mean_agent_score", "mean_agent_score", "{:.2f}") | |
| row("mean_wickets_lost", "mean_wickets_lost", "{:.2f}") | |
| row("mean_tool_calls", "mean_tool_calls", "{:.1f}") | |
| row("mean_validity_rate", "mean_validity_rate") | |
| row("mean_composite_reward", "mean_composite_reward") | |
| row("mean_r_result", "mean_r_result") | |
| row("mean_r_cricket", "mean_r_cricket") | |
| row("mean_r_behavior", "mean_r_behavior") | |
| row("mean_r_validity", "mean_r_validity") | |
| print(f"{'='*80}\n") | |
| # ---------------------------------------------------------------------------- | |
| # Main | |
| # ---------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Baseline vs trained eval for CricketCaptain.") | |
| parser.add_argument("--model", default="Qwen/Qwen3-4B-Instruct-2507", help="Base HF model id") | |
| parser.add_argument("--adapter", default=None, help="Optional LoRA adapter directory") | |
| parser.add_argument("--label", default="run", help="Label for this run (used in output)") | |
| parser.add_argument("--episodes", type=int, default=10) | |
| parser.add_argument("--max-overs", type=int, default=5) | |
| parser.add_argument("--opponent-mode", default="heuristic", | |
| choices=["heuristic", "llm_live", "llm_cached", "cricsheet"]) | |
| parser.add_argument("--agent-team", default="india") | |
| parser.add_argument("--eval-pack-id", default="adaptive_t20_v1") | |
| parser.add_argument("--seed-base", type=int, default=10000) | |
| parser.add_argument("--max-tool-calls", type=int, default=800) | |
| parser.add_argument("--max-completion-per-turn", type=int, default=256) | |
| parser.add_argument("--temperature", type=float, default=0.3) | |
| parser.add_argument("--output", default=None, help="JSON output path") | |
| parser.add_argument("--verbose", action="store_true") | |
| parser.add_argument("--compare", nargs=2, default=None, metavar=("BASELINE_JSON", "TRAINED_JSON"), | |
| help="Skip eval; just print comparison from two existing JSON files") | |
| args = parser.parse_args() | |
| if args.compare: | |
| print_comparison(args.compare[0], args.compare[1]) | |
| return | |
| print(f"\nCricketCaptain compare-eval — label='{args.label}'") | |
| print(f" model={args.model} adapter={args.adapter or '(none)'}") | |
| print(f" {args.episodes} episodes × {args.max_overs} overs vs {args.opponent_mode} opponent\n") | |
| model, tokenizer = load_model_for_eval(args.model, args.adapter) | |
| out = run_n_episodes( | |
| model=model, tokenizer=tokenizer, | |
| episodes=args.episodes, max_overs=args.max_overs, | |
| opponent_mode=args.opponent_mode, | |
| agent_team=args.agent_team, eval_pack_id=args.eval_pack_id, | |
| seed_base=args.seed_base, max_tool_calls=args.max_tool_calls, | |
| max_completion_per_turn=args.max_completion_per_turn, | |
| temperature=args.temperature, verbose=args.verbose, | |
| ) | |
| out["label"] = args.label | |
| out["model"] = args.model | |
| out["adapter"] = args.adapter | |
| out["config"] = { | |
| "episodes": args.episodes, "max_overs": args.max_overs, | |
| "opponent_mode": args.opponent_mode, "agent_team": args.agent_team, | |
| "max_tool_calls": args.max_tool_calls, | |
| "max_completion_per_turn": args.max_completion_per_turn, | |
| "temperature": args.temperature, | |
| } | |
| print("\n=== SUMMARY ===") | |
| print(json.dumps(out["summary"], indent=2)) | |
| if args.output: | |
| out_path = Path(args.output) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| with out_path.open("w") as f: | |
| json.dump(out, f, indent=2) | |
| print(f"\nResults → {out_path}") | |
| if __name__ == "__main__": | |
| main() | |