sync: today's source updates (XML-only prompt, reward unclip, neg-reward on loss, pinned versions, configs reorg)
Browse files- .gitignore +3 -0
- PRD.md +2 -0
- README.md +279 -68
- compare_eval.py +395 -0
- configs/cricket_train_qwen3.yaml +99 -0
- configs/cricket_train_qwen3_smoke.yaml +99 -0
- configs/cricket_train_qwen3_warmup.yaml +98 -0
- configs/extras/cached_eval.yaml +18 -0
- configs/extras/cricket_train.yaml +125 -0
- configs/extras/cricket_train_warmup.yaml +95 -0
- configs/extras/default.yaml +59 -0
- configs/game_knowledge.yaml +16 -5
- docs/benchmark_explainer.md +190 -451
- docs/experiment_workflow.md +161 -355
- docs/slides.html +36 -31
- openenv.yaml +13 -12
- pyproject.toml +15 -4
- scripts/eval_all_checkpoints.sh +95 -0
- scripts/generate_training_plots.py +248 -0
- scripts/run_full_pipeline.sh +84 -0
- scripts/run_warmup_then_main.sh +43 -0
- server/coherence_grader.py +27 -20
- server/cricket_environment.py +57 -5
- server/markov_engine.py +24 -13
- server/reward_calculator.py +31 -11
- train.py +1178 -147
.gitignore
CHANGED
|
@@ -14,9 +14,12 @@ checkpoints/
|
|
| 14 |
training_curves.png
|
| 15 |
training_summary.json
|
| 16 |
wandb/
|
|
|
|
| 17 |
.env
|
| 18 |
unsloth_compiled_cache/
|
| 19 |
.ipynb_checkpoints/
|
| 20 |
.DS_Store
|
| 21 |
*.log
|
| 22 |
*.zip
|
|
|
|
|
|
|
|
|
| 14 |
training_curves.png
|
| 15 |
training_summary.json
|
| 16 |
wandb/
|
| 17 |
+
logs/
|
| 18 |
.env
|
| 19 |
unsloth_compiled_cache/
|
| 20 |
.ipynb_checkpoints/
|
| 21 |
.DS_Store
|
| 22 |
*.log
|
| 23 |
*.zip
|
| 24 |
+
checkpoints_smoke/
|
| 25 |
+
.venv-qwen3/
|
PRD.md
CHANGED
|
@@ -128,6 +128,8 @@ Each step returns a `CricketObservation` containing:
|
|
| 128 |
|
| 129 |
The top-level objective remains long-horizon match success over many simulated matches. Dream11-style reward is auxiliary shaping, not the primary benchmark target.
|
| 130 |
|
|
|
|
|
|
|
| 131 |
### 5.4 Curriculum Stages
|
| 132 |
|
| 133 |
| Stage | Episodes | Active Rubrics | Objective |
|
|
|
|
| 128 |
|
| 129 |
The top-level objective remains long-horizon match success over many simulated matches. Dream11-style reward is auxiliary shaping, not the primary benchmark target.
|
| 130 |
|
| 131 |
+
**Tool budget (operational constraint during play and training):** per over, the environment allows **3 no-fine “overhead” tool calls** among `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, and `analyze_situation`. Each additional overhead call in that over applies a **−0.04** step reward. `plan_shot`, `set_match_plan`, `update_match_plan`, and ball-advancing tools do **not** count against this limit. Training via `train.py` (TRL GRPO with `CricketEnvironment`) uses the same rule, so the policy learns to ration analysis and re-planning across a full innings without a separate ad-hoc budget in the trainer.
|
| 132 |
+
|
| 133 |
### 5.4 Curriculum Stages
|
| 134 |
|
| 135 |
| Stage | Episodes | Active Rubrics | Objective |
|
README.md
CHANGED
|
@@ -12,12 +12,22 @@ license: mit
|
|
| 12 |
|
| 13 |
# CricketCaptain-LLM
|
| 14 |
|
| 15 |
-
**An
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
The Hugging Face Space exposes the OpenEnv server and a Gradio demo UI at `/web`.
|
| 20 |
|
|
|
|
|
|
|
| 21 |
---
|
| 22 |
|
| 23 |
## The Problem
|
|
@@ -36,6 +46,41 @@ CricketCaptain evaluates:
|
|
| 36 |
|
| 37 |
---
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
## Read This First
|
| 40 |
|
| 41 |
- [`docs/benchmark_explainer.md`](docs/benchmark_explainer.md): full explanation of the problem statement, OpenEnv architecture, environment loop, rewards, data curation, training, and competition compliance.
|
|
@@ -58,18 +103,19 @@ Call the coin and decide whether to bat or bowl first.
|
|
| 58 |
```
|
| 59 |
|
| 60 |
### 2. Batting (when agent bats)
|
| 61 |
-
Choose
|
| 62 |
```json
|
| 63 |
-
{"tool": "
|
|
|
|
| 64 |
{"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Middle overs against spin — rotate strike and preserve wickets."}}
|
| 65 |
{"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy gap."}}
|
| 66 |
{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Working into the gap at mid-wicket."}}
|
| 67 |
```
|
| 68 |
|
| 69 |
### 3. Bowling & Fielding (when agent bowls)
|
| 70 |
-
Choose
|
| 71 |
```json
|
| 72 |
-
{"tool": "choose_bowler", "arguments": {"name": "
|
| 73 |
{"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full"}}
|
| 74 |
{"tool": "set_field_setting", "arguments": {"setting": "Aggressive"}}
|
| 75 |
{"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full", "rationale": "Limit swing room against an aggressive finisher."}}
|
|
@@ -108,6 +154,8 @@ Query match intel at a small reward cost.
|
|
| 108 |
| `bowl_delivery` | Bowling | Bowl the next delivery |
|
| 109 |
| `reflect_after_ball` | Bat/Bowl | Record post-ball tactical adjustment |
|
| 110 |
| `analyze_situation` | Any | Query pitch, bowler, field, or match situation |
|
|
|
|
|
|
|
| 111 |
|
| 112 |
`plan_shot.target_area` is normalized into cricket zones such as `cover`, `point`, `straight`, `midwicket`, `square_leg`, `fine_leg`, `long_on`, and `long_off`. `plan_shot.trajectory` can be `ground`, `lofted`, or `aerial`. Delivery plans normalize line (`outside_off`, `stumps`, `pads`, `wide`), length (`yorker`, `full`, `good`, `short`, `bouncer`), and variation (`stock`, `swing`, `seam`, `slower`, `yorker`, `bouncer`, `off_spin`, `leg_spin`, `googly`).
|
| 113 |
|
|
@@ -115,22 +163,32 @@ Query match intel at a small reward cost.
|
|
| 115 |
|
| 116 |
## Reward Architecture
|
| 117 |
|
|
|
|
|
|
|
| 118 |
| Rubric | Weight | When | What |
|
| 119 |
|--------|--------|------|------|
|
| 120 |
-
| `
|
| 121 |
-
| `
|
| 122 |
-
| `
|
| 123 |
-
| `r_validity` |
|
|
|
|
|
|
|
| 124 |
|
| 125 |
`r_tools` is computed and logged for analysis but excluded from the composite — tool discipline is measured through outcome and behavior instead.
|
| 126 |
|
| 127 |
-
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
**Two-
|
| 132 |
-
-
|
| 133 |
-
-
|
|
|
|
|
|
|
| 134 |
|
| 135 |
**Innings-specific scoring:**
|
| 136 |
- **1st Innings (batting):** Score vs DLS par baseline
|
|
@@ -189,7 +247,7 @@ cricket_captain/
|
|
| 189 |
├── inference.py # Random + LLM agent evaluation
|
| 190 |
├── client.py # OpenEnv WebSocket client (CricketCaptainEnv)
|
| 191 |
├── models.py # GameState, CricketAction, CricketObservation, CricketState
|
| 192 |
-
├── train.py #
|
| 193 |
├── eval.py # Coherence heatmaps, reward curves
|
| 194 |
├── scripts/
|
| 195 |
│ └── curate_transitions.py # Cricsheet → Markov transition table pipeline
|
|
@@ -234,9 +292,19 @@ Both the heuristic opponent and the environment's `select_batter` / `choose_bowl
|
|
| 234 |
|
| 235 |
`server/player_roster.py` loads team profiles from `data/player_profiles/` (10 T20I squads: India, Australia, England, Pakistan, South Africa, New Zealand, West Indies, Sri Lanka, Bangladesh, Afghanistan). When the agent calls `select_batter` or `choose_bowler` with a player name, the roster performs fuzzy lookup (exact → surname → word-overlap) and fills in real aggression, batting/bowling style, and phase strengths.
|
| 236 |
|
| 237 |
-
### Tool
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
### Markov Engine
|
| 242 |
|
|
@@ -250,81 +318,201 @@ Ball outcomes are sampled from a 5-dimensional transition table keyed by `(over,
|
|
| 250 |
|
| 251 |
Bowler rotation mirrors real cricket: pace-heavy powerplay (90/10), spin-heavy middle overs (45/55), pace-heavy death (80/20). Each bowler has a 10-over cap before rotation is forced.
|
| 252 |
|
| 253 |
-
###
|
| 254 |
|
| 255 |
-
|
| 256 |
|
| 257 |
-
```
|
| 258 |
-
# Prompt always contains:
|
| 259 |
-
# "[CricketCaptain] MIDDLE | FIRST INNINGS"
|
| 260 |
-
# "Over 18.2 | Score: 145/4 | ..."
|
| 261 |
-
# "Current Strategy: consolidate (aggression=0.30) — Rotate strike against spin..."
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
```
|
| 269 |
|
| 270 |
---
|
| 271 |
|
| 272 |
## Quickstart
|
| 273 |
|
| 274 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
-
|
| 277 |
-
- **
|
| 278 |
-
-
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
| 281 |
|
| 282 |
```bash
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
#
|
| 288 |
-
export
|
| 289 |
-
export HF_TOKEN="hf_..."
|
| 290 |
python inference.py --config configs/default.yaml --episodes 1
|
| 291 |
```
|
| 292 |
|
| 293 |
-
|
| 294 |
|
| 295 |
-
|
| 296 |
|
|
|
|
|
|
|
|
|
|
| 297 |
```bash
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
python inference.py --config configs/cached_eval.yaml --episodes 1
|
| 302 |
```
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
```bash
|
| 305 |
-
#
|
| 306 |
-
|
|
|
|
| 307 |
|
| 308 |
-
#
|
| 309 |
-
|
| 310 |
-
|
| 311 |
|
| 312 |
-
#
|
| 313 |
-
|
| 314 |
-
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
|
|
|
| 318 |
|
| 319 |
-
#
|
| 320 |
-
PYTHONPATH=. python server/ui.py # → http://localhost:7860
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
```
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
### Lightning / Remote Runtime Notes
|
| 329 |
|
| 330 |
`localhost:8000` only works when the agent process and server process are in the same network namespace. On Lightning, expose the server port and pass the resulting WebSocket URL via:
|
|
@@ -351,10 +539,9 @@ For fast iteration, start with short 5-over runs before full 20-over evaluation:
|
|
| 351 |
|
| 352 |
1. Random baseline with heuristic opponent.
|
| 353 |
2. Base/untrained LLM baseline.
|
| 354 |
-
3.
|
| 355 |
-
4. GRPO
|
| 356 |
-
5.
|
| 357 |
-
6. Eval with `adaptive_t20_v1` and cached opponent decisions.
|
| 358 |
|
| 359 |
See [`docs/experiment_workflow.md`](docs/experiment_workflow.md) for exact commands and rationale.
|
| 360 |
|
|
@@ -494,6 +681,30 @@ Still needed for final submission: real trained-vs-baseline plots, HF Space URL,
|
|
| 494 |
|
| 495 |
---
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
## Citation
|
| 498 |
|
| 499 |
```
|
|
|
|
| 12 |
|
| 13 |
# CricketCaptain-LLM
|
| 14 |
|
| 15 |
+
**An OpenEnv benchmark for long-horizon multi-turn agentic RL: train an LLM to captain a full cricket match.**
|
| 16 |
|
| 17 |
+
🔗 **Live Hugging Face Space:** https://huggingface.co/spaces/pratinavseth/cricket-captain-llm
|
| 18 |
+
📊 **Training runs (W&B):** https://wandb.ai/ptnv-s-research/huggingface
|
| 19 |
+
📝 **Mini-blog / video:** _coming soon — placeholder, will link before submission deadline_
|
| 20 |
+
|
| 21 |
+
**Hackathon theme alignment — Theme #2: (Super) Long-Horizon Planning & Instruction Following.** A T20 match is up to **240 legal balls of strategic decision-making across both offensive and defensive roles** — exactly the regime where current LLM agents struggle: deep multi-step reasoning, sparse terminal reward (win/loss arrives 100+ turns after the early decisions that caused it), and recovery from early mistakes (a wicket in over 1 should reshape the plan for over 19).
|
| 22 |
+
|
| 23 |
+
CricketCaptain is a serious test of an LLM's ability to handle **long-horizon planning under sparse, delayed reward** with **opponent modeling** and **multi-action-type credit assignment**. One match = ~180 sequential tool calls across batting AND bowling phases, with 13 state-conditioned tools, a real Markov ball-outcome engine trained on **1.65M cricsheet deliveries**, and a composite reward signal that scores plan coherence, tactical adaptation, opponent awareness, regret, and final match outcome.
|
| 24 |
+
|
| 25 |
+
This is the agentic-RL equivalent of training a coding agent — partial trajectory rollouts, dense intermediate rewards, sparse terminal outcome — but in a domain where the strategy space is genuinely **two-sided** (you don't just write code; you also *anticipate the opposing captain*) and the action distribution is **mixed-discrete-continuous** (categorical tool choice + numeric aggression + free-text rationale).
|
| 26 |
|
| 27 |
The Hugging Face Space exposes the OpenEnv server and a Gradio demo UI at `/web`.
|
| 28 |
|
| 29 |
+
**Why this is novel:** No prior agentic-RL benchmark covers strategic two-sided sports captaincy with phase-aware tool gating, real-data outcome simulation, and a composable rubric system. Cricket is also one of the few domains where *the same agent must alternate between offensive and defensive policies* within a single episode — a capability gap most multi-turn agent benchmarks (SWE-bench, WebArena, AgentBench) don't touch.
|
| 30 |
+
|
| 31 |
---
|
| 32 |
|
| 33 |
## The Problem
|
|
|
|
| 46 |
|
| 47 |
---
|
| 48 |
|
| 49 |
+
## Train → Eval Story (the framing)
|
| 50 |
+
|
| 51 |
+
This is **multi-turn agentic RL with partial-trajectory training and full-task eval generalization** — the same pattern coding-agent RL papers (SWE-RL, AgentR) use.
|
| 52 |
+
|
| 53 |
+
| | TRAINING (warmup) | TRAINING (main) | EVALUATION |
|
| 54 |
+
|---|---|---|---|
|
| 55 |
+
| max overs | 2-3 (curriculum) | 5 (end-to-end) | full T20 (20 over) |
|
| 56 |
+
| steps | 30 | 100 | n/a |
|
| 57 |
+
| token budget per rollout | 16k (≈ 80–180 turns) | 24k (≈ 120–240 turns) | unlimited (full match plays out) |
|
| 58 |
+
| reward signal | composite (`r_result` + `r_cricket` + `r_behavior` + `r_validity`) — see `server/reward_calculator.py` | same | **headline metric: win rate vs heuristic baseline** |
|
| 59 |
+
| what model learns | format mastery + per-state decisions on short formats | full-match strategic depth on 5-over | composes per-state decisions into match-winning trajectories |
|
| 60 |
+
| script | `python train.py train --config configs/cricket_train_qwen3_warmup.yaml` | `python train.py train --config configs/cricket_train_qwen3.yaml` | `compare_eval.py` (baseline + trained, prints comparison) |
|
| 61 |
+
|
| 62 |
+
The training operates on shorter formats than evaluation, but the trained policy generalizes to full matches at inference because it learns good per-state decisions, not specific trajectory lengths. The whole chain runs via `bash scripts/run_warmup_then_main.sh`.
|
| 63 |
+
|
| 64 |
+
## Results
|
| 65 |
+
|
| 66 |
+
**Live W&B project:** https://wandb.ai/ptnv-s-research/huggingface
|
| 67 |
+
|
| 68 |
+
Training stack: Qwen3-4B-Instruct-2507 + LoRA r=64 + TRL GRPO, vLLM colocate, 1× H200. Warmup is on a 2–3 over curriculum (30 steps); the main run trains on 5-over end-to-end matches (100 steps) and resumes from the warmup adapter.
|
| 69 |
+
|
| 70 |
+
Static plots will land here once the chain completes:
|
| 71 |
+
- **Training reward curve:** [docs/plots/training_reward_over_steps.png](docs/plots/training_reward_over_steps.png)
|
| 72 |
+
- **Per-rubric breakdown:** [docs/plots/per_rubric_breakdown.png](docs/plots/per_rubric_breakdown.png)
|
| 73 |
+
- **Tool-call execution frequency:** [docs/plots/tool_call_frequency.png](docs/plots/tool_call_frequency.png)
|
| 74 |
+
- **Before/after comparison:** [docs/plots/before_after_comparison.png](docs/plots/before_after_comparison.png)
|
| 75 |
+
|
| 76 |
+
### Key engineering findings (documented in commit history)
|
| 77 |
+
|
| 78 |
+
| Issue surfaced | Fix | Effect |
|
| 79 |
+
|---|---|---|
|
| 80 |
+
| Only ~19% of rollouts reached `done` naturally | Restrict the system prompt to a single `<tool_call>...</tool_call>` XML format (the prompt previously also advertised bare JSON, which TRL's response-schema parser rejects, ending the rollout) | tools/call_frequency 9 → 73; rollouts 5–8× longer; matches actually play out |
|
| 81 |
+
| GRPO group std collapsing to 0 once matches completed | Remove the `[-1, 1]` reward clip — let GRPO standardize the advantage itself | reward std 0.0 → 1.5; gradient signal restored |
|
| 82 |
+
| Composite reward stayed positive even on dominant losses | Add explicit `outcome_bonus = -1.0` for losses (was 0.0); reduce the always-positive `progress_bonus` cap | composite now spans negative AND positive — model has a real reason to win |
|
| 83 |
+
|
| 84 |
## Read This First
|
| 85 |
|
| 86 |
- [`docs/benchmark_explainer.md`](docs/benchmark_explainer.md): full explanation of the problem statement, OpenEnv architecture, environment loop, rewards, data curation, training, and competition compliance.
|
|
|
|
| 103 |
```
|
| 104 |
|
| 105 |
### 2. Batting (when agent bats)
|
| 106 |
+
Choose a real roster batter, set/update the long-horizon plan, plan the shot, then play each delivery.
|
| 107 |
```json
|
| 108 |
+
{"tool": "set_match_plan", "arguments": {"powerplay_intent": "Use V Kohli and NT Tilak Varma to build a stable platform.", "middle_intent": "Rotate against spin and attack weak matchups.", "death_intent": "Use finishers for boundary options.", "risk_budget": "Escalate only with wickets in hand or target pressure.", "trigger_conditions": "Review after wickets, phase changes, or repeated dots/boundaries.", "rationale": "Roster-aware plan for a short chase."}}
|
| 109 |
+
{"tool": "select_batter", "arguments": {"name": "V Kohli", "style": "balanced", "aggression": 0.45, "rationale": "Reliable top-order batter to control risk early."}}
|
| 110 |
{"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Middle overs against spin — rotate strike and preserve wickets."}}
|
| 111 |
{"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy gap."}}
|
| 112 |
{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Working into the gap at mid-wicket."}}
|
| 113 |
```
|
| 114 |
|
| 115 |
### 3. Bowling & Fielding (when agent bowls)
|
| 116 |
+
Choose a real roster bowler, set a delivery/field plan, then bowl each delivery against an opponent policy.
|
| 117 |
```json
|
| 118 |
+
{"tool": "choose_bowler", "arguments": {"name": "BB Sran", "bowler_type": "pace", "style": "economy", "rationale": "Use a roster pacer in the powerplay with a new ball."}}
|
| 119 |
{"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full"}}
|
| 120 |
{"tool": "set_field_setting", "arguments": {"setting": "Aggressive"}}
|
| 121 |
{"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full", "rationale": "Limit swing room against an aggressive finisher."}}
|
|
|
|
| 154 |
| `bowl_delivery` | Bowling | Bowl the next delivery |
|
| 155 |
| `reflect_after_ball` | Bat/Bowl | Record post-ball tactical adjustment |
|
| 156 |
| `analyze_situation` | Any | Query pitch, bowler, field, or match situation |
|
| 157 |
+
| `set_match_plan` | Bat/Bowl | Establish powerplay/middle/death plan, risk budget, and triggers |
|
| 158 |
+
| `update_match_plan` | Bat/Bowl | Revise match plan with a match-state reason |
|
| 159 |
|
| 160 |
`plan_shot.target_area` is normalized into cricket zones such as `cover`, `point`, `straight`, `midwicket`, `square_leg`, `fine_leg`, `long_on`, and `long_off`. `plan_shot.trajectory` can be `ground`, `lofted`, or `aerial`. Delivery plans normalize line (`outside_off`, `stumps`, `pads`, `wide`), length (`yorker`, `full`, `good`, `short`, `bouncer`), and variation (`stock`, `swing`, `seam`, `slower`, `yorker`, `bouncer`, `off_spin`, `leg_spin`, `googly`).
|
| 161 |
|
|
|
|
| 163 |
|
| 164 |
## Reward Architecture
|
| 165 |
|
| 166 |
+
A **composable 4-rubric composite** following the SWE-RL recipe (60% intermediate / 40% terminal) — chosen because partial-trajectory training (where most episodes truncate before completion) needs gradient signal that actually fires. Putting most weight on the rare terminal reward washes out learning.
|
| 167 |
+
|
| 168 |
| Rubric | Weight | When | What |
|
| 169 |
|--------|--------|------|------|
|
| 170 |
+
| `r_cricket` | **45%** | Per ball | Dream11-style proxy — runs, wickets, dots, boundaries, economy, milestones |
|
| 171 |
+
| `r_behavior` | **25%** | Every turn | Coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%) |
|
| 172 |
+
| `r_result` | **20%** | Innings/episode end | Match outcome: chase progress, defense margin, win bonus, DLS par |
|
| 173 |
+
| `r_validity` | **10%** | Every turn | Valid tool-call structure and legal phase-gated tool use |
|
| 174 |
+
|
| 175 |
+
Plus a **progress bonus** added to `r_result`: `min(0.25, tool_calls_made / 40.0)` — caps at +0.25 once the agent makes ≥10 tool calls. Directly rewards escaping the "planning loop" trap (where the policy maxes overhead tools without ever calling `play_delivery`).
|
| 176 |
|
| 177 |
`r_tools` is computed and logged for analysis but excluded from the composite — tool discipline is measured through outcome and behavior instead.
|
| 178 |
|
| 179 |
+
**Why this weighting works:** in partial-trajectory training every single turn produces a reward (validity + behavior + per-ball Dream11 when balls bowl); the terminal `r_result` only fires when an episode actually completes. Reward weights are calibrated to put gradient on the dense signals that fire most often, exactly mirroring the SWE-RL / coding-agent-RL recipe.
|
| 180 |
|
| 181 |
+
**Single-stage training (full composite reward from step 0):**
|
| 182 |
+
Qwen3-4B-Instruct-2507 emits `<tool_call>...</tool_call>` natively, so we skip the legacy
|
| 183 |
+
"format mastery" warm-up and run the full composite reward (`r_result + r_cricket +
|
| 184 |
+
r_behavior + r_validity`) from step 0. The internal `curriculum_stage` field is still
|
| 185 |
+
set to `2` for code-path compatibility — it just means "full reward".
|
| 186 |
|
| 187 |
+
**Two-config workflow:**
|
| 188 |
+
- [`configs/cricket_train_qwen3_warmup.yaml`](configs/cricket_train_qwen3_warmup.yaml) — short
|
| 189 |
+
2-3 over curriculum, 30 steps. Bootstraps the LoRA adapter on a fast format.
|
| 190 |
+
- [`configs/cricket_train_qwen3.yaml`](configs/cricket_train_qwen3.yaml) — 5-over end-to-end,
|
| 191 |
+
100 steps. `resume_from: ./checkpoints/stage2_final` picks up the warmup adapter.
|
| 192 |
|
| 193 |
**Innings-specific scoring:**
|
| 194 |
- **1st Innings (batting):** Score vs DLS par baseline
|
|
|
|
| 247 |
├── inference.py # Random + LLM agent evaluation
|
| 248 |
├── client.py # OpenEnv WebSocket client (CricketCaptainEnv)
|
| 249 |
├── models.py # GameState, CricketAction, CricketObservation, CricketState
|
| 250 |
+
├── train.py # TRL GRPO agent training with environment_factory tool calls
|
| 251 |
├── eval.py # Coherence heatmaps, reward curves
|
| 252 |
├── scripts/
|
| 253 |
│ └── curate_transitions.py # Cricsheet → Markov transition table pipeline
|
|
|
|
| 292 |
|
| 293 |
`server/player_roster.py` loads team profiles from `data/player_profiles/` (10 T20I squads: India, Australia, England, Pakistan, South Africa, New Zealand, West Indies, Sri Lanka, Bangladesh, Afghanistan). When the agent calls `select_batter` or `choose_bowler` with a player name, the roster performs fuzzy lookup (exact → surname → word-overlap) and fills in real aggression, batting/bowling style, and phase strengths.
|
| 294 |
|
| 295 |
+
### Tool budget (per over)
|
| 296 |
+
|
| 297 |
+
The simulator counts **strategic / analysis** tools that do not advance the ball. Constants live in `CricketEnvironment` as `TOOL_BUDGET_PER_OVER` (3) and `TOOL_FINE_PER_EXCESS` (0.04).
|
| 298 |
+
|
| 299 |
+
**Overhead tools (count toward the 3 / over):** `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, `analyze_situation`.
|
| 300 |
|
| 301 |
+
**Not overhead (no per-over counter):** `plan_shot`, `set_match_plan`, `update_match_plan`, `select_batter`, `choose_bowler`, `set_field_setting`, `play_delivery`, `bowl_delivery`, `call_toss`, and other execution tools.
|
| 302 |
+
|
| 303 |
+
The first **three** overhead calls in a given over are free of this fine. Each additional overhead call in that over applies an immediate **−0.04** step reward. The prompt shows `Tool budget: N/3 overhead calls used this over` so the model can learn to ration reflection and re-planning.
|
| 304 |
+
|
| 305 |
+
### Tool budget and training
|
| 306 |
+
|
| 307 |
+
`train.py train` rollouts are full environment episodes: the same fines apply on every `step` the policy takes. Over a long match, repeatedly burning the budget (for example, `analyze_situation` or `reflect_after_ball` on most balls) **accumulates** many small penalties and competes with match outcome and behavior rewards. GRPO therefore sees a direct signal to use overhead tools when they change decisions, not as padding. Long-horizon **match plans** (`set_match_plan` / `update_match_plan`) are not charged against this overhead budget, so the agent can state multi-phase intent without spending the 3 “slots” on raw analysis calls.
|
| 308 |
|
| 309 |
### Markov Engine
|
| 310 |
|
|
|
|
| 318 |
|
| 319 |
Bowler rotation mirrors real cricket: pace-heavy powerplay (90/10), spin-heavy middle overs (45/55), pace-heavy death (80/20). Each bowler has a 10-over cap before rotation is forced.
|
| 320 |
|
| 321 |
+
### GRPO Agent Training
|
| 322 |
|
| 323 |
+
`train.py train` uses TRL `GRPOTrainer` with `environment_factory=CricketCaptainToolEnv`. The trainer creates live `CricketEnvironment` instances, exposes captaincy actions as tool methods, and lets the model interact over multiple tool-calling turns instead of scoring isolated prompt/completion strings.
|
| 324 |
|
| 325 |
+
The dataset is a set of seeded cricket scenarios. Each rollout resets the environment with `agent_team`, `opponent_mode`, `max_overs`, and optional eval-pack/cache settings. Rewards come back from the environment after real state transitions, so training sees wickets, targets, role swaps, plan updates, and terminal match results.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
+
The training model is configured separately from inference/demo models:
|
| 328 |
+
|
| 329 |
+
- `train.model`: model being optimized with GRPO, currently `Qwen/Qwen3-4B-Instruct-2507` (256k native context, native `Qwen3ForCausalLM` in vLLM, no thinking blocks).
|
| 330 |
+
- `opponent.model`: live opponent model used when `opponent.mode=llm_live`, currently `google/gemma-4-26B-A4B-it`.
|
| 331 |
+
- `captain.model`: inference/evaluation captain model used by `inference.py` and `eval.py`.
|
|
|
|
| 332 |
|
| 333 |
---
|
| 334 |
|
| 335 |
## Quickstart
|
| 336 |
|
| 337 |
+
### 1. Install (uv)
|
| 338 |
+
|
| 339 |
+
System requirements: Python 3.10+, CUDA 12.x for training. Inference / random baselines / Gradio UI work CPU-only.
|
| 340 |
+
|
| 341 |
+
This project is managed with [uv](https://docs.astral.sh/uv/). Versions in [pyproject.toml](pyproject.toml) are pinned to a known-working set (transformers 5.6.2 + trl 1.2.0 + vllm 0.19.1 + torch 2.10.0) — these are the lowest combination that supports TRL multi-turn `environment_factory` AND vLLM colocate AND transformers v5 chat templates. Earlier vLLM pins (<0.19) force transformers <5 and break multi-turn training.
|
| 342 |
+
|
| 343 |
+
```bash
|
| 344 |
+
# Clone and enter
|
| 345 |
+
git clone <this-repo> cricket-captain-llm
|
| 346 |
+
cd cricket-captain-llm
|
| 347 |
+
|
| 348 |
+
# Create a venv and install core + training extras + eval plots
|
| 349 |
+
uv venv .venv --python 3.10
|
| 350 |
+
uv pip install --python .venv/bin/python -e ".[train,eval]"
|
| 351 |
+
|
| 352 |
+
# Activate
|
| 353 |
+
source .venv/bin/activate
|
| 354 |
+
|
| 355 |
+
# HuggingFace login (only needed for gated models / model downloads)
|
| 356 |
+
huggingface-cli login
|
| 357 |
+
# or: export HF_TOKEN=hf_...
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
Inference-only (no GPU) install:
|
| 361 |
+
|
| 362 |
+
```bash
|
| 363 |
+
uv venv .venv --python 3.10
|
| 364 |
+
uv pip install --python .venv/bin/python -e .
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
### 2. YAML config (single source of truth)
|
| 368 |
+
|
| 369 |
+
All commands read defaults from a YAML — pass `--config configs/cricket_train_qwen3.yaml` (or one of the others under `configs/`) and only override what you need on the CLI. Three role groups:
|
| 370 |
+
|
| 371 |
+
| Group | Keys | Used by |
|
| 372 |
+
|---|---|---|
|
| 373 |
+
| `env.*` | `agent_team`, `max_overs`, `eval_pack_id`, `env_url` | server, inference, training |
|
| 374 |
+
| `opponent.*` | `mode`, `model`, `api_base`, `api_key_env`, `cache_path` | server (heuristic / cricsheet / llm_live / llm_cached) |
|
| 375 |
+
| `captain.*` | `model`, `api_base`, `api_key_env` | inference / eval (when using a live captain LLM) |
|
| 376 |
+
| `train.*` | `model`, `resume_from`, `stage`, `prompts`, `steps`, `batch_size`, `grad_accum`, `num_generations`, `max_completion_length`, `max_tool_calling_iterations`, `learning_rate`, `beta`, `temperature`, `top_p`, `gradient_checkpointing`, `gradient_checkpointing_use_reentrant`, `dataloader_pin_memory`, `dataloader_num_workers`, `bf16_base`, `save_steps`, `save_total_limit`, `report_to`, `run_name` | GRPO trainer |
|
| 377 |
|
| 378 |
+
- `configs/cricket_train_qwen3_warmup.yaml` — **GRPO warmup** (2-3 over curriculum, 30 steps). Run first.
|
| 379 |
+
- `configs/cricket_train_qwen3.yaml` — **GRPO main** (5-over end-to-end, 100 steps). Resumes from warmup adapter via `resume_from:`.
|
| 380 |
+
- `configs/cricket_train_qwen3_smoke.yaml` — 2-step infrastructure smoke test.
|
| 381 |
+
- `configs/game_knowledge.yaml` — reward weights and game constants (loaded at import time).
|
| 382 |
+
- `configs/extras/` — legacy Qwen3.5 configs and `default.yaml`, kept for reference.
|
| 383 |
|
| 384 |
+
### 3. Run the environment server (for inference / Gradio)
|
| 385 |
|
| 386 |
```bash
|
| 387 |
+
cd cricket-captain-llm
|
| 388 |
+
PYTHONPATH=. python server/app.py --port 8000 --config configs/extras/default.yaml
|
| 389 |
+
```
|
| 390 |
+
|
| 391 |
+
The server exposes the OpenEnv WebSocket at `ws://localhost:8000/ws` and a Gradio UI at `http://localhost:8000/web`.
|
| 392 |
+
|
| 393 |
+
### 4. Inference baselines (no training required)
|
| 394 |
+
|
| 395 |
+
```bash
|
| 396 |
+
export CRICKET_CAPTAIN_ENV_URL="ws://localhost:8000"
|
| 397 |
+
|
| 398 |
+
# Random agent (no API key needed) — fast sanity check
|
| 399 |
+
python inference.py --model random --episodes 5 --opponent-mode heuristic --max-overs 5
|
| 400 |
|
| 401 |
+
# Live HF Gemma captain baseline using config defaults
|
| 402 |
+
export HF_TOKEN=hf_...
|
|
|
|
| 403 |
python inference.py --config configs/default.yaml --episodes 1
|
| 404 |
```
|
| 405 |
|
| 406 |
+
### 5. GRPO training
|
| 407 |
|
| 408 |
+
Training does **not** need the server — it instantiates `CricketEnvironment` directly and runs it through TRL `GRPOTrainer` with `environment_factory`. All rollouts are simulated live (no static dataset).
|
| 409 |
|
| 410 |
+
The recommended workflow is **warmup → main run**, both controlled entirely from YAML:
|
| 411 |
+
|
| 412 |
+
**Single-command chain** (warmup → main, auto-resume from warmup adapter):
|
| 413 |
```bash
|
| 414 |
+
bash scripts/run_warmup_then_main.sh
|
| 415 |
+
# Logs: /tmp/train_warmup.log (then /tmp/train_main.log on success)
|
| 416 |
+
# Final adapter: ./checkpoints/stage2_final/
|
|
|
|
| 417 |
```
|
| 418 |
|
| 419 |
+
The chain is what we run to produce a trained model. Internally:
|
| 420 |
+
|
| 421 |
+
**Step 1 — Warmup (2-3 over curriculum, 30 steps, ~50–60 min on a single H200):**
|
| 422 |
+
- Curriculum-distributed `max_overs` (heavy on 2-over, tail to 3-over) so episodes
|
| 423 |
+
complete inside the token budget and `r_result` fires reliably.
|
| 424 |
+
- Bootstraps the LoRA adapter from base Qwen3-4B-Instruct-2507 → saves to
|
| 425 |
+
`./checkpoints/stage2_final/`.
|
| 426 |
+
|
| 427 |
+
**Step 2 — Main (5-over end-to-end, 100 steps, ~5–7 hrs):**
|
| 428 |
+
- Resumes the warmup adapter via `resume_from: ./checkpoints/stage2_final`.
|
| 429 |
+
- Trains on full 5-over matches with the `r_result` outcome signal as the dominant gradient driver.
|
| 430 |
+
- Final adapter at `./checkpoints/stage2_final/` (overwritten — that's the deliverable).
|
| 431 |
+
|
| 432 |
+
**Run components individually:**
|
| 433 |
```bash
|
| 434 |
+
# Warmup only
|
| 435 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 436 |
+
python train.py train --config configs/cricket_train_qwen3_warmup.yaml
|
| 437 |
|
| 438 |
+
# Main only (assumes ./checkpoints/stage2_final/ exists)
|
| 439 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 440 |
+
python train.py train --config configs/cricket_train_qwen3.yaml
|
| 441 |
|
| 442 |
+
# Main without resuming (fresh adapter)
|
| 443 |
+
python train.py train --config configs/cricket_train_qwen3.yaml --resume-from ""
|
| 444 |
+
```
|
| 445 |
|
| 446 |
+
The opponent is `heuristic` (rule-based) by default for fast iteration.
|
| 447 |
+
Switch to `mode: llm_live` in `configs/cricket_train_qwen3.yaml` (and set `HF_TOKEN`) to train
|
| 448 |
+
against the live Gemma adversary.
|
| 449 |
|
| 450 |
+
### 6. Evaluating the trained model
|
|
|
|
| 451 |
|
| 452 |
+
The eval-time story is the headline. Training caps rollouts at the warmup/main token budgets (16k / 24k), so warmup rollouts run 2–3 over and main rollouts run 5-over. At inference there's no token cap — full T20 matches play out. This is the same pattern coding-agent RL papers (SWE-RL, AgentR) use: train on partial windows, evaluate on full task completion.
|
| 453 |
+
|
| 454 |
+
```bash
|
| 455 |
+
# Baseline (untrained Qwen3-4B-Instruct-2507)
|
| 456 |
+
python compare_eval.py --model Qwen/Qwen3-4B-Instruct-2507 \
|
| 457 |
+
--label baseline --episodes 20 --max-overs 5 \
|
| 458 |
+
--output eval_results/baseline.json
|
| 459 |
+
|
| 460 |
+
# Trained (uses LoRA adapter)
|
| 461 |
+
python compare_eval.py --model Qwen/Qwen3-4B-Instruct-2507 \
|
| 462 |
+
--adapter ./checkpoints/stage2_final \
|
| 463 |
+
--label trained --episodes 20 --max-overs 5 \
|
| 464 |
+
--output eval_results/trained.json
|
| 465 |
+
|
| 466 |
+
# Side-by-side comparison table
|
| 467 |
+
python compare_eval.py --compare \
|
| 468 |
+
eval_results/baseline.json \
|
| 469 |
+
eval_results/trained.json
|
| 470 |
```
|
| 471 |
|
| 472 |
+
Each `--episodes 20` run takes ~30-45 min depending on match length. The comparison
|
| 473 |
+
prints a side-by-side table of: match completion rate, win rate, mean agent score,
|
| 474 |
+
mean wickets lost, mean tool calls, validity rate, and per-rubric reward breakdown.
|
| 475 |
+
|
| 476 |
+
### 6. Tuning batch size for your GPU
|
| 477 |
+
|
| 478 |
+
The dominant memory consumer during GRPO is **attention prefill of input prompts** (game state ~1300 tokens) for `generation_batch_size = batch_size × grad_accum` simultaneous sequences.
|
| 479 |
+
|
| 480 |
+
| GPU VRAM | flash-attn? | Recommended `batch_size` × `grad_accum` | `max_completion_length` |
|
| 481 |
+
|---|---|---|---|
|
| 482 |
+
| 24 GB (A10/3090) | required | 1 × 4 = 4 | 512 |
|
| 483 |
+
| 46 GB (L40S) | recommended | 4 × 4 = 16 | 512 |
|
| 484 |
+
| 46 GB (L40S) | not installed | 2 × 4 = 8 | 512 |
|
| 485 |
+
| 80 GB (A100/H100) | required | 8 × 4 = 32 | 1024 |
|
| 486 |
+
|
| 487 |
+
Cricket tool calls are short JSON objects (~20–300 tokens), so `max_completion_length: 512` is plenty. Without flash-attn, SDPA allocates `O(seq_len²)` attention matrices — keep batches small or install flash-attn.
|
| 488 |
+
|
| 489 |
+
If you OOM, halve `batch_size` first. If you OOM on prefill specifically (during generation, not gradient), it's the prompt length × generation_batch_size — install flash-attn or shrink `grad_accum`.
|
| 490 |
+
|
| 491 |
+
### 7. Where to find logs and outputs
|
| 492 |
+
|
| 493 |
+
| Path | What |
|
| 494 |
+
|---|---|
|
| 495 |
+
| stdout | Per-step loss, reward, lr (every `logging_steps=10`) plus full sampled completions |
|
| 496 |
+
| `checkpoints/stage{1,2}/` | HF Trainer state, intermediate LoRA checkpoints (every ~80 steps) |
|
| 497 |
+
| `checkpoints/stage{1,2}_final/` | Final LoRA + tokenizer |
|
| 498 |
+
| `illustrations/exp_*/run_output.txt` | Per-step environment trace from `inference.py` and `train.py train-smoke` |
|
| 499 |
+
|
| 500 |
+
For a tensorboard dashboard, set `train.report_to: tensorboard` in the YAML, then:
|
| 501 |
+
```bash
|
| 502 |
+
tensorboard --logdir checkpoints/
|
| 503 |
+
```
|
| 504 |
+
|
| 505 |
+
### 8. Smoke test (no model load)
|
| 506 |
+
|
| 507 |
+
Verify the environment + opponent + tool budget end-to-end without loading any model:
|
| 508 |
+
```bash
|
| 509 |
+
PYTHONPATH=. python train.py train-smoke \
|
| 510 |
+
--config configs/extras/default.yaml \
|
| 511 |
+
--matches 1 --max-overs 2 \
|
| 512 |
+
--opponent-mode heuristic
|
| 513 |
+
```
|
| 514 |
+
This runs one short match with random actions and writes a full step log to `illustrations/exp_*/`.
|
| 515 |
+
|
| 516 |
### Lightning / Remote Runtime Notes
|
| 517 |
|
| 518 |
`localhost:8000` only works when the agent process and server process are in the same network namespace. On Lightning, expose the server port and pass the resulting WebSocket URL via:
|
|
|
|
| 539 |
|
| 540 |
1. Random baseline with heuristic opponent.
|
| 541 |
2. Base/untrained LLM baseline.
|
| 542 |
+
3. GRPO warmup (`configs/cricket_train_qwen3_warmup.yaml`) for format mastery on short matches.
|
| 543 |
+
4. GRPO main (`configs/cricket_train_qwen3.yaml`) for full 5-over strategic depth, resuming the warmup adapter.
|
| 544 |
+
5. Eval with `adaptive_t20_v1` and cached opponent decisions.
|
|
|
|
| 545 |
|
| 546 |
See [`docs/experiment_workflow.md`](docs/experiment_workflow.md) for exact commands and rationale.
|
| 547 |
|
|
|
|
| 681 |
|
| 682 |
---
|
| 683 |
|
| 684 |
+
## Hackathon Submission Materials
|
| 685 |
+
|
| 686 |
+
| Material | Link |
|
| 687 |
+
|---|---|
|
| 688 |
+
| **Live HF Space** (env runs here) | https://huggingface.co/spaces/pratinavseth/cricket-captain-llm |
|
| 689 |
+
| **GitHub repo** | https://github.com/pratinavseth/cricket-captain-llm |
|
| 690 |
+
| **W&B project** (training runs) | https://wandb.ai/ptnv-s-research/huggingface |
|
| 691 |
+
| **Mini-blog (HF blog)** | _placeholder — will be added before submission_ |
|
| 692 |
+
| **Demo video (≤2 min, YouTube)** | _placeholder — will be added before submission_ |
|
| 693 |
+
| **Slide deck** | _optional — TBD_ |
|
| 694 |
+
|
| 695 |
+
### Submission checklist (per hackathon guidelines)
|
| 696 |
+
|
| 697 |
+
- [x] Use OpenEnv (latest release): `openenv-core[core]>=0.2.2` ([pyproject.toml](pyproject.toml))
|
| 698 |
+
- [x] Working training script (TRL GRPO): [train.py](train.py)
|
| 699 |
+
- [x] Reward + training pipeline coherent: 4-rubric composite (`r_result`, `r_cricket`, `r_behavior`, `r_validity`) with documented signal flow
|
| 700 |
+
- [x] HF Space pushed and discoverable
|
| 701 |
+
- [ ] Loss + reward plots from real run (in progress — generated after main run completes)
|
| 702 |
+
- [ ] Mini-blog or ≤2 min video (in progress)
|
| 703 |
+
- [x] README motivates problem, explains env, links materials
|
| 704 |
+
- [x] OpenEnv compliance: `Environment` base class, valid `openenv.yaml`, no reserved tool names
|
| 705 |
+
|
| 706 |
+
---
|
| 707 |
+
|
| 708 |
## Citation
|
| 709 |
|
| 710 |
```
|
compare_eval.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
compare_eval.py — Baseline vs trained head-to-head evaluation.
|
| 3 |
+
|
| 4 |
+
Plays N full cricket matches with the BASELINE model (untrained Qwen3.5-4B)
|
| 5 |
+
and the TRAINED model (Qwen3.5-4B + LoRA adapter from a checkpoint), then
|
| 6 |
+
dumps a comparison table:
|
| 7 |
+
|
| 8 |
+
win_rate, mean_agent_score, mean_opp_score, mean_wickets, match_completion_rate,
|
| 9 |
+
mean_tool_calls_per_episode, validity_rate, plus a few illustrative transcripts.
|
| 10 |
+
|
| 11 |
+
Why this is the right eval for our setup
|
| 12 |
+
----------------------------------------
|
| 13 |
+
Training uses a token budget per rollout (~4096 tokens, ~16 turns) which truncates
|
| 14 |
+
most matches. At EVAL time we lift that cap entirely — the model gets unlimited
|
| 15 |
+
context and can actually play full matches. This is the same pattern coding-agent
|
| 16 |
+
RL papers use: train on partial windows, eval on full task completion. The trained
|
| 17 |
+
policy generalizes because it learned good per-state decisions, not a specific
|
| 18 |
+
trajectory length.
|
| 19 |
+
|
| 20 |
+
Usage
|
| 21 |
+
-----
|
| 22 |
+
# Baseline (untrained Qwen3.5-4B base)
|
| 23 |
+
python compare_eval.py \\
|
| 24 |
+
--model Qwen/Qwen3.5-4B \\
|
| 25 |
+
--label baseline \\
|
| 26 |
+
--episodes 20 --max-overs 5 \\
|
| 27 |
+
--output eval_results/baseline.json
|
| 28 |
+
|
| 29 |
+
# Trained (warmup + main checkpoint)
|
| 30 |
+
python compare_eval.py \\
|
| 31 |
+
--model Qwen/Qwen3.5-4B \\
|
| 32 |
+
--adapter ./checkpoints/stage2_final \\
|
| 33 |
+
--label trained \\
|
| 34 |
+
--episodes 20 --max-overs 5 \\
|
| 35 |
+
--output eval_results/trained.json
|
| 36 |
+
|
| 37 |
+
# Side-by-side comparison
|
| 38 |
+
python compare_eval.py --compare eval_results/baseline.json eval_results/trained.json
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import argparse
|
| 42 |
+
import json
|
| 43 |
+
import os
|
| 44 |
+
import sys
|
| 45 |
+
import time
|
| 46 |
+
from collections import Counter
|
| 47 |
+
from pathlib import Path
|
| 48 |
+
|
| 49 |
+
import torch
|
| 50 |
+
from peft import PeftModel
|
| 51 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 52 |
+
|
| 53 |
+
from server.cricket_environment import CricketEnvironment
|
| 54 |
+
from models import CricketAction
|
| 55 |
+
import train as train_module # reuse SYSTEM_PROMPT and _parse_completion
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ----------------------------------------------------------------------------
|
| 59 |
+
# Model loading
|
| 60 |
+
# ----------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def load_model_for_eval(model_name: str, adapter_path: str | None = None):
|
| 63 |
+
"""Load base model in bf16; optionally apply a LoRA adapter on top."""
|
| 64 |
+
print(f"Loading base model: {model_name}")
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 66 |
+
if tokenizer.pad_token is None:
|
| 67 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 68 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 69 |
+
model_name,
|
| 70 |
+
torch_dtype=torch.bfloat16,
|
| 71 |
+
device_map="auto",
|
| 72 |
+
trust_remote_code=True,
|
| 73 |
+
)
|
| 74 |
+
if adapter_path:
|
| 75 |
+
print(f"Loading LoRA adapter: {adapter_path}")
|
| 76 |
+
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False)
|
| 77 |
+
model.eval()
|
| 78 |
+
return model, tokenizer
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ----------------------------------------------------------------------------
|
| 82 |
+
# Single-episode rollout (no token cap — let matches actually complete)
|
| 83 |
+
# ----------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
def play_one_episode(
|
| 86 |
+
*,
|
| 87 |
+
model,
|
| 88 |
+
tokenizer,
|
| 89 |
+
max_overs: int,
|
| 90 |
+
opponent_mode: str,
|
| 91 |
+
agent_team: str,
|
| 92 |
+
eval_pack_id: str,
|
| 93 |
+
seed: int,
|
| 94 |
+
max_tool_calls: int = 800,
|
| 95 |
+
max_completion_per_turn: int = 256, # per-turn (NOT per-rollout) — eval is turn-by-turn
|
| 96 |
+
temperature: float = 0.3, # deterministic-ish at eval
|
| 97 |
+
verbose: bool = False,
|
| 98 |
+
) -> dict:
|
| 99 |
+
"""Run one full match. Returns per-episode stats."""
|
| 100 |
+
env = CricketEnvironment()
|
| 101 |
+
obs = env.reset(seed=seed, options={
|
| 102 |
+
"task": "stage2_full",
|
| 103 |
+
"random_start": False,
|
| 104 |
+
"max_overs": max_overs,
|
| 105 |
+
"eval_pack_id": eval_pack_id,
|
| 106 |
+
"opponent_mode": opponent_mode,
|
| 107 |
+
"agent_team": agent_team,
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
# Build the message log progressively. Each turn appends model output + tool response.
|
| 111 |
+
system_prompt = train_module.SYSTEM_PROMPT
|
| 112 |
+
messages = [
|
| 113 |
+
{"role": "system", "content": system_prompt},
|
| 114 |
+
{"role": "user", "content": obs.prompt_text},
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
tool_calls_made = 0
|
| 118 |
+
tool_breakdown: Counter = Counter()
|
| 119 |
+
parse_failures = 0
|
| 120 |
+
illegal_tool_attempts = 0
|
| 121 |
+
start_t = time.time()
|
| 122 |
+
|
| 123 |
+
while not obs.done and tool_calls_made < max_tool_calls:
|
| 124 |
+
# Render chat using model's tool template
|
| 125 |
+
try:
|
| 126 |
+
inputs = tokenizer.apply_chat_template(
|
| 127 |
+
messages,
|
| 128 |
+
tokenize=True,
|
| 129 |
+
add_generation_prompt=True,
|
| 130 |
+
return_tensors="pt",
|
| 131 |
+
).to(model.device)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f" apply_chat_template error: {e}")
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
out = model.generate(
|
| 138 |
+
inputs,
|
| 139 |
+
max_new_tokens=max_completion_per_turn,
|
| 140 |
+
do_sample=(temperature > 0),
|
| 141 |
+
temperature=max(temperature, 1e-5),
|
| 142 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 143 |
+
)
|
| 144 |
+
gen_ids = out[0, inputs.shape[1]:]
|
| 145 |
+
completion = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
| 146 |
+
|
| 147 |
+
# Parse the tool call
|
| 148 |
+
parsed = train_module._parse_completion(completion)
|
| 149 |
+
if parsed is None:
|
| 150 |
+
parse_failures += 1
|
| 151 |
+
if verbose:
|
| 152 |
+
print(f" PARSE FAIL: {completion[:200]}...")
|
| 153 |
+
messages.append({"role": "assistant", "content": completion})
|
| 154 |
+
messages.append({"role": "user", "content": "Your previous output was not parseable. Please emit exactly one tool call."})
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
tool_name = parsed.get("tool", "")
|
| 158 |
+
tool_args = parsed.get("arguments", {}) or {}
|
| 159 |
+
tool_breakdown[tool_name] += 1
|
| 160 |
+
|
| 161 |
+
# Apply to env
|
| 162 |
+
try:
|
| 163 |
+
obs = env.step(CricketAction(tool=tool_name, arguments=tool_args))
|
| 164 |
+
tool_calls_made += 1
|
| 165 |
+
except Exception as e:
|
| 166 |
+
illegal_tool_attempts += 1
|
| 167 |
+
if verbose:
|
| 168 |
+
print(f" ILLEGAL TOOL: {tool_name} → {e}")
|
| 169 |
+
messages.append({"role": "assistant", "content": completion})
|
| 170 |
+
messages.append({"role": "user", "content": f"Tool error: {e}. Try a different tool."})
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
messages.append({"role": "assistant", "content": completion})
|
| 174 |
+
messages.append({"role": "user", "content": obs.prompt_text})
|
| 175 |
+
|
| 176 |
+
elapsed = time.time() - start_t
|
| 177 |
+
state = env.state
|
| 178 |
+
breakdown = state.reward_breakdown or {}
|
| 179 |
+
|
| 180 |
+
# Determine match result
|
| 181 |
+
is_complete = bool(obs.done)
|
| 182 |
+
agent_score = int(state.total_score or 0)
|
| 183 |
+
opp_score = int(state.first_innings_score or 0) if state.innings_type == "second" else None
|
| 184 |
+
target = state.target
|
| 185 |
+
won = None
|
| 186 |
+
if is_complete:
|
| 187 |
+
# Crude win check; env's match_result string is the canonical source
|
| 188 |
+
result_str = (state.match_result or "").lower()
|
| 189 |
+
if "won" in result_str and "agent" in result_str:
|
| 190 |
+
won = True
|
| 191 |
+
elif "lost" in result_str or "won" in result_str:
|
| 192 |
+
won = False
|
| 193 |
+
else:
|
| 194 |
+
won = None
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
"seed": seed,
|
| 198 |
+
"max_overs": max_overs,
|
| 199 |
+
"opponent_mode": opponent_mode,
|
| 200 |
+
"tool_calls_made": tool_calls_made,
|
| 201 |
+
"match_complete": is_complete,
|
| 202 |
+
"won": won,
|
| 203 |
+
"agent_score": agent_score,
|
| 204 |
+
"opponent_first_innings_score": opp_score,
|
| 205 |
+
"target": target,
|
| 206 |
+
"wickets_lost": int(state.wickets_lost or 0),
|
| 207 |
+
"match_result": state.match_result or "",
|
| 208 |
+
"tool_breakdown": dict(tool_breakdown),
|
| 209 |
+
"parse_failures": parse_failures,
|
| 210 |
+
"illegal_tool_attempts": illegal_tool_attempts,
|
| 211 |
+
"validity_rate": round(1.0 - (parse_failures + illegal_tool_attempts) / max(tool_calls_made + parse_failures + illegal_tool_attempts, 1), 4),
|
| 212 |
+
"reward_breakdown": dict(breakdown),
|
| 213 |
+
"elapsed_seconds": round(elapsed, 1),
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ----------------------------------------------------------------------------
|
| 218 |
+
# Run N episodes
|
| 219 |
+
# ----------------------------------------------------------------------------
|
| 220 |
+
|
| 221 |
+
def run_n_episodes(
|
| 222 |
+
*, model, tokenizer, episodes: int, max_overs: int, opponent_mode: str,
|
| 223 |
+
agent_team: str, eval_pack_id: str, seed_base: int, max_tool_calls: int,
|
| 224 |
+
max_completion_per_turn: int, temperature: float, verbose: bool,
|
| 225 |
+
) -> dict:
|
| 226 |
+
results = []
|
| 227 |
+
for i in range(episodes):
|
| 228 |
+
seed = seed_base + i
|
| 229 |
+
print(f" [{i+1}/{episodes}] seed={seed} …", end="", flush=True)
|
| 230 |
+
try:
|
| 231 |
+
res = play_one_episode(
|
| 232 |
+
model=model, tokenizer=tokenizer,
|
| 233 |
+
max_overs=max_overs, opponent_mode=opponent_mode,
|
| 234 |
+
agent_team=agent_team, eval_pack_id=eval_pack_id, seed=seed,
|
| 235 |
+
max_tool_calls=max_tool_calls,
|
| 236 |
+
max_completion_per_turn=max_completion_per_turn,
|
| 237 |
+
temperature=temperature, verbose=verbose,
|
| 238 |
+
)
|
| 239 |
+
print(f" {res['tool_calls_made']} tool calls, "
|
| 240 |
+
f"{'COMPLETE' if res['match_complete'] else 'truncated'}, "
|
| 241 |
+
f"score {res['agent_score']}/{res['wickets_lost']}, "
|
| 242 |
+
f"{res['elapsed_seconds']}s")
|
| 243 |
+
results.append(res)
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f" FAILED: {e}")
|
| 246 |
+
results.append({"seed": seed, "error": str(e)})
|
| 247 |
+
|
| 248 |
+
# Aggregate
|
| 249 |
+
valid = [r for r in results if "error" not in r]
|
| 250 |
+
n = len(valid)
|
| 251 |
+
if n == 0:
|
| 252 |
+
return {"results": results, "summary": {"n": 0, "error": "all episodes failed"}}
|
| 253 |
+
|
| 254 |
+
completed = [r for r in valid if r["match_complete"]]
|
| 255 |
+
won = [r for r in completed if r.get("won") is True]
|
| 256 |
+
summary = {
|
| 257 |
+
"n_episodes": n,
|
| 258 |
+
"match_completion_rate": round(len(completed) / n, 4),
|
| 259 |
+
"win_rate_among_completed": round(len(won) / max(len(completed), 1), 4),
|
| 260 |
+
"win_rate_overall": round(len(won) / n, 4),
|
| 261 |
+
"mean_agent_score": round(sum(r["agent_score"] for r in valid) / n, 2),
|
| 262 |
+
"mean_wickets_lost": round(sum(r["wickets_lost"] for r in valid) / n, 2),
|
| 263 |
+
"mean_tool_calls": round(sum(r["tool_calls_made"] for r in valid) / n, 1),
|
| 264 |
+
"mean_validity_rate": round(sum(r["validity_rate"] for r in valid) / n, 4),
|
| 265 |
+
"mean_composite_reward": round(sum(r["reward_breakdown"].get("composite", 0.0) for r in valid) / n, 4),
|
| 266 |
+
"mean_r_result": round(sum(r["reward_breakdown"].get("r_result", 0.0) for r in valid) / n, 4),
|
| 267 |
+
"mean_r_cricket": round(sum(r["reward_breakdown"].get("r_cricket", 0.0) for r in valid) / n, 4),
|
| 268 |
+
"mean_r_behavior": round(sum(r["reward_breakdown"].get("r_behavior", 0.0) for r in valid) / n, 4),
|
| 269 |
+
"mean_r_validity": round(sum(r["reward_breakdown"].get("r_validity", 0.0) for r in valid) / n, 4),
|
| 270 |
+
"tool_freq": {},
|
| 271 |
+
}
|
| 272 |
+
# Aggregate tool frequencies
|
| 273 |
+
all_tools: Counter = Counter()
|
| 274 |
+
for r in valid:
|
| 275 |
+
for t, c in (r.get("tool_breakdown") or {}).items():
|
| 276 |
+
all_tools[t] += c
|
| 277 |
+
total = sum(all_tools.values()) or 1
|
| 278 |
+
summary["tool_freq"] = {t: round(c / total, 3) for t, c in all_tools.most_common()}
|
| 279 |
+
|
| 280 |
+
return {"results": results, "summary": summary}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ----------------------------------------------------------------------------
|
| 284 |
+
# Comparison printer
|
| 285 |
+
# ----------------------------------------------------------------------------
|
| 286 |
+
|
| 287 |
+
def print_comparison(baseline_path: str, trained_path: str):
|
| 288 |
+
with open(baseline_path) as f:
|
| 289 |
+
b = json.load(f)
|
| 290 |
+
with open(trained_path) as f:
|
| 291 |
+
t = json.load(f)
|
| 292 |
+
bs = b["summary"]
|
| 293 |
+
ts = t["summary"]
|
| 294 |
+
|
| 295 |
+
def row(label, key, fmt="{:.4f}"):
|
| 296 |
+
bv = bs.get(key)
|
| 297 |
+
tv = ts.get(key)
|
| 298 |
+
b_str = fmt.format(bv) if bv is not None else "-"
|
| 299 |
+
t_str = fmt.format(tv) if tv is not None else "-"
|
| 300 |
+
delta = ""
|
| 301 |
+
if isinstance(bv, (int, float)) and isinstance(tv, (int, float)):
|
| 302 |
+
d = tv - bv
|
| 303 |
+
delta = f" ({'+' if d >= 0 else ''}{d:.3f})"
|
| 304 |
+
print(f" {label:<32} {b_str:>12} {t_str:>12}{delta}")
|
| 305 |
+
|
| 306 |
+
print(f"\n{'='*80}")
|
| 307 |
+
print(f"BASELINE vs TRAINED — {bs['n_episodes']} episodes each")
|
| 308 |
+
print(f" baseline label: {b.get('label')} | trained label: {t.get('label')}")
|
| 309 |
+
print(f"{'='*80}")
|
| 310 |
+
print(f" {'metric':<32} {'baseline':>12} {'trained':>12}")
|
| 311 |
+
print(f" {'-'*32} {'-'*12} {'-'*12}")
|
| 312 |
+
row("match_completion_rate", "match_completion_rate")
|
| 313 |
+
row("win_rate_overall", "win_rate_overall")
|
| 314 |
+
row("win_rate_among_completed", "win_rate_among_completed")
|
| 315 |
+
row("mean_agent_score", "mean_agent_score", "{:.2f}")
|
| 316 |
+
row("mean_wickets_lost", "mean_wickets_lost", "{:.2f}")
|
| 317 |
+
row("mean_tool_calls", "mean_tool_calls", "{:.1f}")
|
| 318 |
+
row("mean_validity_rate", "mean_validity_rate")
|
| 319 |
+
row("mean_composite_reward", "mean_composite_reward")
|
| 320 |
+
row("mean_r_result", "mean_r_result")
|
| 321 |
+
row("mean_r_cricket", "mean_r_cricket")
|
| 322 |
+
row("mean_r_behavior", "mean_r_behavior")
|
| 323 |
+
row("mean_r_validity", "mean_r_validity")
|
| 324 |
+
print(f"{'='*80}\n")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# ----------------------------------------------------------------------------
|
| 328 |
+
# Main
|
| 329 |
+
# ----------------------------------------------------------------------------
|
| 330 |
+
|
| 331 |
+
def main():
|
| 332 |
+
parser = argparse.ArgumentParser(description="Baseline vs trained eval for CricketCaptain.")
|
| 333 |
+
parser.add_argument("--model", default="Qwen/Qwen3.5-4B", help="Base HF model id")
|
| 334 |
+
parser.add_argument("--adapter", default=None, help="Optional LoRA adapter directory")
|
| 335 |
+
parser.add_argument("--label", default="run", help="Label for this run (used in output)")
|
| 336 |
+
parser.add_argument("--episodes", type=int, default=10)
|
| 337 |
+
parser.add_argument("--max-overs", type=int, default=5)
|
| 338 |
+
parser.add_argument("--opponent-mode", default="heuristic",
|
| 339 |
+
choices=["heuristic", "llm_live", "llm_cached", "cricsheet"])
|
| 340 |
+
parser.add_argument("--agent-team", default="india")
|
| 341 |
+
parser.add_argument("--eval-pack-id", default="adaptive_t20_v1")
|
| 342 |
+
parser.add_argument("--seed-base", type=int, default=10000)
|
| 343 |
+
parser.add_argument("--max-tool-calls", type=int, default=800)
|
| 344 |
+
parser.add_argument("--max-completion-per-turn", type=int, default=256)
|
| 345 |
+
parser.add_argument("--temperature", type=float, default=0.3)
|
| 346 |
+
parser.add_argument("--output", default=None, help="JSON output path")
|
| 347 |
+
parser.add_argument("--verbose", action="store_true")
|
| 348 |
+
|
| 349 |
+
parser.add_argument("--compare", nargs=2, default=None, metavar=("BASELINE_JSON", "TRAINED_JSON"),
|
| 350 |
+
help="Skip eval; just print comparison from two existing JSON files")
|
| 351 |
+
args = parser.parse_args()
|
| 352 |
+
|
| 353 |
+
if args.compare:
|
| 354 |
+
print_comparison(args.compare[0], args.compare[1])
|
| 355 |
+
return
|
| 356 |
+
|
| 357 |
+
print(f"\nCricketCaptain compare-eval — label='{args.label}'")
|
| 358 |
+
print(f" model={args.model} adapter={args.adapter or '(none)'}")
|
| 359 |
+
print(f" {args.episodes} episodes × {args.max_overs} overs vs {args.opponent_mode} opponent\n")
|
| 360 |
+
|
| 361 |
+
model, tokenizer = load_model_for_eval(args.model, args.adapter)
|
| 362 |
+
|
| 363 |
+
out = run_n_episodes(
|
| 364 |
+
model=model, tokenizer=tokenizer,
|
| 365 |
+
episodes=args.episodes, max_overs=args.max_overs,
|
| 366 |
+
opponent_mode=args.opponent_mode,
|
| 367 |
+
agent_team=args.agent_team, eval_pack_id=args.eval_pack_id,
|
| 368 |
+
seed_base=args.seed_base, max_tool_calls=args.max_tool_calls,
|
| 369 |
+
max_completion_per_turn=args.max_completion_per_turn,
|
| 370 |
+
temperature=args.temperature, verbose=args.verbose,
|
| 371 |
+
)
|
| 372 |
+
out["label"] = args.label
|
| 373 |
+
out["model"] = args.model
|
| 374 |
+
out["adapter"] = args.adapter
|
| 375 |
+
out["config"] = {
|
| 376 |
+
"episodes": args.episodes, "max_overs": args.max_overs,
|
| 377 |
+
"opponent_mode": args.opponent_mode, "agent_team": args.agent_team,
|
| 378 |
+
"max_tool_calls": args.max_tool_calls,
|
| 379 |
+
"max_completion_per_turn": args.max_completion_per_turn,
|
| 380 |
+
"temperature": args.temperature,
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
print("\n=== SUMMARY ===")
|
| 384 |
+
print(json.dumps(out["summary"], indent=2))
|
| 385 |
+
|
| 386 |
+
if args.output:
|
| 387 |
+
out_path = Path(args.output)
|
| 388 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 389 |
+
with out_path.open("w") as f:
|
| 390 |
+
json.dump(out, f, indent=2)
|
| 391 |
+
print(f"\nResults → {out_path}")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
if __name__ == "__main__":
|
| 395 |
+
main()
|
configs/cricket_train_qwen3.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# CricketCaptain-LLM — Qwen3 MAIN run (5-over format, end-to-end matches)
|
| 3 |
+
#
|
| 4 |
+
# Differences from cricket_train.yaml:
|
| 5 |
+
# - Model: Qwen3-4B-Instruct-2507 (vs Qwen3.5-4B which has no vLLM class)
|
| 6 |
+
# - vLLM colocate enabled (3-5× throughput gain)
|
| 7 |
+
# - 5-over format instead of 20-over (20-over end-to-end is infeasible on 1× H200)
|
| 8 |
+
# - Resumes Qwen3 warmup adapter, not Qwen3.5 adapter
|
| 9 |
+
# - Reward clip removed at train.py:1240 (separately) — let GRPO standardize advantage
|
| 10 |
+
# - More steps (200 vs 30) since vLLM makes them affordable
|
| 11 |
+
#
|
| 12 |
+
# Stack: TRL 1.2.0 + transformers 5.6.2 + vLLM + PEFT (LoRA) + bf16 base
|
| 13 |
+
# Target hardware: 1× H200 (144 GB), single node
|
| 14 |
+
# =============================================================================
|
| 15 |
+
|
| 16 |
+
env:
|
| 17 |
+
eval_pack_id: adaptive_t20_v1
|
| 18 |
+
agent_team: india
|
| 19 |
+
max_overs: 5 # 5-over end-to-end. ~180 tool calls per match.
|
| 20 |
+
env_url: ws://localhost:8000
|
| 21 |
+
|
| 22 |
+
opponent:
|
| 23 |
+
mode: heuristic # llm_live optional but slows training ~2-3×
|
| 24 |
+
|
| 25 |
+
train:
|
| 26 |
+
# ---- Model & adapter ----
|
| 27 |
+
model: Qwen/Qwen3-4B-Instruct-2507
|
| 28 |
+
# Resume from the Qwen3 warmup adapter (NOT from Qwen3.5 stage2_final).
|
| 29 |
+
# Comment this line out to start the main run with a fresh adapter.
|
| 30 |
+
resume_from: ./checkpoints/stage2_final
|
| 31 |
+
stage: 2
|
| 32 |
+
|
| 33 |
+
# ---- Dataset ----
|
| 34 |
+
prompts: 256
|
| 35 |
+
|
| 36 |
+
# ---- Schedule ----
|
| 37 |
+
# 5-over rollouts ≈ ~180 tool calls each. With vLLM colocate at B=1 sims=4,
|
| 38 |
+
# roughly ~3-4 min/step. 100 steps ≈ ~5-7 hrs.
|
| 39 |
+
steps: 100
|
| 40 |
+
logging_steps: 1
|
| 41 |
+
save_steps: 20
|
| 42 |
+
save_total_limit: 5
|
| 43 |
+
|
| 44 |
+
# ---- Throughput knobs ----
|
| 45 |
+
# B=1 × grad_accum=4 × G=4 → gen_batch=4, 4 sims in flight per step.
|
| 46 |
+
# Matches the working warmup config. Warmup at B=2 / 24k OOM'd on backward
|
| 47 |
+
# due to gradient-accum buffer churn — at 32k completion it would OOM worse.
|
| 48 |
+
# KV cache ≈ 4 × 36k × 144 KB ≈ 21 GB, comfortable in vLLM 0.55 pool (78 GB).
|
| 49 |
+
batch_size: 1
|
| 50 |
+
grad_accum: 4
|
| 51 |
+
num_generations: 4
|
| 52 |
+
|
| 53 |
+
# ---- Length budget ----
|
| 54 |
+
# 24k completion. At 32k we OOM'd at step 7 of main (140+ GB used). 24k gives
|
| 55 |
+
# ~7 GB headroom for backward-pass activation memory. Full 5-over match needs
|
| 56 |
+
# ~9k tokens of model output, so 24k is 2.5x headroom — plenty.
|
| 57 |
+
max_completion_length: 24576
|
| 58 |
+
max_tool_calling_iterations: 240
|
| 59 |
+
|
| 60 |
+
# ---- Optimizer ----
|
| 61 |
+
learning_rate: 5.0e-6
|
| 62 |
+
|
| 63 |
+
# ---- GRPO ----
|
| 64 |
+
# beta=0.0: no reference model. With 24k completion + B=16 sims, the 8 GB
|
| 65 |
+
# ref model would push us past the H200's 144 GB. Format penalty in reward
|
| 66 |
+
# is the soft anchor instead.
|
| 67 |
+
beta: 0.0
|
| 68 |
+
temperature: 0.9
|
| 69 |
+
top_p: 0.95
|
| 70 |
+
|
| 71 |
+
# ---- Memory ----
|
| 72 |
+
gradient_checkpointing: true
|
| 73 |
+
gradient_checkpointing_use_reentrant: false
|
| 74 |
+
bf16_base: true
|
| 75 |
+
|
| 76 |
+
# ---- vLLM ----
|
| 77 |
+
use_vllm: true
|
| 78 |
+
vllm_gpu_memory: 0.55
|
| 79 |
+
|
| 80 |
+
# ---- Dataloader ----
|
| 81 |
+
dataloader_pin_memory: true
|
| 82 |
+
dataloader_num_workers: 4
|
| 83 |
+
|
| 84 |
+
# ---- Logging ----
|
| 85 |
+
report_to: wandb
|
| 86 |
+
run_name: cricket_qwen3_main
|
| 87 |
+
|
| 88 |
+
# =============================================================================
|
| 89 |
+
# Reward composition (defined in server/reward_calculator.py):
|
| 90 |
+
# composite = 0.55·r_result + 0.25·r_cricket + 0.15·r_behavior + 0.05·r_validity
|
| 91 |
+
#
|
| 92 |
+
# WATCH during training:
|
| 93 |
+
# - rollout/match_completion_rate ≥ 0.7 within ~50 steps
|
| 94 |
+
# → if not, episodes are still hitting cap; tighten per-turn schema or drop max_overs
|
| 95 |
+
# - reward/r_result_mean separate from composite
|
| 96 |
+
# → if r_result stays at 0 while composite rises, you're optimizing format only
|
| 97 |
+
# - episode/tool_calls_mean
|
| 98 |
+
# → should be ~150-200 for 5-over; >220 means truncation events are common
|
| 99 |
+
# =============================================================================
|
configs/cricket_train_qwen3_smoke.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# CricketCaptain-LLM — Qwen3 SMOKE TEST (validate vLLM colocate works)
|
| 3 |
+
#
|
| 4 |
+
# Purpose: shortest possible run to confirm:
|
| 5 |
+
# 1. Qwen3-4B-Instruct-2507 loads cleanly into vLLM colocate
|
| 6 |
+
# (no Qwen3.5 architecture-class registration error)
|
| 7 |
+
# 2. TRL multi-turn environment_factory steps execute without crashing
|
| 8 |
+
# 3. At least one episode reaches `done=True` so `r_result` fires
|
| 9 |
+
#
|
| 10 |
+
# Expected runtime: ~10-15 min on a single H200.
|
| 11 |
+
# Run before kicking off cricket_train_qwen3_warmup.yaml.
|
| 12 |
+
#
|
| 13 |
+
# Stack target: conda cloudspace (torch 2.10, transformers 5.6.2, trl 1.2.0)
|
| 14 |
+
# + vllm + flash-attn installed via .venv-qwen3
|
| 15 |
+
# =============================================================================
|
| 16 |
+
|
| 17 |
+
env:
|
| 18 |
+
eval_pack_id: adaptive_t20_v1
|
| 19 |
+
agent_team: india
|
| 20 |
+
max_overs: 2 # smallest format — match must complete in ~70 turns
|
| 21 |
+
env_url: ws://localhost:8000
|
| 22 |
+
|
| 23 |
+
opponent:
|
| 24 |
+
mode: heuristic # deterministic-ish, no API costs
|
| 25 |
+
|
| 26 |
+
train:
|
| 27 |
+
# ---- Model ----
|
| 28 |
+
model: Qwen/Qwen3-4B-Instruct-2507
|
| 29 |
+
# 256k native context, no <think> blocks, native Qwen3ForCausalLM in vLLM.
|
| 30 |
+
# Fresh adapter — do NOT load Qwen3.5-trained weights into Qwen3 base.
|
| 31 |
+
# resume_from intentionally omitted.
|
| 32 |
+
|
| 33 |
+
stage: 2
|
| 34 |
+
|
| 35 |
+
# ---- Dataset ----
|
| 36 |
+
prompts: 16 # tiny — smoke only
|
| 37 |
+
|
| 38 |
+
# ---- Schedule ----
|
| 39 |
+
steps: 2 # absolute minimum to test gradient + save
|
| 40 |
+
logging_steps: 1
|
| 41 |
+
save_steps: 2
|
| 42 |
+
save_total_limit: 1
|
| 43 |
+
|
| 44 |
+
# ---- Throughput knobs ----
|
| 45 |
+
# bs=1 + grad_accum=4 + G=4 → generation_batch_size=4 divides G cleanly
|
| 46 |
+
# (TRL 1.2 GRPOConfig requires bs*grad_accum divisible by num_generations).
|
| 47 |
+
# 4 sim episodes in flight. KV cache ≈ 4 × 16k × 144 KB ≈ 9.5 GB → tiny.
|
| 48 |
+
batch_size: 1
|
| 49 |
+
grad_accum: 4
|
| 50 |
+
num_generations: 4
|
| 51 |
+
|
| 52 |
+
# ---- Length budget ----
|
| 53 |
+
# 2-over needs ~70 tool calls. At <120 tok/turn this fits in ~8k.
|
| 54 |
+
# Generous 16k completion to catch any per-turn bloat.
|
| 55 |
+
max_completion_length: 16384
|
| 56 |
+
max_tool_calling_iterations: 120
|
| 57 |
+
|
| 58 |
+
# ---- Optimizer ----
|
| 59 |
+
learning_rate: 5.0e-6
|
| 60 |
+
|
| 61 |
+
# ---- GRPO ----
|
| 62 |
+
beta: 0.0 # no reference model (saves ~8 GB VRAM)
|
| 63 |
+
temperature: 0.9
|
| 64 |
+
top_p: 0.95
|
| 65 |
+
|
| 66 |
+
# ---- Memory ----
|
| 67 |
+
gradient_checkpointing: true
|
| 68 |
+
gradient_checkpointing_use_reentrant: false
|
| 69 |
+
bf16_base: true
|
| 70 |
+
|
| 71 |
+
# ---- vLLM colocate (THE thing being tested) ----
|
| 72 |
+
use_vllm: true
|
| 73 |
+
vllm_gpu_memory: 0.50
|
| 74 |
+
# vllm_model_impl omitted → vLLM picks Qwen3ForCausalLM natively.
|
| 75 |
+
# If you fall back to Qwen3.5-4B for some reason, set this to "transformers".
|
| 76 |
+
|
| 77 |
+
# ---- Dataloader ----
|
| 78 |
+
dataloader_pin_memory: true
|
| 79 |
+
dataloader_num_workers: 2
|
| 80 |
+
|
| 81 |
+
# ---- Logging ----
|
| 82 |
+
report_to: wandb
|
| 83 |
+
run_name: cricket_qwen3_smoke
|
| 84 |
+
|
| 85 |
+
# =============================================================================
|
| 86 |
+
# Run with:
|
| 87 |
+
# source .venv-qwen3/bin/activate
|
| 88 |
+
# python train.py train --config configs/cricket_train_qwen3_smoke.yaml
|
| 89 |
+
#
|
| 90 |
+
# Pass criteria:
|
| 91 |
+
# - 2 gradient steps complete without OOM
|
| 92 |
+
# - logs show `rollout/match_completion_rate > 0`
|
| 93 |
+
# - at least one episode in episode_stats.jsonl has `termination_reason: natural`
|
| 94 |
+
#
|
| 95 |
+
# If pass → run cricket_train_qwen3_warmup.yaml
|
| 96 |
+
# If fail with "Qwen3_5..." → wrong model name, check spelling
|
| 97 |
+
# If fail with vLLM class error → vLLM build doesn't include Qwen3 support, upgrade
|
| 98 |
+
# If fail with LoRA error → known TRL+vLLM+LoRA issue, set vllm_model_impl: transformers
|
| 99 |
+
# =============================================================================
|
configs/cricket_train_qwen3_warmup.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# CricketCaptain-LLM — Qwen3 WARMUP (2-3 over curriculum, fast iterations)
|
| 3 |
+
#
|
| 4 |
+
# Differences from cricket_train_warmup.yaml:
|
| 5 |
+
# - Model: Qwen3-4B-Instruct-2507 (256k native, no <think>, native vLLM class)
|
| 6 |
+
# - vLLM colocate enabled (works because Qwen3ForCausalLM is a registered class)
|
| 7 |
+
# - Fresh adapter (do NOT resume Qwen3.5 LoRA on Qwen3 base — incompatible)
|
| 8 |
+
# - Slightly tighter schedule given vLLM throughput gain (~3-5x)
|
| 9 |
+
#
|
| 10 |
+
# Stack: TRL 1.2.0 + transformers 5.6.2 + vLLM + PEFT (LoRA) + bf16 base
|
| 11 |
+
# Target hardware: 1× H200 (144 GB)
|
| 12 |
+
# =============================================================================
|
| 13 |
+
|
| 14 |
+
env:
|
| 15 |
+
eval_pack_id: adaptive_t20_v1
|
| 16 |
+
agent_team: india
|
| 17 |
+
max_overs: 0 # 0 = use overs_distribution below
|
| 18 |
+
env_url: ws://localhost:8000
|
| 19 |
+
|
| 20 |
+
opponent:
|
| 21 |
+
mode: heuristic # fast iteration; switch to llm_live for final eval
|
| 22 |
+
|
| 23 |
+
train:
|
| 24 |
+
# ---- Model & adapter ----
|
| 25 |
+
model: Qwen/Qwen3-4B-Instruct-2507
|
| 26 |
+
stage: 2 # full composite reward
|
| 27 |
+
# No resume_from — start fresh on Qwen3 base.
|
| 28 |
+
|
| 29 |
+
# ---- Dataset ----
|
| 30 |
+
prompts: 64
|
| 31 |
+
|
| 32 |
+
# ---- Schedule ----
|
| 33 |
+
# 2-3 over rollouts ≈ ~70-110 tool calls each. With vLLM colocate, ~2-3 min/step
|
| 34 |
+
# at 16 sim episodes. 30 steps ≈ ~1.5 hrs total.
|
| 35 |
+
steps: 30
|
| 36 |
+
logging_steps: 1
|
| 37 |
+
save_steps: 5
|
| 38 |
+
save_total_limit: 5
|
| 39 |
+
|
| 40 |
+
# ---- Curriculum (per-scenario max_overs) ----
|
| 41 |
+
# Heavier on T2 (cleanly completes in token budget), tail to T3.
|
| 42 |
+
# Skip T4/T5 in warmup — those go in the main run.
|
| 43 |
+
overs_distribution: [2, 2, 2, 2, 2, 2, 3, 3, 3]
|
| 44 |
+
|
| 45 |
+
# ---- Throughput knobs ----
|
| 46 |
+
# B=1 × grad_accum=4 × G=4 → gen_batch=4, 4 sims in flight per step.
|
| 47 |
+
# Matches the smoke config that ran cleanly. Slower per step (~30-40s) than
|
| 48 |
+
# B=2 (~55s), but B=2 OOM'd on backward of step 2 due to gradient-accumulation
|
| 49 |
+
# micro-batch buffer churn even with expandable_segments. 30 steps ≈ ~20 min.
|
| 50 |
+
batch_size: 1
|
| 51 |
+
grad_accum: 4
|
| 52 |
+
num_generations: 4
|
| 53 |
+
|
| 54 |
+
# ---- Length budget ----
|
| 55 |
+
# 24k completion = ~130 tok/turn × 180 turns. Generous for 2-3 over format,
|
| 56 |
+
# matches Qwen3-4B-Instruct-2507 recommendation of ≥32k output for most queries
|
| 57 |
+
# (per-rollout cumulative across multi-turn).
|
| 58 |
+
max_completion_length: 24576
|
| 59 |
+
max_tool_calling_iterations: 240
|
| 60 |
+
|
| 61 |
+
# ---- Optimizer ----
|
| 62 |
+
learning_rate: 5.0e-6
|
| 63 |
+
|
| 64 |
+
# ---- GRPO ----
|
| 65 |
+
# beta=0.0: no reference model (saves ~8 GB VRAM, lets G=4 fit at 16k completion).
|
| 66 |
+
# Reward shaping has format penalty as soft anchor.
|
| 67 |
+
beta: 0.0
|
| 68 |
+
temperature: 0.9
|
| 69 |
+
top_p: 0.95
|
| 70 |
+
|
| 71 |
+
# ---- Memory ----
|
| 72 |
+
gradient_checkpointing: true
|
| 73 |
+
gradient_checkpointing_use_reentrant: false
|
| 74 |
+
bf16_base: true
|
| 75 |
+
|
| 76 |
+
# ---- vLLM (THE big change vs Qwen3.5 config) ----
|
| 77 |
+
use_vllm: true
|
| 78 |
+
vllm_gpu_memory: 0.55
|
| 79 |
+
# vllm_model_impl omitted → vLLM picks Qwen3ForCausalLM natively.
|
| 80 |
+
|
| 81 |
+
# ---- Dataloader ----
|
| 82 |
+
dataloader_pin_memory: true
|
| 83 |
+
dataloader_num_workers: 4
|
| 84 |
+
|
| 85 |
+
# ---- Logging ----
|
| 86 |
+
report_to: wandb
|
| 87 |
+
run_name: cricket_qwen3_warmup
|
| 88 |
+
|
| 89 |
+
# =============================================================================
|
| 90 |
+
# Workflow:
|
| 91 |
+
# 1. Smoke test (10-15 min):
|
| 92 |
+
# python train.py train --config configs/cricket_train_qwen3_smoke.yaml
|
| 93 |
+
# 2. Warmup (1-2 hrs):
|
| 94 |
+
# python train.py train --config configs/cricket_train_qwen3_warmup.yaml
|
| 95 |
+
# 3. Main run (5-7 hrs):
|
| 96 |
+
# python train.py train --config configs/cricket_train_qwen3.yaml
|
| 97 |
+
# (resumes from ./checkpoints/stage2_final saved by step 2)
|
| 98 |
+
# =============================================================================
|
configs/extras/cached_eval.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
env:
|
| 2 |
+
# Used by server + runners for reproducible comparison runs.
|
| 3 |
+
eval_pack_id: adaptive_t20_v1
|
| 4 |
+
max_overs: 5
|
| 5 |
+
env_url: ws://localhost:8000
|
| 6 |
+
|
| 7 |
+
opponent:
|
| 8 |
+
# llm_cached does not call `model` live. It replays pre-generated decisions
|
| 9 |
+
# from cache_path so every captain model faces the same opponent behavior.
|
| 10 |
+
mode: llm_cached
|
| 11 |
+
cache_path: data/opponent_cache/adaptive_t20_v1_official_gemma2b.jsonl
|
| 12 |
+
|
| 13 |
+
captain:
|
| 14 |
+
# Captain still calls HF router live in this config.
|
| 15 |
+
model: google/gemma-4-26B-A4B-it
|
| 16 |
+
api_base: https://router.huggingface.co/v1
|
| 17 |
+
api_key_env: HF_TOKEN
|
| 18 |
+
|
configs/extras/cricket_train.yaml
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# CricketCaptain-LLM — dedicated training config
|
| 3 |
+
# Stack: TRL 1.2.0 GRPO + Transformers 5.6.2 + PEFT (LoRA) + bf16 base
|
| 4 |
+
# NOT using vLLM (incompatible with our env), NOT using Unsloth (incompatible
|
| 5 |
+
# with TRL's multi-turn environment_factory).
|
| 6 |
+
# Target run: 10-hour budget on a single H200 (144 GB VRAM), Qwen3.5-4B base.
|
| 7 |
+
# =============================================================================
|
| 8 |
+
|
| 9 |
+
env:
|
| 10 |
+
# eval_pack_id selects scenario distribution + opponent rosters
|
| 11 |
+
eval_pack_id: adaptive_t20_v1
|
| 12 |
+
agent_team: india
|
| 13 |
+
# Match length per episode. 20 = full T20.
|
| 14 |
+
# 5-over (~180 tool calls) → ~5 min/step
|
| 15 |
+
# 20-over (~720 tool calls) → ~15-18 min/step ← current choice
|
| 16 |
+
max_overs: 20
|
| 17 |
+
env_url: ws://localhost:8000
|
| 18 |
+
|
| 19 |
+
opponent:
|
| 20 |
+
# heuristic | llm_live | llm_cached | cricsheet
|
| 21 |
+
# heuristic: rule-based, fast, deterministic-ish — best for fast iteration.
|
| 22 |
+
# llm_live: adversarial Gemma via HF router — realistic but slow + costs API.
|
| 23 |
+
mode: heuristic
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
train:
|
| 27 |
+
# ---- Model & adapter ----
|
| 28 |
+
model: Qwen/Qwen3.5-4B
|
| 29 |
+
# Resume LoRA from the warmup checkpoint. When this run starts, base model
|
| 30 |
+
# loads from Qwen/Qwen3.5-4B above + LoRA adapter loads from this path.
|
| 31 |
+
# Comment out (with #) to start with a fresh adapter instead.
|
| 32 |
+
resume_from: ./checkpoints/stage2_final
|
| 33 |
+
# Single-stage training. Code uses curriculum_stage=2 internally to mean
|
| 34 |
+
# "full composite reward". Stage 1 (validity-only warm-up) was dropped because
|
| 35 |
+
# Qwen3.5-4B already knows tool calling natively (XML+JSON both accepted).
|
| 36 |
+
stage: 2
|
| 37 |
+
|
| 38 |
+
# ---- Dataset ----
|
| 39 |
+
prompts: 256 # number of unique scenarios; trainer cycles through them
|
| 40 |
+
|
| 41 |
+
# ---- Schedule ----
|
| 42 |
+
# 30 steps × ~15 min/step = ~7-8 hrs for the main run. Total chain (warmup +
|
| 43 |
+
# main) fits in ~10-hr budget. Bump to 100+ for a longer training run if
|
| 44 |
+
# compute is unconstrained — the resume_from path lets you continue cleanly.
|
| 45 |
+
steps: 30
|
| 46 |
+
logging_steps: 1 # log every step (set 10 for less wandb noise)
|
| 47 |
+
|
| 48 |
+
# ---- Throughput knobs (H200, 144 GB VRAM) ----
|
| 49 |
+
# batch_size × grad_accum = effective rollout per gradient update.
|
| 50 |
+
# batch_size × num_generations = simultaneous in-flight episodes (memory driver).
|
| 51 |
+
# 4 × 4 = 16 sim ≈ ~80 GB peak at 4096 max_completion — needed since 32 sim @ 4096
|
| 52 |
+
# OOMs at ~110 GB. grad_accum=4 keeps effective batch healthy.
|
| 53 |
+
batch_size: 4
|
| 54 |
+
grad_accum: 4
|
| 55 |
+
num_generations: 4 # GRPO requires ≥2 for group advantage; 4 is the sweet spot
|
| 56 |
+
|
| 57 |
+
# ---- Length budget ----
|
| 58 |
+
# 4096 = ~16 turns per rollout at ~250 tok/turn. Larger budget than warmup
|
| 59 |
+
# because main run uses smaller batch (32 sim @ batch=8 × num_gen=4) so
|
| 60 |
+
# KV cache fits comfortably.
|
| 61 |
+
max_completion_length: 4096
|
| 62 |
+
# Hard cap on tool calls per episode. 5-over needs ~180; 20-over needs ~720.
|
| 63 |
+
# 800 leaves slack for extra-balls (no-balls/wides) without truncating matches.
|
| 64 |
+
max_tool_calling_iterations: 800
|
| 65 |
+
|
| 66 |
+
# ---- Optimizer ----
|
| 67 |
+
# Bumped 5e-6 → 1e-5 for the main run. Warmup at 5e-6 showed flat reward and
|
| 68 |
+
# tiny loss magnitudes (~0.02) with grad_norm 0.35 — gradients flowing but
|
| 69 |
+
# weights barely moving. 2× LR doubles step size without re-entering instability
|
| 70 |
+
# territory (still well below the 5e-5 commonly used for r=64 LoRA SFT).
|
| 71 |
+
learning_rate: 1.0e-5
|
| 72 |
+
|
| 73 |
+
# ---- GRPO knobs ----
|
| 74 |
+
# KL coefficient. Default in TRL 1.2 is 0.0 — and 0.0 specifically means the
|
| 75 |
+
# reference model is NOT loaded (saves ~8 GB weights + ref-forward-pass mem).
|
| 76 |
+
# Any beta > 0 loads a frozen copy of the base model for KL anchoring.
|
| 77 |
+
# We use 0.0 because: (a) memory budget is tight at batch=16 + 20-over,
|
| 78 |
+
# (b) reward is well-shaped (composite with format penalty), so format collapse
|
| 79 |
+
# is unlikely, (c) cricket strategy isn't in the base distribution — we want drift.
|
| 80 |
+
beta: 0.0
|
| 81 |
+
# Sampling — slight bump from 0.8 for GRPO group diversity. top_p=0.95 trims
|
| 82 |
+
# the long tail (rare bad tokens during 720-turn rollouts).
|
| 83 |
+
temperature: 0.9
|
| 84 |
+
top_p: 0.95
|
| 85 |
+
|
| 86 |
+
# ---- Memory ----
|
| 87 |
+
# Trade ~30% extra backward-pass compute for big activation memory savings
|
| 88 |
+
# (>20 GB). Required to fit batch_size=16 + num_gen=4 + 3072 completion in 144 GB.
|
| 89 |
+
# use_reentrant=False is the modern path, more stable with LoRA than the legacy True.
|
| 90 |
+
gradient_checkpointing: true
|
| 91 |
+
gradient_checkpointing_use_reentrant: false
|
| 92 |
+
|
| 93 |
+
# ---- Dataloader ----
|
| 94 |
+
# Cheap micro-opts — pin host memory + a few worker threads so CPU isn't a
|
| 95 |
+
# bottleneck feeding prompts. Tiny win since our dataset is only 256 prompts.
|
| 96 |
+
dataloader_pin_memory: true
|
| 97 |
+
dataloader_num_workers: 4
|
| 98 |
+
|
| 99 |
+
# ---- Checkpointing ----
|
| 100 |
+
# Save every 10 steps; keep only the 5 most recent on disk to cap usage.
|
| 101 |
+
save_steps: 10
|
| 102 |
+
save_total_limit: 5
|
| 103 |
+
|
| 104 |
+
# ---- LoRA (currently hardcoded in train.py load_model()) ----
|
| 105 |
+
# r=64, alpha=128, dropout=0.05, targets q/k/v/o/gate/up/down
|
| 106 |
+
# → 85M trainable params (1.98% of 4.2B base)
|
| 107 |
+
|
| 108 |
+
# ---- Precision ----
|
| 109 |
+
# bf16 base + bf16 LoRA adapter. NO 4-bit quant — H200 has the VRAM and bf16
|
| 110 |
+
# is 15-20% faster than 4-bit dequant on every forward pass.
|
| 111 |
+
bf16_base: true
|
| 112 |
+
|
| 113 |
+
# ---- Logging ----
|
| 114 |
+
report_to: wandb # or "tensorboard" for local-only
|
| 115 |
+
run_name: cricket_captain_v10
|
| 116 |
+
|
| 117 |
+
# =============================================================================
|
| 118 |
+
# Reward composition (weights live in configs/game_knowledge.yaml `reward:` block)
|
| 119 |
+
# composite = 0.35·r_result + 0.30·r_cricket + 0.25·r_behavior + 0.10·r_validity
|
| 120 |
+
#
|
| 121 |
+
# r_result — match outcome (chase margin, defense margin, win/tie bonus)
|
| 122 |
+
# r_cricket — Dream11 fantasy proxy as dense per-ball signal
|
| 123 |
+
# r_behavior — coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%)
|
| 124 |
+
# r_validity — fraction of legal tool calls
|
| 125 |
+
# =============================================================================
|
configs/extras/cricket_train_warmup.yaml
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# CricketCaptain-LLM — WARMUP training config (5-over format, fast iterations)
|
| 3 |
+
# Stack: TRL 1.2.0 GRPO + Transformers 5.6.2 + PEFT (LoRA) + bf16 base
|
| 4 |
+
# Purpose: short matches + bigger GRPO group = fast format mastery before
|
| 5 |
+
# moving to the full 20-over training run (configs/cricket_train.yaml).
|
| 6 |
+
#
|
| 7 |
+
# Why bigger num_generations here?
|
| 8 |
+
# 5-over rollouts are ~4× faster per step than 20-over, so we have memory +
|
| 9 |
+
# wall-clock budget for G=8. Bigger group → more stable GRPO advantage signal,
|
| 10 |
+
# especially useful when the model is just learning format/tactics.
|
| 11 |
+
# =============================================================================
|
| 12 |
+
|
| 13 |
+
env:
|
| 14 |
+
eval_pack_id: adaptive_t20_v1
|
| 15 |
+
agent_team: india
|
| 16 |
+
# max_overs: 0 unlocks the per-scenario curriculum distribution (see train.overs_distribution).
|
| 17 |
+
# Set to a positive integer to lock all scenarios to that single format.
|
| 18 |
+
max_overs: 0
|
| 19 |
+
env_url: ws://localhost:8000
|
| 20 |
+
|
| 21 |
+
opponent:
|
| 22 |
+
# Heuristic for fast iteration. Switch to llm_live for final eval, not warmup.
|
| 23 |
+
mode: heuristic
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
train:
|
| 27 |
+
# ---- Model & adapter ----
|
| 28 |
+
model: Qwen/Qwen3.5-4B
|
| 29 |
+
# Single-stage. curriculum_stage=2 internally = full composite reward.
|
| 30 |
+
stage: 2
|
| 31 |
+
|
| 32 |
+
# ---- Dataset ----
|
| 33 |
+
prompts: 64 # smaller dataset since we're only doing 25 steps
|
| 34 |
+
|
| 35 |
+
# ---- Schedule ----
|
| 36 |
+
# 5-over ≈ ~180 tool calls / episode → ~4-5 min / step at 32 sim episodes.
|
| 37 |
+
# 25 steps ≈ ~2 hrs total. Quick warmup before kicking off the full
|
| 38 |
+
# 20-over run with cricket_train.yaml.
|
| 39 |
+
steps: 25
|
| 40 |
+
logging_steps: 1
|
| 41 |
+
# Save every 5 steps → 5 checkpoints across 25 steps. Keep only 5 on disk.
|
| 42 |
+
save_steps: 5
|
| 43 |
+
save_total_limit: 5
|
| 44 |
+
|
| 45 |
+
# ---- Curriculum (per-scenario max_overs) ----
|
| 46 |
+
# Heavy on T2 (the only format that actually completes inside our token budget),
|
| 47 |
+
# tail to T5 (target eval distribution). Activated when env.max_overs is 0 or unset.
|
| 48 |
+
# Frequencies: ~45% T2, ~27% T3, ~18% T4, ~9% T5.
|
| 49 |
+
overs_distribution: [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
|
| 50 |
+
|
| 51 |
+
# ---- Throughput knobs ----
|
| 52 |
+
# 4 × 4 = 16 sim episodes — needed because 4096 max_completion at 32 sim OOMs
|
| 53 |
+
# (~110 GB observed). At 16 sim: ~80 GB, fits with margin. grad_accum=4 keeps
|
| 54 |
+
# the effective gradient batch size where G=4 still gives stable advantages.
|
| 55 |
+
batch_size: 4
|
| 56 |
+
grad_accum: 4
|
| 57 |
+
num_generations: 4
|
| 58 |
+
|
| 59 |
+
# ---- Length budget ----
|
| 60 |
+
# 4096 = ~16 turns per rollout at ~250 tok/turn. Lets T2 episodes get further
|
| 61 |
+
# toward completion (~5-8 balls per innings vs ~3-5 at 3072).
|
| 62 |
+
max_completion_length: 4096
|
| 63 |
+
# Cap above what 5-over needs (~180) so no truncation from extras.
|
| 64 |
+
max_tool_calling_iterations: 240
|
| 65 |
+
|
| 66 |
+
# ---- Optimizer ----
|
| 67 |
+
learning_rate: 5.0e-6
|
| 68 |
+
|
| 69 |
+
# ---- GRPO knobs ----
|
| 70 |
+
# beta=0.0 → no reference model loaded → saves ~8 GB VRAM and lets G=8 fit.
|
| 71 |
+
beta: 0.0
|
| 72 |
+
temperature: 0.9
|
| 73 |
+
top_p: 0.95
|
| 74 |
+
|
| 75 |
+
# ---- Memory ----
|
| 76 |
+
gradient_checkpointing: true
|
| 77 |
+
gradient_checkpointing_use_reentrant: false
|
| 78 |
+
|
| 79 |
+
# ---- Dataloader ----
|
| 80 |
+
dataloader_pin_memory: true
|
| 81 |
+
dataloader_num_workers: 4
|
| 82 |
+
|
| 83 |
+
# ---- Precision ----
|
| 84 |
+
bf16_base: true
|
| 85 |
+
|
| 86 |
+
# ---- Logging ----
|
| 87 |
+
report_to: wandb
|
| 88 |
+
run_name: cricket_captain_warmup_5over
|
| 89 |
+
|
| 90 |
+
# =============================================================================
|
| 91 |
+
# Workflow:
|
| 92 |
+
# 1. Train warmup: python train.py train --config configs/cricket_train_warmup.yaml
|
| 93 |
+
# 2. Train main run: python train.py train --config configs/cricket_train.yaml \
|
| 94 |
+
# --model ./checkpoints/stage2_final (resume from warmup)
|
| 95 |
+
# =============================================================================
|
configs/extras/default.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
env:
|
| 2 |
+
# Used by server + runners
|
| 3 |
+
eval_pack_id: adaptive_t20_v1
|
| 4 |
+
max_overs: 20
|
| 5 |
+
agent_team: india
|
| 6 |
+
env_url: ws://localhost:8000
|
| 7 |
+
|
| 8 |
+
opponent:
|
| 9 |
+
# heuristic | llm_live | llm_cached
|
| 10 |
+
# llm_live is the adversarial opponent. The captain being trained still runs
|
| 11 |
+
# locally through TRL/Transformers, not through the HF inference endpoint.
|
| 12 |
+
# For reproducible cached evaluation, use configs/cached_eval.yaml instead.
|
| 13 |
+
# H200 default: use the live LLM opponent for realism.
|
| 14 |
+
mode: llm_live
|
| 15 |
+
model: google/gemma-4-26B-A4B-it
|
| 16 |
+
api_base: https://router.huggingface.co/v1
|
| 17 |
+
api_key_env: HF_TOKEN
|
| 18 |
+
|
| 19 |
+
captain:
|
| 20 |
+
# For inference/eval runner when using an API model (OpenAI-compatible).
|
| 21 |
+
# You can still pass --model random for baseline runs.
|
| 22 |
+
model: google/gemma-4-26B-A4B-it
|
| 23 |
+
api_base: https://router.huggingface.co/v1
|
| 24 |
+
api_key_env: HF_TOKEN
|
| 25 |
+
|
| 26 |
+
train:
|
| 27 |
+
model: Qwen/Qwen3.5-4B
|
| 28 |
+
# SINGLE-STAGE training. The two-stage curriculum (Stage 1 = format-only)
|
| 29 |
+
# was dropped — Qwen3.5-4B already knows tool calling natively (XML+JSON),
|
| 30 |
+
# so we run the full composite reward (r_result + r_cricket + r_behavior + r_validity)
|
| 31 |
+
# from step 0. Code still uses curriculum_stage=2 internally to mean "full reward".
|
| 32 |
+
stage: 2
|
| 33 |
+
prompts: 256
|
| 34 |
+
# 10-hr H200 budget. ~150 steps × ~4-5 min each = ~10 hrs.
|
| 35 |
+
steps: 150
|
| 36 |
+
# H200 (144GB) v8 config: bigger micro-batch, fewer grad accum steps, more updates.
|
| 37 |
+
# Effective rollout per gradient update = batch_size * grad_accum = 32 prompts
|
| 38 |
+
# → with num_generations=4 that's 128 episodes per gradient update.
|
| 39 |
+
batch_size: 16
|
| 40 |
+
grad_accum: 2
|
| 41 |
+
num_generations: 4 # GRPO requires >= 2 for group advantage; 4 is the standard sweet spot
|
| 42 |
+
# 3072 = enough headroom for thinking (v5 mean was 1115 tokens) without the
|
| 43 |
+
# KV-cache bloat of 4096. Cuts memory for the rollout buffer ~25%.
|
| 44 |
+
max_completion_length: 3072
|
| 45 |
+
# Episodes terminate either when match finishes or this cap hits.
|
| 46 |
+
# 5-over match needs ~180 tool calls, 20-over T20 needs ~720. Don't truncate.
|
| 47 |
+
max_tool_calling_iterations: 800
|
| 48 |
+
# Learning rate — was hardcoded to 1e-5; lowered to 5e-6 because of LoRA r=64
|
| 49 |
+
# (4× more trainable params than r=16) and partially-sparse outcome reward.
|
| 50 |
+
# Smaller LR = more stable updates when r_result is binary at innings end.
|
| 51 |
+
learning_rate: 5e-6
|
| 52 |
+
# Trade ~30% extra compute per backward pass for big activation memory savings
|
| 53 |
+
# (>20 GB freed). Lets us push micro-batch / num_generations with safety margin.
|
| 54 |
+
gradient_checkpointing: true
|
| 55 |
+
logging_steps: 1 # log loss/reward every step (set 10 for less noise)
|
| 56 |
+
# Switch to `tensorboard` if you prefer local-only logging.
|
| 57 |
+
report_to: wandb
|
| 58 |
+
run_name: cricket_captain_v10
|
| 59 |
+
|
configs/game_knowledge.yaml
CHANGED
|
@@ -29,11 +29,22 @@ transition_overs: [6, 16]
|
|
| 29 |
# Reward weights (must sum to 1.0)
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
reward:
|
| 32 |
-
# Episode-level composite
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# Within r_behavior
|
| 39 |
behavior:
|
|
|
|
| 29 |
# Reward weights (must sum to 1.0)
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
reward:
|
| 32 |
+
# Episode-level composite — REBALANCED for partial-trajectory training.
|
| 33 |
+
# Original 55/25/15/5 was textbook "outcome-dominated" but assumed matches
|
| 34 |
+
# would actually complete in the token budget. Reality: episodes truncate at
|
| 35 |
+
# ~25% of T2 and r_result almost never fires. Putting 55% weight on a signal
|
| 36 |
+
# that fires <5% of the time washes out gradient.
|
| 37 |
+
#
|
| 38 |
+
# New weights match the SWE-RL / coding-agent-RL recipe. Re-rebalanced for
|
| 39 |
+
# the main 20-over run: at full T20 length r_result actually fires (matches
|
| 40 |
+
# complete), so we shift 0.15 weight back from r_cricket → r_result to give
|
| 41 |
+
# the outcome signal real gradient pull. r_cricket alone was producing flat,
|
| 42 |
+
# bouncy reward curves in warmup — per-ball Dream11 is too noisy as 45% of
|
| 43 |
+
# the gradient when group size is only 4.
|
| 44 |
+
r_result: 0.35 # match outcome — fires reliably at 20 overs, biggest signal
|
| 45 |
+
r_cricket: 0.30 # dense per-ball Dream11 (analog: partial test-pass)
|
| 46 |
+
r_behavior: 0.25 # per-turn coherence/adaptation (analog: lint/quality)
|
| 47 |
+
r_validity: 0.10 # tool format (analog: compile success)
|
| 48 |
|
| 49 |
# Within r_behavior
|
| 50 |
behavior:
|
docs/benchmark_explainer.md
CHANGED
|
@@ -30,14 +30,13 @@ The original motivation came from strategic coherence: LLMs often say one thing
|
|
| 30 |
|
| 31 |
## 2. Fit With OpenEnv Competition Themes
|
| 32 |
|
| 33 |
-
The environment aligns with multiple OpenEnv hackathon themes.
|
| 34 |
-
|
| 35 |
### Multi-Agent Interactions
|
| 36 |
|
| 37 |
The submitted captain agent plays against an opponent policy. The opponent can be:
|
| 38 |
|
| 39 |
-
- `heuristic`: fast
|
| 40 |
-
- `
|
|
|
|
| 41 |
- `llm_cached`: replayed opponent decisions for reproducible evaluation.
|
| 42 |
|
| 43 |
This tests whether the agent can reason about another actor's incentives, field settings, and likely plans.
|
|
@@ -48,29 +47,11 @@ A full match has many decisions across innings, phases, wickets, and pressure st
|
|
| 48 |
|
| 49 |
### World Modeling
|
| 50 |
|
| 51 |
-
The agent observes a partially summarized cricket world:
|
| 52 |
-
|
| 53 |
-
- score,
|
| 54 |
-
- over/ball,
|
| 55 |
-
- wickets,
|
| 56 |
-
- target,
|
| 57 |
-
- phase,
|
| 58 |
-
- field,
|
| 59 |
-
- batter profile,
|
| 60 |
-
- bowler profile,
|
| 61 |
-
- opponent plan,
|
| 62 |
-
- previous outcome.
|
| 63 |
-
|
| 64 |
-
It must maintain an internal model of what is happening and update that model after every ball.
|
| 65 |
|
| 66 |
### Self-Improvement
|
| 67 |
|
| 68 |
-
The same environment can support
|
| 69 |
-
|
| 70 |
-
- heuristic curriculum training,
|
| 71 |
-
- cached-opponent official evaluation,
|
| 72 |
-
- live LLM opponent self-play,
|
| 73 |
-
- future agent-vs-agent training.
|
| 74 |
|
| 75 |
## 3. Environment Flow
|
| 76 |
|
|
@@ -86,8 +67,6 @@ Within each batting or bowling phase, the tactical loop is:
|
|
| 86 |
PRE_OVER -> PRE_BALL -> BALL_RESOLUTION -> POST_BALL -> next decision
|
| 87 |
```
|
| 88 |
|
| 89 |
-
The captain can use tools at different points in that loop.
|
| 90 |
-
|
| 91 |
### Toss
|
| 92 |
|
| 93 |
```json
|
|
@@ -97,40 +76,21 @@ The captain can use tools at different points in that loop.
|
|
| 97 |
### Batting Tools
|
| 98 |
|
| 99 |
```json
|
| 100 |
-
{"tool": "select_batter", "arguments": {"name": "
|
| 101 |
-
```
|
| 102 |
-
|
| 103 |
-
```json
|
| 104 |
{"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Rotate strike against spin and keep wickets in hand."}}
|
| 105 |
-
```
|
| 106 |
-
|
| 107 |
-
```json
|
| 108 |
{"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy single."}}
|
| 109 |
-
```
|
| 110 |
-
|
| 111 |
-
```json
|
| 112 |
{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Work into the gap."}}
|
| 113 |
```
|
| 114 |
|
| 115 |
-
|
| 116 |
|
| 117 |
-
|
| 118 |
-
{"tool": "choose_bowler", "arguments": {"name": "Death Specialist", "bowler_type": "pace", "style": "yorker", "rationale": "Attack the stumps at the death."}}
|
| 119 |
-
```
|
| 120 |
|
| 121 |
```json
|
|
|
|
| 122 |
{"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Limit swing room."}}
|
| 123 |
-
```
|
| 124 |
-
|
| 125 |
-
```json
|
| 126 |
{"tool": "set_field_setting", "arguments": {"setting": "Defensive"}}
|
| 127 |
-
```
|
| 128 |
-
|
| 129 |
-
```json
|
| 130 |
{"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Protect boundaries and force a low-percentage shot."}}
|
| 131 |
-
```
|
| 132 |
-
|
| 133 |
-
```json
|
| 134 |
{"tool": "bowl_delivery", "arguments": {}}
|
| 135 |
```
|
| 136 |
|
|
@@ -146,309 +106,197 @@ The captain can use tools at different points in that loop.
|
|
| 146 |
{"tool": "analyze_situation", "arguments": {"query_type": "match_situation"}}
|
| 147 |
```
|
| 148 |
|
| 149 |
-
## 4.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
The
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
```text
|
| 154 |
LLM Agent / Evaluator
|
| 155 |
|
|
| 156 |
-
| WebSocket
|
| 157 |
v
|
| 158 |
-
|
| 159 |
|
|
| 160 |
v
|
| 161 |
-
CricketEnvironment
|
| 162 |
|
|
| 163 |
-
+--> MarkovCricketEngine
|
| 164 |
-
+-->
|
| 165 |
-
+-->
|
| 166 |
-
+-->
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
- `server/app.py`: creates the OpenEnv server using `create_app(...)`.
|
| 172 |
-
- `server/cricket_environment.py`: implements `reset`, `step`, and `state`.
|
| 173 |
-
- `models.py`: defines `CricketAction`, `CricketObservation`, and `CricketState`.
|
| 174 |
-
- `client.py`: defines the WebSocket client `CricketCaptainEnv`.
|
| 175 |
-
- `inference.py`: runs random or OpenAI-compatible agents against the server.
|
| 176 |
-
- `eval.py`: runs evaluation episodes and saves plots/raw logs.
|
| 177 |
-
|
| 178 |
-
The standard OpenEnv lifecycle is:
|
| 179 |
-
|
| 180 |
-
```python
|
| 181 |
-
obs = env.reset(...)
|
| 182 |
-
obs = env.step(CricketAction(tool="...", arguments={...}))
|
| 183 |
-
state = env.state
|
| 184 |
-
```
|
| 185 |
-
|
| 186 |
-
This matters for competition compliance because clients do not need to import server internals. They interact through the OpenEnv API.
|
| 187 |
-
|
| 188 |
-
## 5. What The Observation Contains
|
| 189 |
-
|
| 190 |
-
Each step returns a `CricketObservation` with fields like:
|
| 191 |
-
|
| 192 |
-
- `game_state`: toss / batting / bowling / finished.
|
| 193 |
-
- `strategic_phase`: pre_over / pre_ball / ball_resolution / post_ball.
|
| 194 |
-
- `game_context`: score, wickets, over, ball, target, phase.
|
| 195 |
-
- `declared_strategy`: current batting strategy.
|
| 196 |
-
- `bowling_strategy`: current bowling plan.
|
| 197 |
-
- `field_setting`: Aggressive / Balanced / Defensive.
|
| 198 |
-
- `current_batter`: batter profile.
|
| 199 |
-
- `current_bowler`: bowler profile.
|
| 200 |
-
- `opponent_plan`: last visible opponent policy decision.
|
| 201 |
-
- `last_outcome`: previous ball outcome plus tactical metadata such as event type, shot zone, delivery features, field pressure, and fielder effect.
|
| 202 |
-
- `available_tools`: legal tools for current state.
|
| 203 |
-
- `prompt_text`: rendered prompt for the LLM.
|
| 204 |
-
|
| 205 |
-
The LLM sees enough information to reason tactically, but not the entire simulator internals.
|
| 206 |
-
|
| 207 |
-
## 6. Opponent Policies
|
| 208 |
-
|
| 209 |
-
Opponent behavior lives in `server/opponent_policy.py`.
|
| 210 |
-
|
| 211 |
-
There are three modes:
|
| 212 |
-
|
| 213 |
-
### `heuristic`
|
| 214 |
-
|
| 215 |
-
Fast local policy. Useful for:
|
| 216 |
-
|
| 217 |
-
- tests,
|
| 218 |
-
- development,
|
| 219 |
-
- cheap training rollouts,
|
| 220 |
-
- baseline comparison.
|
| 221 |
-
|
| 222 |
-
### `llm_live`
|
| 223 |
-
|
| 224 |
-
Calls an OpenAI-compatible LLM with a fixed prompt. Useful for:
|
| 225 |
-
|
| 226 |
-
- demos,
|
| 227 |
-
- realistic opponent behavior,
|
| 228 |
-
- self-play-style experiments.
|
| 229 |
-
|
| 230 |
-
The current default live opponent/captain model is:
|
| 231 |
-
|
| 232 |
-
```text
|
| 233 |
-
google/gemma-4-26B-A4B-it via https://router.huggingface.co/v1
|
| 234 |
```
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
### `llm_cached`
|
| 239 |
-
|
| 240 |
-
Reads pre-recorded opponent decisions from JSONL. Useful for:
|
| 241 |
|
| 242 |
-
|
| 243 |
-
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
|
| 248 |
-
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
It supports:
|
| 257 |
-
|
| 258 |
-
1. Synthetic transition probabilities from `data/transition_probs.json`.
|
| 259 |
-
2. Cricsheet-derived transition tables from `data/processed/cricket_transitions_v1.pkl`.
|
| 260 |
-
|
| 261 |
-
The upgraded ball resolver uses both sides' plans:
|
| 262 |
-
|
| 263 |
-
```text
|
| 264 |
-
outcome ~ P(
|
| 265 |
-
shot_plan,
|
| 266 |
-
delivery_plan,
|
| 267 |
-
batter_profile,
|
| 268 |
-
bowler_profile,
|
| 269 |
-
field_setting,
|
| 270 |
-
phase,
|
| 271 |
-
score,
|
| 272 |
-
wickets,
|
| 273 |
-
target_pressure
|
| 274 |
-
)
|
| 275 |
-
```
|
| 276 |
|
| 277 |
-
The
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
|
| 280 |
-
- anchor benefits low-risk rotation,
|
| 281 |
-
- yorker/death specialist suppresses boundaries,
|
| 282 |
-
- shot target zones (`cover`, `point`, `midwicket`, `long_on`, etc.) are matched against delivery line/length/variation,
|
| 283 |
-
- field presets expand `Aggressive`, `Balanced`, and `Defensive` into named fielder zones,
|
| 284 |
-
- boundary riders can cut off fours/sixes and inner ring fielders can save singles,
|
| 285 |
-
- close catchers/slips/gully can convert edges into wickets,
|
| 286 |
-
- wides/no-balls, drops, misfields, overthrows, run-outs, bowled/LBW routes, and caught-in-zone events add bounded stochastic noise,
|
| 287 |
-
- high chase pressure makes defensive batting less useful.
|
| 288 |
|
| 289 |
-
|
| 290 |
|
| 291 |
-
|
| 292 |
|
| 293 |
-
|
| 294 |
|
| 295 |
-
|
| 296 |
|
| 297 |
-
|
| 298 |
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
- score vs DLS/par,
|
| 303 |
-
- chase success,
|
| 304 |
-
- defense success,
|
| 305 |
-
- wickets preserved or taken.
|
| 306 |
|
| 307 |
-
|
| 308 |
|
| 309 |
-
|
| 310 |
|
| 311 |
-
|
| 312 |
|
| 313 |
-
-
|
| 314 |
-
- bowling wickets, dots, economy,
|
| 315 |
-
- milestone and dismissal bonuses.
|
| 316 |
|
| 317 |
-
|
| 318 |
|
| 319 |
-
|
|
|
|
|
|
|
| 320 |
|
| 321 |
-
|
| 322 |
|
| 323 |
-
|
| 324 |
|
| 325 |
-
-
|
| 326 |
-
- Declared aggression `0.30` plus `six` is less coherent.
|
| 327 |
|
| 328 |
-
###
|
| 329 |
|
| 330 |
-
|
| 331 |
|
| 332 |
-
|
| 333 |
-
- target pressure,
|
| 334 |
-
- wickets down,
|
| 335 |
-
- previous reflection,
|
| 336 |
-
- opponent behavior.
|
| 337 |
|
| 338 |
-
|
| 339 |
|
| 340 |
-
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
|
| 343 |
-
-
|
| 344 |
-
-
|
| 345 |
-
-
|
| 346 |
-
-
|
|
|
|
| 347 |
|
| 348 |
-
##
|
| 349 |
|
| 350 |
-
|
| 351 |
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
|
| 356 |
-
|
| 357 |
|
| 358 |
-
|
| 359 |
|
| 360 |
-
|
| 361 |
|
| 362 |
-
```text
|
| 363 |
-
25% result quality / match outcome
|
| 364 |
-
10% Dream11 dense cricket proxy
|
| 365 |
-
30% strategy bundle
|
| 366 |
-
20% tool efficiency
|
| 367 |
-
15% format validity
|
| 368 |
```
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
The data pipeline is designed to keep training and environment behavior aligned.
|
| 375 |
-
|
| 376 |
-
### Step 1: Curate Ball Outcomes
|
| 377 |
-
|
| 378 |
-
Script:
|
| 379 |
-
|
| 380 |
-
```bash
|
| 381 |
-
python scripts/curate_transitions.py --format t20
|
| 382 |
```
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
- `data/processed/ball_outcomes_t20_v1.pkl`
|
| 387 |
-
- `data/processed/cricket_transitions_v1.pkl`
|
| 388 |
|
| 389 |
-
The
|
|
|
|
|
|
|
| 390 |
|
| 391 |
-
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
- batter name,
|
| 396 |
-
- bowler name,
|
| 397 |
-
- bowler type,
|
| 398 |
-
- dismissal type,
|
| 399 |
-
- phase,
|
| 400 |
-
- score before ball,
|
| 401 |
-
- wickets before ball,
|
| 402 |
-
- runs and wicket outcome.
|
| 403 |
|
| 404 |
-
##
|
| 405 |
|
| 406 |
-
|
| 407 |
|
| 408 |
```bash
|
| 409 |
-
python scripts/
|
|
|
|
| 410 |
```
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
- `data/processed/
|
| 415 |
-
|
| 416 |
-
Profiles include:
|
| 417 |
-
|
| 418 |
-
- batter style: anchor, balanced, hitter, finisher,
|
| 419 |
-
- bowler style: pace, spin, death specialist, economy, wicket-taker,
|
| 420 |
-
- phase strengths,
|
| 421 |
-
- economy,
|
| 422 |
-
- strike rate,
|
| 423 |
-
- dot rate,
|
| 424 |
-
- Dream11-style pressure proxy.
|
| 425 |
|
| 426 |
-
### Step
|
| 427 |
-
|
| 428 |
-
Script:
|
| 429 |
|
| 430 |
```bash
|
| 431 |
python scripts/build_eval_pack.py --eval-pack-id adaptive_t20_v1
|
| 432 |
```
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
- `data/eval_packs/adaptive_t20_v1.json`
|
| 437 |
-
|
| 438 |
-
The pack has:
|
| 439 |
-
|
| 440 |
-
- dev scenarios,
|
| 441 |
-
- official scenarios,
|
| 442 |
-
- chase states,
|
| 443 |
-
- defense states,
|
| 444 |
-
- collapse states,
|
| 445 |
-
- death-over states,
|
| 446 |
-
- matchup states,
|
| 447 |
-
- opponent config.
|
| 448 |
-
|
| 449 |
-
### Step 4: Generate Opponent Cache
|
| 450 |
-
|
| 451 |
-
Script:
|
| 452 |
|
| 453 |
```bash
|
| 454 |
python scripts/generate_opponent_cache.py \
|
|
@@ -458,199 +306,90 @@ python scripts/generate_opponent_cache.py \
|
|
| 458 |
--output data/opponent_cache/adaptive_t20_v1.jsonl
|
| 459 |
```
|
| 460 |
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
## 10. Training Pipeline
|
| 464 |
-
|
| 465 |
-
The training plan uses SFT only as a bootstrapping stage. The main optimization path remains GRPO.
|
| 466 |
-
|
| 467 |
-
### Stage 0: SFT Tool Warmup
|
| 468 |
|
| 469 |
-
|
| 470 |
|
| 471 |
```bash
|
| 472 |
-
|
|
|
|
|
|
|
| 473 |
```
|
| 474 |
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
- teach valid JSON,
|
| 478 |
-
- teach tool names,
|
| 479 |
-
- teach argument schemas,
|
| 480 |
-
- reduce parse errors before RL.
|
| 481 |
-
|
| 482 |
-
This does not replace RL. It makes RL less wasteful.
|
| 483 |
-
|
| 484 |
-
### Stage 1: GRPO Format / Tool Correctness
|
| 485 |
|
| 486 |
-
|
| 487 |
|
| 488 |
```bash
|
| 489 |
-
|
|
|
|
| 490 |
```
|
| 491 |
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
- train valid tool calls,
|
| 495 |
-
- reduce invalid JSON,
|
| 496 |
-
- stabilize action format.
|
| 497 |
-
|
| 498 |
-
### Stage 2: GRPO Strategic Behavior
|
| 499 |
-
|
| 500 |
-
Command:
|
| 501 |
|
| 502 |
```bash
|
| 503 |
-
|
|
|
|
| 504 |
```
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
- improve coherence,
|
| 509 |
-
- improve adaptation,
|
| 510 |
-
- improve opponent awareness,
|
| 511 |
-
- improve tool efficiency,
|
| 512 |
-
- improve match result quality.
|
| 513 |
-
|
| 514 |
-
### Evaluation
|
| 515 |
-
|
| 516 |
-
Command:
|
| 517 |
|
| 518 |
```bash
|
| 519 |
-
python
|
| 520 |
```
|
| 521 |
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
- random/untrained baseline,
|
| 525 |
-
- SFT-warmed model,
|
| 526 |
-
- GRPO-trained model.
|
| 527 |
-
|
| 528 |
-
For the competition, we should produce plots showing:
|
| 529 |
-
|
| 530 |
-
- reward over training,
|
| 531 |
-
- parse error rate,
|
| 532 |
-
- coherence,
|
| 533 |
-
- adaptation,
|
| 534 |
-
- opponent awareness,
|
| 535 |
-
- score/chase/defense metrics.
|
| 536 |
-
|
| 537 |
-
## 11. How This Complies With The Competition Instructions
|
| 538 |
-
|
| 539 |
-
The competition requires:
|
| 540 |
-
|
| 541 |
-
### Use OpenEnv
|
| 542 |
-
|
| 543 |
-
Implemented through:
|
| 544 |
-
|
| 545 |
-
- `server/app.py`
|
| 546 |
-
- `server/cricket_environment.py`
|
| 547 |
-
- `models.py`
|
| 548 |
-
- `client.py`
|
| 549 |
-
|
| 550 |
-
The environment follows `reset`, `step`, and `state`.
|
| 551 |
-
|
| 552 |
-
### Training Script With HF TRL / Unsloth
|
| 553 |
-
|
| 554 |
-
Implemented through:
|
| 555 |
-
|
| 556 |
-
- `train.py`
|
| 557 |
-
|
| 558 |
-
It uses Hugging Face TRL GRPO paths when training dependencies are installed.
|
| 559 |
-
|
| 560 |
-
### Hosted Environment
|
| 561 |
-
|
| 562 |
-
The repo has Hugging Face Spaces metadata in `README.md` and a Docker-based app path. The server binds to `0.0.0.0`, and remote clients should use `CRICKET_CAPTAIN_ENV_URL`.
|
| 563 |
-
|
| 564 |
-
### README With Problem / Environment / Results
|
| 565 |
-
|
| 566 |
-
The README now explains:
|
| 567 |
-
|
| 568 |
-
- problem statement,
|
| 569 |
-
- tools,
|
| 570 |
-
- reward architecture,
|
| 571 |
-
- environment design,
|
| 572 |
-
- data pipeline,
|
| 573 |
-
- Lightning/HF runtime notes.
|
| 574 |
-
|
| 575 |
-
Still needed for final submission:
|
| 576 |
-
|
| 577 |
-
- actual HF Space URL,
|
| 578 |
-
- training result plots,
|
| 579 |
-
- mini-blog/video link.
|
| 580 |
-
|
| 581 |
-
### Show Improvement
|
| 582 |
-
|
| 583 |
-
The environment and scripts support this, but the final artifact still needs a real training run with plots.
|
| 584 |
-
|
| 585 |
-
Minimum evidence to add:
|
| 586 |
-
|
| 587 |
-
- random baseline metrics,
|
| 588 |
-
- trained model metrics,
|
| 589 |
-
- reward curve,
|
| 590 |
-
- parse error curve,
|
| 591 |
-
- example before/after decisions.
|
| 592 |
-
|
| 593 |
-
## 12. Recommended Demo Story
|
| 594 |
-
|
| 595 |
-
A simple judge-friendly demo:
|
| 596 |
-
|
| 597 |
-
1. Show a late chase scenario:
|
| 598 |
-
|
| 599 |
-
```text
|
| 600 |
-
Over 16.0, 128/5, target 172
|
| 601 |
-
```
|
| 602 |
|
| 603 |
-
|
| 604 |
|
| 605 |
-
|
| 606 |
-
- may attack blindly,
|
| 607 |
-
- may ignore field/opponent.
|
| 608 |
|
| 609 |
-
|
| 610 |
|
| 611 |
-
|
| 612 |
-
-
|
| 613 |
-
|
| 614 |
-
-
|
| 615 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
|
| 617 |
-
|
| 618 |
|
| 619 |
-
-
|
| 620 |
-
- adaptation up,
|
| 621 |
-
- opponent awareness up,
|
| 622 |
-
- reward up.
|
| 623 |
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
-
|
| 627 |
|
| 628 |
-
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
-
|
| 631 |
|
| 632 |
-
|
| 633 |
-
- Rich strategic tool surface.
|
| 634 |
-
- Opponent policies.
|
| 635 |
-
- Adaptive eval pack.
|
| 636 |
-
- T20 data curation script.
|
| 637 |
-
- Player profile builder.
|
| 638 |
-
- Opponent cache generator.
|
| 639 |
-
- GRPO training script path.
|
| 640 |
-
- SFT bootstrap data generator.
|
| 641 |
-
- Eval and plotting scripts.
|
| 642 |
|
| 643 |
-
|
| 644 |
|
| 645 |
-
|
| 646 |
-
- Lint checks.
|
| 647 |
-
- `train.py test`.
|
| 648 |
-
- `train.py sft-data`.
|
| 649 |
-
- opponent cache generation.
|
| 650 |
-
- server startup.
|
| 651 |
-
- random inference run.
|
| 652 |
-
- eval run with plots.
|
| 653 |
|
| 654 |
-
|
| 655 |
|
| 656 |
-
|
|
|
|
| 30 |
|
| 31 |
## 2. Fit With OpenEnv Competition Themes
|
| 32 |
|
|
|
|
|
|
|
| 33 |
### Multi-Agent Interactions
|
| 34 |
|
| 35 |
The submitted captain agent plays against an opponent policy. The opponent can be:
|
| 36 |
|
| 37 |
+
- `heuristic`: fast format-aware cricket logic (T5/T20/ODI rules).
|
| 38 |
+
- `cricsheet`: real Cricsheet ball-by-ball match data sampled by game context.
|
| 39 |
+
- `llm_live`: live OpenAI-compatible LLM opponent (google/gemma-4-26B-A4B-it via HF Router).
|
| 40 |
- `llm_cached`: replayed opponent decisions for reproducible evaluation.
|
| 41 |
|
| 42 |
This tests whether the agent can reason about another actor's incentives, field settings, and likely plans.
|
|
|
|
| 47 |
|
| 48 |
### World Modeling
|
| 49 |
|
| 50 |
+
The agent observes a partially summarized cricket world: score, over/ball, wickets, target, phase, field, batter profile, bowler profile, previous outcome. It must maintain an internal model of what is happening and update that model after every ball.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
### Self-Improvement
|
| 53 |
|
| 54 |
+
The same environment can support heuristic curriculum training, cached-opponent official evaluation, live LLM opponent self-play, and future agent-vs-agent training.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
## 3. Environment Flow
|
| 57 |
|
|
|
|
| 67 |
PRE_OVER -> PRE_BALL -> BALL_RESOLUTION -> POST_BALL -> next decision
|
| 68 |
```
|
| 69 |
|
|
|
|
|
|
|
| 70 |
### Toss
|
| 71 |
|
| 72 |
```json
|
|
|
|
| 76 |
### Batting Tools
|
| 77 |
|
| 78 |
```json
|
| 79 |
+
{"tool": "select_batter", "arguments": {"name": "Virat Kohli", "style": "anchor", "aggression": 0.35, "rationale": "Preserve wickets in the middle overs."}}
|
|
|
|
|
|
|
|
|
|
| 80 |
{"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Rotate strike against spin and keep wickets in hand."}}
|
|
|
|
|
|
|
|
|
|
| 81 |
{"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy single."}}
|
|
|
|
|
|
|
|
|
|
| 82 |
{"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Work into the gap."}}
|
| 83 |
```
|
| 84 |
|
| 85 |
+
`plan_shot` is **not** an overhead tool. Only `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, and `analyze_situation` count against the 3 / over limit (see Tool budget).
|
| 86 |
|
| 87 |
+
### Bowling Tools
|
|
|
|
|
|
|
| 88 |
|
| 89 |
```json
|
| 90 |
+
{"tool": "choose_bowler", "arguments": {"name": "Jasprit Bumrah", "bowler_type": "pace", "style": "yorker", "rationale": "Attack the stumps at the death."}}
|
| 91 |
{"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Limit swing room."}}
|
|
|
|
|
|
|
|
|
|
| 92 |
{"tool": "set_field_setting", "arguments": {"setting": "Defensive"}}
|
|
|
|
|
|
|
|
|
|
| 93 |
{"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Protect boundaries and force a low-percentage shot."}}
|
|
|
|
|
|
|
|
|
|
| 94 |
{"tool": "bowl_delivery", "arguments": {}}
|
| 95 |
```
|
| 96 |
|
|
|
|
| 106 |
{"tool": "analyze_situation", "arguments": {"query_type": "match_situation"}}
|
| 107 |
```
|
| 108 |
|
| 109 |
+
## 4. Tool budget
|
| 110 |
+
|
| 111 |
+
The environment enforces a **3-call overhead budget per over** (see `CricketEnvironment.TOOL_BUDGET_PER_OVER` and `TOOL_FINE_PER_EXCESS` in `server/cricket_environment.py`).
|
| 112 |
+
|
| 113 |
+
**Overhead tools** (increment the per-over counter; the 4th+ in the same over are fined):
|
| 114 |
+
`set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, `analyze_situation`
|
| 115 |
+
|
| 116 |
+
**Not overhead** (do not use the 3 free “slots”):
|
| 117 |
+
`play_delivery`, `bowl_delivery`, `plan_shot`, `call_toss`, `select_batter`, `choose_bowler`, `set_field_setting`, `set_match_plan`, `update_match_plan`
|
| 118 |
|
| 119 |
+
Each overhead call **beyond the third in that over** incurs an immediate **−0.04** step reward. The prompt shows `Tool budget: N/3 overhead calls used this over`.
|
| 120 |
+
|
| 121 |
+
**Training connection:** `train.py train` uses real `CricketEnvironment` steps, so these fines are part of the return GRPO optimizes. That keeps long-horizon training aligned with the benchmark: agents must choose when to pay for `analyze_situation` and `reflect_after_ball`, while `set_match_plan` / `update_match_plan` let them carry structure across overs without spending overhead budget.
|
| 122 |
+
|
| 123 |
+
## 5. OpenEnv Architecture
|
| 124 |
|
| 125 |
```text
|
| 126 |
LLM Agent / Evaluator
|
| 127 |
|
|
| 128 |
+
| WebSocket (OpenEnv)
|
| 129 |
v
|
| 130 |
+
FastAPI server (server/app.py)
|
| 131 |
|
|
| 132 |
v
|
| 133 |
+
CricketEnvironment (server/cricket_environment.py)
|
| 134 |
|
|
| 135 |
+
+--> MarkovCricketEngine (server/markov_engine.py)
|
| 136 |
+
+--> FormatMapper (server/format_mapper.py)
|
| 137 |
+
+--> OpponentPolicy (server/opponent_policy.py)
|
| 138 |
+
+--> PlayerRoster (server/player_roster.py)
|
| 139 |
+
+--> CoherenceGrader (server/coherence_grader.py)
|
| 140 |
+
+--> RewardCalculator (server/reward_calculator.py)
|
| 141 |
+
+--> FieldModel (server/field_model.py)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
```
|
| 143 |
|
| 144 |
+
Key files:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
| File | Role |
|
| 147 |
+
|------|------|
|
| 148 |
+
| `server/app.py` | OpenEnv server entry point |
|
| 149 |
+
| `server/cricket_environment.py` | `reset`, `step`, `state` implementation |
|
| 150 |
+
| `server/format_mapper.py` | T5/T20/ODI closest-format selector; phase-aware shot weights, batter/bowler roles |
|
| 151 |
+
| `server/opponent_policy.py` | Heuristic, Cricsheet, live LLM, cached LLM opponent policies |
|
| 152 |
+
| `server/player_roster.py` | Fuzzy player lookup; batter/bowler profile extractor |
|
| 153 |
+
| `models.py` | `CricketAction`, `CricketObservation`, `CricketState` |
|
| 154 |
+
| `client.py` | WebSocket client `CricketCaptainEnv` |
|
| 155 |
+
| `inference.py` | Random + LLM agent evaluation |
|
| 156 |
+
| `train.py` | MT-GRPO + SFT training pipeline |
|
| 157 |
+
| `eval.py` | Coherence heatmaps, reward curves, tool analytics |
|
| 158 |
|
| 159 |
+
## 6. Format-Aware Rules
|
| 160 |
|
| 161 |
+
`server/format_mapper.py` auto-selects T5 / T20 / ODI rules by `|max_overs − format_overs|`:
|
| 162 |
|
| 163 |
+
| Format | max_overs | Key differences |
|
| 164 |
+
|--------|-----------|-----------------|
|
| 165 |
+
| T5 | ≤ 7 | High-aggression throughout, powerplay dominates all overs |
|
| 166 |
+
| T20 | 8–35 | Three phases (PP/Middle/Death); spin-heavy middle |
|
| 167 |
+
| ODI | > 35 | Four phases (PP/Middle-early/Middle-late/Death); anchor roles |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
The format mapper provides:
|
| 170 |
+
- **Phase-aware shot weights**: boundary/six probability rises sharply in death overs
|
| 171 |
+
- **Batter roles** with `overs_active` windows (opener, anchor, middle_order, finisher)
|
| 172 |
+
- **Bowler roles** with `preferred_phases` (pace_opener, spin_controller, death_specialist)
|
| 173 |
+
- **Bowling strategy** per phase (line, length, delivery_type, field_setting)
|
| 174 |
|
| 175 |
+
Both the heuristic opponent and the `select_batter` / `choose_bowler` tools draw from these tables.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
## 7. Player Rosters
|
| 178 |
|
| 179 |
+
`server/player_roster.py` loads team profiles from `data/player_profiles/` — 10 T20I squads: India, Australia, England, Pakistan, South Africa, New Zealand, West Indies, Sri Lanka, Bangladesh, Afghanistan.
|
| 180 |
|
| 181 |
+
When the agent calls `select_batter` or `choose_bowler` with a player name, the roster performs **fuzzy lookup** (exact → surname → word-overlap) and fills in real aggression, batting/bowling style, and phase strengths from the profile.
|
| 182 |
|
| 183 |
+
## 8. What The Observation Contains
|
| 184 |
|
| 185 |
+
Each step returns a `CricketObservation` with:
|
| 186 |
|
| 187 |
+
- `game_state`: toss / batting / bowling / finished
|
| 188 |
+
- `strategic_phase`: pre_over / pre_ball / ball_resolution / post_ball
|
| 189 |
+
- `game_context`: score, wickets, over, ball, target, phase, run_rate, req_rate
|
| 190 |
+
- `declared_strategy`: current batting strategy (aggression, intent, rationale)
|
| 191 |
+
- `bowling_strategy`: current bowling plan
|
| 192 |
+
- `field_setting`: Aggressive / Balanced / Defensive
|
| 193 |
+
- `current_batter`: batter profile (style, aggression, phase strengths)
|
| 194 |
+
- `current_bowler`: bowler profile
|
| 195 |
+
- `last_outcome`: ball outcome + tactical metadata (event type, shot zone, delivery features, field pressure, fielder effect)
|
| 196 |
+
- `available_tools`: legal tools for current state (phase-gated)
|
| 197 |
+
- `tool_budget`: overhead calls used this over vs 3-call limit
|
| 198 |
+
- `prompt_text`: rendered prompt for the LLM
|
| 199 |
|
| 200 |
+
The LLM sees enough information to reason tactically, but not simulator internals.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
## 9. Opponent Policies
|
| 203 |
|
| 204 |
+
Four modes in `server/opponent_policy.py`:
|
| 205 |
|
| 206 |
+
### `heuristic`
|
| 207 |
|
| 208 |
+
Format-aware local policy using T5/T20/ODI rules from `format_mapper.py`. Picks shot intent from phase-weighted distributions, adjusts for wicket pressure (shifts conservative under 7+ wickets down), and selects batter/bowler roles by current over and format. Fast, no API key needed.
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
### `cricsheet`
|
| 211 |
|
| 212 |
+
Samples real Cricsheet ball-by-ball deliveries indexed by `(phase, wickets_band, innings_type)`. Automatically selects T20 or ODI data based on `max_overs`:
|
| 213 |
+
- ≤ 25 overs �� `ball_outcomes_t20_v1.pkl` (1.17M T20 deliveries from 5,176 matches)
|
| 214 |
+
- > 25 overs → `ball_outcomes_odi_v1.pkl` (1.65M ODI deliveries from 3,116 matches)
|
| 215 |
|
| 216 |
+
Progressive fallback widening (drop innings_type → drop wickets_band → any phase record) ensures no dead buckets. Heuristic fallback if data file absent.
|
| 217 |
|
| 218 |
+
### `llm_live`
|
| 219 |
|
| 220 |
+
Calls `google/gemma-4-26B-A4B-it` via HF Router (or any OpenAI-compatible API). Graceful heuristic fallback when no API key is present, so local development never breaks.
|
|
|
|
| 221 |
|
| 222 |
+
### `llm_cached`
|
| 223 |
|
| 224 |
+
Replays pre-recorded opponent decisions from JSONL. Does **not** call the configured model live. Use for official leaderboard-style evaluation where every compared captain faces identical opponent decisions.
|
| 225 |
|
| 226 |
+
## 10. Ball Physics And Markov Engine
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
The simulation uses `server/markov_engine.py` plus field/zone definitions in `server/field_model.py`.
|
| 229 |
|
| 230 |
+
Ball transition tables keyed by `(over, wickets, score_band, phase, bowler_type)`:
|
| 231 |
+
1. **Cricsheet-derived**: `data/processed/cricket_transitions_v1.pkl` when available
|
| 232 |
+
2. **Calibrated synthetic**: `data/transition_probs.json` as fallback
|
| 233 |
|
| 234 |
+
After the base Markov draw, a **hybrid tactical layer** applies:
|
| 235 |
+
- Shot target zones (`cover`, `point`, `midwicket`, `long_on`, …) matched against delivery line/length/variation
|
| 236 |
+
- Field presets (`Aggressive`, `Balanced`, `Defensive`) expand into named fielder zones
|
| 237 |
+
- Boundary riders cut off fours/sixes; inner-ring fielders save singles; slips/gully convert edges
|
| 238 |
+
- Wides/no-balls, drops, misfields, overthrows, run-outs, caught-in-zone events add bounded stochastic noise
|
| 239 |
+
- High chase pressure makes defensive batting less useful
|
| 240 |
|
| 241 |
+
## 11. Reward Design
|
| 242 |
|
| 243 |
+
Four-rubric composite reward:
|
| 244 |
|
| 245 |
+
| Rubric | Weight | Frequency | Measures |
|
| 246 |
+
|--------|--------|-----------|----------|
|
| 247 |
+
| `r_cricket` | **45%** | Per ball | Dream11 proxy: runs, wickets, dots, milestones, economy, strike rate |
|
| 248 |
+
| `r_behavior` | **25%** | Every turn | Coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%) |
|
| 249 |
+
| `r_result` | **20%** | Innings/episode end | Win/loss vs DLS par, target margin, wickets |
|
| 250 |
+
| `r_validity` | **10%** | Every turn | Valid tool-call structure and legal phase-gated tool use |
|
| 251 |
|
| 252 |
+
Plus a **progress bonus** added to `r_result`: `min(0.25, tool_calls_made / 40.0)` — caps at +0.25 once the agent makes ≥10 tool calls. Directly rewards escaping the planning-loop trap (where the policy maxes overhead tools without ever calling `play_delivery`).
|
| 253 |
|
| 254 |
+
**Why these weights** (rebalanced from the original 55/25/15/5): partial-trajectory training means `r_result` rarely fires (episodes truncate before completion). Putting 55% weight on a signal that fires <5% of the time washes out the gradient. The new 45/25/20/10 split mirrors the SWE-RL recipe (60% intermediate / 40% terminal) and matches what working coding-agent RL setups actually use.
|
| 255 |
|
| 256 |
+
`r_tools` is computed and logged but excluded from the composite — tool discipline is measured through outcomes.
|
| 257 |
|
| 258 |
+
### Coherence Scoring (batting)
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
```
|
| 261 |
+
coherence = aggression_match × rationale_specificity × phase_appropriate
|
| 262 |
+
aggression_match = 1 − |declared_aggression − shot_aggression_proxy|
|
| 263 |
+
rationale_specificity = (word_count_score + cricket_keyword_density) / 2
|
| 264 |
+
phase_appropriate = 1 − |declared_aggression − phase_baseline|
|
| 265 |
+
phase_baselines: powerplay=0.55, middle=0.35, death=0.75
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
```
|
| 267 |
|
| 268 |
+
### Single-Stage Training with Format Curriculum
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
The original two-stage (format → strategy) curriculum was collapsed because Qwen3.5-4B
|
| 271 |
+
already does tool calling natively (XML+JSON via `_parse_completion`). The full composite
|
| 272 |
+
reward fires from step 0.
|
| 273 |
|
| 274 |
+
What remains is a **format-length curriculum within the warmup config**: per-scenario
|
| 275 |
+
`max_overs` is sampled from `[2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]` (heavy on T2-T3 so episodes
|
| 276 |
+
actually complete inside the token budget). The main run then trains on full T20 (20-over)
|
| 277 |
+
matches, optionally resuming from the warmup adapter.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
+
## 12. Data Curation Pipeline
|
| 280 |
|
| 281 |
+
### Step 1: Curate Ball Outcomes
|
| 282 |
|
| 283 |
```bash
|
| 284 |
+
python scripts/curate_transitions.py --format t20 # → ball_outcomes_t20_v1.pkl
|
| 285 |
+
python scripts/curate_transitions.py --format odi # → ball_outcomes_odi_v1.pkl
|
| 286 |
```
|
| 287 |
|
| 288 |
+
Both files already generated:
|
| 289 |
+
- `data/processed/ball_outcomes_t20_v1.pkl` — 1.17M T20 deliveries, 5,176 matches
|
| 290 |
+
- `data/processed/ball_outcomes_odi_v1.pkl` — 1.65M ODI deliveries, 3,116 matches
|
| 291 |
+
- `data/processed/cricket_transitions_v1.pkl` — 5,138 Markov keys, 2,878 high-confidence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
### Step 2: Build Evaluation Pack
|
|
|
|
|
|
|
| 294 |
|
| 295 |
```bash
|
| 296 |
python scripts/build_eval_pack.py --eval-pack-id adaptive_t20_v1
|
| 297 |
```
|
| 298 |
|
| 299 |
+
### Step 3: Generate Opponent Cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
```bash
|
| 302 |
python scripts/generate_opponent_cache.py \
|
|
|
|
| 306 |
--output data/opponent_cache/adaptive_t20_v1.jsonl
|
| 307 |
```
|
| 308 |
|
| 309 |
+
## 13. Training Pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
+
### Recommended: Single-Command Chain
|
| 312 |
|
| 313 |
```bash
|
| 314 |
+
# Warmup (5-over curriculum, 25 steps) → Main (20-over T20, 100 steps).
|
| 315 |
+
# Main auto-resumes from warmup adapter at ./checkpoints/stage2_final.
|
| 316 |
+
bash scripts/run_warmup_then_main.sh
|
| 317 |
```
|
| 318 |
|
| 319 |
+
### Run Components Individually
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
+
**Warmup only — short curriculum, bootstraps the LoRA adapter:**
|
| 322 |
|
| 323 |
```bash
|
| 324 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 325 |
+
python train.py train --config configs/cricket_train_warmup.yaml
|
| 326 |
```
|
| 327 |
|
| 328 |
+
**Main only — full T20, resumes warmup adapter (or fresh if `resume_from` is empty):**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
```bash
|
| 331 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 332 |
+
python train.py train --config configs/cricket_train.yaml
|
| 333 |
```
|
| 334 |
|
| 335 |
+
**Optional SFT bootstrap** (legacy, not needed for Qwen3.5-4B which has native tool calling):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
```bash
|
| 338 |
+
python train.py sft-data --output data/training/tool_sft_examples.jsonl
|
| 339 |
```
|
| 340 |
|
| 341 |
+
`train.py train` uses TRL `GRPOTrainer` with `environment_factory=CricketCaptainToolEnv`. The captain being trained is loaded locally by Transformers/TRL and interacts with live environment instances through tool methods. `opponent-mode llm_live` affects only the adversary; it does not mean the trained captain is served through the HF inference endpoint.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
+
The default training model is `Qwen/Qwen3.5-4B`. The default live opponent model is `google/gemma-4-26B-A4B-it`. Roster-backed training requires `--agent-team` or `env.agent_team` in YAML so `select_batter` and `choose_bowler` use real player profiles instead of generic names.
|
| 344 |
|
| 345 |
+
## 14. Current Status (2026-04-25)
|
|
|
|
|
|
|
| 346 |
|
| 347 |
+
### Implemented and verified
|
| 348 |
|
| 349 |
+
| Component | Status |
|
| 350 |
+
|-----------|--------|
|
| 351 |
+
| OpenEnv server + client | ✅ |
|
| 352 |
+
| 14-tool strategic surface | ✅ |
|
| 353 |
+
| 4-rubric reward system | ✅ |
|
| 354 |
+
| Tool budget system (3/over, −0.04 fine) | ✅ |
|
| 355 |
+
| Format mapper (T5/T20/ODI) | ✅ |
|
| 356 |
+
| Player rosters (10 T20I teams, fuzzy lookup) | ✅ |
|
| 357 |
+
| Cricsheet T20 data (1.17M deliveries) | ✅ |
|
| 358 |
+
| Cricsheet ODI data (1.65M deliveries) | ✅ |
|
| 359 |
+
| Heuristic opponent (format-aware) | ✅ |
|
| 360 |
+
| Cricsheet opponent (T20+ODI, context-indexed) | ✅ |
|
| 361 |
+
| LLM live opponent (HF Router / OpenAI-compatible API) | ✅ |
|
| 362 |
+
| LLM cached opponent | ✅ |
|
| 363 |
+
| GRPO training script (`environment_factory` agent rollouts) | ✅ |
|
| 364 |
+
| SFT data generator | ✅ |
|
| 365 |
+
| Gradio demo UI | ✅ |
|
| 366 |
+
| Colab training notebook | ✅ |
|
| 367 |
|
| 368 |
+
### Verified end-to-end (2026-04-25)
|
| 369 |
|
| 370 |
+
All 3 opponent modes verified at 5-over inference + train-smoke:
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
+
| Mode | inference parse_err | train-smoke r_validity | coherence |
|
| 373 |
+
|------|--------------------|-----------------------|-----------|
|
| 374 |
+
| heuristic | 0% | 1.0 | 0.556 |
|
| 375 |
+
| cricsheet | 0% | 1.0 | 0.620 |
|
| 376 |
+
| llm_live | 0% | 1.0 | 0.537 |
|
| 377 |
|
| 378 |
+
### Pending for submission
|
| 379 |
|
| 380 |
+
- Real GRPO training run with reward curves (requires HF compute)
|
| 381 |
+
- HF Space deployment URL
|
| 382 |
+
- Training-vs-baseline comparison plots
|
| 383 |
+
- Mini-blog / video
|
| 384 |
|
| 385 |
+
## 15. Recommended Demo Story
|
| 386 |
|
| 387 |
+
1. **Show a late chase scenario**: Over 16.0, 128/5, target 172
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
+
2. **Random/untrained model**: invalid tools, blind aggression, ignores field/opponent
|
| 390 |
|
| 391 |
+
3. **Trained model**: checks target pressure → selects finisher → plans boundary zones → responds after wicket → changes risk level
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
4. **Show metrics**: parse errors ↓, coherence ↑, adaptation ↑, opponent_awareness ↑, reward ↑
|
| 394 |
|
| 395 |
+
> The model learned to captain, not just emit a valid tool-call object.
|
docs/experiment_workflow.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
# Experiment Workflow: Baselines, Opponents, Short Runs, and Training
|
| 2 |
|
| 3 |
-
This document explains how to run CricketCaptain experiments in a practical order:
|
| 4 |
|
| 5 |
-
## 1. Why Start With
|
| 6 |
|
| 7 |
-
A full T20 innings is 20 overs. That is useful for final evaluation
|
| 8 |
|
| 9 |
-
For early
|
| 10 |
|
| 11 |
- Is the OpenEnv server working?
|
| 12 |
- Is the client connecting correctly?
|
|
@@ -19,461 +19,267 @@ For early code-path experiments, 2-over smoke runs are better because they quick
|
|
| 19 |
The workflow should be:
|
| 20 |
|
| 21 |
```text
|
| 22 |
-
|
| 23 |
```
|
| 24 |
|
| 25 |
-
Do not start with full 20-over training unless the
|
| 26 |
|
| 27 |
-
## 2.
|
| 28 |
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
## 2.1 Heuristic Opponent
|
| 34 |
|
| 35 |
```bash
|
| 36 |
-
|
| 37 |
```
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
- fast local tests,
|
| 46 |
-
- smoke tests,
|
| 47 |
-
- cheap training rollouts,
|
| 48 |
-
- deterministic-ish baselines.
|
| 49 |
-
|
| 50 |
-
Pros:
|
| 51 |
-
|
| 52 |
-
- cheap,
|
| 53 |
-
- fast,
|
| 54 |
-
- no API key,
|
| 55 |
-
- stable enough for debugging.
|
| 56 |
-
|
| 57 |
-
Cons:
|
| 58 |
-
|
| 59 |
-
- less realistic than an LLM opponent,
|
| 60 |
-
- less diverse.
|
| 61 |
-
|
| 62 |
-
## 2.2 Live LLM Opponent
|
| 63 |
|
| 64 |
```bash
|
| 65 |
-
|
| 66 |
-
export CRICKET_OPPONENT_MODEL=google/gemma-4-26B-A4B-it
|
| 67 |
-
export CRICKET_OPPONENT_API_BASE=https://router.huggingface.co/v1
|
| 68 |
-
export HF_TOKEN=...
|
| 69 |
```
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
```text
|
| 74 |
-
google/gemma-4-26B-A4B-it
|
| 75 |
-
```
|
| 76 |
-
|
| 77 |
-
This mode calls an OpenAI-compatible API from `LLMOpponentPolicy`.
|
| 78 |
-
|
| 79 |
-
Use it for:
|
| 80 |
|
| 81 |
-
-
|
| 82 |
-
-
|
| 83 |
-
-
|
| 84 |
-
- future self-play-style experiments.
|
| 85 |
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
- can react with natural tactical reasoning,
|
| 90 |
-
- good for storytelling/demo.
|
| 91 |
-
|
| 92 |
-
Cons:
|
| 93 |
-
|
| 94 |
-
- costs API calls,
|
| 95 |
-
- can be non-deterministic,
|
| 96 |
-
- not ideal for official evaluation directly.
|
| 97 |
-
|
| 98 |
-
## 2.3 Cached LLM Opponent
|
| 99 |
|
| 100 |
```bash
|
| 101 |
-
export
|
| 102 |
-
|
| 103 |
```
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
-
Use
|
| 108 |
|
| 109 |
-
|
| 110 |
-
- leaderboard-style comparison,
|
| 111 |
-
- reproducible experiments.
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
llm_live once -> save opponent decisions -> llm_cached for all model comparisons
|
| 117 |
-
```
|
| 118 |
-
|
| 119 |
-
This gives the benefit of an LLM opponent while ensuring every model faces the same opponent decisions.
|
| 120 |
-
|
| 121 |
-
## 3. What Model Is The Opposite Team?
|
| 122 |
-
|
| 123 |
-
Currently, the default live opponent is:
|
| 124 |
-
|
| 125 |
-
```text
|
| 126 |
-
google/gemma-4-26B-A4B-it via https://router.huggingface.co/v1
|
| 127 |
```
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
```text
|
| 132 |
-
llm_live -> calls the configured model during the run
|
| 133 |
-
llm_cached -> ignores live model calls and replays cache_path
|
| 134 |
-
heuristic -> uses local rule-based cricket policy
|
| 135 |
-
```
|
| 136 |
|
| 137 |
-
|
| 138 |
|
| 139 |
```bash
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
For example:
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
```
|
| 149 |
|
| 150 |
-
|
| 151 |
|
|
|
|
| 152 |
```bash
|
| 153 |
-
export
|
| 154 |
-
export CRICKET_OPPONENT_MODEL=<local-model-name>
|
| 155 |
```
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
Before testing any trained model, run the random baseline.
|
| 160 |
-
|
| 161 |
-
Start the server:
|
| 162 |
-
|
| 163 |
```bash
|
| 164 |
-
|
| 165 |
```
|
| 166 |
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
|
| 169 |
```bash
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
--opponent-mode heuristic \
|
| 175 |
-
--eval-pack-id adaptive_t20_v1
|
| 176 |
```
|
| 177 |
|
| 178 |
-
|
| 179 |
|
| 180 |
-
|
| 181 |
-
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
|
| 186 |
-
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
Next, evaluate a base model without training.
|
| 191 |
-
|
| 192 |
-
Example:
|
| 193 |
|
| 194 |
```bash
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
--env-url "$CRICKET_CAPTAIN_ENV_URL" \
|
| 199 |
-
--opponent-mode heuristic \
|
| 200 |
-
--eval-pack-id adaptive_t20_v1
|
| 201 |
```
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
Expected weaknesses:
|
| 206 |
-
|
| 207 |
-
- may output invalid JSON,
|
| 208 |
-
- may choose wrong tools,
|
| 209 |
-
- may ignore opponent plan,
|
| 210 |
-
- may be verbose instead of tool-only,
|
| 211 |
-
- may not adapt after bad outcomes.
|
| 212 |
-
|
| 213 |
-
## 6. Why SFT Exists If We Use GRPO
|
| 214 |
-
|
| 215 |
-
SFT is not the main training objective. It is a warmup.
|
| 216 |
-
|
| 217 |
-
GRPO should optimize strategic behavior, but if the model cannot produce valid tool JSON, GRPO wastes rollouts learning syntax.
|
| 218 |
-
|
| 219 |
-
SFT helps the model learn:
|
| 220 |
-
|
| 221 |
-
- valid JSON shape,
|
| 222 |
-
- available tools,
|
| 223 |
-
- argument schemas,
|
| 224 |
-
- when tools are legal,
|
| 225 |
-
- one-tool-call responses.
|
| 226 |
-
|
| 227 |
-
Then GRPO can focus on:
|
| 228 |
-
|
| 229 |
-
- coherence,
|
| 230 |
-
- adaptation,
|
| 231 |
-
- opponent awareness,
|
| 232 |
-
- match result quality.
|
| 233 |
|
| 234 |
-
|
| 235 |
|
| 236 |
-
```
|
| 237 |
-
|
| 238 |
```
|
| 239 |
|
| 240 |
-
##
|
| 241 |
|
| 242 |
```bash
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
```
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
- analysis.
|
| 255 |
-
|
| 256 |
-
These examples are useful for quick tool-format finetuning.
|
| 257 |
|
| 258 |
-
##
|
| 259 |
|
| 260 |
```bash
|
| 261 |
-
python train.py
|
| 262 |
-
--stage 1 \
|
| 263 |
-
--steps 100 \
|
| 264 |
-
--prompts 200 \
|
| 265 |
-
--model Qwen/Qwen2.5-7B-Instruct
|
| 266 |
```
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
- reduce parse errors,
|
| 271 |
-
- make tool calls valid,
|
| 272 |
-
- stabilize action format.
|
| 273 |
-
|
| 274 |
-
Metrics to watch:
|
| 275 |
|
| 276 |
-
- format
|
| 277 |
-
- parse error rate,
|
| 278 |
-
- invalid tool rate.
|
| 279 |
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
|
| 282 |
```bash
|
| 283 |
-
|
| 284 |
-
--
|
| 285 |
-
--steps 200 \
|
| 286 |
-
--prompts 300 \
|
| 287 |
-
--model ./checkpoints/stage1_final
|
| 288 |
```
|
| 289 |
|
| 290 |
-
|
|
|
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
- improve tool efficiency,
|
| 296 |
-
- improve cricket result quality.
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
- total reward,
|
| 301 |
-
- coherence,
|
| 302 |
-
- adaptation,
|
| 303 |
-
- opponent awareness,
|
| 304 |
-
- regret-style score,
|
| 305 |
-
- score/wickets,
|
| 306 |
-
- chase/defense success.
|
| 307 |
-
|
| 308 |
-
## 10. 5-Over vs 20-Over Evaluation
|
| 309 |
-
|
| 310 |
-
### 5-Over Evaluation
|
| 311 |
-
|
| 312 |
-
Use for:
|
| 313 |
-
|
| 314 |
-
- debugging,
|
| 315 |
-
- model sanity checks,
|
| 316 |
-
- comparing before/after quickly,
|
| 317 |
-
- cheap experiments.
|
| 318 |
-
|
| 319 |
-
Both `inference.py` and `eval.py` support `--max-overs`, and the YAML configs set `max_overs: 5` by default for quick iteration.
|
| 320 |
-
|
| 321 |
-
Random captain sanity check:
|
| 322 |
|
| 323 |
```bash
|
| 324 |
-
|
| 325 |
-
--
|
| 326 |
-
--episodes 5 \
|
| 327 |
-
--max-overs 5 \
|
| 328 |
-
--env-url "$CRICKET_CAPTAIN_ENV_URL" \
|
| 329 |
-
--eval-pack-id adaptive_t20_v1 \
|
| 330 |
-
--opponent-mode llm_cached
|
| 331 |
```
|
| 332 |
|
| 333 |
-
|
|
|
|
| 334 |
|
| 335 |
-
``
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
--model google/gemma-4-26B-A4B-it \
|
| 339 |
-
--api-base https://router.huggingface.co/v1 \
|
| 340 |
-
--api-key "$HF_TOKEN" \
|
| 341 |
-
--episodes 1 \
|
| 342 |
-
--max-overs 5 \
|
| 343 |
-
--env-url "$CRICKET_CAPTAIN_ENV_URL" \
|
| 344 |
-
--eval-pack-id adaptive_t20_v1 \
|
| 345 |
-
--opponent-mode llm_cached
|
| 346 |
```
|
|
|
|
|
|
|
| 347 |
|
| 348 |
-
|
|
|
|
|
|
|
| 349 |
|
| 350 |
-
|
|
|
|
| 351 |
|
| 352 |
-
-
|
| 353 |
-
- README numbers,
|
| 354 |
-
- competition evidence,
|
| 355 |
-
- trained-vs-baseline comparison.
|
| 356 |
|
| 357 |
-
|
| 358 |
|
| 359 |
```bash
|
| 360 |
-
|
| 361 |
--episodes 20 \
|
| 362 |
-
--env-url
|
| 363 |
-
--eval-pack-id adaptive_t20_v1
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
## 11. Evaluation Pack
|
| 367 |
-
|
| 368 |
-
The main adaptive pack is:
|
| 369 |
-
|
| 370 |
-
```text
|
| 371 |
-
data/eval_packs/adaptive_t20_v1.json
|
| 372 |
```
|
| 373 |
|
| 374 |
-
|
| 375 |
|
| 376 |
-
|
| 377 |
-
- 60 official scenarios,
|
| 378 |
-
- chase states,
|
| 379 |
-
- defend states,
|
| 380 |
-
- death-over states,
|
| 381 |
-
- collapse states,
|
| 382 |
-
- matchup states.
|
| 383 |
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
|
| 387 |
|
| 388 |
-
|
| 389 |
|
| 390 |
-
```
|
| 391 |
-
1. Random baseline, 5-over
|
| 392 |
-
2. Base LLM baseline, 5-over
|
| 393 |
-
3. Training-side rollout smoke, 1 match / 5 overs
|
| 394 |
-
4. SFT warmup
|
| 395 |
-
5. GRPO stage 1, short
|
| 396 |
-
6. GRPO stage 2, short
|
| 397 |
-
7. Trained eval, 5-over
|
| 398 |
-
8. Trained eval, 20-over
|
| 399 |
-
9. Cached LLM opponent official eval
|
| 400 |
-
10. Add plots and before/after examples to README
|
| 401 |
-
```
|
| 402 |
|
| 403 |
-
|
| 404 |
|
| 405 |
-
```
|
| 406 |
-
python train.py train-smoke \
|
| 407 |
-
--matches 1 \
|
| 408 |
-
--max-overs 2 \
|
| 409 |
-
--max-steps 240 \
|
| 410 |
-
--log-steps 90 \
|
| 411 |
-
--eval-pack-id adaptive_t20_v1 \
|
| 412 |
-
--opponent-mode heuristic
|
| 413 |
-
```
|
| 414 |
|
| 415 |
-
|
| 416 |
|
| 417 |
-
|
| 418 |
|
| 419 |
-
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
| Base LLM baseline | Shows general LLM behavior before training |
|
| 425 |
-
| Trained model metrics | Shows improvement |
|
| 426 |
-
| Reward curve | Shows learning progress |
|
| 427 |
-
| Parse error curve | Shows tool-use improvement |
|
| 428 |
-
| Before/after examples | Makes the story clear |
|
| 429 |
-
| Eval against cached opponent | Shows fairness/reproducibility |
|
| 430 |
|
| 431 |
-
|
|
|
|
|
|
|
| 432 |
|
| 433 |
-
|
| 434 |
-
- parse error rate,
|
| 435 |
-
- coherence,
|
| 436 |
-
- adaptation,
|
| 437 |
-
- opponent awareness,
|
| 438 |
-
- score/wickets,
|
| 439 |
-
- chase or defense success rate.
|
| 440 |
|
| 441 |
-
## 14. Latest
|
| 442 |
|
| 443 |
-
|
| 444 |
|
| 445 |
-
```
|
| 446 |
-
Random
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
coherence
|
| 458 |
-
adaptation: 0.502
|
| 459 |
-
opponent awareness: 0.750
|
| 460 |
-
parse errors: 0.0%
|
| 461 |
-
|
| 462 |
-
Latest captured training-side smoke, 1 match / 5 overs:
|
| 463 |
-
first innings: opponent 30/6, target 31
|
| 464 |
-
first-innings reward: +0.170 from par/run-rate/wicket context
|
| 465 |
-
chase: 26/1 in 5 overs
|
| 466 |
-
match result: loss
|
| 467 |
-
terminal reward: 0.634 (r_cric=0.759, r_dream11=1.317, r_strategy=0.536)
|
| 468 |
-
tactical events: deep-cover save, edge/catch chances, no-ball, misfield, caught-in-zone
|
| 469 |
```
|
| 470 |
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
## 15. Immediate Next Engineering Improvement
|
| 474 |
-
|
| 475 |
-
Next useful work:
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
| 1 |
# Experiment Workflow: Baselines, Opponents, Short Runs, and Training
|
| 2 |
|
| 3 |
+
This document explains how to run CricketCaptain experiments in a practical order: smoke checks, 5-over baselines, training, then longer evaluation.
|
| 4 |
|
| 5 |
+
## 1. Why Start With 5-Over Smoke + Baselines?
|
| 6 |
|
| 7 |
+
A full T20 innings is 20 overs. That is useful for final evaluation but slow for debugging.
|
| 8 |
|
| 9 |
+
For early experiments, 5-over runs are better because they quickly answer:
|
| 10 |
|
| 11 |
- Is the OpenEnv server working?
|
| 12 |
- Is the client connecting correctly?
|
|
|
|
| 19 |
The workflow should be:
|
| 20 |
|
| 21 |
```text
|
| 22 |
+
5-over smoke → 5-over untrained baseline → short training → 5-over trained eval → 20-over final eval
|
| 23 |
```
|
| 24 |
|
| 25 |
+
Do not start with full 20-over training unless the 5-over smoke loop is stable.
|
| 26 |
|
| 27 |
+
## 2. Opponent Modes
|
| 28 |
|
| 29 |
+
Four modes in `server/opponent_policy.py`. Controlled via `--opponent-mode`, `CRICKET_OPPONENT_MODE`, or `configs/default.yaml`.
|
| 30 |
|
| 31 |
+
**Default is `llm_live`** in `configs/default.yaml` so training can face a real LLM opponent when credentials are present. For cheap/local checks, explicitly pass `--opponent-mode heuristic`.
|
| 32 |
|
| 33 |
+
### 2.1 Heuristic Opponent
|
| 34 |
|
| 35 |
```bash
|
| 36 |
+
--opponent-mode heuristic
|
| 37 |
```
|
| 38 |
|
| 39 |
+
Format-aware local policy. Uses T5/T20/ODI rules from `server/format_mapper.py`:
|
| 40 |
+
- Phase-weighted shot distributions (powerplay/middle/death per format)
|
| 41 |
+
- Wicket-pressure shift (heavier weight toward defensive shots when 7+ down)
|
| 42 |
+
- Batter/bowler roles selected from `data/format_rules.json`
|
| 43 |
|
| 44 |
+
Use for: fast local tests, cheap training rollouts, deterministic-ish baselines.
|
| 45 |
+
**No API key needed.**
|
| 46 |
|
| 47 |
+
### 2.2 Cricsheet Opponent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
```bash
|
| 50 |
+
--opponent-mode cricsheet
|
|
|
|
|
|
|
|
|
|
| 51 |
```
|
| 52 |
|
| 53 |
+
Samples real Cricsheet ball-by-ball deliveries, indexed by `(phase, wickets_band, innings_type)`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
Auto-selects data by format:
|
| 56 |
+
- `max_overs ≤ 25` → `ball_outcomes_t20_v1.pkl` (1.17M T20 deliveries)
|
| 57 |
+
- `max_overs > 25` → `ball_outcomes_odi_v1.pkl` (1.65M ODI deliveries)
|
|
|
|
| 58 |
|
| 59 |
+
Progressive fallback: drop innings_type → drop wickets_band → any phase record.
|
| 60 |
+
**No API key needed.** Data files must be present under `data/processed/`.
|
| 61 |
|
| 62 |
+
### 2.3 Live LLM Opponent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
```bash
|
| 65 |
+
export HF_TOKEN=hf_...
|
| 66 |
+
--opponent-mode llm_live
|
| 67 |
```
|
| 68 |
|
| 69 |
+
Calls `google/gemma-4-26B-A4B-it` via HF Router (or any OpenAI-compatible endpoint). Set `HF_TOKEN` or pass `--opponent-api-key`; otherwise use `--opponent-mode heuristic` for local runs.
|
| 70 |
|
| 71 |
+
Use for: demos, realistic opponent behavior, self-play experiments.
|
| 72 |
|
| 73 |
+
### 2.4 Cached LLM Opponent
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
```bash
|
| 76 |
+
--opponent-mode llm_cached
|
| 77 |
+
export CRICKET_OPPONENT_CACHE=data/opponent_cache/adaptive_t20_v1.jsonl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
```
|
| 79 |
|
| 80 |
+
Replays pre-recorded decisions. Does **not** call any live model. Use for official/reproducible eval — every compared captain faces identical opponent decisions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
## 3. Starting The Server
|
| 83 |
|
| 84 |
```bash
|
| 85 |
+
# Recommended (uvicorn auto-reload)
|
| 86 |
+
cd cricket_captain
|
| 87 |
+
python -m uvicorn server.app:app --port 8001
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Or via app.py directly
|
| 90 |
+
PYTHONPATH=. python server/app.py --port 8001
|
| 91 |
```
|
| 92 |
|
| 93 |
+
Health check: `curl http://localhost:8001/health` → `{"status":"healthy"}`
|
| 94 |
|
| 95 |
+
Set the URL for runners:
|
| 96 |
```bash
|
| 97 |
+
export CRICKET_CAPTAIN_ENV_URL=http://localhost:8001
|
|
|
|
| 98 |
```
|
| 99 |
|
| 100 |
+
On Lightning / remote runtimes, expose the port and pass the external URL:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
```bash
|
| 102 |
+
export CRICKET_CAPTAIN_ENV_URL=ws://<lightning-exposed-host>/ws
|
| 103 |
```
|
| 104 |
|
| 105 |
+
## 4. Step 1: Random Baseline (all 3 local modes)
|
| 106 |
+
|
| 107 |
+
No API key, no GPU needed. Verify the full loop works.
|
| 108 |
|
| 109 |
```bash
|
| 110 |
+
# Run all 3 modes in parallel
|
| 111 |
+
python inference.py --model random --episodes 5 --max-overs 5 --opponent-mode heuristic --env-url http://localhost:8001
|
| 112 |
+
python inference.py --model random --episodes 5 --max-overs 5 --opponent-mode cricsheet --env-url http://localhost:8001
|
| 113 |
+
python inference.py --model random --episodes 5 --max-overs 5 --opponent-mode llm_live --env-url http://localhost:8001
|
|
|
|
|
|
|
| 114 |
```
|
| 115 |
|
| 116 |
+
**Verified baselines (2026-04-25):**
|
| 117 |
|
| 118 |
+
| Opponent | score | coherence | reward | parse_err |
|
| 119 |
+
|----------|-------|-----------|--------|-----------|
|
| 120 |
+
| heuristic | 20.8 | 0.556 | −0.826 | 0% |
|
| 121 |
+
| cricsheet | 28.0 | 0.527 | −0.410 | 0% |
|
| 122 |
+
| llm_live | 27.4 | 0.537 | −0.723 | 0% |
|
| 123 |
|
| 124 |
+
## 5. Step 2: Train-Smoke (verify reward signals, no GPU)
|
| 125 |
|
| 126 |
+
`train.py train-smoke` runs direct `CricketEnvironment` rollouts — **no server needed**.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
```bash
|
| 129 |
+
python train.py train-smoke --matches 3 --max-overs 5 --opponent-mode heuristic
|
| 130 |
+
python train.py train-smoke --matches 3 --max-overs 5 --opponent-mode cricsheet
|
| 131 |
+
python train.py train-smoke --matches 3 --max-overs 5 --opponent-mode llm_live
|
|
|
|
|
|
|
|
|
|
| 132 |
```
|
| 133 |
|
| 134 |
+
**Verified train-smoke baselines (2026-04-25):** r_validity=1.0 on all 9 matches (3 modes × 3 matches). All reward signals active: coherence, adaptation, opponent_awareness, plan_commitment, staleness, regret.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
Quick 2-over smoke for CI-style checks:
|
| 137 |
|
| 138 |
+
```bash
|
| 139 |
+
python train.py train-smoke --matches 1 --max-overs 2 --max-steps 240 --log-steps 90 --opponent-mode heuristic
|
| 140 |
```
|
| 141 |
|
| 142 |
+
## 6. Step 3: Untrained LLM Baseline (requires HF token)
|
| 143 |
|
| 144 |
```bash
|
| 145 |
+
export HF_TOKEN=hf_...
|
| 146 |
+
python inference.py \
|
| 147 |
+
--model google/gemma-4-26B-A4B-it \
|
| 148 |
+
--api-base https://router.huggingface.co/v1 \
|
| 149 |
+
--api-key "$HF_TOKEN" \
|
| 150 |
+
--episodes 3 --max-overs 5 \
|
| 151 |
+
--opponent-mode llm_live \
|
| 152 |
+
--env-url http://localhost:8001
|
| 153 |
```
|
| 154 |
|
| 155 |
+
**Verified LLM captain run (2026-04-25):**
|
| 156 |
+
```
|
| 157 |
+
model: google/gemma-4-26B-A4B-it via HF Router
|
| 158 |
+
coherence: 0.657 | adaptation: 0.502 | opp_aware: 0.750
|
| 159 |
+
parse errors: 0.0% | reward: +0.168
|
| 160 |
+
```
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
## 7. Step 4: SFT Tool Warmup
|
| 163 |
|
| 164 |
```bash
|
| 165 |
+
python train.py sft-data --output data/training/tool_sft_examples.jsonl --examples 500
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
```
|
| 167 |
|
| 168 |
+
Teaches tool-call shape, tool names, and argument schemas before RL. Not the main objective — just reduces wasted GRPO rollouts on syntax/tool-selection errors.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
## 8. Step 5: GRPO Warmup (5-over format)
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
The two-stage curriculum was collapsed into a single-stage full-reward run because Qwen3.5-4B
|
| 173 |
+
already supports tool calling natively (XML+JSON both accepted). What used to be "Stage 1" is
|
| 174 |
+
now a fast 5-over warmup, controlled entirely from YAML:
|
| 175 |
|
| 176 |
```bash
|
| 177 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 178 |
+
python train.py train --config configs/cricket_train_warmup.yaml
|
|
|
|
|
|
|
|
|
|
| 179 |
```
|
| 180 |
|
| 181 |
+
Config: `max_overs=5`, `steps=25`, `num_generations=8`, `batch_size=8`, `max_completion_length=3072`,
|
| 182 |
+
`save_steps=5`, full composite reward from step 0. Approx 1.5–2 hrs on H200.
|
| 183 |
|
| 184 |
+
Goal: bootstrap the LoRA adapter on cheap short matches before the longer 20-over run.
|
| 185 |
+
Watch in WandB: `reward/composite_mean`, `tools/freq_*`, `rollout/match_completion_rate`,
|
| 186 |
+
`completions/clipped_ratio`.
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
## 9. Step 6: GRPO Main Run (20-over T20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
```bash
|
| 191 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 192 |
+
python train.py train --config configs/cricket_train.yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
```
|
| 194 |
|
| 195 |
+
Config: `max_overs=20`, `steps=100`, `num_generations=4`, `batch_size=8`, `max_completion_length=3072`,
|
| 196 |
+
`save_steps=10`, `beta=0.0` (no reference model). Approx 15-18 min/step (~30-40 steps fit a 10-hr budget).
|
| 197 |
|
| 198 |
+
To **resume from the warmup adapter**, uncomment in `configs/cricket_train.yaml`:
|
| 199 |
+
```yaml
|
| 200 |
+
resume_from: ./checkpoints/stage2_final
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
```
|
| 202 |
+
or pass `--resume-from ./checkpoints/stage2_final` on the CLI. The base model still loads from
|
| 203 |
+
`Qwen/Qwen3.5-4B`; only the LoRA weights resume.
|
| 204 |
|
| 205 |
+
Goal: improve coherence, adaptation, opponent awareness, match outcomes on full T20s.
|
| 206 |
+
Watch in WandB: `reward/r_result_mean` (sparse outcome), `reward/r_coherence_mean`,
|
| 207 |
+
`reward/r_adaptation_mean`, `episode/tool_calls_mean` (should approach 720), `episode/agent_score_mean`.
|
| 208 |
|
| 209 |
+
Switch the opponent in YAML (`opponent.mode: llm_live`) and set `HF_TOKEN` for adversarial
|
| 210 |
+
training against live Gemma. Use `cricsheet` or `heuristic` for cheaper ablations.
|
| 211 |
|
| 212 |
+
`train.py train` uses TRL `GRPOTrainer(environment_factory=CricketCaptainToolEnv)`, so the model interacts with a live environment over multiple tool-calling turns. This is not inference through the HF Router; the trained captain model is loaded locally by Transformers/TRL, with LoRA adapters when using quantized weights.
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
## 10. Step 7: Evaluation
|
| 215 |
|
| 216 |
```bash
|
| 217 |
+
python eval.py \
|
| 218 |
--episodes 20 \
|
| 219 |
+
--env-url http://localhost:8001 \
|
| 220 |
+
--eval-pack-id adaptive_t20_v1 \
|
| 221 |
+
--opponent-mode llm_cached
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
```
|
| 223 |
|
| 224 |
+
Compare: random baseline → untrained Qwen3.5-4B → trained Qwen3.5-4B (warmup + main adapter via `compare_eval.py`).
|
| 225 |
|
| 226 |
+
## 11. Format Comparison
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
| max_overs | Format selected | Data used by cricsheet | Typical target |
|
| 229 |
+
|-----------|----------------|------------------------|---------------|
|
| 230 |
+
| 5 | T5 | T20 pkl (closest) | ~47 runs |
|
| 231 |
+
| 7 | T5 | T20 pkl | ~66 runs |
|
| 232 |
+
| 20 | T20 | T20 pkl | ~160 runs |
|
| 233 |
+
| 50 | ODI | ODI pkl | ~290 runs |
|
| 234 |
|
| 235 |
+
All formats work with all opponent modes. Use `--max-overs N` with any runner.
|
| 236 |
|
| 237 |
+
## 12. Tool budget and training
|
| 238 |
|
| 239 |
+
Implemented in `server/cricket_environment.py` (`TOOL_BUDGET_PER_OVER=3`, `TOOL_FINE_PER_EXCESS=0.04`).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
**Overhead tools (only these increment the per-over counter):** `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, `analyze_situation`.
|
| 242 |
|
| 243 |
+
**Not overhead:** `plan_shot`, `set_match_plan`, `update_match_plan`, `select_batter`, `choose_bowler`, `set_field_setting`, `play_delivery`, `bowl_delivery`, `call_toss`, and other tools that advance or directly execute the ball.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
**Rule:** the first 3 overhead calls in each over are not fined; each further overhead call in that over adds **−0.04** to the step reward. The prompt includes `Tool budget: N/3 overhead calls used this over`.
|
| 246 |
|
| 247 |
+
**Why this matters for GRPO:** training uses the same environment as inference. Fines are part of the reward the trainer optimizes, so the policy learns to use reflection and `analyze_situation` when they matter, and to lean on `plan_shot` plus match-level planning (`set_match_plan` / `update_match_plan`) for routine structure without spending the 3 free overhead calls every ball.
|
| 248 |
|
| 249 |
+
## 13. Configs
|
| 250 |
|
| 251 |
+
```bash
|
| 252 |
+
# Start server with default config (llm_live opponent, 5-over default)
|
| 253 |
+
PYTHONPATH=. python server/app.py --port 8001 --config configs/default.yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
# Start with reproducible cached eval
|
| 256 |
+
PYTHONPATH=. python server/app.py --port 8001 --config configs/cached_eval.yaml
|
| 257 |
+
```
|
| 258 |
|
| 259 |
+
Config controls: `env.agent_team`, `env.max_overs`, `env.eval_pack_id`, `train.model`, `train.max_completion_length`, `train.max_tool_calling_iterations`, `opponent.mode`, `opponent.model`, `opponent.api_base`, `captain.model`, and `captain.api_base`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
## 14. Latest Verified Run Results (2026-04-25)
|
| 262 |
|
| 263 |
+
All runs in `illustrations/`. Zero parse errors across all 14 inference runs. r_validity=1.0 across all train-smoke matches.
|
| 264 |
|
| 265 |
+
```
|
| 266 |
+
Random agent — 5-over, heuristic: score=20.8 coherence=0.556 reward=−0.826
|
| 267 |
+
Random agent — 5-over, cricsheet: score=28.0 coherence=0.527 reward=−0.410
|
| 268 |
+
Random agent — 5-over, llm_live: score=27.4 coherence=0.537 reward=−0.723
|
| 269 |
+
Random agent — 20-over, cricsheet: score=63.6 coherence=0.568 reward=−5.632
|
| 270 |
+
Random agent — 20-over, heuristic: score=82.4 coherence=0.545 reward=−8.174
|
| 271 |
+
|
| 272 |
+
Train-smoke — 5-over, heuristic: r_validity=1.0 coherence=0.596 3W/0L
|
| 273 |
+
Train-smoke — 5-over, cricsheet: r_validity=1.0 coherence=0.620 2W/1L
|
| 274 |
+
Train-smoke — 5-over, llm_live: r_validity=1.0 coherence=0.552 2W/1L
|
| 275 |
+
|
| 276 |
+
LLM captain (gemma-4-26B) — 3-over, llm_live:
|
| 277 |
+
coherence=0.657 adaptation=0.502 opp_aware=0.750 parse_err=0%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
```
|
| 279 |
|
| 280 |
+
## 15. Immediate Next Steps
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
1. **Run GRPO training** via `bash scripts/run_warmup_then_main.sh` (warmup curriculum + main run + auto-resume) to produce reward curves.
|
| 283 |
+
2. **Deploy HF Space** for live Gradio demo — Dockerfile present, just needs HF push.
|
| 284 |
+
3. **Generate opponent cache** using `llm_live` for reproducible official eval.
|
| 285 |
+
4. **Produce training plots** — coherence heatmap, reward curve, tool-usage timeline.
|
docs/slides.html
CHANGED
|
@@ -350,35 +350,36 @@
|
|
| 350 |
<tr>
|
| 351 |
<th>Rubric</th><th>Weight</th><th>Frequency</th><th>Measures</th><th>Key Sub-signals</th>
|
| 352 |
</tr>
|
| 353 |
-
<tr>
|
| 354 |
-
<td><code>r_result</code></td>
|
| 355 |
-
<td><strong>55%</strong></td>
|
| 356 |
-
<td>Episode end</td>
|
| 357 |
-
<td>Win/loss vs DLS par, target margin</td>
|
| 358 |
-
<td>score/par, wickets_remaining, lead/deficit</td>
|
| 359 |
-
</tr>
|
| 360 |
<tr>
|
| 361 |
<td><code>r_cricket</code></td>
|
| 362 |
-
<td><strong>
|
| 363 |
-
<td>
|
| 364 |
<td>Dream11 proxy: runs, wickets, milestones</td>
|
| 365 |
-
<td>dot%, boundary%, 50s/100s, maiden overs</td>
|
| 366 |
</tr>
|
| 367 |
<tr>
|
| 368 |
<td><code>r_behavior</code></td>
|
| 369 |
-
<td><strong>
|
| 370 |
-
<td>Every
|
| 371 |
<td>Declaration-execution alignment</td>
|
| 372 |
<td>coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%)</td>
|
| 373 |
</tr>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
<tr>
|
| 375 |
<td><code>r_validity</code></td>
|
| 376 |
-
<td><strong>
|
| 377 |
<td>Every turn</td>
|
| 378 |
-
<td>Parseable JSON tool call</td>
|
| 379 |
<td>Format gate; 0 = parse fail, 1 = valid</td>
|
| 380 |
</tr>
|
| 381 |
</table>
|
|
|
|
| 382 |
<div class="two-col" style="margin-top:18px;">
|
| 383 |
<div>
|
| 384 |
<h3>Coherence Score Formula (per delivery)</h3>
|
|
@@ -389,12 +390,12 @@
|
|
| 389 |
)</pre>
|
| 390 |
</div>
|
| 391 |
<div>
|
| 392 |
-
<h3>
|
| 393 |
<ul>
|
| 394 |
-
<li><strong>
|
| 395 |
-
<li><strong>
|
| 396 |
-
<li>
|
| 397 |
-
<li>GRPO group size =
|
| 398 |
</ul>
|
| 399 |
</div>
|
| 400 |
</div>
|
|
@@ -445,15 +446,19 @@ python inference.py \
|
|
| 445 |
--config configs/default.yaml \
|
| 446 |
--max-overs 3 --opponent-mode llm_live
|
| 447 |
|
| 448 |
-
<span class="dim"># 4.
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
<div class="wn" style="font-size:0.84rem;">
|
| 458 |
All model / API / env settings live in <code>configs/default.yaml</code>. Zero hardcoding.
|
| 459 |
</div>
|
|
@@ -489,8 +494,8 @@ python train.py train \
|
|
| 489 |
<div>
|
| 490 |
<h3>What training should produce (target)</h3>
|
| 491 |
<ul>
|
| 492 |
-
<li>r_validity: 0.70 → 0.98+ after
|
| 493 |
-
<li>Coherence: ~0.52 (random) → 0.75+ after
|
| 494 |
<li>analyze_situation calls cluster at over 6, 16, 36 transitions</li>
|
| 495 |
<li>Strategy declarations become more specific (>15 word rationales)</li>
|
| 496 |
<li>Shot choices match declared aggression level >80% of deliveries</li>
|
|
@@ -570,7 +575,7 @@ python train.py train \
|
|
| 570 |
<div>
|
| 571 |
<h3>🔴 Critical Path (on-site, Day 1–2)</h3>
|
| 572 |
<ul>
|
| 573 |
-
<li>Run Colab notebook
|
| 574 |
<li>Export: reward_curves.png, coherence_heatmap.png, tool_timeline.png</li>
|
| 575 |
<li>Deploy to HuggingFace Spaces → live interactive Gradio demo URL</li>
|
| 576 |
<li>Add HF Space URL + plot images to README</li>
|
|
|
|
| 350 |
<tr>
|
| 351 |
<th>Rubric</th><th>Weight</th><th>Frequency</th><th>Measures</th><th>Key Sub-signals</th>
|
| 352 |
</tr>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
<tr>
|
| 354 |
<td><code>r_cricket</code></td>
|
| 355 |
+
<td><strong>45%</strong></td>
|
| 356 |
+
<td>Per ball</td>
|
| 357 |
<td>Dream11 proxy: runs, wickets, milestones</td>
|
| 358 |
+
<td>dot%, boundary%, 50s/100s, maiden overs, economy</td>
|
| 359 |
</tr>
|
| 360 |
<tr>
|
| 361 |
<td><code>r_behavior</code></td>
|
| 362 |
+
<td><strong>25%</strong></td>
|
| 363 |
+
<td>Every turn</td>
|
| 364 |
<td>Declaration-execution alignment</td>
|
| 365 |
<td>coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%)</td>
|
| 366 |
</tr>
|
| 367 |
+
<tr>
|
| 368 |
+
<td><code>r_result</code></td>
|
| 369 |
+
<td><strong>20%</strong></td>
|
| 370 |
+
<td>Innings/episode end</td>
|
| 371 |
+
<td>Win/loss vs DLS par, target margin</td>
|
| 372 |
+
<td>score/par, wickets_remaining, lead/deficit, +0.25 progress bonus</td>
|
| 373 |
+
</tr>
|
| 374 |
<tr>
|
| 375 |
<td><code>r_validity</code></td>
|
| 376 |
+
<td><strong>10%</strong></td>
|
| 377 |
<td>Every turn</td>
|
| 378 |
+
<td>Parseable XML/JSON tool call</td>
|
| 379 |
<td>Format gate; 0 = parse fail, 1 = valid</td>
|
| 380 |
</tr>
|
| 381 |
</table>
|
| 382 |
+
<p style="margin-top:8px;font-size:0.9em;color:#888">Rebalanced from 55/25/15/5 → 45/25/20/10 to match the SWE-RL recipe (60% intermediate / 40% terminal). Reasoning: partial-trajectory training rarely fires <code>r_result</code>; weighting it 55% wastes gradient on a near-zero signal.</p>
|
| 383 |
<div class="two-col" style="margin-top:18px;">
|
| 384 |
<div>
|
| 385 |
<h3>Coherence Score Formula (per delivery)</h3>
|
|
|
|
| 390 |
)</pre>
|
| 391 |
</div>
|
| 392 |
<div>
|
| 393 |
+
<h3>Single-Stage Training + Format Curriculum</h3>
|
| 394 |
<ul>
|
| 395 |
+
<li><strong>Warmup (5-over curriculum):</strong> per-scenario <code>max_overs</code> sampled from [2,2,2,2,2,3,3,3,4,4,5] so episodes complete in budget and <code>r_result</code> can fire</li>
|
| 396 |
+
<li><strong>Main run (full T20):</strong> resumes warmup adapter, trains on target eval distribution</li>
|
| 397 |
+
<li>Qwen3.5-4B already does tool calling natively (XML+JSON) — no Stage 1 SFT needed</li>
|
| 398 |
+
<li>GRPO group size = 4; full episode advantages (TRL <code>environment_factory</code>)</li>
|
| 399 |
</ul>
|
| 400 |
</div>
|
| 401 |
</div>
|
|
|
|
| 446 |
--config configs/default.yaml \
|
| 447 |
--max-overs 3 --opponent-mode llm_live
|
| 448 |
|
| 449 |
+
<span class="dim"># 4. Warmup → Main chained run (auto-resumes adapter)</span>
|
| 450 |
+
bash scripts/run_warmup_then_main.sh
|
| 451 |
+
|
| 452 |
+
<span class="dim"># 5. Eval: untrained vs trained head-to-head</span>
|
| 453 |
+
python compare_eval.py --model Qwen/Qwen3.5-4B \
|
| 454 |
+
--label baseline --episodes 20 --max-overs 5 \
|
| 455 |
+
--output eval_results/baseline.json
|
| 456 |
+
python compare_eval.py --model Qwen/Qwen3.5-4B \
|
| 457 |
+
--adapter ./checkpoints/stage2_final \
|
| 458 |
+
--label trained --episodes 20 --max-overs 5 \
|
| 459 |
+
--output eval_results/trained.json
|
| 460 |
+
python compare_eval.py --compare \
|
| 461 |
+
eval_results/baseline.json eval_results/trained.json</pre>
|
| 462 |
<div class="wn" style="font-size:0.84rem;">
|
| 463 |
All model / API / env settings live in <code>configs/default.yaml</code>. Zero hardcoding.
|
| 464 |
</div>
|
|
|
|
| 494 |
<div>
|
| 495 |
<h3>What training should produce (target)</h3>
|
| 496 |
<ul>
|
| 497 |
+
<li>r_validity: 0.70 → 0.98+ after warmup (25 steps)</li>
|
| 498 |
+
<li>Coherence: ~0.52 (random) → 0.75+ after main run</li>
|
| 499 |
<li>analyze_situation calls cluster at over 6, 16, 36 transitions</li>
|
| 500 |
<li>Strategy declarations become more specific (>15 word rationales)</li>
|
| 501 |
<li>Shot choices match declared aggression level >80% of deliveries</li>
|
|
|
|
| 575 |
<div>
|
| 576 |
<h3>🔴 Critical Path (on-site, Day 1–2)</h3>
|
| 577 |
<ul>
|
| 578 |
+
<li>Run Colab notebook (notebooks/colab_train_minimal.ipynb) → warmup → main chained training</li>
|
| 579 |
<li>Export: reward_curves.png, coherence_heatmap.png, tool_timeline.png</li>
|
| 580 |
<li>Deploy to HuggingFace Spaces → live interactive Gradio demo URL</li>
|
| 581 |
<li>Add HF Space URL + plot images to README</li>
|
openenv.yaml
CHANGED
|
@@ -6,23 +6,24 @@ app: server.app:app
|
|
| 6 |
port: 8000
|
| 7 |
|
| 8 |
description: >
|
| 9 |
-
CricketCaptain-LLM
|
| 10 |
-
|
| 11 |
-
The agent uses 14 tools (toss, match
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
Two-sided: a live or heuristic LLM opponent plays the opposing team.
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
tasks:
|
| 18 |
-
- name: stage1_format
|
| 19 |
-
description: "5-over mini-match. r_validity only — teaches valid JSON tool-call structure (ToolRL Stage 1)."
|
| 20 |
-
difficulty: easy
|
| 21 |
-
|
| 22 |
- name: stage2_full
|
| 23 |
-
description: "
|
| 24 |
difficulty: medium
|
| 25 |
|
| 26 |
- name: eval_50over
|
| 27 |
-
description: "Full 50-over ODI. Evaluation benchmark — measures trained
|
| 28 |
difficulty: hard
|
|
|
|
| 6 |
port: 8000
|
| 7 |
|
| 8 |
description: >
|
| 9 |
+
CricketCaptain-LLM is a multi-turn agentic RL benchmark: train an LLM to captain
|
| 10 |
+
a cricket match end-to-end, alternating between batting and bowling phases across
|
| 11 |
+
~180 sequential tool calls. The agent uses 14 phase-gated tools (toss, match plan,
|
| 12 |
+
batting, bowling, fielding, reflection, analysis) and is scored by a composite
|
| 13 |
+
4-rubric reward — Dream11-style per-ball cricket-position signal (45%), per-turn
|
| 14 |
+
behavioral coherence/adaptation/opponent-awareness/regret (25%), match outcome
|
| 15 |
+
with DLS par + win bonus + progress bonus (20%), and tool-call validity (10%).
|
| 16 |
Two-sided: a live or heuristic LLM opponent plays the opposing team.
|
| 17 |
+
Real Markov outcome engine trained on 1.65M cricsheet deliveries.
|
| 18 |
+
Single-stage GRPO training with format-length curriculum (T2 → T5 in warmup,
|
| 19 |
+
full T20 in main run). Partial-trajectory training generalizes to full match
|
| 20 |
+
completion at eval time (same recipe as SWE-RL coding agents).
|
| 21 |
|
| 22 |
tasks:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
- name: stage2_full
|
| 24 |
+
description: "Full T20 match (configurable max_overs 2-50). All 4 rubrics active. Composite reward: r_cricket (45%) + r_behavior (25%) + r_result (20%) + r_validity (10%)."
|
| 25 |
difficulty: medium
|
| 26 |
|
| 27 |
- name: eval_50over
|
| 28 |
+
description: "Full 50-over ODI. Evaluation benchmark — measures trained captaincy across the longest format (DLS par + chase pressure)."
|
| 29 |
difficulty: hard
|
pyproject.toml
CHANGED
|
@@ -15,15 +15,26 @@ dependencies = [
|
|
| 15 |
"openai>=1.0.0",
|
| 16 |
]
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
[project.optional-dependencies]
|
| 19 |
train = [
|
| 20 |
-
"
|
| 21 |
-
"transformers
|
|
|
|
|
|
|
|
|
|
| 22 |
"accelerate>=1.0.0",
|
| 23 |
"datasets>=4.0.0",
|
| 24 |
"bitsandbytes>=0.43.0",
|
| 25 |
-
"
|
| 26 |
-
|
|
|
|
| 27 |
]
|
| 28 |
eval = [
|
| 29 |
"matplotlib>=3.8.0",
|
|
|
|
| 15 |
"openai>=1.0.0",
|
| 16 |
]
|
| 17 |
|
| 18 |
+
# Training extras — these are the versions that actually reconcile in 2026:
|
| 19 |
+
# transformers 5.6.2 ─┐
|
| 20 |
+
# trl 1.2.0 ├─ TRL multi-turn environment_factory needs transformers >=5.2,
|
| 21 |
+
# vllm 0.19.1 ┘ vLLM 0.19+ supports transformers 5, vLLM 0.18 does not.
|
| 22 |
+
# Earlier we tried vllm 0.11.x — it pinned transformers <5 and broke environment_factory.
|
| 23 |
+
# mergekit removed: pinned pydantic <2.11 which conflicts with openenv-core 0.2.3 (>=2.11.7).
|
| 24 |
+
# Not used by training anyway.
|
| 25 |
[project.optional-dependencies]
|
| 26 |
train = [
|
| 27 |
+
"torch==2.10.0",
|
| 28 |
+
"transformers==5.6.2",
|
| 29 |
+
"trl==1.2.0",
|
| 30 |
+
"vllm==0.19.1",
|
| 31 |
+
"peft>=0.13.0,<0.20.0",
|
| 32 |
"accelerate>=1.0.0",
|
| 33 |
"datasets>=4.0.0",
|
| 34 |
"bitsandbytes>=0.43.0",
|
| 35 |
+
"wandb>=0.16",
|
| 36 |
+
# flash-attn is optional — vLLM has its own attention backends; uncomment if you want it:
|
| 37 |
+
# "flash-attn>=2.5.0",
|
| 38 |
]
|
| 39 |
eval = [
|
| 40 |
"matplotlib>=3.8.0",
|
scripts/eval_all_checkpoints.sh
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Evaluate every saved checkpoint and find the best one.
|
| 3 |
+
#
|
| 4 |
+
# RL training is bumpy — the FINAL checkpoint isn't always the best policy.
|
| 5 |
+
# This script runs compare_eval.py against the baseline AND each checkpoint
|
| 6 |
+
# in ./checkpoints/, then picks whichever has the highest mean composite
|
| 7 |
+
# reward to be the "best trained" submission.
|
| 8 |
+
#
|
| 9 |
+
# Usage:
|
| 10 |
+
# bash scripts/eval_all_checkpoints.sh # 10 episodes each
|
| 11 |
+
# EVAL_EPISODES=20 bash scripts/eval_all_checkpoints.sh # more confidence
|
| 12 |
+
|
| 13 |
+
cd "$(dirname "$0")/.."
|
| 14 |
+
mkdir -p eval_results
|
| 15 |
+
|
| 16 |
+
EVAL_EPISODES="${EVAL_EPISODES:-10}"
|
| 17 |
+
EVAL_MAX_OVERS="${EVAL_MAX_OVERS:-5}"
|
| 18 |
+
BASE_MODEL="${BASE_MODEL:-Qwen/Qwen3.5-4B}"
|
| 19 |
+
|
| 20 |
+
# 1. Baseline (untrained)
|
| 21 |
+
if [ ! -f eval_results/baseline.json ]; then
|
| 22 |
+
echo "=== [baseline] running ==="
|
| 23 |
+
python compare_eval.py \
|
| 24 |
+
--model "$BASE_MODEL" \
|
| 25 |
+
--label baseline \
|
| 26 |
+
--episodes "$EVAL_EPISODES" \
|
| 27 |
+
--max-overs "$EVAL_MAX_OVERS" \
|
| 28 |
+
--opponent-mode heuristic \
|
| 29 |
+
--output eval_results/baseline.json
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
# 2. Each checkpoint
|
| 33 |
+
for ckpt in ./checkpoints/stage*/checkpoint-* ./checkpoints/stage*_final; do
|
| 34 |
+
if [ -d "$ckpt" ]; then
|
| 35 |
+
# Verify it's a PEFT adapter dir (has adapter_config.json)
|
| 36 |
+
if [ ! -f "$ckpt/adapter_config.json" ]; then
|
| 37 |
+
echo "=== [$ckpt] no adapter_config.json — skip ==="
|
| 38 |
+
continue
|
| 39 |
+
fi
|
| 40 |
+
# Use a safe filename derived from path
|
| 41 |
+
label=$(echo "$ckpt" | sed 's|[/.]|_|g' | sed 's|^_||')
|
| 42 |
+
out="eval_results/${label}.json"
|
| 43 |
+
if [ -f "$out" ]; then
|
| 44 |
+
echo "=== [$ckpt] already evaluated, reading $out ==="
|
| 45 |
+
continue
|
| 46 |
+
fi
|
| 47 |
+
echo "=== [$ckpt] evaluating ==="
|
| 48 |
+
python compare_eval.py \
|
| 49 |
+
--model "$BASE_MODEL" \
|
| 50 |
+
--adapter "$ckpt" \
|
| 51 |
+
--label "$label" \
|
| 52 |
+
--episodes "$EVAL_EPISODES" \
|
| 53 |
+
--max-overs "$EVAL_MAX_OVERS" \
|
| 54 |
+
--opponent-mode heuristic \
|
| 55 |
+
--output "$out"
|
| 56 |
+
fi
|
| 57 |
+
done
|
| 58 |
+
|
| 59 |
+
# 3. Pick the best one and run a final compare against baseline
|
| 60 |
+
echo ""
|
| 61 |
+
echo "=== picking best checkpoint by mean composite reward ==="
|
| 62 |
+
python - <<'PYEOF'
|
| 63 |
+
import json, glob
|
| 64 |
+
|
| 65 |
+
best = None
|
| 66 |
+
for path in glob.glob("eval_results/*.json"):
|
| 67 |
+
if path.endswith("baseline.json"):
|
| 68 |
+
continue
|
| 69 |
+
try:
|
| 70 |
+
with open(path) as f:
|
| 71 |
+
data = json.load(f)
|
| 72 |
+
score = data["summary"].get("mean_composite_reward", 0.0) or 0.0
|
| 73 |
+
win_rate = data["summary"].get("win_rate_overall", 0.0) or 0.0
|
| 74 |
+
# Composite ranking: composite reward + 0.5 * win_rate
|
| 75 |
+
composite_score = score + 0.5 * win_rate
|
| 76 |
+
print(f" {path:50s} composite={score:.4f} win_rate={win_rate:.4f} ranking={composite_score:.4f}")
|
| 77 |
+
if best is None or composite_score > best[0]:
|
| 78 |
+
best = (composite_score, path, data)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f" {path}: skip ({e})")
|
| 81 |
+
|
| 82 |
+
if best:
|
| 83 |
+
print(f"\nBEST: {best[1]} (composite_score={best[0]:.4f}, label={best[2].get('label')})")
|
| 84 |
+
print(f" adapter: {best[2].get('adapter')}")
|
| 85 |
+
|
| 86 |
+
# Save a symlink/copy as 'best.json' for easy reference
|
| 87 |
+
import shutil
|
| 88 |
+
shutil.copy(best[1], "eval_results/best_trained.json")
|
| 89 |
+
print(f"\nCopied to eval_results/best_trained.json")
|
| 90 |
+
PYEOF
|
| 91 |
+
|
| 92 |
+
echo ""
|
| 93 |
+
echo "=== final comparison: baseline vs best trained ==="
|
| 94 |
+
python compare_eval.py --compare eval_results/baseline.json eval_results/best_trained.json \
|
| 95 |
+
| tee eval_results/final_comparison.txt
|
scripts/generate_training_plots.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate labeled PNG plots for the README from a WandB run OR from local
|
| 3 |
+
episode_stats.jsonl files.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
# From a WandB run id (preferred — uses the per-step rebalanced metrics)
|
| 7 |
+
python scripts/generate_training_plots.py \\
|
| 8 |
+
--wandb-run ptnv-s-research/huggingface/<RUN_ID> \\
|
| 9 |
+
--output-dir docs/plots/
|
| 10 |
+
|
| 11 |
+
# From local episode_stats.jsonl (faster, no API call)
|
| 12 |
+
python scripts/generate_training_plots.py \\
|
| 13 |
+
--jsonl logs/run_*/episode_stats.jsonl \\
|
| 14 |
+
--output-dir docs/plots/
|
| 15 |
+
|
| 16 |
+
Generates (with axis labels + units):
|
| 17 |
+
docs/plots/training_reward_over_steps.png
|
| 18 |
+
docs/plots/per_rubric_breakdown.png
|
| 19 |
+
docs/plots/tool_call_frequency.png
|
| 20 |
+
docs/plots/match_completion_rate.png
|
| 21 |
+
docs/plots/before_after_comparison.png (if --compare given)
|
| 22 |
+
"""
|
| 23 |
+
import argparse
|
| 24 |
+
import glob
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Any
|
| 29 |
+
|
| 30 |
+
import matplotlib
|
| 31 |
+
matplotlib.use("Agg") # headless
|
| 32 |
+
import matplotlib.pyplot as plt
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _load_jsonl(path: str) -> list[dict[str, Any]]:
|
| 36 |
+
rows = []
|
| 37 |
+
paths = glob.glob(path) if "*" in path else [path]
|
| 38 |
+
for p in paths:
|
| 39 |
+
with open(p) as f:
|
| 40 |
+
for line in f:
|
| 41 |
+
line = line.strip()
|
| 42 |
+
if line:
|
| 43 |
+
try:
|
| 44 |
+
rows.append(json.loads(line))
|
| 45 |
+
except json.JSONDecodeError:
|
| 46 |
+
continue
|
| 47 |
+
return rows
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _load_wandb(run_path: str) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
| 51 |
+
"""Returns (history, config). Requires `pip install wandb` and login."""
|
| 52 |
+
try:
|
| 53 |
+
import wandb
|
| 54 |
+
except ImportError:
|
| 55 |
+
raise RuntimeError("wandb not installed. pip install wandb")
|
| 56 |
+
api = wandb.Api()
|
| 57 |
+
run = api.run(run_path)
|
| 58 |
+
history = list(run.history(samples=10000))
|
| 59 |
+
return history, run.config
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def plot_training_reward(history, out_dir: Path, label: str):
|
| 63 |
+
steps, rewards = [], []
|
| 64 |
+
for row in history:
|
| 65 |
+
if "rewards/environment_reward/mean" in row and row["rewards/environment_reward/mean"] is not None:
|
| 66 |
+
steps.append(row.get("_step", row.get("step", len(steps))))
|
| 67 |
+
rewards.append(row["rewards/environment_reward/mean"])
|
| 68 |
+
if not rewards:
|
| 69 |
+
print(" no environment_reward/mean found, skipping")
|
| 70 |
+
return
|
| 71 |
+
fig, ax = plt.subplots(figsize=(8, 4.5))
|
| 72 |
+
ax.plot(steps, rewards, marker="o", linewidth=1.5, markersize=4, color="#0066cc")
|
| 73 |
+
ax.set_xlabel("Training step (gradient updates)")
|
| 74 |
+
ax.set_ylabel("Mean environment reward (composite)")
|
| 75 |
+
ax.set_title(f"GRPO training reward over time — {label}")
|
| 76 |
+
ax.grid(alpha=0.3)
|
| 77 |
+
fig.tight_layout()
|
| 78 |
+
out_path = out_dir / "training_reward_over_steps.png"
|
| 79 |
+
fig.savefig(out_path, dpi=130)
|
| 80 |
+
plt.close(fig)
|
| 81 |
+
print(f" → {out_path}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def plot_per_rubric_breakdown(history, out_dir: Path, label: str):
|
| 85 |
+
"""Plot the per-step means of all 4 rubrics on one axes."""
|
| 86 |
+
rubrics = ("reward/composite_mean", "reward/r_result_mean",
|
| 87 |
+
"reward/r_cricket_mean", "reward/r_behavior_mean",
|
| 88 |
+
"reward/r_validity_mean")
|
| 89 |
+
series = {r: [] for r in rubrics}
|
| 90 |
+
steps_per = {r: [] for r in rubrics}
|
| 91 |
+
for row in history:
|
| 92 |
+
for r in rubrics:
|
| 93 |
+
if r in row and row[r] is not None:
|
| 94 |
+
series[r].append(row[r])
|
| 95 |
+
steps_per[r].append(row.get("_step", row.get("step", len(series[r]))))
|
| 96 |
+
if not any(series.values()):
|
| 97 |
+
print(" no per-rubric metrics found, skipping")
|
| 98 |
+
return
|
| 99 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 100 |
+
colors = {"reward/composite_mean": "#000",
|
| 101 |
+
"reward/r_result_mean": "#cc0000",
|
| 102 |
+
"reward/r_cricket_mean": "#0066cc",
|
| 103 |
+
"reward/r_behavior_mean": "#009900",
|
| 104 |
+
"reward/r_validity_mean": "#9900cc"}
|
| 105 |
+
for r in rubrics:
|
| 106 |
+
if series[r]:
|
| 107 |
+
ax.plot(steps_per[r], series[r], marker="o", markersize=3, linewidth=1.3,
|
| 108 |
+
label=r.replace("reward/", "").replace("_mean", ""),
|
| 109 |
+
color=colors[r])
|
| 110 |
+
ax.set_xlabel("Training step (gradient updates)")
|
| 111 |
+
ax.set_ylabel("Mean reward")
|
| 112 |
+
ax.set_title(f"Per-rubric reward breakdown — {label}")
|
| 113 |
+
ax.legend(loc="best", fontsize=9)
|
| 114 |
+
ax.grid(alpha=0.3)
|
| 115 |
+
fig.tight_layout()
|
| 116 |
+
out_path = out_dir / "per_rubric_breakdown.png"
|
| 117 |
+
fig.savefig(out_path, dpi=130)
|
| 118 |
+
plt.close(fig)
|
| 119 |
+
print(f" → {out_path}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def plot_tool_call_frequency(history, out_dir: Path, label: str):
|
| 123 |
+
steps, freq = [], []
|
| 124 |
+
for row in history:
|
| 125 |
+
if "tools/call_frequency" in row and row["tools/call_frequency"] is not None:
|
| 126 |
+
steps.append(row.get("_step", row.get("step", len(steps))))
|
| 127 |
+
freq.append(row["tools/call_frequency"])
|
| 128 |
+
if not freq:
|
| 129 |
+
print(" no tools/call_frequency found, skipping")
|
| 130 |
+
return
|
| 131 |
+
fig, ax = plt.subplots(figsize=(8, 4.5))
|
| 132 |
+
ax.plot(steps, freq, marker="o", linewidth=1.5, markersize=4, color="#cc6600")
|
| 133 |
+
ax.set_xlabel("Training step (gradient updates)")
|
| 134 |
+
ax.set_ylabel("Mean tool calls per rollout")
|
| 135 |
+
ax.set_title(f"Tool-call execution frequency (proxy for match progress) — {label}")
|
| 136 |
+
ax.grid(alpha=0.3)
|
| 137 |
+
fig.tight_layout()
|
| 138 |
+
out_path = out_dir / "tool_call_frequency.png"
|
| 139 |
+
fig.savefig(out_path, dpi=130)
|
| 140 |
+
plt.close(fig)
|
| 141 |
+
print(f" → {out_path}")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def plot_completion_rate(history, out_dir: Path, label: str):
|
| 145 |
+
steps, rate = [], []
|
| 146 |
+
for row in history:
|
| 147 |
+
if "rollout/match_completion_rate" in row and row["rollout/match_completion_rate"] is not None:
|
| 148 |
+
steps.append(row.get("_step", row.get("step", len(steps))))
|
| 149 |
+
rate.append(row["rollout/match_completion_rate"])
|
| 150 |
+
if not rate:
|
| 151 |
+
print(" no match_completion_rate found, skipping")
|
| 152 |
+
return
|
| 153 |
+
fig, ax = plt.subplots(figsize=(8, 4.5))
|
| 154 |
+
ax.plot(steps, rate, marker="o", linewidth=1.5, markersize=4, color="#009966")
|
| 155 |
+
ax.set_xlabel("Training step (gradient updates)")
|
| 156 |
+
ax.set_ylabel("Match completion rate")
|
| 157 |
+
ax.set_ylim(0, 1.05)
|
| 158 |
+
ax.set_title(f"Fraction of rollouts that completed the full match — {label}")
|
| 159 |
+
ax.grid(alpha=0.3)
|
| 160 |
+
fig.tight_layout()
|
| 161 |
+
out_path = out_dir / "match_completion_rate.png"
|
| 162 |
+
fig.savefig(out_path, dpi=130)
|
| 163 |
+
plt.close(fig)
|
| 164 |
+
print(f" → {out_path}")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def plot_before_after(baseline_json: str, trained_json: str, out_dir: Path):
|
| 168 |
+
"""Bar chart comparing baseline vs trained on key eval metrics."""
|
| 169 |
+
with open(baseline_json) as f:
|
| 170 |
+
b = json.load(f)
|
| 171 |
+
with open(trained_json) as f:
|
| 172 |
+
t = json.load(f)
|
| 173 |
+
bs, ts = b["summary"], t["summary"]
|
| 174 |
+
metrics = [
|
| 175 |
+
("match_completion_rate", "Match\ncompletion rate"),
|
| 176 |
+
("win_rate_overall", "Overall\nwin rate"),
|
| 177 |
+
("mean_validity_rate", "Mean\nvalidity rate"),
|
| 178 |
+
("mean_composite_reward", "Mean composite\nreward (scaled)"),
|
| 179 |
+
]
|
| 180 |
+
bvals = [bs.get(k, 0) or 0 for k, _ in metrics]
|
| 181 |
+
tvals = [ts.get(k, 0) or 0 for k, _ in metrics]
|
| 182 |
+
labels = [lbl for _, lbl in metrics]
|
| 183 |
+
|
| 184 |
+
x = range(len(metrics))
|
| 185 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 186 |
+
width = 0.35
|
| 187 |
+
bars_b = ax.bar([xi - width/2 for xi in x], bvals, width, label="baseline (untrained)", color="#999")
|
| 188 |
+
bars_t = ax.bar([xi + width/2 for xi in x], tvals, width, label="trained (LoRA r=64)", color="#0066cc")
|
| 189 |
+
|
| 190 |
+
for bars in (bars_b, bars_t):
|
| 191 |
+
for bar in bars:
|
| 192 |
+
h = bar.get_height()
|
| 193 |
+
ax.text(bar.get_x() + bar.get_width()/2, h + 0.01,
|
| 194 |
+
f"{h:.2f}", ha="center", fontsize=8)
|
| 195 |
+
|
| 196 |
+
ax.set_xticks(list(x))
|
| 197 |
+
ax.set_xticklabels(labels)
|
| 198 |
+
ax.set_ylabel("Metric value")
|
| 199 |
+
ax.set_title(f"Before vs After training — {bs['n_episodes']} eval matches each")
|
| 200 |
+
ax.legend()
|
| 201 |
+
ax.grid(axis="y", alpha=0.3)
|
| 202 |
+
fig.tight_layout()
|
| 203 |
+
out_path = out_dir / "before_after_comparison.png"
|
| 204 |
+
fig.savefig(out_path, dpi=130)
|
| 205 |
+
plt.close(fig)
|
| 206 |
+
print(f" → {out_path}")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main():
|
| 210 |
+
p = argparse.ArgumentParser()
|
| 211 |
+
p.add_argument("--wandb-run", default=None,
|
| 212 |
+
help="WandB run path: entity/project/run_id (e.g. ptnv-s-research/huggingface/abc123)")
|
| 213 |
+
p.add_argument("--jsonl", default=None,
|
| 214 |
+
help="Local episode_stats.jsonl path (or glob)")
|
| 215 |
+
p.add_argument("--output-dir", default="docs/plots",
|
| 216 |
+
help="Output directory for PNGs (default: docs/plots/)")
|
| 217 |
+
p.add_argument("--label", default="warmup", help="Label suffix for plot titles")
|
| 218 |
+
p.add_argument("--compare", nargs=2, metavar=("BASELINE_JSON", "TRAINED_JSON"),
|
| 219 |
+
help="Also generate before/after bar chart from two compare_eval JSON files")
|
| 220 |
+
args = p.parse_args()
|
| 221 |
+
|
| 222 |
+
out_dir = Path(args.output_dir)
|
| 223 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
history = []
|
| 226 |
+
if args.wandb_run:
|
| 227 |
+
print(f"Loading WandB run: {args.wandb_run}")
|
| 228 |
+
history, _ = _load_wandb(args.wandb_run)
|
| 229 |
+
print(f" {len(history)} history rows")
|
| 230 |
+
elif args.jsonl:
|
| 231 |
+
print(f"Loading local jsonl: {args.jsonl}")
|
| 232 |
+
history = _load_jsonl(args.jsonl)
|
| 233 |
+
print(f" {len(history)} rows")
|
| 234 |
+
|
| 235 |
+
if history:
|
| 236 |
+
plot_training_reward(history, out_dir, args.label)
|
| 237 |
+
plot_per_rubric_breakdown(history, out_dir, args.label)
|
| 238 |
+
plot_tool_call_frequency(history, out_dir, args.label)
|
| 239 |
+
plot_completion_rate(history, out_dir, args.label)
|
| 240 |
+
|
| 241 |
+
if args.compare:
|
| 242 |
+
plot_before_after(args.compare[0], args.compare[1], out_dir)
|
| 243 |
+
|
| 244 |
+
print(f"\nDone — PNGs in {out_dir}/")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
main()
|
scripts/run_full_pipeline.sh
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Full pipeline: warmup → main → baseline eval → trained eval → comparison.
|
| 3 |
+
#
|
| 4 |
+
# This is the end-to-end deliverable:
|
| 5 |
+
# 1. Warmup (5-over curriculum, 25 steps) ~2 hr
|
| 6 |
+
# 2. Main run (20-over T20, 100 steps) ~6-8 hr
|
| 7 |
+
# 3. Baseline eval (untrained Qwen3.5-4B) ~30 min
|
| 8 |
+
# 4. Trained eval (warmup+main checkpoint) ~30 min
|
| 9 |
+
# 5. Print comparison table ~instant
|
| 10 |
+
#
|
| 11 |
+
# Outputs:
|
| 12 |
+
# - ./checkpoints/stage2_final/ trained LoRA adapter
|
| 13 |
+
# - eval_results/baseline.json baseline match stats
|
| 14 |
+
# - eval_results/trained.json trained match stats
|
| 15 |
+
# - /tmp/train_warmup.log warmup training log
|
| 16 |
+
# - /tmp/train_main.log main run training log
|
| 17 |
+
# - /tmp/eval_baseline.log baseline eval log
|
| 18 |
+
# - /tmp/eval_trained.log trained eval log
|
| 19 |
+
#
|
| 20 |
+
# Usage: bash scripts/run_full_pipeline.sh
|
| 21 |
+
# (or run individual stages as documented in the README)
|
| 22 |
+
|
| 23 |
+
# NOTE: deliberately NOT using `set -e`. We want to inspect each stage's exit
|
| 24 |
+
# code and decide whether to continue, not abort on first non-zero return.
|
| 25 |
+
cd "$(dirname "$0")/.."
|
| 26 |
+
|
| 27 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 28 |
+
|
| 29 |
+
EVAL_EPISODES="${EVAL_EPISODES:-15}"
|
| 30 |
+
EVAL_MAX_OVERS="${EVAL_MAX_OVERS:-5}"
|
| 31 |
+
|
| 32 |
+
# --------------------------------------------------------------------------
|
| 33 |
+
# Stage 1+2: warmup → main (chained)
|
| 34 |
+
# --------------------------------------------------------------------------
|
| 35 |
+
echo "=== [$(date '+%H:%M:%S')] FULL PIPELINE: warmup → main → eval ==="
|
| 36 |
+
bash scripts/run_warmup_then_main.sh
|
| 37 |
+
PIPE_STATUS=$?
|
| 38 |
+
if [ $PIPE_STATUS -ne 0 ]; then
|
| 39 |
+
echo "!!! Training pipeline failed (exit $PIPE_STATUS). Skipping eval."
|
| 40 |
+
exit $PIPE_STATUS
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
# --------------------------------------------------------------------------
|
| 44 |
+
# Stage 3: baseline eval (untrained Qwen3.5-4B)
|
| 45 |
+
# --------------------------------------------------------------------------
|
| 46 |
+
mkdir -p eval_results
|
| 47 |
+
echo "=== [$(date '+%H:%M:%S')] EVAL: baseline (untrained Qwen3.5-4B) ==="
|
| 48 |
+
python compare_eval.py \
|
| 49 |
+
--model Qwen/Qwen3.5-4B \
|
| 50 |
+
--label baseline \
|
| 51 |
+
--episodes "$EVAL_EPISODES" \
|
| 52 |
+
--max-overs "$EVAL_MAX_OVERS" \
|
| 53 |
+
--opponent-mode heuristic \
|
| 54 |
+
--output eval_results/baseline.json \
|
| 55 |
+
> /tmp/eval_baseline.log 2>&1
|
| 56 |
+
|
| 57 |
+
# --------------------------------------------------------------------------
|
| 58 |
+
# Stage 4: trained eval (warmup + main adapter)
|
| 59 |
+
# --------------------------------------------------------------------------
|
| 60 |
+
echo "=== [$(date '+%H:%M:%S')] EVAL: trained (LoRA from ./checkpoints/stage2_final) ==="
|
| 61 |
+
python compare_eval.py \
|
| 62 |
+
--model Qwen/Qwen3.5-4B \
|
| 63 |
+
--adapter ./checkpoints/stage2_final \
|
| 64 |
+
--label trained \
|
| 65 |
+
--episodes "$EVAL_EPISODES" \
|
| 66 |
+
--max-overs "$EVAL_MAX_OVERS" \
|
| 67 |
+
--opponent-mode heuristic \
|
| 68 |
+
--output eval_results/trained.json \
|
| 69 |
+
> /tmp/eval_trained.log 2>&1
|
| 70 |
+
|
| 71 |
+
# --------------------------------------------------------------------------
|
| 72 |
+
# Stage 5: comparison
|
| 73 |
+
# --------------------------------------------------------------------------
|
| 74 |
+
echo "=== [$(date '+%H:%M:%S')] COMPARISON ==="
|
| 75 |
+
python compare_eval.py --compare eval_results/baseline.json eval_results/trained.json \
|
| 76 |
+
| tee eval_results/comparison.txt
|
| 77 |
+
|
| 78 |
+
echo ""
|
| 79 |
+
echo "=== [$(date '+%H:%M:%S')] DONE ==="
|
| 80 |
+
echo "Trained adapter: ./checkpoints/stage2_final/"
|
| 81 |
+
echo "Eval JSON: eval_results/{baseline,trained}.json"
|
| 82 |
+
echo "Comparison: eval_results/comparison.txt"
|
| 83 |
+
echo "Training logs: /tmp/train_warmup.log, /tmp/train_main.log"
|
| 84 |
+
echo "Eval logs: /tmp/eval_baseline.log, /tmp/eval_trained.log"
|
scripts/run_warmup_then_main.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Chained run: warmup (2-3 over curriculum) → main (5-over T20).
|
| 3 |
+
# Main auto-resumes from warmup's LoRA adapter at ./checkpoints/stage2_final
|
| 4 |
+
# (set via configs/cricket_train_qwen3.yaml resume_from).
|
| 5 |
+
#
|
| 6 |
+
# Active configs target Qwen3-4B-Instruct-2507 + vLLM colocate. The legacy
|
| 7 |
+
# Qwen3.5-4B configs are archived in configs/extras/.
|
| 8 |
+
#
|
| 9 |
+
# Usage: bash scripts/run_warmup_then_main.sh
|
| 10 |
+
# Logs: /tmp/train_warmup.log then /tmp/train_main.log
|
| 11 |
+
|
| 12 |
+
# NOTE: deliberately NOT using `set -e`. We want to inspect the warmup exit
|
| 13 |
+
# code and decide whether to continue, not abort on first non-zero return.
|
| 14 |
+
cd "$(dirname "$0")/.."
|
| 15 |
+
|
| 16 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 17 |
+
export TRL_EXPERIMENTAL_SILENCE=1
|
| 18 |
+
|
| 19 |
+
echo "=== [$(date '+%H:%M:%S')] WARMUP starting (2-3 over curriculum, 30 steps) ==="
|
| 20 |
+
python train.py train --config configs/cricket_train_qwen3_warmup.yaml \
|
| 21 |
+
> /tmp/train_warmup.log 2>&1
|
| 22 |
+
|
| 23 |
+
WARMUP_EXIT=$?
|
| 24 |
+
if [ $WARMUP_EXIT -ne 0 ]; then
|
| 25 |
+
echo "!!! WARMUP failed with exit $WARMUP_EXIT — see /tmp/train_warmup.log"
|
| 26 |
+
echo " Skipping main run."
|
| 27 |
+
exit $WARMUP_EXIT
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
# Sanity: confirm checkpoint exists before launching main.
|
| 31 |
+
if [ ! -d ./checkpoints/stage2_final ]; then
|
| 32 |
+
echo "!!! Expected ./checkpoints/stage2_final not found after warmup."
|
| 33 |
+
echo " Main run wants to resume from there — aborting."
|
| 34 |
+
exit 1
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
echo "=== [$(date '+%H:%M:%S')] WARMUP done — adapter saved to ./checkpoints/stage2_final ==="
|
| 38 |
+
echo "=== [$(date '+%H:%M:%S')] MAIN starting (5-over, 100 steps, resuming from warmup) ==="
|
| 39 |
+
|
| 40 |
+
python train.py train --config configs/cricket_train_qwen3.yaml \
|
| 41 |
+
> /tmp/train_main.log 2>&1
|
| 42 |
+
|
| 43 |
+
echo "=== [$(date '+%H:%M:%S')] MAIN done ==="
|
server/coherence_grader.py
CHANGED
|
@@ -64,37 +64,44 @@ def bowling_coherence_score(
|
|
| 64 |
) -> float:
|
| 65 |
"""
|
| 66 |
Grade bowling strategy coherence.
|
| 67 |
-
Weights: 40% rationale
|
|
|
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
if not bowling_strategy:
|
| 70 |
return 0.0
|
| 71 |
-
|
| 72 |
rationale = bowling_strategy.get("rationale", "")
|
| 73 |
r_spec = rationale_specificity(rationale)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
if field_setting == "Aggressive":
|
| 81 |
-
if line in
|
| 82 |
-
logic_score = 1.0
|
| 83 |
elif field_setting == "Defensive":
|
| 84 |
-
if line in
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# Phase appropriate (e.g., spin in middle overs)
|
| 90 |
-
p_approp = 1.0
|
| 91 |
bowler_type = bowling_strategy.get("bowler_type", "pace")
|
| 92 |
if phase == "middle" and bowler_type == "spin":
|
| 93 |
p_approp = 1.0
|
| 94 |
-
elif phase
|
| 95 |
p_approp = 1.0
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
return round(score, 4)
|
| 99 |
|
| 100 |
|
|
|
|
| 64 |
) -> float:
|
| 65 |
"""
|
| 66 |
Grade bowling strategy coherence.
|
| 67 |
+
Weights (from game_knowledge.yaml): 40% rationale + 30% field logic + 30% phase fit.
|
| 68 |
+
|
| 69 |
+
Line/length values must already be normalized (normalize_line / normalize_length
|
| 70 |
+
from field_model.py) — e.g. "pads" not "on pads", "outside_off" not "outside off".
|
| 71 |
"""
|
| 72 |
if not bowling_strategy:
|
| 73 |
return 0.0
|
| 74 |
+
|
| 75 |
rationale = bowling_strategy.get("rationale", "")
|
| 76 |
r_spec = rationale_specificity(rationale)
|
| 77 |
+
|
| 78 |
+
line = bowling_strategy.get("line", "outside_off")
|
| 79 |
+
length = bowling_strategy.get("length", "good")
|
| 80 |
+
|
| 81 |
+
# Attacking plan: attack the stumps/pads with short/full threatening lengths
|
| 82 |
+
_ATTACKING_LINES = {"stumps", "pads"}
|
| 83 |
+
_ATTACKING_LENGTHS = {"bouncer", "short", "yorker"}
|
| 84 |
+
# Containing plan: bowl wide or full to restrict scoring
|
| 85 |
+
_DEFENSIVE_LINES = {"outside_off", "wide"}
|
| 86 |
+
_DEFENSIVE_LENGTHS = {"yorker", "full"}
|
| 87 |
+
|
| 88 |
if field_setting == "Aggressive":
|
| 89 |
+
logic_score = 1.0 if (line in _ATTACKING_LINES or length in _ATTACKING_LENGTHS) else 0.5
|
|
|
|
| 90 |
elif field_setting == "Defensive":
|
| 91 |
+
logic_score = 1.0 if (line in _DEFENSIVE_LINES or length in _DEFENSIVE_LENGTHS) else 0.5
|
| 92 |
+
else: # Balanced
|
| 93 |
+
logic_score = 0.8
|
| 94 |
+
|
| 95 |
+
# Phase appropriateness: spin in middle, pace in powerplay/death
|
|
|
|
|
|
|
| 96 |
bowler_type = bowling_strategy.get("bowler_type", "pace")
|
| 97 |
if phase == "middle" and bowler_type == "spin":
|
| 98 |
p_approp = 1.0
|
| 99 |
+
elif phase in {"powerplay", "death"} and bowler_type == "pace":
|
| 100 |
p_approp = 1.0
|
| 101 |
+
else:
|
| 102 |
+
p_approp = 0.6
|
| 103 |
+
|
| 104 |
+
score = 0.40 * r_spec + 0.30 * logic_score + 0.30 * p_approp
|
| 105 |
return round(score, 4)
|
| 106 |
|
| 107 |
|
server/cricket_environment.py
CHANGED
|
@@ -163,7 +163,7 @@ class CricketEnvironment(Environment):
|
|
| 163 |
start_wickets = self._rng.randint(0, 9)
|
| 164 |
start_score = int(start_over * self._rng.uniform(5.5, 8.5))
|
| 165 |
|
| 166 |
-
start_phase = over_to_phase(start_over)
|
| 167 |
start_bowler = sample_bowler_type(start_phase, self._rng)
|
| 168 |
|
| 169 |
self._state = CricketState(
|
|
@@ -206,6 +206,20 @@ class CricketEnvironment(Environment):
|
|
| 206 |
# Load roster for the agent's team
|
| 207 |
agent_team = options.get("agent_team", os.environ.get("CRICKET_AGENT_TEAM", "india"))
|
| 208 |
self._agent_roster = load_team_roster(agent_team)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Reset tool budget
|
| 210 |
self._overhead_calls_this_over = 0
|
| 211 |
self._total_tool_fines = 0.0
|
|
@@ -427,6 +441,23 @@ class CricketEnvironment(Environment):
|
|
| 427 |
if shot_intent not in VALID_SHOT_INTENTS:
|
| 428 |
self._format_violations += 1
|
| 429 |
shot_intent = "defensive"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
context = self._context_for_policy()
|
| 431 |
self._opponent_plan = self._opponent.bowling_plan(context)
|
| 432 |
self._state.opponent_plan = self._opponent_plan
|
|
@@ -559,6 +590,23 @@ class CricketEnvironment(Environment):
|
|
| 559 |
return self._build_obs(last_ball=f"Field set to {setting}.")
|
| 560 |
|
| 561 |
def _handle_bowl_delivery(self, args: dict) -> CricketObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
context = self._context_for_policy()
|
| 563 |
self._opponent_plan = self._opponent.batting_plan(context)
|
| 564 |
self._state.opponent_plan = self._opponent_plan
|
|
@@ -599,6 +647,7 @@ class CricketEnvironment(Environment):
|
|
| 599 |
field_setting=self._field_setting,
|
| 600 |
wickets=self._state.wickets_lost,
|
| 601 |
score=self._state.total_score,
|
|
|
|
| 602 |
)
|
| 603 |
metadata = {"event_type": "base_outcome", "target_area": normalize_target_area("", shot_intent)}
|
| 604 |
return self._process_delivery(runs, wicket, extra, shot_intent, dismissal_type, metadata)
|
|
@@ -843,7 +892,7 @@ class CricketEnvironment(Environment):
|
|
| 843 |
return round(max(-0.15, min(0.15, reward)), 4)
|
| 844 |
|
| 845 |
def _update_phase(self):
|
| 846 |
-
new_phase = over_to_phase(self._state.over)
|
| 847 |
if new_phase != self._state.phase:
|
| 848 |
self._state.phase = new_phase
|
| 849 |
self._bowler_type = sample_bowler_type(new_phase, self._rng)
|
|
@@ -1083,8 +1132,8 @@ class CricketEnvironment(Environment):
|
|
| 1083 |
if wicket:
|
| 1084 |
return 0.0
|
| 1085 |
if self._state.game_state == GameState.BATTING:
|
| 1086 |
-
expected = self._engine.expected_runs(self._state.over, shot_intent, self._bowler_type)
|
| 1087 |
-
baseline = max(self._engine.expected_runs(self._state.over, shot, self._bowler_type) for shot in SHOT_INTENTS)
|
| 1088 |
return round(min(1.0, expected / max(baseline, 0.1)), 4)
|
| 1089 |
# For bowling, lower conceded runs are better.
|
| 1090 |
return round(max(0.0, 1.0 - (runs / 6.0)), 4)
|
|
@@ -1189,7 +1238,10 @@ def _render_prompt(ctx, batting_strat, bowling_strat, shot_plan, delivery_plan,
|
|
| 1189 |
budget = 3
|
| 1190 |
budget_str = f"Tool budget: {overhead_used}/{budget} overhead calls used this over"
|
| 1191 |
if overhead_used >= budget:
|
| 1192 |
-
budget_str +=
|
|
|
|
|
|
|
|
|
|
| 1193 |
lines.append(budget_str)
|
| 1194 |
# Opponent plan intentionally NOT shown — agent must infer via analyze_situation
|
| 1195 |
|
|
|
|
| 163 |
start_wickets = self._rng.randint(0, 9)
|
| 164 |
start_score = int(start_over * self._rng.uniform(5.5, 8.5))
|
| 165 |
|
| 166 |
+
start_phase = over_to_phase(start_over, max_overs)
|
| 167 |
start_bowler = sample_bowler_type(start_phase, self._rng)
|
| 168 |
|
| 169 |
self._state = CricketState(
|
|
|
|
| 206 |
# Load roster for the agent's team
|
| 207 |
agent_team = options.get("agent_team", os.environ.get("CRICKET_AGENT_TEAM", "india"))
|
| 208 |
self._agent_roster = load_team_roster(agent_team)
|
| 209 |
+
if self._agent_roster:
|
| 210 |
+
playing_xi = build_playing_xi(self._agent_roster)
|
| 211 |
+
batters = [p for p in playing_xi if p.get("role") != "bowler"]
|
| 212 |
+
bowlers = [p for p in playing_xi if p.get("bowler_type")]
|
| 213 |
+
if len(batters) >= 2:
|
| 214 |
+
self._current_batter = batter_profile_from_player(batters[0])
|
| 215 |
+
self._non_striker = batter_profile_from_player(batters[1])
|
| 216 |
+
elif len(playing_xi) >= 2:
|
| 217 |
+
self._current_batter = batter_profile_from_player(playing_xi[0])
|
| 218 |
+
self._non_striker = batter_profile_from_player(playing_xi[1])
|
| 219 |
+
matching_bowler = next((p for p in bowlers if p.get("bowler_type") == self._bowler_type), None)
|
| 220 |
+
if matching_bowler or bowlers:
|
| 221 |
+
self._current_bowler = bowler_profile_from_player(matching_bowler or bowlers[0])
|
| 222 |
+
self._bowler_type = self._current_bowler["type"]
|
| 223 |
# Reset tool budget
|
| 224 |
self._overhead_calls_this_over = 0
|
| 225 |
self._total_tool_fines = 0.0
|
|
|
|
| 441 |
if shot_intent not in VALID_SHOT_INTENTS:
|
| 442 |
self._format_violations += 1
|
| 443 |
shot_intent = "defensive"
|
| 444 |
+
# Inline shot-plan capture: if the model passes target_area/risk/trajectory/
|
| 445 |
+
# rationale directly to play_delivery (the new "execute-first" pattern that
|
| 446 |
+
# collapses plan_shot+play_delivery into a single turn), update the shot plan
|
| 447 |
+
# and score adaptation/opp_awareness here so we don't lose the reward signal.
|
| 448 |
+
has_plan_args = any(args.get(k) for k in ("target_area", "risk", "trajectory", "rationale", "explanation"))
|
| 449 |
+
if has_plan_args:
|
| 450 |
+
target_area = normalize_target_area(args.get("target_area", "gaps"), shot_intent)
|
| 451 |
+
risk = str(args.get("risk", "balanced")).lower()
|
| 452 |
+
self._shot_plan = {
|
| 453 |
+
"shot_intent": shot_intent,
|
| 454 |
+
"target_area": target_area,
|
| 455 |
+
"trajectory": infer_trajectory(shot_intent, risk, args.get("trajectory")),
|
| 456 |
+
"risk": risk,
|
| 457 |
+
"rationale": str(args.get("rationale", args.get("explanation", ""))),
|
| 458 |
+
}
|
| 459 |
+
self._score_adaptation(self._shot_plan)
|
| 460 |
+
self._score_opponent_awareness(self._shot_plan)
|
| 461 |
context = self._context_for_policy()
|
| 462 |
self._opponent_plan = self._opponent.bowling_plan(context)
|
| 463 |
self._state.opponent_plan = self._opponent_plan
|
|
|
|
| 590 |
return self._build_obs(last_ball=f"Field set to {setting}.")
|
| 591 |
|
| 592 |
def _handle_bowl_delivery(self, args: dict) -> CricketObservation:
|
| 593 |
+
# Inline delivery-plan capture: if the model passes line/length/delivery_type
|
| 594 |
+
# directly to bowl_delivery (the new "execute-first" pattern that collapses
|
| 595 |
+
# plan_delivery+bowl_delivery into a single turn), update the delivery plan
|
| 596 |
+
# and score adaptation/opp_awareness here so we don't lose the reward signal.
|
| 597 |
+
has_plan_args = any(args.get(k) for k in ("line", "length", "delivery_type", "rationale"))
|
| 598 |
+
if has_plan_args:
|
| 599 |
+
current_type = str(self._current_bowler.get("type", self._bowler_type)).lower()
|
| 600 |
+
self._delivery_plan = {
|
| 601 |
+
"bowler_type": current_type,
|
| 602 |
+
"line": normalize_line(args.get("line", "outside off")),
|
| 603 |
+
"length": normalize_length(args.get("length", "good length")),
|
| 604 |
+
"delivery_type": normalize_variation(args.get("delivery_type", "stock"), current_type),
|
| 605 |
+
"rationale": str(args.get("rationale", "")),
|
| 606 |
+
}
|
| 607 |
+
self._bowling_strategy = dict(self._delivery_plan)
|
| 608 |
+
self._score_adaptation(self._delivery_plan)
|
| 609 |
+
self._score_opponent_awareness(self._delivery_plan)
|
| 610 |
context = self._context_for_policy()
|
| 611 |
self._opponent_plan = self._opponent.batting_plan(context)
|
| 612 |
self._state.opponent_plan = self._opponent_plan
|
|
|
|
| 647 |
field_setting=self._field_setting,
|
| 648 |
wickets=self._state.wickets_lost,
|
| 649 |
score=self._state.total_score,
|
| 650 |
+
max_overs=self._state.max_overs,
|
| 651 |
)
|
| 652 |
metadata = {"event_type": "base_outcome", "target_area": normalize_target_area("", shot_intent)}
|
| 653 |
return self._process_delivery(runs, wicket, extra, shot_intent, dismissal_type, metadata)
|
|
|
|
| 892 |
return round(max(-0.15, min(0.15, reward)), 4)
|
| 893 |
|
| 894 |
def _update_phase(self):
|
| 895 |
+
new_phase = over_to_phase(self._state.over, self._state.max_overs)
|
| 896 |
if new_phase != self._state.phase:
|
| 897 |
self._state.phase = new_phase
|
| 898 |
self._bowler_type = sample_bowler_type(new_phase, self._rng)
|
|
|
|
| 1132 |
if wicket:
|
| 1133 |
return 0.0
|
| 1134 |
if self._state.game_state == GameState.BATTING:
|
| 1135 |
+
expected = self._engine.expected_runs(self._state.over, shot_intent, self._bowler_type, self._state.max_overs)
|
| 1136 |
+
baseline = max(self._engine.expected_runs(self._state.over, shot, self._bowler_type, self._state.max_overs) for shot in SHOT_INTENTS)
|
| 1137 |
return round(min(1.0, expected / max(baseline, 0.1)), 4)
|
| 1138 |
# For bowling, lower conceded runs are better.
|
| 1139 |
return round(max(0.0, 1.0 - (runs / 6.0)), 4)
|
|
|
|
| 1238 |
budget = 3
|
| 1239 |
budget_str = f"Tool budget: {overhead_used}/{budget} overhead calls used this over"
|
| 1240 |
if overhead_used >= budget:
|
| 1241 |
+
budget_str += (
|
| 1242 |
+
" ⚠ BUDGET EXHAUSTED — further set_strategy, set_bowling_strategy, plan_delivery, "
|
| 1243 |
+
"analyze_situation, or reflect_after_ball calls will be FINED"
|
| 1244 |
+
)
|
| 1245 |
lines.append(budget_str)
|
| 1246 |
# Opponent plan intentionally NOT shown — agent must infer via analyze_situation
|
| 1247 |
|
server/markov_engine.py
CHANGED
|
@@ -65,12 +65,19 @@ BOWLER_TYPES = ["pace", "spin"]
|
|
| 65 |
EXTRAS_RATE = 0.05
|
| 66 |
|
| 67 |
|
| 68 |
-
def over_to_phase(over: int) -> str:
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def sample_bowler_type(phase: str, rng: random.Random) -> str:
|
|
@@ -118,6 +125,7 @@ class MarkovCricketEngine:
|
|
| 118 |
score: int = 0,
|
| 119 |
bowler_type: str = "pace",
|
| 120 |
field_setting: str = "Balanced",
|
|
|
|
| 121 |
) -> tuple[int, bool, bool, str]:
|
| 122 |
"""Sample an outcome for one delivery.
|
| 123 |
|
|
@@ -131,7 +139,7 @@ class MarkovCricketEngine:
|
|
| 131 |
if self._rng.random() < EXTRAS_RATE:
|
| 132 |
return 1, False, True, ""
|
| 133 |
|
| 134 |
-
phase = over_to_phase(over)
|
| 135 |
|
| 136 |
if self._cricsheet is not None:
|
| 137 |
runs, wicket = self._cricsheet_step(over, wickets, score, phase, bowler_type, shot_intent)
|
|
@@ -214,6 +222,7 @@ class MarkovCricketEngine:
|
|
| 214 |
score=score,
|
| 215 |
bowler_type=bowler_type,
|
| 216 |
field_setting="Balanced",
|
|
|
|
| 217 |
)
|
| 218 |
if extra:
|
| 219 |
metadata.update({"event_type": "wide", "fielder_effect": "none", "base_runs": runs, "base_wicket": wicket})
|
|
@@ -274,12 +283,13 @@ class MarkovCricketEngine:
|
|
| 274 |
field_setting: str = "Balanced",
|
| 275 |
wickets: int = 0,
|
| 276 |
score: int = 0,
|
|
|
|
| 277 |
) -> tuple[int, bool, bool, str]:
|
| 278 |
"""Simulate an AI batter faced with the agent's bowling/fielding.
|
| 279 |
|
| 280 |
Returns (runs, wicket, extra, shot_intent).
|
| 281 |
"""
|
| 282 |
-
phase = over_to_phase(over)
|
| 283 |
|
| 284 |
# Decide AI batter's shot intent based on state and phase
|
| 285 |
# Aggression increases in death overs or with wickets in hand
|
|
@@ -311,15 +321,16 @@ class MarkovCricketEngine:
|
|
| 311 |
wickets=wickets,
|
| 312 |
score=score,
|
| 313 |
bowler_type=bowler_type,
|
| 314 |
-
field_setting=field_setting
|
|
|
|
| 315 |
)
|
| 316 |
|
| 317 |
return runs, wicket, extra, shot_intent, dismissal_type
|
| 318 |
|
| 319 |
-
def expected_runs(self, over: int, shot_intent: str, bowler_type: str = "pace") -> float:
|
| 320 |
if shot_intent not in SHOT_AGGRESSION:
|
| 321 |
return 0.0
|
| 322 |
-
phase = over_to_phase(over)
|
| 323 |
if self._cricsheet:
|
| 324 |
dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
|
| 325 |
if dist:
|
|
@@ -327,10 +338,10 @@ class MarkovCricketEngine:
|
|
| 327 |
dist = self._synthetic[shot_intent][phase]
|
| 328 |
return sum(r * p for r, _, p in dist)
|
| 329 |
|
| 330 |
-
def wicket_probability(self, over: int, shot_intent: str, bowler_type: str = "pace") -> float:
|
| 331 |
if shot_intent not in SHOT_AGGRESSION:
|
| 332 |
return 0.0
|
| 333 |
-
phase = over_to_phase(over)
|
| 334 |
if self._cricsheet:
|
| 335 |
dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
|
| 336 |
if dist:
|
|
|
|
| 65 |
EXTRAS_RATE = 0.05
|
| 66 |
|
| 67 |
|
| 68 |
+
def over_to_phase(over: int, max_overs: int | None = None) -> str:
|
| 69 |
+
"""Return the phase label for a given over, respecting the match format.
|
| 70 |
+
|
| 71 |
+
Without max_overs the old hardcoded thresholds (designed for ODI) would
|
| 72 |
+
leave T20 overs 16-19 classified as "middle" instead of "death". We now
|
| 73 |
+
delegate to format_mapper.get_phase which reads the correct phase windows
|
| 74 |
+
from data/format_rules.json.
|
| 75 |
+
"""
|
| 76 |
+
try:
|
| 77 |
+
from server.format_mapper import get_phase
|
| 78 |
+
except ImportError:
|
| 79 |
+
from .format_mapper import get_phase
|
| 80 |
+
return get_phase(over, max_overs)
|
| 81 |
|
| 82 |
|
| 83 |
def sample_bowler_type(phase: str, rng: random.Random) -> str:
|
|
|
|
| 125 |
score: int = 0,
|
| 126 |
bowler_type: str = "pace",
|
| 127 |
field_setting: str = "Balanced",
|
| 128 |
+
max_overs: int | None = None,
|
| 129 |
) -> tuple[int, bool, bool, str]:
|
| 130 |
"""Sample an outcome for one delivery.
|
| 131 |
|
|
|
|
| 139 |
if self._rng.random() < EXTRAS_RATE:
|
| 140 |
return 1, False, True, ""
|
| 141 |
|
| 142 |
+
phase = over_to_phase(over, max_overs)
|
| 143 |
|
| 144 |
if self._cricsheet is not None:
|
| 145 |
runs, wicket = self._cricsheet_step(over, wickets, score, phase, bowler_type, shot_intent)
|
|
|
|
| 222 |
score=score,
|
| 223 |
bowler_type=bowler_type,
|
| 224 |
field_setting="Balanced",
|
| 225 |
+
max_overs=max_overs,
|
| 226 |
)
|
| 227 |
if extra:
|
| 228 |
metadata.update({"event_type": "wide", "fielder_effect": "none", "base_runs": runs, "base_wicket": wicket})
|
|
|
|
| 283 |
field_setting: str = "Balanced",
|
| 284 |
wickets: int = 0,
|
| 285 |
score: int = 0,
|
| 286 |
+
max_overs: int | None = None,
|
| 287 |
) -> tuple[int, bool, bool, str]:
|
| 288 |
"""Simulate an AI batter faced with the agent's bowling/fielding.
|
| 289 |
|
| 290 |
Returns (runs, wicket, extra, shot_intent).
|
| 291 |
"""
|
| 292 |
+
phase = over_to_phase(over, max_overs)
|
| 293 |
|
| 294 |
# Decide AI batter's shot intent based on state and phase
|
| 295 |
# Aggression increases in death overs or with wickets in hand
|
|
|
|
| 321 |
wickets=wickets,
|
| 322 |
score=score,
|
| 323 |
bowler_type=bowler_type,
|
| 324 |
+
field_setting=field_setting,
|
| 325 |
+
max_overs=max_overs,
|
| 326 |
)
|
| 327 |
|
| 328 |
return runs, wicket, extra, shot_intent, dismissal_type
|
| 329 |
|
| 330 |
+
def expected_runs(self, over: int, shot_intent: str, bowler_type: str = "pace", max_overs: int | None = None) -> float:
|
| 331 |
if shot_intent not in SHOT_AGGRESSION:
|
| 332 |
return 0.0
|
| 333 |
+
phase = over_to_phase(over, max_overs)
|
| 334 |
if self._cricsheet:
|
| 335 |
dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
|
| 336 |
if dist:
|
|
|
|
| 338 |
dist = self._synthetic[shot_intent][phase]
|
| 339 |
return sum(r * p for r, _, p in dist)
|
| 340 |
|
| 341 |
+
def wicket_probability(self, over: int, shot_intent: str, bowler_type: str = "pace", max_overs: int | None = None) -> float:
|
| 342 |
if shot_intent not in SHOT_AGGRESSION:
|
| 343 |
return 0.0
|
| 344 |
+
phase = over_to_phase(over, max_overs)
|
| 345 |
if self._cricsheet:
|
| 346 |
dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
|
| 347 |
if dist:
|
server/reward_calculator.py
CHANGED
|
@@ -130,36 +130,47 @@ def compute_episode_reward(
|
|
| 130 |
chase_progress = total_score / max(target, 1)
|
| 131 |
wicket_penalty = wickets_lost * 0.08
|
| 132 |
if total_score >= target:
|
| 133 |
-
outcome_bonus = 1.0
|
| 134 |
elif total_score == target - 1:
|
| 135 |
-
outcome_bonus = 0.
|
| 136 |
else:
|
| 137 |
-
outcome_bonus =
|
| 138 |
r_cric = chase_progress + outcome_bonus - wicket_penalty
|
| 139 |
else:
|
| 140 |
# Bowling to defend: reward keeping opponent below target.
|
| 141 |
defense_margin = max(target - total_score, 0) / max(target, 1)
|
| 142 |
wicket_pressure = wickets_lost * 0.08
|
| 143 |
if total_score < target - 1:
|
| 144 |
-
outcome_bonus = 1.0
|
| 145 |
elif total_score == target - 1:
|
| 146 |
-
outcome_bonus = 0.
|
| 147 |
else:
|
| 148 |
-
outcome_bonus =
|
| 149 |
r_cric = defense_margin + wicket_pressure + outcome_bonus
|
| 150 |
elif game_state == "batting":
|
| 151 |
r_cric = (total_score / max(dls_par, 1.0)) - (wickets_lost * 0.08)
|
| 152 |
else:
|
| 153 |
# Bowling first innings: reward conceding fewer runs than DLS par.
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# r_cricket: dense per-ball position signal via Dream11 proxy.
|
| 158 |
# Normalised per innings then averaged so two-innings totals stay in [0, 1].
|
|
|
|
|
|
|
|
|
|
| 159 |
if dream11_scores:
|
| 160 |
r_dream11 = mean(normalize_dream11(s) for s in dream11_scores)
|
|
|
|
| 161 |
else:
|
| 162 |
r_dream11 = 0.0
|
|
|
|
| 163 |
|
| 164 |
# Load weights from game_knowledge.yaml (cached after first load).
|
| 165 |
w = get_reward_weights() if get_reward_weights is not None else None
|
|
@@ -179,12 +190,21 @@ def compute_episode_reward(
|
|
| 179 |
r_tools = compute_tool_efficiency(tool_calls_made, analyze_calls, overs_played)
|
| 180 |
|
| 181 |
eff_behavior_w = (w.r_behavior if w else 0.15) * coherence_weight_ramp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
composite = (
|
| 184 |
-
|
| 185 |
-
+
|
| 186 |
+ eff_behavior_w * r_strategy
|
| 187 |
-
+
|
| 188 |
)
|
| 189 |
|
| 190 |
return {
|
|
|
|
| 130 |
chase_progress = total_score / max(target, 1)
|
| 131 |
wicket_penalty = wickets_lost * 0.08
|
| 132 |
if total_score >= target:
|
| 133 |
+
outcome_bonus = 1.0 # win
|
| 134 |
elif total_score == target - 1:
|
| 135 |
+
outcome_bonus = 0.0 # tie — neutral, no consolation
|
| 136 |
else:
|
| 137 |
+
outcome_bonus = -1.0 # loss — explicit negative signal
|
| 138 |
r_cric = chase_progress + outcome_bonus - wicket_penalty
|
| 139 |
else:
|
| 140 |
# Bowling to defend: reward keeping opponent below target.
|
| 141 |
defense_margin = max(target - total_score, 0) / max(target, 1)
|
| 142 |
wicket_pressure = wickets_lost * 0.08
|
| 143 |
if total_score < target - 1:
|
| 144 |
+
outcome_bonus = 1.0 # win (defended)
|
| 145 |
elif total_score == target - 1:
|
| 146 |
+
outcome_bonus = 0.0 # tie
|
| 147 |
else:
|
| 148 |
+
outcome_bonus = -1.0 # loss (target chased)
|
| 149 |
r_cric = defense_margin + wicket_pressure + outcome_bonus
|
| 150 |
elif game_state == "batting":
|
| 151 |
r_cric = (total_score / max(dls_par, 1.0)) - (wickets_lost * 0.08)
|
| 152 |
else:
|
| 153 |
# Bowling first innings: reward conceding fewer runs than DLS par.
|
| 154 |
+
# Allow negative when conceding above par (was clamped to ≥0; now signed).
|
| 155 |
+
r_cric = (dls_par - total_score) / max(dls_par, 1.0)
|
| 156 |
+
# Progress bonus: small reward for actually executing balls instead of getting
|
| 157 |
+
# stuck in a planning loop. Reduced cap (0.25 → 0.10) so it doesn't drown out
|
| 158 |
+
# the loss penalty. Caps at +0.10 once the agent makes >=10 tool calls.
|
| 159 |
+
progress_bonus = min(0.10, tool_calls_made / 100.0)
|
| 160 |
+
r_cric = r_cric + progress_bonus
|
| 161 |
+
r_cric = max(-2.0, min(2.5, r_cric))
|
| 162 |
|
| 163 |
# r_cricket: dense per-ball position signal via Dream11 proxy.
|
| 164 |
# Normalised per innings then averaged so two-innings totals stay in [0, 1].
|
| 165 |
+
# When no innings has completed yet, dream11_scores is empty: redistribute
|
| 166 |
+
# its weight to r_result so the composite stays in the same [0, 1] range
|
| 167 |
+
# instead of silently capping at 0.75.
|
| 168 |
if dream11_scores:
|
| 169 |
r_dream11 = mean(normalize_dream11(s) for s in dream11_scores)
|
| 170 |
+
r_cricket_available = True
|
| 171 |
else:
|
| 172 |
r_dream11 = 0.0
|
| 173 |
+
r_cricket_available = False
|
| 174 |
|
| 175 |
# Load weights from game_knowledge.yaml (cached after first load).
|
| 176 |
w = get_reward_weights() if get_reward_weights is not None else None
|
|
|
|
| 190 |
r_tools = compute_tool_efficiency(tool_calls_made, analyze_calls, overs_played)
|
| 191 |
|
| 192 |
eff_behavior_w = (w.r_behavior if w else 0.15) * coherence_weight_ramp
|
| 193 |
+
r_result_w = w.r_result if w else 0.55
|
| 194 |
+
r_cricket_w = w.r_cricket if w else 0.25
|
| 195 |
+
r_validity_w = w.r_validity if w else 0.05
|
| 196 |
+
|
| 197 |
+
# If no innings has completed, fold r_cricket weight into r_result so the
|
| 198 |
+
# composite ceiling stays at 1.0 and gradients are not systematically suppressed.
|
| 199 |
+
if not r_cricket_available:
|
| 200 |
+
r_result_w = r_result_w + r_cricket_w
|
| 201 |
+
r_cricket_w = 0.0
|
| 202 |
|
| 203 |
composite = (
|
| 204 |
+
r_result_w * r_cric
|
| 205 |
+
+ r_cricket_w * r_dream11
|
| 206 |
+ eff_behavior_w * r_strategy
|
| 207 |
+
+ r_validity_w * r_format
|
| 208 |
)
|
| 209 |
|
| 210 |
return {
|
train.py
CHANGED
|
@@ -2,26 +2,32 @@
|
|
| 2 |
MT-GRPO training script for CricketCaptain-LLM.
|
| 3 |
|
| 4 |
Two-stage curriculum (ToolRL-style):
|
| 5 |
-
Stage 1
|
| 6 |
-
Stage 2
|
| 7 |
|
| 8 |
Design:
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
python train.py
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
import argparse
|
|
|
|
|
|
|
| 20 |
import json
|
| 21 |
import os
|
| 22 |
import random
|
| 23 |
import re
|
| 24 |
import sys
|
|
|
|
| 25 |
import time
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import Any
|
|
@@ -32,12 +38,16 @@ from typing import Any
|
|
| 32 |
try:
|
| 33 |
import torch
|
| 34 |
from datasets import Dataset
|
|
|
|
| 35 |
from trl import GRPOConfig, GRPOTrainer
|
| 36 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 37 |
_TRAIN_IMPORTS_AVAILABLE = True
|
| 38 |
except ImportError:
|
| 39 |
torch = None
|
| 40 |
Dataset = None
|
|
|
|
|
|
|
|
|
|
| 41 |
GRPOConfig = None
|
| 42 |
GRPOTrainer = None
|
| 43 |
AutoModelForCausalLM = None
|
|
@@ -49,12 +59,14 @@ try:
|
|
| 49 |
from server.cricket_environment import CricketEnvironment
|
| 50 |
from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
|
| 51 |
from server.markov_engine import SHOT_AGGRESSION
|
|
|
|
| 52 |
from models import CricketAction
|
| 53 |
from config_yaml import get_game_constants, get_reward_weights
|
| 54 |
except ImportError:
|
| 55 |
from cricket_captain.server.cricket_environment import CricketEnvironment
|
| 56 |
from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
|
| 57 |
from cricket_captain.server.markov_engine import SHOT_AGGRESSION
|
|
|
|
| 58 |
from cricket_captain.models import CricketAction
|
| 59 |
from cricket_captain.config_yaml import get_game_constants, get_reward_weights
|
| 60 |
|
|
@@ -106,37 +118,172 @@ def extract_phase_from_prompt(prompt: str) -> str:
|
|
| 106 |
# Per-turn reward components (all stateless) #
|
| 107 |
# ------------------------------------------------------------------ #
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def _parse_completion(raw: str) -> dict | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
raw = raw.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
if raw.startswith("```"):
|
| 112 |
lines = raw.split("\n")
|
| 113 |
raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw
|
|
|
|
|
|
|
| 114 |
try:
|
| 115 |
return json.loads(raw)
|
| 116 |
except (json.JSONDecodeError, ValueError):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
return None
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def r_validity(completion: str) -> float:
|
| 121 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
data = _parse_completion(completion)
|
| 123 |
if data is None:
|
| 124 |
return 0.0
|
|
|
|
|
|
|
| 125 |
tool = data.get("tool", "")
|
| 126 |
args = data.get("arguments", {})
|
|
|
|
|
|
|
| 127 |
if tool not in _VALID_TOOLS:
|
| 128 |
-
return 0.
|
|
|
|
|
|
|
| 129 |
if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION:
|
| 130 |
-
return 0.
|
| 131 |
if tool == "set_strategy":
|
| 132 |
agg = args.get("aggression")
|
| 133 |
if not isinstance(agg, (int, float)):
|
| 134 |
-
return 0.
|
| 135 |
if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION:
|
| 136 |
-
return 0.
|
| 137 |
if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}:
|
| 138 |
if args.get("bowler_type") not in (None, "pace", "spin"):
|
| 139 |
-
return 0.
|
| 140 |
return 1.0
|
| 141 |
|
| 142 |
|
|
@@ -299,13 +446,9 @@ def make_reward_fn(curriculum_stage: int):
|
|
| 299 |
Returns reward_fn(prompts, completions, **kwargs) → list[float].
|
| 300 |
|
| 301 |
Weights align with compute_episode_reward in reward_calculator.py:
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
r_validity 5% → 0.25 of stateless composite
|
| 306 |
-
|
| 307 |
-
Scaling: behaviour=0.75, validity=0.25 preserves the relative 15:5 ratio
|
| 308 |
-
from the episode-level rubric while using only what's available per-turn.
|
| 309 |
"""
|
| 310 |
# Minimum reward for any structurally valid completion — ensures GRPO has a
|
| 311 |
# positive gradient to reinforce valid tool use even for unscored tool types.
|
|
@@ -315,8 +458,18 @@ def make_reward_fn(curriculum_stage: int):
|
|
| 315 |
rewards = []
|
| 316 |
for prompt, completion in zip(prompts, completions):
|
| 317 |
fmt = r_validity(completion)
|
|
|
|
| 318 |
if curriculum_stage == 1:
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
continue
|
| 321 |
|
| 322 |
behavior = r_behavior_stateless(prompt, completion)
|
|
@@ -327,9 +480,12 @@ def make_reward_fn(curriculum_stage: int):
|
|
| 327 |
+ _RW.behavior_adaptation * adapt
|
| 328 |
+ _RW.behavior_opponent_awareness * aware
|
| 329 |
)
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
| 331 |
# Floor: valid JSON should always beat invalid JSON (reward=0)
|
| 332 |
-
if fmt > 0.0:
|
| 333 |
reward = max(reward, _VALID_FLOOR)
|
| 334 |
rewards.append(round(reward, 4))
|
| 335 |
return rewards
|
|
@@ -344,66 +500,136 @@ def make_reward_fn(curriculum_stage: int):
|
|
| 344 |
|
| 345 |
SYSTEM_PROMPT = (
|
| 346 |
"You are an expert adaptive cricket captain. Each turn you receive a scorecard "
|
| 347 |
-
"and must
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
"Available tools:\n"
|
| 349 |
-
" call_toss
|
| 350 |
-
" select_batter
|
| 351 |
-
" set_strategy
|
| 352 |
-
" plan_shot
|
| 353 |
-
" play_delivery
|
| 354 |
-
" choose_bowler
|
| 355 |
-
" set_bowling_strategy
|
| 356 |
-
" plan_delivery
|
| 357 |
-
" set_field_setting
|
| 358 |
-
" bowl_delivery
|
| 359 |
-
" reflect_after_ball
|
| 360 |
-
" analyze_situation
|
| 361 |
"Shot intents: leave | defensive | single | rotate | boundary | six\n\n"
|
| 362 |
-
"
|
| 363 |
-
"
|
| 364 |
)
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
_RANDOM_SHOTS = list(SHOT_AGGRESSION.keys())
|
| 367 |
_RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"]
|
| 368 |
_RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"]
|
| 369 |
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
def _random_action(
|
| 372 |
rng: random.Random,
|
| 373 |
game_state: str = "batting",
|
| 374 |
available_tools: list[str] | None = None,
|
| 375 |
current_bowler_type: str | None = None,
|
|
|
|
| 376 |
) -> CricketAction:
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
if game_state == "toss":
|
| 389 |
-
return
|
| 390 |
tool="call_toss",
|
| 391 |
arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])},
|
| 392 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
if game_state == "bowling":
|
| 394 |
choice = rng.random()
|
| 395 |
-
if choice < 0.15:
|
| 396 |
-
|
|
|
|
| 397 |
tool="choose_bowler",
|
| 398 |
arguments={
|
| 399 |
-
"name":
|
| 400 |
-
"bowler_type":
|
| 401 |
-
"style":
|
| 402 |
-
"rationale": "Match bowler to phase
|
| 403 |
},
|
| 404 |
-
)
|
| 405 |
-
if choice < 0.35:
|
| 406 |
-
return
|
| 407 |
tool="plan_delivery",
|
| 408 |
arguments={
|
| 409 |
"bowler_type": current_bowler_type or rng.choice(["pace", "spin"]),
|
|
@@ -412,74 +638,96 @@ def _random_action(
|
|
| 412 |
"delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]),
|
| 413 |
"rationale": "Use field and batter style to control scoring zones",
|
| 414 |
},
|
| 415 |
-
)
|
| 416 |
-
if choice < 0.5:
|
| 417 |
-
return
|
| 418 |
-
if choice < 0.6:
|
| 419 |
-
return
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
choice = rng.random()
|
| 423 |
-
if choice < 0.15:
|
| 424 |
-
|
|
|
|
| 425 |
tool="select_batter",
|
| 426 |
arguments={
|
| 427 |
-
"name":
|
| 428 |
-
"style":
|
| 429 |
-
"aggression": round(
|
| 430 |
"rationale": "Select batter based on phase, wickets, and target pressure",
|
| 431 |
},
|
| 432 |
-
)
|
| 433 |
-
if choice < 0.3:
|
| 434 |
-
return
|
| 435 |
tool="set_strategy",
|
| 436 |
arguments={
|
| 437 |
"phase_intent": rng.choice(["attack", "consolidate", "rotate"]),
|
| 438 |
"aggression": round(rng.uniform(0.1, 0.9), 2),
|
| 439 |
-
"rationale": "
|
| 440 |
},
|
| 441 |
-
)
|
| 442 |
-
if choice < 0.45:
|
| 443 |
-
return
|
| 444 |
tool="plan_shot",
|
| 445 |
arguments={
|
| 446 |
"shot_intent": rng.choice(_RANDOM_SHOTS),
|
| 447 |
-
|
| 448 |
-
|
| 449 |
"risk": rng.choice(["low", "balanced", "high"]),
|
| 450 |
"rationale": "Plan shot against bowler, field, and required rate",
|
| 451 |
},
|
| 452 |
-
)
|
| 453 |
-
if choice < 0.55:
|
| 454 |
-
return
|
| 455 |
tool="analyze_situation",
|
| 456 |
arguments={"query_type": rng.choice(_RANDOM_QUERIES)},
|
| 457 |
-
)
|
| 458 |
-
if choice < 0.65:
|
| 459 |
-
return
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
| 464 |
|
| 465 |
|
| 466 |
def collect_prompts(
|
| 467 |
n_prompts: int,
|
| 468 |
task: str = "stage2_full",
|
| 469 |
seed: int = 42,
|
|
|
|
|
|
|
| 470 |
) -> list[str]:
|
| 471 |
"""
|
| 472 |
Collect game-state prompts by running episodes with random actions.
|
| 473 |
Returns a list of prompt strings (one per game state observation).
|
| 474 |
"""
|
| 475 |
rng = random.Random(seed)
|
|
|
|
|
|
|
| 476 |
prompts: list[str] = []
|
| 477 |
ep_count = 0
|
| 478 |
|
| 479 |
while len(prompts) < n_prompts:
|
| 480 |
env = CricketEnvironment()
|
| 481 |
-
obs = env.reset(seed=rng.randint(0, 99999), options={
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
steps = 0
|
| 484 |
|
| 485 |
while not obs.done and steps < 80:
|
|
@@ -488,10 +736,11 @@ def collect_prompts(
|
|
| 488 |
obs.game_state,
|
| 489 |
obs.available_tools,
|
| 490 |
obs.current_bowler.get("type") if obs.current_bowler else None,
|
|
|
|
| 491 |
)
|
| 492 |
obs = env.step(action)
|
| 493 |
if not obs.done:
|
| 494 |
-
prompts.append(
|
| 495 |
steps += 1
|
| 496 |
|
| 497 |
ep_count += 1
|
|
@@ -513,13 +762,548 @@ def build_dataset(prompts: list[str]) -> Dataset:
|
|
| 513 |
return Dataset.from_dict({"prompt": prompts})
|
| 514 |
|
| 515 |
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
"""Stage 0 bootstrap data: valid tool JSON for every tool family."""
|
| 518 |
rng = random.Random(seed)
|
|
|
|
| 519 |
examples = []
|
| 520 |
for _ in range(n_examples):
|
| 521 |
game_state = rng.choice(["toss", "batting", "bowling"])
|
| 522 |
-
action = _random_action(rng, game_state)
|
| 523 |
prompt = (
|
| 524 |
f"{SYSTEM_PROMPT}\n\n"
|
| 525 |
f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n"
|
|
@@ -546,28 +1330,68 @@ def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42):
|
|
| 546 |
# Model loading (plain transformers + bitsandbytes 4-bit) #
|
| 547 |
# ------------------------------------------------------------------ #
|
| 548 |
|
| 549 |
-
def load_model(model_name: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
if not _TRAIN_IMPORTS_AVAILABLE:
|
| 551 |
raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
|
| 552 |
-
print(f"Loading {model_name} …")
|
| 553 |
-
bnb_cfg = BitsAndBytesConfig(
|
| 554 |
-
load_in_4bit=True,
|
| 555 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 556 |
-
bnb_4bit_use_double_quant=True,
|
| 557 |
-
bnb_4bit_quant_type="nf4",
|
| 558 |
-
)
|
| 559 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 560 |
if tokenizer.pad_token is None:
|
| 561 |
tokenizer.pad_token = tokenizer.eos_token
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
device_map="auto",
|
| 567 |
trust_remote_code=True,
|
| 568 |
torch_dtype=torch.bfloat16,
|
|
|
|
| 569 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
print(f"Loaded. Parameters: {model.num_parameters():,}")
|
|
|
|
| 571 |
return model, tokenizer
|
| 572 |
|
| 573 |
|
|
@@ -578,52 +1402,173 @@ def load_model(model_name: str):
|
|
| 578 |
def train(args):
|
| 579 |
if not _TRAIN_IMPORTS_AVAILABLE:
|
| 580 |
raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
task = "stage1_format" if args.stage == 1 else "stage2_full"
|
| 582 |
-
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
print(f"\n=== Stage {args.stage} Training ===")
|
| 586 |
print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}")
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
# GRPO config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
config = GRPOConfig(
|
| 598 |
output_dir=out_dir,
|
|
|
|
| 599 |
num_train_epochs=1,
|
| 600 |
max_steps=args.steps,
|
| 601 |
per_device_train_batch_size=args.batch_size,
|
| 602 |
gradient_accumulation_steps=args.grad_accum,
|
| 603 |
-
learning_rate=
|
| 604 |
warmup_ratio=0.05,
|
| 605 |
lr_scheduler_type="cosine",
|
| 606 |
-
logging_steps=
|
| 607 |
-
save_steps=
|
| 608 |
-
save_total_limit=
|
| 609 |
bf16=True,
|
| 610 |
-
|
| 611 |
-
max_completion_length=256,
|
| 612 |
num_generations=args.num_generations,
|
| 613 |
-
|
| 614 |
-
|
|
|
|
|
|
|
| 615 |
log_completions=True,
|
| 616 |
seed=args.seed,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
)
|
| 618 |
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
trainer = GRPOTrainer(
|
| 622 |
model=model,
|
| 623 |
-
reward_funcs=
|
| 624 |
args=config,
|
| 625 |
train_dataset=dataset,
|
| 626 |
processing_class=tokenizer,
|
|
|
|
| 627 |
)
|
| 628 |
|
| 629 |
print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …")
|
|
@@ -659,7 +1604,11 @@ def evaluate(args):
|
|
| 659 |
|
| 660 |
for ep in range(args.eval_episodes):
|
| 661 |
env = CricketEnvironment()
|
| 662 |
-
obs = env.reset(seed=rng.randint(0, 99999), options={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
steps = 0
|
| 664 |
|
| 665 |
while not obs.done and steps < 150:
|
|
@@ -670,12 +1619,7 @@ def evaluate(args):
|
|
| 670 |
if data:
|
| 671 |
action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {}))
|
| 672 |
else:
|
| 673 |
-
|
| 674 |
-
action = CricketAction(tool="bowl_delivery", arguments={})
|
| 675 |
-
elif obs.game_state == "toss":
|
| 676 |
-
action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
|
| 677 |
-
else:
|
| 678 |
-
action = CricketAction(tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"})
|
| 679 |
|
| 680 |
obs = env.step(action)
|
| 681 |
steps += 1
|
|
@@ -706,6 +1650,7 @@ def _make_run_folder(prefix: str, model: str | None, opponent_mode: str | None,
|
|
| 706 |
def train_smoke(args):
|
| 707 |
"""Run short direct-environment training rollouts without loading a model."""
|
| 708 |
rng = random.Random(args.seed)
|
|
|
|
| 709 |
|
| 710 |
# Auto-create run folder unless --output explicitly given
|
| 711 |
if args.output:
|
|
@@ -746,6 +1691,7 @@ def train_smoke(args):
|
|
| 746 |
"eval_pack_id": args.eval_pack_id,
|
| 747 |
"opponent_mode": args.opponent_mode,
|
| 748 |
"opponent_cache_path": args.opponent_cache_path,
|
|
|
|
| 749 |
})
|
| 750 |
prompts = [_format_prompt(obs.prompt_text)]
|
| 751 |
total_reward = 0.0
|
|
@@ -763,6 +1709,7 @@ def train_smoke(args):
|
|
| 763 |
obs.game_state,
|
| 764 |
obs.available_tools,
|
| 765 |
obs.current_bowler.get("type") if obs.current_bowler else None,
|
|
|
|
| 766 |
)
|
| 767 |
obs = env.step(action)
|
| 768 |
step_end = time.perf_counter()
|
|
@@ -855,6 +1802,7 @@ def train_smoke(args):
|
|
| 855 |
def _apply_yaml_defaults(args, cfg: dict) -> None:
|
| 856 |
"""Merge YAML config values into args, CLI args take precedence."""
|
| 857 |
captain = cfg.get("captain", {}) or {}
|
|
|
|
| 858 |
env_cfg = cfg.get("env", {}) or {}
|
| 859 |
train_cfg = cfg.get("train", {}) or {}
|
| 860 |
|
|
@@ -862,17 +1810,48 @@ def _apply_yaml_defaults(args, cfg: dict) -> None:
|
|
| 862 |
if val is not None and getattr(args, attr, None) is None:
|
| 863 |
setattr(args, attr, val)
|
| 864 |
|
| 865 |
-
|
|
|
|
|
|
|
|
|
|
| 866 |
_set("api_base", captain.get("api_base"))
|
| 867 |
_set("api_key", os.environ.get(captain.get("api_key_env", "")) or None)
|
| 868 |
_set("eval_pack_id", env_cfg.get("eval_pack_id"))
|
| 869 |
-
_set("opponent_mode",
|
| 870 |
-
_set("opponent_cache_path",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
_set("max_overs", env_cfg.get("max_overs"))
|
|
|
|
| 872 |
_set("steps", train_cfg.get("steps"))
|
| 873 |
_set("prompts", train_cfg.get("prompts"))
|
| 874 |
_set("batch_size", train_cfg.get("batch_size"))
|
|
|
|
| 875 |
_set("stage", train_cfg.get("stage"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
|
| 878 |
def main():
|
|
@@ -888,15 +1867,55 @@ def main():
|
|
| 888 |
t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect")
|
| 889 |
t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps")
|
| 890 |
t.add_argument("--batch-size", type=int, default=None, dest="batch_size")
|
| 891 |
-
t.add_argument("--grad-accum", type=int, default=
|
| 892 |
-
t.add_argument("--num-generations", type=int, default=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
t.add_argument("--seed", type=int, default=42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 894 |
|
| 895 |
# eval
|
| 896 |
e = sub.add_parser("eval", help="Evaluate a checkpoint")
|
| 897 |
e.add_argument("--config", default=None)
|
| 898 |
e.add_argument("--model", default=None)
|
| 899 |
e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes")
|
|
|
|
| 900 |
e.add_argument("--seed", type=int, default=0)
|
| 901 |
|
| 902 |
# quick test (no GPU needed)
|
|
@@ -911,12 +1930,14 @@ def main():
|
|
| 911 |
smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
|
| 912 |
smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
|
| 913 |
smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
|
|
|
|
| 914 |
smoke.add_argument("--output", default=None)
|
| 915 |
smoke.add_argument("--seed", type=int, default=42)
|
| 916 |
|
| 917 |
sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples")
|
| 918 |
sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl")
|
| 919 |
sft.add_argument("--examples", type=int, default=240)
|
|
|
|
| 920 |
sft.add_argument("--seed", type=int, default=42)
|
| 921 |
|
| 922 |
args = parser.parse_args()
|
|
@@ -934,35 +1955,45 @@ def main():
|
|
| 934 |
if getattr(args, "stage", None) is None:
|
| 935 |
args.stage = 1
|
| 936 |
if getattr(args, "model", None) is None:
|
| 937 |
-
args.model = "Qwen/
|
| 938 |
if getattr(args, "steps", None) is None:
|
| 939 |
args.steps = 200
|
| 940 |
if getattr(args, "prompts", None) is None:
|
| 941 |
args.prompts = 500
|
| 942 |
if getattr(args, "batch_size", None) is None:
|
| 943 |
args.batch_size = 2
|
|
|
|
|
|
|
| 944 |
if getattr(args, "eval_pack_id", None) is None:
|
| 945 |
args.eval_pack_id = "adaptive_t20_v1"
|
| 946 |
if getattr(args, "opponent_mode", None) is None:
|
| 947 |
args.opponent_mode = "llm_live"
|
| 948 |
if getattr(args, "max_overs", None) is None:
|
| 949 |
args.max_overs = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
|
| 951 |
if args.cmd == "train":
|
| 952 |
train(args)
|
| 953 |
elif args.cmd == "eval":
|
| 954 |
evaluate(args)
|
| 955 |
elif args.cmd == "test":
|
| 956 |
-
_smoke_test()
|
| 957 |
elif args.cmd == "train-smoke":
|
| 958 |
train_smoke(args)
|
| 959 |
elif args.cmd == "sft-data":
|
| 960 |
-
generate_sft_examples(args.output, args.examples, args.seed)
|
| 961 |
else:
|
| 962 |
parser.print_help()
|
| 963 |
|
| 964 |
|
| 965 |
-
def _smoke_test():
|
| 966 |
"""Verify reward functions work correctly."""
|
| 967 |
cases = [
|
| 968 |
(
|
|
@@ -990,12 +2021,12 @@ def _smoke_test():
|
|
| 990 |
]
|
| 991 |
print("Reward function smoke test:\n")
|
| 992 |
for prompt, completion, expected in cases:
|
| 993 |
-
fmt =
|
| 994 |
-
coh =
|
| 995 |
print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}")
|
| 996 |
|
| 997 |
print("\nPrompt collection test (5 prompts):")
|
| 998 |
-
p = collect_prompts(5, task="stage1_format", seed=1)
|
| 999 |
for i, pp in enumerate(p):
|
| 1000 |
print(f" [{i}] {pp[:80].strip()} …")
|
| 1001 |
|
|
|
|
| 2 |
MT-GRPO training script for CricketCaptain-LLM.
|
| 3 |
|
| 4 |
Two-stage curriculum (ToolRL-style):
|
| 5 |
+
Stage 1: tool-call mastery — emphasize valid, phase-legal tool usage
|
| 6 |
+
Stage 2: strategic behavior — full environment-backed reward (result + cricket + behavior + validity)
|
| 7 |
|
| 8 |
Design:
|
| 9 |
+
- Training uses TRL GRPO with environment_factory=CricketCaptainToolEnv
|
| 10 |
+
- The model interacts with live CricketEnvironment instances over multi-turn tool calls
|
| 11 |
+
- Rewards are collected from the environment (environment_reward), not only from stateless prompt parsing
|
| 12 |
+
- The opponent policy is part of the environment: heuristic/cricsheet/llm_live/llm_cached
|
| 13 |
+
- Plain TRL + Transformers + bitsandbytes + PEFT (LoRA adapters for 4-bit models)
|
| 14 |
+
|
| 15 |
+
Usage (canonical Qwen3 setup):
|
| 16 |
+
python train.py train --config configs/cricket_train_qwen3_warmup.yaml # warmup
|
| 17 |
+
python train.py train --config configs/cricket_train_qwen3.yaml # main 5-over
|
| 18 |
+
|
| 19 |
+
Legacy Qwen3.5 configs live in configs/extras/.
|
| 20 |
"""
|
| 21 |
|
| 22 |
import argparse
|
| 23 |
+
import copy
|
| 24 |
+
import datetime
|
| 25 |
import json
|
| 26 |
import os
|
| 27 |
import random
|
| 28 |
import re
|
| 29 |
import sys
|
| 30 |
+
import threading
|
| 31 |
import time
|
| 32 |
from pathlib import Path
|
| 33 |
from typing import Any
|
|
|
|
| 38 |
try:
|
| 39 |
import torch
|
| 40 |
from datasets import Dataset
|
| 41 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 42 |
from trl import GRPOConfig, GRPOTrainer
|
| 43 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 44 |
_TRAIN_IMPORTS_AVAILABLE = True
|
| 45 |
except ImportError:
|
| 46 |
torch = None
|
| 47 |
Dataset = None
|
| 48 |
+
LoraConfig = None
|
| 49 |
+
get_peft_model = None
|
| 50 |
+
prepare_model_for_kbit_training = None
|
| 51 |
GRPOConfig = None
|
| 52 |
GRPOTrainer = None
|
| 53 |
AutoModelForCausalLM = None
|
|
|
|
| 59 |
from server.cricket_environment import CricketEnvironment
|
| 60 |
from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
|
| 61 |
from server.markov_engine import SHOT_AGGRESSION
|
| 62 |
+
from server.player_roster import build_playing_xi, load_team_roster
|
| 63 |
from models import CricketAction
|
| 64 |
from config_yaml import get_game_constants, get_reward_weights
|
| 65 |
except ImportError:
|
| 66 |
from cricket_captain.server.cricket_environment import CricketEnvironment
|
| 67 |
from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
|
| 68 |
from cricket_captain.server.markov_engine import SHOT_AGGRESSION
|
| 69 |
+
from cricket_captain.server.player_roster import build_playing_xi, load_team_roster
|
| 70 |
from cricket_captain.models import CricketAction
|
| 71 |
from cricket_captain.config_yaml import get_game_constants, get_reward_weights
|
| 72 |
|
|
|
|
| 118 |
# Per-turn reward components (all stateless) #
|
| 119 |
# ------------------------------------------------------------------ #
|
| 120 |
|
| 121 |
+
_XML_FN_RE = re.compile(r"<function\s*=?\s*([^>\s]+)\s*>", re.IGNORECASE)
|
| 122 |
+
_XML_PARAM_RE = re.compile(r"<parameter\s*=\s*([^>\s]+)\s*>(.*?)</parameter>", re.IGNORECASE | re.DOTALL)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
def _parse_completion(raw: str) -> dict | None:
|
| 126 |
+
"""Parse a tool-call from the raw completion into our canonical {tool, arguments} dict.
|
| 127 |
+
|
| 128 |
+
Handles four common model output patterns:
|
| 129 |
+
1. Plain JSON (ideal).
|
| 130 |
+
2. Markdown code block (```json ... ```).
|
| 131 |
+
3. Thinking-model preamble: <think>...</think> followed by JSON.
|
| 132 |
+
Qwen3/Qwen3.5 in default mode emits reasoning inside <think> tags;
|
| 133 |
+
we strip everything up to and including the closing </think> tag.
|
| 134 |
+
4. XML function-call format that Qwen3.5 was trained on:
|
| 135 |
+
<function=tool_name><parameter=foo>bar</parameter>...</function>
|
| 136 |
+
Empirically (see logs/run_2026-04-25_21-08-45) every Stage-1 completion
|
| 137 |
+
emitted this XML form instead of JSON — so we extract it as a fallback
|
| 138 |
+
to give GRPO a non-zero gradient before the model has been trained
|
| 139 |
+
onto the JSON contract.
|
| 140 |
+
"""
|
| 141 |
raw = raw.strip()
|
| 142 |
+
|
| 143 |
+
# Strip <think>...</think> preamble emitted by thinking-mode models.
|
| 144 |
+
if "<think>" in raw:
|
| 145 |
+
think_end = raw.rfind("</think>")
|
| 146 |
+
if think_end != -1:
|
| 147 |
+
raw = raw[think_end + len("</think>"):].strip()
|
| 148 |
+
|
| 149 |
if raw.startswith("```"):
|
| 150 |
lines = raw.split("\n")
|
| 151 |
raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw
|
| 152 |
+
|
| 153 |
+
# Try parsing the whole string, then fall back to the first {...} block.
|
| 154 |
try:
|
| 155 |
return json.loads(raw)
|
| 156 |
except (json.JSONDecodeError, ValueError):
|
| 157 |
+
pass
|
| 158 |
+
|
| 159 |
+
start = raw.find("{")
|
| 160 |
+
end = raw.rfind("}")
|
| 161 |
+
if start != -1 and end > start:
|
| 162 |
+
try:
|
| 163 |
+
return json.loads(raw[start : end + 1])
|
| 164 |
+
except (json.JSONDecodeError, ValueError):
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
# XML function-call fallback (Qwen3.5 default tool-call emission style).
|
| 168 |
+
fn_match = _XML_FN_RE.search(raw)
|
| 169 |
+
if fn_match:
|
| 170 |
+
tool = fn_match.group(1).strip().strip("\"'")
|
| 171 |
+
arguments: dict[str, Any] = {}
|
| 172 |
+
for pname, pval in _XML_PARAM_RE.findall(raw):
|
| 173 |
+
v = pval.strip()
|
| 174 |
+
# Coerce numeric/bool literals so downstream validators accept them.
|
| 175 |
+
try:
|
| 176 |
+
arguments[pname] = json.loads(v)
|
| 177 |
+
except (json.JSONDecodeError, ValueError):
|
| 178 |
+
arguments[pname] = v
|
| 179 |
+
return {"tool": tool, "arguments": arguments}
|
| 180 |
+
|
| 181 |
+
return None
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Bounded LRU-ish cache. Each snapshot is a deepcopy of CricketEnvironment
|
| 185 |
+
# (~1 MB) and only used by the LEGACY single-turn r_environment_rollout path,
|
| 186 |
+
# not by the multi-turn environment_factory training path. Cap at 4096 entries
|
| 187 |
+
# (~4 GB worst case) so a long collect_prompts call can't blow up host RAM.
|
| 188 |
+
_PROMPT_ENV_SNAPSHOTS: dict[str, CricketEnvironment] = {}
|
| 189 |
+
_PROMPT_SNAPSHOT_CAP = 4096
|
| 190 |
+
_ENV_REWARD_ROLLOUT_STEPS = 12
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _remember_prompt(obs_text: str, env: CricketEnvironment) -> str:
|
| 194 |
+
"""Format an observation and keep the exact env state for rollout reward."""
|
| 195 |
+
prompt = _format_prompt(obs_text)
|
| 196 |
+
if len(_PROMPT_ENV_SNAPSHOTS) >= _PROMPT_SNAPSHOT_CAP:
|
| 197 |
+
# Evict oldest insertion (dict preserves insertion order in py3.7+).
|
| 198 |
+
oldest_key = next(iter(_PROMPT_ENV_SNAPSHOTS))
|
| 199 |
+
del _PROMPT_ENV_SNAPSHOTS[oldest_key]
|
| 200 |
+
_PROMPT_ENV_SNAPSHOTS[prompt] = copy.deepcopy(env)
|
| 201 |
+
return prompt
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def r_environment_rollout(prompt: str, completion: str) -> float | None:
|
| 205 |
+
"""Env-backed score for a generated tool call plus short continuation.
|
| 206 |
+
|
| 207 |
+
Returns None when the prompt was not collected from an env snapshot, allowing
|
| 208 |
+
callers to fall back to stateless scoring. Otherwise returns [0, 1], where 0
|
| 209 |
+
means invalid JSON/tool-for-state and higher values reflect the env reward.
|
| 210 |
+
"""
|
| 211 |
+
snapshot = _PROMPT_ENV_SNAPSHOTS.get(prompt)
|
| 212 |
+
if snapshot is None:
|
| 213 |
return None
|
| 214 |
|
| 215 |
+
data = _parse_completion(completion)
|
| 216 |
+
if data is None:
|
| 217 |
+
return 0.0
|
| 218 |
+
|
| 219 |
+
tool = data.get("tool", "")
|
| 220 |
+
args = data.get("arguments", {})
|
| 221 |
+
if not isinstance(args, dict):
|
| 222 |
+
return 0.0
|
| 223 |
+
|
| 224 |
+
env = copy.deepcopy(snapshot)
|
| 225 |
+
if tool not in env._get_available_tools():
|
| 226 |
+
return 0.0
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
obs = env.step(CricketAction(tool=tool, arguments=args))
|
| 230 |
+
except Exception:
|
| 231 |
+
return 0.0
|
| 232 |
+
|
| 233 |
+
reward = float(obs.reward or 0.0)
|
| 234 |
+
rng = random.Random(hash(prompt + completion) & 0xFFFFFFFF)
|
| 235 |
+
roster = build_playing_xi(getattr(env, "_agent_roster", []))
|
| 236 |
+
for _ in range(_ENV_REWARD_ROLLOUT_STEPS):
|
| 237 |
+
if obs.done:
|
| 238 |
+
break
|
| 239 |
+
action = _random_action(
|
| 240 |
+
rng,
|
| 241 |
+
obs.game_state,
|
| 242 |
+
obs.available_tools,
|
| 243 |
+
obs.current_bowler.get("type") if obs.current_bowler else None,
|
| 244 |
+
roster,
|
| 245 |
+
)
|
| 246 |
+
obs = env.step(action)
|
| 247 |
+
reward += float(obs.reward or 0.0)
|
| 248 |
+
|
| 249 |
+
if obs.done and env.state.reward_breakdown:
|
| 250 |
+
reward += float(env.state.reward_breakdown.get("composite", 0.0))
|
| 251 |
+
|
| 252 |
+
# Map rollout reward into [0,1] while preserving penalties for bad tool choices.
|
| 253 |
+
return round(max(0.0, min(1.0, 0.5 + reward)), 4)
|
| 254 |
+
|
| 255 |
|
| 256 |
def r_validity(completion: str) -> float:
|
| 257 |
+
"""Schema reward for tool calling.
|
| 258 |
+
|
| 259 |
+
Exact env-executable calls receive 1.0. Malformed but parseable JSON gets a
|
| 260 |
+
small shaping signal so early GRPO has non-zero variance before the model has
|
| 261 |
+
learned the strict `{"tool": ..., "arguments": {...}}` contract.
|
| 262 |
+
"""
|
| 263 |
data = _parse_completion(completion)
|
| 264 |
if data is None:
|
| 265 |
return 0.0
|
| 266 |
+
if not isinstance(data, dict):
|
| 267 |
+
return 0.05
|
| 268 |
tool = data.get("tool", "")
|
| 269 |
args = data.get("arguments", {})
|
| 270 |
+
if "tool" not in data or "arguments" not in data:
|
| 271 |
+
return 0.15
|
| 272 |
if tool not in _VALID_TOOLS:
|
| 273 |
+
return 0.25
|
| 274 |
+
if not isinstance(args, dict):
|
| 275 |
+
return 0.35
|
| 276 |
if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION:
|
| 277 |
+
return 0.5
|
| 278 |
if tool == "set_strategy":
|
| 279 |
agg = args.get("aggression")
|
| 280 |
if not isinstance(agg, (int, float)):
|
| 281 |
+
return 0.5
|
| 282 |
if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION:
|
| 283 |
+
return 0.5
|
| 284 |
if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}:
|
| 285 |
if args.get("bowler_type") not in (None, "pace", "spin"):
|
| 286 |
+
return 0.5
|
| 287 |
return 1.0
|
| 288 |
|
| 289 |
|
|
|
|
| 446 |
Returns reward_fn(prompts, completions, **kwargs) → list[float].
|
| 447 |
|
| 448 |
Weights align with compute_episode_reward in reward_calculator.py:
|
| 449 |
+
r_env — one-step env rollout reward when prompt snapshot exists
|
| 450 |
+
r_behavior — stateless tactical/tool coherence
|
| 451 |
+
r_validity — JSON/tool schema validity
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
"""
|
| 453 |
# Minimum reward for any structurally valid completion — ensures GRPO has a
|
| 454 |
# positive gradient to reinforce valid tool use even for unscored tool types.
|
|
|
|
| 458 |
rewards = []
|
| 459 |
for prompt, completion in zip(prompts, completions):
|
| 460 |
fmt = r_validity(completion)
|
| 461 |
+
env_score = r_environment_rollout(prompt, completion)
|
| 462 |
if curriculum_stage == 1:
|
| 463 |
+
# Length-efficiency penalty: a valid JSON tool call is ≤400 chars.
|
| 464 |
+
# Models with thinking mode (Qwen3/3.5) generate 800-2000 char
|
| 465 |
+
# preambles before the JSON; penalise that verbosity so GRPO
|
| 466 |
+
# learns to emit short, direct JSON. The penalty scales from
|
| 467 |
+
# 1.0 at ≤400 chars to 0.0 at ≥2400 chars (linear).
|
| 468 |
+
_JSON_TARGET = 400
|
| 469 |
+
_RAMP_RANGE = 2000
|
| 470 |
+
length_eff = max(0.0, 1.0 - max(0, len(completion) - _JSON_TARGET) / _RAMP_RANGE)
|
| 471 |
+
base = 0.5 * fmt + 0.5 * (env_score if env_score is not None else fmt)
|
| 472 |
+
rewards.append(round(length_eff * base, 4))
|
| 473 |
continue
|
| 474 |
|
| 475 |
behavior = r_behavior_stateless(prompt, completion)
|
|
|
|
| 480 |
+ _RW.behavior_adaptation * adapt
|
| 481 |
+ _RW.behavior_opponent_awareness * aware
|
| 482 |
)
|
| 483 |
+
if env_score is None:
|
| 484 |
+
reward = _RW.training_behavior * r_beh + _RW.training_validity * fmt
|
| 485 |
+
else:
|
| 486 |
+
reward = 0.45 * env_score + 0.40 * r_beh + 0.15 * fmt
|
| 487 |
# Floor: valid JSON should always beat invalid JSON (reward=0)
|
| 488 |
+
if fmt > 0.0 and (env_score is None or env_score > 0.0):
|
| 489 |
reward = max(reward, _VALID_FLOOR)
|
| 490 |
rewards.append(round(reward, 4))
|
| 491 |
return rewards
|
|
|
|
| 500 |
|
| 501 |
SYSTEM_PROMPT = (
|
| 502 |
"You are an expert adaptive cricket captain. Each turn you receive a scorecard "
|
| 503 |
+
"and must choose exactly one cricket captaincy tool call.\n\n"
|
| 504 |
+
"EXECUTE FIRST — strict rule:\n"
|
| 505 |
+
" - The match only progresses when you call `play_delivery` (batting) or\n"
|
| 506 |
+
" `bowl_delivery` (bowling). Every other tool is overhead.\n"
|
| 507 |
+
" - Default action on EVERY ball: call `play_delivery` / `bowl_delivery` with\n"
|
| 508 |
+
" plan args INLINE: e.g. `play_delivery(shot_intent='single', risk='low', rationale='rotate')`\n"
|
| 509 |
+
" or `bowl_delivery(line='outside_off', length='good', delivery_type='stock')`.\n"
|
| 510 |
+
" - Use `set_match_plan` ONCE at the very start of an innings to declare strategy.\n"
|
| 511 |
+
" - Use `set_strategy` / `set_bowling_strategy` ONCE per phase boundary.\n"
|
| 512 |
+
" - DO NOT call `plan_shot` or `plan_delivery` (deprecated) — they only add a\n"
|
| 513 |
+
" wasted turn. Pass the same parameters to play_delivery / bowl_delivery directly.\n"
|
| 514 |
+
" - SKIP `reflect_after_ball` unless the previous ball was a wicket or boundary.\n"
|
| 515 |
+
" - You are scored on MATCH OUTCOMES, not on philosophical depth. Bloated\n"
|
| 516 |
+
" pre-ball planning truncates the episode and you forfeit the result reward.\n\n"
|
| 517 |
+
"THINKING BUDGET — HARD LIMIT:\n"
|
| 518 |
+
" - Per turn: ONE sentence of reasoning, max 30 tokens, inside <think>...</think>.\n"
|
| 519 |
+
" - Do NOT enumerate options, restate the scorecard, or re-derive the plan.\n"
|
| 520 |
+
" - Bad: '<think>This is the first ball, the field is balanced, Kohli is on strike at 0.45 aggression, I should consider...'\n"
|
| 521 |
+
" - Good: '<think>Powerplay, balanced field — single to rotate.</think>'\n"
|
| 522 |
+
" - Token budget per rollout is finite. Long thinking = match truncated = ZERO result reward.\n"
|
| 523 |
+
" - The plan you set at the start carries the strategy; do not re-derive it every ball.\n\n"
|
| 524 |
+
"Emit exactly one tool call wrapped in <tool_call>...</tool_call> XML tags. "
|
| 525 |
+
"Bare JSON without the wrapper is NOT recognized and will end the rollout.\n"
|
| 526 |
+
'Example: <tool_call>{"name": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotate strike"}}</tool_call>\n\n'
|
| 527 |
"Available tools:\n"
|
| 528 |
+
" call_toss — Call heads/tails and choose bat/bowl\n"
|
| 529 |
+
" select_batter — Choose batter profile for the match situation\n"
|
| 530 |
+
" set_strategy — Declare batting intent (aggression 0–1, rationale)\n"
|
| 531 |
+
" plan_shot — Pre-ball batting plan\n"
|
| 532 |
+
" play_delivery — Choose a shot and advance the game\n"
|
| 533 |
+
" choose_bowler — Choose bowler profile for the situation\n"
|
| 534 |
+
" set_bowling_strategy — Declare bowling line/length/type/rationale\n"
|
| 535 |
+
" plan_delivery — Pre-ball bowling plan\n"
|
| 536 |
+
" set_field_setting — Aggressive/Balanced/Defensive field\n"
|
| 537 |
+
" bowl_delivery — Execute the delivery\n"
|
| 538 |
+
" reflect_after_ball — Adapt after the previous ball\n"
|
| 539 |
+
" analyze_situation — Query pitch/bowler/field info\n\n"
|
| 540 |
"Shot intents: leave | defensive | single | rotate | boundary | six\n\n"
|
| 541 |
+
"PRIORITIES (in order): (1) finish the match, (2) win the match, (3) score well per ball.\n"
|
| 542 |
+
"Verbose reasoning forfeits all three. Decide fast, act, move on."
|
| 543 |
)
|
| 544 |
|
| 545 |
+
|
| 546 |
+
def get_system_prompt(stage: int = 2) -> str:
|
| 547 |
+
return SYSTEM_PROMPT
|
| 548 |
+
|
| 549 |
_RANDOM_SHOTS = list(SHOT_AGGRESSION.keys())
|
| 550 |
_RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"]
|
| 551 |
_RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"]
|
| 552 |
|
| 553 |
|
| 554 |
+
def _training_roster(agent_team: str | None = None) -> list[dict]:
|
| 555 |
+
team = agent_team or os.environ.get("CRICKET_AGENT_TEAM")
|
| 556 |
+
if not team:
|
| 557 |
+
raise ValueError("Roster-backed training requires --agent-team or CRICKET_AGENT_TEAM.")
|
| 558 |
+
roster = load_team_roster(team)
|
| 559 |
+
if not roster:
|
| 560 |
+
raise ValueError(f"No player profile roster found for agent team '{team}'.")
|
| 561 |
+
playing_xi = build_playing_xi(roster)
|
| 562 |
+
if len(playing_xi) < 11:
|
| 563 |
+
raise ValueError(f"Player profile roster for '{team}' could not produce a playing XI.")
|
| 564 |
+
return playing_xi
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _sample_batter(rng: random.Random, roster: list[dict]) -> dict:
|
| 568 |
+
batters = [p for p in roster if p.get("role") != "bowler"] or roster
|
| 569 |
+
if not batters:
|
| 570 |
+
raise ValueError("Roster-backed training requires at least one batting-capable player.")
|
| 571 |
+
return rng.choice(batters)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _sample_bowler(rng: random.Random, roster: list[dict]) -> dict:
|
| 575 |
+
bowlers = [p for p in roster if p.get("bowler_type")]
|
| 576 |
+
if not bowlers:
|
| 577 |
+
raise ValueError("Roster-backed training requires at least one bowling-capable player.")
|
| 578 |
+
return rng.choice(bowlers)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
def _random_action(
|
| 582 |
rng: random.Random,
|
| 583 |
game_state: str = "batting",
|
| 584 |
available_tools: list[str] | None = None,
|
| 585 |
current_bowler_type: str | None = None,
|
| 586 |
+
roster: list[dict] | None = None,
|
| 587 |
) -> CricketAction:
|
| 588 |
+
legal = set(available_tools or [])
|
| 589 |
+
|
| 590 |
+
def can(tool: str) -> bool:
|
| 591 |
+
return available_tools is None or tool in legal
|
| 592 |
+
|
| 593 |
+
def match_plan_action() -> CricketAction:
|
| 594 |
+
return CricketAction(tool="set_match_plan", arguments={
|
| 595 |
+
"powerplay_intent": "Use roster strengths to establish tempo while protecting wickets",
|
| 596 |
+
"middle_intent": "Rotate strike, attack favorable matchups, and preserve finishers",
|
| 597 |
+
"death_intent": "Commit boundary options with wickets and target pressure in mind",
|
| 598 |
+
"risk_budget": "Escalate only when phase, target, and wickets justify the risk",
|
| 599 |
+
"trigger_conditions": "Review after wicket clusters, phase changes, target pressure, or repeated boundary/dot outcomes",
|
| 600 |
+
"rationale": "Create a long-horizon plan before choosing ball-by-ball tactics",
|
| 601 |
+
})
|
| 602 |
|
| 603 |
if game_state == "toss":
|
| 604 |
+
return CricketAction(
|
| 605 |
tool="call_toss",
|
| 606 |
arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])},
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
if can("set_match_plan") and rng.random() < 0.12:
|
| 610 |
+
return match_plan_action()
|
| 611 |
+
|
| 612 |
+
if can("update_match_plan") and rng.random() < 0.08:
|
| 613 |
+
return CricketAction(tool="update_match_plan", arguments={
|
| 614 |
+
"reason": "Adjust plan after phase, score pressure, wickets, and field information",
|
| 615 |
+
"risk_budget": "Shift risk based on current target pressure and wickets in hand",
|
| 616 |
+
})
|
| 617 |
+
|
| 618 |
if game_state == "bowling":
|
| 619 |
choice = rng.random()
|
| 620 |
+
if choice < 0.15 and can("choose_bowler"):
|
| 621 |
+
bowler = _sample_bowler(rng, roster or [])
|
| 622 |
+
return CricketAction(
|
| 623 |
tool="choose_bowler",
|
| 624 |
arguments={
|
| 625 |
+
"name": bowler["name"],
|
| 626 |
+
"bowler_type": bowler["bowler_type"],
|
| 627 |
+
"style": bowler.get("bowl_style", bowler.get("style", "stock")),
|
| 628 |
+
"rationale": "Match roster bowler to phase, batter matchup, and remaining overs",
|
| 629 |
},
|
| 630 |
+
)
|
| 631 |
+
if choice < 0.35 and can("plan_delivery"):
|
| 632 |
+
return CricketAction(
|
| 633 |
tool="plan_delivery",
|
| 634 |
arguments={
|
| 635 |
"bowler_type": current_bowler_type or rng.choice(["pace", "spin"]),
|
|
|
|
| 638 |
"delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]),
|
| 639 |
"rationale": "Use field and batter style to control scoring zones",
|
| 640 |
},
|
| 641 |
+
)
|
| 642 |
+
if choice < 0.5 and can("set_field_setting"):
|
| 643 |
+
return CricketAction(tool="set_field_setting", arguments={"setting": rng.choice(["Aggressive", "Balanced", "Defensive"])})
|
| 644 |
+
if choice < 0.6 and can("reflect_after_ball"):
|
| 645 |
+
return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Adjust line and field after the last ball"})
|
| 646 |
+
if can("bowl_delivery"):
|
| 647 |
+
return CricketAction(tool="bowl_delivery", arguments={})
|
| 648 |
+
if can("set_bowling_strategy"):
|
| 649 |
+
return CricketAction(tool="set_bowling_strategy", arguments={
|
| 650 |
+
"bowler_type": current_bowler_type or "pace",
|
| 651 |
+
"line": "outside off",
|
| 652 |
+
"length": "good length",
|
| 653 |
+
"delivery_type": "stock",
|
| 654 |
+
"rationale": "Set a legal bowling plan before executing the delivery",
|
| 655 |
+
})
|
| 656 |
+
raise ValueError(f"No legal bowling action available from tools={available_tools}")
|
| 657 |
|
| 658 |
choice = rng.random()
|
| 659 |
+
if choice < 0.15 and can("select_batter"):
|
| 660 |
+
batter = _sample_batter(rng, roster or [])
|
| 661 |
+
return CricketAction(
|
| 662 |
tool="select_batter",
|
| 663 |
arguments={
|
| 664 |
+
"name": batter["name"],
|
| 665 |
+
"style": batter.get("style", "balanced"),
|
| 666 |
+
"aggression": round(float(batter["aggression"]), 2),
|
| 667 |
"rationale": "Select batter based on phase, wickets, and target pressure",
|
| 668 |
},
|
| 669 |
+
)
|
| 670 |
+
if choice < 0.3 and can("set_strategy"):
|
| 671 |
+
return CricketAction(
|
| 672 |
tool="set_strategy",
|
| 673 |
arguments={
|
| 674 |
"phase_intent": rng.choice(["attack", "consolidate", "rotate"]),
|
| 675 |
"aggression": round(rng.uniform(0.1, 0.9), 2),
|
| 676 |
+
"rationale": "Align roster strengths with phase, target pressure, and wickets",
|
| 677 |
},
|
| 678 |
+
)
|
| 679 |
+
if choice < 0.45 and can("plan_shot"):
|
| 680 |
+
return CricketAction(
|
| 681 |
tool="plan_shot",
|
| 682 |
arguments={
|
| 683 |
"shot_intent": rng.choice(_RANDOM_SHOTS),
|
| 684 |
+
"target_area": rng.choice(_RANDOM_ZONES),
|
| 685 |
+
"trajectory": rng.choice(["ground", "lofted", "aerial"]),
|
| 686 |
"risk": rng.choice(["low", "balanced", "high"]),
|
| 687 |
"rationale": "Plan shot against bowler, field, and required rate",
|
| 688 |
},
|
| 689 |
+
)
|
| 690 |
+
if choice < 0.55 and can("analyze_situation"):
|
| 691 |
+
return CricketAction(
|
| 692 |
tool="analyze_situation",
|
| 693 |
arguments={"query_type": rng.choice(_RANDOM_QUERIES)},
|
| 694 |
+
)
|
| 695 |
+
if choice < 0.65 and can("reflect_after_ball"):
|
| 696 |
+
return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Revise risk after previous ball"})
|
| 697 |
+
if can("play_delivery"):
|
| 698 |
+
return CricketAction(
|
| 699 |
+
tool="play_delivery",
|
| 700 |
+
arguments={"shot_intent": rng.choice(_RANDOM_SHOTS), "explanation": "Advance the innings according to the current plan"},
|
| 701 |
+
)
|
| 702 |
+
raise ValueError(f"No legal batting action available from tools={available_tools}")
|
| 703 |
|
| 704 |
|
| 705 |
def collect_prompts(
|
| 706 |
n_prompts: int,
|
| 707 |
task: str = "stage2_full",
|
| 708 |
seed: int = 42,
|
| 709 |
+
agent_team: str | None = None,
|
| 710 |
+
opponent_mode: str = "heuristic",
|
| 711 |
) -> list[str]:
|
| 712 |
"""
|
| 713 |
Collect game-state prompts by running episodes with random actions.
|
| 714 |
Returns a list of prompt strings (one per game state observation).
|
| 715 |
"""
|
| 716 |
rng = random.Random(seed)
|
| 717 |
+
roster = _training_roster(agent_team)
|
| 718 |
+
_PROMPT_ENV_SNAPSHOTS.clear()
|
| 719 |
prompts: list[str] = []
|
| 720 |
ep_count = 0
|
| 721 |
|
| 722 |
while len(prompts) < n_prompts:
|
| 723 |
env = CricketEnvironment()
|
| 724 |
+
obs = env.reset(seed=rng.randint(0, 99999), options={
|
| 725 |
+
"task": task,
|
| 726 |
+
"random_start": True,
|
| 727 |
+
"agent_team": agent_team or os.environ.get("CRICKET_AGENT_TEAM"),
|
| 728 |
+
"opponent_mode": opponent_mode,
|
| 729 |
+
})
|
| 730 |
+
prompts.append(_remember_prompt(obs.prompt_text, env))
|
| 731 |
steps = 0
|
| 732 |
|
| 733 |
while not obs.done and steps < 80:
|
|
|
|
| 736 |
obs.game_state,
|
| 737 |
obs.available_tools,
|
| 738 |
obs.current_bowler.get("type") if obs.current_bowler else None,
|
| 739 |
+
roster,
|
| 740 |
)
|
| 741 |
obs = env.step(action)
|
| 742 |
if not obs.done:
|
| 743 |
+
prompts.append(_remember_prompt(obs.prompt_text, env))
|
| 744 |
steps += 1
|
| 745 |
|
| 746 |
ep_count += 1
|
|
|
|
| 762 |
return Dataset.from_dict({"prompt": prompts})
|
| 763 |
|
| 764 |
|
| 765 |
+
class CricketCaptainToolEnv:
|
| 766 |
+
"""TRL environment wrapper exposing CricketCaptain actions as real tools."""
|
| 767 |
+
|
| 768 |
+
_stats_lock = threading.Lock()
|
| 769 |
+
|
| 770 |
+
def __init__(self):
|
| 771 |
+
self.env = CricketEnvironment()
|
| 772 |
+
self.reward = 0.0
|
| 773 |
+
self.done = False
|
| 774 |
+
self.final_reward = 0.0
|
| 775 |
+
self._episode_seed: int | None = None
|
| 776 |
+
self._episode_started = False
|
| 777 |
+
self._max_tool_iters: int | None = None
|
| 778 |
+
self._episode_had_step = False
|
| 779 |
+
self._episode_logged = False
|
| 780 |
+
|
| 781 |
+
def _maybe_log_episode_end(self, termination_reason: str):
|
| 782 |
+
# Avoid double-logging the same episode (e.g. once at termination, again on reset()).
|
| 783 |
+
if self._episode_logged:
|
| 784 |
+
return
|
| 785 |
+
stats_path = os.environ.get("CRICKET_EPISODE_STATS_PATH")
|
| 786 |
+
if not stats_path:
|
| 787 |
+
return
|
| 788 |
+
|
| 789 |
+
state = getattr(self.env, "state", None)
|
| 790 |
+
|
| 791 |
+
payload = {
|
| 792 |
+
"ts": datetime.datetime.now().isoformat(),
|
| 793 |
+
"seed": self._episode_seed,
|
| 794 |
+
"done": bool(self.done),
|
| 795 |
+
"termination_reason": termination_reason,
|
| 796 |
+
"reward_running_sum": float(self.reward),
|
| 797 |
+
"final_reward_bonus": float(self.final_reward),
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
if state is not None:
|
| 801 |
+
# ---- match config / context ----
|
| 802 |
+
payload["max_overs"] = getattr(state, "max_overs", None)
|
| 803 |
+
payload["opponent_mode"] = getattr(state, "opponent_mode", None)
|
| 804 |
+
payload["agent_team"] = getattr(state, "eval_pack_id", None) or getattr(state, "agent_team", None)
|
| 805 |
+
payload["innings_type"] = getattr(state, "innings_type", None)
|
| 806 |
+
payload["game_state"] = getattr(state, "game_state", None)
|
| 807 |
+
|
| 808 |
+
# ---- match outcome ----
|
| 809 |
+
payload["overs_played"] = getattr(state, "over", None)
|
| 810 |
+
payload["balls_played"] = getattr(state, "ball", None)
|
| 811 |
+
payload["agent_score"] = getattr(state, "total_score", None)
|
| 812 |
+
payload["wickets_lost"] = getattr(state, "wickets_lost", None)
|
| 813 |
+
payload["first_innings_score"] = getattr(state, "first_innings_score", None)
|
| 814 |
+
payload["target"] = getattr(state, "target", None)
|
| 815 |
+
payload["match_result"] = getattr(state, "match_result", None) or None
|
| 816 |
+
|
| 817 |
+
# ---- tool calls ----
|
| 818 |
+
tool_calls_made = int(getattr(state, "tool_calls_made", 0) or 0)
|
| 819 |
+
payload["tool_calls"] = tool_calls_made
|
| 820 |
+
tool_history = getattr(state, "tool_history", None) or []
|
| 821 |
+
tool_breakdown: dict[str, int] = {}
|
| 822 |
+
for c in tool_history:
|
| 823 |
+
t = c.get("tool", "unknown")
|
| 824 |
+
tool_breakdown[t] = tool_breakdown.get(t, 0) + 1
|
| 825 |
+
payload["tool_breakdown"] = tool_breakdown
|
| 826 |
+
payload["analyze_calls"] = len(getattr(state, "analyze_calls", []) or [])
|
| 827 |
+
|
| 828 |
+
# ---- per-turn rubric averages (mean across the full episode) ----
|
| 829 |
+
def _mean(xs):
|
| 830 |
+
xs = list(xs or [])
|
| 831 |
+
return round(sum(xs) / len(xs), 4) if xs else None
|
| 832 |
+
payload["mean_coherence"] = _mean(getattr(state, "coherence_scores", None))
|
| 833 |
+
payload["mean_adaptation"] = _mean(getattr(state, "adaptation_scores", None))
|
| 834 |
+
payload["mean_opponent_awareness"] = _mean(getattr(state, "opponent_awareness_scores", None))
|
| 835 |
+
payload["mean_regret"] = _mean(getattr(state, "regret_scores", None))
|
| 836 |
+
payload["mean_plan_commitment"] = _mean(getattr(state, "plan_commitment_scores", None))
|
| 837 |
+
payload["mean_plan_freshness"] = _mean(getattr(state, "plan_freshness_scores", None))
|
| 838 |
+
payload["strategy_changes"] = getattr(state, "strategy_changes", None)
|
| 839 |
+
payload["plan_version"] = getattr(state, "plan_version", None)
|
| 840 |
+
|
| 841 |
+
# ---- composite + per-rubric reward (already computed in reward_calculator) ----
|
| 842 |
+
if getattr(state, "reward_breakdown", None):
|
| 843 |
+
payload["reward_breakdown"] = dict(state.reward_breakdown)
|
| 844 |
+
|
| 845 |
+
with self._stats_lock:
|
| 846 |
+
with open(stats_path, "a", encoding="utf-8") as f:
|
| 847 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 848 |
+
f.flush()
|
| 849 |
+
self._episode_logged = True
|
| 850 |
+
|
| 851 |
+
def reset(self, **kwargs) -> str:
|
| 852 |
+
# If the previous episode ended because the trainer hit the tool-iteration cap,
|
| 853 |
+
# TRL will stop calling tools and then call reset() for the next scenario.
|
| 854 |
+
# In that case, self.done will still be False, but tool_calls_made will be at/near the cap.
|
| 855 |
+
if self._episode_started and self._episode_had_step and not self._episode_logged:
|
| 856 |
+
prev_calls = getattr(getattr(self.env, "state", None), "tool_calls_made", None)
|
| 857 |
+
if self.done:
|
| 858 |
+
self._maybe_log_episode_end("natural")
|
| 859 |
+
elif self._max_tool_iters and prev_calls is not None and int(prev_calls) >= int(self._max_tool_iters):
|
| 860 |
+
self._maybe_log_episode_end("cap")
|
| 861 |
+
# Otherwise: trainer reset the env mid-episode (e.g. generation bookkeeping).
|
| 862 |
+
# Don't log — it would skew the termination distribution.
|
| 863 |
+
|
| 864 |
+
self.reward = 0.0
|
| 865 |
+
self.done = False
|
| 866 |
+
self.final_reward = 0.0
|
| 867 |
+
|
| 868 |
+
self._episode_seed = kwargs.get("seed")
|
| 869 |
+
self._episode_started = True
|
| 870 |
+
self._episode_had_step = False
|
| 871 |
+
self._episode_logged = False
|
| 872 |
+
self._max_tool_iters = (
|
| 873 |
+
int(kwargs["max_tool_calling_iterations"])
|
| 874 |
+
if "max_tool_calling_iterations" in kwargs and kwargs["max_tool_calling_iterations"] is not None
|
| 875 |
+
else (int(os.environ["CRICKET_MAX_TOOL_ITERS"]) if os.environ.get("CRICKET_MAX_TOOL_ITERS") else None)
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
obs = self.env.reset(seed=kwargs.get("seed"), options={
|
| 879 |
+
"task": kwargs.get("task", "stage2_full"),
|
| 880 |
+
"random_start": bool(kwargs.get("random_start", False)),
|
| 881 |
+
"max_overs": int(kwargs.get("max_overs", 5)),
|
| 882 |
+
"eval_pack_id": kwargs.get("eval_pack_id", "adaptive_t20_v1"),
|
| 883 |
+
"opponent_mode": kwargs.get("opponent_mode", "heuristic"),
|
| 884 |
+
"opponent_cache_path": kwargs.get("opponent_cache_path"),
|
| 885 |
+
"agent_team": kwargs.get("agent_team"),
|
| 886 |
+
})
|
| 887 |
+
return obs.prompt_text
|
| 888 |
+
|
| 889 |
+
def _apply(self, tool: str, arguments: dict[str, Any]) -> str:
|
| 890 |
+
if self.done:
|
| 891 |
+
raise ValueError("Match is already finished.")
|
| 892 |
+
self._episode_had_step = True
|
| 893 |
+
available = self.env.state.game_state and self.env._get_available_tools()
|
| 894 |
+
if tool not in available:
|
| 895 |
+
self.reward -= 0.2
|
| 896 |
+
raise ValueError(f"Tool '{tool}' is not available. Available tools: {available}")
|
| 897 |
+
obs = self.env.step(CricketAction(tool=tool, arguments=arguments))
|
| 898 |
+
self.done = bool(obs.done)
|
| 899 |
+
self.reward += float(obs.reward or 0.0)
|
| 900 |
+
if obs.done and self.env.state.reward_breakdown:
|
| 901 |
+
self.final_reward = float(self.env.state.reward_breakdown.get("composite", 0.0))
|
| 902 |
+
self.reward += self.final_reward
|
| 903 |
+
# Log at the time of termination (do not wait for reset()) so the file appears promptly.
|
| 904 |
+
if self.done:
|
| 905 |
+
self._maybe_log_episode_end("natural")
|
| 906 |
+
# Also log cap termination as soon as we hit it, so runs always get stats even if TRL delays reset().
|
| 907 |
+
elif self._max_tool_iters:
|
| 908 |
+
state = getattr(self.env, "state", None)
|
| 909 |
+
calls = getattr(state, "tool_calls_made", None) if state is not None else None
|
| 910 |
+
if calls is not None and int(calls) >= int(self._max_tool_iters):
|
| 911 |
+
self._maybe_log_episode_end("cap")
|
| 912 |
+
return obs.prompt_text
|
| 913 |
+
|
| 914 |
+
def call_toss(self, call: str, decision: str) -> str:
|
| 915 |
+
"""
|
| 916 |
+
Call the coin toss and choose whether to bat or bowl if the toss is won.
|
| 917 |
+
|
| 918 |
+
Args:
|
| 919 |
+
call: Coin call, either "heads" or "tails".
|
| 920 |
+
decision: Preferred decision, either "bat" or "bowl".
|
| 921 |
+
|
| 922 |
+
Returns:
|
| 923 |
+
Updated match observation after the toss.
|
| 924 |
+
"""
|
| 925 |
+
return self._apply("call_toss", {"call": call, "decision": decision})
|
| 926 |
+
|
| 927 |
+
def set_match_plan(
|
| 928 |
+
self,
|
| 929 |
+
powerplay_intent: str,
|
| 930 |
+
middle_intent: str,
|
| 931 |
+
death_intent: str,
|
| 932 |
+
risk_budget: str,
|
| 933 |
+
trigger_conditions: str,
|
| 934 |
+
rationale: str,
|
| 935 |
+
) -> str:
|
| 936 |
+
"""
|
| 937 |
+
Establish the long-horizon plan for the innings.
|
| 938 |
+
|
| 939 |
+
Args:
|
| 940 |
+
powerplay_intent: Plan for overs in the powerplay.
|
| 941 |
+
middle_intent: Plan for middle overs.
|
| 942 |
+
death_intent: Plan for death overs.
|
| 943 |
+
risk_budget: How wickets, overs, and target pressure affect risk.
|
| 944 |
+
trigger_conditions: Match-state changes that should trigger a plan update.
|
| 945 |
+
rationale: Why this plan fits the roster and match situation.
|
| 946 |
+
|
| 947 |
+
Returns:
|
| 948 |
+
Updated match observation after setting the plan.
|
| 949 |
+
"""
|
| 950 |
+
return self._apply("set_match_plan", {
|
| 951 |
+
"powerplay_intent": powerplay_intent,
|
| 952 |
+
"middle_intent": middle_intent,
|
| 953 |
+
"death_intent": death_intent,
|
| 954 |
+
"risk_budget": risk_budget,
|
| 955 |
+
"trigger_conditions": trigger_conditions,
|
| 956 |
+
"rationale": rationale,
|
| 957 |
+
})
|
| 958 |
+
|
| 959 |
+
def update_match_plan(self, reason: str, risk_budget: str = "", trigger_conditions: str = "") -> str:
|
| 960 |
+
"""
|
| 961 |
+
Update the long-horizon plan after a meaningful match-state change.
|
| 962 |
+
|
| 963 |
+
Args:
|
| 964 |
+
reason: Specific reason for updating the plan.
|
| 965 |
+
risk_budget: Optional revised risk budget.
|
| 966 |
+
trigger_conditions: Optional revised trigger conditions.
|
| 967 |
+
|
| 968 |
+
Returns:
|
| 969 |
+
Updated match observation after revising the plan.
|
| 970 |
+
"""
|
| 971 |
+
args = {"reason": reason}
|
| 972 |
+
if risk_budget:
|
| 973 |
+
args["risk_budget"] = risk_budget
|
| 974 |
+
if trigger_conditions:
|
| 975 |
+
args["trigger_conditions"] = trigger_conditions
|
| 976 |
+
return self._apply("update_match_plan", args)
|
| 977 |
+
|
| 978 |
+
def select_batter(self, name: str, style: str, aggression: float, rationale: str) -> str:
|
| 979 |
+
"""
|
| 980 |
+
Select the next batter from the configured roster.
|
| 981 |
+
|
| 982 |
+
Args:
|
| 983 |
+
name: Player name from the team roster.
|
| 984 |
+
style: Batter style from the roster or tactical role.
|
| 985 |
+
aggression: Batting aggression from 0.0 to 1.0.
|
| 986 |
+
rationale: Why this batter fits the phase, wickets, and target.
|
| 987 |
+
|
| 988 |
+
Returns:
|
| 989 |
+
Updated match observation after selecting the batter.
|
| 990 |
+
"""
|
| 991 |
+
return self._apply("select_batter", {
|
| 992 |
+
"name": name,
|
| 993 |
+
"style": style,
|
| 994 |
+
"aggression": aggression,
|
| 995 |
+
"rationale": rationale,
|
| 996 |
+
})
|
| 997 |
+
|
| 998 |
+
def set_strategy(self, phase_intent: str, aggression: float, rationale: str) -> str:
|
| 999 |
+
"""
|
| 1000 |
+
Set batting strategy for the current phase.
|
| 1001 |
+
|
| 1002 |
+
Args:
|
| 1003 |
+
phase_intent: Tactical batting intent for this phase.
|
| 1004 |
+
aggression: Batting aggression from 0.0 to 1.0.
|
| 1005 |
+
rationale: Why the strategy fits score, wickets, target, and field.
|
| 1006 |
+
|
| 1007 |
+
Returns:
|
| 1008 |
+
Updated match observation after setting batting strategy.
|
| 1009 |
+
"""
|
| 1010 |
+
return self._apply("set_strategy", {
|
| 1011 |
+
"phase_intent": phase_intent,
|
| 1012 |
+
"aggression": aggression,
|
| 1013 |
+
"rationale": rationale,
|
| 1014 |
+
})
|
| 1015 |
+
|
| 1016 |
+
def plan_shot(self, shot_intent: str, target_area: str, risk: str, trajectory: str, rationale: str) -> str:
|
| 1017 |
+
"""DEPRECATED — pass these args inline to play_delivery() instead.
|
| 1018 |
+
|
| 1019 |
+
Args:
|
| 1020 |
+
shot_intent: leave|defensive|single|rotate|boundary|six.
|
| 1021 |
+
target_area: scoring area.
|
| 1022 |
+
risk: low|balanced|high.
|
| 1023 |
+
trajectory: ground|lofted|aerial.
|
| 1024 |
+
rationale: one-line reason.
|
| 1025 |
+
|
| 1026 |
+
Returns:
|
| 1027 |
+
Updated observation.
|
| 1028 |
+
"""
|
| 1029 |
+
return self._apply("plan_shot", {
|
| 1030 |
+
"shot_intent": shot_intent,
|
| 1031 |
+
"target_area": target_area,
|
| 1032 |
+
"risk": risk,
|
| 1033 |
+
"trajectory": trajectory,
|
| 1034 |
+
"rationale": rationale,
|
| 1035 |
+
})
|
| 1036 |
+
|
| 1037 |
+
def play_delivery(
|
| 1038 |
+
self,
|
| 1039 |
+
shot_intent: str = "",
|
| 1040 |
+
target_area: str = "",
|
| 1041 |
+
risk: str = "",
|
| 1042 |
+
trajectory: str = "",
|
| 1043 |
+
rationale: str = "",
|
| 1044 |
+
) -> str:
|
| 1045 |
+
"""
|
| 1046 |
+
Execute the ball. Pass shot params inline to skip plan_shot.
|
| 1047 |
+
|
| 1048 |
+
Args:
|
| 1049 |
+
shot_intent: leave|defensive|single|rotate|boundary|six.
|
| 1050 |
+
target_area: optional scoring area.
|
| 1051 |
+
risk: optional low|balanced|high.
|
| 1052 |
+
trajectory: optional ground|lofted|aerial.
|
| 1053 |
+
rationale: optional one-line reason.
|
| 1054 |
+
|
| 1055 |
+
Returns:
|
| 1056 |
+
Updated observation after the ball outcome.
|
| 1057 |
+
"""
|
| 1058 |
+
args: dict[str, Any] = {}
|
| 1059 |
+
if shot_intent: args["shot_intent"] = shot_intent
|
| 1060 |
+
if target_area: args["target_area"] = target_area
|
| 1061 |
+
if risk: args["risk"] = risk
|
| 1062 |
+
if trajectory: args["trajectory"] = trajectory
|
| 1063 |
+
if rationale: args["rationale"] = rationale
|
| 1064 |
+
return self._apply("play_delivery", args)
|
| 1065 |
+
|
| 1066 |
+
def choose_bowler(self, name: str, bowler_type: str, style: str, rationale: str) -> str:
|
| 1067 |
+
"""
|
| 1068 |
+
Choose the bowler at the start of an over from the configured roster.
|
| 1069 |
+
|
| 1070 |
+
Args:
|
| 1071 |
+
name: Player name from the team roster.
|
| 1072 |
+
bowler_type: Bowler type, either pace or spin.
|
| 1073 |
+
style: Bowling style or role.
|
| 1074 |
+
rationale: Why this bowler fits phase, matchup, and remaining overs.
|
| 1075 |
+
|
| 1076 |
+
Returns:
|
| 1077 |
+
Updated match observation after choosing the bowler.
|
| 1078 |
+
"""
|
| 1079 |
+
return self._apply("choose_bowler", {
|
| 1080 |
+
"name": name,
|
| 1081 |
+
"bowler_type": bowler_type,
|
| 1082 |
+
"style": style,
|
| 1083 |
+
"rationale": rationale,
|
| 1084 |
+
})
|
| 1085 |
+
|
| 1086 |
+
def set_bowling_strategy(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str:
|
| 1087 |
+
"""
|
| 1088 |
+
Set bowling strategy for the current bowler.
|
| 1089 |
+
|
| 1090 |
+
Args:
|
| 1091 |
+
bowler_type: Current bowler type, either pace or spin.
|
| 1092 |
+
line: Intended line.
|
| 1093 |
+
length: Intended length.
|
| 1094 |
+
delivery_type: Variation or stock delivery type.
|
| 1095 |
+
rationale: Why this plan fits batter, field, phase, and target.
|
| 1096 |
+
|
| 1097 |
+
Returns:
|
| 1098 |
+
Updated match observation after setting bowling strategy.
|
| 1099 |
+
"""
|
| 1100 |
+
return self._apply("set_bowling_strategy", {
|
| 1101 |
+
"bowler_type": bowler_type,
|
| 1102 |
+
"line": line,
|
| 1103 |
+
"length": length,
|
| 1104 |
+
"delivery_type": delivery_type,
|
| 1105 |
+
"rationale": rationale,
|
| 1106 |
+
})
|
| 1107 |
+
|
| 1108 |
+
def plan_delivery(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str:
|
| 1109 |
+
"""DEPRECATED — pass these args inline to bowl_delivery() instead.
|
| 1110 |
+
|
| 1111 |
+
Args:
|
| 1112 |
+
bowler_type: pace|spin.
|
| 1113 |
+
line: line.
|
| 1114 |
+
length: length.
|
| 1115 |
+
delivery_type: variation or stock.
|
| 1116 |
+
rationale: one-line reason.
|
| 1117 |
+
|
| 1118 |
+
Returns:
|
| 1119 |
+
Updated observation.
|
| 1120 |
+
"""
|
| 1121 |
+
return self._apply("plan_delivery", {
|
| 1122 |
+
"bowler_type": bowler_type,
|
| 1123 |
+
"line": line,
|
| 1124 |
+
"length": length,
|
| 1125 |
+
"delivery_type": delivery_type,
|
| 1126 |
+
"rationale": rationale,
|
| 1127 |
+
})
|
| 1128 |
+
|
| 1129 |
+
def set_field_setting(self, setting: str) -> str:
|
| 1130 |
+
"""
|
| 1131 |
+
Set the field preset.
|
| 1132 |
+
|
| 1133 |
+
Args:
|
| 1134 |
+
setting: One of Aggressive, Balanced, or Defensive.
|
| 1135 |
+
|
| 1136 |
+
Returns:
|
| 1137 |
+
Updated match observation after setting the field.
|
| 1138 |
+
"""
|
| 1139 |
+
return self._apply("set_field_setting", {"setting": setting})
|
| 1140 |
+
|
| 1141 |
+
def bowl_delivery(
|
| 1142 |
+
self,
|
| 1143 |
+
line: str = "",
|
| 1144 |
+
length: str = "",
|
| 1145 |
+
delivery_type: str = "",
|
| 1146 |
+
rationale: str = "",
|
| 1147 |
+
) -> str:
|
| 1148 |
+
"""
|
| 1149 |
+
Execute the delivery. Pass plan params inline to skip plan_delivery.
|
| 1150 |
+
|
| 1151 |
+
Args:
|
| 1152 |
+
line: optional line.
|
| 1153 |
+
length: optional length.
|
| 1154 |
+
delivery_type: optional variation or stock.
|
| 1155 |
+
rationale: optional one-line reason.
|
| 1156 |
+
|
| 1157 |
+
Returns:
|
| 1158 |
+
Updated observation after the ball outcome.
|
| 1159 |
+
"""
|
| 1160 |
+
args: dict[str, Any] = {}
|
| 1161 |
+
if line: args["line"] = line
|
| 1162 |
+
if length: args["length"] = length
|
| 1163 |
+
if delivery_type: args["delivery_type"] = delivery_type
|
| 1164 |
+
if rationale: args["rationale"] = rationale
|
| 1165 |
+
return self._apply("bowl_delivery", args)
|
| 1166 |
+
|
| 1167 |
+
def reflect_after_ball(self, reflection: str) -> str:
|
| 1168 |
+
"""
|
| 1169 |
+
Reflect after the previous ball and adapt the plan.
|
| 1170 |
+
|
| 1171 |
+
Args:
|
| 1172 |
+
reflection: Specific tactical lesson from the previous ball.
|
| 1173 |
+
|
| 1174 |
+
Returns:
|
| 1175 |
+
Updated match observation after recording reflection.
|
| 1176 |
+
"""
|
| 1177 |
+
return self._apply("reflect_after_ball", {"reflection": reflection})
|
| 1178 |
+
|
| 1179 |
+
def analyze_situation(self, query_type: str) -> str:
|
| 1180 |
+
"""
|
| 1181 |
+
Analyze part of the match context.
|
| 1182 |
+
|
| 1183 |
+
Args:
|
| 1184 |
+
query_type: One of pitch_conditions, bowler_info, field_setting, or match_situation.
|
| 1185 |
+
|
| 1186 |
+
Returns:
|
| 1187 |
+
Updated observation containing the analysis result.
|
| 1188 |
+
"""
|
| 1189 |
+
return self._apply("analyze_situation", {"query_type": query_type})
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
def build_agent_dataset(n_examples: int, args) -> Dataset:
|
| 1193 |
+
if Dataset is None:
|
| 1194 |
+
raise ImportError("datasets is required for training. Install with: pip install '.[train]'")
|
| 1195 |
+
rows = []
|
| 1196 |
+
rng = random.Random(args.seed)
|
| 1197 |
+
stage_prompt = get_system_prompt(args.stage)
|
| 1198 |
+
# Curriculum distribution. If --max-overs is set, use it as a fixed format.
|
| 1199 |
+
# Otherwise sample per-scenario from a T2-heavy distribution that tapers to T5.
|
| 1200 |
+
# Rationale: T2 episodes (~72 tool calls) actually COMPLETE within our token
|
| 1201 |
+
# budget so r_result fires; T5 episodes (~180) train the model on its
|
| 1202 |
+
# eval distribution. Heavy weight on short formats early so the policy
|
| 1203 |
+
# escapes the "planning loop" before tackling longer matches.
|
| 1204 |
+
overs_distribution = getattr(args, "overs_distribution", None)
|
| 1205 |
+
fixed_overs = args.max_overs if args.max_overs and args.max_overs > 0 else None
|
| 1206 |
+
if fixed_overs is None and not overs_distribution:
|
| 1207 |
+
# default curriculum: 50% T2, 30% T3, 15% T4, 5% T5
|
| 1208 |
+
overs_distribution = [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
|
| 1209 |
+
for idx in range(n_examples):
|
| 1210 |
+
scenario_overs = fixed_overs if fixed_overs is not None else rng.choice(overs_distribution)
|
| 1211 |
+
rows.append({
|
| 1212 |
+
"prompt": [
|
| 1213 |
+
{"role": "system", "content": stage_prompt},
|
| 1214 |
+
{"role": "user", "content": ""},
|
| 1215 |
+
],
|
| 1216 |
+
"seed": rng.randint(0, 999999),
|
| 1217 |
+
"task": "stage1_format" if args.stage == 1 else "stage2_full",
|
| 1218 |
+
"random_start": False,
|
| 1219 |
+
"max_overs": scenario_overs,
|
| 1220 |
+
"eval_pack_id": args.eval_pack_id,
|
| 1221 |
+
"opponent_mode": args.opponent_mode,
|
| 1222 |
+
"opponent_cache_path": getattr(args, "opponent_cache_path", None),
|
| 1223 |
+
"agent_team": args.agent_team,
|
| 1224 |
+
"scenario_id": idx,
|
| 1225 |
+
})
|
| 1226 |
+
return Dataset.from_list(rows)
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
def environment_reward(environments, **kwargs) -> list[float]:
|
| 1230 |
+
rewards = []
|
| 1231 |
+
# Aggregate metrics across all envs in this gradient step for WandB logging.
|
| 1232 |
+
agg = {
|
| 1233 |
+
"r_result": [], "r_cricket": [], "r_behavior": [], "r_validity": [],
|
| 1234 |
+
"r_coherence": [], "r_adaptation": [], "r_opponent_awareness": [], "r_regret": [],
|
| 1235 |
+
"composite": [], "tool_calls": [], "wickets_lost": [], "agent_score": [],
|
| 1236 |
+
"matches_completed": 0, "n": 0,
|
| 1237 |
+
}
|
| 1238 |
+
tool_freq: dict[str, int] = {}
|
| 1239 |
+
for env in environments:
|
| 1240 |
+
state = env.env.state
|
| 1241 |
+
breakdown = state.reward_breakdown or {}
|
| 1242 |
+
terminal = float(breakdown.get("composite", 0.0))
|
| 1243 |
+
plan_score = (sum(state.plan_commitment_scores) / len(state.plan_commitment_scores)) if state.plan_commitment_scores else 0.0
|
| 1244 |
+
validity = 1.0 - min(1.0, len([c for c in state.tool_history if c.get("tool") == "invalid_json"]) / max(state.step_count, 1))
|
| 1245 |
+
reward = env.reward + terminal + 0.1 * plan_score + 0.05 * validity
|
| 1246 |
+
# Reward clip removed: when rollouts complete naturally, the composite
|
| 1247 |
+
# reward easily saturates [-1, 1], causing GRPO group-std → 0 and
|
| 1248 |
+
# killing the gradient signal. Let GRPO standardize the advantage itself.
|
| 1249 |
+
rewards.append(round(reward, 4))
|
| 1250 |
+
|
| 1251 |
+
# Collect for aggregate logging.
|
| 1252 |
+
agg["n"] += 1
|
| 1253 |
+
if env.done:
|
| 1254 |
+
agg["matches_completed"] += 1
|
| 1255 |
+
for k in ("r_result", "r_cricket", "r_behavior", "r_validity",
|
| 1256 |
+
"r_coherence", "r_adaptation", "r_opponent_awareness",
|
| 1257 |
+
"r_regret", "composite"):
|
| 1258 |
+
v = breakdown.get(k)
|
| 1259 |
+
if v is not None:
|
| 1260 |
+
agg[k].append(float(v))
|
| 1261 |
+
agg["tool_calls"].append(int(getattr(state, "tool_calls_made", 0) or 0))
|
| 1262 |
+
agg["wickets_lost"].append(int(getattr(state, "wickets_lost", 0) or 0))
|
| 1263 |
+
agg["agent_score"].append(int(getattr(state, "total_score", 0) or 0))
|
| 1264 |
+
for c in (state.tool_history or []):
|
| 1265 |
+
t = c.get("tool", "unknown")
|
| 1266 |
+
tool_freq[t] = tool_freq.get(t, 0) + 1
|
| 1267 |
+
|
| 1268 |
+
# WandB log — only if wandb is initialised in this process.
|
| 1269 |
+
try:
|
| 1270 |
+
import wandb
|
| 1271 |
+
if wandb.run is not None and agg["n"] > 0:
|
| 1272 |
+
log_dict: dict[str, float] = {
|
| 1273 |
+
"rollout/n_episodes": agg["n"],
|
| 1274 |
+
"rollout/matches_completed": agg["matches_completed"],
|
| 1275 |
+
"rollout/match_completion_rate": agg["matches_completed"] / agg["n"],
|
| 1276 |
+
}
|
| 1277 |
+
for k in ("r_result", "r_cricket", "r_behavior", "r_validity",
|
| 1278 |
+
"r_coherence", "r_adaptation", "r_opponent_awareness",
|
| 1279 |
+
"r_regret", "composite"):
|
| 1280 |
+
if agg[k]:
|
| 1281 |
+
log_dict[f"reward/{k}_mean"] = sum(agg[k]) / len(agg[k])
|
| 1282 |
+
log_dict[f"reward/{k}_max"] = max(agg[k])
|
| 1283 |
+
log_dict[f"reward/{k}_min"] = min(agg[k])
|
| 1284 |
+
for k in ("tool_calls", "wickets_lost", "agent_score"):
|
| 1285 |
+
if agg[k]:
|
| 1286 |
+
log_dict[f"episode/{k}_mean"] = sum(agg[k]) / len(agg[k])
|
| 1287 |
+
log_dict[f"episode/{k}_max"] = max(agg[k])
|
| 1288 |
+
# Tool usage breakdown — frequency per tool name across this step.
|
| 1289 |
+
total_tools = sum(tool_freq.values()) or 1
|
| 1290 |
+
for t, n in tool_freq.items():
|
| 1291 |
+
log_dict[f"tools/freq_{t}"] = n / total_tools
|
| 1292 |
+
wandb.log(log_dict)
|
| 1293 |
+
except Exception:
|
| 1294 |
+
# Never let logging break training.
|
| 1295 |
+
pass
|
| 1296 |
+
return rewards
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42, agent_team: str | None = None):
|
| 1300 |
"""Stage 0 bootstrap data: valid tool JSON for every tool family."""
|
| 1301 |
rng = random.Random(seed)
|
| 1302 |
+
roster = _training_roster(agent_team)
|
| 1303 |
examples = []
|
| 1304 |
for _ in range(n_examples):
|
| 1305 |
game_state = rng.choice(["toss", "batting", "bowling"])
|
| 1306 |
+
action = _random_action(rng, game_state, roster=roster)
|
| 1307 |
prompt = (
|
| 1308 |
f"{SYSTEM_PROMPT}\n\n"
|
| 1309 |
f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n"
|
|
|
|
| 1330 |
# Model loading (plain transformers + bitsandbytes 4-bit) #
|
| 1331 |
# ------------------------------------------------------------------ #
|
| 1332 |
|
| 1333 |
+
def load_model(model_name: str, *, use_vllm: bool = False, resume_adapter_from: str | None = None):
|
| 1334 |
+
"""Load base + LoRA. When use_vllm=True, base is loaded in bf16 (vLLM
|
| 1335 |
+
does not support 4-bit BNB inference); otherwise 4-bit NF4.
|
| 1336 |
+
|
| 1337 |
+
resume_adapter_from: optional path to a PEFT adapter directory (e.g. a previous
|
| 1338 |
+
checkpoint dir). If provided, loads the adapter weights instead of initializing
|
| 1339 |
+
a fresh LoRA. The base model is still loaded from `model_name`. The adapter's
|
| 1340 |
+
LoraConfig is preserved (so you can resume even if r= or alpha= drift between runs)."""
|
| 1341 |
if not _TRAIN_IMPORTS_AVAILABLE:
|
| 1342 |
raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
|
| 1343 |
+
print(f"Loading {model_name} … (use_vllm={use_vllm}, dtype={'bf16' if use_vllm else '4-bit NF4'})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1344 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 1345 |
if tokenizer.pad_token is None:
|
| 1346 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1347 |
|
| 1348 |
+
try:
|
| 1349 |
+
import flash_attn # noqa: F401
|
| 1350 |
+
attn_impl = "flash_attention_2"
|
| 1351 |
+
except ImportError:
|
| 1352 |
+
attn_impl = "sdpa"
|
| 1353 |
+
|
| 1354 |
+
load_kwargs = dict(
|
| 1355 |
device_map="auto",
|
| 1356 |
trust_remote_code=True,
|
| 1357 |
torch_dtype=torch.bfloat16,
|
| 1358 |
+
attn_implementation=attn_impl,
|
| 1359 |
)
|
| 1360 |
+
if not use_vllm:
|
| 1361 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 1362 |
+
load_in_4bit=True,
|
| 1363 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 1364 |
+
bnb_4bit_use_double_quant=True,
|
| 1365 |
+
bnb_4bit_quant_type="nf4",
|
| 1366 |
+
)
|
| 1367 |
+
|
| 1368 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
|
| 1369 |
+
if not use_vllm:
|
| 1370 |
+
model = prepare_model_for_kbit_training(model)
|
| 1371 |
+
|
| 1372 |
+
if resume_adapter_from:
|
| 1373 |
+
# Resume from a previous PEFT adapter checkpoint (e.g. warmup output).
|
| 1374 |
+
# PeftModel.from_pretrained reads the adapter_config.json from the dir,
|
| 1375 |
+
# so any r/alpha/target_modules saved with the warmup run is preserved.
|
| 1376 |
+
from peft import PeftModel
|
| 1377 |
+
adapter_path = Path(resume_adapter_from)
|
| 1378 |
+
if not adapter_path.exists():
|
| 1379 |
+
raise FileNotFoundError(f"resume_adapter_from path does not exist: {adapter_path}")
|
| 1380 |
+
print(f"Resuming LoRA adapter from {adapter_path}")
|
| 1381 |
+
model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True)
|
| 1382 |
+
else:
|
| 1383 |
+
lora_cfg = LoraConfig(
|
| 1384 |
+
r=64,
|
| 1385 |
+
lora_alpha=128,
|
| 1386 |
+
lora_dropout=0.05,
|
| 1387 |
+
bias="none",
|
| 1388 |
+
task_type="CAUSAL_LM",
|
| 1389 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 1390 |
+
)
|
| 1391 |
+
model = get_peft_model(model, lora_cfg)
|
| 1392 |
+
|
| 1393 |
print(f"Loaded. Parameters: {model.num_parameters():,}")
|
| 1394 |
+
model.print_trainable_parameters()
|
| 1395 |
return model, tokenizer
|
| 1396 |
|
| 1397 |
|
|
|
|
| 1402 |
def train(args):
|
| 1403 |
if not _TRAIN_IMPORTS_AVAILABLE:
|
| 1404 |
raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
|
| 1405 |
+
if args.opponent_mode == "llm_live":
|
| 1406 |
+
if args.opponent_model:
|
| 1407 |
+
os.environ["CRICKET_OPPONENT_MODEL"] = args.opponent_model
|
| 1408 |
+
if args.opponent_api_base:
|
| 1409 |
+
os.environ["CRICKET_OPPONENT_API_BASE"] = args.opponent_api_base
|
| 1410 |
+
if args.opponent_api_key:
|
| 1411 |
+
os.environ["CRICKET_OPPONENT_API_KEY"] = args.opponent_api_key
|
| 1412 |
task = "stage1_format" if args.stage == 1 else "stage2_full"
|
| 1413 |
+
# CRICKET_CKPT_ROOT lets a side-by-side run write checkpoints to a different
|
| 1414 |
+
# tree (e.g. ./checkpoints_smoke) without trampling an active production run.
|
| 1415 |
+
# Default unchanged: ./checkpoints/.
|
| 1416 |
+
ckpt_root = os.environ.get("CRICKET_CKPT_ROOT", "./checkpoints").rstrip("/")
|
| 1417 |
+
out_dir = f"{ckpt_root}/stage{args.stage}"
|
| 1418 |
+
save_dir = f"{ckpt_root}/stage{args.stage}_final"
|
| 1419 |
+
|
| 1420 |
+
ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 1421 |
+
log_dir = Path(f"./logs/run_{ts}_stage{args.stage}_{args.opponent_mode}")
|
| 1422 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 1423 |
+
|
| 1424 |
+
# Make episode termination stats available to the environment wrapper.
|
| 1425 |
+
# This lets us distinguish natural terminations from tool-iteration cap truncations.
|
| 1426 |
+
stats_path = log_dir / "episode_stats.jsonl"
|
| 1427 |
+
os.environ["CRICKET_EPISODE_STATS_PATH"] = str(stats_path)
|
| 1428 |
+
os.environ["CRICKET_MAX_TOOL_ITERS"] = str(args.max_tool_calling_iterations)
|
| 1429 |
+
# Create the file immediately so users can find/tail it even before the first termination.
|
| 1430 |
+
stats_path.touch(exist_ok=True)
|
| 1431 |
|
| 1432 |
print(f"\n=== Stage {args.stage} Training ===")
|
| 1433 |
print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}")
|
| 1434 |
+
print(f"Logs: {log_dir}/ | Checkpoints: {out_dir}/")
|
| 1435 |
+
print(f"max_tool_calling_iterations={args.max_tool_calling_iterations} (full 5-over match needs ~180; 20-over needs ~720)")
|
| 1436 |
+
|
| 1437 |
+
(log_dir / "metadata.json").write_text(json.dumps({
|
| 1438 |
+
"stage": args.stage, "model": args.model, "agent_team": args.agent_team,
|
| 1439 |
+
"max_overs": args.max_overs, "opponent_mode": args.opponent_mode,
|
| 1440 |
+
"prompts": args.prompts, "steps": args.steps,
|
| 1441 |
+
"batch_size": args.batch_size, "grad_accum": args.grad_accum,
|
| 1442 |
+
"num_generations": args.num_generations,
|
| 1443 |
+
"max_completion_length": args.max_completion_length,
|
| 1444 |
+
"max_tool_calling_iterations": args.max_tool_calling_iterations,
|
| 1445 |
+
"logging_steps": args.logging_steps,
|
| 1446 |
+
"timestamp": ts,
|
| 1447 |
+
}, indent=2))
|
| 1448 |
+
|
| 1449 |
+
# Build scenario seeds. TRL's environment_factory performs the actual
|
| 1450 |
+
# multi-turn rollout and tool execution during training.
|
| 1451 |
+
print("\nBuilding environment scenarios …")
|
| 1452 |
+
dataset = build_agent_dataset(args.prompts, args)
|
| 1453 |
+
|
| 1454 |
+
# Load model — bf16 if vLLM is on (vLLM rejects 4-bit BNB) or --bf16-base, else 4-bit NF4.
|
| 1455 |
+
# If resume_from is set, load the LoRA adapter from that path instead of fresh init.
|
| 1456 |
+
bf16_base = getattr(args, "use_vllm", False) or getattr(args, "bf16_base", False)
|
| 1457 |
+
resume_from = getattr(args, "resume_from", None)
|
| 1458 |
+
model, tokenizer = load_model(args.model, use_vllm=bf16_base, resume_adapter_from=resume_from)
|
| 1459 |
|
| 1460 |
# GRPO config
|
| 1461 |
+
#
|
| 1462 |
+
# Qwen3 / Qwen3.5 ship with hybrid thinking ENABLED by default. Empirically
|
| 1463 |
+
# (see logs/run_2026-04-25_21-08-45 completions parquet) every generation
|
| 1464 |
+
# spent ~1200 chars inside <think>...</think> and then emitted XML-style
|
| 1465 |
+
# <function><parameter> tags instead of the JSON tool call we asked for.
|
| 1466 |
+
# That meant 0/32 generations were parseable, _apply() never advanced the
|
| 1467 |
+
# match, and episodes always hit max_tool_calling_iterations before any
|
| 1468 |
+
# innings finished — so r_result (55% of the composite) was never earned.
|
| 1469 |
+
#
|
| 1470 |
+
chat_template_kwargs = {}
|
| 1471 |
+
generation_kwargs = {}
|
| 1472 |
+
|
| 1473 |
+
completion_len = max(args.max_completion_length, 2048)
|
| 1474 |
+
use_vllm = getattr(args, "use_vllm", False)
|
| 1475 |
+
vllm_kwargs = {}
|
| 1476 |
+
if use_vllm:
|
| 1477 |
+
# vllm_model_impl: None (default) → vLLM picks its native class. Use this for
|
| 1478 |
+
# Qwen3-* (Qwen3ForCausalLM is registered, native path with full LoRA support).
|
| 1479 |
+
# Set to "transformers" only for Qwen3.5-* where vLLM has no text-only class
|
| 1480 |
+
# registered and the native path tries to load a vision tower.
|
| 1481 |
+
vllm_kwargs = dict(
|
| 1482 |
+
use_vllm=True,
|
| 1483 |
+
vllm_mode="colocate",
|
| 1484 |
+
vllm_gpu_memory_utilization=getattr(args, "vllm_gpu_memory", 0.5),
|
| 1485 |
+
vllm_max_model_length=completion_len + 2048,
|
| 1486 |
+
)
|
| 1487 |
+
vllm_impl = getattr(args, "vllm_model_impl", None)
|
| 1488 |
+
if vllm_impl:
|
| 1489 |
+
vllm_kwargs["vllm_model_impl"] = vllm_impl
|
| 1490 |
+
|
| 1491 |
+
# Resolve hyperparameters from YAML/CLI with sensible fallbacks.
|
| 1492 |
+
lr = args.learning_rate if getattr(args, "learning_rate", None) is not None \
|
| 1493 |
+
else (2e-5 if args.stage == 1 else 1e-5)
|
| 1494 |
+
grpo_beta = getattr(args, "beta", None)
|
| 1495 |
+
grpo_temp = getattr(args, "temperature", None) or 0.8
|
| 1496 |
+
grpo_top_p = getattr(args, "top_p", None)
|
| 1497 |
+
grad_ckpt = getattr(args, "gradient_checkpointing", None)
|
| 1498 |
+
grad_ckpt_kwargs = None
|
| 1499 |
+
if grad_ckpt and getattr(args, "gradient_checkpointing_use_reentrant", None) is not None:
|
| 1500 |
+
grad_ckpt_kwargs = {"use_reentrant": bool(args.gradient_checkpointing_use_reentrant)}
|
| 1501 |
+
|
| 1502 |
+
optional_cfg = {}
|
| 1503 |
+
if grpo_beta is not None:
|
| 1504 |
+
optional_cfg["beta"] = grpo_beta
|
| 1505 |
+
if grpo_top_p is not None:
|
| 1506 |
+
optional_cfg["top_p"] = grpo_top_p
|
| 1507 |
+
if grad_ckpt is not None:
|
| 1508 |
+
optional_cfg["gradient_checkpointing"] = bool(grad_ckpt)
|
| 1509 |
+
if grad_ckpt_kwargs is not None:
|
| 1510 |
+
optional_cfg["gradient_checkpointing_kwargs"] = grad_ckpt_kwargs
|
| 1511 |
+
if getattr(args, "dataloader_pin_memory", None) is not None:
|
| 1512 |
+
optional_cfg["dataloader_pin_memory"] = bool(args.dataloader_pin_memory)
|
| 1513 |
+
if getattr(args, "dataloader_num_workers", None) is not None:
|
| 1514 |
+
optional_cfg["dataloader_num_workers"] = int(args.dataloader_num_workers)
|
| 1515 |
+
|
| 1516 |
config = GRPOConfig(
|
| 1517 |
output_dir=out_dir,
|
| 1518 |
+
logging_dir=str(log_dir / "tensorboard"),
|
| 1519 |
num_train_epochs=1,
|
| 1520 |
max_steps=args.steps,
|
| 1521 |
per_device_train_batch_size=args.batch_size,
|
| 1522 |
gradient_accumulation_steps=args.grad_accum,
|
| 1523 |
+
learning_rate=lr,
|
| 1524 |
warmup_ratio=0.05,
|
| 1525 |
lr_scheduler_type="cosine",
|
| 1526 |
+
logging_steps=args.logging_steps,
|
| 1527 |
+
save_steps=getattr(args, "save_steps", None) or 10,
|
| 1528 |
+
save_total_limit=getattr(args, "save_total_limit", None) or 20,
|
| 1529 |
bf16=True,
|
| 1530 |
+
max_completion_length=completion_len,
|
|
|
|
| 1531 |
num_generations=args.num_generations,
|
| 1532 |
+
max_tool_calling_iterations=args.max_tool_calling_iterations,
|
| 1533 |
+
temperature=grpo_temp,
|
| 1534 |
+
report_to=args.report_to,
|
| 1535 |
+
run_name=args.run_name,
|
| 1536 |
log_completions=True,
|
| 1537 |
seed=args.seed,
|
| 1538 |
+
chat_template_kwargs=chat_template_kwargs,
|
| 1539 |
+
generation_kwargs=generation_kwargs,
|
| 1540 |
+
**optional_cfg,
|
| 1541 |
+
**vllm_kwargs,
|
| 1542 |
)
|
| 1543 |
|
| 1544 |
+
# TRL's add_response_schema pattern-matches tokenizer.chat_template against
|
| 1545 |
+
# a fixed list and raises "Unrecognized chat template" if no match. Some
|
| 1546 |
+
# newer Qwen3 builds (e.g. Qwen3-4B-Instruct-2507, Aug 2025) ship a
|
| 1547 |
+
# template that differs from TRL's stored string (the Instruct release
|
| 1548 |
+
# dropped the enable_thinking block) — but the tool-call format
|
| 1549 |
+
# (<tool_call>…</tool_call>) is identical, so the appropriate schema still
|
| 1550 |
+
# parses correctly. We assign it manually before GRPOTrainer init; TRL
|
| 1551 |
+
# checks `response_schema is None` first so this is a safe override.
|
| 1552 |
+
if getattr(tokenizer, "response_schema", None) is None:
|
| 1553 |
+
try:
|
| 1554 |
+
from trl.chat_template_utils import qwen3_schema, qwen3_5_schema
|
| 1555 |
+
m = args.model.lower()
|
| 1556 |
+
if "qwen3.5" in m or "qwen3_5" in m:
|
| 1557 |
+
tokenizer.response_schema = qwen3_5_schema
|
| 1558 |
+
print("Set tokenizer.response_schema = qwen3_5_schema (manual override).")
|
| 1559 |
+
elif "qwen3" in m:
|
| 1560 |
+
tokenizer.response_schema = qwen3_schema
|
| 1561 |
+
print("Set tokenizer.response_schema = qwen3_schema (manual override).")
|
| 1562 |
+
except ImportError:
|
| 1563 |
+
pass
|
| 1564 |
|
| 1565 |
trainer = GRPOTrainer(
|
| 1566 |
model=model,
|
| 1567 |
+
reward_funcs=environment_reward,
|
| 1568 |
args=config,
|
| 1569 |
train_dataset=dataset,
|
| 1570 |
processing_class=tokenizer,
|
| 1571 |
+
environment_factory=CricketCaptainToolEnv,
|
| 1572 |
)
|
| 1573 |
|
| 1574 |
print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …")
|
|
|
|
| 1604 |
|
| 1605 |
for ep in range(args.eval_episodes):
|
| 1606 |
env = CricketEnvironment()
|
| 1607 |
+
obs = env.reset(seed=rng.randint(0, 99999), options={
|
| 1608 |
+
"task": "stage2_full",
|
| 1609 |
+
"random_start": False,
|
| 1610 |
+
"agent_team": args.agent_team,
|
| 1611 |
+
})
|
| 1612 |
steps = 0
|
| 1613 |
|
| 1614 |
while not obs.done and steps < 150:
|
|
|
|
| 1619 |
if data:
|
| 1620 |
action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {}))
|
| 1621 |
else:
|
| 1622 |
+
action = CricketAction(tool="invalid_json", arguments={})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
|
| 1624 |
obs = env.step(action)
|
| 1625 |
steps += 1
|
|
|
|
| 1650 |
def train_smoke(args):
|
| 1651 |
"""Run short direct-environment training rollouts without loading a model."""
|
| 1652 |
rng = random.Random(args.seed)
|
| 1653 |
+
roster = _training_roster(args.agent_team)
|
| 1654 |
|
| 1655 |
# Auto-create run folder unless --output explicitly given
|
| 1656 |
if args.output:
|
|
|
|
| 1691 |
"eval_pack_id": args.eval_pack_id,
|
| 1692 |
"opponent_mode": args.opponent_mode,
|
| 1693 |
"opponent_cache_path": args.opponent_cache_path,
|
| 1694 |
+
"agent_team": args.agent_team,
|
| 1695 |
})
|
| 1696 |
prompts = [_format_prompt(obs.prompt_text)]
|
| 1697 |
total_reward = 0.0
|
|
|
|
| 1709 |
obs.game_state,
|
| 1710 |
obs.available_tools,
|
| 1711 |
obs.current_bowler.get("type") if obs.current_bowler else None,
|
| 1712 |
+
roster,
|
| 1713 |
)
|
| 1714 |
obs = env.step(action)
|
| 1715 |
step_end = time.perf_counter()
|
|
|
|
| 1802 |
def _apply_yaml_defaults(args, cfg: dict) -> None:
|
| 1803 |
"""Merge YAML config values into args, CLI args take precedence."""
|
| 1804 |
captain = cfg.get("captain", {}) or {}
|
| 1805 |
+
opponent = cfg.get("opponent", {}) or {}
|
| 1806 |
env_cfg = cfg.get("env", {}) or {}
|
| 1807 |
train_cfg = cfg.get("train", {}) or {}
|
| 1808 |
|
|
|
|
| 1810 |
if val is not None and getattr(args, attr, None) is None:
|
| 1811 |
setattr(args, attr, val)
|
| 1812 |
|
| 1813 |
+
if getattr(args, "cmd", None) == "train":
|
| 1814 |
+
_set("model", train_cfg.get("model"))
|
| 1815 |
+
else:
|
| 1816 |
+
_set("model", captain.get("model"))
|
| 1817 |
_set("api_base", captain.get("api_base"))
|
| 1818 |
_set("api_key", os.environ.get(captain.get("api_key_env", "")) or None)
|
| 1819 |
_set("eval_pack_id", env_cfg.get("eval_pack_id"))
|
| 1820 |
+
_set("opponent_mode", opponent.get("mode"))
|
| 1821 |
+
_set("opponent_cache_path", opponent.get("cache_path"))
|
| 1822 |
+
_set("opponent_model", opponent.get("model"))
|
| 1823 |
+
_set("opponent_api_base", opponent.get("api_base"))
|
| 1824 |
+
api_key_env = opponent.get("api_key_env")
|
| 1825 |
+
_set("opponent_api_key", os.environ.get(api_key_env, "") if api_key_env else None)
|
| 1826 |
_set("max_overs", env_cfg.get("max_overs"))
|
| 1827 |
+
_set("agent_team", env_cfg.get("agent_team"))
|
| 1828 |
_set("steps", train_cfg.get("steps"))
|
| 1829 |
_set("prompts", train_cfg.get("prompts"))
|
| 1830 |
_set("batch_size", train_cfg.get("batch_size"))
|
| 1831 |
+
_set("grad_accum", train_cfg.get("grad_accum"))
|
| 1832 |
_set("stage", train_cfg.get("stage"))
|
| 1833 |
+
_set("num_generations", train_cfg.get("num_generations"))
|
| 1834 |
+
_set("max_completion_length", train_cfg.get("max_completion_length"))
|
| 1835 |
+
_set("max_tool_calling_iterations", train_cfg.get("max_tool_calling_iterations"))
|
| 1836 |
+
_set("logging_steps", train_cfg.get("logging_steps"))
|
| 1837 |
+
_set("report_to", train_cfg.get("report_to"))
|
| 1838 |
+
_set("run_name", train_cfg.get("run_name"))
|
| 1839 |
+
_set("learning_rate", train_cfg.get("learning_rate"))
|
| 1840 |
+
_set("beta", train_cfg.get("beta"))
|
| 1841 |
+
_set("temperature", train_cfg.get("temperature"))
|
| 1842 |
+
_set("top_p", train_cfg.get("top_p"))
|
| 1843 |
+
_set("gradient_checkpointing", train_cfg.get("gradient_checkpointing"))
|
| 1844 |
+
_set("gradient_checkpointing_use_reentrant", train_cfg.get("gradient_checkpointing_use_reentrant"))
|
| 1845 |
+
_set("dataloader_pin_memory", train_cfg.get("dataloader_pin_memory"))
|
| 1846 |
+
_set("dataloader_num_workers", train_cfg.get("dataloader_num_workers"))
|
| 1847 |
+
_set("bf16_base", train_cfg.get("bf16_base"))
|
| 1848 |
+
_set("save_steps", train_cfg.get("save_steps"))
|
| 1849 |
+
_set("save_total_limit", train_cfg.get("save_total_limit"))
|
| 1850 |
+
_set("resume_from", train_cfg.get("resume_from"))
|
| 1851 |
+
_set("overs_distribution", train_cfg.get("overs_distribution"))
|
| 1852 |
+
_set("use_vllm", train_cfg.get("use_vllm"))
|
| 1853 |
+
_set("vllm_gpu_memory", train_cfg.get("vllm_gpu_memory"))
|
| 1854 |
+
_set("vllm_model_impl", train_cfg.get("vllm_model_impl"))
|
| 1855 |
|
| 1856 |
|
| 1857 |
def main():
|
|
|
|
| 1867 |
t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect")
|
| 1868 |
t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps")
|
| 1869 |
t.add_argument("--batch-size", type=int, default=None, dest="batch_size")
|
| 1870 |
+
t.add_argument("--grad-accum", type=int, default=None, dest="grad_accum")
|
| 1871 |
+
t.add_argument("--num-generations", type=int, default=None, dest="num_generations")
|
| 1872 |
+
t.add_argument("--agent-team", default=None, dest="agent_team")
|
| 1873 |
+
t.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
|
| 1874 |
+
t.add_argument("--opponent-model", default=None, dest="opponent_model")
|
| 1875 |
+
t.add_argument("--opponent-api-base", default=None, dest="opponent_api_base")
|
| 1876 |
+
t.add_argument("--opponent-api-key", default=None, dest="opponent_api_key")
|
| 1877 |
+
t.add_argument("--max-overs", type=int, default=None, dest="max_overs")
|
| 1878 |
+
t.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
|
| 1879 |
+
t.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
|
| 1880 |
+
t.add_argument("--max-completion-length", type=int, default=None, dest="max_completion_length")
|
| 1881 |
+
t.add_argument("--max-tool-calling-iterations", type=int, default=None, dest="max_tool_calling_iterations")
|
| 1882 |
+
t.add_argument("--logging-steps", type=int, default=None, dest="logging_steps")
|
| 1883 |
+
t.add_argument("--report-to", default=None, dest="report_to")
|
| 1884 |
+
t.add_argument("--run-name", default=None, dest="run_name")
|
| 1885 |
t.add_argument("--seed", type=int, default=42)
|
| 1886 |
+
t.add_argument("--resume-from", default=None, dest="resume_from",
|
| 1887 |
+
help="Path to a previous LoRA adapter dir (e.g. ./checkpoints/stage2_final). "
|
| 1888 |
+
"When set, the adapter is loaded on top of the base model instead of a fresh init.")
|
| 1889 |
+
t.add_argument("--save-steps", type=int, default=None, dest="save_steps")
|
| 1890 |
+
t.add_argument("--save-total-limit", type=int, default=None, dest="save_total_limit")
|
| 1891 |
+
t.add_argument("--learning-rate", type=float, default=None, dest="learning_rate")
|
| 1892 |
+
t.add_argument("--beta", type=float, default=None, dest="beta",
|
| 1893 |
+
help="GRPO KL coefficient. Lower = more exploration.")
|
| 1894 |
+
t.add_argument("--temperature", type=float, default=None, dest="temperature")
|
| 1895 |
+
t.add_argument("--top-p", type=float, default=None, dest="top_p")
|
| 1896 |
+
t.add_argument("--gradient-checkpointing", action="store_true", dest="gradient_checkpointing", default=None)
|
| 1897 |
+
t.add_argument("--no-gradient-checkpointing", action="store_false", dest="gradient_checkpointing")
|
| 1898 |
+
t.add_argument("--gradient-checkpointing-use-reentrant", action="store_true",
|
| 1899 |
+
dest="gradient_checkpointing_use_reentrant", default=None)
|
| 1900 |
+
t.add_argument("--dataloader-pin-memory", action="store_true", dest="dataloader_pin_memory", default=None)
|
| 1901 |
+
t.add_argument("--dataloader-num-workers", type=int, default=None, dest="dataloader_num_workers")
|
| 1902 |
+
t.add_argument("--use-vllm", action="store_true", dest="use_vllm", default=None,
|
| 1903 |
+
help="Use vLLM-backed rollouts (colocate). Forces bf16 base.")
|
| 1904 |
+
t.add_argument("--bf16-base", action="store_true", dest="bf16_base", default=None,
|
| 1905 |
+
help="Load base model in bf16 instead of 4-bit NF4. Faster matmul on H200 since 4B fits in 8GB.")
|
| 1906 |
+
t.add_argument("--vllm-gpu-memory", type=float, default=0.5, dest="vllm_gpu_memory",
|
| 1907 |
+
help="Fraction of GPU memory reserved for vLLM (colocate). Default 0.5.")
|
| 1908 |
+
t.add_argument("--vllm-model-impl", default=None, dest="vllm_model_impl",
|
| 1909 |
+
choices=["transformers", "vllm"],
|
| 1910 |
+
help="vLLM model backend. None (default) = native vLLM class (e.g. Qwen3ForCausalLM); "
|
| 1911 |
+
"'transformers' = HF transformers backend (workaround for Qwen3.5 — flaky with LoRA).")
|
| 1912 |
|
| 1913 |
# eval
|
| 1914 |
e = sub.add_parser("eval", help="Evaluate a checkpoint")
|
| 1915 |
e.add_argument("--config", default=None)
|
| 1916 |
e.add_argument("--model", default=None)
|
| 1917 |
e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes")
|
| 1918 |
+
e.add_argument("--agent-team", default=None, dest="agent_team")
|
| 1919 |
e.add_argument("--seed", type=int, default=0)
|
| 1920 |
|
| 1921 |
# quick test (no GPU needed)
|
|
|
|
| 1930 |
smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
|
| 1931 |
smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
|
| 1932 |
smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
|
| 1933 |
+
smoke.add_argument("--agent-team", default=None, dest="agent_team")
|
| 1934 |
smoke.add_argument("--output", default=None)
|
| 1935 |
smoke.add_argument("--seed", type=int, default=42)
|
| 1936 |
|
| 1937 |
sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples")
|
| 1938 |
sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl")
|
| 1939 |
sft.add_argument("--examples", type=int, default=240)
|
| 1940 |
+
sft.add_argument("--agent-team", default=None, dest="agent_team")
|
| 1941 |
sft.add_argument("--seed", type=int, default=42)
|
| 1942 |
|
| 1943 |
args = parser.parse_args()
|
|
|
|
| 1955 |
if getattr(args, "stage", None) is None:
|
| 1956 |
args.stage = 1
|
| 1957 |
if getattr(args, "model", None) is None:
|
| 1958 |
+
args.model = "Qwen/Qwen3.5-4B"
|
| 1959 |
if getattr(args, "steps", None) is None:
|
| 1960 |
args.steps = 200
|
| 1961 |
if getattr(args, "prompts", None) is None:
|
| 1962 |
args.prompts = 500
|
| 1963 |
if getattr(args, "batch_size", None) is None:
|
| 1964 |
args.batch_size = 2
|
| 1965 |
+
if getattr(args, "grad_accum", None) is None:
|
| 1966 |
+
args.grad_accum = 4
|
| 1967 |
if getattr(args, "eval_pack_id", None) is None:
|
| 1968 |
args.eval_pack_id = "adaptive_t20_v1"
|
| 1969 |
if getattr(args, "opponent_mode", None) is None:
|
| 1970 |
args.opponent_mode = "llm_live"
|
| 1971 |
if getattr(args, "max_overs", None) is None:
|
| 1972 |
args.max_overs = 5
|
| 1973 |
+
if getattr(args, "agent_team", None) is None:
|
| 1974 |
+
args.agent_team = os.environ.get("CRICKET_AGENT_TEAM")
|
| 1975 |
+
if getattr(args, "max_tool_calling_iterations", None) is None:
|
| 1976 |
+
args.max_tool_calling_iterations = 200
|
| 1977 |
+
if getattr(args, "logging_steps", None) is None:
|
| 1978 |
+
args.logging_steps = 1
|
| 1979 |
+
if getattr(args, "report_to", None) is None:
|
| 1980 |
+
args.report_to = "none"
|
| 1981 |
|
| 1982 |
if args.cmd == "train":
|
| 1983 |
train(args)
|
| 1984 |
elif args.cmd == "eval":
|
| 1985 |
evaluate(args)
|
| 1986 |
elif args.cmd == "test":
|
| 1987 |
+
_smoke_test(args.agent_team, args.opponent_mode)
|
| 1988 |
elif args.cmd == "train-smoke":
|
| 1989 |
train_smoke(args)
|
| 1990 |
elif args.cmd == "sft-data":
|
| 1991 |
+
generate_sft_examples(args.output, args.examples, args.seed, args.agent_team)
|
| 1992 |
else:
|
| 1993 |
parser.print_help()
|
| 1994 |
|
| 1995 |
|
| 1996 |
+
def _smoke_test(agent_team: str | None, opponent_mode: str):
|
| 1997 |
"""Verify reward functions work correctly."""
|
| 1998 |
cases = [
|
| 1999 |
(
|
|
|
|
| 2021 |
]
|
| 2022 |
print("Reward function smoke test:\n")
|
| 2023 |
for prompt, completion, expected in cases:
|
| 2024 |
+
fmt = r_validity(completion)
|
| 2025 |
+
coh = r_behavior_stateless(prompt, completion)
|
| 2026 |
print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}")
|
| 2027 |
|
| 2028 |
print("\nPrompt collection test (5 prompts):")
|
| 2029 |
+
p = collect_prompts(5, task="stage1_format", seed=1, agent_team=agent_team, opponent_mode=opponent_mode)
|
| 2030 |
for i, pp in enumerate(p):
|
| 2031 |
print(f" [{i}] {pp[:80].strip()} …")
|
| 2032 |
|