sync: today's source updates (XML-only prompt, reward unclip, neg-reward on loss, pinned versions, configs reorg)
2fc50a9 verified | """ | |
| MT-GRPO training script for CricketCaptain-LLM. | |
| Two-stage curriculum (ToolRL-style): | |
| Stage 1: tool-call mastery — emphasize valid, phase-legal tool usage | |
| Stage 2: strategic behavior — full environment-backed reward (result + cricket + behavior + validity) | |
| Design: | |
| - Training uses TRL GRPO with environment_factory=CricketCaptainToolEnv | |
| - The model interacts with live CricketEnvironment instances over multi-turn tool calls | |
| - Rewards are collected from the environment (environment_reward), not only from stateless prompt parsing | |
| - The opponent policy is part of the environment: heuristic/cricsheet/llm_live/llm_cached | |
| - Plain TRL + Transformers + bitsandbytes + PEFT (LoRA adapters for 4-bit models) | |
| Usage (canonical Qwen3 setup): | |
| python train.py train --config configs/cricket_train_qwen3_warmup.yaml # warmup | |
| python train.py train --config configs/cricket_train_qwen3.yaml # main 5-over | |
| Legacy Qwen3.5 configs live in configs/extras/. | |
| """ | |
| import argparse | |
| import copy | |
| import datetime | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| # ------------------------------------------------------------------ # | |
| # Optional training imports # | |
| # ------------------------------------------------------------------ # | |
| try: | |
| import torch | |
| from datasets import Dataset | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from trl import GRPOConfig, GRPOTrainer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| _TRAIN_IMPORTS_AVAILABLE = True | |
| except ImportError: | |
| torch = None | |
| Dataset = None | |
| LoraConfig = None | |
| get_peft_model = None | |
| prepare_model_for_kbit_training = None | |
| GRPOConfig = None | |
| GRPOTrainer = None | |
| AutoModelForCausalLM = None | |
| AutoTokenizer = None | |
| BitsAndBytesConfig = None | |
| _TRAIN_IMPORTS_AVAILABLE = False | |
| try: | |
| from server.cricket_environment import CricketEnvironment | |
| from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity | |
| from server.markov_engine import SHOT_AGGRESSION | |
| from server.player_roster import build_playing_xi, load_team_roster | |
| from models import CricketAction | |
| from config_yaml import get_game_constants, get_reward_weights | |
| except ImportError: | |
| from cricket_captain.server.cricket_environment import CricketEnvironment | |
| from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity | |
| from cricket_captain.server.markov_engine import SHOT_AGGRESSION | |
| from cricket_captain.server.player_roster import build_playing_xi, load_team_roster | |
| from cricket_captain.models import CricketAction | |
| from cricket_captain.config_yaml import get_game_constants, get_reward_weights | |
| # Load game knowledge once at import time (cached in config_yaml). | |
| _GK = get_game_constants() | |
| _RW = get_reward_weights() | |
| # ------------------------------------------------------------------ # | |
| # Prompt parsing (stateless — reads from rendered observation text) # | |
| # ------------------------------------------------------------------ # | |
| _STRATEGY_RE = re.compile(r"Batting Strategy:\s*(.+)$", re.MULTILINE) | |
| _AGGRESSION_RE = re.compile(r"aggression[\"'=:\s]+([0-9.]+)", re.IGNORECASE) | |
| _PHASE_RE = re.compile(r"Phase:\s+(POWERPLAY|MIDDLE|DEATH)", re.IGNORECASE) | |
| _VALID_TOOLS = { | |
| "call_toss", | |
| "select_batter", | |
| "set_strategy", | |
| "plan_shot", | |
| "play_delivery", | |
| "choose_bowler", | |
| "set_bowling_strategy", | |
| "plan_delivery", | |
| "set_field_setting", | |
| "bowl_delivery", | |
| "reflect_after_ball", | |
| "analyze_situation", | |
| "set_match_plan", | |
| "update_match_plan", | |
| } | |
| def extract_strategy_from_prompt(prompt: str) -> dict | None: | |
| m = _STRATEGY_RE.search(prompt) | |
| if not m or m.group(1).strip().lower().startswith("none"): | |
| return None | |
| agg_m = _AGGRESSION_RE.search(prompt) | |
| agg = float(agg_m.group(1)) if agg_m else 0.5 | |
| return {"phase_intent": m.group(1).strip(), "aggression": agg, "rationale": m.group(1).strip()} | |
| def extract_phase_from_prompt(prompt: str) -> str: | |
| m = _PHASE_RE.search(prompt) | |
| return m.group(1).lower() if m else "middle" | |
| # ------------------------------------------------------------------ # | |
| # Per-turn reward components (all stateless) # | |
| # ------------------------------------------------------------------ # | |
| _XML_FN_RE = re.compile(r"<function\s*=?\s*([^>\s]+)\s*>", re.IGNORECASE) | |
| _XML_PARAM_RE = re.compile(r"<parameter\s*=\s*([^>\s]+)\s*>(.*?)</parameter>", re.IGNORECASE | re.DOTALL) | |
| def _parse_completion(raw: str) -> dict | None: | |
| """Parse a tool-call from the raw completion into our canonical {tool, arguments} dict. | |
| Handles four common model output patterns: | |
| 1. Plain JSON (ideal). | |
| 2. Markdown code block (```json ... ```). | |
| 3. Thinking-model preamble: <think>...</think> followed by JSON. | |
| Qwen3/Qwen3.5 in default mode emits reasoning inside <think> tags; | |
| we strip everything up to and including the closing </think> tag. | |
| 4. XML function-call format that Qwen3.5 was trained on: | |
| <function=tool_name><parameter=foo>bar</parameter>...</function> | |
| Empirically (see logs/run_2026-04-25_21-08-45) every Stage-1 completion | |
| emitted this XML form instead of JSON — so we extract it as a fallback | |
| to give GRPO a non-zero gradient before the model has been trained | |
| onto the JSON contract. | |
| """ | |
| raw = raw.strip() | |
| # Strip <think>...</think> preamble emitted by thinking-mode models. | |
| if "<think>" in raw: | |
| think_end = raw.rfind("</think>") | |
| if think_end != -1: | |
| raw = raw[think_end + len("</think>"):].strip() | |
| if raw.startswith("```"): | |
| lines = raw.split("\n") | |
| raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw | |
| # Try parsing the whole string, then fall back to the first {...} block. | |
| try: | |
| return json.loads(raw) | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| start = raw.find("{") | |
| end = raw.rfind("}") | |
| if start != -1 and end > start: | |
| try: | |
| return json.loads(raw[start : end + 1]) | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| # XML function-call fallback (Qwen3.5 default tool-call emission style). | |
| fn_match = _XML_FN_RE.search(raw) | |
| if fn_match: | |
| tool = fn_match.group(1).strip().strip("\"'") | |
| arguments: dict[str, Any] = {} | |
| for pname, pval in _XML_PARAM_RE.findall(raw): | |
| v = pval.strip() | |
| # Coerce numeric/bool literals so downstream validators accept them. | |
| try: | |
| arguments[pname] = json.loads(v) | |
| except (json.JSONDecodeError, ValueError): | |
| arguments[pname] = v | |
| return {"tool": tool, "arguments": arguments} | |
| return None | |
| # Bounded LRU-ish cache. Each snapshot is a deepcopy of CricketEnvironment | |
| # (~1 MB) and only used by the LEGACY single-turn r_environment_rollout path, | |
| # not by the multi-turn environment_factory training path. Cap at 4096 entries | |
| # (~4 GB worst case) so a long collect_prompts call can't blow up host RAM. | |
| _PROMPT_ENV_SNAPSHOTS: dict[str, CricketEnvironment] = {} | |
| _PROMPT_SNAPSHOT_CAP = 4096 | |
| _ENV_REWARD_ROLLOUT_STEPS = 12 | |
| def _remember_prompt(obs_text: str, env: CricketEnvironment) -> str: | |
| """Format an observation and keep the exact env state for rollout reward.""" | |
| prompt = _format_prompt(obs_text) | |
| if len(_PROMPT_ENV_SNAPSHOTS) >= _PROMPT_SNAPSHOT_CAP: | |
| # Evict oldest insertion (dict preserves insertion order in py3.7+). | |
| oldest_key = next(iter(_PROMPT_ENV_SNAPSHOTS)) | |
| del _PROMPT_ENV_SNAPSHOTS[oldest_key] | |
| _PROMPT_ENV_SNAPSHOTS[prompt] = copy.deepcopy(env) | |
| return prompt | |
| def r_environment_rollout(prompt: str, completion: str) -> float | None: | |
| """Env-backed score for a generated tool call plus short continuation. | |
| Returns None when the prompt was not collected from an env snapshot, allowing | |
| callers to fall back to stateless scoring. Otherwise returns [0, 1], where 0 | |
| means invalid JSON/tool-for-state and higher values reflect the env reward. | |
| """ | |
| snapshot = _PROMPT_ENV_SNAPSHOTS.get(prompt) | |
| if snapshot is None: | |
| return None | |
| data = _parse_completion(completion) | |
| if data is None: | |
| return 0.0 | |
| tool = data.get("tool", "") | |
| args = data.get("arguments", {}) | |
| if not isinstance(args, dict): | |
| return 0.0 | |
| env = copy.deepcopy(snapshot) | |
| if tool not in env._get_available_tools(): | |
| return 0.0 | |
| try: | |
| obs = env.step(CricketAction(tool=tool, arguments=args)) | |
| except Exception: | |
| return 0.0 | |
| reward = float(obs.reward or 0.0) | |
| rng = random.Random(hash(prompt + completion) & 0xFFFFFFFF) | |
| roster = build_playing_xi(getattr(env, "_agent_roster", [])) | |
| for _ in range(_ENV_REWARD_ROLLOUT_STEPS): | |
| if obs.done: | |
| break | |
| action = _random_action( | |
| rng, | |
| obs.game_state, | |
| obs.available_tools, | |
| obs.current_bowler.get("type") if obs.current_bowler else None, | |
| roster, | |
| ) | |
| obs = env.step(action) | |
| reward += float(obs.reward or 0.0) | |
| if obs.done and env.state.reward_breakdown: | |
| reward += float(env.state.reward_breakdown.get("composite", 0.0)) | |
| # Map rollout reward into [0,1] while preserving penalties for bad tool choices. | |
| return round(max(0.0, min(1.0, 0.5 + reward)), 4) | |
| def r_validity(completion: str) -> float: | |
| """Schema reward for tool calling. | |
| Exact env-executable calls receive 1.0. Malformed but parseable JSON gets a | |
| small shaping signal so early GRPO has non-zero variance before the model has | |
| learned the strict `{"tool": ..., "arguments": {...}}` contract. | |
| """ | |
| data = _parse_completion(completion) | |
| if data is None: | |
| return 0.0 | |
| if not isinstance(data, dict): | |
| return 0.05 | |
| tool = data.get("tool", "") | |
| args = data.get("arguments", {}) | |
| if "tool" not in data or "arguments" not in data: | |
| return 0.15 | |
| if tool not in _VALID_TOOLS: | |
| return 0.25 | |
| if not isinstance(args, dict): | |
| return 0.35 | |
| if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION: | |
| return 0.5 | |
| if tool == "set_strategy": | |
| agg = args.get("aggression") | |
| if not isinstance(agg, (int, float)): | |
| return 0.5 | |
| if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION: | |
| return 0.5 | |
| if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}: | |
| if args.get("bowler_type") not in (None, "pace", "spin"): | |
| return 0.5 | |
| return 1.0 | |
| # Kept for backward compatibility with smoke test | |
| r_format = r_validity | |
| def _bowling_phase_fit(delivery_type: str, phase: str) -> float: | |
| """Return 1.0 if delivery_type fits the phase, else 0.4. Loaded from game_knowledge.yaml.""" | |
| valid = _GK.bowling_phase_delivery.get(phase, []) | |
| return 1.0 if delivery_type in valid else 0.4 | |
| def _field_phase_fit(setting: str, phase: str) -> float: | |
| """Return phase-fit score for a field preset. Loaded from game_knowledge.yaml.""" | |
| return float(_GK.field_phase_fit.get(setting, {}).get(phase, 0.5)) | |
| def r_behavior_stateless(prompt: str, completion: str) -> float: | |
| """ | |
| r_behavior: plan-action coherence score, covering ALL tool types. | |
| Previously only graded play_delivery, leaving ~60% of decision points | |
| unscored. Now each tool family gets an appropriate coherence signal: | |
| - play_delivery / plan_shot : aggression-match + rationale + phase fit | |
| - set_strategy : phase appropriateness + rationale quality | |
| - set_bowling_strategy / plan_delivery : delivery-phase fit + rationale | |
| - set_field_setting : field-phase fit | |
| - reflect_after_ball : rationale specificity | |
| - choose_bowler / select_batter : rationale specificity | |
| - bowl_delivery / call_toss / analyze_situation : not graded (no plan) | |
| """ | |
| data = _parse_completion(completion) | |
| if data is None: | |
| return 0.0 | |
| tool = data.get("tool", "") | |
| args = data.get("arguments", {}) | |
| strategy = extract_strategy_from_prompt(prompt) | |
| phase = extract_phase_from_prompt(prompt) | |
| # Base reward for any valid structured action — ensures GRPO always has a | |
| # positive gradient to reinforce correct tool use even when coherence can't | |
| # be fully scored (no declared strategy, unscored tool types, etc.). | |
| _BASE = 0.10 | |
| if tool == "play_delivery": | |
| shot = args.get("shot_intent", "") | |
| if shot not in SHOT_AGGRESSION: | |
| return 0.0 | |
| if strategy is None: | |
| return _BASE # valid shot, no declared strategy to align against | |
| agg = strategy["aggression"] | |
| a_match = aggression_match(agg, shot) | |
| r_spec = rationale_specificity(strategy.get("rationale", "")) | |
| p_approp = phase_appropriate(agg, phase) | |
| return round(_BASE + (1 - _BASE) * (0.50 * a_match + 0.30 * r_spec + 0.20 * p_approp), 4) | |
| if tool == "plan_shot": | |
| shot = args.get("shot_intent", "") | |
| if shot not in SHOT_AGGRESSION: | |
| return 0.0 | |
| if strategy is None: | |
| return _BASE # valid plan structure, no context to grade against | |
| agg = strategy["aggression"] | |
| a_match = aggression_match(agg, shot) | |
| r_spec = rationale_specificity(args.get("rationale", "")) | |
| return round(_BASE + (1 - _BASE) * (0.60 * a_match + 0.40 * r_spec), 4) | |
| if tool == "set_strategy": | |
| agg = args.get("aggression", 0.5) | |
| try: | |
| agg = float(agg) | |
| except (TypeError, ValueError): | |
| agg = 0.5 | |
| r_spec = rationale_specificity(args.get("rationale", "")) | |
| p_approp = phase_appropriate(agg, phase) | |
| return round(0.50 * p_approp + 0.50 * r_spec, 4) | |
| if tool in {"set_bowling_strategy", "plan_delivery"}: | |
| delivery_type = args.get("delivery_type", "") | |
| r_spec = rationale_specificity(args.get("rationale", "")) | |
| p_fit = _bowling_phase_fit(delivery_type, phase) | |
| return round(0.50 * r_spec + 0.50 * p_fit, 4) | |
| if tool == "set_field_setting": | |
| setting = args.get("setting", "Balanced") | |
| return round(_field_phase_fit(setting, phase), 4) | |
| if tool == "reflect_after_ball": | |
| return round(rationale_specificity(args.get("reflection", "")), 4) | |
| if tool in {"choose_bowler", "select_batter"}: | |
| return round(rationale_specificity(args.get("rationale", "")), 4) | |
| if tool == "set_match_plan": | |
| # Score completeness + rationale quality | |
| fields = ["powerplay_intent", "middle_intent", "death_intent", "risk_budget", "trigger_conditions"] | |
| completeness = sum(1 for f in fields if args.get(f, "").strip()) / len(fields) | |
| r_spec = rationale_specificity(args.get("rationale", "")) | |
| return round(0.6 * completeness + 0.4 * r_spec, 4) | |
| if tool == "update_match_plan": | |
| # Score whether update is justified by a match-state trigger | |
| reason = args.get("reason", args.get("rationale", "")) | |
| triggers = ["wicket", "target", "rrr", "phase", "field", "rate", "pressure", "boundary", "dot"] | |
| hits = sum(1 for t in triggers if t in reason.lower()) | |
| r_spec = rationale_specificity(reason) | |
| return round(min(1.0, 0.5 * r_spec + 0.5 * min(hits / 3, 1.0)), 4) | |
| # bowl_delivery, call_toss, analyze_situation — structurally valid but no | |
| # coherence plan to grade. Return a small base so GRPO distinguishes these | |
| # from invalid JSON (which scores 0.0) without over-weighting them. | |
| if tool in {"bowl_delivery", "call_toss", "analyze_situation"}: | |
| return 0.15 | |
| return 0.0 | |
| def r_adaptation_stateless(prompt: str, completion: str) -> float: | |
| if r_validity(completion) == 0.0: | |
| return 0.0 # don't reward invalid tool calls for context-matching | |
| data = _parse_completion(completion) | |
| if data is None: | |
| return 0.0 | |
| text = json.dumps(data.get("arguments", {})).lower() | |
| phase = extract_phase_from_prompt(prompt) | |
| score = 0.0 | |
| if phase in text: | |
| score += 0.25 | |
| if any(word in prompt.lower() for word in ("target:", "death", "wicket", "opponent last plan")): | |
| score += 0.25 | |
| if any(word in text for word in ("adjust", "target", "field", "phase", "wicket", "pressure", "matchup")): | |
| score += 0.25 | |
| if data.get("tool") in {"plan_shot", "plan_delivery", "reflect_after_ball", "choose_bowler", "select_batter"}: | |
| score += 0.25 | |
| return round(min(score, 1.0), 4) | |
| def r_opponent_awareness_stateless(prompt: str, completion: str) -> float: | |
| if r_validity(completion) == 0.0: | |
| return 0.0 # don't reward invalid tool calls for context-matching | |
| data = _parse_completion(completion) | |
| if data is None: | |
| return 0.0 | |
| text = json.dumps(data.get("arguments", {})).lower() | |
| prompt_l = prompt.lower() | |
| hits = 0 | |
| for word in ("opponent", "field", "bowler", "batter", "spin", "pace", "aggressive", "defensive"): | |
| if word in prompt_l and word in text: | |
| hits += 1 | |
| return round(min(hits / 3, 1.0), 4) | |
| # ------------------------------------------------------------------ # | |
| # Composite reward function — TRL 0.24 signature # | |
| # ------------------------------------------------------------------ # | |
| def make_reward_fn(curriculum_stage: int): | |
| """ | |
| Returns reward_fn(prompts, completions, **kwargs) → list[float]. | |
| Weights align with compute_episode_reward in reward_calculator.py: | |
| r_env — one-step env rollout reward when prompt snapshot exists | |
| r_behavior — stateless tactical/tool coherence | |
| r_validity — JSON/tool schema validity | |
| """ | |
| # Minimum reward for any structurally valid completion — ensures GRPO has a | |
| # positive gradient to reinforce valid tool use even for unscored tool types. | |
| _VALID_FLOOR = 0.05 | |
| def reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]: | |
| rewards = [] | |
| for prompt, completion in zip(prompts, completions): | |
| fmt = r_validity(completion) | |
| env_score = r_environment_rollout(prompt, completion) | |
| if curriculum_stage == 1: | |
| # Length-efficiency penalty: a valid JSON tool call is ≤400 chars. | |
| # Models with thinking mode (Qwen3/3.5) generate 800-2000 char | |
| # preambles before the JSON; penalise that verbosity so GRPO | |
| # learns to emit short, direct JSON. The penalty scales from | |
| # 1.0 at ≤400 chars to 0.0 at ≥2400 chars (linear). | |
| _JSON_TARGET = 400 | |
| _RAMP_RANGE = 2000 | |
| length_eff = max(0.0, 1.0 - max(0, len(completion) - _JSON_TARGET) / _RAMP_RANGE) | |
| base = 0.5 * fmt + 0.5 * (env_score if env_score is not None else fmt) | |
| rewards.append(round(length_eff * base, 4)) | |
| continue | |
| behavior = r_behavior_stateless(prompt, completion) | |
| adapt = r_adaptation_stateless(prompt, completion) | |
| aware = r_opponent_awareness_stateless(prompt, completion) | |
| r_beh = ( | |
| _RW.behavior_coherence * behavior | |
| + _RW.behavior_adaptation * adapt | |
| + _RW.behavior_opponent_awareness * aware | |
| ) | |
| if env_score is None: | |
| reward = _RW.training_behavior * r_beh + _RW.training_validity * fmt | |
| else: | |
| reward = 0.45 * env_score + 0.40 * r_beh + 0.15 * fmt | |
| # Floor: valid JSON should always beat invalid JSON (reward=0) | |
| if fmt > 0.0 and (env_score is None or env_score > 0.0): | |
| reward = max(reward, _VALID_FLOOR) | |
| rewards.append(round(reward, 4)) | |
| return rewards | |
| reward_fn.__name__ = f"stage{curriculum_stage}_reward" | |
| return reward_fn | |
| # ------------------------------------------------------------------ # | |
| # Prompt collection (direct env instantiation — no server needed) # | |
| # ------------------------------------------------------------------ # | |
| SYSTEM_PROMPT = ( | |
| "You are an expert adaptive cricket captain. Each turn you receive a scorecard " | |
| "and must choose exactly one cricket captaincy tool call.\n\n" | |
| "EXECUTE FIRST — strict rule:\n" | |
| " - The match only progresses when you call `play_delivery` (batting) or\n" | |
| " `bowl_delivery` (bowling). Every other tool is overhead.\n" | |
| " - Default action on EVERY ball: call `play_delivery` / `bowl_delivery` with\n" | |
| " plan args INLINE: e.g. `play_delivery(shot_intent='single', risk='low', rationale='rotate')`\n" | |
| " or `bowl_delivery(line='outside_off', length='good', delivery_type='stock')`.\n" | |
| " - Use `set_match_plan` ONCE at the very start of an innings to declare strategy.\n" | |
| " - Use `set_strategy` / `set_bowling_strategy` ONCE per phase boundary.\n" | |
| " - DO NOT call `plan_shot` or `plan_delivery` (deprecated) — they only add a\n" | |
| " wasted turn. Pass the same parameters to play_delivery / bowl_delivery directly.\n" | |
| " - SKIP `reflect_after_ball` unless the previous ball was a wicket or boundary.\n" | |
| " - You are scored on MATCH OUTCOMES, not on philosophical depth. Bloated\n" | |
| " pre-ball planning truncates the episode and you forfeit the result reward.\n\n" | |
| "THINKING BUDGET — HARD LIMIT:\n" | |
| " - Per turn: ONE sentence of reasoning, max 30 tokens, inside <think>...</think>.\n" | |
| " - Do NOT enumerate options, restate the scorecard, or re-derive the plan.\n" | |
| " - Bad: '<think>This is the first ball, the field is balanced, Kohli is on strike at 0.45 aggression, I should consider...'\n" | |
| " - Good: '<think>Powerplay, balanced field — single to rotate.</think>'\n" | |
| " - Token budget per rollout is finite. Long thinking = match truncated = ZERO result reward.\n" | |
| " - The plan you set at the start carries the strategy; do not re-derive it every ball.\n\n" | |
| "Emit exactly one tool call wrapped in <tool_call>...</tool_call> XML tags. " | |
| "Bare JSON without the wrapper is NOT recognized and will end the rollout.\n" | |
| 'Example: <tool_call>{"name": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotate strike"}}</tool_call>\n\n' | |
| "Available tools:\n" | |
| " call_toss — Call heads/tails and choose bat/bowl\n" | |
| " select_batter — Choose batter profile for the match situation\n" | |
| " set_strategy — Declare batting intent (aggression 0–1, rationale)\n" | |
| " plan_shot — Pre-ball batting plan\n" | |
| " play_delivery — Choose a shot and advance the game\n" | |
| " choose_bowler — Choose bowler profile for the situation\n" | |
| " set_bowling_strategy — Declare bowling line/length/type/rationale\n" | |
| " plan_delivery — Pre-ball bowling plan\n" | |
| " set_field_setting — Aggressive/Balanced/Defensive field\n" | |
| " bowl_delivery — Execute the delivery\n" | |
| " reflect_after_ball — Adapt after the previous ball\n" | |
| " analyze_situation — Query pitch/bowler/field info\n\n" | |
| "Shot intents: leave | defensive | single | rotate | boundary | six\n\n" | |
| "PRIORITIES (in order): (1) finish the match, (2) win the match, (3) score well per ball.\n" | |
| "Verbose reasoning forfeits all three. Decide fast, act, move on." | |
| ) | |
| def get_system_prompt(stage: int = 2) -> str: | |
| return SYSTEM_PROMPT | |
| _RANDOM_SHOTS = list(SHOT_AGGRESSION.keys()) | |
| _RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"] | |
| _RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"] | |
| def _training_roster(agent_team: str | None = None) -> list[dict]: | |
| team = agent_team or os.environ.get("CRICKET_AGENT_TEAM") | |
| if not team: | |
| raise ValueError("Roster-backed training requires --agent-team or CRICKET_AGENT_TEAM.") | |
| roster = load_team_roster(team) | |
| if not roster: | |
| raise ValueError(f"No player profile roster found for agent team '{team}'.") | |
| playing_xi = build_playing_xi(roster) | |
| if len(playing_xi) < 11: | |
| raise ValueError(f"Player profile roster for '{team}' could not produce a playing XI.") | |
| return playing_xi | |
| def _sample_batter(rng: random.Random, roster: list[dict]) -> dict: | |
| batters = [p for p in roster if p.get("role") != "bowler"] or roster | |
| if not batters: | |
| raise ValueError("Roster-backed training requires at least one batting-capable player.") | |
| return rng.choice(batters) | |
| def _sample_bowler(rng: random.Random, roster: list[dict]) -> dict: | |
| bowlers = [p for p in roster if p.get("bowler_type")] | |
| if not bowlers: | |
| raise ValueError("Roster-backed training requires at least one bowling-capable player.") | |
| return rng.choice(bowlers) | |
| def _random_action( | |
| rng: random.Random, | |
| game_state: str = "batting", | |
| available_tools: list[str] | None = None, | |
| current_bowler_type: str | None = None, | |
| roster: list[dict] | None = None, | |
| ) -> CricketAction: | |
| legal = set(available_tools or []) | |
| def can(tool: str) -> bool: | |
| return available_tools is None or tool in legal | |
| def match_plan_action() -> CricketAction: | |
| return CricketAction(tool="set_match_plan", arguments={ | |
| "powerplay_intent": "Use roster strengths to establish tempo while protecting wickets", | |
| "middle_intent": "Rotate strike, attack favorable matchups, and preserve finishers", | |
| "death_intent": "Commit boundary options with wickets and target pressure in mind", | |
| "risk_budget": "Escalate only when phase, target, and wickets justify the risk", | |
| "trigger_conditions": "Review after wicket clusters, phase changes, target pressure, or repeated boundary/dot outcomes", | |
| "rationale": "Create a long-horizon plan before choosing ball-by-ball tactics", | |
| }) | |
| if game_state == "toss": | |
| return CricketAction( | |
| tool="call_toss", | |
| arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])}, | |
| ) | |
| if can("set_match_plan") and rng.random() < 0.12: | |
| return match_plan_action() | |
| if can("update_match_plan") and rng.random() < 0.08: | |
| return CricketAction(tool="update_match_plan", arguments={ | |
| "reason": "Adjust plan after phase, score pressure, wickets, and field information", | |
| "risk_budget": "Shift risk based on current target pressure and wickets in hand", | |
| }) | |
| if game_state == "bowling": | |
| choice = rng.random() | |
| if choice < 0.15 and can("choose_bowler"): | |
| bowler = _sample_bowler(rng, roster or []) | |
| return CricketAction( | |
| tool="choose_bowler", | |
| arguments={ | |
| "name": bowler["name"], | |
| "bowler_type": bowler["bowler_type"], | |
| "style": bowler.get("bowl_style", bowler.get("style", "stock")), | |
| "rationale": "Match roster bowler to phase, batter matchup, and remaining overs", | |
| }, | |
| ) | |
| if choice < 0.35 and can("plan_delivery"): | |
| return CricketAction( | |
| tool="plan_delivery", | |
| arguments={ | |
| "bowler_type": current_bowler_type or rng.choice(["pace", "spin"]), | |
| "line": rng.choice(["stumps", "outside off", "wide"]), | |
| "length": rng.choice(["good length", "full", "short", "yorker"]), | |
| "delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]), | |
| "rationale": "Use field and batter style to control scoring zones", | |
| }, | |
| ) | |
| if choice < 0.5 and can("set_field_setting"): | |
| return CricketAction(tool="set_field_setting", arguments={"setting": rng.choice(["Aggressive", "Balanced", "Defensive"])}) | |
| if choice < 0.6 and can("reflect_after_ball"): | |
| return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Adjust line and field after the last ball"}) | |
| if can("bowl_delivery"): | |
| return CricketAction(tool="bowl_delivery", arguments={}) | |
| if can("set_bowling_strategy"): | |
| return CricketAction(tool="set_bowling_strategy", arguments={ | |
| "bowler_type": current_bowler_type or "pace", | |
| "line": "outside off", | |
| "length": "good length", | |
| "delivery_type": "stock", | |
| "rationale": "Set a legal bowling plan before executing the delivery", | |
| }) | |
| raise ValueError(f"No legal bowling action available from tools={available_tools}") | |
| choice = rng.random() | |
| if choice < 0.15 and can("select_batter"): | |
| batter = _sample_batter(rng, roster or []) | |
| return CricketAction( | |
| tool="select_batter", | |
| arguments={ | |
| "name": batter["name"], | |
| "style": batter.get("style", "balanced"), | |
| "aggression": round(float(batter["aggression"]), 2), | |
| "rationale": "Select batter based on phase, wickets, and target pressure", | |
| }, | |
| ) | |
| if choice < 0.3 and can("set_strategy"): | |
| return CricketAction( | |
| tool="set_strategy", | |
| arguments={ | |
| "phase_intent": rng.choice(["attack", "consolidate", "rotate"]), | |
| "aggression": round(rng.uniform(0.1, 0.9), 2), | |
| "rationale": "Align roster strengths with phase, target pressure, and wickets", | |
| }, | |
| ) | |
| if choice < 0.45 and can("plan_shot"): | |
| return CricketAction( | |
| tool="plan_shot", | |
| arguments={ | |
| "shot_intent": rng.choice(_RANDOM_SHOTS), | |
| "target_area": rng.choice(_RANDOM_ZONES), | |
| "trajectory": rng.choice(["ground", "lofted", "aerial"]), | |
| "risk": rng.choice(["low", "balanced", "high"]), | |
| "rationale": "Plan shot against bowler, field, and required rate", | |
| }, | |
| ) | |
| if choice < 0.55 and can("analyze_situation"): | |
| return CricketAction( | |
| tool="analyze_situation", | |
| arguments={"query_type": rng.choice(_RANDOM_QUERIES)}, | |
| ) | |
| if choice < 0.65 and can("reflect_after_ball"): | |
| return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Revise risk after previous ball"}) | |
| if can("play_delivery"): | |
| return CricketAction( | |
| tool="play_delivery", | |
| arguments={"shot_intent": rng.choice(_RANDOM_SHOTS), "explanation": "Advance the innings according to the current plan"}, | |
| ) | |
| raise ValueError(f"No legal batting action available from tools={available_tools}") | |
| def collect_prompts( | |
| n_prompts: int, | |
| task: str = "stage2_full", | |
| seed: int = 42, | |
| agent_team: str | None = None, | |
| opponent_mode: str = "heuristic", | |
| ) -> list[str]: | |
| """ | |
| Collect game-state prompts by running episodes with random actions. | |
| Returns a list of prompt strings (one per game state observation). | |
| """ | |
| rng = random.Random(seed) | |
| roster = _training_roster(agent_team) | |
| _PROMPT_ENV_SNAPSHOTS.clear() | |
| prompts: list[str] = [] | |
| ep_count = 0 | |
| while len(prompts) < n_prompts: | |
| env = CricketEnvironment() | |
| obs = env.reset(seed=rng.randint(0, 99999), options={ | |
| "task": task, | |
| "random_start": True, | |
| "agent_team": agent_team or os.environ.get("CRICKET_AGENT_TEAM"), | |
| "opponent_mode": opponent_mode, | |
| }) | |
| prompts.append(_remember_prompt(obs.prompt_text, env)) | |
| steps = 0 | |
| while not obs.done and steps < 80: | |
| action = _random_action( | |
| rng, | |
| obs.game_state, | |
| obs.available_tools, | |
| obs.current_bowler.get("type") if obs.current_bowler else None, | |
| roster, | |
| ) | |
| obs = env.step(action) | |
| if not obs.done: | |
| prompts.append(_remember_prompt(obs.prompt_text, env)) | |
| steps += 1 | |
| ep_count += 1 | |
| if ep_count % 10 == 0: | |
| print(f" Collected {len(prompts)} prompts from {ep_count} episodes …", flush=True) | |
| print(f"Collected {len(prompts)} prompts from {ep_count} episodes.") | |
| return prompts[:n_prompts] | |
| def _format_prompt(obs_text: str) -> str: | |
| """Wrap the observation in a chat-style user message.""" | |
| return f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{obs_text}<|im_end|>\n<|im_start|>assistant\n" | |
| def build_dataset(prompts: list[str]) -> Dataset: | |
| if Dataset is None: | |
| raise ImportError("datasets is required for training. Install with: pip install '.[train]'") | |
| return Dataset.from_dict({"prompt": prompts}) | |
| class CricketCaptainToolEnv: | |
| """TRL environment wrapper exposing CricketCaptain actions as real tools.""" | |
| _stats_lock = threading.Lock() | |
| def __init__(self): | |
| self.env = CricketEnvironment() | |
| self.reward = 0.0 | |
| self.done = False | |
| self.final_reward = 0.0 | |
| self._episode_seed: int | None = None | |
| self._episode_started = False | |
| self._max_tool_iters: int | None = None | |
| self._episode_had_step = False | |
| self._episode_logged = False | |
| def _maybe_log_episode_end(self, termination_reason: str): | |
| # Avoid double-logging the same episode (e.g. once at termination, again on reset()). | |
| if self._episode_logged: | |
| return | |
| stats_path = os.environ.get("CRICKET_EPISODE_STATS_PATH") | |
| if not stats_path: | |
| return | |
| state = getattr(self.env, "state", None) | |
| payload = { | |
| "ts": datetime.datetime.now().isoformat(), | |
| "seed": self._episode_seed, | |
| "done": bool(self.done), | |
| "termination_reason": termination_reason, | |
| "reward_running_sum": float(self.reward), | |
| "final_reward_bonus": float(self.final_reward), | |
| } | |
| if state is not None: | |
| # ---- match config / context ---- | |
| payload["max_overs"] = getattr(state, "max_overs", None) | |
| payload["opponent_mode"] = getattr(state, "opponent_mode", None) | |
| payload["agent_team"] = getattr(state, "eval_pack_id", None) or getattr(state, "agent_team", None) | |
| payload["innings_type"] = getattr(state, "innings_type", None) | |
| payload["game_state"] = getattr(state, "game_state", None) | |
| # ---- match outcome ---- | |
| payload["overs_played"] = getattr(state, "over", None) | |
| payload["balls_played"] = getattr(state, "ball", None) | |
| payload["agent_score"] = getattr(state, "total_score", None) | |
| payload["wickets_lost"] = getattr(state, "wickets_lost", None) | |
| payload["first_innings_score"] = getattr(state, "first_innings_score", None) | |
| payload["target"] = getattr(state, "target", None) | |
| payload["match_result"] = getattr(state, "match_result", None) or None | |
| # ---- tool calls ---- | |
| tool_calls_made = int(getattr(state, "tool_calls_made", 0) or 0) | |
| payload["tool_calls"] = tool_calls_made | |
| tool_history = getattr(state, "tool_history", None) or [] | |
| tool_breakdown: dict[str, int] = {} | |
| for c in tool_history: | |
| t = c.get("tool", "unknown") | |
| tool_breakdown[t] = tool_breakdown.get(t, 0) + 1 | |
| payload["tool_breakdown"] = tool_breakdown | |
| payload["analyze_calls"] = len(getattr(state, "analyze_calls", []) or []) | |
| # ---- per-turn rubric averages (mean across the full episode) ---- | |
| def _mean(xs): | |
| xs = list(xs or []) | |
| return round(sum(xs) / len(xs), 4) if xs else None | |
| payload["mean_coherence"] = _mean(getattr(state, "coherence_scores", None)) | |
| payload["mean_adaptation"] = _mean(getattr(state, "adaptation_scores", None)) | |
| payload["mean_opponent_awareness"] = _mean(getattr(state, "opponent_awareness_scores", None)) | |
| payload["mean_regret"] = _mean(getattr(state, "regret_scores", None)) | |
| payload["mean_plan_commitment"] = _mean(getattr(state, "plan_commitment_scores", None)) | |
| payload["mean_plan_freshness"] = _mean(getattr(state, "plan_freshness_scores", None)) | |
| payload["strategy_changes"] = getattr(state, "strategy_changes", None) | |
| payload["plan_version"] = getattr(state, "plan_version", None) | |
| # ---- composite + per-rubric reward (already computed in reward_calculator) ---- | |
| if getattr(state, "reward_breakdown", None): | |
| payload["reward_breakdown"] = dict(state.reward_breakdown) | |
| with self._stats_lock: | |
| with open(stats_path, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(payload, ensure_ascii=False) + "\n") | |
| f.flush() | |
| self._episode_logged = True | |
| def reset(self, **kwargs) -> str: | |
| # If the previous episode ended because the trainer hit the tool-iteration cap, | |
| # TRL will stop calling tools and then call reset() for the next scenario. | |
| # In that case, self.done will still be False, but tool_calls_made will be at/near the cap. | |
| if self._episode_started and self._episode_had_step and not self._episode_logged: | |
| prev_calls = getattr(getattr(self.env, "state", None), "tool_calls_made", None) | |
| if self.done: | |
| self._maybe_log_episode_end("natural") | |
| elif self._max_tool_iters and prev_calls is not None and int(prev_calls) >= int(self._max_tool_iters): | |
| self._maybe_log_episode_end("cap") | |
| # Otherwise: trainer reset the env mid-episode (e.g. generation bookkeeping). | |
| # Don't log — it would skew the termination distribution. | |
| self.reward = 0.0 | |
| self.done = False | |
| self.final_reward = 0.0 | |
| self._episode_seed = kwargs.get("seed") | |
| self._episode_started = True | |
| self._episode_had_step = False | |
| self._episode_logged = False | |
| self._max_tool_iters = ( | |
| int(kwargs["max_tool_calling_iterations"]) | |
| if "max_tool_calling_iterations" in kwargs and kwargs["max_tool_calling_iterations"] is not None | |
| else (int(os.environ["CRICKET_MAX_TOOL_ITERS"]) if os.environ.get("CRICKET_MAX_TOOL_ITERS") else None) | |
| ) | |
| obs = self.env.reset(seed=kwargs.get("seed"), options={ | |
| "task": kwargs.get("task", "stage2_full"), | |
| "random_start": bool(kwargs.get("random_start", False)), | |
| "max_overs": int(kwargs.get("max_overs", 5)), | |
| "eval_pack_id": kwargs.get("eval_pack_id", "adaptive_t20_v1"), | |
| "opponent_mode": kwargs.get("opponent_mode", "heuristic"), | |
| "opponent_cache_path": kwargs.get("opponent_cache_path"), | |
| "agent_team": kwargs.get("agent_team"), | |
| }) | |
| return obs.prompt_text | |
| def _apply(self, tool: str, arguments: dict[str, Any]) -> str: | |
| if self.done: | |
| raise ValueError("Match is already finished.") | |
| self._episode_had_step = True | |
| available = self.env.state.game_state and self.env._get_available_tools() | |
| if tool not in available: | |
| self.reward -= 0.2 | |
| raise ValueError(f"Tool '{tool}' is not available. Available tools: {available}") | |
| obs = self.env.step(CricketAction(tool=tool, arguments=arguments)) | |
| self.done = bool(obs.done) | |
| self.reward += float(obs.reward or 0.0) | |
| if obs.done and self.env.state.reward_breakdown: | |
| self.final_reward = float(self.env.state.reward_breakdown.get("composite", 0.0)) | |
| self.reward += self.final_reward | |
| # Log at the time of termination (do not wait for reset()) so the file appears promptly. | |
| if self.done: | |
| self._maybe_log_episode_end("natural") | |
| # Also log cap termination as soon as we hit it, so runs always get stats even if TRL delays reset(). | |
| elif self._max_tool_iters: | |
| state = getattr(self.env, "state", None) | |
| calls = getattr(state, "tool_calls_made", None) if state is not None else None | |
| if calls is not None and int(calls) >= int(self._max_tool_iters): | |
| self._maybe_log_episode_end("cap") | |
| return obs.prompt_text | |
| def call_toss(self, call: str, decision: str) -> str: | |
| """ | |
| Call the coin toss and choose whether to bat or bowl if the toss is won. | |
| Args: | |
| call: Coin call, either "heads" or "tails". | |
| decision: Preferred decision, either "bat" or "bowl". | |
| Returns: | |
| Updated match observation after the toss. | |
| """ | |
| return self._apply("call_toss", {"call": call, "decision": decision}) | |
| def set_match_plan( | |
| self, | |
| powerplay_intent: str, | |
| middle_intent: str, | |
| death_intent: str, | |
| risk_budget: str, | |
| trigger_conditions: str, | |
| rationale: str, | |
| ) -> str: | |
| """ | |
| Establish the long-horizon plan for the innings. | |
| Args: | |
| powerplay_intent: Plan for overs in the powerplay. | |
| middle_intent: Plan for middle overs. | |
| death_intent: Plan for death overs. | |
| risk_budget: How wickets, overs, and target pressure affect risk. | |
| trigger_conditions: Match-state changes that should trigger a plan update. | |
| rationale: Why this plan fits the roster and match situation. | |
| Returns: | |
| Updated match observation after setting the plan. | |
| """ | |
| return self._apply("set_match_plan", { | |
| "powerplay_intent": powerplay_intent, | |
| "middle_intent": middle_intent, | |
| "death_intent": death_intent, | |
| "risk_budget": risk_budget, | |
| "trigger_conditions": trigger_conditions, | |
| "rationale": rationale, | |
| }) | |
| def update_match_plan(self, reason: str, risk_budget: str = "", trigger_conditions: str = "") -> str: | |
| """ | |
| Update the long-horizon plan after a meaningful match-state change. | |
| Args: | |
| reason: Specific reason for updating the plan. | |
| risk_budget: Optional revised risk budget. | |
| trigger_conditions: Optional revised trigger conditions. | |
| Returns: | |
| Updated match observation after revising the plan. | |
| """ | |
| args = {"reason": reason} | |
| if risk_budget: | |
| args["risk_budget"] = risk_budget | |
| if trigger_conditions: | |
| args["trigger_conditions"] = trigger_conditions | |
| return self._apply("update_match_plan", args) | |
| def select_batter(self, name: str, style: str, aggression: float, rationale: str) -> str: | |
| """ | |
| Select the next batter from the configured roster. | |
| Args: | |
| name: Player name from the team roster. | |
| style: Batter style from the roster or tactical role. | |
| aggression: Batting aggression from 0.0 to 1.0. | |
| rationale: Why this batter fits the phase, wickets, and target. | |
| Returns: | |
| Updated match observation after selecting the batter. | |
| """ | |
| return self._apply("select_batter", { | |
| "name": name, | |
| "style": style, | |
| "aggression": aggression, | |
| "rationale": rationale, | |
| }) | |
| def set_strategy(self, phase_intent: str, aggression: float, rationale: str) -> str: | |
| """ | |
| Set batting strategy for the current phase. | |
| Args: | |
| phase_intent: Tactical batting intent for this phase. | |
| aggression: Batting aggression from 0.0 to 1.0. | |
| rationale: Why the strategy fits score, wickets, target, and field. | |
| Returns: | |
| Updated match observation after setting batting strategy. | |
| """ | |
| return self._apply("set_strategy", { | |
| "phase_intent": phase_intent, | |
| "aggression": aggression, | |
| "rationale": rationale, | |
| }) | |
| def plan_shot(self, shot_intent: str, target_area: str, risk: str, trajectory: str, rationale: str) -> str: | |
| """DEPRECATED — pass these args inline to play_delivery() instead. | |
| Args: | |
| shot_intent: leave|defensive|single|rotate|boundary|six. | |
| target_area: scoring area. | |
| risk: low|balanced|high. | |
| trajectory: ground|lofted|aerial. | |
| rationale: one-line reason. | |
| Returns: | |
| Updated observation. | |
| """ | |
| return self._apply("plan_shot", { | |
| "shot_intent": shot_intent, | |
| "target_area": target_area, | |
| "risk": risk, | |
| "trajectory": trajectory, | |
| "rationale": rationale, | |
| }) | |
| def play_delivery( | |
| self, | |
| shot_intent: str = "", | |
| target_area: str = "", | |
| risk: str = "", | |
| trajectory: str = "", | |
| rationale: str = "", | |
| ) -> str: | |
| """ | |
| Execute the ball. Pass shot params inline to skip plan_shot. | |
| Args: | |
| shot_intent: leave|defensive|single|rotate|boundary|six. | |
| target_area: optional scoring area. | |
| risk: optional low|balanced|high. | |
| trajectory: optional ground|lofted|aerial. | |
| rationale: optional one-line reason. | |
| Returns: | |
| Updated observation after the ball outcome. | |
| """ | |
| args: dict[str, Any] = {} | |
| if shot_intent: args["shot_intent"] = shot_intent | |
| if target_area: args["target_area"] = target_area | |
| if risk: args["risk"] = risk | |
| if trajectory: args["trajectory"] = trajectory | |
| if rationale: args["rationale"] = rationale | |
| return self._apply("play_delivery", args) | |
| def choose_bowler(self, name: str, bowler_type: str, style: str, rationale: str) -> str: | |
| """ | |
| Choose the bowler at the start of an over from the configured roster. | |
| Args: | |
| name: Player name from the team roster. | |
| bowler_type: Bowler type, either pace or spin. | |
| style: Bowling style or role. | |
| rationale: Why this bowler fits phase, matchup, and remaining overs. | |
| Returns: | |
| Updated match observation after choosing the bowler. | |
| """ | |
| return self._apply("choose_bowler", { | |
| "name": name, | |
| "bowler_type": bowler_type, | |
| "style": style, | |
| "rationale": rationale, | |
| }) | |
| def set_bowling_strategy(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str: | |
| """ | |
| Set bowling strategy for the current bowler. | |
| Args: | |
| bowler_type: Current bowler type, either pace or spin. | |
| line: Intended line. | |
| length: Intended length. | |
| delivery_type: Variation or stock delivery type. | |
| rationale: Why this plan fits batter, field, phase, and target. | |
| Returns: | |
| Updated match observation after setting bowling strategy. | |
| """ | |
| return self._apply("set_bowling_strategy", { | |
| "bowler_type": bowler_type, | |
| "line": line, | |
| "length": length, | |
| "delivery_type": delivery_type, | |
| "rationale": rationale, | |
| }) | |
| def plan_delivery(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str: | |
| """DEPRECATED — pass these args inline to bowl_delivery() instead. | |
| Args: | |
| bowler_type: pace|spin. | |
| line: line. | |
| length: length. | |
| delivery_type: variation or stock. | |
| rationale: one-line reason. | |
| Returns: | |
| Updated observation. | |
| """ | |
| return self._apply("plan_delivery", { | |
| "bowler_type": bowler_type, | |
| "line": line, | |
| "length": length, | |
| "delivery_type": delivery_type, | |
| "rationale": rationale, | |
| }) | |
| def set_field_setting(self, setting: str) -> str: | |
| """ | |
| Set the field preset. | |
| Args: | |
| setting: One of Aggressive, Balanced, or Defensive. | |
| Returns: | |
| Updated match observation after setting the field. | |
| """ | |
| return self._apply("set_field_setting", {"setting": setting}) | |
| def bowl_delivery( | |
| self, | |
| line: str = "", | |
| length: str = "", | |
| delivery_type: str = "", | |
| rationale: str = "", | |
| ) -> str: | |
| """ | |
| Execute the delivery. Pass plan params inline to skip plan_delivery. | |
| Args: | |
| line: optional line. | |
| length: optional length. | |
| delivery_type: optional variation or stock. | |
| rationale: optional one-line reason. | |
| Returns: | |
| Updated observation after the ball outcome. | |
| """ | |
| args: dict[str, Any] = {} | |
| if line: args["line"] = line | |
| if length: args["length"] = length | |
| if delivery_type: args["delivery_type"] = delivery_type | |
| if rationale: args["rationale"] = rationale | |
| return self._apply("bowl_delivery", args) | |
| def reflect_after_ball(self, reflection: str) -> str: | |
| """ | |
| Reflect after the previous ball and adapt the plan. | |
| Args: | |
| reflection: Specific tactical lesson from the previous ball. | |
| Returns: | |
| Updated match observation after recording reflection. | |
| """ | |
| return self._apply("reflect_after_ball", {"reflection": reflection}) | |
| def analyze_situation(self, query_type: str) -> str: | |
| """ | |
| Analyze part of the match context. | |
| Args: | |
| query_type: One of pitch_conditions, bowler_info, field_setting, or match_situation. | |
| Returns: | |
| Updated observation containing the analysis result. | |
| """ | |
| return self._apply("analyze_situation", {"query_type": query_type}) | |
| def build_agent_dataset(n_examples: int, args) -> Dataset: | |
| if Dataset is None: | |
| raise ImportError("datasets is required for training. Install with: pip install '.[train]'") | |
| rows = [] | |
| rng = random.Random(args.seed) | |
| stage_prompt = get_system_prompt(args.stage) | |
| # Curriculum distribution. If --max-overs is set, use it as a fixed format. | |
| # Otherwise sample per-scenario from a T2-heavy distribution that tapers to T5. | |
| # Rationale: T2 episodes (~72 tool calls) actually COMPLETE within our token | |
| # budget so r_result fires; T5 episodes (~180) train the model on its | |
| # eval distribution. Heavy weight on short formats early so the policy | |
| # escapes the "planning loop" before tackling longer matches. | |
| overs_distribution = getattr(args, "overs_distribution", None) | |
| fixed_overs = args.max_overs if args.max_overs and args.max_overs > 0 else None | |
| if fixed_overs is None and not overs_distribution: | |
| # default curriculum: 50% T2, 30% T3, 15% T4, 5% T5 | |
| overs_distribution = [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5] | |
| for idx in range(n_examples): | |
| scenario_overs = fixed_overs if fixed_overs is not None else rng.choice(overs_distribution) | |
| rows.append({ | |
| "prompt": [ | |
| {"role": "system", "content": stage_prompt}, | |
| {"role": "user", "content": ""}, | |
| ], | |
| "seed": rng.randint(0, 999999), | |
| "task": "stage1_format" if args.stage == 1 else "stage2_full", | |
| "random_start": False, | |
| "max_overs": scenario_overs, | |
| "eval_pack_id": args.eval_pack_id, | |
| "opponent_mode": args.opponent_mode, | |
| "opponent_cache_path": getattr(args, "opponent_cache_path", None), | |
| "agent_team": args.agent_team, | |
| "scenario_id": idx, | |
| }) | |
| return Dataset.from_list(rows) | |
| def environment_reward(environments, **kwargs) -> list[float]: | |
| rewards = [] | |
| # Aggregate metrics across all envs in this gradient step for WandB logging. | |
| agg = { | |
| "r_result": [], "r_cricket": [], "r_behavior": [], "r_validity": [], | |
| "r_coherence": [], "r_adaptation": [], "r_opponent_awareness": [], "r_regret": [], | |
| "composite": [], "tool_calls": [], "wickets_lost": [], "agent_score": [], | |
| "matches_completed": 0, "n": 0, | |
| } | |
| tool_freq: dict[str, int] = {} | |
| for env in environments: | |
| state = env.env.state | |
| breakdown = state.reward_breakdown or {} | |
| terminal = float(breakdown.get("composite", 0.0)) | |
| plan_score = (sum(state.plan_commitment_scores) / len(state.plan_commitment_scores)) if state.plan_commitment_scores else 0.0 | |
| validity = 1.0 - min(1.0, len([c for c in state.tool_history if c.get("tool") == "invalid_json"]) / max(state.step_count, 1)) | |
| reward = env.reward + terminal + 0.1 * plan_score + 0.05 * validity | |
| # Reward clip removed: when rollouts complete naturally, the composite | |
| # reward easily saturates [-1, 1], causing GRPO group-std → 0 and | |
| # killing the gradient signal. Let GRPO standardize the advantage itself. | |
| rewards.append(round(reward, 4)) | |
| # Collect for aggregate logging. | |
| agg["n"] += 1 | |
| if env.done: | |
| agg["matches_completed"] += 1 | |
| for k in ("r_result", "r_cricket", "r_behavior", "r_validity", | |
| "r_coherence", "r_adaptation", "r_opponent_awareness", | |
| "r_regret", "composite"): | |
| v = breakdown.get(k) | |
| if v is not None: | |
| agg[k].append(float(v)) | |
| agg["tool_calls"].append(int(getattr(state, "tool_calls_made", 0) or 0)) | |
| agg["wickets_lost"].append(int(getattr(state, "wickets_lost", 0) or 0)) | |
| agg["agent_score"].append(int(getattr(state, "total_score", 0) or 0)) | |
| for c in (state.tool_history or []): | |
| t = c.get("tool", "unknown") | |
| tool_freq[t] = tool_freq.get(t, 0) + 1 | |
| # WandB log — only if wandb is initialised in this process. | |
| try: | |
| import wandb | |
| if wandb.run is not None and agg["n"] > 0: | |
| log_dict: dict[str, float] = { | |
| "rollout/n_episodes": agg["n"], | |
| "rollout/matches_completed": agg["matches_completed"], | |
| "rollout/match_completion_rate": agg["matches_completed"] / agg["n"], | |
| } | |
| for k in ("r_result", "r_cricket", "r_behavior", "r_validity", | |
| "r_coherence", "r_adaptation", "r_opponent_awareness", | |
| "r_regret", "composite"): | |
| if agg[k]: | |
| log_dict[f"reward/{k}_mean"] = sum(agg[k]) / len(agg[k]) | |
| log_dict[f"reward/{k}_max"] = max(agg[k]) | |
| log_dict[f"reward/{k}_min"] = min(agg[k]) | |
| for k in ("tool_calls", "wickets_lost", "agent_score"): | |
| if agg[k]: | |
| log_dict[f"episode/{k}_mean"] = sum(agg[k]) / len(agg[k]) | |
| log_dict[f"episode/{k}_max"] = max(agg[k]) | |
| # Tool usage breakdown — frequency per tool name across this step. | |
| total_tools = sum(tool_freq.values()) or 1 | |
| for t, n in tool_freq.items(): | |
| log_dict[f"tools/freq_{t}"] = n / total_tools | |
| wandb.log(log_dict) | |
| except Exception: | |
| # Never let logging break training. | |
| pass | |
| return rewards | |
| def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42, agent_team: str | None = None): | |
| """Stage 0 bootstrap data: valid tool JSON for every tool family.""" | |
| rng = random.Random(seed) | |
| roster = _training_roster(agent_team) | |
| examples = [] | |
| for _ in range(n_examples): | |
| game_state = rng.choice(["toss", "batting", "bowling"]) | |
| action = _random_action(rng, game_state, roster=roster) | |
| prompt = ( | |
| f"{SYSTEM_PROMPT}\n\n" | |
| f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n" | |
| "Phase: MIDDLE | Strategic turn: PRE_BALL\n" | |
| "Opponent last plan: {'field_setting': 'Defensive', 'shot_intent': 'boundary'}\n" | |
| ) | |
| examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": json.dumps({"tool": action.tool, "arguments": action.arguments})}, | |
| ] | |
| }) | |
| out = Path(out_path) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with out.open("w") as f: | |
| for ex in examples: | |
| f.write(json.dumps(ex) + "\n") | |
| print(f"Wrote {len(examples)} SFT examples -> {out}") | |
| # ------------------------------------------------------------------ # | |
| # Model loading (plain transformers + bitsandbytes 4-bit) # | |
| # ------------------------------------------------------------------ # | |
| def load_model(model_name: str, *, use_vllm: bool = False, resume_adapter_from: str | None = None): | |
| """Load base + LoRA. When use_vllm=True, base is loaded in bf16 (vLLM | |
| does not support 4-bit BNB inference); otherwise 4-bit NF4. | |
| resume_adapter_from: optional path to a PEFT adapter directory (e.g. a previous | |
| checkpoint dir). If provided, loads the adapter weights instead of initializing | |
| a fresh LoRA. The base model is still loaded from `model_name`. The adapter's | |
| LoraConfig is preserved (so you can resume even if r= or alpha= drift between runs).""" | |
| if not _TRAIN_IMPORTS_AVAILABLE: | |
| raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'") | |
| print(f"Loading {model_name} … (use_vllm={use_vllm}, dtype={'bf16' if use_vllm else '4-bit NF4'})") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| try: | |
| import flash_attn # noqa: F401 | |
| attn_impl = "flash_attention_2" | |
| except ImportError: | |
| attn_impl = "sdpa" | |
| load_kwargs = dict( | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation=attn_impl, | |
| ) | |
| if not use_vllm: | |
| load_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs) | |
| if not use_vllm: | |
| model = prepare_model_for_kbit_training(model) | |
| if resume_adapter_from: | |
| # Resume from a previous PEFT adapter checkpoint (e.g. warmup output). | |
| # PeftModel.from_pretrained reads the adapter_config.json from the dir, | |
| # so any r/alpha/target_modules saved with the warmup run is preserved. | |
| from peft import PeftModel | |
| adapter_path = Path(resume_adapter_from) | |
| if not adapter_path.exists(): | |
| raise FileNotFoundError(f"resume_adapter_from path does not exist: {adapter_path}") | |
| print(f"Resuming LoRA adapter from {adapter_path}") | |
| model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True) | |
| else: | |
| lora_cfg = LoraConfig( | |
| r=64, | |
| lora_alpha=128, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| ) | |
| model = get_peft_model(model, lora_cfg) | |
| print(f"Loaded. Parameters: {model.num_parameters():,}") | |
| model.print_trainable_parameters() | |
| return model, tokenizer | |
| # ------------------------------------------------------------------ # | |
| # Training # | |
| # ------------------------------------------------------------------ # | |
| def train(args): | |
| if not _TRAIN_IMPORTS_AVAILABLE: | |
| raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'") | |
| if args.opponent_mode == "llm_live": | |
| if args.opponent_model: | |
| os.environ["CRICKET_OPPONENT_MODEL"] = args.opponent_model | |
| if args.opponent_api_base: | |
| os.environ["CRICKET_OPPONENT_API_BASE"] = args.opponent_api_base | |
| if args.opponent_api_key: | |
| os.environ["CRICKET_OPPONENT_API_KEY"] = args.opponent_api_key | |
| task = "stage1_format" if args.stage == 1 else "stage2_full" | |
| # CRICKET_CKPT_ROOT lets a side-by-side run write checkpoints to a different | |
| # tree (e.g. ./checkpoints_smoke) without trampling an active production run. | |
| # Default unchanged: ./checkpoints/. | |
| ckpt_root = os.environ.get("CRICKET_CKPT_ROOT", "./checkpoints").rstrip("/") | |
| out_dir = f"{ckpt_root}/stage{args.stage}" | |
| save_dir = f"{ckpt_root}/stage{args.stage}_final" | |
| ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| log_dir = Path(f"./logs/run_{ts}_stage{args.stage}_{args.opponent_mode}") | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| # Make episode termination stats available to the environment wrapper. | |
| # This lets us distinguish natural terminations from tool-iteration cap truncations. | |
| stats_path = log_dir / "episode_stats.jsonl" | |
| os.environ["CRICKET_EPISODE_STATS_PATH"] = str(stats_path) | |
| os.environ["CRICKET_MAX_TOOL_ITERS"] = str(args.max_tool_calling_iterations) | |
| # Create the file immediately so users can find/tail it even before the first termination. | |
| stats_path.touch(exist_ok=True) | |
| print(f"\n=== Stage {args.stage} Training ===") | |
| print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}") | |
| print(f"Logs: {log_dir}/ | Checkpoints: {out_dir}/") | |
| print(f"max_tool_calling_iterations={args.max_tool_calling_iterations} (full 5-over match needs ~180; 20-over needs ~720)") | |
| (log_dir / "metadata.json").write_text(json.dumps({ | |
| "stage": args.stage, "model": args.model, "agent_team": args.agent_team, | |
| "max_overs": args.max_overs, "opponent_mode": args.opponent_mode, | |
| "prompts": args.prompts, "steps": args.steps, | |
| "batch_size": args.batch_size, "grad_accum": args.grad_accum, | |
| "num_generations": args.num_generations, | |
| "max_completion_length": args.max_completion_length, | |
| "max_tool_calling_iterations": args.max_tool_calling_iterations, | |
| "logging_steps": args.logging_steps, | |
| "timestamp": ts, | |
| }, indent=2)) | |
| # Build scenario seeds. TRL's environment_factory performs the actual | |
| # multi-turn rollout and tool execution during training. | |
| print("\nBuilding environment scenarios …") | |
| dataset = build_agent_dataset(args.prompts, args) | |
| # Load model — bf16 if vLLM is on (vLLM rejects 4-bit BNB) or --bf16-base, else 4-bit NF4. | |
| # If resume_from is set, load the LoRA adapter from that path instead of fresh init. | |
| bf16_base = getattr(args, "use_vllm", False) or getattr(args, "bf16_base", False) | |
| resume_from = getattr(args, "resume_from", None) | |
| model, tokenizer = load_model(args.model, use_vllm=bf16_base, resume_adapter_from=resume_from) | |
| # GRPO config | |
| # | |
| # Qwen3 / Qwen3.5 ship with hybrid thinking ENABLED by default. Empirically | |
| # (see logs/run_2026-04-25_21-08-45 completions parquet) every generation | |
| # spent ~1200 chars inside <think>...</think> and then emitted XML-style | |
| # <function><parameter> tags instead of the JSON tool call we asked for. | |
| # That meant 0/32 generations were parseable, _apply() never advanced the | |
| # match, and episodes always hit max_tool_calling_iterations before any | |
| # innings finished — so r_result (55% of the composite) was never earned. | |
| # | |
| chat_template_kwargs = {} | |
| generation_kwargs = {} | |
| completion_len = max(args.max_completion_length, 2048) | |
| use_vllm = getattr(args, "use_vllm", False) | |
| vllm_kwargs = {} | |
| if use_vllm: | |
| # vllm_model_impl: None (default) → vLLM picks its native class. Use this for | |
| # Qwen3-* (Qwen3ForCausalLM is registered, native path with full LoRA support). | |
| # Set to "transformers" only for Qwen3.5-* where vLLM has no text-only class | |
| # registered and the native path tries to load a vision tower. | |
| vllm_kwargs = dict( | |
| use_vllm=True, | |
| vllm_mode="colocate", | |
| vllm_gpu_memory_utilization=getattr(args, "vllm_gpu_memory", 0.5), | |
| vllm_max_model_length=completion_len + 2048, | |
| ) | |
| vllm_impl = getattr(args, "vllm_model_impl", None) | |
| if vllm_impl: | |
| vllm_kwargs["vllm_model_impl"] = vllm_impl | |
| # Resolve hyperparameters from YAML/CLI with sensible fallbacks. | |
| lr = args.learning_rate if getattr(args, "learning_rate", None) is not None \ | |
| else (2e-5 if args.stage == 1 else 1e-5) | |
| grpo_beta = getattr(args, "beta", None) | |
| grpo_temp = getattr(args, "temperature", None) or 0.8 | |
| grpo_top_p = getattr(args, "top_p", None) | |
| grad_ckpt = getattr(args, "gradient_checkpointing", None) | |
| grad_ckpt_kwargs = None | |
| if grad_ckpt and getattr(args, "gradient_checkpointing_use_reentrant", None) is not None: | |
| grad_ckpt_kwargs = {"use_reentrant": bool(args.gradient_checkpointing_use_reentrant)} | |
| optional_cfg = {} | |
| if grpo_beta is not None: | |
| optional_cfg["beta"] = grpo_beta | |
| if grpo_top_p is not None: | |
| optional_cfg["top_p"] = grpo_top_p | |
| if grad_ckpt is not None: | |
| optional_cfg["gradient_checkpointing"] = bool(grad_ckpt) | |
| if grad_ckpt_kwargs is not None: | |
| optional_cfg["gradient_checkpointing_kwargs"] = grad_ckpt_kwargs | |
| if getattr(args, "dataloader_pin_memory", None) is not None: | |
| optional_cfg["dataloader_pin_memory"] = bool(args.dataloader_pin_memory) | |
| if getattr(args, "dataloader_num_workers", None) is not None: | |
| optional_cfg["dataloader_num_workers"] = int(args.dataloader_num_workers) | |
| config = GRPOConfig( | |
| output_dir=out_dir, | |
| logging_dir=str(log_dir / "tensorboard"), | |
| num_train_epochs=1, | |
| max_steps=args.steps, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=lr, | |
| warmup_ratio=0.05, | |
| lr_scheduler_type="cosine", | |
| logging_steps=args.logging_steps, | |
| save_steps=getattr(args, "save_steps", None) or 10, | |
| save_total_limit=getattr(args, "save_total_limit", None) or 20, | |
| bf16=True, | |
| max_completion_length=completion_len, | |
| num_generations=args.num_generations, | |
| max_tool_calling_iterations=args.max_tool_calling_iterations, | |
| temperature=grpo_temp, | |
| report_to=args.report_to, | |
| run_name=args.run_name, | |
| log_completions=True, | |
| seed=args.seed, | |
| chat_template_kwargs=chat_template_kwargs, | |
| generation_kwargs=generation_kwargs, | |
| **optional_cfg, | |
| **vllm_kwargs, | |
| ) | |
| # TRL's add_response_schema pattern-matches tokenizer.chat_template against | |
| # a fixed list and raises "Unrecognized chat template" if no match. Some | |
| # newer Qwen3 builds (e.g. Qwen3-4B-Instruct-2507, Aug 2025) ship a | |
| # template that differs from TRL's stored string (the Instruct release | |
| # dropped the enable_thinking block) — but the tool-call format | |
| # (<tool_call>…</tool_call>) is identical, so the appropriate schema still | |
| # parses correctly. We assign it manually before GRPOTrainer init; TRL | |
| # checks `response_schema is None` first so this is a safe override. | |
| if getattr(tokenizer, "response_schema", None) is None: | |
| try: | |
| from trl.chat_template_utils import qwen3_schema, qwen3_5_schema | |
| m = args.model.lower() | |
| if "qwen3.5" in m or "qwen3_5" in m: | |
| tokenizer.response_schema = qwen3_5_schema | |
| print("Set tokenizer.response_schema = qwen3_5_schema (manual override).") | |
| elif "qwen3" in m: | |
| tokenizer.response_schema = qwen3_schema | |
| print("Set tokenizer.response_schema = qwen3_schema (manual override).") | |
| except ImportError: | |
| pass | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=environment_reward, | |
| args=config, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| environment_factory=CricketCaptainToolEnv, | |
| ) | |
| print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …") | |
| trainer.train() | |
| model.save_pretrained(save_dir) | |
| tokenizer.save_pretrained(save_dir) | |
| print(f"\nSaved → {save_dir}") | |
| # ------------------------------------------------------------------ # | |
| # Quick eval: run N episodes with the trained model # | |
| # ------------------------------------------------------------------ # | |
| def evaluate(args): | |
| """Run N episodes and print coherence + score stats.""" | |
| from server.reward_calculator import compute_episode_reward, get_dls_par | |
| model, tokenizer = load_model(args.model) | |
| model.eval() | |
| def generate(prompt: str) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, max_new_tokens=200, temperature=0.7, do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| rng = random.Random(args.seed) | |
| all_coh, all_scores = [], [] | |
| for ep in range(args.eval_episodes): | |
| env = CricketEnvironment() | |
| obs = env.reset(seed=rng.randint(0, 99999), options={ | |
| "task": "stage2_full", | |
| "random_start": False, | |
| "agent_team": args.agent_team, | |
| }) | |
| steps = 0 | |
| while not obs.done and steps < 150: | |
| prompt = _format_prompt(obs.prompt_text) | |
| raw = generate(prompt) | |
| data = _parse_completion(raw) | |
| if data: | |
| action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {})) | |
| else: | |
| action = CricketAction(tool="invalid_json", arguments={}) | |
| obs = env.step(action) | |
| steps += 1 | |
| state = env.state | |
| avg_coh = sum(state.coherence_scores) / len(state.coherence_scores) if state.coherence_scores else 0 | |
| all_coh.append(avg_coh) | |
| all_scores.append(state.total_score) | |
| print(f" Ep {ep+1}: {state.total_score}/{state.wickets_lost} coh={avg_coh:.3f}") | |
| print(f"\nAvg coherence: {sum(all_coh)/len(all_coh):.3f}") | |
| print(f"Avg score: {sum(all_scores)/len(all_scores):.1f}") | |
| def _make_run_folder(prefix: str, model: str | None, opponent_mode: str | None, max_overs: int | None) -> Path: | |
| """Create a timestamped illustrations folder, return its path.""" | |
| import datetime | |
| ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") | |
| model_short = (model or "heuristic").split("/")[-1][:20] if model else "heuristic" | |
| overs_str = f"_{max_overs}ov" if max_overs else "" | |
| opp_str = f"_{opponent_mode}" if opponent_mode else "" | |
| folder_name = f"exp_{ts}_{prefix}{overs_str}{opp_str}_{model_short}" | |
| run_dir = Path(__file__).parent / "illustrations" / folder_name | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| return run_dir | |
| def train_smoke(args): | |
| """Run short direct-environment training rollouts without loading a model.""" | |
| rng = random.Random(args.seed) | |
| roster = _training_roster(args.agent_team) | |
| # Auto-create run folder unless --output explicitly given | |
| if args.output: | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| run_dir = output_path.parent | |
| else: | |
| model_hint = getattr(args, "model", None) | |
| run_dir = _make_run_folder("train_smoke", model_hint, args.opponent_mode, args.max_overs) | |
| output_path = run_dir / "run_output.txt" | |
| # Write header immediately so the file exists while the run is in progress | |
| header_lines = [ | |
| "# Training smoke: direct CricketEnvironment rollout", | |
| f"matches={args.matches} max_overs={args.max_overs} opponent_mode={args.opponent_mode}", | |
| "purpose=verify one short training-style match rollout, prompt collection, tool stepping, and terminal reward", | |
| "", | |
| ] | |
| output_path.write_text("\n".join(header_lines)) | |
| def log(msg: str): | |
| print(msg) | |
| with open(output_path, "a") as _f: | |
| _f.write(msg + "\n") | |
| log("# Training smoke: direct CricketEnvironment rollout") | |
| log(f"matches={args.matches} max_overs={args.max_overs} opponent_mode={args.opponent_mode}") | |
| log("purpose=verify one short training-style match rollout, prompt collection, tool stepping, and terminal reward") | |
| for match_idx in range(args.matches): | |
| match_start = time.perf_counter() | |
| last_step_time = match_start | |
| env = CricketEnvironment() | |
| obs = env.reset(seed=rng.randint(0, 99999), options={ | |
| "task": "stage2_full", | |
| "random_start": False, | |
| "max_overs": args.max_overs, | |
| "eval_pack_id": args.eval_pack_id, | |
| "opponent_mode": args.opponent_mode, | |
| "opponent_cache_path": args.opponent_cache_path, | |
| "agent_team": args.agent_team, | |
| }) | |
| prompts = [_format_prompt(obs.prompt_text)] | |
| total_reward = 0.0 | |
| steps = 0 | |
| log(f"\n--- match {match_idx + 1} reset ---") | |
| log(f"initial_state={obs.game_state} phase={obs.strategic_phase} t_elapsed=0.000s tools={obs.available_tools}") | |
| while not obs.done and steps < args.max_steps: | |
| step_start = time.perf_counter() | |
| previous_game_state = obs.game_state | |
| previous_innings = obs.innings_type | |
| previous_available_tools = obs.available_tools | |
| action = _random_action( | |
| rng, | |
| obs.game_state, | |
| obs.available_tools, | |
| obs.current_bowler.get("type") if obs.current_bowler else None, | |
| roster, | |
| ) | |
| obs = env.step(action) | |
| step_end = time.perf_counter() | |
| step_dt = step_end - step_start | |
| elapsed = step_end - match_start | |
| since_prev = step_end - last_step_time | |
| last_step_time = step_end | |
| total_reward += obs.reward or 0.0 | |
| if not obs.done: | |
| prompts.append(_format_prompt(obs.prompt_text)) | |
| changed_context = ( | |
| obs.done | |
| or obs.game_state != previous_game_state | |
| or obs.innings_type != previous_innings | |
| or obs.available_tools != previous_available_tools | |
| ) | |
| if steps < args.log_steps or changed_context: | |
| over = obs.game_context.get("over", 0) or 0 | |
| ball = obs.game_context.get("ball", 0) or 0 | |
| score = obs.game_context.get("score", 0) or 0 | |
| balls_used = int(over) * 6 + int(ball) | |
| balls_left = max(args.max_overs * 6 - balls_used, 0) | |
| current_rr = score / (balls_used / 6) if balls_used > 0 else 0.0 | |
| if obs.target is not None: | |
| runs_required = max(obs.target - score, 0) | |
| required_rr = runs_required / max(balls_left / 6, 1 / 6) | |
| chase_context = f"need={runs_required} balls_left={balls_left} rrr={required_rr:.2f}" | |
| else: | |
| chase_context = "need=None balls_left=None rrr=None" | |
| outcome_meta = (obs.last_outcome or {}).get("metadata", {}) | |
| tactical_context = "" | |
| if outcome_meta and action.tool in {"bowl_delivery", "play_delivery"}: | |
| delivery = outcome_meta.get("delivery_features", {}) | |
| tactical_context = ( | |
| f" event={outcome_meta.get('event_type')} zone={outcome_meta.get('target_area')} " | |
| f"traj={outcome_meta.get('trajectory')} field_effect={outcome_meta.get('fielder_effect')} " | |
| f"fit={outcome_meta.get('shot_delivery_fit')} field_pressure={outcome_meta.get('field_pressure')} " | |
| f"line={delivery.get('line')} length={delivery.get('length')} variation={delivery.get('variation')}" | |
| ) | |
| log( | |
| f"step={steps:03d} t_elapsed={elapsed:.3f}s step_dt={step_dt:.4f}s since_prev={since_prev:.4f}s " | |
| f"tool={action.tool} reward={(obs.reward or 0.0):.3f} " | |
| f"state={obs.game_state}/{obs.innings_type} phase={obs.strategic_phase} " | |
| f"over={obs.game_context.get('over')}.{obs.game_context.get('ball')} " | |
| f"score={obs.game_context.get('score')}/{obs.game_context.get('wickets')} " | |
| f"target={obs.target} rr={current_rr:.2f} {chase_context} " | |
| f"{tactical_context} " | |
| f"tools={obs.available_tools} " | |
| f"last={obs.last_ball_result[:140]!r}" | |
| ) | |
| steps += 1 | |
| state = env.state | |
| match_elapsed = time.perf_counter() - match_start | |
| log(f"\n--- match {match_idx + 1} final ---") | |
| log(f"done={obs.done} steps={steps} prompts_collected={len(prompts)} rollout_reward_sum={total_reward:.3f} match_elapsed={match_elapsed:.3f}s avg_step_dt={(match_elapsed / max(steps, 1)):.4f}s") | |
| log(f"score={state.total_score}/{state.wickets_lost} over={state.over}.{state.ball} target={state.target} game_state={state.game_state}") | |
| log(f"last_outcome={state.last_outcome}") | |
| log(f"match_result={state.match_result} reward_breakdown={state.reward_breakdown}") | |
| log(f"innings_rewards={state.innings_rewards}") | |
| log(f"tool_calls={state.tool_calls_made} dream11_scores={state.dream11_scores}") | |
| log(f"mean_coherence={(sum(state.coherence_scores) / len(state.coherence_scores)) if state.coherence_scores else 0.0:.3f}") | |
| log(f"mean_adaptation={(sum(state.adaptation_scores) / len(state.adaptation_scores)) if state.adaptation_scores else 0.0:.3f}") | |
| log(f"mean_opponent_awareness={(sum(state.opponent_awareness_scores) / len(state.opponent_awareness_scores)) if state.opponent_awareness_scores else 0.0:.3f}") | |
| print(f"\nwrote={output_path}") | |
| # Write README for the run | |
| import datetime | |
| readme_path = run_dir / "README.md" | |
| model_str = getattr(args, "model", None) or "heuristic (random actions)" | |
| readme_path.write_text( | |
| f"## Train-Smoke Run: {run_dir.name}\n\n" | |
| f"**Date**: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n" | |
| f"**Config**: `{getattr(args, 'config', None) or 'defaults'}`\n\n" | |
| f"| Setting | Value |\n|---|---|\n" | |
| f"| Matches | {args.matches} |\n" | |
| f"| Max overs | {args.max_overs} |\n" | |
| f"| Opponent mode | {args.opponent_mode} |\n" | |
| f"| Model (train target) | `{model_str}` |\n\n" | |
| f"See `run_output.txt` for full step-by-step rollout log, reward breakdowns, and coherence scores.\n" | |
| ) | |
| print(f"wrote={readme_path}") | |
| # ------------------------------------------------------------------ # | |
| # CLI # | |
| # ------------------------------------------------------------------ # | |
| def _apply_yaml_defaults(args, cfg: dict) -> None: | |
| """Merge YAML config values into args, CLI args take precedence.""" | |
| captain = cfg.get("captain", {}) or {} | |
| opponent = cfg.get("opponent", {}) or {} | |
| env_cfg = cfg.get("env", {}) or {} | |
| train_cfg = cfg.get("train", {}) or {} | |
| def _set(attr, val): | |
| if val is not None and getattr(args, attr, None) is None: | |
| setattr(args, attr, val) | |
| if getattr(args, "cmd", None) == "train": | |
| _set("model", train_cfg.get("model")) | |
| else: | |
| _set("model", captain.get("model")) | |
| _set("api_base", captain.get("api_base")) | |
| _set("api_key", os.environ.get(captain.get("api_key_env", "")) or None) | |
| _set("eval_pack_id", env_cfg.get("eval_pack_id")) | |
| _set("opponent_mode", opponent.get("mode")) | |
| _set("opponent_cache_path", opponent.get("cache_path")) | |
| _set("opponent_model", opponent.get("model")) | |
| _set("opponent_api_base", opponent.get("api_base")) | |
| api_key_env = opponent.get("api_key_env") | |
| _set("opponent_api_key", os.environ.get(api_key_env, "") if api_key_env else None) | |
| _set("max_overs", env_cfg.get("max_overs")) | |
| _set("agent_team", env_cfg.get("agent_team")) | |
| _set("steps", train_cfg.get("steps")) | |
| _set("prompts", train_cfg.get("prompts")) | |
| _set("batch_size", train_cfg.get("batch_size")) | |
| _set("grad_accum", train_cfg.get("grad_accum")) | |
| _set("stage", train_cfg.get("stage")) | |
| _set("num_generations", train_cfg.get("num_generations")) | |
| _set("max_completion_length", train_cfg.get("max_completion_length")) | |
| _set("max_tool_calling_iterations", train_cfg.get("max_tool_calling_iterations")) | |
| _set("logging_steps", train_cfg.get("logging_steps")) | |
| _set("report_to", train_cfg.get("report_to")) | |
| _set("run_name", train_cfg.get("run_name")) | |
| _set("learning_rate", train_cfg.get("learning_rate")) | |
| _set("beta", train_cfg.get("beta")) | |
| _set("temperature", train_cfg.get("temperature")) | |
| _set("top_p", train_cfg.get("top_p")) | |
| _set("gradient_checkpointing", train_cfg.get("gradient_checkpointing")) | |
| _set("gradient_checkpointing_use_reentrant", train_cfg.get("gradient_checkpointing_use_reentrant")) | |
| _set("dataloader_pin_memory", train_cfg.get("dataloader_pin_memory")) | |
| _set("dataloader_num_workers", train_cfg.get("dataloader_num_workers")) | |
| _set("bf16_base", train_cfg.get("bf16_base")) | |
| _set("save_steps", train_cfg.get("save_steps")) | |
| _set("save_total_limit", train_cfg.get("save_total_limit")) | |
| _set("resume_from", train_cfg.get("resume_from")) | |
| _set("overs_distribution", train_cfg.get("overs_distribution")) | |
| _set("use_vllm", train_cfg.get("use_vllm")) | |
| _set("vllm_gpu_memory", train_cfg.get("vllm_gpu_memory")) | |
| _set("vllm_model_impl", train_cfg.get("vllm_model_impl")) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", default=None, help="YAML config path (sets defaults for all subcommands)") | |
| sub = parser.add_subparsers(dest="cmd") | |
| # train | |
| t = sub.add_parser("train", help="Run GRPO training") | |
| t.add_argument("--config", default=None) | |
| t.add_argument("--stage", type=int, default=None, choices=[1, 2]) | |
| t.add_argument("--model", default=None) | |
| t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect") | |
| t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps") | |
| t.add_argument("--batch-size", type=int, default=None, dest="batch_size") | |
| t.add_argument("--grad-accum", type=int, default=None, dest="grad_accum") | |
| t.add_argument("--num-generations", type=int, default=None, dest="num_generations") | |
| t.add_argument("--agent-team", default=None, dest="agent_team") | |
| t.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode") | |
| t.add_argument("--opponent-model", default=None, dest="opponent_model") | |
| t.add_argument("--opponent-api-base", default=None, dest="opponent_api_base") | |
| t.add_argument("--opponent-api-key", default=None, dest="opponent_api_key") | |
| t.add_argument("--max-overs", type=int, default=None, dest="max_overs") | |
| t.add_argument("--eval-pack-id", default=None, dest="eval_pack_id") | |
| t.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path") | |
| t.add_argument("--max-completion-length", type=int, default=None, dest="max_completion_length") | |
| t.add_argument("--max-tool-calling-iterations", type=int, default=None, dest="max_tool_calling_iterations") | |
| t.add_argument("--logging-steps", type=int, default=None, dest="logging_steps") | |
| t.add_argument("--report-to", default=None, dest="report_to") | |
| t.add_argument("--run-name", default=None, dest="run_name") | |
| t.add_argument("--seed", type=int, default=42) | |
| t.add_argument("--resume-from", default=None, dest="resume_from", | |
| help="Path to a previous LoRA adapter dir (e.g. ./checkpoints/stage2_final). " | |
| "When set, the adapter is loaded on top of the base model instead of a fresh init.") | |
| t.add_argument("--save-steps", type=int, default=None, dest="save_steps") | |
| t.add_argument("--save-total-limit", type=int, default=None, dest="save_total_limit") | |
| t.add_argument("--learning-rate", type=float, default=None, dest="learning_rate") | |
| t.add_argument("--beta", type=float, default=None, dest="beta", | |
| help="GRPO KL coefficient. Lower = more exploration.") | |
| t.add_argument("--temperature", type=float, default=None, dest="temperature") | |
| t.add_argument("--top-p", type=float, default=None, dest="top_p") | |
| t.add_argument("--gradient-checkpointing", action="store_true", dest="gradient_checkpointing", default=None) | |
| t.add_argument("--no-gradient-checkpointing", action="store_false", dest="gradient_checkpointing") | |
| t.add_argument("--gradient-checkpointing-use-reentrant", action="store_true", | |
| dest="gradient_checkpointing_use_reentrant", default=None) | |
| t.add_argument("--dataloader-pin-memory", action="store_true", dest="dataloader_pin_memory", default=None) | |
| t.add_argument("--dataloader-num-workers", type=int, default=None, dest="dataloader_num_workers") | |
| t.add_argument("--use-vllm", action="store_true", dest="use_vllm", default=None, | |
| help="Use vLLM-backed rollouts (colocate). Forces bf16 base.") | |
| t.add_argument("--bf16-base", action="store_true", dest="bf16_base", default=None, | |
| help="Load base model in bf16 instead of 4-bit NF4. Faster matmul on H200 since 4B fits in 8GB.") | |
| t.add_argument("--vllm-gpu-memory", type=float, default=0.5, dest="vllm_gpu_memory", | |
| help="Fraction of GPU memory reserved for vLLM (colocate). Default 0.5.") | |
| t.add_argument("--vllm-model-impl", default=None, dest="vllm_model_impl", | |
| choices=["transformers", "vllm"], | |
| help="vLLM model backend. None (default) = native vLLM class (e.g. Qwen3ForCausalLM); " | |
| "'transformers' = HF transformers backend (workaround for Qwen3.5 — flaky with LoRA).") | |
| # eval | |
| e = sub.add_parser("eval", help="Evaluate a checkpoint") | |
| e.add_argument("--config", default=None) | |
| e.add_argument("--model", default=None) | |
| e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes") | |
| e.add_argument("--agent-team", default=None, dest="agent_team") | |
| e.add_argument("--seed", type=int, default=0) | |
| # quick test (no GPU needed) | |
| sub.add_parser("test", help="Smoke-test reward functions") | |
| smoke = sub.add_parser("train-smoke", help="Run short direct-env training rollouts without loading a model") | |
| smoke.add_argument("--config", default=None) | |
| smoke.add_argument("--matches", type=int, default=1) | |
| smoke.add_argument("--max-overs", type=int, default=None, dest="max_overs") | |
| smoke.add_argument("--max-steps", type=int, default=240, dest="max_steps") | |
| smoke.add_argument("--log-steps", type=int, default=30, dest="log_steps") | |
| smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id") | |
| smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode") | |
| smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path") | |
| smoke.add_argument("--agent-team", default=None, dest="agent_team") | |
| smoke.add_argument("--output", default=None) | |
| smoke.add_argument("--seed", type=int, default=42) | |
| sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples") | |
| sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl") | |
| sft.add_argument("--examples", type=int, default=240) | |
| sft.add_argument("--agent-team", default=None, dest="agent_team") | |
| sft.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| # Apply YAML config (subcommand --config overrides top-level --config) | |
| config_path = getattr(args, "config", None) or getattr(parser.parse_known_args()[0], "config", None) | |
| if config_path: | |
| try: | |
| from config_yaml import load_config | |
| except ImportError: | |
| from cricket_captain.config_yaml import load_config | |
| _apply_yaml_defaults(args, load_config(config_path)) | |
| # Set safe defaults after YAML merge | |
| if getattr(args, "stage", None) is None: | |
| args.stage = 1 | |
| if getattr(args, "model", None) is None: | |
| args.model = "Qwen/Qwen3.5-4B" | |
| if getattr(args, "steps", None) is None: | |
| args.steps = 200 | |
| if getattr(args, "prompts", None) is None: | |
| args.prompts = 500 | |
| if getattr(args, "batch_size", None) is None: | |
| args.batch_size = 2 | |
| if getattr(args, "grad_accum", None) is None: | |
| args.grad_accum = 4 | |
| if getattr(args, "eval_pack_id", None) is None: | |
| args.eval_pack_id = "adaptive_t20_v1" | |
| if getattr(args, "opponent_mode", None) is None: | |
| args.opponent_mode = "llm_live" | |
| if getattr(args, "max_overs", None) is None: | |
| args.max_overs = 5 | |
| if getattr(args, "agent_team", None) is None: | |
| args.agent_team = os.environ.get("CRICKET_AGENT_TEAM") | |
| if getattr(args, "max_tool_calling_iterations", None) is None: | |
| args.max_tool_calling_iterations = 200 | |
| if getattr(args, "logging_steps", None) is None: | |
| args.logging_steps = 1 | |
| if getattr(args, "report_to", None) is None: | |
| args.report_to = "none" | |
| if args.cmd == "train": | |
| train(args) | |
| elif args.cmd == "eval": | |
| evaluate(args) | |
| elif args.cmd == "test": | |
| _smoke_test(args.agent_team, args.opponent_mode) | |
| elif args.cmd == "train-smoke": | |
| train_smoke(args) | |
| elif args.cmd == "sft-data": | |
| generate_sft_examples(args.output, args.examples, args.seed, args.agent_team) | |
| else: | |
| parser.print_help() | |
| def _smoke_test(agent_team: str | None, opponent_mode: str): | |
| """Verify reward functions work correctly.""" | |
| cases = [ | |
| ( | |
| "[CricketCaptain] BATTING | SECOND INNINGS\nPhase: MIDDLE | Bowler: SPIN\n" | |
| "Batting Strategy: consolidate aggression=0.30 rotate strike against spin in middle overs", | |
| '{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotating"}}', | |
| "high", | |
| ), | |
| ( | |
| "[CricketCaptain] BATTING | SECOND INNINGS\nPhase: MIDDLE | Bowler: SPIN\n" | |
| "Batting Strategy: consolidate aggression=0.30 rotate strike against spin in middle overs", | |
| '{"tool": "play_delivery", "arguments": {"shot_intent": "six", "explanation": "going big"}}', | |
| "low", | |
| ), | |
| ( | |
| "[CricketCaptain] BATTING | FIRST INNINGS\nPhase: POWERPLAY\nBatting Strategy: None", | |
| '{"tool": "play_delivery", "arguments": {"shot_intent": "boundary", "explanation": "attack"}}', | |
| "zero", | |
| ), | |
| ( | |
| "[CricketCaptain] BATTING | SECOND INNINGS\nPhase: DEATH\nBatting Strategy: None", | |
| "not valid json at all", | |
| "zero", | |
| ), | |
| ] | |
| print("Reward function smoke test:\n") | |
| for prompt, completion, expected in cases: | |
| fmt = r_validity(completion) | |
| coh = r_behavior_stateless(prompt, completion) | |
| print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}") | |
| print("\nPrompt collection test (5 prompts):") | |
| p = collect_prompts(5, task="stage1_format", seed=1, agent_team=agent_team, opponent_mode=opponent_mode) | |
| for i, pp in enumerate(p): | |
| print(f" [{i}] {pp[:80].strip()} …") | |
| if __name__ == "__main__": | |
| main() | |