""" MT-GRPO training script for CricketCaptain-LLM. Two-stage curriculum (ToolRL-style): Stage 1: tool-call mastery — emphasize valid, phase-legal tool usage Stage 2: strategic behavior — full environment-backed reward (result + cricket + behavior + validity) Design: - Training uses TRL GRPO with environment_factory=CricketCaptainToolEnv - The model interacts with live CricketEnvironment instances over multi-turn tool calls - Rewards are collected from the environment (environment_reward), not only from stateless prompt parsing - The opponent policy is part of the environment: heuristic/cricsheet/llm_live/llm_cached - Plain TRL + Transformers + bitsandbytes + PEFT (LoRA adapters for 4-bit models) Usage (canonical Qwen3 setup): python train.py train --config configs/cricket_train_qwen3_warmup.yaml # warmup python train.py train --config configs/cricket_train_qwen3.yaml # main 5-over Legacy Qwen3.5 configs live in configs/extras/. """ import argparse import copy import datetime import json import os import random import re import sys import threading import time from pathlib import Path from typing import Any # ------------------------------------------------------------------ # # Optional training imports # # ------------------------------------------------------------------ # try: import torch from datasets import Dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from trl import GRPOConfig, GRPOTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig _TRAIN_IMPORTS_AVAILABLE = True except ImportError: torch = None Dataset = None LoraConfig = None get_peft_model = None prepare_model_for_kbit_training = None GRPOConfig = None GRPOTrainer = None AutoModelForCausalLM = None AutoTokenizer = None BitsAndBytesConfig = None _TRAIN_IMPORTS_AVAILABLE = False try: from server.cricket_environment import CricketEnvironment from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity from server.markov_engine import SHOT_AGGRESSION from server.player_roster import build_playing_xi, load_team_roster from models import CricketAction from config_yaml import get_game_constants, get_reward_weights except ImportError: from cricket_captain.server.cricket_environment import CricketEnvironment from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity from cricket_captain.server.markov_engine import SHOT_AGGRESSION from cricket_captain.server.player_roster import build_playing_xi, load_team_roster from cricket_captain.models import CricketAction from cricket_captain.config_yaml import get_game_constants, get_reward_weights # Load game knowledge once at import time (cached in config_yaml). _GK = get_game_constants() _RW = get_reward_weights() # ------------------------------------------------------------------ # # Prompt parsing (stateless — reads from rendered observation text) # # ------------------------------------------------------------------ # _STRATEGY_RE = re.compile(r"Batting Strategy:\s*(.+)$", re.MULTILINE) _AGGRESSION_RE = re.compile(r"aggression[\"'=:\s]+([0-9.]+)", re.IGNORECASE) _PHASE_RE = re.compile(r"Phase:\s+(POWERPLAY|MIDDLE|DEATH)", re.IGNORECASE) _VALID_TOOLS = { "call_toss", "select_batter", "set_strategy", "plan_shot", "play_delivery", "choose_bowler", "set_bowling_strategy", "plan_delivery", "set_field_setting", "bowl_delivery", "reflect_after_ball", "analyze_situation", "set_match_plan", "update_match_plan", } def extract_strategy_from_prompt(prompt: str) -> dict | None: m = _STRATEGY_RE.search(prompt) if not m or m.group(1).strip().lower().startswith("none"): return None agg_m = _AGGRESSION_RE.search(prompt) agg = float(agg_m.group(1)) if agg_m else 0.5 return {"phase_intent": m.group(1).strip(), "aggression": agg, "rationale": m.group(1).strip()} def extract_phase_from_prompt(prompt: str) -> str: m = _PHASE_RE.search(prompt) return m.group(1).lower() if m else "middle" # ------------------------------------------------------------------ # # Per-turn reward components (all stateless) # # ------------------------------------------------------------------ # _XML_FN_RE = re.compile(r"\s]+)\s*>", re.IGNORECASE) _XML_PARAM_RE = re.compile(r"\s]+)\s*>(.*?)", re.IGNORECASE | re.DOTALL) def _parse_completion(raw: str) -> dict | None: """Parse a tool-call from the raw completion into our canonical {tool, arguments} dict. Handles four common model output patterns: 1. Plain JSON (ideal). 2. Markdown code block (```json ... ```). 3. Thinking-model preamble: ... followed by JSON. Qwen3/Qwen3.5 in default mode emits reasoning inside tags; we strip everything up to and including the closing tag. 4. XML function-call format that Qwen3.5 was trained on: bar... Empirically (see logs/run_2026-04-25_21-08-45) every Stage-1 completion emitted this XML form instead of JSON — so we extract it as a fallback to give GRPO a non-zero gradient before the model has been trained onto the JSON contract. """ raw = raw.strip() # Strip ... preamble emitted by thinking-mode models. if "" in raw: think_end = raw.rfind("") if think_end != -1: raw = raw[think_end + len(""):].strip() if raw.startswith("```"): lines = raw.split("\n") raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw # Try parsing the whole string, then fall back to the first {...} block. try: return json.loads(raw) except (json.JSONDecodeError, ValueError): pass start = raw.find("{") end = raw.rfind("}") if start != -1 and end > start: try: return json.loads(raw[start : end + 1]) except (json.JSONDecodeError, ValueError): pass # XML function-call fallback (Qwen3.5 default tool-call emission style). fn_match = _XML_FN_RE.search(raw) if fn_match: tool = fn_match.group(1).strip().strip("\"'") arguments: dict[str, Any] = {} for pname, pval in _XML_PARAM_RE.findall(raw): v = pval.strip() # Coerce numeric/bool literals so downstream validators accept them. try: arguments[pname] = json.loads(v) except (json.JSONDecodeError, ValueError): arguments[pname] = v return {"tool": tool, "arguments": arguments} return None # Bounded LRU-ish cache. Each snapshot is a deepcopy of CricketEnvironment # (~1 MB) and only used by the LEGACY single-turn r_environment_rollout path, # not by the multi-turn environment_factory training path. Cap at 4096 entries # (~4 GB worst case) so a long collect_prompts call can't blow up host RAM. _PROMPT_ENV_SNAPSHOTS: dict[str, CricketEnvironment] = {} _PROMPT_SNAPSHOT_CAP = 4096 _ENV_REWARD_ROLLOUT_STEPS = 12 def _remember_prompt(obs_text: str, env: CricketEnvironment) -> str: """Format an observation and keep the exact env state for rollout reward.""" prompt = _format_prompt(obs_text) if len(_PROMPT_ENV_SNAPSHOTS) >= _PROMPT_SNAPSHOT_CAP: # Evict oldest insertion (dict preserves insertion order in py3.7+). oldest_key = next(iter(_PROMPT_ENV_SNAPSHOTS)) del _PROMPT_ENV_SNAPSHOTS[oldest_key] _PROMPT_ENV_SNAPSHOTS[prompt] = copy.deepcopy(env) return prompt def r_environment_rollout(prompt: str, completion: str) -> float | None: """Env-backed score for a generated tool call plus short continuation. Returns None when the prompt was not collected from an env snapshot, allowing callers to fall back to stateless scoring. Otherwise returns [0, 1], where 0 means invalid JSON/tool-for-state and higher values reflect the env reward. """ snapshot = _PROMPT_ENV_SNAPSHOTS.get(prompt) if snapshot is None: return None data = _parse_completion(completion) if data is None: return 0.0 tool = data.get("tool", "") args = data.get("arguments", {}) if not isinstance(args, dict): return 0.0 env = copy.deepcopy(snapshot) if tool not in env._get_available_tools(): return 0.0 try: obs = env.step(CricketAction(tool=tool, arguments=args)) except Exception: return 0.0 reward = float(obs.reward or 0.0) rng = random.Random(hash(prompt + completion) & 0xFFFFFFFF) roster = build_playing_xi(getattr(env, "_agent_roster", [])) for _ in range(_ENV_REWARD_ROLLOUT_STEPS): if obs.done: break action = _random_action( rng, obs.game_state, obs.available_tools, obs.current_bowler.get("type") if obs.current_bowler else None, roster, ) obs = env.step(action) reward += float(obs.reward or 0.0) if obs.done and env.state.reward_breakdown: reward += float(env.state.reward_breakdown.get("composite", 0.0)) # Map rollout reward into [0,1] while preserving penalties for bad tool choices. return round(max(0.0, min(1.0, 0.5 + reward)), 4) def r_validity(completion: str) -> float: """Schema reward for tool calling. Exact env-executable calls receive 1.0. Malformed but parseable JSON gets a small shaping signal so early GRPO has non-zero variance before the model has learned the strict `{"tool": ..., "arguments": {...}}` contract. """ data = _parse_completion(completion) if data is None: return 0.0 if not isinstance(data, dict): return 0.05 tool = data.get("tool", "") args = data.get("arguments", {}) if "tool" not in data or "arguments" not in data: return 0.15 if tool not in _VALID_TOOLS: return 0.25 if not isinstance(args, dict): return 0.35 if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION: return 0.5 if tool == "set_strategy": agg = args.get("aggression") if not isinstance(agg, (int, float)): return 0.5 if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION: return 0.5 if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}: if args.get("bowler_type") not in (None, "pace", "spin"): return 0.5 return 1.0 # Kept for backward compatibility with smoke test r_format = r_validity def _bowling_phase_fit(delivery_type: str, phase: str) -> float: """Return 1.0 if delivery_type fits the phase, else 0.4. Loaded from game_knowledge.yaml.""" valid = _GK.bowling_phase_delivery.get(phase, []) return 1.0 if delivery_type in valid else 0.4 def _field_phase_fit(setting: str, phase: str) -> float: """Return phase-fit score for a field preset. Loaded from game_knowledge.yaml.""" return float(_GK.field_phase_fit.get(setting, {}).get(phase, 0.5)) def r_behavior_stateless(prompt: str, completion: str) -> float: """ r_behavior: plan-action coherence score, covering ALL tool types. Previously only graded play_delivery, leaving ~60% of decision points unscored. Now each tool family gets an appropriate coherence signal: - play_delivery / plan_shot : aggression-match + rationale + phase fit - set_strategy : phase appropriateness + rationale quality - set_bowling_strategy / plan_delivery : delivery-phase fit + rationale - set_field_setting : field-phase fit - reflect_after_ball : rationale specificity - choose_bowler / select_batter : rationale specificity - bowl_delivery / call_toss / analyze_situation : not graded (no plan) """ data = _parse_completion(completion) if data is None: return 0.0 tool = data.get("tool", "") args = data.get("arguments", {}) strategy = extract_strategy_from_prompt(prompt) phase = extract_phase_from_prompt(prompt) # Base reward for any valid structured action — ensures GRPO always has a # positive gradient to reinforce correct tool use even when coherence can't # be fully scored (no declared strategy, unscored tool types, etc.). _BASE = 0.10 if tool == "play_delivery": shot = args.get("shot_intent", "") if shot not in SHOT_AGGRESSION: return 0.0 if strategy is None: return _BASE # valid shot, no declared strategy to align against agg = strategy["aggression"] a_match = aggression_match(agg, shot) r_spec = rationale_specificity(strategy.get("rationale", "")) p_approp = phase_appropriate(agg, phase) return round(_BASE + (1 - _BASE) * (0.50 * a_match + 0.30 * r_spec + 0.20 * p_approp), 4) if tool == "plan_shot": shot = args.get("shot_intent", "") if shot not in SHOT_AGGRESSION: return 0.0 if strategy is None: return _BASE # valid plan structure, no context to grade against agg = strategy["aggression"] a_match = aggression_match(agg, shot) r_spec = rationale_specificity(args.get("rationale", "")) return round(_BASE + (1 - _BASE) * (0.60 * a_match + 0.40 * r_spec), 4) if tool == "set_strategy": agg = args.get("aggression", 0.5) try: agg = float(agg) except (TypeError, ValueError): agg = 0.5 r_spec = rationale_specificity(args.get("rationale", "")) p_approp = phase_appropriate(agg, phase) return round(0.50 * p_approp + 0.50 * r_spec, 4) if tool in {"set_bowling_strategy", "plan_delivery"}: delivery_type = args.get("delivery_type", "") r_spec = rationale_specificity(args.get("rationale", "")) p_fit = _bowling_phase_fit(delivery_type, phase) return round(0.50 * r_spec + 0.50 * p_fit, 4) if tool == "set_field_setting": setting = args.get("setting", "Balanced") return round(_field_phase_fit(setting, phase), 4) if tool == "reflect_after_ball": return round(rationale_specificity(args.get("reflection", "")), 4) if tool in {"choose_bowler", "select_batter"}: return round(rationale_specificity(args.get("rationale", "")), 4) if tool == "set_match_plan": # Score completeness + rationale quality fields = ["powerplay_intent", "middle_intent", "death_intent", "risk_budget", "trigger_conditions"] completeness = sum(1 for f in fields if args.get(f, "").strip()) / len(fields) r_spec = rationale_specificity(args.get("rationale", "")) return round(0.6 * completeness + 0.4 * r_spec, 4) if tool == "update_match_plan": # Score whether update is justified by a match-state trigger reason = args.get("reason", args.get("rationale", "")) triggers = ["wicket", "target", "rrr", "phase", "field", "rate", "pressure", "boundary", "dot"] hits = sum(1 for t in triggers if t in reason.lower()) r_spec = rationale_specificity(reason) return round(min(1.0, 0.5 * r_spec + 0.5 * min(hits / 3, 1.0)), 4) # bowl_delivery, call_toss, analyze_situation — structurally valid but no # coherence plan to grade. Return a small base so GRPO distinguishes these # from invalid JSON (which scores 0.0) without over-weighting them. if tool in {"bowl_delivery", "call_toss", "analyze_situation"}: return 0.15 return 0.0 def r_adaptation_stateless(prompt: str, completion: str) -> float: if r_validity(completion) == 0.0: return 0.0 # don't reward invalid tool calls for context-matching data = _parse_completion(completion) if data is None: return 0.0 text = json.dumps(data.get("arguments", {})).lower() phase = extract_phase_from_prompt(prompt) score = 0.0 if phase in text: score += 0.25 if any(word in prompt.lower() for word in ("target:", "death", "wicket", "opponent last plan")): score += 0.25 if any(word in text for word in ("adjust", "target", "field", "phase", "wicket", "pressure", "matchup")): score += 0.25 if data.get("tool") in {"plan_shot", "plan_delivery", "reflect_after_ball", "choose_bowler", "select_batter"}: score += 0.25 return round(min(score, 1.0), 4) def r_opponent_awareness_stateless(prompt: str, completion: str) -> float: if r_validity(completion) == 0.0: return 0.0 # don't reward invalid tool calls for context-matching data = _parse_completion(completion) if data is None: return 0.0 text = json.dumps(data.get("arguments", {})).lower() prompt_l = prompt.lower() hits = 0 for word in ("opponent", "field", "bowler", "batter", "spin", "pace", "aggressive", "defensive"): if word in prompt_l and word in text: hits += 1 return round(min(hits / 3, 1.0), 4) # ------------------------------------------------------------------ # # Composite reward function — TRL 0.24 signature # # ------------------------------------------------------------------ # def make_reward_fn(curriculum_stage: int): """ Returns reward_fn(prompts, completions, **kwargs) → list[float]. Weights align with compute_episode_reward in reward_calculator.py: r_env — one-step env rollout reward when prompt snapshot exists r_behavior — stateless tactical/tool coherence r_validity — JSON/tool schema validity """ # Minimum reward for any structurally valid completion — ensures GRPO has a # positive gradient to reinforce valid tool use even for unscored tool types. _VALID_FLOOR = 0.05 def reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]: rewards = [] for prompt, completion in zip(prompts, completions): fmt = r_validity(completion) env_score = r_environment_rollout(prompt, completion) if curriculum_stage == 1: # Length-efficiency penalty: a valid JSON tool call is ≤400 chars. # Models with thinking mode (Qwen3/3.5) generate 800-2000 char # preambles before the JSON; penalise that verbosity so GRPO # learns to emit short, direct JSON. The penalty scales from # 1.0 at ≤400 chars to 0.0 at ≥2400 chars (linear). _JSON_TARGET = 400 _RAMP_RANGE = 2000 length_eff = max(0.0, 1.0 - max(0, len(completion) - _JSON_TARGET) / _RAMP_RANGE) base = 0.5 * fmt + 0.5 * (env_score if env_score is not None else fmt) rewards.append(round(length_eff * base, 4)) continue behavior = r_behavior_stateless(prompt, completion) adapt = r_adaptation_stateless(prompt, completion) aware = r_opponent_awareness_stateless(prompt, completion) r_beh = ( _RW.behavior_coherence * behavior + _RW.behavior_adaptation * adapt + _RW.behavior_opponent_awareness * aware ) if env_score is None: reward = _RW.training_behavior * r_beh + _RW.training_validity * fmt else: reward = 0.45 * env_score + 0.40 * r_beh + 0.15 * fmt # Floor: valid JSON should always beat invalid JSON (reward=0) if fmt > 0.0 and (env_score is None or env_score > 0.0): reward = max(reward, _VALID_FLOOR) rewards.append(round(reward, 4)) return rewards reward_fn.__name__ = f"stage{curriculum_stage}_reward" return reward_fn # ------------------------------------------------------------------ # # Prompt collection (direct env instantiation — no server needed) # # ------------------------------------------------------------------ # SYSTEM_PROMPT = ( "You are an expert adaptive cricket captain. Each turn you receive a scorecard " "and must choose exactly one cricket captaincy tool call.\n\n" "EXECUTE FIRST — strict rule:\n" " - The match only progresses when you call `play_delivery` (batting) or\n" " `bowl_delivery` (bowling). Every other tool is overhead.\n" " - Default action on EVERY ball: call `play_delivery` / `bowl_delivery` with\n" " plan args INLINE: e.g. `play_delivery(shot_intent='single', risk='low', rationale='rotate')`\n" " or `bowl_delivery(line='outside_off', length='good', delivery_type='stock')`.\n" " - Use `set_match_plan` ONCE at the very start of an innings to declare strategy.\n" " - Use `set_strategy` / `set_bowling_strategy` ONCE per phase boundary.\n" " - DO NOT call `plan_shot` or `plan_delivery` (deprecated) — they only add a\n" " wasted turn. Pass the same parameters to play_delivery / bowl_delivery directly.\n" " - SKIP `reflect_after_ball` unless the previous ball was a wicket or boundary.\n" " - You are scored on MATCH OUTCOMES, not on philosophical depth. Bloated\n" " pre-ball planning truncates the episode and you forfeit the result reward.\n\n" "THINKING BUDGET — HARD LIMIT:\n" " - Per turn: ONE sentence of reasoning, max 30 tokens, inside ....\n" " - Do NOT enumerate options, restate the scorecard, or re-derive the plan.\n" " - Bad: 'This is the first ball, the field is balanced, Kohli is on strike at 0.45 aggression, I should consider...'\n" " - Good: 'Powerplay, balanced field — single to rotate.'\n" " - Token budget per rollout is finite. Long thinking = match truncated = ZERO result reward.\n" " - The plan you set at the start carries the strategy; do not re-derive it every ball.\n\n" "Emit exactly one tool call wrapped in ... XML tags. " "Bare JSON without the wrapper is NOT recognized and will end the rollout.\n" 'Example: {"name": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotate strike"}}\n\n' "Available tools:\n" " call_toss — Call heads/tails and choose bat/bowl\n" " select_batter — Choose batter profile for the match situation\n" " set_strategy — Declare batting intent (aggression 0–1, rationale)\n" " plan_shot — Pre-ball batting plan\n" " play_delivery — Choose a shot and advance the game\n" " choose_bowler — Choose bowler profile for the situation\n" " set_bowling_strategy — Declare bowling line/length/type/rationale\n" " plan_delivery — Pre-ball bowling plan\n" " set_field_setting — Aggressive/Balanced/Defensive field\n" " bowl_delivery — Execute the delivery\n" " reflect_after_ball — Adapt after the previous ball\n" " analyze_situation — Query pitch/bowler/field info\n\n" "Shot intents: leave | defensive | single | rotate | boundary | six\n\n" "PRIORITIES (in order): (1) finish the match, (2) win the match, (3) score well per ball.\n" "Verbose reasoning forfeits all three. Decide fast, act, move on." ) def get_system_prompt(stage: int = 2) -> str: return SYSTEM_PROMPT _RANDOM_SHOTS = list(SHOT_AGGRESSION.keys()) _RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"] _RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"] def _training_roster(agent_team: str | None = None) -> list[dict]: team = agent_team or os.environ.get("CRICKET_AGENT_TEAM") if not team: raise ValueError("Roster-backed training requires --agent-team or CRICKET_AGENT_TEAM.") roster = load_team_roster(team) if not roster: raise ValueError(f"No player profile roster found for agent team '{team}'.") playing_xi = build_playing_xi(roster) if len(playing_xi) < 11: raise ValueError(f"Player profile roster for '{team}' could not produce a playing XI.") return playing_xi def _sample_batter(rng: random.Random, roster: list[dict]) -> dict: batters = [p for p in roster if p.get("role") != "bowler"] or roster if not batters: raise ValueError("Roster-backed training requires at least one batting-capable player.") return rng.choice(batters) def _sample_bowler(rng: random.Random, roster: list[dict]) -> dict: bowlers = [p for p in roster if p.get("bowler_type")] if not bowlers: raise ValueError("Roster-backed training requires at least one bowling-capable player.") return rng.choice(bowlers) def _random_action( rng: random.Random, game_state: str = "batting", available_tools: list[str] | None = None, current_bowler_type: str | None = None, roster: list[dict] | None = None, ) -> CricketAction: legal = set(available_tools or []) def can(tool: str) -> bool: return available_tools is None or tool in legal def match_plan_action() -> CricketAction: return CricketAction(tool="set_match_plan", arguments={ "powerplay_intent": "Use roster strengths to establish tempo while protecting wickets", "middle_intent": "Rotate strike, attack favorable matchups, and preserve finishers", "death_intent": "Commit boundary options with wickets and target pressure in mind", "risk_budget": "Escalate only when phase, target, and wickets justify the risk", "trigger_conditions": "Review after wicket clusters, phase changes, target pressure, or repeated boundary/dot outcomes", "rationale": "Create a long-horizon plan before choosing ball-by-ball tactics", }) if game_state == "toss": return CricketAction( tool="call_toss", arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])}, ) if can("set_match_plan") and rng.random() < 0.12: return match_plan_action() if can("update_match_plan") and rng.random() < 0.08: return CricketAction(tool="update_match_plan", arguments={ "reason": "Adjust plan after phase, score pressure, wickets, and field information", "risk_budget": "Shift risk based on current target pressure and wickets in hand", }) if game_state == "bowling": choice = rng.random() if choice < 0.15 and can("choose_bowler"): bowler = _sample_bowler(rng, roster or []) return CricketAction( tool="choose_bowler", arguments={ "name": bowler["name"], "bowler_type": bowler["bowler_type"], "style": bowler.get("bowl_style", bowler.get("style", "stock")), "rationale": "Match roster bowler to phase, batter matchup, and remaining overs", }, ) if choice < 0.35 and can("plan_delivery"): return CricketAction( tool="plan_delivery", arguments={ "bowler_type": current_bowler_type or rng.choice(["pace", "spin"]), "line": rng.choice(["stumps", "outside off", "wide"]), "length": rng.choice(["good length", "full", "short", "yorker"]), "delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]), "rationale": "Use field and batter style to control scoring zones", }, ) if choice < 0.5 and can("set_field_setting"): return CricketAction(tool="set_field_setting", arguments={"setting": rng.choice(["Aggressive", "Balanced", "Defensive"])}) if choice < 0.6 and can("reflect_after_ball"): return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Adjust line and field after the last ball"}) if can("bowl_delivery"): return CricketAction(tool="bowl_delivery", arguments={}) if can("set_bowling_strategy"): return CricketAction(tool="set_bowling_strategy", arguments={ "bowler_type": current_bowler_type or "pace", "line": "outside off", "length": "good length", "delivery_type": "stock", "rationale": "Set a legal bowling plan before executing the delivery", }) raise ValueError(f"No legal bowling action available from tools={available_tools}") choice = rng.random() if choice < 0.15 and can("select_batter"): batter = _sample_batter(rng, roster or []) return CricketAction( tool="select_batter", arguments={ "name": batter["name"], "style": batter.get("style", "balanced"), "aggression": round(float(batter["aggression"]), 2), "rationale": "Select batter based on phase, wickets, and target pressure", }, ) if choice < 0.3 and can("set_strategy"): return CricketAction( tool="set_strategy", arguments={ "phase_intent": rng.choice(["attack", "consolidate", "rotate"]), "aggression": round(rng.uniform(0.1, 0.9), 2), "rationale": "Align roster strengths with phase, target pressure, and wickets", }, ) if choice < 0.45 and can("plan_shot"): return CricketAction( tool="plan_shot", arguments={ "shot_intent": rng.choice(_RANDOM_SHOTS), "target_area": rng.choice(_RANDOM_ZONES), "trajectory": rng.choice(["ground", "lofted", "aerial"]), "risk": rng.choice(["low", "balanced", "high"]), "rationale": "Plan shot against bowler, field, and required rate", }, ) if choice < 0.55 and can("analyze_situation"): return CricketAction( tool="analyze_situation", arguments={"query_type": rng.choice(_RANDOM_QUERIES)}, ) if choice < 0.65 and can("reflect_after_ball"): return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Revise risk after previous ball"}) if can("play_delivery"): return CricketAction( tool="play_delivery", arguments={"shot_intent": rng.choice(_RANDOM_SHOTS), "explanation": "Advance the innings according to the current plan"}, ) raise ValueError(f"No legal batting action available from tools={available_tools}") def collect_prompts( n_prompts: int, task: str = "stage2_full", seed: int = 42, agent_team: str | None = None, opponent_mode: str = "heuristic", ) -> list[str]: """ Collect game-state prompts by running episodes with random actions. Returns a list of prompt strings (one per game state observation). """ rng = random.Random(seed) roster = _training_roster(agent_team) _PROMPT_ENV_SNAPSHOTS.clear() prompts: list[str] = [] ep_count = 0 while len(prompts) < n_prompts: env = CricketEnvironment() obs = env.reset(seed=rng.randint(0, 99999), options={ "task": task, "random_start": True, "agent_team": agent_team or os.environ.get("CRICKET_AGENT_TEAM"), "opponent_mode": opponent_mode, }) prompts.append(_remember_prompt(obs.prompt_text, env)) steps = 0 while not obs.done and steps < 80: action = _random_action( rng, obs.game_state, obs.available_tools, obs.current_bowler.get("type") if obs.current_bowler else None, roster, ) obs = env.step(action) if not obs.done: prompts.append(_remember_prompt(obs.prompt_text, env)) steps += 1 ep_count += 1 if ep_count % 10 == 0: print(f" Collected {len(prompts)} prompts from {ep_count} episodes …", flush=True) print(f"Collected {len(prompts)} prompts from {ep_count} episodes.") return prompts[:n_prompts] def _format_prompt(obs_text: str) -> str: """Wrap the observation in a chat-style user message.""" return f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{obs_text}<|im_end|>\n<|im_start|>assistant\n" def build_dataset(prompts: list[str]) -> Dataset: if Dataset is None: raise ImportError("datasets is required for training. Install with: pip install '.[train]'") return Dataset.from_dict({"prompt": prompts}) class CricketCaptainToolEnv: """TRL environment wrapper exposing CricketCaptain actions as real tools.""" _stats_lock = threading.Lock() def __init__(self): self.env = CricketEnvironment() self.reward = 0.0 self.done = False self.final_reward = 0.0 self._episode_seed: int | None = None self._episode_started = False self._max_tool_iters: int | None = None self._episode_had_step = False self._episode_logged = False def _maybe_log_episode_end(self, termination_reason: str): # Avoid double-logging the same episode (e.g. once at termination, again on reset()). if self._episode_logged: return stats_path = os.environ.get("CRICKET_EPISODE_STATS_PATH") if not stats_path: return state = getattr(self.env, "state", None) payload = { "ts": datetime.datetime.now().isoformat(), "seed": self._episode_seed, "done": bool(self.done), "termination_reason": termination_reason, "reward_running_sum": float(self.reward), "final_reward_bonus": float(self.final_reward), } if state is not None: # ---- match config / context ---- payload["max_overs"] = getattr(state, "max_overs", None) payload["opponent_mode"] = getattr(state, "opponent_mode", None) payload["agent_team"] = getattr(state, "eval_pack_id", None) or getattr(state, "agent_team", None) payload["innings_type"] = getattr(state, "innings_type", None) payload["game_state"] = getattr(state, "game_state", None) # ---- match outcome ---- payload["overs_played"] = getattr(state, "over", None) payload["balls_played"] = getattr(state, "ball", None) payload["agent_score"] = getattr(state, "total_score", None) payload["wickets_lost"] = getattr(state, "wickets_lost", None) payload["first_innings_score"] = getattr(state, "first_innings_score", None) payload["target"] = getattr(state, "target", None) payload["match_result"] = getattr(state, "match_result", None) or None # ---- tool calls ---- tool_calls_made = int(getattr(state, "tool_calls_made", 0) or 0) payload["tool_calls"] = tool_calls_made tool_history = getattr(state, "tool_history", None) or [] tool_breakdown: dict[str, int] = {} for c in tool_history: t = c.get("tool", "unknown") tool_breakdown[t] = tool_breakdown.get(t, 0) + 1 payload["tool_breakdown"] = tool_breakdown payload["analyze_calls"] = len(getattr(state, "analyze_calls", []) or []) # ---- per-turn rubric averages (mean across the full episode) ---- def _mean(xs): xs = list(xs or []) return round(sum(xs) / len(xs), 4) if xs else None payload["mean_coherence"] = _mean(getattr(state, "coherence_scores", None)) payload["mean_adaptation"] = _mean(getattr(state, "adaptation_scores", None)) payload["mean_opponent_awareness"] = _mean(getattr(state, "opponent_awareness_scores", None)) payload["mean_regret"] = _mean(getattr(state, "regret_scores", None)) payload["mean_plan_commitment"] = _mean(getattr(state, "plan_commitment_scores", None)) payload["mean_plan_freshness"] = _mean(getattr(state, "plan_freshness_scores", None)) payload["strategy_changes"] = getattr(state, "strategy_changes", None) payload["plan_version"] = getattr(state, "plan_version", None) # ---- composite + per-rubric reward (already computed in reward_calculator) ---- if getattr(state, "reward_breakdown", None): payload["reward_breakdown"] = dict(state.reward_breakdown) with self._stats_lock: with open(stats_path, "a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") f.flush() self._episode_logged = True def reset(self, **kwargs) -> str: # If the previous episode ended because the trainer hit the tool-iteration cap, # TRL will stop calling tools and then call reset() for the next scenario. # In that case, self.done will still be False, but tool_calls_made will be at/near the cap. if self._episode_started and self._episode_had_step and not self._episode_logged: prev_calls = getattr(getattr(self.env, "state", None), "tool_calls_made", None) if self.done: self._maybe_log_episode_end("natural") elif self._max_tool_iters and prev_calls is not None and int(prev_calls) >= int(self._max_tool_iters): self._maybe_log_episode_end("cap") # Otherwise: trainer reset the env mid-episode (e.g. generation bookkeeping). # Don't log — it would skew the termination distribution. self.reward = 0.0 self.done = False self.final_reward = 0.0 self._episode_seed = kwargs.get("seed") self._episode_started = True self._episode_had_step = False self._episode_logged = False self._max_tool_iters = ( int(kwargs["max_tool_calling_iterations"]) if "max_tool_calling_iterations" in kwargs and kwargs["max_tool_calling_iterations"] is not None else (int(os.environ["CRICKET_MAX_TOOL_ITERS"]) if os.environ.get("CRICKET_MAX_TOOL_ITERS") else None) ) obs = self.env.reset(seed=kwargs.get("seed"), options={ "task": kwargs.get("task", "stage2_full"), "random_start": bool(kwargs.get("random_start", False)), "max_overs": int(kwargs.get("max_overs", 5)), "eval_pack_id": kwargs.get("eval_pack_id", "adaptive_t20_v1"), "opponent_mode": kwargs.get("opponent_mode", "heuristic"), "opponent_cache_path": kwargs.get("opponent_cache_path"), "agent_team": kwargs.get("agent_team"), }) return obs.prompt_text def _apply(self, tool: str, arguments: dict[str, Any]) -> str: if self.done: raise ValueError("Match is already finished.") self._episode_had_step = True available = self.env.state.game_state and self.env._get_available_tools() if tool not in available: self.reward -= 0.2 raise ValueError(f"Tool '{tool}' is not available. Available tools: {available}") obs = self.env.step(CricketAction(tool=tool, arguments=arguments)) self.done = bool(obs.done) self.reward += float(obs.reward or 0.0) if obs.done and self.env.state.reward_breakdown: self.final_reward = float(self.env.state.reward_breakdown.get("composite", 0.0)) self.reward += self.final_reward # Log at the time of termination (do not wait for reset()) so the file appears promptly. if self.done: self._maybe_log_episode_end("natural") # Also log cap termination as soon as we hit it, so runs always get stats even if TRL delays reset(). elif self._max_tool_iters: state = getattr(self.env, "state", None) calls = getattr(state, "tool_calls_made", None) if state is not None else None if calls is not None and int(calls) >= int(self._max_tool_iters): self._maybe_log_episode_end("cap") return obs.prompt_text def call_toss(self, call: str, decision: str) -> str: """ Call the coin toss and choose whether to bat or bowl if the toss is won. Args: call: Coin call, either "heads" or "tails". decision: Preferred decision, either "bat" or "bowl". Returns: Updated match observation after the toss. """ return self._apply("call_toss", {"call": call, "decision": decision}) def set_match_plan( self, powerplay_intent: str, middle_intent: str, death_intent: str, risk_budget: str, trigger_conditions: str, rationale: str, ) -> str: """ Establish the long-horizon plan for the innings. Args: powerplay_intent: Plan for overs in the powerplay. middle_intent: Plan for middle overs. death_intent: Plan for death overs. risk_budget: How wickets, overs, and target pressure affect risk. trigger_conditions: Match-state changes that should trigger a plan update. rationale: Why this plan fits the roster and match situation. Returns: Updated match observation after setting the plan. """ return self._apply("set_match_plan", { "powerplay_intent": powerplay_intent, "middle_intent": middle_intent, "death_intent": death_intent, "risk_budget": risk_budget, "trigger_conditions": trigger_conditions, "rationale": rationale, }) def update_match_plan(self, reason: str, risk_budget: str = "", trigger_conditions: str = "") -> str: """ Update the long-horizon plan after a meaningful match-state change. Args: reason: Specific reason for updating the plan. risk_budget: Optional revised risk budget. trigger_conditions: Optional revised trigger conditions. Returns: Updated match observation after revising the plan. """ args = {"reason": reason} if risk_budget: args["risk_budget"] = risk_budget if trigger_conditions: args["trigger_conditions"] = trigger_conditions return self._apply("update_match_plan", args) def select_batter(self, name: str, style: str, aggression: float, rationale: str) -> str: """ Select the next batter from the configured roster. Args: name: Player name from the team roster. style: Batter style from the roster or tactical role. aggression: Batting aggression from 0.0 to 1.0. rationale: Why this batter fits the phase, wickets, and target. Returns: Updated match observation after selecting the batter. """ return self._apply("select_batter", { "name": name, "style": style, "aggression": aggression, "rationale": rationale, }) def set_strategy(self, phase_intent: str, aggression: float, rationale: str) -> str: """ Set batting strategy for the current phase. Args: phase_intent: Tactical batting intent for this phase. aggression: Batting aggression from 0.0 to 1.0. rationale: Why the strategy fits score, wickets, target, and field. Returns: Updated match observation after setting batting strategy. """ return self._apply("set_strategy", { "phase_intent": phase_intent, "aggression": aggression, "rationale": rationale, }) def plan_shot(self, shot_intent: str, target_area: str, risk: str, trajectory: str, rationale: str) -> str: """DEPRECATED — pass these args inline to play_delivery() instead. Args: shot_intent: leave|defensive|single|rotate|boundary|six. target_area: scoring area. risk: low|balanced|high. trajectory: ground|lofted|aerial. rationale: one-line reason. Returns: Updated observation. """ return self._apply("plan_shot", { "shot_intent": shot_intent, "target_area": target_area, "risk": risk, "trajectory": trajectory, "rationale": rationale, }) def play_delivery( self, shot_intent: str = "", target_area: str = "", risk: str = "", trajectory: str = "", rationale: str = "", ) -> str: """ Execute the ball. Pass shot params inline to skip plan_shot. Args: shot_intent: leave|defensive|single|rotate|boundary|six. target_area: optional scoring area. risk: optional low|balanced|high. trajectory: optional ground|lofted|aerial. rationale: optional one-line reason. Returns: Updated observation after the ball outcome. """ args: dict[str, Any] = {} if shot_intent: args["shot_intent"] = shot_intent if target_area: args["target_area"] = target_area if risk: args["risk"] = risk if trajectory: args["trajectory"] = trajectory if rationale: args["rationale"] = rationale return self._apply("play_delivery", args) def choose_bowler(self, name: str, bowler_type: str, style: str, rationale: str) -> str: """ Choose the bowler at the start of an over from the configured roster. Args: name: Player name from the team roster. bowler_type: Bowler type, either pace or spin. style: Bowling style or role. rationale: Why this bowler fits phase, matchup, and remaining overs. Returns: Updated match observation after choosing the bowler. """ return self._apply("choose_bowler", { "name": name, "bowler_type": bowler_type, "style": style, "rationale": rationale, }) def set_bowling_strategy(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str: """ Set bowling strategy for the current bowler. Args: bowler_type: Current bowler type, either pace or spin. line: Intended line. length: Intended length. delivery_type: Variation or stock delivery type. rationale: Why this plan fits batter, field, phase, and target. Returns: Updated match observation after setting bowling strategy. """ return self._apply("set_bowling_strategy", { "bowler_type": bowler_type, "line": line, "length": length, "delivery_type": delivery_type, "rationale": rationale, }) def plan_delivery(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str: """DEPRECATED — pass these args inline to bowl_delivery() instead. Args: bowler_type: pace|spin. line: line. length: length. delivery_type: variation or stock. rationale: one-line reason. Returns: Updated observation. """ return self._apply("plan_delivery", { "bowler_type": bowler_type, "line": line, "length": length, "delivery_type": delivery_type, "rationale": rationale, }) def set_field_setting(self, setting: str) -> str: """ Set the field preset. Args: setting: One of Aggressive, Balanced, or Defensive. Returns: Updated match observation after setting the field. """ return self._apply("set_field_setting", {"setting": setting}) def bowl_delivery( self, line: str = "", length: str = "", delivery_type: str = "", rationale: str = "", ) -> str: """ Execute the delivery. Pass plan params inline to skip plan_delivery. Args: line: optional line. length: optional length. delivery_type: optional variation or stock. rationale: optional one-line reason. Returns: Updated observation after the ball outcome. """ args: dict[str, Any] = {} if line: args["line"] = line if length: args["length"] = length if delivery_type: args["delivery_type"] = delivery_type if rationale: args["rationale"] = rationale return self._apply("bowl_delivery", args) def reflect_after_ball(self, reflection: str) -> str: """ Reflect after the previous ball and adapt the plan. Args: reflection: Specific tactical lesson from the previous ball. Returns: Updated match observation after recording reflection. """ return self._apply("reflect_after_ball", {"reflection": reflection}) def analyze_situation(self, query_type: str) -> str: """ Analyze part of the match context. Args: query_type: One of pitch_conditions, bowler_info, field_setting, or match_situation. Returns: Updated observation containing the analysis result. """ return self._apply("analyze_situation", {"query_type": query_type}) def build_agent_dataset(n_examples: int, args) -> Dataset: if Dataset is None: raise ImportError("datasets is required for training. Install with: pip install '.[train]'") rows = [] rng = random.Random(args.seed) stage_prompt = get_system_prompt(args.stage) # Curriculum distribution. If --max-overs is set, use it as a fixed format. # Otherwise sample per-scenario from a T2-heavy distribution that tapers to T5. # Rationale: T2 episodes (~72 tool calls) actually COMPLETE within our token # budget so r_result fires; T5 episodes (~180) train the model on its # eval distribution. Heavy weight on short formats early so the policy # escapes the "planning loop" before tackling longer matches. overs_distribution = getattr(args, "overs_distribution", None) fixed_overs = args.max_overs if args.max_overs and args.max_overs > 0 else None if fixed_overs is None and not overs_distribution: # default curriculum: 50% T2, 30% T3, 15% T4, 5% T5 overs_distribution = [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5] for idx in range(n_examples): scenario_overs = fixed_overs if fixed_overs is not None else rng.choice(overs_distribution) rows.append({ "prompt": [ {"role": "system", "content": stage_prompt}, {"role": "user", "content": ""}, ], "seed": rng.randint(0, 999999), "task": "stage1_format" if args.stage == 1 else "stage2_full", "random_start": False, "max_overs": scenario_overs, "eval_pack_id": args.eval_pack_id, "opponent_mode": args.opponent_mode, "opponent_cache_path": getattr(args, "opponent_cache_path", None), "agent_team": args.agent_team, "scenario_id": idx, }) return Dataset.from_list(rows) def environment_reward(environments, **kwargs) -> list[float]: rewards = [] # Aggregate metrics across all envs in this gradient step for WandB logging. agg = { "r_result": [], "r_cricket": [], "r_behavior": [], "r_validity": [], "r_coherence": [], "r_adaptation": [], "r_opponent_awareness": [], "r_regret": [], "composite": [], "tool_calls": [], "wickets_lost": [], "agent_score": [], "matches_completed": 0, "n": 0, } tool_freq: dict[str, int] = {} for env in environments: state = env.env.state breakdown = state.reward_breakdown or {} terminal = float(breakdown.get("composite", 0.0)) plan_score = (sum(state.plan_commitment_scores) / len(state.plan_commitment_scores)) if state.plan_commitment_scores else 0.0 validity = 1.0 - min(1.0, len([c for c in state.tool_history if c.get("tool") == "invalid_json"]) / max(state.step_count, 1)) reward = env.reward + terminal + 0.1 * plan_score + 0.05 * validity # Reward clip removed: when rollouts complete naturally, the composite # reward easily saturates [-1, 1], causing GRPO group-std → 0 and # killing the gradient signal. Let GRPO standardize the advantage itself. rewards.append(round(reward, 4)) # Collect for aggregate logging. agg["n"] += 1 if env.done: agg["matches_completed"] += 1 for k in ("r_result", "r_cricket", "r_behavior", "r_validity", "r_coherence", "r_adaptation", "r_opponent_awareness", "r_regret", "composite"): v = breakdown.get(k) if v is not None: agg[k].append(float(v)) agg["tool_calls"].append(int(getattr(state, "tool_calls_made", 0) or 0)) agg["wickets_lost"].append(int(getattr(state, "wickets_lost", 0) or 0)) agg["agent_score"].append(int(getattr(state, "total_score", 0) or 0)) for c in (state.tool_history or []): t = c.get("tool", "unknown") tool_freq[t] = tool_freq.get(t, 0) + 1 # WandB log — only if wandb is initialised in this process. try: import wandb if wandb.run is not None and agg["n"] > 0: log_dict: dict[str, float] = { "rollout/n_episodes": agg["n"], "rollout/matches_completed": agg["matches_completed"], "rollout/match_completion_rate": agg["matches_completed"] / agg["n"], } for k in ("r_result", "r_cricket", "r_behavior", "r_validity", "r_coherence", "r_adaptation", "r_opponent_awareness", "r_regret", "composite"): if agg[k]: log_dict[f"reward/{k}_mean"] = sum(agg[k]) / len(agg[k]) log_dict[f"reward/{k}_max"] = max(agg[k]) log_dict[f"reward/{k}_min"] = min(agg[k]) for k in ("tool_calls", "wickets_lost", "agent_score"): if agg[k]: log_dict[f"episode/{k}_mean"] = sum(agg[k]) / len(agg[k]) log_dict[f"episode/{k}_max"] = max(agg[k]) # Tool usage breakdown — frequency per tool name across this step. total_tools = sum(tool_freq.values()) or 1 for t, n in tool_freq.items(): log_dict[f"tools/freq_{t}"] = n / total_tools wandb.log(log_dict) except Exception: # Never let logging break training. pass return rewards def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42, agent_team: str | None = None): """Stage 0 bootstrap data: valid tool JSON for every tool family.""" rng = random.Random(seed) roster = _training_roster(agent_team) examples = [] for _ in range(n_examples): game_state = rng.choice(["toss", "batting", "bowling"]) action = _random_action(rng, game_state, roster=roster) prompt = ( f"{SYSTEM_PROMPT}\n\n" f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n" "Phase: MIDDLE | Strategic turn: PRE_BALL\n" "Opponent last plan: {'field_setting': 'Defensive', 'shot_intent': 'boundary'}\n" ) examples.append({ "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, {"role": "assistant", "content": json.dumps({"tool": action.tool, "arguments": action.arguments})}, ] }) out = Path(out_path) out.parent.mkdir(parents=True, exist_ok=True) with out.open("w") as f: for ex in examples: f.write(json.dumps(ex) + "\n") print(f"Wrote {len(examples)} SFT examples -> {out}") # ------------------------------------------------------------------ # # Model loading (plain transformers + bitsandbytes 4-bit) # # ------------------------------------------------------------------ # def load_model(model_name: str, *, use_vllm: bool = False, resume_adapter_from: str | None = None): """Load base + LoRA. When use_vllm=True, base is loaded in bf16 (vLLM does not support 4-bit BNB inference); otherwise 4-bit NF4. resume_adapter_from: optional path to a PEFT adapter directory (e.g. a previous checkpoint dir). If provided, loads the adapter weights instead of initializing a fresh LoRA. The base model is still loaded from `model_name`. The adapter's LoraConfig is preserved (so you can resume even if r= or alpha= drift between runs).""" if not _TRAIN_IMPORTS_AVAILABLE: raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'") print(f"Loading {model_name} … (use_vllm={use_vllm}, dtype={'bf16' if use_vllm else '4-bit NF4'})") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: import flash_attn # noqa: F401 attn_impl = "flash_attention_2" except ImportError: attn_impl = "sdpa" load_kwargs = dict( device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, ) if not use_vllm: load_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs) if not use_vllm: model = prepare_model_for_kbit_training(model) if resume_adapter_from: # Resume from a previous PEFT adapter checkpoint (e.g. warmup output). # PeftModel.from_pretrained reads the adapter_config.json from the dir, # so any r/alpha/target_modules saved with the warmup run is preserved. from peft import PeftModel adapter_path = Path(resume_adapter_from) if not adapter_path.exists(): raise FileNotFoundError(f"resume_adapter_from path does not exist: {adapter_path}") print(f"Resuming LoRA adapter from {adapter_path}") model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True) else: lora_cfg = LoraConfig( r=64, lora_alpha=128, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, lora_cfg) print(f"Loaded. Parameters: {model.num_parameters():,}") model.print_trainable_parameters() return model, tokenizer # ------------------------------------------------------------------ # # Training # # ------------------------------------------------------------------ # def train(args): if not _TRAIN_IMPORTS_AVAILABLE: raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'") if args.opponent_mode == "llm_live": if args.opponent_model: os.environ["CRICKET_OPPONENT_MODEL"] = args.opponent_model if args.opponent_api_base: os.environ["CRICKET_OPPONENT_API_BASE"] = args.opponent_api_base if args.opponent_api_key: os.environ["CRICKET_OPPONENT_API_KEY"] = args.opponent_api_key task = "stage1_format" if args.stage == 1 else "stage2_full" # CRICKET_CKPT_ROOT lets a side-by-side run write checkpoints to a different # tree (e.g. ./checkpoints_smoke) without trampling an active production run. # Default unchanged: ./checkpoints/. ckpt_root = os.environ.get("CRICKET_CKPT_ROOT", "./checkpoints").rstrip("/") out_dir = f"{ckpt_root}/stage{args.stage}" save_dir = f"{ckpt_root}/stage{args.stage}_final" ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") log_dir = Path(f"./logs/run_{ts}_stage{args.stage}_{args.opponent_mode}") log_dir.mkdir(parents=True, exist_ok=True) # Make episode termination stats available to the environment wrapper. # This lets us distinguish natural terminations from tool-iteration cap truncations. stats_path = log_dir / "episode_stats.jsonl" os.environ["CRICKET_EPISODE_STATS_PATH"] = str(stats_path) os.environ["CRICKET_MAX_TOOL_ITERS"] = str(args.max_tool_calling_iterations) # Create the file immediately so users can find/tail it even before the first termination. stats_path.touch(exist_ok=True) print(f"\n=== Stage {args.stage} Training ===") print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}") print(f"Logs: {log_dir}/ | Checkpoints: {out_dir}/") print(f"max_tool_calling_iterations={args.max_tool_calling_iterations} (full 5-over match needs ~180; 20-over needs ~720)") (log_dir / "metadata.json").write_text(json.dumps({ "stage": args.stage, "model": args.model, "agent_team": args.agent_team, "max_overs": args.max_overs, "opponent_mode": args.opponent_mode, "prompts": args.prompts, "steps": args.steps, "batch_size": args.batch_size, "grad_accum": args.grad_accum, "num_generations": args.num_generations, "max_completion_length": args.max_completion_length, "max_tool_calling_iterations": args.max_tool_calling_iterations, "logging_steps": args.logging_steps, "timestamp": ts, }, indent=2)) # Build scenario seeds. TRL's environment_factory performs the actual # multi-turn rollout and tool execution during training. print("\nBuilding environment scenarios …") dataset = build_agent_dataset(args.prompts, args) # Load model — bf16 if vLLM is on (vLLM rejects 4-bit BNB) or --bf16-base, else 4-bit NF4. # If resume_from is set, load the LoRA adapter from that path instead of fresh init. bf16_base = getattr(args, "use_vllm", False) or getattr(args, "bf16_base", False) resume_from = getattr(args, "resume_from", None) model, tokenizer = load_model(args.model, use_vllm=bf16_base, resume_adapter_from=resume_from) # GRPO config # # Qwen3 / Qwen3.5 ship with hybrid thinking ENABLED by default. Empirically # (see logs/run_2026-04-25_21-08-45 completions parquet) every generation # spent ~1200 chars inside ... and then emitted XML-style # tags instead of the JSON tool call we asked for. # That meant 0/32 generations were parseable, _apply() never advanced the # match, and episodes always hit max_tool_calling_iterations before any # innings finished — so r_result (55% of the composite) was never earned. # chat_template_kwargs = {} generation_kwargs = {} completion_len = max(args.max_completion_length, 2048) use_vllm = getattr(args, "use_vllm", False) vllm_kwargs = {} if use_vllm: # vllm_model_impl: None (default) → vLLM picks its native class. Use this for # Qwen3-* (Qwen3ForCausalLM is registered, native path with full LoRA support). # Set to "transformers" only for Qwen3.5-* where vLLM has no text-only class # registered and the native path tries to load a vision tower. vllm_kwargs = dict( use_vllm=True, vllm_mode="colocate", vllm_gpu_memory_utilization=getattr(args, "vllm_gpu_memory", 0.5), vllm_max_model_length=completion_len + 2048, ) vllm_impl = getattr(args, "vllm_model_impl", None) if vllm_impl: vllm_kwargs["vllm_model_impl"] = vllm_impl # Resolve hyperparameters from YAML/CLI with sensible fallbacks. lr = args.learning_rate if getattr(args, "learning_rate", None) is not None \ else (2e-5 if args.stage == 1 else 1e-5) grpo_beta = getattr(args, "beta", None) grpo_temp = getattr(args, "temperature", None) or 0.8 grpo_top_p = getattr(args, "top_p", None) grad_ckpt = getattr(args, "gradient_checkpointing", None) grad_ckpt_kwargs = None if grad_ckpt and getattr(args, "gradient_checkpointing_use_reentrant", None) is not None: grad_ckpt_kwargs = {"use_reentrant": bool(args.gradient_checkpointing_use_reentrant)} optional_cfg = {} if grpo_beta is not None: optional_cfg["beta"] = grpo_beta if grpo_top_p is not None: optional_cfg["top_p"] = grpo_top_p if grad_ckpt is not None: optional_cfg["gradient_checkpointing"] = bool(grad_ckpt) if grad_ckpt_kwargs is not None: optional_cfg["gradient_checkpointing_kwargs"] = grad_ckpt_kwargs if getattr(args, "dataloader_pin_memory", None) is not None: optional_cfg["dataloader_pin_memory"] = bool(args.dataloader_pin_memory) if getattr(args, "dataloader_num_workers", None) is not None: optional_cfg["dataloader_num_workers"] = int(args.dataloader_num_workers) config = GRPOConfig( output_dir=out_dir, logging_dir=str(log_dir / "tensorboard"), num_train_epochs=1, max_steps=args.steps, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=lr, warmup_ratio=0.05, lr_scheduler_type="cosine", logging_steps=args.logging_steps, save_steps=getattr(args, "save_steps", None) or 10, save_total_limit=getattr(args, "save_total_limit", None) or 20, bf16=True, max_completion_length=completion_len, num_generations=args.num_generations, max_tool_calling_iterations=args.max_tool_calling_iterations, temperature=grpo_temp, report_to=args.report_to, run_name=args.run_name, log_completions=True, seed=args.seed, chat_template_kwargs=chat_template_kwargs, generation_kwargs=generation_kwargs, **optional_cfg, **vllm_kwargs, ) # TRL's add_response_schema pattern-matches tokenizer.chat_template against # a fixed list and raises "Unrecognized chat template" if no match. Some # newer Qwen3 builds (e.g. Qwen3-4B-Instruct-2507, Aug 2025) ship a # template that differs from TRL's stored string (the Instruct release # dropped the enable_thinking block) — but the tool-call format # () is identical, so the appropriate schema still # parses correctly. We assign it manually before GRPOTrainer init; TRL # checks `response_schema is None` first so this is a safe override. if getattr(tokenizer, "response_schema", None) is None: try: from trl.chat_template_utils import qwen3_schema, qwen3_5_schema m = args.model.lower() if "qwen3.5" in m or "qwen3_5" in m: tokenizer.response_schema = qwen3_5_schema print("Set tokenizer.response_schema = qwen3_5_schema (manual override).") elif "qwen3" in m: tokenizer.response_schema = qwen3_schema print("Set tokenizer.response_schema = qwen3_schema (manual override).") except ImportError: pass trainer = GRPOTrainer( model=model, reward_funcs=environment_reward, args=config, train_dataset=dataset, processing_class=tokenizer, environment_factory=CricketCaptainToolEnv, ) print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …") trainer.train() model.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir) print(f"\nSaved → {save_dir}") # ------------------------------------------------------------------ # # Quick eval: run N episodes with the trained model # # ------------------------------------------------------------------ # def evaluate(args): """Run N episodes and print coherence + score stats.""" from server.reward_calculator import compute_episode_reward, get_dls_par model, tokenizer = load_model(args.model) model.eval() def generate(prompt: str) -> str: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=200, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) rng = random.Random(args.seed) all_coh, all_scores = [], [] for ep in range(args.eval_episodes): env = CricketEnvironment() obs = env.reset(seed=rng.randint(0, 99999), options={ "task": "stage2_full", "random_start": False, "agent_team": args.agent_team, }) steps = 0 while not obs.done and steps < 150: prompt = _format_prompt(obs.prompt_text) raw = generate(prompt) data = _parse_completion(raw) if data: action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {})) else: action = CricketAction(tool="invalid_json", arguments={}) obs = env.step(action) steps += 1 state = env.state avg_coh = sum(state.coherence_scores) / len(state.coherence_scores) if state.coherence_scores else 0 all_coh.append(avg_coh) all_scores.append(state.total_score) print(f" Ep {ep+1}: {state.total_score}/{state.wickets_lost} coh={avg_coh:.3f}") print(f"\nAvg coherence: {sum(all_coh)/len(all_coh):.3f}") print(f"Avg score: {sum(all_scores)/len(all_scores):.1f}") def _make_run_folder(prefix: str, model: str | None, opponent_mode: str | None, max_overs: int | None) -> Path: """Create a timestamped illustrations folder, return its path.""" import datetime ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") model_short = (model or "heuristic").split("/")[-1][:20] if model else "heuristic" overs_str = f"_{max_overs}ov" if max_overs else "" opp_str = f"_{opponent_mode}" if opponent_mode else "" folder_name = f"exp_{ts}_{prefix}{overs_str}{opp_str}_{model_short}" run_dir = Path(__file__).parent / "illustrations" / folder_name run_dir.mkdir(parents=True, exist_ok=True) return run_dir def train_smoke(args): """Run short direct-environment training rollouts without loading a model.""" rng = random.Random(args.seed) roster = _training_roster(args.agent_team) # Auto-create run folder unless --output explicitly given if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) run_dir = output_path.parent else: model_hint = getattr(args, "model", None) run_dir = _make_run_folder("train_smoke", model_hint, args.opponent_mode, args.max_overs) output_path = run_dir / "run_output.txt" # Write header immediately so the file exists while the run is in progress header_lines = [ "# Training smoke: direct CricketEnvironment rollout", f"matches={args.matches} max_overs={args.max_overs} opponent_mode={args.opponent_mode}", "purpose=verify one short training-style match rollout, prompt collection, tool stepping, and terminal reward", "", ] output_path.write_text("\n".join(header_lines)) def log(msg: str): print(msg) with open(output_path, "a") as _f: _f.write(msg + "\n") log("# Training smoke: direct CricketEnvironment rollout") log(f"matches={args.matches} max_overs={args.max_overs} opponent_mode={args.opponent_mode}") log("purpose=verify one short training-style match rollout, prompt collection, tool stepping, and terminal reward") for match_idx in range(args.matches): match_start = time.perf_counter() last_step_time = match_start env = CricketEnvironment() obs = env.reset(seed=rng.randint(0, 99999), options={ "task": "stage2_full", "random_start": False, "max_overs": args.max_overs, "eval_pack_id": args.eval_pack_id, "opponent_mode": args.opponent_mode, "opponent_cache_path": args.opponent_cache_path, "agent_team": args.agent_team, }) prompts = [_format_prompt(obs.prompt_text)] total_reward = 0.0 steps = 0 log(f"\n--- match {match_idx + 1} reset ---") log(f"initial_state={obs.game_state} phase={obs.strategic_phase} t_elapsed=0.000s tools={obs.available_tools}") while not obs.done and steps < args.max_steps: step_start = time.perf_counter() previous_game_state = obs.game_state previous_innings = obs.innings_type previous_available_tools = obs.available_tools action = _random_action( rng, obs.game_state, obs.available_tools, obs.current_bowler.get("type") if obs.current_bowler else None, roster, ) obs = env.step(action) step_end = time.perf_counter() step_dt = step_end - step_start elapsed = step_end - match_start since_prev = step_end - last_step_time last_step_time = step_end total_reward += obs.reward or 0.0 if not obs.done: prompts.append(_format_prompt(obs.prompt_text)) changed_context = ( obs.done or obs.game_state != previous_game_state or obs.innings_type != previous_innings or obs.available_tools != previous_available_tools ) if steps < args.log_steps or changed_context: over = obs.game_context.get("over", 0) or 0 ball = obs.game_context.get("ball", 0) or 0 score = obs.game_context.get("score", 0) or 0 balls_used = int(over) * 6 + int(ball) balls_left = max(args.max_overs * 6 - balls_used, 0) current_rr = score / (balls_used / 6) if balls_used > 0 else 0.0 if obs.target is not None: runs_required = max(obs.target - score, 0) required_rr = runs_required / max(balls_left / 6, 1 / 6) chase_context = f"need={runs_required} balls_left={balls_left} rrr={required_rr:.2f}" else: chase_context = "need=None balls_left=None rrr=None" outcome_meta = (obs.last_outcome or {}).get("metadata", {}) tactical_context = "" if outcome_meta and action.tool in {"bowl_delivery", "play_delivery"}: delivery = outcome_meta.get("delivery_features", {}) tactical_context = ( f" event={outcome_meta.get('event_type')} zone={outcome_meta.get('target_area')} " f"traj={outcome_meta.get('trajectory')} field_effect={outcome_meta.get('fielder_effect')} " f"fit={outcome_meta.get('shot_delivery_fit')} field_pressure={outcome_meta.get('field_pressure')} " f"line={delivery.get('line')} length={delivery.get('length')} variation={delivery.get('variation')}" ) log( f"step={steps:03d} t_elapsed={elapsed:.3f}s step_dt={step_dt:.4f}s since_prev={since_prev:.4f}s " f"tool={action.tool} reward={(obs.reward or 0.0):.3f} " f"state={obs.game_state}/{obs.innings_type} phase={obs.strategic_phase} " f"over={obs.game_context.get('over')}.{obs.game_context.get('ball')} " f"score={obs.game_context.get('score')}/{obs.game_context.get('wickets')} " f"target={obs.target} rr={current_rr:.2f} {chase_context} " f"{tactical_context} " f"tools={obs.available_tools} " f"last={obs.last_ball_result[:140]!r}" ) steps += 1 state = env.state match_elapsed = time.perf_counter() - match_start log(f"\n--- match {match_idx + 1} final ---") log(f"done={obs.done} steps={steps} prompts_collected={len(prompts)} rollout_reward_sum={total_reward:.3f} match_elapsed={match_elapsed:.3f}s avg_step_dt={(match_elapsed / max(steps, 1)):.4f}s") log(f"score={state.total_score}/{state.wickets_lost} over={state.over}.{state.ball} target={state.target} game_state={state.game_state}") log(f"last_outcome={state.last_outcome}") log(f"match_result={state.match_result} reward_breakdown={state.reward_breakdown}") log(f"innings_rewards={state.innings_rewards}") log(f"tool_calls={state.tool_calls_made} dream11_scores={state.dream11_scores}") log(f"mean_coherence={(sum(state.coherence_scores) / len(state.coherence_scores)) if state.coherence_scores else 0.0:.3f}") log(f"mean_adaptation={(sum(state.adaptation_scores) / len(state.adaptation_scores)) if state.adaptation_scores else 0.0:.3f}") log(f"mean_opponent_awareness={(sum(state.opponent_awareness_scores) / len(state.opponent_awareness_scores)) if state.opponent_awareness_scores else 0.0:.3f}") print(f"\nwrote={output_path}") # Write README for the run import datetime readme_path = run_dir / "README.md" model_str = getattr(args, "model", None) or "heuristic (random actions)" readme_path.write_text( f"## Train-Smoke Run: {run_dir.name}\n\n" f"**Date**: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n" f"**Config**: `{getattr(args, 'config', None) or 'defaults'}`\n\n" f"| Setting | Value |\n|---|---|\n" f"| Matches | {args.matches} |\n" f"| Max overs | {args.max_overs} |\n" f"| Opponent mode | {args.opponent_mode} |\n" f"| Model (train target) | `{model_str}` |\n\n" f"See `run_output.txt` for full step-by-step rollout log, reward breakdowns, and coherence scores.\n" ) print(f"wrote={readme_path}") # ------------------------------------------------------------------ # # CLI # # ------------------------------------------------------------------ # def _apply_yaml_defaults(args, cfg: dict) -> None: """Merge YAML config values into args, CLI args take precedence.""" captain = cfg.get("captain", {}) or {} opponent = cfg.get("opponent", {}) or {} env_cfg = cfg.get("env", {}) or {} train_cfg = cfg.get("train", {}) or {} def _set(attr, val): if val is not None and getattr(args, attr, None) is None: setattr(args, attr, val) if getattr(args, "cmd", None) == "train": _set("model", train_cfg.get("model")) else: _set("model", captain.get("model")) _set("api_base", captain.get("api_base")) _set("api_key", os.environ.get(captain.get("api_key_env", "")) or None) _set("eval_pack_id", env_cfg.get("eval_pack_id")) _set("opponent_mode", opponent.get("mode")) _set("opponent_cache_path", opponent.get("cache_path")) _set("opponent_model", opponent.get("model")) _set("opponent_api_base", opponent.get("api_base")) api_key_env = opponent.get("api_key_env") _set("opponent_api_key", os.environ.get(api_key_env, "") if api_key_env else None) _set("max_overs", env_cfg.get("max_overs")) _set("agent_team", env_cfg.get("agent_team")) _set("steps", train_cfg.get("steps")) _set("prompts", train_cfg.get("prompts")) _set("batch_size", train_cfg.get("batch_size")) _set("grad_accum", train_cfg.get("grad_accum")) _set("stage", train_cfg.get("stage")) _set("num_generations", train_cfg.get("num_generations")) _set("max_completion_length", train_cfg.get("max_completion_length")) _set("max_tool_calling_iterations", train_cfg.get("max_tool_calling_iterations")) _set("logging_steps", train_cfg.get("logging_steps")) _set("report_to", train_cfg.get("report_to")) _set("run_name", train_cfg.get("run_name")) _set("learning_rate", train_cfg.get("learning_rate")) _set("beta", train_cfg.get("beta")) _set("temperature", train_cfg.get("temperature")) _set("top_p", train_cfg.get("top_p")) _set("gradient_checkpointing", train_cfg.get("gradient_checkpointing")) _set("gradient_checkpointing_use_reentrant", train_cfg.get("gradient_checkpointing_use_reentrant")) _set("dataloader_pin_memory", train_cfg.get("dataloader_pin_memory")) _set("dataloader_num_workers", train_cfg.get("dataloader_num_workers")) _set("bf16_base", train_cfg.get("bf16_base")) _set("save_steps", train_cfg.get("save_steps")) _set("save_total_limit", train_cfg.get("save_total_limit")) _set("resume_from", train_cfg.get("resume_from")) _set("overs_distribution", train_cfg.get("overs_distribution")) _set("use_vllm", train_cfg.get("use_vllm")) _set("vllm_gpu_memory", train_cfg.get("vllm_gpu_memory")) _set("vllm_model_impl", train_cfg.get("vllm_model_impl")) def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default=None, help="YAML config path (sets defaults for all subcommands)") sub = parser.add_subparsers(dest="cmd") # train t = sub.add_parser("train", help="Run GRPO training") t.add_argument("--config", default=None) t.add_argument("--stage", type=int, default=None, choices=[1, 2]) t.add_argument("--model", default=None) t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect") t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps") t.add_argument("--batch-size", type=int, default=None, dest="batch_size") t.add_argument("--grad-accum", type=int, default=None, dest="grad_accum") t.add_argument("--num-generations", type=int, default=None, dest="num_generations") t.add_argument("--agent-team", default=None, dest="agent_team") t.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode") t.add_argument("--opponent-model", default=None, dest="opponent_model") t.add_argument("--opponent-api-base", default=None, dest="opponent_api_base") t.add_argument("--opponent-api-key", default=None, dest="opponent_api_key") t.add_argument("--max-overs", type=int, default=None, dest="max_overs") t.add_argument("--eval-pack-id", default=None, dest="eval_pack_id") t.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path") t.add_argument("--max-completion-length", type=int, default=None, dest="max_completion_length") t.add_argument("--max-tool-calling-iterations", type=int, default=None, dest="max_tool_calling_iterations") t.add_argument("--logging-steps", type=int, default=None, dest="logging_steps") t.add_argument("--report-to", default=None, dest="report_to") t.add_argument("--run-name", default=None, dest="run_name") t.add_argument("--seed", type=int, default=42) t.add_argument("--resume-from", default=None, dest="resume_from", help="Path to a previous LoRA adapter dir (e.g. ./checkpoints/stage2_final). " "When set, the adapter is loaded on top of the base model instead of a fresh init.") t.add_argument("--save-steps", type=int, default=None, dest="save_steps") t.add_argument("--save-total-limit", type=int, default=None, dest="save_total_limit") t.add_argument("--learning-rate", type=float, default=None, dest="learning_rate") t.add_argument("--beta", type=float, default=None, dest="beta", help="GRPO KL coefficient. Lower = more exploration.") t.add_argument("--temperature", type=float, default=None, dest="temperature") t.add_argument("--top-p", type=float, default=None, dest="top_p") t.add_argument("--gradient-checkpointing", action="store_true", dest="gradient_checkpointing", default=None) t.add_argument("--no-gradient-checkpointing", action="store_false", dest="gradient_checkpointing") t.add_argument("--gradient-checkpointing-use-reentrant", action="store_true", dest="gradient_checkpointing_use_reentrant", default=None) t.add_argument("--dataloader-pin-memory", action="store_true", dest="dataloader_pin_memory", default=None) t.add_argument("--dataloader-num-workers", type=int, default=None, dest="dataloader_num_workers") t.add_argument("--use-vllm", action="store_true", dest="use_vllm", default=None, help="Use vLLM-backed rollouts (colocate). Forces bf16 base.") t.add_argument("--bf16-base", action="store_true", dest="bf16_base", default=None, help="Load base model in bf16 instead of 4-bit NF4. Faster matmul on H200 since 4B fits in 8GB.") t.add_argument("--vllm-gpu-memory", type=float, default=0.5, dest="vllm_gpu_memory", help="Fraction of GPU memory reserved for vLLM (colocate). Default 0.5.") t.add_argument("--vllm-model-impl", default=None, dest="vllm_model_impl", choices=["transformers", "vllm"], help="vLLM model backend. None (default) = native vLLM class (e.g. Qwen3ForCausalLM); " "'transformers' = HF transformers backend (workaround for Qwen3.5 — flaky with LoRA).") # eval e = sub.add_parser("eval", help="Evaluate a checkpoint") e.add_argument("--config", default=None) e.add_argument("--model", default=None) e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes") e.add_argument("--agent-team", default=None, dest="agent_team") e.add_argument("--seed", type=int, default=0) # quick test (no GPU needed) sub.add_parser("test", help="Smoke-test reward functions") smoke = sub.add_parser("train-smoke", help="Run short direct-env training rollouts without loading a model") smoke.add_argument("--config", default=None) smoke.add_argument("--matches", type=int, default=1) smoke.add_argument("--max-overs", type=int, default=None, dest="max_overs") smoke.add_argument("--max-steps", type=int, default=240, dest="max_steps") smoke.add_argument("--log-steps", type=int, default=30, dest="log_steps") smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id") smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode") smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path") smoke.add_argument("--agent-team", default=None, dest="agent_team") smoke.add_argument("--output", default=None) smoke.add_argument("--seed", type=int, default=42) sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples") sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl") sft.add_argument("--examples", type=int, default=240) sft.add_argument("--agent-team", default=None, dest="agent_team") sft.add_argument("--seed", type=int, default=42) args = parser.parse_args() # Apply YAML config (subcommand --config overrides top-level --config) config_path = getattr(args, "config", None) or getattr(parser.parse_known_args()[0], "config", None) if config_path: try: from config_yaml import load_config except ImportError: from cricket_captain.config_yaml import load_config _apply_yaml_defaults(args, load_config(config_path)) # Set safe defaults after YAML merge if getattr(args, "stage", None) is None: args.stage = 1 if getattr(args, "model", None) is None: args.model = "Qwen/Qwen3.5-4B" if getattr(args, "steps", None) is None: args.steps = 200 if getattr(args, "prompts", None) is None: args.prompts = 500 if getattr(args, "batch_size", None) is None: args.batch_size = 2 if getattr(args, "grad_accum", None) is None: args.grad_accum = 4 if getattr(args, "eval_pack_id", None) is None: args.eval_pack_id = "adaptive_t20_v1" if getattr(args, "opponent_mode", None) is None: args.opponent_mode = "llm_live" if getattr(args, "max_overs", None) is None: args.max_overs = 5 if getattr(args, "agent_team", None) is None: args.agent_team = os.environ.get("CRICKET_AGENT_TEAM") if getattr(args, "max_tool_calling_iterations", None) is None: args.max_tool_calling_iterations = 200 if getattr(args, "logging_steps", None) is None: args.logging_steps = 1 if getattr(args, "report_to", None) is None: args.report_to = "none" if args.cmd == "train": train(args) elif args.cmd == "eval": evaluate(args) elif args.cmd == "test": _smoke_test(args.agent_team, args.opponent_mode) elif args.cmd == "train-smoke": train_smoke(args) elif args.cmd == "sft-data": generate_sft_examples(args.output, args.examples, args.seed, args.agent_team) else: parser.print_help() def _smoke_test(agent_team: str | None, opponent_mode: str): """Verify reward functions work correctly.""" cases = [ ( "[CricketCaptain] BATTING | SECOND INNINGS\nPhase: MIDDLE | Bowler: SPIN\n" "Batting Strategy: consolidate aggression=0.30 rotate strike against spin in middle overs", '{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotating"}}', "high", ), ( "[CricketCaptain] BATTING | SECOND INNINGS\nPhase: MIDDLE | Bowler: SPIN\n" "Batting Strategy: consolidate aggression=0.30 rotate strike against spin in middle overs", '{"tool": "play_delivery", "arguments": {"shot_intent": "six", "explanation": "going big"}}', "low", ), ( "[CricketCaptain] BATTING | FIRST INNINGS\nPhase: POWERPLAY\nBatting Strategy: None", '{"tool": "play_delivery", "arguments": {"shot_intent": "boundary", "explanation": "attack"}}', "zero", ), ( "[CricketCaptain] BATTING | SECOND INNINGS\nPhase: DEATH\nBatting Strategy: None", "not valid json at all", "zero", ), ] print("Reward function smoke test:\n") for prompt, completion, expected in cases: fmt = r_validity(completion) coh = r_behavior_stateless(prompt, completion) print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}") print("\nPrompt collection test (5 prompts):") p = collect_prompts(5, task="stage1_format", seed=1, agent_team=agent_team, opponent_mode=opponent_mode) for i, pp in enumerate(p): print(f" [{i}] {pp[:80].strip()} …") if __name__ == "__main__": main()