pratinavseth's picture
sync: today's source updates (XML-only prompt, reward unclip, neg-reward on loss, pinned versions, configs reorg)
2fc50a9 verified
"""
MT-GRPO training script for CricketCaptain-LLM.
Two-stage curriculum (ToolRL-style):
Stage 1: tool-call mastery — emphasize valid, phase-legal tool usage
Stage 2: strategic behavior — full environment-backed reward (result + cricket + behavior + validity)
Design:
- Training uses TRL GRPO with environment_factory=CricketCaptainToolEnv
- The model interacts with live CricketEnvironment instances over multi-turn tool calls
- Rewards are collected from the environment (environment_reward), not only from stateless prompt parsing
- The opponent policy is part of the environment: heuristic/cricsheet/llm_live/llm_cached
- Plain TRL + Transformers + bitsandbytes + PEFT (LoRA adapters for 4-bit models)
Usage (canonical Qwen3 setup):
python train.py train --config configs/cricket_train_qwen3_warmup.yaml # warmup
python train.py train --config configs/cricket_train_qwen3.yaml # main 5-over
Legacy Qwen3.5 configs live in configs/extras/.
"""
import argparse
import copy
import datetime
import json
import os
import random
import re
import sys
import threading
import time
from pathlib import Path
from typing import Any
# ------------------------------------------------------------------ #
# Optional training imports #
# ------------------------------------------------------------------ #
try:
import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
_TRAIN_IMPORTS_AVAILABLE = True
except ImportError:
torch = None
Dataset = None
LoraConfig = None
get_peft_model = None
prepare_model_for_kbit_training = None
GRPOConfig = None
GRPOTrainer = None
AutoModelForCausalLM = None
AutoTokenizer = None
BitsAndBytesConfig = None
_TRAIN_IMPORTS_AVAILABLE = False
try:
from server.cricket_environment import CricketEnvironment
from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
from server.markov_engine import SHOT_AGGRESSION
from server.player_roster import build_playing_xi, load_team_roster
from models import CricketAction
from config_yaml import get_game_constants, get_reward_weights
except ImportError:
from cricket_captain.server.cricket_environment import CricketEnvironment
from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
from cricket_captain.server.markov_engine import SHOT_AGGRESSION
from cricket_captain.server.player_roster import build_playing_xi, load_team_roster
from cricket_captain.models import CricketAction
from cricket_captain.config_yaml import get_game_constants, get_reward_weights
# Load game knowledge once at import time (cached in config_yaml).
_GK = get_game_constants()
_RW = get_reward_weights()
# ------------------------------------------------------------------ #
# Prompt parsing (stateless — reads from rendered observation text) #
# ------------------------------------------------------------------ #
_STRATEGY_RE = re.compile(r"Batting Strategy:\s*(.+)$", re.MULTILINE)
_AGGRESSION_RE = re.compile(r"aggression[\"'=:\s]+([0-9.]+)", re.IGNORECASE)
_PHASE_RE = re.compile(r"Phase:\s+(POWERPLAY|MIDDLE|DEATH)", re.IGNORECASE)
_VALID_TOOLS = {
"call_toss",
"select_batter",
"set_strategy",
"plan_shot",
"play_delivery",
"choose_bowler",
"set_bowling_strategy",
"plan_delivery",
"set_field_setting",
"bowl_delivery",
"reflect_after_ball",
"analyze_situation",
"set_match_plan",
"update_match_plan",
}
def extract_strategy_from_prompt(prompt: str) -> dict | None:
m = _STRATEGY_RE.search(prompt)
if not m or m.group(1).strip().lower().startswith("none"):
return None
agg_m = _AGGRESSION_RE.search(prompt)
agg = float(agg_m.group(1)) if agg_m else 0.5
return {"phase_intent": m.group(1).strip(), "aggression": agg, "rationale": m.group(1).strip()}
def extract_phase_from_prompt(prompt: str) -> str:
m = _PHASE_RE.search(prompt)
return m.group(1).lower() if m else "middle"
# ------------------------------------------------------------------ #
# Per-turn reward components (all stateless) #
# ------------------------------------------------------------------ #
_XML_FN_RE = re.compile(r"<function\s*=?\s*([^>\s]+)\s*>", re.IGNORECASE)
_XML_PARAM_RE = re.compile(r"<parameter\s*=\s*([^>\s]+)\s*>(.*?)</parameter>", re.IGNORECASE | re.DOTALL)
def _parse_completion(raw: str) -> dict | None:
"""Parse a tool-call from the raw completion into our canonical {tool, arguments} dict.
Handles four common model output patterns:
1. Plain JSON (ideal).
2. Markdown code block (```json ... ```).
3. Thinking-model preamble: <think>...</think> followed by JSON.
Qwen3/Qwen3.5 in default mode emits reasoning inside <think> tags;
we strip everything up to and including the closing </think> tag.
4. XML function-call format that Qwen3.5 was trained on:
<function=tool_name><parameter=foo>bar</parameter>...</function>
Empirically (see logs/run_2026-04-25_21-08-45) every Stage-1 completion
emitted this XML form instead of JSON — so we extract it as a fallback
to give GRPO a non-zero gradient before the model has been trained
onto the JSON contract.
"""
raw = raw.strip()
# Strip <think>...</think> preamble emitted by thinking-mode models.
if "<think>" in raw:
think_end = raw.rfind("</think>")
if think_end != -1:
raw = raw[think_end + len("</think>"):].strip()
if raw.startswith("```"):
lines = raw.split("\n")
raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw
# Try parsing the whole string, then fall back to the first {...} block.
try:
return json.loads(raw)
except (json.JSONDecodeError, ValueError):
pass
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end > start:
try:
return json.loads(raw[start : end + 1])
except (json.JSONDecodeError, ValueError):
pass
# XML function-call fallback (Qwen3.5 default tool-call emission style).
fn_match = _XML_FN_RE.search(raw)
if fn_match:
tool = fn_match.group(1).strip().strip("\"'")
arguments: dict[str, Any] = {}
for pname, pval in _XML_PARAM_RE.findall(raw):
v = pval.strip()
# Coerce numeric/bool literals so downstream validators accept them.
try:
arguments[pname] = json.loads(v)
except (json.JSONDecodeError, ValueError):
arguments[pname] = v
return {"tool": tool, "arguments": arguments}
return None
# Bounded LRU-ish cache. Each snapshot is a deepcopy of CricketEnvironment
# (~1 MB) and only used by the LEGACY single-turn r_environment_rollout path,
# not by the multi-turn environment_factory training path. Cap at 4096 entries
# (~4 GB worst case) so a long collect_prompts call can't blow up host RAM.
_PROMPT_ENV_SNAPSHOTS: dict[str, CricketEnvironment] = {}
_PROMPT_SNAPSHOT_CAP = 4096
_ENV_REWARD_ROLLOUT_STEPS = 12
def _remember_prompt(obs_text: str, env: CricketEnvironment) -> str:
"""Format an observation and keep the exact env state for rollout reward."""
prompt = _format_prompt(obs_text)
if len(_PROMPT_ENV_SNAPSHOTS) >= _PROMPT_SNAPSHOT_CAP:
# Evict oldest insertion (dict preserves insertion order in py3.7+).
oldest_key = next(iter(_PROMPT_ENV_SNAPSHOTS))
del _PROMPT_ENV_SNAPSHOTS[oldest_key]
_PROMPT_ENV_SNAPSHOTS[prompt] = copy.deepcopy(env)
return prompt
def r_environment_rollout(prompt: str, completion: str) -> float | None:
"""Env-backed score for a generated tool call plus short continuation.
Returns None when the prompt was not collected from an env snapshot, allowing
callers to fall back to stateless scoring. Otherwise returns [0, 1], where 0
means invalid JSON/tool-for-state and higher values reflect the env reward.
"""
snapshot = _PROMPT_ENV_SNAPSHOTS.get(prompt)
if snapshot is None:
return None
data = _parse_completion(completion)
if data is None:
return 0.0
tool = data.get("tool", "")
args = data.get("arguments", {})
if not isinstance(args, dict):
return 0.0
env = copy.deepcopy(snapshot)
if tool not in env._get_available_tools():
return 0.0
try:
obs = env.step(CricketAction(tool=tool, arguments=args))
except Exception:
return 0.0
reward = float(obs.reward or 0.0)
rng = random.Random(hash(prompt + completion) & 0xFFFFFFFF)
roster = build_playing_xi(getattr(env, "_agent_roster", []))
for _ in range(_ENV_REWARD_ROLLOUT_STEPS):
if obs.done:
break
action = _random_action(
rng,
obs.game_state,
obs.available_tools,
obs.current_bowler.get("type") if obs.current_bowler else None,
roster,
)
obs = env.step(action)
reward += float(obs.reward or 0.0)
if obs.done and env.state.reward_breakdown:
reward += float(env.state.reward_breakdown.get("composite", 0.0))
# Map rollout reward into [0,1] while preserving penalties for bad tool choices.
return round(max(0.0, min(1.0, 0.5 + reward)), 4)
def r_validity(completion: str) -> float:
"""Schema reward for tool calling.
Exact env-executable calls receive 1.0. Malformed but parseable JSON gets a
small shaping signal so early GRPO has non-zero variance before the model has
learned the strict `{"tool": ..., "arguments": {...}}` contract.
"""
data = _parse_completion(completion)
if data is None:
return 0.0
if not isinstance(data, dict):
return 0.05
tool = data.get("tool", "")
args = data.get("arguments", {})
if "tool" not in data or "arguments" not in data:
return 0.15
if tool not in _VALID_TOOLS:
return 0.25
if not isinstance(args, dict):
return 0.35
if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION:
return 0.5
if tool == "set_strategy":
agg = args.get("aggression")
if not isinstance(agg, (int, float)):
return 0.5
if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION:
return 0.5
if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}:
if args.get("bowler_type") not in (None, "pace", "spin"):
return 0.5
return 1.0
# Kept for backward compatibility with smoke test
r_format = r_validity
def _bowling_phase_fit(delivery_type: str, phase: str) -> float:
"""Return 1.0 if delivery_type fits the phase, else 0.4. Loaded from game_knowledge.yaml."""
valid = _GK.bowling_phase_delivery.get(phase, [])
return 1.0 if delivery_type in valid else 0.4
def _field_phase_fit(setting: str, phase: str) -> float:
"""Return phase-fit score for a field preset. Loaded from game_knowledge.yaml."""
return float(_GK.field_phase_fit.get(setting, {}).get(phase, 0.5))
def r_behavior_stateless(prompt: str, completion: str) -> float:
"""
r_behavior: plan-action coherence score, covering ALL tool types.
Previously only graded play_delivery, leaving ~60% of decision points
unscored. Now each tool family gets an appropriate coherence signal:
- play_delivery / plan_shot : aggression-match + rationale + phase fit
- set_strategy : phase appropriateness + rationale quality
- set_bowling_strategy / plan_delivery : delivery-phase fit + rationale
- set_field_setting : field-phase fit
- reflect_after_ball : rationale specificity
- choose_bowler / select_batter : rationale specificity
- bowl_delivery / call_toss / analyze_situation : not graded (no plan)
"""
data = _parse_completion(completion)
if data is None:
return 0.0
tool = data.get("tool", "")
args = data.get("arguments", {})
strategy = extract_strategy_from_prompt(prompt)
phase = extract_phase_from_prompt(prompt)
# Base reward for any valid structured action — ensures GRPO always has a
# positive gradient to reinforce correct tool use even when coherence can't
# be fully scored (no declared strategy, unscored tool types, etc.).
_BASE = 0.10
if tool == "play_delivery":
shot = args.get("shot_intent", "")
if shot not in SHOT_AGGRESSION:
return 0.0
if strategy is None:
return _BASE # valid shot, no declared strategy to align against
agg = strategy["aggression"]
a_match = aggression_match(agg, shot)
r_spec = rationale_specificity(strategy.get("rationale", ""))
p_approp = phase_appropriate(agg, phase)
return round(_BASE + (1 - _BASE) * (0.50 * a_match + 0.30 * r_spec + 0.20 * p_approp), 4)
if tool == "plan_shot":
shot = args.get("shot_intent", "")
if shot not in SHOT_AGGRESSION:
return 0.0
if strategy is None:
return _BASE # valid plan structure, no context to grade against
agg = strategy["aggression"]
a_match = aggression_match(agg, shot)
r_spec = rationale_specificity(args.get("rationale", ""))
return round(_BASE + (1 - _BASE) * (0.60 * a_match + 0.40 * r_spec), 4)
if tool == "set_strategy":
agg = args.get("aggression", 0.5)
try:
agg = float(agg)
except (TypeError, ValueError):
agg = 0.5
r_spec = rationale_specificity(args.get("rationale", ""))
p_approp = phase_appropriate(agg, phase)
return round(0.50 * p_approp + 0.50 * r_spec, 4)
if tool in {"set_bowling_strategy", "plan_delivery"}:
delivery_type = args.get("delivery_type", "")
r_spec = rationale_specificity(args.get("rationale", ""))
p_fit = _bowling_phase_fit(delivery_type, phase)
return round(0.50 * r_spec + 0.50 * p_fit, 4)
if tool == "set_field_setting":
setting = args.get("setting", "Balanced")
return round(_field_phase_fit(setting, phase), 4)
if tool == "reflect_after_ball":
return round(rationale_specificity(args.get("reflection", "")), 4)
if tool in {"choose_bowler", "select_batter"}:
return round(rationale_specificity(args.get("rationale", "")), 4)
if tool == "set_match_plan":
# Score completeness + rationale quality
fields = ["powerplay_intent", "middle_intent", "death_intent", "risk_budget", "trigger_conditions"]
completeness = sum(1 for f in fields if args.get(f, "").strip()) / len(fields)
r_spec = rationale_specificity(args.get("rationale", ""))
return round(0.6 * completeness + 0.4 * r_spec, 4)
if tool == "update_match_plan":
# Score whether update is justified by a match-state trigger
reason = args.get("reason", args.get("rationale", ""))
triggers = ["wicket", "target", "rrr", "phase", "field", "rate", "pressure", "boundary", "dot"]
hits = sum(1 for t in triggers if t in reason.lower())
r_spec = rationale_specificity(reason)
return round(min(1.0, 0.5 * r_spec + 0.5 * min(hits / 3, 1.0)), 4)
# bowl_delivery, call_toss, analyze_situation — structurally valid but no
# coherence plan to grade. Return a small base so GRPO distinguishes these
# from invalid JSON (which scores 0.0) without over-weighting them.
if tool in {"bowl_delivery", "call_toss", "analyze_situation"}:
return 0.15
return 0.0
def r_adaptation_stateless(prompt: str, completion: str) -> float:
if r_validity(completion) == 0.0:
return 0.0 # don't reward invalid tool calls for context-matching
data = _parse_completion(completion)
if data is None:
return 0.0
text = json.dumps(data.get("arguments", {})).lower()
phase = extract_phase_from_prompt(prompt)
score = 0.0
if phase in text:
score += 0.25
if any(word in prompt.lower() for word in ("target:", "death", "wicket", "opponent last plan")):
score += 0.25
if any(word in text for word in ("adjust", "target", "field", "phase", "wicket", "pressure", "matchup")):
score += 0.25
if data.get("tool") in {"plan_shot", "plan_delivery", "reflect_after_ball", "choose_bowler", "select_batter"}:
score += 0.25
return round(min(score, 1.0), 4)
def r_opponent_awareness_stateless(prompt: str, completion: str) -> float:
if r_validity(completion) == 0.0:
return 0.0 # don't reward invalid tool calls for context-matching
data = _parse_completion(completion)
if data is None:
return 0.0
text = json.dumps(data.get("arguments", {})).lower()
prompt_l = prompt.lower()
hits = 0
for word in ("opponent", "field", "bowler", "batter", "spin", "pace", "aggressive", "defensive"):
if word in prompt_l and word in text:
hits += 1
return round(min(hits / 3, 1.0), 4)
# ------------------------------------------------------------------ #
# Composite reward function — TRL 0.24 signature #
# ------------------------------------------------------------------ #
def make_reward_fn(curriculum_stage: int):
"""
Returns reward_fn(prompts, completions, **kwargs) → list[float].
Weights align with compute_episode_reward in reward_calculator.py:
r_env — one-step env rollout reward when prompt snapshot exists
r_behavior — stateless tactical/tool coherence
r_validity — JSON/tool schema validity
"""
# Minimum reward for any structurally valid completion — ensures GRPO has a
# positive gradient to reinforce valid tool use even for unscored tool types.
_VALID_FLOOR = 0.05
def reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
rewards = []
for prompt, completion in zip(prompts, completions):
fmt = r_validity(completion)
env_score = r_environment_rollout(prompt, completion)
if curriculum_stage == 1:
# Length-efficiency penalty: a valid JSON tool call is ≤400 chars.
# Models with thinking mode (Qwen3/3.5) generate 800-2000 char
# preambles before the JSON; penalise that verbosity so GRPO
# learns to emit short, direct JSON. The penalty scales from
# 1.0 at ≤400 chars to 0.0 at ≥2400 chars (linear).
_JSON_TARGET = 400
_RAMP_RANGE = 2000
length_eff = max(0.0, 1.0 - max(0, len(completion) - _JSON_TARGET) / _RAMP_RANGE)
base = 0.5 * fmt + 0.5 * (env_score if env_score is not None else fmt)
rewards.append(round(length_eff * base, 4))
continue
behavior = r_behavior_stateless(prompt, completion)
adapt = r_adaptation_stateless(prompt, completion)
aware = r_opponent_awareness_stateless(prompt, completion)
r_beh = (
_RW.behavior_coherence * behavior
+ _RW.behavior_adaptation * adapt
+ _RW.behavior_opponent_awareness * aware
)
if env_score is None:
reward = _RW.training_behavior * r_beh + _RW.training_validity * fmt
else:
reward = 0.45 * env_score + 0.40 * r_beh + 0.15 * fmt
# Floor: valid JSON should always beat invalid JSON (reward=0)
if fmt > 0.0 and (env_score is None or env_score > 0.0):
reward = max(reward, _VALID_FLOOR)
rewards.append(round(reward, 4))
return rewards
reward_fn.__name__ = f"stage{curriculum_stage}_reward"
return reward_fn
# ------------------------------------------------------------------ #
# Prompt collection (direct env instantiation — no server needed) #
# ------------------------------------------------------------------ #
SYSTEM_PROMPT = (
"You are an expert adaptive cricket captain. Each turn you receive a scorecard "
"and must choose exactly one cricket captaincy tool call.\n\n"
"EXECUTE FIRST — strict rule:\n"
" - The match only progresses when you call `play_delivery` (batting) or\n"
" `bowl_delivery` (bowling). Every other tool is overhead.\n"
" - Default action on EVERY ball: call `play_delivery` / `bowl_delivery` with\n"
" plan args INLINE: e.g. `play_delivery(shot_intent='single', risk='low', rationale='rotate')`\n"
" or `bowl_delivery(line='outside_off', length='good', delivery_type='stock')`.\n"
" - Use `set_match_plan` ONCE at the very start of an innings to declare strategy.\n"
" - Use `set_strategy` / `set_bowling_strategy` ONCE per phase boundary.\n"
" - DO NOT call `plan_shot` or `plan_delivery` (deprecated) — they only add a\n"
" wasted turn. Pass the same parameters to play_delivery / bowl_delivery directly.\n"
" - SKIP `reflect_after_ball` unless the previous ball was a wicket or boundary.\n"
" - You are scored on MATCH OUTCOMES, not on philosophical depth. Bloated\n"
" pre-ball planning truncates the episode and you forfeit the result reward.\n\n"
"THINKING BUDGET — HARD LIMIT:\n"
" - Per turn: ONE sentence of reasoning, max 30 tokens, inside <think>...</think>.\n"
" - Do NOT enumerate options, restate the scorecard, or re-derive the plan.\n"
" - Bad: '<think>This is the first ball, the field is balanced, Kohli is on strike at 0.45 aggression, I should consider...'\n"
" - Good: '<think>Powerplay, balanced field — single to rotate.</think>'\n"
" - Token budget per rollout is finite. Long thinking = match truncated = ZERO result reward.\n"
" - The plan you set at the start carries the strategy; do not re-derive it every ball.\n\n"
"Emit exactly one tool call wrapped in <tool_call>...</tool_call> XML tags. "
"Bare JSON without the wrapper is NOT recognized and will end the rollout.\n"
'Example: <tool_call>{"name": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotate strike"}}</tool_call>\n\n'
"Available tools:\n"
" call_toss — Call heads/tails and choose bat/bowl\n"
" select_batter — Choose batter profile for the match situation\n"
" set_strategy — Declare batting intent (aggression 0–1, rationale)\n"
" plan_shot — Pre-ball batting plan\n"
" play_delivery — Choose a shot and advance the game\n"
" choose_bowler — Choose bowler profile for the situation\n"
" set_bowling_strategy — Declare bowling line/length/type/rationale\n"
" plan_delivery — Pre-ball bowling plan\n"
" set_field_setting — Aggressive/Balanced/Defensive field\n"
" bowl_delivery — Execute the delivery\n"
" reflect_after_ball — Adapt after the previous ball\n"
" analyze_situation — Query pitch/bowler/field info\n\n"
"Shot intents: leave | defensive | single | rotate | boundary | six\n\n"
"PRIORITIES (in order): (1) finish the match, (2) win the match, (3) score well per ball.\n"
"Verbose reasoning forfeits all three. Decide fast, act, move on."
)
def get_system_prompt(stage: int = 2) -> str:
return SYSTEM_PROMPT
_RANDOM_SHOTS = list(SHOT_AGGRESSION.keys())
_RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"]
_RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"]
def _training_roster(agent_team: str | None = None) -> list[dict]:
team = agent_team or os.environ.get("CRICKET_AGENT_TEAM")
if not team:
raise ValueError("Roster-backed training requires --agent-team or CRICKET_AGENT_TEAM.")
roster = load_team_roster(team)
if not roster:
raise ValueError(f"No player profile roster found for agent team '{team}'.")
playing_xi = build_playing_xi(roster)
if len(playing_xi) < 11:
raise ValueError(f"Player profile roster for '{team}' could not produce a playing XI.")
return playing_xi
def _sample_batter(rng: random.Random, roster: list[dict]) -> dict:
batters = [p for p in roster if p.get("role") != "bowler"] or roster
if not batters:
raise ValueError("Roster-backed training requires at least one batting-capable player.")
return rng.choice(batters)
def _sample_bowler(rng: random.Random, roster: list[dict]) -> dict:
bowlers = [p for p in roster if p.get("bowler_type")]
if not bowlers:
raise ValueError("Roster-backed training requires at least one bowling-capable player.")
return rng.choice(bowlers)
def _random_action(
rng: random.Random,
game_state: str = "batting",
available_tools: list[str] | None = None,
current_bowler_type: str | None = None,
roster: list[dict] | None = None,
) -> CricketAction:
legal = set(available_tools or [])
def can(tool: str) -> bool:
return available_tools is None or tool in legal
def match_plan_action() -> CricketAction:
return CricketAction(tool="set_match_plan", arguments={
"powerplay_intent": "Use roster strengths to establish tempo while protecting wickets",
"middle_intent": "Rotate strike, attack favorable matchups, and preserve finishers",
"death_intent": "Commit boundary options with wickets and target pressure in mind",
"risk_budget": "Escalate only when phase, target, and wickets justify the risk",
"trigger_conditions": "Review after wicket clusters, phase changes, target pressure, or repeated boundary/dot outcomes",
"rationale": "Create a long-horizon plan before choosing ball-by-ball tactics",
})
if game_state == "toss":
return CricketAction(
tool="call_toss",
arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])},
)
if can("set_match_plan") and rng.random() < 0.12:
return match_plan_action()
if can("update_match_plan") and rng.random() < 0.08:
return CricketAction(tool="update_match_plan", arguments={
"reason": "Adjust plan after phase, score pressure, wickets, and field information",
"risk_budget": "Shift risk based on current target pressure and wickets in hand",
})
if game_state == "bowling":
choice = rng.random()
if choice < 0.15 and can("choose_bowler"):
bowler = _sample_bowler(rng, roster or [])
return CricketAction(
tool="choose_bowler",
arguments={
"name": bowler["name"],
"bowler_type": bowler["bowler_type"],
"style": bowler.get("bowl_style", bowler.get("style", "stock")),
"rationale": "Match roster bowler to phase, batter matchup, and remaining overs",
},
)
if choice < 0.35 and can("plan_delivery"):
return CricketAction(
tool="plan_delivery",
arguments={
"bowler_type": current_bowler_type or rng.choice(["pace", "spin"]),
"line": rng.choice(["stumps", "outside off", "wide"]),
"length": rng.choice(["good length", "full", "short", "yorker"]),
"delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]),
"rationale": "Use field and batter style to control scoring zones",
},
)
if choice < 0.5 and can("set_field_setting"):
return CricketAction(tool="set_field_setting", arguments={"setting": rng.choice(["Aggressive", "Balanced", "Defensive"])})
if choice < 0.6 and can("reflect_after_ball"):
return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Adjust line and field after the last ball"})
if can("bowl_delivery"):
return CricketAction(tool="bowl_delivery", arguments={})
if can("set_bowling_strategy"):
return CricketAction(tool="set_bowling_strategy", arguments={
"bowler_type": current_bowler_type or "pace",
"line": "outside off",
"length": "good length",
"delivery_type": "stock",
"rationale": "Set a legal bowling plan before executing the delivery",
})
raise ValueError(f"No legal bowling action available from tools={available_tools}")
choice = rng.random()
if choice < 0.15 and can("select_batter"):
batter = _sample_batter(rng, roster or [])
return CricketAction(
tool="select_batter",
arguments={
"name": batter["name"],
"style": batter.get("style", "balanced"),
"aggression": round(float(batter["aggression"]), 2),
"rationale": "Select batter based on phase, wickets, and target pressure",
},
)
if choice < 0.3 and can("set_strategy"):
return CricketAction(
tool="set_strategy",
arguments={
"phase_intent": rng.choice(["attack", "consolidate", "rotate"]),
"aggression": round(rng.uniform(0.1, 0.9), 2),
"rationale": "Align roster strengths with phase, target pressure, and wickets",
},
)
if choice < 0.45 and can("plan_shot"):
return CricketAction(
tool="plan_shot",
arguments={
"shot_intent": rng.choice(_RANDOM_SHOTS),
"target_area": rng.choice(_RANDOM_ZONES),
"trajectory": rng.choice(["ground", "lofted", "aerial"]),
"risk": rng.choice(["low", "balanced", "high"]),
"rationale": "Plan shot against bowler, field, and required rate",
},
)
if choice < 0.55 and can("analyze_situation"):
return CricketAction(
tool="analyze_situation",
arguments={"query_type": rng.choice(_RANDOM_QUERIES)},
)
if choice < 0.65 and can("reflect_after_ball"):
return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Revise risk after previous ball"})
if can("play_delivery"):
return CricketAction(
tool="play_delivery",
arguments={"shot_intent": rng.choice(_RANDOM_SHOTS), "explanation": "Advance the innings according to the current plan"},
)
raise ValueError(f"No legal batting action available from tools={available_tools}")
def collect_prompts(
n_prompts: int,
task: str = "stage2_full",
seed: int = 42,
agent_team: str | None = None,
opponent_mode: str = "heuristic",
) -> list[str]:
"""
Collect game-state prompts by running episodes with random actions.
Returns a list of prompt strings (one per game state observation).
"""
rng = random.Random(seed)
roster = _training_roster(agent_team)
_PROMPT_ENV_SNAPSHOTS.clear()
prompts: list[str] = []
ep_count = 0
while len(prompts) < n_prompts:
env = CricketEnvironment()
obs = env.reset(seed=rng.randint(0, 99999), options={
"task": task,
"random_start": True,
"agent_team": agent_team or os.environ.get("CRICKET_AGENT_TEAM"),
"opponent_mode": opponent_mode,
})
prompts.append(_remember_prompt(obs.prompt_text, env))
steps = 0
while not obs.done and steps < 80:
action = _random_action(
rng,
obs.game_state,
obs.available_tools,
obs.current_bowler.get("type") if obs.current_bowler else None,
roster,
)
obs = env.step(action)
if not obs.done:
prompts.append(_remember_prompt(obs.prompt_text, env))
steps += 1
ep_count += 1
if ep_count % 10 == 0:
print(f" Collected {len(prompts)} prompts from {ep_count} episodes …", flush=True)
print(f"Collected {len(prompts)} prompts from {ep_count} episodes.")
return prompts[:n_prompts]
def _format_prompt(obs_text: str) -> str:
"""Wrap the observation in a chat-style user message."""
return f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{obs_text}<|im_end|>\n<|im_start|>assistant\n"
def build_dataset(prompts: list[str]) -> Dataset:
if Dataset is None:
raise ImportError("datasets is required for training. Install with: pip install '.[train]'")
return Dataset.from_dict({"prompt": prompts})
class CricketCaptainToolEnv:
"""TRL environment wrapper exposing CricketCaptain actions as real tools."""
_stats_lock = threading.Lock()
def __init__(self):
self.env = CricketEnvironment()
self.reward = 0.0
self.done = False
self.final_reward = 0.0
self._episode_seed: int | None = None
self._episode_started = False
self._max_tool_iters: int | None = None
self._episode_had_step = False
self._episode_logged = False
def _maybe_log_episode_end(self, termination_reason: str):
# Avoid double-logging the same episode (e.g. once at termination, again on reset()).
if self._episode_logged:
return
stats_path = os.environ.get("CRICKET_EPISODE_STATS_PATH")
if not stats_path:
return
state = getattr(self.env, "state", None)
payload = {
"ts": datetime.datetime.now().isoformat(),
"seed": self._episode_seed,
"done": bool(self.done),
"termination_reason": termination_reason,
"reward_running_sum": float(self.reward),
"final_reward_bonus": float(self.final_reward),
}
if state is not None:
# ---- match config / context ----
payload["max_overs"] = getattr(state, "max_overs", None)
payload["opponent_mode"] = getattr(state, "opponent_mode", None)
payload["agent_team"] = getattr(state, "eval_pack_id", None) or getattr(state, "agent_team", None)
payload["innings_type"] = getattr(state, "innings_type", None)
payload["game_state"] = getattr(state, "game_state", None)
# ---- match outcome ----
payload["overs_played"] = getattr(state, "over", None)
payload["balls_played"] = getattr(state, "ball", None)
payload["agent_score"] = getattr(state, "total_score", None)
payload["wickets_lost"] = getattr(state, "wickets_lost", None)
payload["first_innings_score"] = getattr(state, "first_innings_score", None)
payload["target"] = getattr(state, "target", None)
payload["match_result"] = getattr(state, "match_result", None) or None
# ---- tool calls ----
tool_calls_made = int(getattr(state, "tool_calls_made", 0) or 0)
payload["tool_calls"] = tool_calls_made
tool_history = getattr(state, "tool_history", None) or []
tool_breakdown: dict[str, int] = {}
for c in tool_history:
t = c.get("tool", "unknown")
tool_breakdown[t] = tool_breakdown.get(t, 0) + 1
payload["tool_breakdown"] = tool_breakdown
payload["analyze_calls"] = len(getattr(state, "analyze_calls", []) or [])
# ---- per-turn rubric averages (mean across the full episode) ----
def _mean(xs):
xs = list(xs or [])
return round(sum(xs) / len(xs), 4) if xs else None
payload["mean_coherence"] = _mean(getattr(state, "coherence_scores", None))
payload["mean_adaptation"] = _mean(getattr(state, "adaptation_scores", None))
payload["mean_opponent_awareness"] = _mean(getattr(state, "opponent_awareness_scores", None))
payload["mean_regret"] = _mean(getattr(state, "regret_scores", None))
payload["mean_plan_commitment"] = _mean(getattr(state, "plan_commitment_scores", None))
payload["mean_plan_freshness"] = _mean(getattr(state, "plan_freshness_scores", None))
payload["strategy_changes"] = getattr(state, "strategy_changes", None)
payload["plan_version"] = getattr(state, "plan_version", None)
# ---- composite + per-rubric reward (already computed in reward_calculator) ----
if getattr(state, "reward_breakdown", None):
payload["reward_breakdown"] = dict(state.reward_breakdown)
with self._stats_lock:
with open(stats_path, "a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
f.flush()
self._episode_logged = True
def reset(self, **kwargs) -> str:
# If the previous episode ended because the trainer hit the tool-iteration cap,
# TRL will stop calling tools and then call reset() for the next scenario.
# In that case, self.done will still be False, but tool_calls_made will be at/near the cap.
if self._episode_started and self._episode_had_step and not self._episode_logged:
prev_calls = getattr(getattr(self.env, "state", None), "tool_calls_made", None)
if self.done:
self._maybe_log_episode_end("natural")
elif self._max_tool_iters and prev_calls is not None and int(prev_calls) >= int(self._max_tool_iters):
self._maybe_log_episode_end("cap")
# Otherwise: trainer reset the env mid-episode (e.g. generation bookkeeping).
# Don't log — it would skew the termination distribution.
self.reward = 0.0
self.done = False
self.final_reward = 0.0
self._episode_seed = kwargs.get("seed")
self._episode_started = True
self._episode_had_step = False
self._episode_logged = False
self._max_tool_iters = (
int(kwargs["max_tool_calling_iterations"])
if "max_tool_calling_iterations" in kwargs and kwargs["max_tool_calling_iterations"] is not None
else (int(os.environ["CRICKET_MAX_TOOL_ITERS"]) if os.environ.get("CRICKET_MAX_TOOL_ITERS") else None)
)
obs = self.env.reset(seed=kwargs.get("seed"), options={
"task": kwargs.get("task", "stage2_full"),
"random_start": bool(kwargs.get("random_start", False)),
"max_overs": int(kwargs.get("max_overs", 5)),
"eval_pack_id": kwargs.get("eval_pack_id", "adaptive_t20_v1"),
"opponent_mode": kwargs.get("opponent_mode", "heuristic"),
"opponent_cache_path": kwargs.get("opponent_cache_path"),
"agent_team": kwargs.get("agent_team"),
})
return obs.prompt_text
def _apply(self, tool: str, arguments: dict[str, Any]) -> str:
if self.done:
raise ValueError("Match is already finished.")
self._episode_had_step = True
available = self.env.state.game_state and self.env._get_available_tools()
if tool not in available:
self.reward -= 0.2
raise ValueError(f"Tool '{tool}' is not available. Available tools: {available}")
obs = self.env.step(CricketAction(tool=tool, arguments=arguments))
self.done = bool(obs.done)
self.reward += float(obs.reward or 0.0)
if obs.done and self.env.state.reward_breakdown:
self.final_reward = float(self.env.state.reward_breakdown.get("composite", 0.0))
self.reward += self.final_reward
# Log at the time of termination (do not wait for reset()) so the file appears promptly.
if self.done:
self._maybe_log_episode_end("natural")
# Also log cap termination as soon as we hit it, so runs always get stats even if TRL delays reset().
elif self._max_tool_iters:
state = getattr(self.env, "state", None)
calls = getattr(state, "tool_calls_made", None) if state is not None else None
if calls is not None and int(calls) >= int(self._max_tool_iters):
self._maybe_log_episode_end("cap")
return obs.prompt_text
def call_toss(self, call: str, decision: str) -> str:
"""
Call the coin toss and choose whether to bat or bowl if the toss is won.
Args:
call: Coin call, either "heads" or "tails".
decision: Preferred decision, either "bat" or "bowl".
Returns:
Updated match observation after the toss.
"""
return self._apply("call_toss", {"call": call, "decision": decision})
def set_match_plan(
self,
powerplay_intent: str,
middle_intent: str,
death_intent: str,
risk_budget: str,
trigger_conditions: str,
rationale: str,
) -> str:
"""
Establish the long-horizon plan for the innings.
Args:
powerplay_intent: Plan for overs in the powerplay.
middle_intent: Plan for middle overs.
death_intent: Plan for death overs.
risk_budget: How wickets, overs, and target pressure affect risk.
trigger_conditions: Match-state changes that should trigger a plan update.
rationale: Why this plan fits the roster and match situation.
Returns:
Updated match observation after setting the plan.
"""
return self._apply("set_match_plan", {
"powerplay_intent": powerplay_intent,
"middle_intent": middle_intent,
"death_intent": death_intent,
"risk_budget": risk_budget,
"trigger_conditions": trigger_conditions,
"rationale": rationale,
})
def update_match_plan(self, reason: str, risk_budget: str = "", trigger_conditions: str = "") -> str:
"""
Update the long-horizon plan after a meaningful match-state change.
Args:
reason: Specific reason for updating the plan.
risk_budget: Optional revised risk budget.
trigger_conditions: Optional revised trigger conditions.
Returns:
Updated match observation after revising the plan.
"""
args = {"reason": reason}
if risk_budget:
args["risk_budget"] = risk_budget
if trigger_conditions:
args["trigger_conditions"] = trigger_conditions
return self._apply("update_match_plan", args)
def select_batter(self, name: str, style: str, aggression: float, rationale: str) -> str:
"""
Select the next batter from the configured roster.
Args:
name: Player name from the team roster.
style: Batter style from the roster or tactical role.
aggression: Batting aggression from 0.0 to 1.0.
rationale: Why this batter fits the phase, wickets, and target.
Returns:
Updated match observation after selecting the batter.
"""
return self._apply("select_batter", {
"name": name,
"style": style,
"aggression": aggression,
"rationale": rationale,
})
def set_strategy(self, phase_intent: str, aggression: float, rationale: str) -> str:
"""
Set batting strategy for the current phase.
Args:
phase_intent: Tactical batting intent for this phase.
aggression: Batting aggression from 0.0 to 1.0.
rationale: Why the strategy fits score, wickets, target, and field.
Returns:
Updated match observation after setting batting strategy.
"""
return self._apply("set_strategy", {
"phase_intent": phase_intent,
"aggression": aggression,
"rationale": rationale,
})
def plan_shot(self, shot_intent: str, target_area: str, risk: str, trajectory: str, rationale: str) -> str:
"""DEPRECATED — pass these args inline to play_delivery() instead.
Args:
shot_intent: leave|defensive|single|rotate|boundary|six.
target_area: scoring area.
risk: low|balanced|high.
trajectory: ground|lofted|aerial.
rationale: one-line reason.
Returns:
Updated observation.
"""
return self._apply("plan_shot", {
"shot_intent": shot_intent,
"target_area": target_area,
"risk": risk,
"trajectory": trajectory,
"rationale": rationale,
})
def play_delivery(
self,
shot_intent: str = "",
target_area: str = "",
risk: str = "",
trajectory: str = "",
rationale: str = "",
) -> str:
"""
Execute the ball. Pass shot params inline to skip plan_shot.
Args:
shot_intent: leave|defensive|single|rotate|boundary|six.
target_area: optional scoring area.
risk: optional low|balanced|high.
trajectory: optional ground|lofted|aerial.
rationale: optional one-line reason.
Returns:
Updated observation after the ball outcome.
"""
args: dict[str, Any] = {}
if shot_intent: args["shot_intent"] = shot_intent
if target_area: args["target_area"] = target_area
if risk: args["risk"] = risk
if trajectory: args["trajectory"] = trajectory
if rationale: args["rationale"] = rationale
return self._apply("play_delivery", args)
def choose_bowler(self, name: str, bowler_type: str, style: str, rationale: str) -> str:
"""
Choose the bowler at the start of an over from the configured roster.
Args:
name: Player name from the team roster.
bowler_type: Bowler type, either pace or spin.
style: Bowling style or role.
rationale: Why this bowler fits phase, matchup, and remaining overs.
Returns:
Updated match observation after choosing the bowler.
"""
return self._apply("choose_bowler", {
"name": name,
"bowler_type": bowler_type,
"style": style,
"rationale": rationale,
})
def set_bowling_strategy(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str:
"""
Set bowling strategy for the current bowler.
Args:
bowler_type: Current bowler type, either pace or spin.
line: Intended line.
length: Intended length.
delivery_type: Variation or stock delivery type.
rationale: Why this plan fits batter, field, phase, and target.
Returns:
Updated match observation after setting bowling strategy.
"""
return self._apply("set_bowling_strategy", {
"bowler_type": bowler_type,
"line": line,
"length": length,
"delivery_type": delivery_type,
"rationale": rationale,
})
def plan_delivery(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str:
"""DEPRECATED — pass these args inline to bowl_delivery() instead.
Args:
bowler_type: pace|spin.
line: line.
length: length.
delivery_type: variation or stock.
rationale: one-line reason.
Returns:
Updated observation.
"""
return self._apply("plan_delivery", {
"bowler_type": bowler_type,
"line": line,
"length": length,
"delivery_type": delivery_type,
"rationale": rationale,
})
def set_field_setting(self, setting: str) -> str:
"""
Set the field preset.
Args:
setting: One of Aggressive, Balanced, or Defensive.
Returns:
Updated match observation after setting the field.
"""
return self._apply("set_field_setting", {"setting": setting})
def bowl_delivery(
self,
line: str = "",
length: str = "",
delivery_type: str = "",
rationale: str = "",
) -> str:
"""
Execute the delivery. Pass plan params inline to skip plan_delivery.
Args:
line: optional line.
length: optional length.
delivery_type: optional variation or stock.
rationale: optional one-line reason.
Returns:
Updated observation after the ball outcome.
"""
args: dict[str, Any] = {}
if line: args["line"] = line
if length: args["length"] = length
if delivery_type: args["delivery_type"] = delivery_type
if rationale: args["rationale"] = rationale
return self._apply("bowl_delivery", args)
def reflect_after_ball(self, reflection: str) -> str:
"""
Reflect after the previous ball and adapt the plan.
Args:
reflection: Specific tactical lesson from the previous ball.
Returns:
Updated match observation after recording reflection.
"""
return self._apply("reflect_after_ball", {"reflection": reflection})
def analyze_situation(self, query_type: str) -> str:
"""
Analyze part of the match context.
Args:
query_type: One of pitch_conditions, bowler_info, field_setting, or match_situation.
Returns:
Updated observation containing the analysis result.
"""
return self._apply("analyze_situation", {"query_type": query_type})
def build_agent_dataset(n_examples: int, args) -> Dataset:
if Dataset is None:
raise ImportError("datasets is required for training. Install with: pip install '.[train]'")
rows = []
rng = random.Random(args.seed)
stage_prompt = get_system_prompt(args.stage)
# Curriculum distribution. If --max-overs is set, use it as a fixed format.
# Otherwise sample per-scenario from a T2-heavy distribution that tapers to T5.
# Rationale: T2 episodes (~72 tool calls) actually COMPLETE within our token
# budget so r_result fires; T5 episodes (~180) train the model on its
# eval distribution. Heavy weight on short formats early so the policy
# escapes the "planning loop" before tackling longer matches.
overs_distribution = getattr(args, "overs_distribution", None)
fixed_overs = args.max_overs if args.max_overs and args.max_overs > 0 else None
if fixed_overs is None and not overs_distribution:
# default curriculum: 50% T2, 30% T3, 15% T4, 5% T5
overs_distribution = [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
for idx in range(n_examples):
scenario_overs = fixed_overs if fixed_overs is not None else rng.choice(overs_distribution)
rows.append({
"prompt": [
{"role": "system", "content": stage_prompt},
{"role": "user", "content": ""},
],
"seed": rng.randint(0, 999999),
"task": "stage1_format" if args.stage == 1 else "stage2_full",
"random_start": False,
"max_overs": scenario_overs,
"eval_pack_id": args.eval_pack_id,
"opponent_mode": args.opponent_mode,
"opponent_cache_path": getattr(args, "opponent_cache_path", None),
"agent_team": args.agent_team,
"scenario_id": idx,
})
return Dataset.from_list(rows)
def environment_reward(environments, **kwargs) -> list[float]:
rewards = []
# Aggregate metrics across all envs in this gradient step for WandB logging.
agg = {
"r_result": [], "r_cricket": [], "r_behavior": [], "r_validity": [],
"r_coherence": [], "r_adaptation": [], "r_opponent_awareness": [], "r_regret": [],
"composite": [], "tool_calls": [], "wickets_lost": [], "agent_score": [],
"matches_completed": 0, "n": 0,
}
tool_freq: dict[str, int] = {}
for env in environments:
state = env.env.state
breakdown = state.reward_breakdown or {}
terminal = float(breakdown.get("composite", 0.0))
plan_score = (sum(state.plan_commitment_scores) / len(state.plan_commitment_scores)) if state.plan_commitment_scores else 0.0
validity = 1.0 - min(1.0, len([c for c in state.tool_history if c.get("tool") == "invalid_json"]) / max(state.step_count, 1))
reward = env.reward + terminal + 0.1 * plan_score + 0.05 * validity
# Reward clip removed: when rollouts complete naturally, the composite
# reward easily saturates [-1, 1], causing GRPO group-std → 0 and
# killing the gradient signal. Let GRPO standardize the advantage itself.
rewards.append(round(reward, 4))
# Collect for aggregate logging.
agg["n"] += 1
if env.done:
agg["matches_completed"] += 1
for k in ("r_result", "r_cricket", "r_behavior", "r_validity",
"r_coherence", "r_adaptation", "r_opponent_awareness",
"r_regret", "composite"):
v = breakdown.get(k)
if v is not None:
agg[k].append(float(v))
agg["tool_calls"].append(int(getattr(state, "tool_calls_made", 0) or 0))
agg["wickets_lost"].append(int(getattr(state, "wickets_lost", 0) or 0))
agg["agent_score"].append(int(getattr(state, "total_score", 0) or 0))
for c in (state.tool_history or []):
t = c.get("tool", "unknown")
tool_freq[t] = tool_freq.get(t, 0) + 1
# WandB log — only if wandb is initialised in this process.
try:
import wandb
if wandb.run is not None and agg["n"] > 0:
log_dict: dict[str, float] = {
"rollout/n_episodes": agg["n"],
"rollout/matches_completed": agg["matches_completed"],
"rollout/match_completion_rate": agg["matches_completed"] / agg["n"],
}
for k in ("r_result", "r_cricket", "r_behavior", "r_validity",
"r_coherence", "r_adaptation", "r_opponent_awareness",
"r_regret", "composite"):
if agg[k]:
log_dict[f"reward/{k}_mean"] = sum(agg[k]) / len(agg[k])
log_dict[f"reward/{k}_max"] = max(agg[k])
log_dict[f"reward/{k}_min"] = min(agg[k])
for k in ("tool_calls", "wickets_lost", "agent_score"):
if agg[k]:
log_dict[f"episode/{k}_mean"] = sum(agg[k]) / len(agg[k])
log_dict[f"episode/{k}_max"] = max(agg[k])
# Tool usage breakdown — frequency per tool name across this step.
total_tools = sum(tool_freq.values()) or 1
for t, n in tool_freq.items():
log_dict[f"tools/freq_{t}"] = n / total_tools
wandb.log(log_dict)
except Exception:
# Never let logging break training.
pass
return rewards
def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42, agent_team: str | None = None):
"""Stage 0 bootstrap data: valid tool JSON for every tool family."""
rng = random.Random(seed)
roster = _training_roster(agent_team)
examples = []
for _ in range(n_examples):
game_state = rng.choice(["toss", "batting", "bowling"])
action = _random_action(rng, game_state, roster=roster)
prompt = (
f"{SYSTEM_PROMPT}\n\n"
f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n"
"Phase: MIDDLE | Strategic turn: PRE_BALL\n"
"Opponent last plan: {'field_setting': 'Defensive', 'shot_intent': 'boundary'}\n"
)
examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
{"role": "assistant", "content": json.dumps({"tool": action.tool, "arguments": action.arguments})},
]
})
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("w") as f:
for ex in examples:
f.write(json.dumps(ex) + "\n")
print(f"Wrote {len(examples)} SFT examples -> {out}")
# ------------------------------------------------------------------ #
# Model loading (plain transformers + bitsandbytes 4-bit) #
# ------------------------------------------------------------------ #
def load_model(model_name: str, *, use_vllm: bool = False, resume_adapter_from: str | None = None):
"""Load base + LoRA. When use_vllm=True, base is loaded in bf16 (vLLM
does not support 4-bit BNB inference); otherwise 4-bit NF4.
resume_adapter_from: optional path to a PEFT adapter directory (e.g. a previous
checkpoint dir). If provided, loads the adapter weights instead of initializing
a fresh LoRA. The base model is still loaded from `model_name`. The adapter's
LoraConfig is preserved (so you can resume even if r= or alpha= drift between runs)."""
if not _TRAIN_IMPORTS_AVAILABLE:
raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
print(f"Loading {model_name} … (use_vllm={use_vllm}, dtype={'bf16' if use_vllm else '4-bit NF4'})")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
import flash_attn # noqa: F401
attn_impl = "flash_attention_2"
except ImportError:
attn_impl = "sdpa"
load_kwargs = dict(
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
)
if not use_vllm:
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
if not use_vllm:
model = prepare_model_for_kbit_training(model)
if resume_adapter_from:
# Resume from a previous PEFT adapter checkpoint (e.g. warmup output).
# PeftModel.from_pretrained reads the adapter_config.json from the dir,
# so any r/alpha/target_modules saved with the warmup run is preserved.
from peft import PeftModel
adapter_path = Path(resume_adapter_from)
if not adapter_path.exists():
raise FileNotFoundError(f"resume_adapter_from path does not exist: {adapter_path}")
print(f"Resuming LoRA adapter from {adapter_path}")
model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True)
else:
lora_cfg = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_cfg)
print(f"Loaded. Parameters: {model.num_parameters():,}")
model.print_trainable_parameters()
return model, tokenizer
# ------------------------------------------------------------------ #
# Training #
# ------------------------------------------------------------------ #
def train(args):
if not _TRAIN_IMPORTS_AVAILABLE:
raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
if args.opponent_mode == "llm_live":
if args.opponent_model:
os.environ["CRICKET_OPPONENT_MODEL"] = args.opponent_model
if args.opponent_api_base:
os.environ["CRICKET_OPPONENT_API_BASE"] = args.opponent_api_base
if args.opponent_api_key:
os.environ["CRICKET_OPPONENT_API_KEY"] = args.opponent_api_key
task = "stage1_format" if args.stage == 1 else "stage2_full"
# CRICKET_CKPT_ROOT lets a side-by-side run write checkpoints to a different
# tree (e.g. ./checkpoints_smoke) without trampling an active production run.
# Default unchanged: ./checkpoints/.
ckpt_root = os.environ.get("CRICKET_CKPT_ROOT", "./checkpoints").rstrip("/")
out_dir = f"{ckpt_root}/stage{args.stage}"
save_dir = f"{ckpt_root}/stage{args.stage}_final"
ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = Path(f"./logs/run_{ts}_stage{args.stage}_{args.opponent_mode}")
log_dir.mkdir(parents=True, exist_ok=True)
# Make episode termination stats available to the environment wrapper.
# This lets us distinguish natural terminations from tool-iteration cap truncations.
stats_path = log_dir / "episode_stats.jsonl"
os.environ["CRICKET_EPISODE_STATS_PATH"] = str(stats_path)
os.environ["CRICKET_MAX_TOOL_ITERS"] = str(args.max_tool_calling_iterations)
# Create the file immediately so users can find/tail it even before the first termination.
stats_path.touch(exist_ok=True)
print(f"\n=== Stage {args.stage} Training ===")
print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}")
print(f"Logs: {log_dir}/ | Checkpoints: {out_dir}/")
print(f"max_tool_calling_iterations={args.max_tool_calling_iterations} (full 5-over match needs ~180; 20-over needs ~720)")
(log_dir / "metadata.json").write_text(json.dumps({
"stage": args.stage, "model": args.model, "agent_team": args.agent_team,
"max_overs": args.max_overs, "opponent_mode": args.opponent_mode,
"prompts": args.prompts, "steps": args.steps,
"batch_size": args.batch_size, "grad_accum": args.grad_accum,
"num_generations": args.num_generations,
"max_completion_length": args.max_completion_length,
"max_tool_calling_iterations": args.max_tool_calling_iterations,
"logging_steps": args.logging_steps,
"timestamp": ts,
}, indent=2))
# Build scenario seeds. TRL's environment_factory performs the actual
# multi-turn rollout and tool execution during training.
print("\nBuilding environment scenarios …")
dataset = build_agent_dataset(args.prompts, args)
# Load model — bf16 if vLLM is on (vLLM rejects 4-bit BNB) or --bf16-base, else 4-bit NF4.
# If resume_from is set, load the LoRA adapter from that path instead of fresh init.
bf16_base = getattr(args, "use_vllm", False) or getattr(args, "bf16_base", False)
resume_from = getattr(args, "resume_from", None)
model, tokenizer = load_model(args.model, use_vllm=bf16_base, resume_adapter_from=resume_from)
# GRPO config
#
# Qwen3 / Qwen3.5 ship with hybrid thinking ENABLED by default. Empirically
# (see logs/run_2026-04-25_21-08-45 completions parquet) every generation
# spent ~1200 chars inside <think>...</think> and then emitted XML-style
# <function><parameter> tags instead of the JSON tool call we asked for.
# That meant 0/32 generations were parseable, _apply() never advanced the
# match, and episodes always hit max_tool_calling_iterations before any
# innings finished — so r_result (55% of the composite) was never earned.
#
chat_template_kwargs = {}
generation_kwargs = {}
completion_len = max(args.max_completion_length, 2048)
use_vllm = getattr(args, "use_vllm", False)
vllm_kwargs = {}
if use_vllm:
# vllm_model_impl: None (default) → vLLM picks its native class. Use this for
# Qwen3-* (Qwen3ForCausalLM is registered, native path with full LoRA support).
# Set to "transformers" only for Qwen3.5-* where vLLM has no text-only class
# registered and the native path tries to load a vision tower.
vllm_kwargs = dict(
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=getattr(args, "vllm_gpu_memory", 0.5),
vllm_max_model_length=completion_len + 2048,
)
vllm_impl = getattr(args, "vllm_model_impl", None)
if vllm_impl:
vllm_kwargs["vllm_model_impl"] = vllm_impl
# Resolve hyperparameters from YAML/CLI with sensible fallbacks.
lr = args.learning_rate if getattr(args, "learning_rate", None) is not None \
else (2e-5 if args.stage == 1 else 1e-5)
grpo_beta = getattr(args, "beta", None)
grpo_temp = getattr(args, "temperature", None) or 0.8
grpo_top_p = getattr(args, "top_p", None)
grad_ckpt = getattr(args, "gradient_checkpointing", None)
grad_ckpt_kwargs = None
if grad_ckpt and getattr(args, "gradient_checkpointing_use_reentrant", None) is not None:
grad_ckpt_kwargs = {"use_reentrant": bool(args.gradient_checkpointing_use_reentrant)}
optional_cfg = {}
if grpo_beta is not None:
optional_cfg["beta"] = grpo_beta
if grpo_top_p is not None:
optional_cfg["top_p"] = grpo_top_p
if grad_ckpt is not None:
optional_cfg["gradient_checkpointing"] = bool(grad_ckpt)
if grad_ckpt_kwargs is not None:
optional_cfg["gradient_checkpointing_kwargs"] = grad_ckpt_kwargs
if getattr(args, "dataloader_pin_memory", None) is not None:
optional_cfg["dataloader_pin_memory"] = bool(args.dataloader_pin_memory)
if getattr(args, "dataloader_num_workers", None) is not None:
optional_cfg["dataloader_num_workers"] = int(args.dataloader_num_workers)
config = GRPOConfig(
output_dir=out_dir,
logging_dir=str(log_dir / "tensorboard"),
num_train_epochs=1,
max_steps=args.steps,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=lr,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
logging_steps=args.logging_steps,
save_steps=getattr(args, "save_steps", None) or 10,
save_total_limit=getattr(args, "save_total_limit", None) or 20,
bf16=True,
max_completion_length=completion_len,
num_generations=args.num_generations,
max_tool_calling_iterations=args.max_tool_calling_iterations,
temperature=grpo_temp,
report_to=args.report_to,
run_name=args.run_name,
log_completions=True,
seed=args.seed,
chat_template_kwargs=chat_template_kwargs,
generation_kwargs=generation_kwargs,
**optional_cfg,
**vllm_kwargs,
)
# TRL's add_response_schema pattern-matches tokenizer.chat_template against
# a fixed list and raises "Unrecognized chat template" if no match. Some
# newer Qwen3 builds (e.g. Qwen3-4B-Instruct-2507, Aug 2025) ship a
# template that differs from TRL's stored string (the Instruct release
# dropped the enable_thinking block) — but the tool-call format
# (<tool_call>…</tool_call>) is identical, so the appropriate schema still
# parses correctly. We assign it manually before GRPOTrainer init; TRL
# checks `response_schema is None` first so this is a safe override.
if getattr(tokenizer, "response_schema", None) is None:
try:
from trl.chat_template_utils import qwen3_schema, qwen3_5_schema
m = args.model.lower()
if "qwen3.5" in m or "qwen3_5" in m:
tokenizer.response_schema = qwen3_5_schema
print("Set tokenizer.response_schema = qwen3_5_schema (manual override).")
elif "qwen3" in m:
tokenizer.response_schema = qwen3_schema
print("Set tokenizer.response_schema = qwen3_schema (manual override).")
except ImportError:
pass
trainer = GRPOTrainer(
model=model,
reward_funcs=environment_reward,
args=config,
train_dataset=dataset,
processing_class=tokenizer,
environment_factory=CricketCaptainToolEnv,
)
print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …")
trainer.train()
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f"\nSaved → {save_dir}")
# ------------------------------------------------------------------ #
# Quick eval: run N episodes with the trained model #
# ------------------------------------------------------------------ #
def evaluate(args):
"""Run N episodes and print coherence + score stats."""
from server.reward_calculator import compute_episode_reward, get_dls_par
model, tokenizer = load_model(args.model)
model.eval()
def generate(prompt: str) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs, max_new_tokens=200, temperature=0.7, do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
rng = random.Random(args.seed)
all_coh, all_scores = [], []
for ep in range(args.eval_episodes):
env = CricketEnvironment()
obs = env.reset(seed=rng.randint(0, 99999), options={
"task": "stage2_full",
"random_start": False,
"agent_team": args.agent_team,
})
steps = 0
while not obs.done and steps < 150:
prompt = _format_prompt(obs.prompt_text)
raw = generate(prompt)
data = _parse_completion(raw)
if data:
action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {}))
else:
action = CricketAction(tool="invalid_json", arguments={})
obs = env.step(action)
steps += 1
state = env.state
avg_coh = sum(state.coherence_scores) / len(state.coherence_scores) if state.coherence_scores else 0
all_coh.append(avg_coh)
all_scores.append(state.total_score)
print(f" Ep {ep+1}: {state.total_score}/{state.wickets_lost} coh={avg_coh:.3f}")
print(f"\nAvg coherence: {sum(all_coh)/len(all_coh):.3f}")
print(f"Avg score: {sum(all_scores)/len(all_scores):.1f}")
def _make_run_folder(prefix: str, model: str | None, opponent_mode: str | None, max_overs: int | None) -> Path:
"""Create a timestamped illustrations folder, return its path."""
import datetime
ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
model_short = (model or "heuristic").split("/")[-1][:20] if model else "heuristic"
overs_str = f"_{max_overs}ov" if max_overs else ""
opp_str = f"_{opponent_mode}" if opponent_mode else ""
folder_name = f"exp_{ts}_{prefix}{overs_str}{opp_str}_{model_short}"
run_dir = Path(__file__).parent / "illustrations" / folder_name
run_dir.mkdir(parents=True, exist_ok=True)
return run_dir
def train_smoke(args):
"""Run short direct-environment training rollouts without loading a model."""
rng = random.Random(args.seed)
roster = _training_roster(args.agent_team)
# Auto-create run folder unless --output explicitly given
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
run_dir = output_path.parent
else:
model_hint = getattr(args, "model", None)
run_dir = _make_run_folder("train_smoke", model_hint, args.opponent_mode, args.max_overs)
output_path = run_dir / "run_output.txt"
# Write header immediately so the file exists while the run is in progress
header_lines = [
"# Training smoke: direct CricketEnvironment rollout",
f"matches={args.matches} max_overs={args.max_overs} opponent_mode={args.opponent_mode}",
"purpose=verify one short training-style match rollout, prompt collection, tool stepping, and terminal reward",
"",
]
output_path.write_text("\n".join(header_lines))
def log(msg: str):
print(msg)
with open(output_path, "a") as _f:
_f.write(msg + "\n")
log("# Training smoke: direct CricketEnvironment rollout")
log(f"matches={args.matches} max_overs={args.max_overs} opponent_mode={args.opponent_mode}")
log("purpose=verify one short training-style match rollout, prompt collection, tool stepping, and terminal reward")
for match_idx in range(args.matches):
match_start = time.perf_counter()
last_step_time = match_start
env = CricketEnvironment()
obs = env.reset(seed=rng.randint(0, 99999), options={
"task": "stage2_full",
"random_start": False,
"max_overs": args.max_overs,
"eval_pack_id": args.eval_pack_id,
"opponent_mode": args.opponent_mode,
"opponent_cache_path": args.opponent_cache_path,
"agent_team": args.agent_team,
})
prompts = [_format_prompt(obs.prompt_text)]
total_reward = 0.0
steps = 0
log(f"\n--- match {match_idx + 1} reset ---")
log(f"initial_state={obs.game_state} phase={obs.strategic_phase} t_elapsed=0.000s tools={obs.available_tools}")
while not obs.done and steps < args.max_steps:
step_start = time.perf_counter()
previous_game_state = obs.game_state
previous_innings = obs.innings_type
previous_available_tools = obs.available_tools
action = _random_action(
rng,
obs.game_state,
obs.available_tools,
obs.current_bowler.get("type") if obs.current_bowler else None,
roster,
)
obs = env.step(action)
step_end = time.perf_counter()
step_dt = step_end - step_start
elapsed = step_end - match_start
since_prev = step_end - last_step_time
last_step_time = step_end
total_reward += obs.reward or 0.0
if not obs.done:
prompts.append(_format_prompt(obs.prompt_text))
changed_context = (
obs.done
or obs.game_state != previous_game_state
or obs.innings_type != previous_innings
or obs.available_tools != previous_available_tools
)
if steps < args.log_steps or changed_context:
over = obs.game_context.get("over", 0) or 0
ball = obs.game_context.get("ball", 0) or 0
score = obs.game_context.get("score", 0) or 0
balls_used = int(over) * 6 + int(ball)
balls_left = max(args.max_overs * 6 - balls_used, 0)
current_rr = score / (balls_used / 6) if balls_used > 0 else 0.0
if obs.target is not None:
runs_required = max(obs.target - score, 0)
required_rr = runs_required / max(balls_left / 6, 1 / 6)
chase_context = f"need={runs_required} balls_left={balls_left} rrr={required_rr:.2f}"
else:
chase_context = "need=None balls_left=None rrr=None"
outcome_meta = (obs.last_outcome or {}).get("metadata", {})
tactical_context = ""
if outcome_meta and action.tool in {"bowl_delivery", "play_delivery"}:
delivery = outcome_meta.get("delivery_features", {})
tactical_context = (
f" event={outcome_meta.get('event_type')} zone={outcome_meta.get('target_area')} "
f"traj={outcome_meta.get('trajectory')} field_effect={outcome_meta.get('fielder_effect')} "
f"fit={outcome_meta.get('shot_delivery_fit')} field_pressure={outcome_meta.get('field_pressure')} "
f"line={delivery.get('line')} length={delivery.get('length')} variation={delivery.get('variation')}"
)
log(
f"step={steps:03d} t_elapsed={elapsed:.3f}s step_dt={step_dt:.4f}s since_prev={since_prev:.4f}s "
f"tool={action.tool} reward={(obs.reward or 0.0):.3f} "
f"state={obs.game_state}/{obs.innings_type} phase={obs.strategic_phase} "
f"over={obs.game_context.get('over')}.{obs.game_context.get('ball')} "
f"score={obs.game_context.get('score')}/{obs.game_context.get('wickets')} "
f"target={obs.target} rr={current_rr:.2f} {chase_context} "
f"{tactical_context} "
f"tools={obs.available_tools} "
f"last={obs.last_ball_result[:140]!r}"
)
steps += 1
state = env.state
match_elapsed = time.perf_counter() - match_start
log(f"\n--- match {match_idx + 1} final ---")
log(f"done={obs.done} steps={steps} prompts_collected={len(prompts)} rollout_reward_sum={total_reward:.3f} match_elapsed={match_elapsed:.3f}s avg_step_dt={(match_elapsed / max(steps, 1)):.4f}s")
log(f"score={state.total_score}/{state.wickets_lost} over={state.over}.{state.ball} target={state.target} game_state={state.game_state}")
log(f"last_outcome={state.last_outcome}")
log(f"match_result={state.match_result} reward_breakdown={state.reward_breakdown}")
log(f"innings_rewards={state.innings_rewards}")
log(f"tool_calls={state.tool_calls_made} dream11_scores={state.dream11_scores}")
log(f"mean_coherence={(sum(state.coherence_scores) / len(state.coherence_scores)) if state.coherence_scores else 0.0:.3f}")
log(f"mean_adaptation={(sum(state.adaptation_scores) / len(state.adaptation_scores)) if state.adaptation_scores else 0.0:.3f}")
log(f"mean_opponent_awareness={(sum(state.opponent_awareness_scores) / len(state.opponent_awareness_scores)) if state.opponent_awareness_scores else 0.0:.3f}")
print(f"\nwrote={output_path}")
# Write README for the run
import datetime
readme_path = run_dir / "README.md"
model_str = getattr(args, "model", None) or "heuristic (random actions)"
readme_path.write_text(
f"## Train-Smoke Run: {run_dir.name}\n\n"
f"**Date**: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n"
f"**Config**: `{getattr(args, 'config', None) or 'defaults'}`\n\n"
f"| Setting | Value |\n|---|---|\n"
f"| Matches | {args.matches} |\n"
f"| Max overs | {args.max_overs} |\n"
f"| Opponent mode | {args.opponent_mode} |\n"
f"| Model (train target) | `{model_str}` |\n\n"
f"See `run_output.txt` for full step-by-step rollout log, reward breakdowns, and coherence scores.\n"
)
print(f"wrote={readme_path}")
# ------------------------------------------------------------------ #
# CLI #
# ------------------------------------------------------------------ #
def _apply_yaml_defaults(args, cfg: dict) -> None:
"""Merge YAML config values into args, CLI args take precedence."""
captain = cfg.get("captain", {}) or {}
opponent = cfg.get("opponent", {}) or {}
env_cfg = cfg.get("env", {}) or {}
train_cfg = cfg.get("train", {}) or {}
def _set(attr, val):
if val is not None and getattr(args, attr, None) is None:
setattr(args, attr, val)
if getattr(args, "cmd", None) == "train":
_set("model", train_cfg.get("model"))
else:
_set("model", captain.get("model"))
_set("api_base", captain.get("api_base"))
_set("api_key", os.environ.get(captain.get("api_key_env", "")) or None)
_set("eval_pack_id", env_cfg.get("eval_pack_id"))
_set("opponent_mode", opponent.get("mode"))
_set("opponent_cache_path", opponent.get("cache_path"))
_set("opponent_model", opponent.get("model"))
_set("opponent_api_base", opponent.get("api_base"))
api_key_env = opponent.get("api_key_env")
_set("opponent_api_key", os.environ.get(api_key_env, "") if api_key_env else None)
_set("max_overs", env_cfg.get("max_overs"))
_set("agent_team", env_cfg.get("agent_team"))
_set("steps", train_cfg.get("steps"))
_set("prompts", train_cfg.get("prompts"))
_set("batch_size", train_cfg.get("batch_size"))
_set("grad_accum", train_cfg.get("grad_accum"))
_set("stage", train_cfg.get("stage"))
_set("num_generations", train_cfg.get("num_generations"))
_set("max_completion_length", train_cfg.get("max_completion_length"))
_set("max_tool_calling_iterations", train_cfg.get("max_tool_calling_iterations"))
_set("logging_steps", train_cfg.get("logging_steps"))
_set("report_to", train_cfg.get("report_to"))
_set("run_name", train_cfg.get("run_name"))
_set("learning_rate", train_cfg.get("learning_rate"))
_set("beta", train_cfg.get("beta"))
_set("temperature", train_cfg.get("temperature"))
_set("top_p", train_cfg.get("top_p"))
_set("gradient_checkpointing", train_cfg.get("gradient_checkpointing"))
_set("gradient_checkpointing_use_reentrant", train_cfg.get("gradient_checkpointing_use_reentrant"))
_set("dataloader_pin_memory", train_cfg.get("dataloader_pin_memory"))
_set("dataloader_num_workers", train_cfg.get("dataloader_num_workers"))
_set("bf16_base", train_cfg.get("bf16_base"))
_set("save_steps", train_cfg.get("save_steps"))
_set("save_total_limit", train_cfg.get("save_total_limit"))
_set("resume_from", train_cfg.get("resume_from"))
_set("overs_distribution", train_cfg.get("overs_distribution"))
_set("use_vllm", train_cfg.get("use_vllm"))
_set("vllm_gpu_memory", train_cfg.get("vllm_gpu_memory"))
_set("vllm_model_impl", train_cfg.get("vllm_model_impl"))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None, help="YAML config path (sets defaults for all subcommands)")
sub = parser.add_subparsers(dest="cmd")
# train
t = sub.add_parser("train", help="Run GRPO training")
t.add_argument("--config", default=None)
t.add_argument("--stage", type=int, default=None, choices=[1, 2])
t.add_argument("--model", default=None)
t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect")
t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps")
t.add_argument("--batch-size", type=int, default=None, dest="batch_size")
t.add_argument("--grad-accum", type=int, default=None, dest="grad_accum")
t.add_argument("--num-generations", type=int, default=None, dest="num_generations")
t.add_argument("--agent-team", default=None, dest="agent_team")
t.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
t.add_argument("--opponent-model", default=None, dest="opponent_model")
t.add_argument("--opponent-api-base", default=None, dest="opponent_api_base")
t.add_argument("--opponent-api-key", default=None, dest="opponent_api_key")
t.add_argument("--max-overs", type=int, default=None, dest="max_overs")
t.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
t.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
t.add_argument("--max-completion-length", type=int, default=None, dest="max_completion_length")
t.add_argument("--max-tool-calling-iterations", type=int, default=None, dest="max_tool_calling_iterations")
t.add_argument("--logging-steps", type=int, default=None, dest="logging_steps")
t.add_argument("--report-to", default=None, dest="report_to")
t.add_argument("--run-name", default=None, dest="run_name")
t.add_argument("--seed", type=int, default=42)
t.add_argument("--resume-from", default=None, dest="resume_from",
help="Path to a previous LoRA adapter dir (e.g. ./checkpoints/stage2_final). "
"When set, the adapter is loaded on top of the base model instead of a fresh init.")
t.add_argument("--save-steps", type=int, default=None, dest="save_steps")
t.add_argument("--save-total-limit", type=int, default=None, dest="save_total_limit")
t.add_argument("--learning-rate", type=float, default=None, dest="learning_rate")
t.add_argument("--beta", type=float, default=None, dest="beta",
help="GRPO KL coefficient. Lower = more exploration.")
t.add_argument("--temperature", type=float, default=None, dest="temperature")
t.add_argument("--top-p", type=float, default=None, dest="top_p")
t.add_argument("--gradient-checkpointing", action="store_true", dest="gradient_checkpointing", default=None)
t.add_argument("--no-gradient-checkpointing", action="store_false", dest="gradient_checkpointing")
t.add_argument("--gradient-checkpointing-use-reentrant", action="store_true",
dest="gradient_checkpointing_use_reentrant", default=None)
t.add_argument("--dataloader-pin-memory", action="store_true", dest="dataloader_pin_memory", default=None)
t.add_argument("--dataloader-num-workers", type=int, default=None, dest="dataloader_num_workers")
t.add_argument("--use-vllm", action="store_true", dest="use_vllm", default=None,
help="Use vLLM-backed rollouts (colocate). Forces bf16 base.")
t.add_argument("--bf16-base", action="store_true", dest="bf16_base", default=None,
help="Load base model in bf16 instead of 4-bit NF4. Faster matmul on H200 since 4B fits in 8GB.")
t.add_argument("--vllm-gpu-memory", type=float, default=0.5, dest="vllm_gpu_memory",
help="Fraction of GPU memory reserved for vLLM (colocate). Default 0.5.")
t.add_argument("--vllm-model-impl", default=None, dest="vllm_model_impl",
choices=["transformers", "vllm"],
help="vLLM model backend. None (default) = native vLLM class (e.g. Qwen3ForCausalLM); "
"'transformers' = HF transformers backend (workaround for Qwen3.5 — flaky with LoRA).")
# eval
e = sub.add_parser("eval", help="Evaluate a checkpoint")
e.add_argument("--config", default=None)
e.add_argument("--model", default=None)
e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes")
e.add_argument("--agent-team", default=None, dest="agent_team")
e.add_argument("--seed", type=int, default=0)
# quick test (no GPU needed)
sub.add_parser("test", help="Smoke-test reward functions")
smoke = sub.add_parser("train-smoke", help="Run short direct-env training rollouts without loading a model")
smoke.add_argument("--config", default=None)
smoke.add_argument("--matches", type=int, default=1)
smoke.add_argument("--max-overs", type=int, default=None, dest="max_overs")
smoke.add_argument("--max-steps", type=int, default=240, dest="max_steps")
smoke.add_argument("--log-steps", type=int, default=30, dest="log_steps")
smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
smoke.add_argument("--agent-team", default=None, dest="agent_team")
smoke.add_argument("--output", default=None)
smoke.add_argument("--seed", type=int, default=42)
sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples")
sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl")
sft.add_argument("--examples", type=int, default=240)
sft.add_argument("--agent-team", default=None, dest="agent_team")
sft.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# Apply YAML config (subcommand --config overrides top-level --config)
config_path = getattr(args, "config", None) or getattr(parser.parse_known_args()[0], "config", None)
if config_path:
try:
from config_yaml import load_config
except ImportError:
from cricket_captain.config_yaml import load_config
_apply_yaml_defaults(args, load_config(config_path))
# Set safe defaults after YAML merge
if getattr(args, "stage", None) is None:
args.stage = 1
if getattr(args, "model", None) is None:
args.model = "Qwen/Qwen3.5-4B"
if getattr(args, "steps", None) is None:
args.steps = 200
if getattr(args, "prompts", None) is None:
args.prompts = 500
if getattr(args, "batch_size", None) is None:
args.batch_size = 2
if getattr(args, "grad_accum", None) is None:
args.grad_accum = 4
if getattr(args, "eval_pack_id", None) is None:
args.eval_pack_id = "adaptive_t20_v1"
if getattr(args, "opponent_mode", None) is None:
args.opponent_mode = "llm_live"
if getattr(args, "max_overs", None) is None:
args.max_overs = 5
if getattr(args, "agent_team", None) is None:
args.agent_team = os.environ.get("CRICKET_AGENT_TEAM")
if getattr(args, "max_tool_calling_iterations", None) is None:
args.max_tool_calling_iterations = 200
if getattr(args, "logging_steps", None) is None:
args.logging_steps = 1
if getattr(args, "report_to", None) is None:
args.report_to = "none"
if args.cmd == "train":
train(args)
elif args.cmd == "eval":
evaluate(args)
elif args.cmd == "test":
_smoke_test(args.agent_team, args.opponent_mode)
elif args.cmd == "train-smoke":
train_smoke(args)
elif args.cmd == "sft-data":
generate_sft_examples(args.output, args.examples, args.seed, args.agent_team)
else:
parser.print_help()
def _smoke_test(agent_team: str | None, opponent_mode: str):
"""Verify reward functions work correctly."""
cases = [
(
"[CricketCaptain] BATTING | SECOND INNINGS\nPhase: MIDDLE | Bowler: SPIN\n"
"Batting Strategy: consolidate aggression=0.30 rotate strike against spin in middle overs",
'{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotating"}}',
"high",
),
(
"[CricketCaptain] BATTING | SECOND INNINGS\nPhase: MIDDLE | Bowler: SPIN\n"
"Batting Strategy: consolidate aggression=0.30 rotate strike against spin in middle overs",
'{"tool": "play_delivery", "arguments": {"shot_intent": "six", "explanation": "going big"}}',
"low",
),
(
"[CricketCaptain] BATTING | FIRST INNINGS\nPhase: POWERPLAY\nBatting Strategy: None",
'{"tool": "play_delivery", "arguments": {"shot_intent": "boundary", "explanation": "attack"}}',
"zero",
),
(
"[CricketCaptain] BATTING | SECOND INNINGS\nPhase: DEATH\nBatting Strategy: None",
"not valid json at all",
"zero",
),
]
print("Reward function smoke test:\n")
for prompt, completion, expected in cases:
fmt = r_validity(completion)
coh = r_behavior_stateless(prompt, completion)
print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}")
print("\nPrompt collection test (5 prompts):")
p = collect_prompts(5, task="stage1_format", seed=1, agent_team=agent_team, opponent_mode=opponent_mode)
for i, pp in enumerate(p):
print(f" [{i}] {pp[:80].strip()} …")
if __name__ == "__main__":
main()