pratinavseth commited on
Commit
2fc50a9
·
verified ·
1 Parent(s): 9431040

sync: today's source updates (XML-only prompt, reward unclip, neg-reward on loss, pinned versions, configs reorg)

Browse files
.gitignore CHANGED
@@ -14,9 +14,12 @@ checkpoints/
14
  training_curves.png
15
  training_summary.json
16
  wandb/
 
17
  .env
18
  unsloth_compiled_cache/
19
  .ipynb_checkpoints/
20
  .DS_Store
21
  *.log
22
  *.zip
 
 
 
14
  training_curves.png
15
  training_summary.json
16
  wandb/
17
+ logs/
18
  .env
19
  unsloth_compiled_cache/
20
  .ipynb_checkpoints/
21
  .DS_Store
22
  *.log
23
  *.zip
24
+ checkpoints_smoke/
25
+ .venv-qwen3/
PRD.md CHANGED
@@ -128,6 +128,8 @@ Each step returns a `CricketObservation` containing:
128
 
129
  The top-level objective remains long-horizon match success over many simulated matches. Dream11-style reward is auxiliary shaping, not the primary benchmark target.
130
 
 
 
131
  ### 5.4 Curriculum Stages
132
 
133
  | Stage | Episodes | Active Rubrics | Objective |
 
128
 
129
  The top-level objective remains long-horizon match success over many simulated matches. Dream11-style reward is auxiliary shaping, not the primary benchmark target.
130
 
131
+ **Tool budget (operational constraint during play and training):** per over, the environment allows **3 no-fine “overhead” tool calls** among `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, and `analyze_situation`. Each additional overhead call in that over applies a **−0.04** step reward. `plan_shot`, `set_match_plan`, `update_match_plan`, and ball-advancing tools do **not** count against this limit. Training via `train.py` (TRL GRPO with `CricketEnvironment`) uses the same rule, so the policy learns to ration analysis and re-planning across a full innings without a separate ad-hoc budget in the trainer.
132
+
133
  ### 5.4 Curriculum Stages
134
 
135
  | Stage | Episodes | Active Rubrics | Objective |
README.md CHANGED
@@ -12,12 +12,22 @@ license: mit
12
 
13
  # CricketCaptain-LLM
14
 
15
- **An RL benchmark for adaptive, opponent-aware strategic decision-making.**
16
 
17
- CricketCaptain tests whether language-model agents can plan before a ball, act, observe the result, model the opponent, and revise tactics under changing match pressure. Cricket is the domain: overs, wickets, target pressure, player roles, field settings, and hundreds of tactical decisions per match.
 
 
 
 
 
 
 
 
18
 
19
  The Hugging Face Space exposes the OpenEnv server and a Gradio demo UI at `/web`.
20
 
 
 
21
  ---
22
 
23
  ## The Problem
@@ -36,6 +46,41 @@ CricketCaptain evaluates:
36
 
37
  ---
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ## Read This First
40
 
41
  - [`docs/benchmark_explainer.md`](docs/benchmark_explainer.md): full explanation of the problem statement, OpenEnv architecture, environment loop, rewards, data curation, training, and competition compliance.
@@ -58,18 +103,19 @@ Call the coin and decide whether to bat or bowl first.
58
  ```
59
 
60
  ### 2. Batting (when agent bats)
61
- Choose the batter/strategy, plan the shot, then play each delivery.
62
  ```json
63
- {"tool": "select_batter", "arguments": {"name": "Anchor", "style": "anchor", "aggression": 0.35, "rationale": "Middle overs need wicket preservation and strike rotation."}}
 
64
  {"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Middle overs against spin — rotate strike and preserve wickets."}}
65
  {"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy gap."}}
66
  {"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Working into the gap at mid-wicket."}}
67
  ```
68
 
69
  ### 3. Bowling & Fielding (when agent bowls)
70
- Choose the bowler, set a delivery/field plan, then bowl each delivery against an opponent LLM/heuristic batter.
71
  ```json
72
- {"tool": "choose_bowler", "arguments": {"name": "Death Specialist", "bowler_type": "pace", "style": "yorker", "rationale": "Target the stumps at the death."}}
73
  {"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full"}}
74
  {"tool": "set_field_setting", "arguments": {"setting": "Aggressive"}}
75
  {"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full", "rationale": "Limit swing room against an aggressive finisher."}}
@@ -108,6 +154,8 @@ Query match intel at a small reward cost.
108
  | `bowl_delivery` | Bowling | Bowl the next delivery |
109
  | `reflect_after_ball` | Bat/Bowl | Record post-ball tactical adjustment |
110
  | `analyze_situation` | Any | Query pitch, bowler, field, or match situation |
 
 
111
 
112
  `plan_shot.target_area` is normalized into cricket zones such as `cover`, `point`, `straight`, `midwicket`, `square_leg`, `fine_leg`, `long_on`, and `long_off`. `plan_shot.trajectory` can be `ground`, `lofted`, or `aerial`. Delivery plans normalize line (`outside_off`, `stumps`, `pads`, `wide`), length (`yorker`, `full`, `good`, `short`, `bouncer`), and variation (`stock`, `swing`, `seam`, `slower`, `yorker`, `bouncer`, `off_spin`, `leg_spin`, `googly`).
113
 
@@ -115,22 +163,32 @@ Query match intel at a small reward cost.
115
 
116
  ## Reward Architecture
117
 
 
 
118
  | Rubric | Weight | When | What |
119
  |--------|--------|------|------|
120
- | `r_result` | 55% | Episode end | Match outcome: win/loss, target margin, DLS/par |
121
- | `r_cricket` | 25% | Innings end | Dense Cricket contribution proxy (Dream11-style: runs, wickets, dots, milestones) |
122
- | `r_behavior` | 15% | Every delivery | Plan-action coherence + adaptation + opponent awareness + counterfactual regret |
123
- | `r_validity` | 5% | Every turn | Valid JSON tool call structure (gate/penalty) |
 
 
124
 
125
  `r_tools` is computed and logged for analysis but excluded from the composite — tool discipline is measured through outcome and behavior instead.
126
 
127
- The primary objective is to **win or defend the match over a full long-horizon episode**. `r_cricket` provides dense per-ball feedback so training gets a gradient before the final win/loss result.
128
 
129
- The `r_behavior` bundle (15%) covers: plan-action coherence (50%), strategic adaptation (20%), opponent awareness (20%), counterfactual regret (10%).
 
 
 
 
130
 
131
- **Two-stage curriculum (ToolRL):**
132
- - Stage 1 (episodes 0–100): `r_format` only trains valid JSON
133
- - Stage 2 (episodes 100+): all rubrics trains full-match strategy
 
 
134
 
135
  **Innings-specific scoring:**
136
  - **1st Innings (batting):** Score vs DLS par baseline
@@ -189,7 +247,7 @@ cricket_captain/
189
  ├── inference.py # Random + LLM agent evaluation
190
  ├── client.py # OpenEnv WebSocket client (CricketCaptainEnv)
191
  ├── models.py # GameState, CricketAction, CricketObservation, CricketState
192
- ├── train.py # MT-GRPO training (stateless reward function)
193
  ├── eval.py # Coherence heatmaps, reward curves
194
  ├── scripts/
195
  │ └── curate_transitions.py # Cricsheet → Markov transition table pipeline
@@ -234,9 +292,19 @@ Both the heuristic opponent and the environment's `select_batter` / `choose_bowl
234
 
235
  `server/player_roster.py` loads team profiles from `data/player_profiles/` (10 T20I squads: India, Australia, England, Pakistan, South Africa, New Zealand, West Indies, Sri Lanka, Bangladesh, Afghanistan). When the agent calls `select_batter` or `choose_bowler` with a player name, the roster performs fuzzy lookup (exact → surname → word-overlap) and fills in real aggression, batting/bowling style, and phase strengths.
236
 
237
- ### Tool Budget
 
 
 
 
238
 
239
- The environment enforces a **3-call overhead budget per over**. Overhead tools are `analyze_situation`, `reflect_after_ball`, `plan_delivery`, `set_strategy`, and `set_bowling_strategy`. Each call beyond 3 incurs a **−0.04 reward fine**. `plan_shot` is explicitly excluded from the budget — shot planning is always free. This discourages LLMs from padding with low-information tool calls.
 
 
 
 
 
 
240
 
241
  ### Markov Engine
242
 
@@ -250,81 +318,201 @@ Ball outcomes are sampled from a 5-dimensional transition table keyed by `(over,
250
 
251
  Bowler rotation mirrors real cricket: pace-heavy powerplay (90/10), spin-heavy middle overs (45/55), pace-heavy death (80/20). Each bowler has a 10-over cap before rotation is forced.
252
 
253
- ### Stateless GRPO Reward
254
 
255
- The `reward_fn` passed to `GRPOTrainer` computes rewards purely from `(prompt, completion)` pairs no shared environment state. Strategy and phase are extracted from the rendered prompt text:
256
 
257
- ```python
258
- # Prompt always contains:
259
- # "[CricketCaptain] MIDDLE | FIRST INNINGS"
260
- # "Over 18.2 | Score: 145/4 | ..."
261
- # "Current Strategy: consolidate (aggression=0.30) — Rotate strike against spin..."
262
 
263
- def r_coherence_stateless(prompt: str, completion: str) -> float:
264
- strategy = extract_strategy_from_prompt(prompt)
265
- phase = extract_phase_from_prompt(prompt)
266
- shot = json.loads(completion)["arguments"]["shot_intent"]
267
- return coherence_score(strategy, shot, phase)
268
- ```
269
 
270
  ---
271
 
272
  ## Quickstart
273
 
274
- ### YAML config (recommended)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- Use YAML configs to control **both**:
277
- - **server defaults** (opponent mode/model/cache, eval pack id), and
278
- - **runner defaults** (`inference.py` / `eval.py`: env URL, max overs, captain model/API).
 
 
279
 
280
- Use `configs/default.yaml` when you want both the captain and live opponent to call HF router models. Use `configs/cached_eval.yaml` when you want a live captain model against replayed opponent decisions for reproducible comparison.
281
 
282
  ```bash
283
- # Start server with live HF opponent config
284
- cd cricket_captain
285
- PYTHONPATH=. python server/app.py --port 8001 --config configs/default.yaml
 
 
 
 
 
 
 
 
 
 
286
 
287
- # Run a short HF Gemma captain baseline using config defaults
288
- export CRICKET_CAPTAIN_ENV_URL="ws://localhost:8001"
289
- export HF_TOKEN="hf_..."
290
  python inference.py --config configs/default.yaml --episodes 1
291
  ```
292
 
293
- The default config uses the HF router-compatible model `google/gemma-4-26B-A4B-it` for captain-side inference and live opponent defaults. In `llm_live` mode, the opponent actually calls that model during the run. In `llm_cached` mode, the opponent does **not** call `model`; it replays `cache_path`.
294
 
295
- For fair/reproducible eval:
296
 
 
 
 
297
  ```bash
298
- PYTHONPATH=. python server/app.py --port 8001 --config configs/cached_eval.yaml
299
- export CRICKET_CAPTAIN_ENV_URL="ws://localhost:8001"
300
- export HF_TOKEN="hf_..."
301
- python inference.py --config configs/cached_eval.yaml --episodes 1
302
  ```
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  ```bash
305
- # Install
306
- pip install openenv-core>=0.2.2 trl unsloth gradio fastapi uvicorn
 
307
 
308
- # Start environment server
309
- cd cricket_captain
310
- PYTHONPATH=. python server/app.py
311
 
312
- # Set this to your reachable WebSocket endpoint.
313
- # On Lightning, use the public/internal Lightning URL, not localhost from a remote runner.
314
- export CRICKET_CAPTAIN_ENV_URL="ws://<your-lightning-host>/ws"
315
 
316
- # Run random baseline agent (5 episodes)
317
- python inference.py --model random --episodes 5 --verbose --env-url "$CRICKET_CAPTAIN_ENV_URL"
 
318
 
319
- # Play interactively (Gradio UI)
320
- PYTHONPATH=. python server/ui.py # → http://localhost:7860
321
 
322
- # Train (requires GPU + train extras)
323
- python train.py sft-data --output data/training/tool_sft_examples.jsonl
324
- python train.py train --stage 1 --steps 200 --model Qwen/Qwen2.5-7B-Instruct
325
- python train.py train --stage 2 --steps 600 --model ./checkpoints/stage1_final
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  ```
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  ### Lightning / Remote Runtime Notes
329
 
330
  `localhost:8000` only works when the agent process and server process are in the same network namespace. On Lightning, expose the server port and pass the resulting WebSocket URL via:
@@ -351,10 +539,9 @@ For fast iteration, start with short 5-over runs before full 20-over evaluation:
351
 
352
  1. Random baseline with heuristic opponent.
353
  2. Base/untrained LLM baseline.
354
- 3. Optional SFT tool-format warmup.
355
- 4. GRPO Stage 1 for format/tool correctness.
356
- 5. GRPO Stage 2 for adaptive strategy.
357
- 6. Eval with `adaptive_t20_v1` and cached opponent decisions.
358
 
359
  See [`docs/experiment_workflow.md`](docs/experiment_workflow.md) for exact commands and rationale.
360
 
@@ -494,6 +681,30 @@ Still needed for final submission: real trained-vs-baseline plots, HF Space URL,
494
 
495
  ---
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  ## Citation
498
 
499
  ```
 
12
 
13
  # CricketCaptain-LLM
14
 
15
+ **An OpenEnv benchmark for long-horizon multi-turn agentic RL: train an LLM to captain a full cricket match.**
16
 
17
+ 🔗 **Live Hugging Face Space:** https://huggingface.co/spaces/pratinavseth/cricket-captain-llm
18
+ 📊 **Training runs (W&B):** https://wandb.ai/ptnv-s-research/huggingface
19
+ 📝 **Mini-blog / video:** _coming soon — placeholder, will link before submission deadline_
20
+
21
+ **Hackathon theme alignment — Theme #2: (Super) Long-Horizon Planning & Instruction Following.** A T20 match is up to **240 legal balls of strategic decision-making across both offensive and defensive roles** — exactly the regime where current LLM agents struggle: deep multi-step reasoning, sparse terminal reward (win/loss arrives 100+ turns after the early decisions that caused it), and recovery from early mistakes (a wicket in over 1 should reshape the plan for over 19).
22
+
23
+ CricketCaptain is a serious test of an LLM's ability to handle **long-horizon planning under sparse, delayed reward** with **opponent modeling** and **multi-action-type credit assignment**. One match = ~180 sequential tool calls across batting AND bowling phases, with 13 state-conditioned tools, a real Markov ball-outcome engine trained on **1.65M cricsheet deliveries**, and a composite reward signal that scores plan coherence, tactical adaptation, opponent awareness, regret, and final match outcome.
24
+
25
+ This is the agentic-RL equivalent of training a coding agent — partial trajectory rollouts, dense intermediate rewards, sparse terminal outcome — but in a domain where the strategy space is genuinely **two-sided** (you don't just write code; you also *anticipate the opposing captain*) and the action distribution is **mixed-discrete-continuous** (categorical tool choice + numeric aggression + free-text rationale).
26
 
27
  The Hugging Face Space exposes the OpenEnv server and a Gradio demo UI at `/web`.
28
 
29
+ **Why this is novel:** No prior agentic-RL benchmark covers strategic two-sided sports captaincy with phase-aware tool gating, real-data outcome simulation, and a composable rubric system. Cricket is also one of the few domains where *the same agent must alternate between offensive and defensive policies* within a single episode — a capability gap most multi-turn agent benchmarks (SWE-bench, WebArena, AgentBench) don't touch.
30
+
31
  ---
32
 
33
  ## The Problem
 
46
 
47
  ---
48
 
49
+ ## Train → Eval Story (the framing)
50
+
51
+ This is **multi-turn agentic RL with partial-trajectory training and full-task eval generalization** — the same pattern coding-agent RL papers (SWE-RL, AgentR) use.
52
+
53
+ | | TRAINING (warmup) | TRAINING (main) | EVALUATION |
54
+ |---|---|---|---|
55
+ | max overs | 2-3 (curriculum) | 5 (end-to-end) | full T20 (20 over) |
56
+ | steps | 30 | 100 | n/a |
57
+ | token budget per rollout | 16k (≈ 80–180 turns) | 24k (≈ 120–240 turns) | unlimited (full match plays out) |
58
+ | reward signal | composite (`r_result` + `r_cricket` + `r_behavior` + `r_validity`) — see `server/reward_calculator.py` | same | **headline metric: win rate vs heuristic baseline** |
59
+ | what model learns | format mastery + per-state decisions on short formats | full-match strategic depth on 5-over | composes per-state decisions into match-winning trajectories |
60
+ | script | `python train.py train --config configs/cricket_train_qwen3_warmup.yaml` | `python train.py train --config configs/cricket_train_qwen3.yaml` | `compare_eval.py` (baseline + trained, prints comparison) |
61
+
62
+ The training operates on shorter formats than evaluation, but the trained policy generalizes to full matches at inference because it learns good per-state decisions, not specific trajectory lengths. The whole chain runs via `bash scripts/run_warmup_then_main.sh`.
63
+
64
+ ## Results
65
+
66
+ **Live W&B project:** https://wandb.ai/ptnv-s-research/huggingface
67
+
68
+ Training stack: Qwen3-4B-Instruct-2507 + LoRA r=64 + TRL GRPO, vLLM colocate, 1× H200. Warmup is on a 2–3 over curriculum (30 steps); the main run trains on 5-over end-to-end matches (100 steps) and resumes from the warmup adapter.
69
+
70
+ Static plots will land here once the chain completes:
71
+ - **Training reward curve:** [docs/plots/training_reward_over_steps.png](docs/plots/training_reward_over_steps.png)
72
+ - **Per-rubric breakdown:** [docs/plots/per_rubric_breakdown.png](docs/plots/per_rubric_breakdown.png)
73
+ - **Tool-call execution frequency:** [docs/plots/tool_call_frequency.png](docs/plots/tool_call_frequency.png)
74
+ - **Before/after comparison:** [docs/plots/before_after_comparison.png](docs/plots/before_after_comparison.png)
75
+
76
+ ### Key engineering findings (documented in commit history)
77
+
78
+ | Issue surfaced | Fix | Effect |
79
+ |---|---|---|
80
+ | Only ~19% of rollouts reached `done` naturally | Restrict the system prompt to a single `<tool_call>...</tool_call>` XML format (the prompt previously also advertised bare JSON, which TRL's response-schema parser rejects, ending the rollout) | tools/call_frequency 9 → 73; rollouts 5–8× longer; matches actually play out |
81
+ | GRPO group std collapsing to 0 once matches completed | Remove the `[-1, 1]` reward clip — let GRPO standardize the advantage itself | reward std 0.0 → 1.5; gradient signal restored |
82
+ | Composite reward stayed positive even on dominant losses | Add explicit `outcome_bonus = -1.0` for losses (was 0.0); reduce the always-positive `progress_bonus` cap | composite now spans negative AND positive — model has a real reason to win |
83
+
84
  ## Read This First
85
 
86
  - [`docs/benchmark_explainer.md`](docs/benchmark_explainer.md): full explanation of the problem statement, OpenEnv architecture, environment loop, rewards, data curation, training, and competition compliance.
 
103
  ```
104
 
105
  ### 2. Batting (when agent bats)
106
+ Choose a real roster batter, set/update the long-horizon plan, plan the shot, then play each delivery.
107
  ```json
108
+ {"tool": "set_match_plan", "arguments": {"powerplay_intent": "Use V Kohli and NT Tilak Varma to build a stable platform.", "middle_intent": "Rotate against spin and attack weak matchups.", "death_intent": "Use finishers for boundary options.", "risk_budget": "Escalate only with wickets in hand or target pressure.", "trigger_conditions": "Review after wickets, phase changes, or repeated dots/boundaries.", "rationale": "Roster-aware plan for a short chase."}}
109
+ {"tool": "select_batter", "arguments": {"name": "V Kohli", "style": "balanced", "aggression": 0.45, "rationale": "Reliable top-order batter to control risk early."}}
110
  {"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Middle overs against spin — rotate strike and preserve wickets."}}
111
  {"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy gap."}}
112
  {"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Working into the gap at mid-wicket."}}
113
  ```
114
 
115
  ### 3. Bowling & Fielding (when agent bowls)
116
+ Choose a real roster bowler, set a delivery/field plan, then bowl each delivery against an opponent policy.
117
  ```json
118
+ {"tool": "choose_bowler", "arguments": {"name": "BB Sran", "bowler_type": "pace", "style": "economy", "rationale": "Use a roster pacer in the powerplay with a new ball."}}
119
  {"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full"}}
120
  {"tool": "set_field_setting", "arguments": {"setting": "Aggressive"}}
121
  {"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "delivery_type": "yorker", "line": "stumps", "length": "full", "rationale": "Limit swing room against an aggressive finisher."}}
 
154
  | `bowl_delivery` | Bowling | Bowl the next delivery |
155
  | `reflect_after_ball` | Bat/Bowl | Record post-ball tactical adjustment |
156
  | `analyze_situation` | Any | Query pitch, bowler, field, or match situation |
157
+ | `set_match_plan` | Bat/Bowl | Establish powerplay/middle/death plan, risk budget, and triggers |
158
+ | `update_match_plan` | Bat/Bowl | Revise match plan with a match-state reason |
159
 
160
  `plan_shot.target_area` is normalized into cricket zones such as `cover`, `point`, `straight`, `midwicket`, `square_leg`, `fine_leg`, `long_on`, and `long_off`. `plan_shot.trajectory` can be `ground`, `lofted`, or `aerial`. Delivery plans normalize line (`outside_off`, `stumps`, `pads`, `wide`), length (`yorker`, `full`, `good`, `short`, `bouncer`), and variation (`stock`, `swing`, `seam`, `slower`, `yorker`, `bouncer`, `off_spin`, `leg_spin`, `googly`).
161
 
 
163
 
164
  ## Reward Architecture
165
 
166
+ A **composable 4-rubric composite** following the SWE-RL recipe (60% intermediate / 40% terminal) — chosen because partial-trajectory training (where most episodes truncate before completion) needs gradient signal that actually fires. Putting most weight on the rare terminal reward washes out learning.
167
+
168
  | Rubric | Weight | When | What |
169
  |--------|--------|------|------|
170
+ | `r_cricket` | **45%** | Per ball | Dream11-style proxy — runs, wickets, dots, boundaries, economy, milestones |
171
+ | `r_behavior` | **25%** | Every turn | Coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%) |
172
+ | `r_result` | **20%** | Innings/episode end | Match outcome: chase progress, defense margin, win bonus, DLS par |
173
+ | `r_validity` | **10%** | Every turn | Valid tool-call structure and legal phase-gated tool use |
174
+
175
+ Plus a **progress bonus** added to `r_result`: `min(0.25, tool_calls_made / 40.0)` — caps at +0.25 once the agent makes ≥10 tool calls. Directly rewards escaping the "planning loop" trap (where the policy maxes overhead tools without ever calling `play_delivery`).
176
 
177
  `r_tools` is computed and logged for analysis but excluded from the composite — tool discipline is measured through outcome and behavior instead.
178
 
179
+ **Why this weighting works:** in partial-trajectory training every single turn produces a reward (validity + behavior + per-ball Dream11 when balls bowl); the terminal `r_result` only fires when an episode actually completes. Reward weights are calibrated to put gradient on the dense signals that fire most often, exactly mirroring the SWE-RL / coding-agent-RL recipe.
180
 
181
+ **Single-stage training (full composite reward from step 0):**
182
+ Qwen3-4B-Instruct-2507 emits `<tool_call>...</tool_call>` natively, so we skip the legacy
183
+ "format mastery" warm-up and run the full composite reward (`r_result + r_cricket +
184
+ r_behavior + r_validity`) from step 0. The internal `curriculum_stage` field is still
185
+ set to `2` for code-path compatibility — it just means "full reward".
186
 
187
+ **Two-config workflow:**
188
+ - [`configs/cricket_train_qwen3_warmup.yaml`](configs/cricket_train_qwen3_warmup.yaml) — short
189
+ 2-3 over curriculum, 30 steps. Bootstraps the LoRA adapter on a fast format.
190
+ - [`configs/cricket_train_qwen3.yaml`](configs/cricket_train_qwen3.yaml) — 5-over end-to-end,
191
+ 100 steps. `resume_from: ./checkpoints/stage2_final` picks up the warmup adapter.
192
 
193
  **Innings-specific scoring:**
194
  - **1st Innings (batting):** Score vs DLS par baseline
 
247
  ├── inference.py # Random + LLM agent evaluation
248
  ├── client.py # OpenEnv WebSocket client (CricketCaptainEnv)
249
  ├── models.py # GameState, CricketAction, CricketObservation, CricketState
250
+ ├── train.py # TRL GRPO agent training with environment_factory tool calls
251
  ├── eval.py # Coherence heatmaps, reward curves
252
  ├── scripts/
253
  │ └── curate_transitions.py # Cricsheet → Markov transition table pipeline
 
292
 
293
  `server/player_roster.py` loads team profiles from `data/player_profiles/` (10 T20I squads: India, Australia, England, Pakistan, South Africa, New Zealand, West Indies, Sri Lanka, Bangladesh, Afghanistan). When the agent calls `select_batter` or `choose_bowler` with a player name, the roster performs fuzzy lookup (exact → surname → word-overlap) and fills in real aggression, batting/bowling style, and phase strengths.
294
 
295
+ ### Tool budget (per over)
296
+
297
+ The simulator counts **strategic / analysis** tools that do not advance the ball. Constants live in `CricketEnvironment` as `TOOL_BUDGET_PER_OVER` (3) and `TOOL_FINE_PER_EXCESS` (0.04).
298
+
299
+ **Overhead tools (count toward the 3 / over):** `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, `analyze_situation`.
300
 
301
+ **Not overhead (no per-over counter):** `plan_shot`, `set_match_plan`, `update_match_plan`, `select_batter`, `choose_bowler`, `set_field_setting`, `play_delivery`, `bowl_delivery`, `call_toss`, and other execution tools.
302
+
303
+ The first **three** overhead calls in a given over are free of this fine. Each additional overhead call in that over applies an immediate **−0.04** step reward. The prompt shows `Tool budget: N/3 overhead calls used this over` so the model can learn to ration reflection and re-planning.
304
+
305
+ ### Tool budget and training
306
+
307
+ `train.py train` rollouts are full environment episodes: the same fines apply on every `step` the policy takes. Over a long match, repeatedly burning the budget (for example, `analyze_situation` or `reflect_after_ball` on most balls) **accumulates** many small penalties and competes with match outcome and behavior rewards. GRPO therefore sees a direct signal to use overhead tools when they change decisions, not as padding. Long-horizon **match plans** (`set_match_plan` / `update_match_plan`) are not charged against this overhead budget, so the agent can state multi-phase intent without spending the 3 “slots” on raw analysis calls.
308
 
309
  ### Markov Engine
310
 
 
318
 
319
  Bowler rotation mirrors real cricket: pace-heavy powerplay (90/10), spin-heavy middle overs (45/55), pace-heavy death (80/20). Each bowler has a 10-over cap before rotation is forced.
320
 
321
+ ### GRPO Agent Training
322
 
323
+ `train.py train` uses TRL `GRPOTrainer` with `environment_factory=CricketCaptainToolEnv`. The trainer creates live `CricketEnvironment` instances, exposes captaincy actions as tool methods, and lets the model interact over multiple tool-calling turns instead of scoring isolated prompt/completion strings.
324
 
325
+ The dataset is a set of seeded cricket scenarios. Each rollout resets the environment with `agent_team`, `opponent_mode`, `max_overs`, and optional eval-pack/cache settings. Rewards come back from the environment after real state transitions, so training sees wickets, targets, role swaps, plan updates, and terminal match results.
 
 
 
 
326
 
327
+ The training model is configured separately from inference/demo models:
328
+
329
+ - `train.model`: model being optimized with GRPO, currently `Qwen/Qwen3-4B-Instruct-2507` (256k native context, native `Qwen3ForCausalLM` in vLLM, no thinking blocks).
330
+ - `opponent.model`: live opponent model used when `opponent.mode=llm_live`, currently `google/gemma-4-26B-A4B-it`.
331
+ - `captain.model`: inference/evaluation captain model used by `inference.py` and `eval.py`.
 
332
 
333
  ---
334
 
335
  ## Quickstart
336
 
337
+ ### 1. Install (uv)
338
+
339
+ System requirements: Python 3.10+, CUDA 12.x for training. Inference / random baselines / Gradio UI work CPU-only.
340
+
341
+ This project is managed with [uv](https://docs.astral.sh/uv/). Versions in [pyproject.toml](pyproject.toml) are pinned to a known-working set (transformers 5.6.2 + trl 1.2.0 + vllm 0.19.1 + torch 2.10.0) — these are the lowest combination that supports TRL multi-turn `environment_factory` AND vLLM colocate AND transformers v5 chat templates. Earlier vLLM pins (<0.19) force transformers <5 and break multi-turn training.
342
+
343
+ ```bash
344
+ # Clone and enter
345
+ git clone <this-repo> cricket-captain-llm
346
+ cd cricket-captain-llm
347
+
348
+ # Create a venv and install core + training extras + eval plots
349
+ uv venv .venv --python 3.10
350
+ uv pip install --python .venv/bin/python -e ".[train,eval]"
351
+
352
+ # Activate
353
+ source .venv/bin/activate
354
+
355
+ # HuggingFace login (only needed for gated models / model downloads)
356
+ huggingface-cli login
357
+ # or: export HF_TOKEN=hf_...
358
+ ```
359
+
360
+ Inference-only (no GPU) install:
361
+
362
+ ```bash
363
+ uv venv .venv --python 3.10
364
+ uv pip install --python .venv/bin/python -e .
365
+ ```
366
+
367
+ ### 2. YAML config (single source of truth)
368
+
369
+ All commands read defaults from a YAML — pass `--config configs/cricket_train_qwen3.yaml` (or one of the others under `configs/`) and only override what you need on the CLI. Three role groups:
370
+
371
+ | Group | Keys | Used by |
372
+ |---|---|---|
373
+ | `env.*` | `agent_team`, `max_overs`, `eval_pack_id`, `env_url` | server, inference, training |
374
+ | `opponent.*` | `mode`, `model`, `api_base`, `api_key_env`, `cache_path` | server (heuristic / cricsheet / llm_live / llm_cached) |
375
+ | `captain.*` | `model`, `api_base`, `api_key_env` | inference / eval (when using a live captain LLM) |
376
+ | `train.*` | `model`, `resume_from`, `stage`, `prompts`, `steps`, `batch_size`, `grad_accum`, `num_generations`, `max_completion_length`, `max_tool_calling_iterations`, `learning_rate`, `beta`, `temperature`, `top_p`, `gradient_checkpointing`, `gradient_checkpointing_use_reentrant`, `dataloader_pin_memory`, `dataloader_num_workers`, `bf16_base`, `save_steps`, `save_total_limit`, `report_to`, `run_name` | GRPO trainer |
377
 
378
+ - `configs/cricket_train_qwen3_warmup.yaml` **GRPO warmup** (2-3 over curriculum, 30 steps). Run first.
379
+ - `configs/cricket_train_qwen3.yaml` — **GRPO main** (5-over end-to-end, 100 steps). Resumes from warmup adapter via `resume_from:`.
380
+ - `configs/cricket_train_qwen3_smoke.yaml` 2-step infrastructure smoke test.
381
+ - `configs/game_knowledge.yaml` — reward weights and game constants (loaded at import time).
382
+ - `configs/extras/` — legacy Qwen3.5 configs and `default.yaml`, kept for reference.
383
 
384
+ ### 3. Run the environment server (for inference / Gradio)
385
 
386
  ```bash
387
+ cd cricket-captain-llm
388
+ PYTHONPATH=. python server/app.py --port 8000 --config configs/extras/default.yaml
389
+ ```
390
+
391
+ The server exposes the OpenEnv WebSocket at `ws://localhost:8000/ws` and a Gradio UI at `http://localhost:8000/web`.
392
+
393
+ ### 4. Inference baselines (no training required)
394
+
395
+ ```bash
396
+ export CRICKET_CAPTAIN_ENV_URL="ws://localhost:8000"
397
+
398
+ # Random agent (no API key needed) — fast sanity check
399
+ python inference.py --model random --episodes 5 --opponent-mode heuristic --max-overs 5
400
 
401
+ # Live HF Gemma captain baseline using config defaults
402
+ export HF_TOKEN=hf_...
 
403
  python inference.py --config configs/default.yaml --episodes 1
404
  ```
405
 
406
+ ### 5. GRPO training
407
 
408
+ Training does **not** need the server — it instantiates `CricketEnvironment` directly and runs it through TRL `GRPOTrainer` with `environment_factory`. All rollouts are simulated live (no static dataset).
409
 
410
+ The recommended workflow is **warmup → main run**, both controlled entirely from YAML:
411
+
412
+ **Single-command chain** (warmup → main, auto-resume from warmup adapter):
413
  ```bash
414
+ bash scripts/run_warmup_then_main.sh
415
+ # Logs: /tmp/train_warmup.log (then /tmp/train_main.log on success)
416
+ # Final adapter: ./checkpoints/stage2_final/
 
417
  ```
418
 
419
+ The chain is what we run to produce a trained model. Internally:
420
+
421
+ **Step 1 — Warmup (2-3 over curriculum, 30 steps, ~50–60 min on a single H200):**
422
+ - Curriculum-distributed `max_overs` (heavy on 2-over, tail to 3-over) so episodes
423
+ complete inside the token budget and `r_result` fires reliably.
424
+ - Bootstraps the LoRA adapter from base Qwen3-4B-Instruct-2507 → saves to
425
+ `./checkpoints/stage2_final/`.
426
+
427
+ **Step 2 — Main (5-over end-to-end, 100 steps, ~5–7 hrs):**
428
+ - Resumes the warmup adapter via `resume_from: ./checkpoints/stage2_final`.
429
+ - Trains on full 5-over matches with the `r_result` outcome signal as the dominant gradient driver.
430
+ - Final adapter at `./checkpoints/stage2_final/` (overwritten — that's the deliverable).
431
+
432
+ **Run components individually:**
433
  ```bash
434
+ # Warmup only
435
+ PYTORCH_ALLOC_CONF=expandable_segments:True \
436
+ python train.py train --config configs/cricket_train_qwen3_warmup.yaml
437
 
438
+ # Main only (assumes ./checkpoints/stage2_final/ exists)
439
+ PYTORCH_ALLOC_CONF=expandable_segments:True \
440
+ python train.py train --config configs/cricket_train_qwen3.yaml
441
 
442
+ # Main without resuming (fresh adapter)
443
+ python train.py train --config configs/cricket_train_qwen3.yaml --resume-from ""
444
+ ```
445
 
446
+ The opponent is `heuristic` (rule-based) by default for fast iteration.
447
+ Switch to `mode: llm_live` in `configs/cricket_train_qwen3.yaml` (and set `HF_TOKEN`) to train
448
+ against the live Gemma adversary.
449
 
450
+ ### 6. Evaluating the trained model
 
451
 
452
+ The eval-time story is the headline. Training caps rollouts at the warmup/main token budgets (16k / 24k), so warmup rollouts run 2–3 over and main rollouts run 5-over. At inference there's no token cap — full T20 matches play out. This is the same pattern coding-agent RL papers (SWE-RL, AgentR) use: train on partial windows, evaluate on full task completion.
453
+
454
+ ```bash
455
+ # Baseline (untrained Qwen3-4B-Instruct-2507)
456
+ python compare_eval.py --model Qwen/Qwen3-4B-Instruct-2507 \
457
+ --label baseline --episodes 20 --max-overs 5 \
458
+ --output eval_results/baseline.json
459
+
460
+ # Trained (uses LoRA adapter)
461
+ python compare_eval.py --model Qwen/Qwen3-4B-Instruct-2507 \
462
+ --adapter ./checkpoints/stage2_final \
463
+ --label trained --episodes 20 --max-overs 5 \
464
+ --output eval_results/trained.json
465
+
466
+ # Side-by-side comparison table
467
+ python compare_eval.py --compare \
468
+ eval_results/baseline.json \
469
+ eval_results/trained.json
470
  ```
471
 
472
+ Each `--episodes 20` run takes ~30-45 min depending on match length. The comparison
473
+ prints a side-by-side table of: match completion rate, win rate, mean agent score,
474
+ mean wickets lost, mean tool calls, validity rate, and per-rubric reward breakdown.
475
+
476
+ ### 6. Tuning batch size for your GPU
477
+
478
+ The dominant memory consumer during GRPO is **attention prefill of input prompts** (game state ~1300 tokens) for `generation_batch_size = batch_size × grad_accum` simultaneous sequences.
479
+
480
+ | GPU VRAM | flash-attn? | Recommended `batch_size` × `grad_accum` | `max_completion_length` |
481
+ |---|---|---|---|
482
+ | 24 GB (A10/3090) | required | 1 × 4 = 4 | 512 |
483
+ | 46 GB (L40S) | recommended | 4 × 4 = 16 | 512 |
484
+ | 46 GB (L40S) | not installed | 2 × 4 = 8 | 512 |
485
+ | 80 GB (A100/H100) | required | 8 × 4 = 32 | 1024 |
486
+
487
+ Cricket tool calls are short JSON objects (~20–300 tokens), so `max_completion_length: 512` is plenty. Without flash-attn, SDPA allocates `O(seq_len²)` attention matrices — keep batches small or install flash-attn.
488
+
489
+ If you OOM, halve `batch_size` first. If you OOM on prefill specifically (during generation, not gradient), it's the prompt length × generation_batch_size — install flash-attn or shrink `grad_accum`.
490
+
491
+ ### 7. Where to find logs and outputs
492
+
493
+ | Path | What |
494
+ |---|---|
495
+ | stdout | Per-step loss, reward, lr (every `logging_steps=10`) plus full sampled completions |
496
+ | `checkpoints/stage{1,2}/` | HF Trainer state, intermediate LoRA checkpoints (every ~80 steps) |
497
+ | `checkpoints/stage{1,2}_final/` | Final LoRA + tokenizer |
498
+ | `illustrations/exp_*/run_output.txt` | Per-step environment trace from `inference.py` and `train.py train-smoke` |
499
+
500
+ For a tensorboard dashboard, set `train.report_to: tensorboard` in the YAML, then:
501
+ ```bash
502
+ tensorboard --logdir checkpoints/
503
+ ```
504
+
505
+ ### 8. Smoke test (no model load)
506
+
507
+ Verify the environment + opponent + tool budget end-to-end without loading any model:
508
+ ```bash
509
+ PYTHONPATH=. python train.py train-smoke \
510
+ --config configs/extras/default.yaml \
511
+ --matches 1 --max-overs 2 \
512
+ --opponent-mode heuristic
513
+ ```
514
+ This runs one short match with random actions and writes a full step log to `illustrations/exp_*/`.
515
+
516
  ### Lightning / Remote Runtime Notes
517
 
518
  `localhost:8000` only works when the agent process and server process are in the same network namespace. On Lightning, expose the server port and pass the resulting WebSocket URL via:
 
539
 
540
  1. Random baseline with heuristic opponent.
541
  2. Base/untrained LLM baseline.
542
+ 3. GRPO warmup (`configs/cricket_train_qwen3_warmup.yaml`) for format mastery on short matches.
543
+ 4. GRPO main (`configs/cricket_train_qwen3.yaml`) for full 5-over strategic depth, resuming the warmup adapter.
544
+ 5. Eval with `adaptive_t20_v1` and cached opponent decisions.
 
545
 
546
  See [`docs/experiment_workflow.md`](docs/experiment_workflow.md) for exact commands and rationale.
547
 
 
681
 
682
  ---
683
 
684
+ ## Hackathon Submission Materials
685
+
686
+ | Material | Link |
687
+ |---|---|
688
+ | **Live HF Space** (env runs here) | https://huggingface.co/spaces/pratinavseth/cricket-captain-llm |
689
+ | **GitHub repo** | https://github.com/pratinavseth/cricket-captain-llm |
690
+ | **W&B project** (training runs) | https://wandb.ai/ptnv-s-research/huggingface |
691
+ | **Mini-blog (HF blog)** | _placeholder — will be added before submission_ |
692
+ | **Demo video (≤2 min, YouTube)** | _placeholder — will be added before submission_ |
693
+ | **Slide deck** | _optional — TBD_ |
694
+
695
+ ### Submission checklist (per hackathon guidelines)
696
+
697
+ - [x] Use OpenEnv (latest release): `openenv-core[core]>=0.2.2` ([pyproject.toml](pyproject.toml))
698
+ - [x] Working training script (TRL GRPO): [train.py](train.py)
699
+ - [x] Reward + training pipeline coherent: 4-rubric composite (`r_result`, `r_cricket`, `r_behavior`, `r_validity`) with documented signal flow
700
+ - [x] HF Space pushed and discoverable
701
+ - [ ] Loss + reward plots from real run (in progress — generated after main run completes)
702
+ - [ ] Mini-blog or ≤2 min video (in progress)
703
+ - [x] README motivates problem, explains env, links materials
704
+ - [x] OpenEnv compliance: `Environment` base class, valid `openenv.yaml`, no reserved tool names
705
+
706
+ ---
707
+
708
  ## Citation
709
 
710
  ```
compare_eval.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ compare_eval.py — Baseline vs trained head-to-head evaluation.
3
+
4
+ Plays N full cricket matches with the BASELINE model (untrained Qwen3.5-4B)
5
+ and the TRAINED model (Qwen3.5-4B + LoRA adapter from a checkpoint), then
6
+ dumps a comparison table:
7
+
8
+ win_rate, mean_agent_score, mean_opp_score, mean_wickets, match_completion_rate,
9
+ mean_tool_calls_per_episode, validity_rate, plus a few illustrative transcripts.
10
+
11
+ Why this is the right eval for our setup
12
+ ----------------------------------------
13
+ Training uses a token budget per rollout (~4096 tokens, ~16 turns) which truncates
14
+ most matches. At EVAL time we lift that cap entirely — the model gets unlimited
15
+ context and can actually play full matches. This is the same pattern coding-agent
16
+ RL papers use: train on partial windows, eval on full task completion. The trained
17
+ policy generalizes because it learned good per-state decisions, not a specific
18
+ trajectory length.
19
+
20
+ Usage
21
+ -----
22
+ # Baseline (untrained Qwen3.5-4B base)
23
+ python compare_eval.py \\
24
+ --model Qwen/Qwen3.5-4B \\
25
+ --label baseline \\
26
+ --episodes 20 --max-overs 5 \\
27
+ --output eval_results/baseline.json
28
+
29
+ # Trained (warmup + main checkpoint)
30
+ python compare_eval.py \\
31
+ --model Qwen/Qwen3.5-4B \\
32
+ --adapter ./checkpoints/stage2_final \\
33
+ --label trained \\
34
+ --episodes 20 --max-overs 5 \\
35
+ --output eval_results/trained.json
36
+
37
+ # Side-by-side comparison
38
+ python compare_eval.py --compare eval_results/baseline.json eval_results/trained.json
39
+ """
40
+
41
+ import argparse
42
+ import json
43
+ import os
44
+ import sys
45
+ import time
46
+ from collections import Counter
47
+ from pathlib import Path
48
+
49
+ import torch
50
+ from peft import PeftModel
51
+ from transformers import AutoModelForCausalLM, AutoTokenizer
52
+
53
+ from server.cricket_environment import CricketEnvironment
54
+ from models import CricketAction
55
+ import train as train_module # reuse SYSTEM_PROMPT and _parse_completion
56
+
57
+
58
+ # ----------------------------------------------------------------------------
59
+ # Model loading
60
+ # ----------------------------------------------------------------------------
61
+
62
+ def load_model_for_eval(model_name: str, adapter_path: str | None = None):
63
+ """Load base model in bf16; optionally apply a LoRA adapter on top."""
64
+ print(f"Loading base model: {model_name}")
65
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_name,
70
+ torch_dtype=torch.bfloat16,
71
+ device_map="auto",
72
+ trust_remote_code=True,
73
+ )
74
+ if adapter_path:
75
+ print(f"Loading LoRA adapter: {adapter_path}")
76
+ model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False)
77
+ model.eval()
78
+ return model, tokenizer
79
+
80
+
81
+ # ----------------------------------------------------------------------------
82
+ # Single-episode rollout (no token cap — let matches actually complete)
83
+ # ----------------------------------------------------------------------------
84
+
85
+ def play_one_episode(
86
+ *,
87
+ model,
88
+ tokenizer,
89
+ max_overs: int,
90
+ opponent_mode: str,
91
+ agent_team: str,
92
+ eval_pack_id: str,
93
+ seed: int,
94
+ max_tool_calls: int = 800,
95
+ max_completion_per_turn: int = 256, # per-turn (NOT per-rollout) — eval is turn-by-turn
96
+ temperature: float = 0.3, # deterministic-ish at eval
97
+ verbose: bool = False,
98
+ ) -> dict:
99
+ """Run one full match. Returns per-episode stats."""
100
+ env = CricketEnvironment()
101
+ obs = env.reset(seed=seed, options={
102
+ "task": "stage2_full",
103
+ "random_start": False,
104
+ "max_overs": max_overs,
105
+ "eval_pack_id": eval_pack_id,
106
+ "opponent_mode": opponent_mode,
107
+ "agent_team": agent_team,
108
+ })
109
+
110
+ # Build the message log progressively. Each turn appends model output + tool response.
111
+ system_prompt = train_module.SYSTEM_PROMPT
112
+ messages = [
113
+ {"role": "system", "content": system_prompt},
114
+ {"role": "user", "content": obs.prompt_text},
115
+ ]
116
+
117
+ tool_calls_made = 0
118
+ tool_breakdown: Counter = Counter()
119
+ parse_failures = 0
120
+ illegal_tool_attempts = 0
121
+ start_t = time.time()
122
+
123
+ while not obs.done and tool_calls_made < max_tool_calls:
124
+ # Render chat using model's tool template
125
+ try:
126
+ inputs = tokenizer.apply_chat_template(
127
+ messages,
128
+ tokenize=True,
129
+ add_generation_prompt=True,
130
+ return_tensors="pt",
131
+ ).to(model.device)
132
+ except Exception as e:
133
+ print(f" apply_chat_template error: {e}")
134
+ break
135
+
136
+ with torch.no_grad():
137
+ out = model.generate(
138
+ inputs,
139
+ max_new_tokens=max_completion_per_turn,
140
+ do_sample=(temperature > 0),
141
+ temperature=max(temperature, 1e-5),
142
+ pad_token_id=tokenizer.pad_token_id,
143
+ )
144
+ gen_ids = out[0, inputs.shape[1]:]
145
+ completion = tokenizer.decode(gen_ids, skip_special_tokens=False)
146
+
147
+ # Parse the tool call
148
+ parsed = train_module._parse_completion(completion)
149
+ if parsed is None:
150
+ parse_failures += 1
151
+ if verbose:
152
+ print(f" PARSE FAIL: {completion[:200]}...")
153
+ messages.append({"role": "assistant", "content": completion})
154
+ messages.append({"role": "user", "content": "Your previous output was not parseable. Please emit exactly one tool call."})
155
+ continue
156
+
157
+ tool_name = parsed.get("tool", "")
158
+ tool_args = parsed.get("arguments", {}) or {}
159
+ tool_breakdown[tool_name] += 1
160
+
161
+ # Apply to env
162
+ try:
163
+ obs = env.step(CricketAction(tool=tool_name, arguments=tool_args))
164
+ tool_calls_made += 1
165
+ except Exception as e:
166
+ illegal_tool_attempts += 1
167
+ if verbose:
168
+ print(f" ILLEGAL TOOL: {tool_name} → {e}")
169
+ messages.append({"role": "assistant", "content": completion})
170
+ messages.append({"role": "user", "content": f"Tool error: {e}. Try a different tool."})
171
+ continue
172
+
173
+ messages.append({"role": "assistant", "content": completion})
174
+ messages.append({"role": "user", "content": obs.prompt_text})
175
+
176
+ elapsed = time.time() - start_t
177
+ state = env.state
178
+ breakdown = state.reward_breakdown or {}
179
+
180
+ # Determine match result
181
+ is_complete = bool(obs.done)
182
+ agent_score = int(state.total_score or 0)
183
+ opp_score = int(state.first_innings_score or 0) if state.innings_type == "second" else None
184
+ target = state.target
185
+ won = None
186
+ if is_complete:
187
+ # Crude win check; env's match_result string is the canonical source
188
+ result_str = (state.match_result or "").lower()
189
+ if "won" in result_str and "agent" in result_str:
190
+ won = True
191
+ elif "lost" in result_str or "won" in result_str:
192
+ won = False
193
+ else:
194
+ won = None
195
+
196
+ return {
197
+ "seed": seed,
198
+ "max_overs": max_overs,
199
+ "opponent_mode": opponent_mode,
200
+ "tool_calls_made": tool_calls_made,
201
+ "match_complete": is_complete,
202
+ "won": won,
203
+ "agent_score": agent_score,
204
+ "opponent_first_innings_score": opp_score,
205
+ "target": target,
206
+ "wickets_lost": int(state.wickets_lost or 0),
207
+ "match_result": state.match_result or "",
208
+ "tool_breakdown": dict(tool_breakdown),
209
+ "parse_failures": parse_failures,
210
+ "illegal_tool_attempts": illegal_tool_attempts,
211
+ "validity_rate": round(1.0 - (parse_failures + illegal_tool_attempts) / max(tool_calls_made + parse_failures + illegal_tool_attempts, 1), 4),
212
+ "reward_breakdown": dict(breakdown),
213
+ "elapsed_seconds": round(elapsed, 1),
214
+ }
215
+
216
+
217
+ # ----------------------------------------------------------------------------
218
+ # Run N episodes
219
+ # ----------------------------------------------------------------------------
220
+
221
+ def run_n_episodes(
222
+ *, model, tokenizer, episodes: int, max_overs: int, opponent_mode: str,
223
+ agent_team: str, eval_pack_id: str, seed_base: int, max_tool_calls: int,
224
+ max_completion_per_turn: int, temperature: float, verbose: bool,
225
+ ) -> dict:
226
+ results = []
227
+ for i in range(episodes):
228
+ seed = seed_base + i
229
+ print(f" [{i+1}/{episodes}] seed={seed} …", end="", flush=True)
230
+ try:
231
+ res = play_one_episode(
232
+ model=model, tokenizer=tokenizer,
233
+ max_overs=max_overs, opponent_mode=opponent_mode,
234
+ agent_team=agent_team, eval_pack_id=eval_pack_id, seed=seed,
235
+ max_tool_calls=max_tool_calls,
236
+ max_completion_per_turn=max_completion_per_turn,
237
+ temperature=temperature, verbose=verbose,
238
+ )
239
+ print(f" {res['tool_calls_made']} tool calls, "
240
+ f"{'COMPLETE' if res['match_complete'] else 'truncated'}, "
241
+ f"score {res['agent_score']}/{res['wickets_lost']}, "
242
+ f"{res['elapsed_seconds']}s")
243
+ results.append(res)
244
+ except Exception as e:
245
+ print(f" FAILED: {e}")
246
+ results.append({"seed": seed, "error": str(e)})
247
+
248
+ # Aggregate
249
+ valid = [r for r in results if "error" not in r]
250
+ n = len(valid)
251
+ if n == 0:
252
+ return {"results": results, "summary": {"n": 0, "error": "all episodes failed"}}
253
+
254
+ completed = [r for r in valid if r["match_complete"]]
255
+ won = [r for r in completed if r.get("won") is True]
256
+ summary = {
257
+ "n_episodes": n,
258
+ "match_completion_rate": round(len(completed) / n, 4),
259
+ "win_rate_among_completed": round(len(won) / max(len(completed), 1), 4),
260
+ "win_rate_overall": round(len(won) / n, 4),
261
+ "mean_agent_score": round(sum(r["agent_score"] for r in valid) / n, 2),
262
+ "mean_wickets_lost": round(sum(r["wickets_lost"] for r in valid) / n, 2),
263
+ "mean_tool_calls": round(sum(r["tool_calls_made"] for r in valid) / n, 1),
264
+ "mean_validity_rate": round(sum(r["validity_rate"] for r in valid) / n, 4),
265
+ "mean_composite_reward": round(sum(r["reward_breakdown"].get("composite", 0.0) for r in valid) / n, 4),
266
+ "mean_r_result": round(sum(r["reward_breakdown"].get("r_result", 0.0) for r in valid) / n, 4),
267
+ "mean_r_cricket": round(sum(r["reward_breakdown"].get("r_cricket", 0.0) for r in valid) / n, 4),
268
+ "mean_r_behavior": round(sum(r["reward_breakdown"].get("r_behavior", 0.0) for r in valid) / n, 4),
269
+ "mean_r_validity": round(sum(r["reward_breakdown"].get("r_validity", 0.0) for r in valid) / n, 4),
270
+ "tool_freq": {},
271
+ }
272
+ # Aggregate tool frequencies
273
+ all_tools: Counter = Counter()
274
+ for r in valid:
275
+ for t, c in (r.get("tool_breakdown") or {}).items():
276
+ all_tools[t] += c
277
+ total = sum(all_tools.values()) or 1
278
+ summary["tool_freq"] = {t: round(c / total, 3) for t, c in all_tools.most_common()}
279
+
280
+ return {"results": results, "summary": summary}
281
+
282
+
283
+ # ----------------------------------------------------------------------------
284
+ # Comparison printer
285
+ # ----------------------------------------------------------------------------
286
+
287
+ def print_comparison(baseline_path: str, trained_path: str):
288
+ with open(baseline_path) as f:
289
+ b = json.load(f)
290
+ with open(trained_path) as f:
291
+ t = json.load(f)
292
+ bs = b["summary"]
293
+ ts = t["summary"]
294
+
295
+ def row(label, key, fmt="{:.4f}"):
296
+ bv = bs.get(key)
297
+ tv = ts.get(key)
298
+ b_str = fmt.format(bv) if bv is not None else "-"
299
+ t_str = fmt.format(tv) if tv is not None else "-"
300
+ delta = ""
301
+ if isinstance(bv, (int, float)) and isinstance(tv, (int, float)):
302
+ d = tv - bv
303
+ delta = f" ({'+' if d >= 0 else ''}{d:.3f})"
304
+ print(f" {label:<32} {b_str:>12} {t_str:>12}{delta}")
305
+
306
+ print(f"\n{'='*80}")
307
+ print(f"BASELINE vs TRAINED — {bs['n_episodes']} episodes each")
308
+ print(f" baseline label: {b.get('label')} | trained label: {t.get('label')}")
309
+ print(f"{'='*80}")
310
+ print(f" {'metric':<32} {'baseline':>12} {'trained':>12}")
311
+ print(f" {'-'*32} {'-'*12} {'-'*12}")
312
+ row("match_completion_rate", "match_completion_rate")
313
+ row("win_rate_overall", "win_rate_overall")
314
+ row("win_rate_among_completed", "win_rate_among_completed")
315
+ row("mean_agent_score", "mean_agent_score", "{:.2f}")
316
+ row("mean_wickets_lost", "mean_wickets_lost", "{:.2f}")
317
+ row("mean_tool_calls", "mean_tool_calls", "{:.1f}")
318
+ row("mean_validity_rate", "mean_validity_rate")
319
+ row("mean_composite_reward", "mean_composite_reward")
320
+ row("mean_r_result", "mean_r_result")
321
+ row("mean_r_cricket", "mean_r_cricket")
322
+ row("mean_r_behavior", "mean_r_behavior")
323
+ row("mean_r_validity", "mean_r_validity")
324
+ print(f"{'='*80}\n")
325
+
326
+
327
+ # ----------------------------------------------------------------------------
328
+ # Main
329
+ # ----------------------------------------------------------------------------
330
+
331
+ def main():
332
+ parser = argparse.ArgumentParser(description="Baseline vs trained eval for CricketCaptain.")
333
+ parser.add_argument("--model", default="Qwen/Qwen3.5-4B", help="Base HF model id")
334
+ parser.add_argument("--adapter", default=None, help="Optional LoRA adapter directory")
335
+ parser.add_argument("--label", default="run", help="Label for this run (used in output)")
336
+ parser.add_argument("--episodes", type=int, default=10)
337
+ parser.add_argument("--max-overs", type=int, default=5)
338
+ parser.add_argument("--opponent-mode", default="heuristic",
339
+ choices=["heuristic", "llm_live", "llm_cached", "cricsheet"])
340
+ parser.add_argument("--agent-team", default="india")
341
+ parser.add_argument("--eval-pack-id", default="adaptive_t20_v1")
342
+ parser.add_argument("--seed-base", type=int, default=10000)
343
+ parser.add_argument("--max-tool-calls", type=int, default=800)
344
+ parser.add_argument("--max-completion-per-turn", type=int, default=256)
345
+ parser.add_argument("--temperature", type=float, default=0.3)
346
+ parser.add_argument("--output", default=None, help="JSON output path")
347
+ parser.add_argument("--verbose", action="store_true")
348
+
349
+ parser.add_argument("--compare", nargs=2, default=None, metavar=("BASELINE_JSON", "TRAINED_JSON"),
350
+ help="Skip eval; just print comparison from two existing JSON files")
351
+ args = parser.parse_args()
352
+
353
+ if args.compare:
354
+ print_comparison(args.compare[0], args.compare[1])
355
+ return
356
+
357
+ print(f"\nCricketCaptain compare-eval — label='{args.label}'")
358
+ print(f" model={args.model} adapter={args.adapter or '(none)'}")
359
+ print(f" {args.episodes} episodes × {args.max_overs} overs vs {args.opponent_mode} opponent\n")
360
+
361
+ model, tokenizer = load_model_for_eval(args.model, args.adapter)
362
+
363
+ out = run_n_episodes(
364
+ model=model, tokenizer=tokenizer,
365
+ episodes=args.episodes, max_overs=args.max_overs,
366
+ opponent_mode=args.opponent_mode,
367
+ agent_team=args.agent_team, eval_pack_id=args.eval_pack_id,
368
+ seed_base=args.seed_base, max_tool_calls=args.max_tool_calls,
369
+ max_completion_per_turn=args.max_completion_per_turn,
370
+ temperature=args.temperature, verbose=args.verbose,
371
+ )
372
+ out["label"] = args.label
373
+ out["model"] = args.model
374
+ out["adapter"] = args.adapter
375
+ out["config"] = {
376
+ "episodes": args.episodes, "max_overs": args.max_overs,
377
+ "opponent_mode": args.opponent_mode, "agent_team": args.agent_team,
378
+ "max_tool_calls": args.max_tool_calls,
379
+ "max_completion_per_turn": args.max_completion_per_turn,
380
+ "temperature": args.temperature,
381
+ }
382
+
383
+ print("\n=== SUMMARY ===")
384
+ print(json.dumps(out["summary"], indent=2))
385
+
386
+ if args.output:
387
+ out_path = Path(args.output)
388
+ out_path.parent.mkdir(parents=True, exist_ok=True)
389
+ with out_path.open("w") as f:
390
+ json.dump(out, f, indent=2)
391
+ print(f"\nResults → {out_path}")
392
+
393
+
394
+ if __name__ == "__main__":
395
+ main()
configs/cricket_train_qwen3.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # CricketCaptain-LLM — Qwen3 MAIN run (5-over format, end-to-end matches)
3
+ #
4
+ # Differences from cricket_train.yaml:
5
+ # - Model: Qwen3-4B-Instruct-2507 (vs Qwen3.5-4B which has no vLLM class)
6
+ # - vLLM colocate enabled (3-5× throughput gain)
7
+ # - 5-over format instead of 20-over (20-over end-to-end is infeasible on 1× H200)
8
+ # - Resumes Qwen3 warmup adapter, not Qwen3.5 adapter
9
+ # - Reward clip removed at train.py:1240 (separately) — let GRPO standardize advantage
10
+ # - More steps (200 vs 30) since vLLM makes them affordable
11
+ #
12
+ # Stack: TRL 1.2.0 + transformers 5.6.2 + vLLM + PEFT (LoRA) + bf16 base
13
+ # Target hardware: 1× H200 (144 GB), single node
14
+ # =============================================================================
15
+
16
+ env:
17
+ eval_pack_id: adaptive_t20_v1
18
+ agent_team: india
19
+ max_overs: 5 # 5-over end-to-end. ~180 tool calls per match.
20
+ env_url: ws://localhost:8000
21
+
22
+ opponent:
23
+ mode: heuristic # llm_live optional but slows training ~2-3×
24
+
25
+ train:
26
+ # ---- Model & adapter ----
27
+ model: Qwen/Qwen3-4B-Instruct-2507
28
+ # Resume from the Qwen3 warmup adapter (NOT from Qwen3.5 stage2_final).
29
+ # Comment this line out to start the main run with a fresh adapter.
30
+ resume_from: ./checkpoints/stage2_final
31
+ stage: 2
32
+
33
+ # ---- Dataset ----
34
+ prompts: 256
35
+
36
+ # ---- Schedule ----
37
+ # 5-over rollouts ≈ ~180 tool calls each. With vLLM colocate at B=1 sims=4,
38
+ # roughly ~3-4 min/step. 100 steps ≈ ~5-7 hrs.
39
+ steps: 100
40
+ logging_steps: 1
41
+ save_steps: 20
42
+ save_total_limit: 5
43
+
44
+ # ---- Throughput knobs ----
45
+ # B=1 × grad_accum=4 × G=4 → gen_batch=4, 4 sims in flight per step.
46
+ # Matches the working warmup config. Warmup at B=2 / 24k OOM'd on backward
47
+ # due to gradient-accum buffer churn — at 32k completion it would OOM worse.
48
+ # KV cache ≈ 4 × 36k × 144 KB ≈ 21 GB, comfortable in vLLM 0.55 pool (78 GB).
49
+ batch_size: 1
50
+ grad_accum: 4
51
+ num_generations: 4
52
+
53
+ # ---- Length budget ----
54
+ # 24k completion. At 32k we OOM'd at step 7 of main (140+ GB used). 24k gives
55
+ # ~7 GB headroom for backward-pass activation memory. Full 5-over match needs
56
+ # ~9k tokens of model output, so 24k is 2.5x headroom — plenty.
57
+ max_completion_length: 24576
58
+ max_tool_calling_iterations: 240
59
+
60
+ # ---- Optimizer ----
61
+ learning_rate: 5.0e-6
62
+
63
+ # ---- GRPO ----
64
+ # beta=0.0: no reference model. With 24k completion + B=16 sims, the 8 GB
65
+ # ref model would push us past the H200's 144 GB. Format penalty in reward
66
+ # is the soft anchor instead.
67
+ beta: 0.0
68
+ temperature: 0.9
69
+ top_p: 0.95
70
+
71
+ # ---- Memory ----
72
+ gradient_checkpointing: true
73
+ gradient_checkpointing_use_reentrant: false
74
+ bf16_base: true
75
+
76
+ # ---- vLLM ----
77
+ use_vllm: true
78
+ vllm_gpu_memory: 0.55
79
+
80
+ # ---- Dataloader ----
81
+ dataloader_pin_memory: true
82
+ dataloader_num_workers: 4
83
+
84
+ # ---- Logging ----
85
+ report_to: wandb
86
+ run_name: cricket_qwen3_main
87
+
88
+ # =============================================================================
89
+ # Reward composition (defined in server/reward_calculator.py):
90
+ # composite = 0.55·r_result + 0.25·r_cricket + 0.15·r_behavior + 0.05·r_validity
91
+ #
92
+ # WATCH during training:
93
+ # - rollout/match_completion_rate ≥ 0.7 within ~50 steps
94
+ # → if not, episodes are still hitting cap; tighten per-turn schema or drop max_overs
95
+ # - reward/r_result_mean separate from composite
96
+ # → if r_result stays at 0 while composite rises, you're optimizing format only
97
+ # - episode/tool_calls_mean
98
+ # → should be ~150-200 for 5-over; >220 means truncation events are common
99
+ # =============================================================================
configs/cricket_train_qwen3_smoke.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # CricketCaptain-LLM — Qwen3 SMOKE TEST (validate vLLM colocate works)
3
+ #
4
+ # Purpose: shortest possible run to confirm:
5
+ # 1. Qwen3-4B-Instruct-2507 loads cleanly into vLLM colocate
6
+ # (no Qwen3.5 architecture-class registration error)
7
+ # 2. TRL multi-turn environment_factory steps execute without crashing
8
+ # 3. At least one episode reaches `done=True` so `r_result` fires
9
+ #
10
+ # Expected runtime: ~10-15 min on a single H200.
11
+ # Run before kicking off cricket_train_qwen3_warmup.yaml.
12
+ #
13
+ # Stack target: conda cloudspace (torch 2.10, transformers 5.6.2, trl 1.2.0)
14
+ # + vllm + flash-attn installed via .venv-qwen3
15
+ # =============================================================================
16
+
17
+ env:
18
+ eval_pack_id: adaptive_t20_v1
19
+ agent_team: india
20
+ max_overs: 2 # smallest format — match must complete in ~70 turns
21
+ env_url: ws://localhost:8000
22
+
23
+ opponent:
24
+ mode: heuristic # deterministic-ish, no API costs
25
+
26
+ train:
27
+ # ---- Model ----
28
+ model: Qwen/Qwen3-4B-Instruct-2507
29
+ # 256k native context, no <think> blocks, native Qwen3ForCausalLM in vLLM.
30
+ # Fresh adapter — do NOT load Qwen3.5-trained weights into Qwen3 base.
31
+ # resume_from intentionally omitted.
32
+
33
+ stage: 2
34
+
35
+ # ---- Dataset ----
36
+ prompts: 16 # tiny — smoke only
37
+
38
+ # ---- Schedule ----
39
+ steps: 2 # absolute minimum to test gradient + save
40
+ logging_steps: 1
41
+ save_steps: 2
42
+ save_total_limit: 1
43
+
44
+ # ---- Throughput knobs ----
45
+ # bs=1 + grad_accum=4 + G=4 → generation_batch_size=4 divides G cleanly
46
+ # (TRL 1.2 GRPOConfig requires bs*grad_accum divisible by num_generations).
47
+ # 4 sim episodes in flight. KV cache ≈ 4 × 16k × 144 KB ≈ 9.5 GB → tiny.
48
+ batch_size: 1
49
+ grad_accum: 4
50
+ num_generations: 4
51
+
52
+ # ---- Length budget ----
53
+ # 2-over needs ~70 tool calls. At <120 tok/turn this fits in ~8k.
54
+ # Generous 16k completion to catch any per-turn bloat.
55
+ max_completion_length: 16384
56
+ max_tool_calling_iterations: 120
57
+
58
+ # ---- Optimizer ----
59
+ learning_rate: 5.0e-6
60
+
61
+ # ---- GRPO ----
62
+ beta: 0.0 # no reference model (saves ~8 GB VRAM)
63
+ temperature: 0.9
64
+ top_p: 0.95
65
+
66
+ # ---- Memory ----
67
+ gradient_checkpointing: true
68
+ gradient_checkpointing_use_reentrant: false
69
+ bf16_base: true
70
+
71
+ # ---- vLLM colocate (THE thing being tested) ----
72
+ use_vllm: true
73
+ vllm_gpu_memory: 0.50
74
+ # vllm_model_impl omitted → vLLM picks Qwen3ForCausalLM natively.
75
+ # If you fall back to Qwen3.5-4B for some reason, set this to "transformers".
76
+
77
+ # ---- Dataloader ----
78
+ dataloader_pin_memory: true
79
+ dataloader_num_workers: 2
80
+
81
+ # ---- Logging ----
82
+ report_to: wandb
83
+ run_name: cricket_qwen3_smoke
84
+
85
+ # =============================================================================
86
+ # Run with:
87
+ # source .venv-qwen3/bin/activate
88
+ # python train.py train --config configs/cricket_train_qwen3_smoke.yaml
89
+ #
90
+ # Pass criteria:
91
+ # - 2 gradient steps complete without OOM
92
+ # - logs show `rollout/match_completion_rate > 0`
93
+ # - at least one episode in episode_stats.jsonl has `termination_reason: natural`
94
+ #
95
+ # If pass → run cricket_train_qwen3_warmup.yaml
96
+ # If fail with "Qwen3_5..." → wrong model name, check spelling
97
+ # If fail with vLLM class error → vLLM build doesn't include Qwen3 support, upgrade
98
+ # If fail with LoRA error → known TRL+vLLM+LoRA issue, set vllm_model_impl: transformers
99
+ # =============================================================================
configs/cricket_train_qwen3_warmup.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # CricketCaptain-LLM — Qwen3 WARMUP (2-3 over curriculum, fast iterations)
3
+ #
4
+ # Differences from cricket_train_warmup.yaml:
5
+ # - Model: Qwen3-4B-Instruct-2507 (256k native, no <think>, native vLLM class)
6
+ # - vLLM colocate enabled (works because Qwen3ForCausalLM is a registered class)
7
+ # - Fresh adapter (do NOT resume Qwen3.5 LoRA on Qwen3 base — incompatible)
8
+ # - Slightly tighter schedule given vLLM throughput gain (~3-5x)
9
+ #
10
+ # Stack: TRL 1.2.0 + transformers 5.6.2 + vLLM + PEFT (LoRA) + bf16 base
11
+ # Target hardware: 1× H200 (144 GB)
12
+ # =============================================================================
13
+
14
+ env:
15
+ eval_pack_id: adaptive_t20_v1
16
+ agent_team: india
17
+ max_overs: 0 # 0 = use overs_distribution below
18
+ env_url: ws://localhost:8000
19
+
20
+ opponent:
21
+ mode: heuristic # fast iteration; switch to llm_live for final eval
22
+
23
+ train:
24
+ # ---- Model & adapter ----
25
+ model: Qwen/Qwen3-4B-Instruct-2507
26
+ stage: 2 # full composite reward
27
+ # No resume_from — start fresh on Qwen3 base.
28
+
29
+ # ---- Dataset ----
30
+ prompts: 64
31
+
32
+ # ---- Schedule ----
33
+ # 2-3 over rollouts ≈ ~70-110 tool calls each. With vLLM colocate, ~2-3 min/step
34
+ # at 16 sim episodes. 30 steps ≈ ~1.5 hrs total.
35
+ steps: 30
36
+ logging_steps: 1
37
+ save_steps: 5
38
+ save_total_limit: 5
39
+
40
+ # ---- Curriculum (per-scenario max_overs) ----
41
+ # Heavier on T2 (cleanly completes in token budget), tail to T3.
42
+ # Skip T4/T5 in warmup — those go in the main run.
43
+ overs_distribution: [2, 2, 2, 2, 2, 2, 3, 3, 3]
44
+
45
+ # ---- Throughput knobs ----
46
+ # B=1 × grad_accum=4 × G=4 → gen_batch=4, 4 sims in flight per step.
47
+ # Matches the smoke config that ran cleanly. Slower per step (~30-40s) than
48
+ # B=2 (~55s), but B=2 OOM'd on backward of step 2 due to gradient-accumulation
49
+ # micro-batch buffer churn even with expandable_segments. 30 steps ≈ ~20 min.
50
+ batch_size: 1
51
+ grad_accum: 4
52
+ num_generations: 4
53
+
54
+ # ---- Length budget ----
55
+ # 24k completion = ~130 tok/turn × 180 turns. Generous for 2-3 over format,
56
+ # matches Qwen3-4B-Instruct-2507 recommendation of ≥32k output for most queries
57
+ # (per-rollout cumulative across multi-turn).
58
+ max_completion_length: 24576
59
+ max_tool_calling_iterations: 240
60
+
61
+ # ---- Optimizer ----
62
+ learning_rate: 5.0e-6
63
+
64
+ # ---- GRPO ----
65
+ # beta=0.0: no reference model (saves ~8 GB VRAM, lets G=4 fit at 16k completion).
66
+ # Reward shaping has format penalty as soft anchor.
67
+ beta: 0.0
68
+ temperature: 0.9
69
+ top_p: 0.95
70
+
71
+ # ---- Memory ----
72
+ gradient_checkpointing: true
73
+ gradient_checkpointing_use_reentrant: false
74
+ bf16_base: true
75
+
76
+ # ---- vLLM (THE big change vs Qwen3.5 config) ----
77
+ use_vllm: true
78
+ vllm_gpu_memory: 0.55
79
+ # vllm_model_impl omitted → vLLM picks Qwen3ForCausalLM natively.
80
+
81
+ # ---- Dataloader ----
82
+ dataloader_pin_memory: true
83
+ dataloader_num_workers: 4
84
+
85
+ # ---- Logging ----
86
+ report_to: wandb
87
+ run_name: cricket_qwen3_warmup
88
+
89
+ # =============================================================================
90
+ # Workflow:
91
+ # 1. Smoke test (10-15 min):
92
+ # python train.py train --config configs/cricket_train_qwen3_smoke.yaml
93
+ # 2. Warmup (1-2 hrs):
94
+ # python train.py train --config configs/cricket_train_qwen3_warmup.yaml
95
+ # 3. Main run (5-7 hrs):
96
+ # python train.py train --config configs/cricket_train_qwen3.yaml
97
+ # (resumes from ./checkpoints/stage2_final saved by step 2)
98
+ # =============================================================================
configs/extras/cached_eval.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env:
2
+ # Used by server + runners for reproducible comparison runs.
3
+ eval_pack_id: adaptive_t20_v1
4
+ max_overs: 5
5
+ env_url: ws://localhost:8000
6
+
7
+ opponent:
8
+ # llm_cached does not call `model` live. It replays pre-generated decisions
9
+ # from cache_path so every captain model faces the same opponent behavior.
10
+ mode: llm_cached
11
+ cache_path: data/opponent_cache/adaptive_t20_v1_official_gemma2b.jsonl
12
+
13
+ captain:
14
+ # Captain still calls HF router live in this config.
15
+ model: google/gemma-4-26B-A4B-it
16
+ api_base: https://router.huggingface.co/v1
17
+ api_key_env: HF_TOKEN
18
+
configs/extras/cricket_train.yaml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # CricketCaptain-LLM — dedicated training config
3
+ # Stack: TRL 1.2.0 GRPO + Transformers 5.6.2 + PEFT (LoRA) + bf16 base
4
+ # NOT using vLLM (incompatible with our env), NOT using Unsloth (incompatible
5
+ # with TRL's multi-turn environment_factory).
6
+ # Target run: 10-hour budget on a single H200 (144 GB VRAM), Qwen3.5-4B base.
7
+ # =============================================================================
8
+
9
+ env:
10
+ # eval_pack_id selects scenario distribution + opponent rosters
11
+ eval_pack_id: adaptive_t20_v1
12
+ agent_team: india
13
+ # Match length per episode. 20 = full T20.
14
+ # 5-over (~180 tool calls) → ~5 min/step
15
+ # 20-over (~720 tool calls) → ~15-18 min/step ← current choice
16
+ max_overs: 20
17
+ env_url: ws://localhost:8000
18
+
19
+ opponent:
20
+ # heuristic | llm_live | llm_cached | cricsheet
21
+ # heuristic: rule-based, fast, deterministic-ish — best for fast iteration.
22
+ # llm_live: adversarial Gemma via HF router — realistic but slow + costs API.
23
+ mode: heuristic
24
+
25
+
26
+ train:
27
+ # ---- Model & adapter ----
28
+ model: Qwen/Qwen3.5-4B
29
+ # Resume LoRA from the warmup checkpoint. When this run starts, base model
30
+ # loads from Qwen/Qwen3.5-4B above + LoRA adapter loads from this path.
31
+ # Comment out (with #) to start with a fresh adapter instead.
32
+ resume_from: ./checkpoints/stage2_final
33
+ # Single-stage training. Code uses curriculum_stage=2 internally to mean
34
+ # "full composite reward". Stage 1 (validity-only warm-up) was dropped because
35
+ # Qwen3.5-4B already knows tool calling natively (XML+JSON both accepted).
36
+ stage: 2
37
+
38
+ # ---- Dataset ----
39
+ prompts: 256 # number of unique scenarios; trainer cycles through them
40
+
41
+ # ---- Schedule ----
42
+ # 30 steps × ~15 min/step = ~7-8 hrs for the main run. Total chain (warmup +
43
+ # main) fits in ~10-hr budget. Bump to 100+ for a longer training run if
44
+ # compute is unconstrained — the resume_from path lets you continue cleanly.
45
+ steps: 30
46
+ logging_steps: 1 # log every step (set 10 for less wandb noise)
47
+
48
+ # ---- Throughput knobs (H200, 144 GB VRAM) ----
49
+ # batch_size × grad_accum = effective rollout per gradient update.
50
+ # batch_size × num_generations = simultaneous in-flight episodes (memory driver).
51
+ # 4 × 4 = 16 sim ≈ ~80 GB peak at 4096 max_completion — needed since 32 sim @ 4096
52
+ # OOMs at ~110 GB. grad_accum=4 keeps effective batch healthy.
53
+ batch_size: 4
54
+ grad_accum: 4
55
+ num_generations: 4 # GRPO requires ≥2 for group advantage; 4 is the sweet spot
56
+
57
+ # ---- Length budget ----
58
+ # 4096 = ~16 turns per rollout at ~250 tok/turn. Larger budget than warmup
59
+ # because main run uses smaller batch (32 sim @ batch=8 × num_gen=4) so
60
+ # KV cache fits comfortably.
61
+ max_completion_length: 4096
62
+ # Hard cap on tool calls per episode. 5-over needs ~180; 20-over needs ~720.
63
+ # 800 leaves slack for extra-balls (no-balls/wides) without truncating matches.
64
+ max_tool_calling_iterations: 800
65
+
66
+ # ---- Optimizer ----
67
+ # Bumped 5e-6 → 1e-5 for the main run. Warmup at 5e-6 showed flat reward and
68
+ # tiny loss magnitudes (~0.02) with grad_norm 0.35 — gradients flowing but
69
+ # weights barely moving. 2× LR doubles step size without re-entering instability
70
+ # territory (still well below the 5e-5 commonly used for r=64 LoRA SFT).
71
+ learning_rate: 1.0e-5
72
+
73
+ # ---- GRPO knobs ----
74
+ # KL coefficient. Default in TRL 1.2 is 0.0 — and 0.0 specifically means the
75
+ # reference model is NOT loaded (saves ~8 GB weights + ref-forward-pass mem).
76
+ # Any beta > 0 loads a frozen copy of the base model for KL anchoring.
77
+ # We use 0.0 because: (a) memory budget is tight at batch=16 + 20-over,
78
+ # (b) reward is well-shaped (composite with format penalty), so format collapse
79
+ # is unlikely, (c) cricket strategy isn't in the base distribution — we want drift.
80
+ beta: 0.0
81
+ # Sampling — slight bump from 0.8 for GRPO group diversity. top_p=0.95 trims
82
+ # the long tail (rare bad tokens during 720-turn rollouts).
83
+ temperature: 0.9
84
+ top_p: 0.95
85
+
86
+ # ---- Memory ----
87
+ # Trade ~30% extra backward-pass compute for big activation memory savings
88
+ # (>20 GB). Required to fit batch_size=16 + num_gen=4 + 3072 completion in 144 GB.
89
+ # use_reentrant=False is the modern path, more stable with LoRA than the legacy True.
90
+ gradient_checkpointing: true
91
+ gradient_checkpointing_use_reentrant: false
92
+
93
+ # ---- Dataloader ----
94
+ # Cheap micro-opts — pin host memory + a few worker threads so CPU isn't a
95
+ # bottleneck feeding prompts. Tiny win since our dataset is only 256 prompts.
96
+ dataloader_pin_memory: true
97
+ dataloader_num_workers: 4
98
+
99
+ # ---- Checkpointing ----
100
+ # Save every 10 steps; keep only the 5 most recent on disk to cap usage.
101
+ save_steps: 10
102
+ save_total_limit: 5
103
+
104
+ # ---- LoRA (currently hardcoded in train.py load_model()) ----
105
+ # r=64, alpha=128, dropout=0.05, targets q/k/v/o/gate/up/down
106
+ # → 85M trainable params (1.98% of 4.2B base)
107
+
108
+ # ---- Precision ----
109
+ # bf16 base + bf16 LoRA adapter. NO 4-bit quant — H200 has the VRAM and bf16
110
+ # is 15-20% faster than 4-bit dequant on every forward pass.
111
+ bf16_base: true
112
+
113
+ # ---- Logging ----
114
+ report_to: wandb # or "tensorboard" for local-only
115
+ run_name: cricket_captain_v10
116
+
117
+ # =============================================================================
118
+ # Reward composition (weights live in configs/game_knowledge.yaml `reward:` block)
119
+ # composite = 0.35·r_result + 0.30·r_cricket + 0.25·r_behavior + 0.10·r_validity
120
+ #
121
+ # r_result — match outcome (chase margin, defense margin, win/tie bonus)
122
+ # r_cricket — Dream11 fantasy proxy as dense per-ball signal
123
+ # r_behavior — coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%)
124
+ # r_validity — fraction of legal tool calls
125
+ # =============================================================================
configs/extras/cricket_train_warmup.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # CricketCaptain-LLM — WARMUP training config (5-over format, fast iterations)
3
+ # Stack: TRL 1.2.0 GRPO + Transformers 5.6.2 + PEFT (LoRA) + bf16 base
4
+ # Purpose: short matches + bigger GRPO group = fast format mastery before
5
+ # moving to the full 20-over training run (configs/cricket_train.yaml).
6
+ #
7
+ # Why bigger num_generations here?
8
+ # 5-over rollouts are ~4× faster per step than 20-over, so we have memory +
9
+ # wall-clock budget for G=8. Bigger group → more stable GRPO advantage signal,
10
+ # especially useful when the model is just learning format/tactics.
11
+ # =============================================================================
12
+
13
+ env:
14
+ eval_pack_id: adaptive_t20_v1
15
+ agent_team: india
16
+ # max_overs: 0 unlocks the per-scenario curriculum distribution (see train.overs_distribution).
17
+ # Set to a positive integer to lock all scenarios to that single format.
18
+ max_overs: 0
19
+ env_url: ws://localhost:8000
20
+
21
+ opponent:
22
+ # Heuristic for fast iteration. Switch to llm_live for final eval, not warmup.
23
+ mode: heuristic
24
+
25
+
26
+ train:
27
+ # ---- Model & adapter ----
28
+ model: Qwen/Qwen3.5-4B
29
+ # Single-stage. curriculum_stage=2 internally = full composite reward.
30
+ stage: 2
31
+
32
+ # ---- Dataset ----
33
+ prompts: 64 # smaller dataset since we're only doing 25 steps
34
+
35
+ # ---- Schedule ----
36
+ # 5-over ≈ ~180 tool calls / episode → ~4-5 min / step at 32 sim episodes.
37
+ # 25 steps ≈ ~2 hrs total. Quick warmup before kicking off the full
38
+ # 20-over run with cricket_train.yaml.
39
+ steps: 25
40
+ logging_steps: 1
41
+ # Save every 5 steps → 5 checkpoints across 25 steps. Keep only 5 on disk.
42
+ save_steps: 5
43
+ save_total_limit: 5
44
+
45
+ # ---- Curriculum (per-scenario max_overs) ----
46
+ # Heavy on T2 (the only format that actually completes inside our token budget),
47
+ # tail to T5 (target eval distribution). Activated when env.max_overs is 0 or unset.
48
+ # Frequencies: ~45% T2, ~27% T3, ~18% T4, ~9% T5.
49
+ overs_distribution: [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
50
+
51
+ # ---- Throughput knobs ----
52
+ # 4 × 4 = 16 sim episodes — needed because 4096 max_completion at 32 sim OOMs
53
+ # (~110 GB observed). At 16 sim: ~80 GB, fits with margin. grad_accum=4 keeps
54
+ # the effective gradient batch size where G=4 still gives stable advantages.
55
+ batch_size: 4
56
+ grad_accum: 4
57
+ num_generations: 4
58
+
59
+ # ---- Length budget ----
60
+ # 4096 = ~16 turns per rollout at ~250 tok/turn. Lets T2 episodes get further
61
+ # toward completion (~5-8 balls per innings vs ~3-5 at 3072).
62
+ max_completion_length: 4096
63
+ # Cap above what 5-over needs (~180) so no truncation from extras.
64
+ max_tool_calling_iterations: 240
65
+
66
+ # ---- Optimizer ----
67
+ learning_rate: 5.0e-6
68
+
69
+ # ---- GRPO knobs ----
70
+ # beta=0.0 → no reference model loaded → saves ~8 GB VRAM and lets G=8 fit.
71
+ beta: 0.0
72
+ temperature: 0.9
73
+ top_p: 0.95
74
+
75
+ # ---- Memory ----
76
+ gradient_checkpointing: true
77
+ gradient_checkpointing_use_reentrant: false
78
+
79
+ # ---- Dataloader ----
80
+ dataloader_pin_memory: true
81
+ dataloader_num_workers: 4
82
+
83
+ # ---- Precision ----
84
+ bf16_base: true
85
+
86
+ # ---- Logging ----
87
+ report_to: wandb
88
+ run_name: cricket_captain_warmup_5over
89
+
90
+ # =============================================================================
91
+ # Workflow:
92
+ # 1. Train warmup: python train.py train --config configs/cricket_train_warmup.yaml
93
+ # 2. Train main run: python train.py train --config configs/cricket_train.yaml \
94
+ # --model ./checkpoints/stage2_final (resume from warmup)
95
+ # =============================================================================
configs/extras/default.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env:
2
+ # Used by server + runners
3
+ eval_pack_id: adaptive_t20_v1
4
+ max_overs: 20
5
+ agent_team: india
6
+ env_url: ws://localhost:8000
7
+
8
+ opponent:
9
+ # heuristic | llm_live | llm_cached
10
+ # llm_live is the adversarial opponent. The captain being trained still runs
11
+ # locally through TRL/Transformers, not through the HF inference endpoint.
12
+ # For reproducible cached evaluation, use configs/cached_eval.yaml instead.
13
+ # H200 default: use the live LLM opponent for realism.
14
+ mode: llm_live
15
+ model: google/gemma-4-26B-A4B-it
16
+ api_base: https://router.huggingface.co/v1
17
+ api_key_env: HF_TOKEN
18
+
19
+ captain:
20
+ # For inference/eval runner when using an API model (OpenAI-compatible).
21
+ # You can still pass --model random for baseline runs.
22
+ model: google/gemma-4-26B-A4B-it
23
+ api_base: https://router.huggingface.co/v1
24
+ api_key_env: HF_TOKEN
25
+
26
+ train:
27
+ model: Qwen/Qwen3.5-4B
28
+ # SINGLE-STAGE training. The two-stage curriculum (Stage 1 = format-only)
29
+ # was dropped — Qwen3.5-4B already knows tool calling natively (XML+JSON),
30
+ # so we run the full composite reward (r_result + r_cricket + r_behavior + r_validity)
31
+ # from step 0. Code still uses curriculum_stage=2 internally to mean "full reward".
32
+ stage: 2
33
+ prompts: 256
34
+ # 10-hr H200 budget. ~150 steps × ~4-5 min each = ~10 hrs.
35
+ steps: 150
36
+ # H200 (144GB) v8 config: bigger micro-batch, fewer grad accum steps, more updates.
37
+ # Effective rollout per gradient update = batch_size * grad_accum = 32 prompts
38
+ # → with num_generations=4 that's 128 episodes per gradient update.
39
+ batch_size: 16
40
+ grad_accum: 2
41
+ num_generations: 4 # GRPO requires >= 2 for group advantage; 4 is the standard sweet spot
42
+ # 3072 = enough headroom for thinking (v5 mean was 1115 tokens) without the
43
+ # KV-cache bloat of 4096. Cuts memory for the rollout buffer ~25%.
44
+ max_completion_length: 3072
45
+ # Episodes terminate either when match finishes or this cap hits.
46
+ # 5-over match needs ~180 tool calls, 20-over T20 needs ~720. Don't truncate.
47
+ max_tool_calling_iterations: 800
48
+ # Learning rate — was hardcoded to 1e-5; lowered to 5e-6 because of LoRA r=64
49
+ # (4× more trainable params than r=16) and partially-sparse outcome reward.
50
+ # Smaller LR = more stable updates when r_result is binary at innings end.
51
+ learning_rate: 5e-6
52
+ # Trade ~30% extra compute per backward pass for big activation memory savings
53
+ # (>20 GB freed). Lets us push micro-batch / num_generations with safety margin.
54
+ gradient_checkpointing: true
55
+ logging_steps: 1 # log loss/reward every step (set 10 for less noise)
56
+ # Switch to `tensorboard` if you prefer local-only logging.
57
+ report_to: wandb
58
+ run_name: cricket_captain_v10
59
+
configs/game_knowledge.yaml CHANGED
@@ -29,11 +29,22 @@ transition_overs: [6, 16]
29
  # Reward weights (must sum to 1.0)
30
  # ---------------------------------------------------------------------------
31
  reward:
32
- # Episode-level composite
33
- r_result: 0.55 # match outcome: win/loss, target margin, DLS/par
34
- r_cricket: 0.25 # dense cricket position signal (Dream11 proxy)
35
- r_behavior: 0.15 # plan-action coherence, adaptation, opponent awareness
36
- r_validity: 0.05 # legal JSON tool use gate
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Within r_behavior
39
  behavior:
 
29
  # Reward weights (must sum to 1.0)
30
  # ---------------------------------------------------------------------------
31
  reward:
32
+ # Episode-level composite — REBALANCED for partial-trajectory training.
33
+ # Original 55/25/15/5 was textbook "outcome-dominated" but assumed matches
34
+ # would actually complete in the token budget. Reality: episodes truncate at
35
+ # ~25% of T2 and r_result almost never fires. Putting 55% weight on a signal
36
+ # that fires <5% of the time washes out gradient.
37
+ #
38
+ # New weights match the SWE-RL / coding-agent-RL recipe. Re-rebalanced for
39
+ # the main 20-over run: at full T20 length r_result actually fires (matches
40
+ # complete), so we shift 0.15 weight back from r_cricket → r_result to give
41
+ # the outcome signal real gradient pull. r_cricket alone was producing flat,
42
+ # bouncy reward curves in warmup — per-ball Dream11 is too noisy as 45% of
43
+ # the gradient when group size is only 4.
44
+ r_result: 0.35 # match outcome — fires reliably at 20 overs, biggest signal
45
+ r_cricket: 0.30 # dense per-ball Dream11 (analog: partial test-pass)
46
+ r_behavior: 0.25 # per-turn coherence/adaptation (analog: lint/quality)
47
+ r_validity: 0.10 # tool format (analog: compile success)
48
 
49
  # Within r_behavior
50
  behavior:
docs/benchmark_explainer.md CHANGED
@@ -30,14 +30,13 @@ The original motivation came from strategic coherence: LLMs often say one thing
30
 
31
  ## 2. Fit With OpenEnv Competition Themes
32
 
33
- The environment aligns with multiple OpenEnv hackathon themes.
34
-
35
  ### Multi-Agent Interactions
36
 
37
  The submitted captain agent plays against an opponent policy. The opponent can be:
38
 
39
- - `heuristic`: fast deterministic/stochastic cricket logic.
40
- - `llm_live`: live OpenAI-compatible LLM opponent.
 
41
  - `llm_cached`: replayed opponent decisions for reproducible evaluation.
42
 
43
  This tests whether the agent can reason about another actor's incentives, field settings, and likely plans.
@@ -48,29 +47,11 @@ A full match has many decisions across innings, phases, wickets, and pressure st
48
 
49
  ### World Modeling
50
 
51
- The agent observes a partially summarized cricket world:
52
-
53
- - score,
54
- - over/ball,
55
- - wickets,
56
- - target,
57
- - phase,
58
- - field,
59
- - batter profile,
60
- - bowler profile,
61
- - opponent plan,
62
- - previous outcome.
63
-
64
- It must maintain an internal model of what is happening and update that model after every ball.
65
 
66
  ### Self-Improvement
67
 
68
- The same environment can support:
69
-
70
- - heuristic curriculum training,
71
- - cached-opponent official evaluation,
72
- - live LLM opponent self-play,
73
- - future agent-vs-agent training.
74
 
75
  ## 3. Environment Flow
76
 
@@ -86,8 +67,6 @@ Within each batting or bowling phase, the tactical loop is:
86
  PRE_OVER -> PRE_BALL -> BALL_RESOLUTION -> POST_BALL -> next decision
87
  ```
88
 
89
- The captain can use tools at different points in that loop.
90
-
91
  ### Toss
92
 
93
  ```json
@@ -97,40 +76,21 @@ The captain can use tools at different points in that loop.
97
  ### Batting Tools
98
 
99
  ```json
100
- {"tool": "select_batter", "arguments": {"name": "Anchor", "style": "anchor", "aggression": 0.35, "rationale": "Preserve wickets in the middle overs."}}
101
- ```
102
-
103
- ```json
104
  {"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Rotate strike against spin and keep wickets in hand."}}
105
- ```
106
-
107
- ```json
108
  {"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy single."}}
109
- ```
110
-
111
- ```json
112
  {"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Work into the gap."}}
113
  ```
114
 
115
- ### Bowling Tools
116
 
117
- ```json
118
- {"tool": "choose_bowler", "arguments": {"name": "Death Specialist", "bowler_type": "pace", "style": "yorker", "rationale": "Attack the stumps at the death."}}
119
- ```
120
 
121
  ```json
 
122
  {"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Limit swing room."}}
123
- ```
124
-
125
- ```json
126
  {"tool": "set_field_setting", "arguments": {"setting": "Defensive"}}
127
- ```
128
-
129
- ```json
130
  {"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Protect boundaries and force a low-percentage shot."}}
131
- ```
132
-
133
- ```json
134
  {"tool": "bowl_delivery", "arguments": {}}
135
  ```
136
 
@@ -146,309 +106,197 @@ The captain can use tools at different points in that loop.
146
  {"tool": "analyze_situation", "arguments": {"query_type": "match_situation"}}
147
  ```
148
 
149
- ## 4. OpenEnv Architecture
 
 
 
 
 
 
 
 
150
 
151
- The environment follows OpenEnv's standard client/server split.
 
 
 
 
152
 
153
  ```text
154
  LLM Agent / Evaluator
155
  |
156
- | WebSocket
157
  v
158
- OpenEnv FastAPI server
159
  |
160
  v
161
- CricketEnvironment
162
  |
163
- +--> MarkovCricketEngine
164
- +--> OpponentPolicy
165
- +--> CoherenceGrader
166
- +--> RewardCalculator
167
- ```
168
-
169
- Important files:
170
-
171
- - `server/app.py`: creates the OpenEnv server using `create_app(...)`.
172
- - `server/cricket_environment.py`: implements `reset`, `step`, and `state`.
173
- - `models.py`: defines `CricketAction`, `CricketObservation`, and `CricketState`.
174
- - `client.py`: defines the WebSocket client `CricketCaptainEnv`.
175
- - `inference.py`: runs random or OpenAI-compatible agents against the server.
176
- - `eval.py`: runs evaluation episodes and saves plots/raw logs.
177
-
178
- The standard OpenEnv lifecycle is:
179
-
180
- ```python
181
- obs = env.reset(...)
182
- obs = env.step(CricketAction(tool="...", arguments={...}))
183
- state = env.state
184
- ```
185
-
186
- This matters for competition compliance because clients do not need to import server internals. They interact through the OpenEnv API.
187
-
188
- ## 5. What The Observation Contains
189
-
190
- Each step returns a `CricketObservation` with fields like:
191
-
192
- - `game_state`: toss / batting / bowling / finished.
193
- - `strategic_phase`: pre_over / pre_ball / ball_resolution / post_ball.
194
- - `game_context`: score, wickets, over, ball, target, phase.
195
- - `declared_strategy`: current batting strategy.
196
- - `bowling_strategy`: current bowling plan.
197
- - `field_setting`: Aggressive / Balanced / Defensive.
198
- - `current_batter`: batter profile.
199
- - `current_bowler`: bowler profile.
200
- - `opponent_plan`: last visible opponent policy decision.
201
- - `last_outcome`: previous ball outcome plus tactical metadata such as event type, shot zone, delivery features, field pressure, and fielder effect.
202
- - `available_tools`: legal tools for current state.
203
- - `prompt_text`: rendered prompt for the LLM.
204
-
205
- The LLM sees enough information to reason tactically, but not the entire simulator internals.
206
-
207
- ## 6. Opponent Policies
208
-
209
- Opponent behavior lives in `server/opponent_policy.py`.
210
-
211
- There are three modes:
212
-
213
- ### `heuristic`
214
-
215
- Fast local policy. Useful for:
216
-
217
- - tests,
218
- - development,
219
- - cheap training rollouts,
220
- - baseline comparison.
221
-
222
- ### `llm_live`
223
-
224
- Calls an OpenAI-compatible LLM with a fixed prompt. Useful for:
225
-
226
- - demos,
227
- - realistic opponent behavior,
228
- - self-play-style experiments.
229
-
230
- The current default live opponent/captain model is:
231
-
232
- ```text
233
- google/gemma-4-26B-A4B-it via https://router.huggingface.co/v1
234
  ```
235
 
236
- In this mode, the opponent actually calls the configured model during the run.
237
-
238
- ### `llm_cached`
239
-
240
- Reads pre-recorded opponent decisions from JSONL. Useful for:
241
 
242
- - official leaderboard-style evaluation,
243
- - reproducibility,
244
- - preventing eval randomness.
 
 
 
 
 
 
 
 
 
245
 
246
- In this mode, the opponent does **not** call the configured model live. It replays the JSONL cache so every compared captain faces the same opponent decisions.
247
 
248
- The key idea:
249
 
250
- > Teams can change their agent however they want, but the evaluation opponent should be frozen.
251
-
252
- ## 7. Ball Physics And Markov Engine
253
-
254
- The simulation uses `server/markov_engine.py` plus the field/zone definitions in `server/field_model.py`.
255
-
256
- It supports:
257
-
258
- 1. Synthetic transition probabilities from `data/transition_probs.json`.
259
- 2. Cricsheet-derived transition tables from `data/processed/cricket_transitions_v1.pkl`.
260
-
261
- The upgraded ball resolver uses both sides' plans:
262
-
263
- ```text
264
- outcome ~ P(
265
- shot_plan,
266
- delivery_plan,
267
- batter_profile,
268
- bowler_profile,
269
- field_setting,
270
- phase,
271
- score,
272
- wickets,
273
- target_pressure
274
- )
275
- ```
276
 
277
- The engine first samples from the transition table, then applies a hybrid tactical layer:
 
 
 
 
278
 
279
- - hitter/finisher benefits aggressive shots,
280
- - anchor benefits low-risk rotation,
281
- - yorker/death specialist suppresses boundaries,
282
- - shot target zones (`cover`, `point`, `midwicket`, `long_on`, etc.) are matched against delivery line/length/variation,
283
- - field presets expand `Aggressive`, `Balanced`, and `Defensive` into named fielder zones,
284
- - boundary riders can cut off fours/sixes and inner ring fielders can save singles,
285
- - close catchers/slips/gully can convert edges into wickets,
286
- - wides/no-balls, drops, misfields, overthrows, run-outs, bowled/LBW routes, and caught-in-zone events add bounded stochastic noise,
287
- - high chase pressure makes defensive batting less useful.
288
 
289
- This keeps the simulator simple enough to train on while making actions meaningfully interact: a six toward an unprotected leg-side boundary is not the same as a lofted hit toward deep midwicket with a rider waiting.
290
 
291
- ## 8. Reward Design
292
 
293
- Rewards are intentionally multi-component. We do not want an agent that wins by gaming one metric.
294
 
295
- Main components:
296
 
297
- ### Result Quality
298
 
299
- Measures the long-horizon cricket objective:
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- - win/loss,
302
- - score vs DLS/par,
303
- - chase success,
304
- - defense success,
305
- - wickets preserved or taken.
306
 
307
- This is the benchmark's primary outcome. The agent is trained over many simulated matches so it learns policies that improve match result, not just isolated ball-level actions.
308
 
309
- ### Dream11 Auxiliary Signal
310
 
311
- Provides a dense cricket contribution proxy:
312
 
313
- - batting runs, strike rate, boundaries,
314
- - bowling wickets, dots, economy,
315
- - milestone and dismissal bonuses.
316
 
317
- This helps training get intermediate signal, but it does **not** replace the win/loss objective.
318
 
319
- ### Plan-Action Coherence
 
 
320
 
321
- Checks whether the action matches the declared plan.
322
 
323
- Example:
324
 
325
- - Declared aggression `0.30` plus `single` is coherent.
326
- - Declared aggression `0.30` plus `six` is less coherent.
327
 
328
- ### Strategic Adaptation
329
 
330
- Rewards plans that change with context:
331
 
332
- - new phase,
333
- - target pressure,
334
- - wickets down,
335
- - previous reflection,
336
- - opponent behavior.
337
 
338
- ### Opponent Awareness
339
 
340
- Rewards plans that mention or respond to:
 
 
341
 
342
- - field setting,
343
- - bowler type,
344
- - batter style,
345
- - opponent plan,
346
- - phase/matchup.
 
347
 
348
- ### Regret-Style Score
349
 
350
- Compares chosen action quality against simple heuristic alternatives.
351
 
352
- ### Tool Efficiency
 
 
 
 
 
353
 
354
- Rewards useful `analyze_situation` calls and penalizes spam indirectly.
355
 
356
- ### Format Validity
357
 
358
- Rewards valid JSON and valid tools.
359
 
360
- The episode reward combines:
361
 
362
- ```text
363
- 25% result quality / match outcome
364
- 10% Dream11 dense cricket proxy
365
- 30% strategy bundle
366
- 20% tool efficiency
367
- 15% format validity
368
  ```
369
-
370
- The strategy bundle includes coherence, adaptation, opponent awareness, and regret-style scoring.
371
-
372
- ## 9. Data Curation Pipeline
373
-
374
- The data pipeline is designed to keep training and environment behavior aligned.
375
-
376
- ### Step 1: Curate Ball Outcomes
377
-
378
- Script:
379
-
380
- ```bash
381
- python scripts/curate_transitions.py --format t20
382
  ```
383
 
384
- Outputs:
385
-
386
- - `data/processed/ball_outcomes_t20_v1.pkl`
387
- - `data/processed/cricket_transitions_v1.pkl`
388
 
389
- The rich ball outcome records include:
 
 
390
 
391
- - both innings,
392
- - target,
393
- - required rate,
394
- - legal vs extra delivery,
395
- - batter name,
396
- - bowler name,
397
- - bowler type,
398
- - dismissal type,
399
- - phase,
400
- - score before ball,
401
- - wickets before ball,
402
- - runs and wicket outcome.
403
 
404
- ### Step 2: Build Player Profiles
405
 
406
- Script:
407
 
408
  ```bash
409
- python scripts/build_player_profiles.py
 
410
  ```
411
 
412
- Output:
413
-
414
- - `data/processed/player_profiles_t20_v1.json`
415
-
416
- Profiles include:
417
-
418
- - batter style: anchor, balanced, hitter, finisher,
419
- - bowler style: pace, spin, death specialist, economy, wicket-taker,
420
- - phase strengths,
421
- - economy,
422
- - strike rate,
423
- - dot rate,
424
- - Dream11-style pressure proxy.
425
 
426
- ### Step 3: Build Evaluation Pack
427
-
428
- Script:
429
 
430
  ```bash
431
  python scripts/build_eval_pack.py --eval-pack-id adaptive_t20_v1
432
  ```
433
 
434
- Output:
435
-
436
- - `data/eval_packs/adaptive_t20_v1.json`
437
-
438
- The pack has:
439
-
440
- - dev scenarios,
441
- - official scenarios,
442
- - chase states,
443
- - defense states,
444
- - collapse states,
445
- - death-over states,
446
- - matchup states,
447
- - opponent config.
448
-
449
- ### Step 4: Generate Opponent Cache
450
-
451
- Script:
452
 
453
  ```bash
454
  python scripts/generate_opponent_cache.py \
@@ -458,199 +306,90 @@ python scripts/generate_opponent_cache.py \
458
  --output data/opponent_cache/adaptive_t20_v1.jsonl
459
  ```
460
 
461
- For official evaluation, use `llm_cached` mode with a fixed cache.
462
-
463
- ## 10. Training Pipeline
464
-
465
- The training plan uses SFT only as a bootstrapping stage. The main optimization path remains GRPO.
466
-
467
- ### Stage 0: SFT Tool Warmup
468
 
469
- Command:
470
 
471
  ```bash
472
- python train.py sft-data --output data/training/tool_sft_examples.jsonl
 
 
473
  ```
474
 
475
- Purpose:
476
-
477
- - teach valid JSON,
478
- - teach tool names,
479
- - teach argument schemas,
480
- - reduce parse errors before RL.
481
-
482
- This does not replace RL. It makes RL less wasteful.
483
-
484
- ### Stage 1: GRPO Format / Tool Correctness
485
 
486
- Command:
487
 
488
  ```bash
489
- python train.py train --stage 1 --steps 200 --model Qwen/Qwen2.5-7B-Instruct
 
490
  ```
491
 
492
- Purpose:
493
-
494
- - train valid tool calls,
495
- - reduce invalid JSON,
496
- - stabilize action format.
497
-
498
- ### Stage 2: GRPO Strategic Behavior
499
-
500
- Command:
501
 
502
  ```bash
503
- python train.py train --stage 2 --steps 600 --model ./checkpoints/stage1_final
 
504
  ```
505
 
506
- Purpose:
507
-
508
- - improve coherence,
509
- - improve adaptation,
510
- - improve opponent awareness,
511
- - improve tool efficiency,
512
- - improve match result quality.
513
-
514
- ### Evaluation
515
-
516
- Command:
517
 
518
  ```bash
519
- python eval.py --episodes 20 --env-url "$CRICKET_CAPTAIN_ENV_URL" --eval-pack-id adaptive_t20_v1
520
  ```
521
 
522
- Expected comparison:
523
-
524
- - random/untrained baseline,
525
- - SFT-warmed model,
526
- - GRPO-trained model.
527
-
528
- For the competition, we should produce plots showing:
529
-
530
- - reward over training,
531
- - parse error rate,
532
- - coherence,
533
- - adaptation,
534
- - opponent awareness,
535
- - score/chase/defense metrics.
536
-
537
- ## 11. How This Complies With The Competition Instructions
538
-
539
- The competition requires:
540
-
541
- ### Use OpenEnv
542
-
543
- Implemented through:
544
-
545
- - `server/app.py`
546
- - `server/cricket_environment.py`
547
- - `models.py`
548
- - `client.py`
549
-
550
- The environment follows `reset`, `step`, and `state`.
551
-
552
- ### Training Script With HF TRL / Unsloth
553
-
554
- Implemented through:
555
-
556
- - `train.py`
557
-
558
- It uses Hugging Face TRL GRPO paths when training dependencies are installed.
559
-
560
- ### Hosted Environment
561
-
562
- The repo has Hugging Face Spaces metadata in `README.md` and a Docker-based app path. The server binds to `0.0.0.0`, and remote clients should use `CRICKET_CAPTAIN_ENV_URL`.
563
-
564
- ### README With Problem / Environment / Results
565
-
566
- The README now explains:
567
-
568
- - problem statement,
569
- - tools,
570
- - reward architecture,
571
- - environment design,
572
- - data pipeline,
573
- - Lightning/HF runtime notes.
574
-
575
- Still needed for final submission:
576
-
577
- - actual HF Space URL,
578
- - training result plots,
579
- - mini-blog/video link.
580
-
581
- ### Show Improvement
582
-
583
- The environment and scripts support this, but the final artifact still needs a real training run with plots.
584
-
585
- Minimum evidence to add:
586
-
587
- - random baseline metrics,
588
- - trained model metrics,
589
- - reward curve,
590
- - parse error curve,
591
- - example before/after decisions.
592
-
593
- ## 12. Recommended Demo Story
594
-
595
- A simple judge-friendly demo:
596
-
597
- 1. Show a late chase scenario:
598
-
599
- ```text
600
- Over 16.0, 128/5, target 172
601
- ```
602
 
603
- 2. Random/untrained model:
604
 
605
- - may choose invalid tools,
606
- - may attack blindly,
607
- - may ignore field/opponent.
608
 
609
- 3. Trained/adaptive model:
610
 
611
- - checks target pressure,
612
- - selects finisher,
613
- - plans boundary zones,
614
- - responds after wicket/boundary,
615
- - changes risk level.
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
- 4. Show metrics:
618
 
619
- - parse errors down,
620
- - adaptation up,
621
- - opponent awareness up,
622
- - reward up.
623
 
624
- This tells the story clearly:
 
 
 
 
625
 
626
- > The model learned to captain, not just output JSON.
627
 
628
- ## 13. Current Status
 
 
 
629
 
630
- Implemented:
631
 
632
- - OpenEnv environment.
633
- - Rich strategic tool surface.
634
- - Opponent policies.
635
- - Adaptive eval pack.
636
- - T20 data curation script.
637
- - Player profile builder.
638
- - Opponent cache generator.
639
- - GRPO training script path.
640
- - SFT bootstrap data generator.
641
- - Eval and plotting scripts.
642
 
643
- Verified with smoke tests:
644
 
645
- - Python compile checks.
646
- - Lint checks.
647
- - `train.py test`.
648
- - `train.py sft-data`.
649
- - opponent cache generation.
650
- - server startup.
651
- - random inference run.
652
- - eval run with plots.
653
 
654
- Next major task:
655
 
656
- Run real training on compute and commit the resulting plots/metrics for the final submission.
 
30
 
31
  ## 2. Fit With OpenEnv Competition Themes
32
 
 
 
33
  ### Multi-Agent Interactions
34
 
35
  The submitted captain agent plays against an opponent policy. The opponent can be:
36
 
37
+ - `heuristic`: fast format-aware cricket logic (T5/T20/ODI rules).
38
+ - `cricsheet`: real Cricsheet ball-by-ball match data sampled by game context.
39
+ - `llm_live`: live OpenAI-compatible LLM opponent (google/gemma-4-26B-A4B-it via HF Router).
40
  - `llm_cached`: replayed opponent decisions for reproducible evaluation.
41
 
42
  This tests whether the agent can reason about another actor's incentives, field settings, and likely plans.
 
47
 
48
  ### World Modeling
49
 
50
+ The agent observes a partially summarized cricket world: score, over/ball, wickets, target, phase, field, batter profile, bowler profile, previous outcome. It must maintain an internal model of what is happening and update that model after every ball.
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  ### Self-Improvement
53
 
54
+ The same environment can support heuristic curriculum training, cached-opponent official evaluation, live LLM opponent self-play, and future agent-vs-agent training.
 
 
 
 
 
55
 
56
  ## 3. Environment Flow
57
 
 
67
  PRE_OVER -> PRE_BALL -> BALL_RESOLUTION -> POST_BALL -> next decision
68
  ```
69
 
 
 
70
  ### Toss
71
 
72
  ```json
 
76
  ### Batting Tools
77
 
78
  ```json
79
+ {"tool": "select_batter", "arguments": {"name": "Virat Kohli", "style": "anchor", "aggression": 0.35, "rationale": "Preserve wickets in the middle overs."}}
 
 
 
80
  {"tool": "set_strategy", "arguments": {"phase_intent": "consolidate", "aggression": 0.35, "rationale": "Rotate strike against spin and keep wickets in hand."}}
 
 
 
81
  {"tool": "plan_shot", "arguments": {"shot_intent": "single", "target_area": "midwicket", "risk": "low", "trajectory": "ground", "rationale": "Field is spread, so take the easy single."}}
 
 
 
82
  {"tool": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "Work into the gap."}}
83
  ```
84
 
85
+ `plan_shot` is **not** an overhead tool. Only `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, and `analyze_situation` count against the 3 / over limit (see Tool budget).
86
 
87
+ ### Bowling Tools
 
 
88
 
89
  ```json
90
+ {"tool": "choose_bowler", "arguments": {"name": "Jasprit Bumrah", "bowler_type": "pace", "style": "yorker", "rationale": "Attack the stumps at the death."}}
91
  {"tool": "set_bowling_strategy", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Limit swing room."}}
 
 
 
92
  {"tool": "set_field_setting", "arguments": {"setting": "Defensive"}}
 
 
 
93
  {"tool": "plan_delivery", "arguments": {"bowler_type": "pace", "line": "stumps", "length": "full", "delivery_type": "yorker", "rationale": "Protect boundaries and force a low-percentage shot."}}
 
 
 
94
  {"tool": "bowl_delivery", "arguments": {}}
95
  ```
96
 
 
106
  {"tool": "analyze_situation", "arguments": {"query_type": "match_situation"}}
107
  ```
108
 
109
+ ## 4. Tool budget
110
+
111
+ The environment enforces a **3-call overhead budget per over** (see `CricketEnvironment.TOOL_BUDGET_PER_OVER` and `TOOL_FINE_PER_EXCESS` in `server/cricket_environment.py`).
112
+
113
+ **Overhead tools** (increment the per-over counter; the 4th+ in the same over are fined):
114
+ `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, `analyze_situation`
115
+
116
+ **Not overhead** (do not use the 3 free “slots”):
117
+ `play_delivery`, `bowl_delivery`, `plan_shot`, `call_toss`, `select_batter`, `choose_bowler`, `set_field_setting`, `set_match_plan`, `update_match_plan`
118
 
119
+ Each overhead call **beyond the third in that over** incurs an immediate **−0.04** step reward. The prompt shows `Tool budget: N/3 overhead calls used this over`.
120
+
121
+ **Training connection:** `train.py train` uses real `CricketEnvironment` steps, so these fines are part of the return GRPO optimizes. That keeps long-horizon training aligned with the benchmark: agents must choose when to pay for `analyze_situation` and `reflect_after_ball`, while `set_match_plan` / `update_match_plan` let them carry structure across overs without spending overhead budget.
122
+
123
+ ## 5. OpenEnv Architecture
124
 
125
  ```text
126
  LLM Agent / Evaluator
127
  |
128
+ | WebSocket (OpenEnv)
129
  v
130
+ FastAPI server (server/app.py)
131
  |
132
  v
133
+ CricketEnvironment (server/cricket_environment.py)
134
  |
135
+ +--> MarkovCricketEngine (server/markov_engine.py)
136
+ +--> FormatMapper (server/format_mapper.py)
137
+ +--> OpponentPolicy (server/opponent_policy.py)
138
+ +--> PlayerRoster (server/player_roster.py)
139
+ +--> CoherenceGrader (server/coherence_grader.py)
140
+ +--> RewardCalculator (server/reward_calculator.py)
141
+ +--> FieldModel (server/field_model.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  ```
143
 
144
+ Key files:
 
 
 
 
145
 
146
+ | File | Role |
147
+ |------|------|
148
+ | `server/app.py` | OpenEnv server entry point |
149
+ | `server/cricket_environment.py` | `reset`, `step`, `state` implementation |
150
+ | `server/format_mapper.py` | T5/T20/ODI closest-format selector; phase-aware shot weights, batter/bowler roles |
151
+ | `server/opponent_policy.py` | Heuristic, Cricsheet, live LLM, cached LLM opponent policies |
152
+ | `server/player_roster.py` | Fuzzy player lookup; batter/bowler profile extractor |
153
+ | `models.py` | `CricketAction`, `CricketObservation`, `CricketState` |
154
+ | `client.py` | WebSocket client `CricketCaptainEnv` |
155
+ | `inference.py` | Random + LLM agent evaluation |
156
+ | `train.py` | MT-GRPO + SFT training pipeline |
157
+ | `eval.py` | Coherence heatmaps, reward curves, tool analytics |
158
 
159
+ ## 6. Format-Aware Rules
160
 
161
+ `server/format_mapper.py` auto-selects T5 / T20 / ODI rules by `|max_overs − format_overs|`:
162
 
163
+ | Format | max_overs | Key differences |
164
+ |--------|-----------|-----------------|
165
+ | T5 | ≤ 7 | High-aggression throughout, powerplay dominates all overs |
166
+ | T20 | 8–35 | Three phases (PP/Middle/Death); spin-heavy middle |
167
+ | ODI | > 35 | Four phases (PP/Middle-early/Middle-late/Death); anchor roles |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ The format mapper provides:
170
+ - **Phase-aware shot weights**: boundary/six probability rises sharply in death overs
171
+ - **Batter roles** with `overs_active` windows (opener, anchor, middle_order, finisher)
172
+ - **Bowler roles** with `preferred_phases` (pace_opener, spin_controller, death_specialist)
173
+ - **Bowling strategy** per phase (line, length, delivery_type, field_setting)
174
 
175
+ Both the heuristic opponent and the `select_batter` / `choose_bowler` tools draw from these tables.
 
 
 
 
 
 
 
 
176
 
177
+ ## 7. Player Rosters
178
 
179
+ `server/player_roster.py` loads team profiles from `data/player_profiles/` — 10 T20I squads: India, Australia, England, Pakistan, South Africa, New Zealand, West Indies, Sri Lanka, Bangladesh, Afghanistan.
180
 
181
+ When the agent calls `select_batter` or `choose_bowler` with a player name, the roster performs **fuzzy lookup** (exact → surname → word-overlap) and fills in real aggression, batting/bowling style, and phase strengths from the profile.
182
 
183
+ ## 8. What The Observation Contains
184
 
185
+ Each step returns a `CricketObservation` with:
186
 
187
+ - `game_state`: toss / batting / bowling / finished
188
+ - `strategic_phase`: pre_over / pre_ball / ball_resolution / post_ball
189
+ - `game_context`: score, wickets, over, ball, target, phase, run_rate, req_rate
190
+ - `declared_strategy`: current batting strategy (aggression, intent, rationale)
191
+ - `bowling_strategy`: current bowling plan
192
+ - `field_setting`: Aggressive / Balanced / Defensive
193
+ - `current_batter`: batter profile (style, aggression, phase strengths)
194
+ - `current_bowler`: bowler profile
195
+ - `last_outcome`: ball outcome + tactical metadata (event type, shot zone, delivery features, field pressure, fielder effect)
196
+ - `available_tools`: legal tools for current state (phase-gated)
197
+ - `tool_budget`: overhead calls used this over vs 3-call limit
198
+ - `prompt_text`: rendered prompt for the LLM
199
 
200
+ The LLM sees enough information to reason tactically, but not simulator internals.
 
 
 
 
201
 
202
+ ## 9. Opponent Policies
203
 
204
+ Four modes in `server/opponent_policy.py`:
205
 
206
+ ### `heuristic`
207
 
208
+ Format-aware local policy using T5/T20/ODI rules from `format_mapper.py`. Picks shot intent from phase-weighted distributions, adjusts for wicket pressure (shifts conservative under 7+ wickets down), and selects batter/bowler roles by current over and format. Fast, no API key needed.
 
 
209
 
210
+ ### `cricsheet`
211
 
212
+ Samples real Cricsheet ball-by-ball deliveries indexed by `(phase, wickets_band, innings_type)`. Automatically selects T20 or ODI data based on `max_overs`:
213
+ - ≤ 25 overs �� `ball_outcomes_t20_v1.pkl` (1.17M T20 deliveries from 5,176 matches)
214
+ - > 25 overs → `ball_outcomes_odi_v1.pkl` (1.65M ODI deliveries from 3,116 matches)
215
 
216
+ Progressive fallback widening (drop innings_type drop wickets_band → any phase record) ensures no dead buckets. Heuristic fallback if data file absent.
217
 
218
+ ### `llm_live`
219
 
220
+ Calls `google/gemma-4-26B-A4B-it` via HF Router (or any OpenAI-compatible API). Graceful heuristic fallback when no API key is present, so local development never breaks.
 
221
 
222
+ ### `llm_cached`
223
 
224
+ Replays pre-recorded opponent decisions from JSONL. Does **not** call the configured model live. Use for official leaderboard-style evaluation where every compared captain faces identical opponent decisions.
225
 
226
+ ## 10. Ball Physics And Markov Engine
 
 
 
 
227
 
228
+ The simulation uses `server/markov_engine.py` plus field/zone definitions in `server/field_model.py`.
229
 
230
+ Ball transition tables keyed by `(over, wickets, score_band, phase, bowler_type)`:
231
+ 1. **Cricsheet-derived**: `data/processed/cricket_transitions_v1.pkl` when available
232
+ 2. **Calibrated synthetic**: `data/transition_probs.json` as fallback
233
 
234
+ After the base Markov draw, a **hybrid tactical layer** applies:
235
+ - Shot target zones (`cover`, `point`, `midwicket`, `long_on`, …) matched against delivery line/length/variation
236
+ - Field presets (`Aggressive`, `Balanced`, `Defensive`) expand into named fielder zones
237
+ - Boundary riders cut off fours/sixes; inner-ring fielders save singles; slips/gully convert edges
238
+ - Wides/no-balls, drops, misfields, overthrows, run-outs, caught-in-zone events add bounded stochastic noise
239
+ - High chase pressure makes defensive batting less useful
240
 
241
+ ## 11. Reward Design
242
 
243
+ Four-rubric composite reward:
244
 
245
+ | Rubric | Weight | Frequency | Measures |
246
+ |--------|--------|-----------|----------|
247
+ | `r_cricket` | **45%** | Per ball | Dream11 proxy: runs, wickets, dots, milestones, economy, strike rate |
248
+ | `r_behavior` | **25%** | Every turn | Coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%) |
249
+ | `r_result` | **20%** | Innings/episode end | Win/loss vs DLS par, target margin, wickets |
250
+ | `r_validity` | **10%** | Every turn | Valid tool-call structure and legal phase-gated tool use |
251
 
252
+ Plus a **progress bonus** added to `r_result`: `min(0.25, tool_calls_made / 40.0)` — caps at +0.25 once the agent makes ≥10 tool calls. Directly rewards escaping the planning-loop trap (where the policy maxes overhead tools without ever calling `play_delivery`).
253
 
254
+ **Why these weights** (rebalanced from the original 55/25/15/5): partial-trajectory training means `r_result` rarely fires (episodes truncate before completion). Putting 55% weight on a signal that fires <5% of the time washes out the gradient. The new 45/25/20/10 split mirrors the SWE-RL recipe (60% intermediate / 40% terminal) and matches what working coding-agent RL setups actually use.
255
 
256
+ `r_tools` is computed and logged but excluded from the composite — tool discipline is measured through outcomes.
257
 
258
+ ### Coherence Scoring (batting)
259
 
 
 
 
 
 
 
260
  ```
261
+ coherence = aggression_match × rationale_specificity × phase_appropriate
262
+ aggression_match = 1 |declared_aggression shot_aggression_proxy|
263
+ rationale_specificity = (word_count_score + cricket_keyword_density) / 2
264
+ phase_appropriate = 1 |declared_aggression − phase_baseline|
265
+ phase_baselines: powerplay=0.55, middle=0.35, death=0.75
 
 
 
 
 
 
 
 
266
  ```
267
 
268
+ ### Single-Stage Training with Format Curriculum
 
 
 
269
 
270
+ The original two-stage (format strategy) curriculum was collapsed because Qwen3.5-4B
271
+ already does tool calling natively (XML+JSON via `_parse_completion`). The full composite
272
+ reward fires from step 0.
273
 
274
+ What remains is a **format-length curriculum within the warmup config**: per-scenario
275
+ `max_overs` is sampled from `[2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]` (heavy on T2-T3 so episodes
276
+ actually complete inside the token budget). The main run then trains on full T20 (20-over)
277
+ matches, optionally resuming from the warmup adapter.
 
 
 
 
 
 
 
 
278
 
279
+ ## 12. Data Curation Pipeline
280
 
281
+ ### Step 1: Curate Ball Outcomes
282
 
283
  ```bash
284
+ python scripts/curate_transitions.py --format t20 # → ball_outcomes_t20_v1.pkl
285
+ python scripts/curate_transitions.py --format odi # → ball_outcomes_odi_v1.pkl
286
  ```
287
 
288
+ Both files already generated:
289
+ - `data/processed/ball_outcomes_t20_v1.pkl` — 1.17M T20 deliveries, 5,176 matches
290
+ - `data/processed/ball_outcomes_odi_v1.pkl` — 1.65M ODI deliveries, 3,116 matches
291
+ - `data/processed/cricket_transitions_v1.pkl` — 5,138 Markov keys, 2,878 high-confidence
 
 
 
 
 
 
 
 
 
292
 
293
+ ### Step 2: Build Evaluation Pack
 
 
294
 
295
  ```bash
296
  python scripts/build_eval_pack.py --eval-pack-id adaptive_t20_v1
297
  ```
298
 
299
+ ### Step 3: Generate Opponent Cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  ```bash
302
  python scripts/generate_opponent_cache.py \
 
306
  --output data/opponent_cache/adaptive_t20_v1.jsonl
307
  ```
308
 
309
+ ## 13. Training Pipeline
 
 
 
 
 
 
310
 
311
+ ### Recommended: Single-Command Chain
312
 
313
  ```bash
314
+ # Warmup (5-over curriculum, 25 steps) → Main (20-over T20, 100 steps).
315
+ # Main auto-resumes from warmup adapter at ./checkpoints/stage2_final.
316
+ bash scripts/run_warmup_then_main.sh
317
  ```
318
 
319
+ ### Run Components Individually
 
 
 
 
 
 
 
 
 
320
 
321
+ **Warmup only — short curriculum, bootstraps the LoRA adapter:**
322
 
323
  ```bash
324
+ PYTORCH_ALLOC_CONF=expandable_segments:True \
325
+ python train.py train --config configs/cricket_train_warmup.yaml
326
  ```
327
 
328
+ **Main only — full T20, resumes warmup adapter (or fresh if `resume_from` is empty):**
 
 
 
 
 
 
 
 
329
 
330
  ```bash
331
+ PYTORCH_ALLOC_CONF=expandable_segments:True \
332
+ python train.py train --config configs/cricket_train.yaml
333
  ```
334
 
335
+ **Optional SFT bootstrap** (legacy, not needed for Qwen3.5-4B which has native tool calling):
 
 
 
 
 
 
 
 
 
 
336
 
337
  ```bash
338
+ python train.py sft-data --output data/training/tool_sft_examples.jsonl
339
  ```
340
 
341
+ `train.py train` uses TRL `GRPOTrainer` with `environment_factory=CricketCaptainToolEnv`. The captain being trained is loaded locally by Transformers/TRL and interacts with live environment instances through tool methods. `opponent-mode llm_live` affects only the adversary; it does not mean the trained captain is served through the HF inference endpoint.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ The default training model is `Qwen/Qwen3.5-4B`. The default live opponent model is `google/gemma-4-26B-A4B-it`. Roster-backed training requires `--agent-team` or `env.agent_team` in YAML so `select_batter` and `choose_bowler` use real player profiles instead of generic names.
344
 
345
+ ## 14. Current Status (2026-04-25)
 
 
346
 
347
+ ### Implemented and verified
348
 
349
+ | Component | Status |
350
+ |-----------|--------|
351
+ | OpenEnv server + client | ✅ |
352
+ | 14-tool strategic surface | ✅ |
353
+ | 4-rubric reward system | ✅ |
354
+ | Tool budget system (3/over, −0.04 fine) | ✅ |
355
+ | Format mapper (T5/T20/ODI) | ✅ |
356
+ | Player rosters (10 T20I teams, fuzzy lookup) | ✅ |
357
+ | Cricsheet T20 data (1.17M deliveries) | ✅ |
358
+ | Cricsheet ODI data (1.65M deliveries) | ✅ |
359
+ | Heuristic opponent (format-aware) | ✅ |
360
+ | Cricsheet opponent (T20+ODI, context-indexed) | ✅ |
361
+ | LLM live opponent (HF Router / OpenAI-compatible API) | ✅ |
362
+ | LLM cached opponent | ✅ |
363
+ | GRPO training script (`environment_factory` agent rollouts) | ✅ |
364
+ | SFT data generator | ✅ |
365
+ | Gradio demo UI | ✅ |
366
+ | Colab training notebook | ✅ |
367
 
368
+ ### Verified end-to-end (2026-04-25)
369
 
370
+ All 3 opponent modes verified at 5-over inference + train-smoke:
 
 
 
371
 
372
+ | Mode | inference parse_err | train-smoke r_validity | coherence |
373
+ |------|--------------------|-----------------------|-----------|
374
+ | heuristic | 0% | 1.0 | 0.556 |
375
+ | cricsheet | 0% | 1.0 | 0.620 |
376
+ | llm_live | 0% | 1.0 | 0.537 |
377
 
378
+ ### Pending for submission
379
 
380
+ - Real GRPO training run with reward curves (requires HF compute)
381
+ - HF Space deployment URL
382
+ - Training-vs-baseline comparison plots
383
+ - Mini-blog / video
384
 
385
+ ## 15. Recommended Demo Story
386
 
387
+ 1. **Show a late chase scenario**: Over 16.0, 128/5, target 172
 
 
 
 
 
 
 
 
 
388
 
389
+ 2. **Random/untrained model**: invalid tools, blind aggression, ignores field/opponent
390
 
391
+ 3. **Trained model**: checks target pressure → selects finisher → plans boundary zones → responds after wicket → changes risk level
 
 
 
 
 
 
 
392
 
393
+ 4. **Show metrics**: parse errors ↓, coherence ↑, adaptation ↑, opponent_awareness ↑, reward ↑
394
 
395
+ > The model learned to captain, not just emit a valid tool-call object.
docs/experiment_workflow.md CHANGED
@@ -1,12 +1,12 @@
1
  # Experiment Workflow: Baselines, Opponents, Short Runs, and Training
2
 
3
- This document explains how to run CricketCaptain experiments in a practical order: first short 2-over smoke checks, then 5-over untrained model baselines, then GRPO training, then longer 20-over evaluation.
4
 
5
- ## 1. Why Start With 2-Over Smoke + 5-Over Baselines?
6
 
7
- A full T20 innings is 20 overs. That is useful for final evaluation, but it is slow and noisy for debugging.
8
 
9
- For early code-path experiments, 2-over smoke runs are better because they quickly answer:
10
 
11
  - Is the OpenEnv server working?
12
  - Is the client connecting correctly?
@@ -19,461 +19,267 @@ For early code-path experiments, 2-over smoke runs are better because they quick
19
  The workflow should be:
20
 
21
  ```text
22
- 2-over smoke -> 5-over untrained baseline -> short training -> 5-over trained eval -> 20-over final eval
23
  ```
24
 
25
- Do not start with full 20-over training unless the 2-over smoke and 5-over baseline loops are stable.
26
 
27
- ## 2. Current Opponent Modes
28
 
29
- Opponent behavior is controlled by YAML (`configs/default.yaml` / `configs/cached_eval.yaml`), `CRICKET_OPPONENT_MODE`, or `--opponent-mode`.
30
 
31
- The code supports three modes in `server/opponent_policy.py`.
32
 
33
- ## 2.1 Heuristic Opponent
34
 
35
  ```bash
36
- export CRICKET_OPPONENT_MODE=heuristic
37
  ```
38
 
39
- This is the current default.
 
 
 
40
 
41
- It is not a live LLM. It is a deterministic/stochastic cricket policy that chooses sensible batting and bowling plans based on phase, wickets, target, batter style, and field.
 
42
 
43
- Use it for:
44
-
45
- - fast local tests,
46
- - smoke tests,
47
- - cheap training rollouts,
48
- - deterministic-ish baselines.
49
-
50
- Pros:
51
-
52
- - cheap,
53
- - fast,
54
- - no API key,
55
- - stable enough for debugging.
56
-
57
- Cons:
58
-
59
- - less realistic than an LLM opponent,
60
- - less diverse.
61
-
62
- ## 2.2 Live LLM Opponent
63
 
64
  ```bash
65
- export CRICKET_OPPONENT_MODE=llm_live
66
- export CRICKET_OPPONENT_MODEL=google/gemma-4-26B-A4B-it
67
- export CRICKET_OPPONENT_API_BASE=https://router.huggingface.co/v1
68
- export HF_TOKEN=...
69
  ```
70
 
71
- The default live opponent model in code/config is:
72
-
73
- ```text
74
- google/gemma-4-26B-A4B-it
75
- ```
76
-
77
- This mode calls an OpenAI-compatible API from `LLMOpponentPolicy`.
78
-
79
- Use it for:
80
 
81
- - demos,
82
- - more realistic opponent behavior,
83
- - generating cached opponent decisions,
84
- - future self-play-style experiments.
85
 
86
- Pros:
 
87
 
88
- - more realistic strategic behavior,
89
- - can react with natural tactical reasoning,
90
- - good for storytelling/demo.
91
-
92
- Cons:
93
-
94
- - costs API calls,
95
- - can be non-deterministic,
96
- - not ideal for official evaluation directly.
97
-
98
- ## 2.3 Cached LLM Opponent
99
 
100
  ```bash
101
- export CRICKET_OPPONENT_MODE=llm_cached
102
- export CRICKET_OPPONENT_CACHE=data/opponent_cache/adaptive_t20_v1_official_gemma2b.jsonl
103
  ```
104
 
105
- This mode replays pre-recorded opponent decisions. It does **not** call `CRICKET_OPPONENT_MODEL` live during the run; `cache_path` is the source of opponent behavior.
106
 
107
- Use it for:
108
 
109
- - official/fair evaluation,
110
- - leaderboard-style comparison,
111
- - reproducible experiments.
112
 
113
- Recommended official flow:
114
-
115
- ```text
116
- llm_live once -> save opponent decisions -> llm_cached for all model comparisons
117
- ```
118
-
119
- This gives the benefit of an LLM opponent while ensuring every model faces the same opponent decisions.
120
-
121
- ## 3. What Model Is The Opposite Team?
122
-
123
- Currently, the default live opponent is:
124
-
125
- ```text
126
- google/gemma-4-26B-A4B-it via https://router.huggingface.co/v1
127
  ```
128
 
129
- Important distinction:
130
-
131
- ```text
132
- llm_live -> calls the configured model during the run
133
- llm_cached -> ignores live model calls and replays cache_path
134
- heuristic -> uses local rule-based cricket policy
135
- ```
136
 
137
- This can be changed with:
138
 
139
  ```bash
140
- export CRICKET_OPPONENT_MODE=llm_live
141
- export CRICKET_OPPONENT_MODEL=<model-name>
142
- ```
143
-
144
- For example:
145
 
146
- ```bash
147
- export CRICKET_OPPONENT_MODEL=google/gemma-4-26B-A4B-it
148
  ```
149
 
150
- or with another OpenAI-compatible server:
151
 
 
152
  ```bash
153
- export CRICKET_OPPONENT_API_BASE=http://localhost:8080/v1
154
- export CRICKET_OPPONENT_MODEL=<local-model-name>
155
  ```
156
 
157
- ## 4. Baseline First: Random Agent
158
-
159
- Before testing any trained model, run the random baseline.
160
-
161
- Start the server:
162
-
163
  ```bash
164
- PYTHONPATH=. python server/app.py
165
  ```
166
 
167
- Run baseline:
 
 
168
 
169
  ```bash
170
- PYTHONPATH=. python inference.py \
171
- --model random \
172
- --episodes 5 \
173
- --env-url "$CRICKET_CAPTAIN_ENV_URL" \
174
- --opponent-mode heuristic \
175
- --eval-pack-id adaptive_t20_v1
176
  ```
177
 
178
- Track:
179
 
180
- - `total_reward`
181
- - `mean_coherence`
182
- - `adaptation`
183
- - `parse_error_rate`
184
- - score/wickets
185
 
186
- This tells us whether the environment works and gives a baseline to beat.
187
 
188
- ## 5. Untrained LLM Baseline
189
-
190
- Next, evaluate a base model without training.
191
-
192
- Example:
193
 
194
  ```bash
195
- PYTHONPATH=. python inference.py \
196
- --model Qwen/Qwen2.5-7B-Instruct \
197
- --episodes 5 \
198
- --env-url "$CRICKET_CAPTAIN_ENV_URL" \
199
- --opponent-mode heuristic \
200
- --eval-pack-id adaptive_t20_v1
201
  ```
202
 
203
- This shows what a general instruction model can do before any RL.
204
-
205
- Expected weaknesses:
206
-
207
- - may output invalid JSON,
208
- - may choose wrong tools,
209
- - may ignore opponent plan,
210
- - may be verbose instead of tool-only,
211
- - may not adapt after bad outcomes.
212
-
213
- ## 6. Why SFT Exists If We Use GRPO
214
-
215
- SFT is not the main training objective. It is a warmup.
216
-
217
- GRPO should optimize strategic behavior, but if the model cannot produce valid tool JSON, GRPO wastes rollouts learning syntax.
218
-
219
- SFT helps the model learn:
220
-
221
- - valid JSON shape,
222
- - available tools,
223
- - argument schemas,
224
- - when tools are legal,
225
- - one-tool-call responses.
226
-
227
- Then GRPO can focus on:
228
-
229
- - coherence,
230
- - adaptation,
231
- - opponent awareness,
232
- - match result quality.
233
 
234
- Recommended stack:
235
 
236
- ```text
237
- SFT/tool warmup -> GRPO stage 1 -> GRPO stage 2 -> eval
238
  ```
239
 
240
- ## 7. Stage 0: Generate SFT Tool Data
241
 
242
  ```bash
243
- python train.py sft-data \
244
- --output data/training/tool_sft_examples.jsonl \
245
- --examples 500
 
 
 
 
 
246
  ```
247
 
248
- This creates supervised examples for:
249
-
250
- - toss,
251
- - batting tools,
252
- - bowling tools,
253
- - reflection,
254
- - analysis.
255
-
256
- These examples are useful for quick tool-format finetuning.
257
 
258
- ## 8. Stage 1: GRPO Format / Tool Training
259
 
260
  ```bash
261
- python train.py train \
262
- --stage 1 \
263
- --steps 100 \
264
- --prompts 200 \
265
- --model Qwen/Qwen2.5-7B-Instruct
266
  ```
267
 
268
- Goal:
269
-
270
- - reduce parse errors,
271
- - make tool calls valid,
272
- - stabilize action format.
273
-
274
- Metrics to watch:
275
 
276
- - format reward,
277
- - parse error rate,
278
- - invalid tool rate.
279
 
280
- ## 9. Stage 2: GRPO Strategic Training
 
 
281
 
282
  ```bash
283
- python train.py train \
284
- --stage 2 \
285
- --steps 200 \
286
- --prompts 300 \
287
- --model ./checkpoints/stage1_final
288
  ```
289
 
290
- Goal:
 
291
 
292
- - improve plan-action coherence,
293
- - improve adaptation,
294
- - improve opponent awareness,
295
- - improve tool efficiency,
296
- - improve cricket result quality.
297
 
298
- Metrics to watch:
299
-
300
- - total reward,
301
- - coherence,
302
- - adaptation,
303
- - opponent awareness,
304
- - regret-style score,
305
- - score/wickets,
306
- - chase/defense success.
307
-
308
- ## 10. 5-Over vs 20-Over Evaluation
309
-
310
- ### 5-Over Evaluation
311
-
312
- Use for:
313
-
314
- - debugging,
315
- - model sanity checks,
316
- - comparing before/after quickly,
317
- - cheap experiments.
318
-
319
- Both `inference.py` and `eval.py` support `--max-overs`, and the YAML configs set `max_overs: 5` by default for quick iteration.
320
-
321
- Random captain sanity check:
322
 
323
  ```bash
324
- PYTHONPATH=. python inference.py \
325
- --model random \
326
- --episodes 5 \
327
- --max-overs 5 \
328
- --env-url "$CRICKET_CAPTAIN_ENV_URL" \
329
- --eval-pack-id adaptive_t20_v1 \
330
- --opponent-mode llm_cached
331
  ```
332
 
333
- HF Gemma captain with live HF inference:
 
334
 
335
- ```bash
336
- export HF_TOKEN="hf_..."
337
- PYTHONPATH=. python inference.py \
338
- --model google/gemma-4-26B-A4B-it \
339
- --api-base https://router.huggingface.co/v1 \
340
- --api-key "$HF_TOKEN" \
341
- --episodes 1 \
342
- --max-overs 5 \
343
- --env-url "$CRICKET_CAPTAIN_ENV_URL" \
344
- --eval-pack-id adaptive_t20_v1 \
345
- --opponent-mode llm_cached
346
  ```
 
 
347
 
348
- ### 20-Over Evaluation
 
 
349
 
350
- Use for:
 
351
 
352
- - final benchmark,
353
- - README numbers,
354
- - competition evidence,
355
- - trained-vs-baseline comparison.
356
 
357
- Command:
358
 
359
  ```bash
360
- PYTHONPATH=. python eval.py \
361
  --episodes 20 \
362
- --env-url "$CRICKET_CAPTAIN_ENV_URL" \
363
- --eval-pack-id adaptive_t20_v1
364
- ```
365
-
366
- ## 11. Evaluation Pack
367
-
368
- The main adaptive pack is:
369
-
370
- ```text
371
- data/eval_packs/adaptive_t20_v1.json
372
  ```
373
 
374
- It contains:
375
 
376
- - 5 dev scenarios,
377
- - 60 official scenarios,
378
- - chase states,
379
- - defend states,
380
- - death-over states,
381
- - collapse states,
382
- - matchup states.
383
 
384
- Use dev scenarios for local iteration and official scenarios for final comparison.
 
 
 
 
 
385
 
386
- ## 12. Recommended Experiment Order
387
 
388
- Use this sequence:
389
 
390
- ```text
391
- 1. Random baseline, 5-over
392
- 2. Base LLM baseline, 5-over
393
- 3. Training-side rollout smoke, 1 match / 5 overs
394
- 4. SFT warmup
395
- 5. GRPO stage 1, short
396
- 6. GRPO stage 2, short
397
- 7. Trained eval, 5-over
398
- 8. Trained eval, 20-over
399
- 9. Cached LLM opponent official eval
400
- 10. Add plots and before/after examples to README
401
- ```
402
 
403
- Training-side smoke command:
404
 
405
- ```bash
406
- python train.py train-smoke \
407
- --matches 1 \
408
- --max-overs 2 \
409
- --max-steps 240 \
410
- --log-steps 90 \
411
- --eval-pack-id adaptive_t20_v1 \
412
- --opponent-mode heuristic
413
- ```
414
 
415
- This does not load a model or run GRPO. It verifies the rollout/prompt/reward path before spending GPU time. Smoke logs include timing fields (`t_elapsed`, `step_dt`, `since_prev`, `match_elapsed`, `avg_step_dt`) for latency analysis.
416
 
417
- ## 13. What To Show In The Final Submission
418
 
419
- For the OpenEnv competition, the strongest evidence is:
420
 
421
- | Evidence | Why It Matters |
422
- |---|---|
423
- | Random baseline | Shows the environment is non-trivial |
424
- | Base LLM baseline | Shows general LLM behavior before training |
425
- | Trained model metrics | Shows improvement |
426
- | Reward curve | Shows learning progress |
427
- | Parse error curve | Shows tool-use improvement |
428
- | Before/after examples | Makes the story clear |
429
- | Eval against cached opponent | Shows fairness/reproducibility |
430
 
431
- Minimum final numbers to report:
 
 
432
 
433
- - total reward,
434
- - parse error rate,
435
- - coherence,
436
- - adaptation,
437
- - opponent awareness,
438
- - score/wickets,
439
- - chase or defense success rate.
440
 
441
- ## 14. Latest Smoke-Test Evidence
442
 
443
- The current reproducible run artifacts live under `illustrations/`.
444
 
445
- ```text
446
- Random captain + cached LLM opponent:
447
- mean score: 13.5 across 2 episodes
448
- mean reward: 0.984
449
- mean coherence: 0.555
450
- parse errors: 0.0%
451
-
452
- HF Gemma 4 captain + cached LLM opponent:
453
- model: google/gemma-4-26B-A4B-it
454
- trace: 40 OpenEnv turns with reset/step/action logs
455
- score: 7/0 after 2.2 overs
456
- reward sum: 0.168
457
- coherence: 0.657
458
- adaptation: 0.502
459
- opponent awareness: 0.750
460
- parse errors: 0.0%
461
-
462
- Latest captured training-side smoke, 1 match / 5 overs:
463
- first innings: opponent 30/6, target 31
464
- first-innings reward: +0.170 from par/run-rate/wicket context
465
- chase: 26/1 in 5 overs
466
- match result: loss
467
- terminal reward: 0.634 (r_cric=0.759, r_dream11=1.317, r_strategy=0.536)
468
- tactical events: deep-cover save, edge/catch chances, no-ball, misfield, caught-in-zone
469
  ```
470
 
471
- These are smoke checks, not final leaderboard numbers. They demonstrate that OpenEnv websocket interaction, HF router inference, tool-call parsing, cached opponent replay, observation updates, opponent plans, target/run-rate context, field-aware tactical outcomes, rule-gated bowler/batter changes, timing instrumentation, and reward metrics are all functioning.
472
-
473
- ## 15. Immediate Next Engineering Improvement
474
-
475
- Next useful work:
476
 
477
- - Generate a fresh cached-opponent file using `google/gemma-4-26B-A4B-it` in `llm_live` mode.
478
- - Run a 5-over base-model comparison across random, Gemma 4, and one trained checkpoint.
479
- - Move the strongest setup to 20-over evaluation.
 
 
1
  # Experiment Workflow: Baselines, Opponents, Short Runs, and Training
2
 
3
+ This document explains how to run CricketCaptain experiments in a practical order: smoke checks, 5-over baselines, training, then longer evaluation.
4
 
5
+ ## 1. Why Start With 5-Over Smoke + Baselines?
6
 
7
+ A full T20 innings is 20 overs. That is useful for final evaluation but slow for debugging.
8
 
9
+ For early experiments, 5-over runs are better because they quickly answer:
10
 
11
  - Is the OpenEnv server working?
12
  - Is the client connecting correctly?
 
19
  The workflow should be:
20
 
21
  ```text
22
+ 5-over smoke 5-over untrained baseline short training 5-over trained eval 20-over final eval
23
  ```
24
 
25
+ Do not start with full 20-over training unless the 5-over smoke loop is stable.
26
 
27
+ ## 2. Opponent Modes
28
 
29
+ Four modes in `server/opponent_policy.py`. Controlled via `--opponent-mode`, `CRICKET_OPPONENT_MODE`, or `configs/default.yaml`.
30
 
31
+ **Default is `llm_live`** in `configs/default.yaml` so training can face a real LLM opponent when credentials are present. For cheap/local checks, explicitly pass `--opponent-mode heuristic`.
32
 
33
+ ### 2.1 Heuristic Opponent
34
 
35
  ```bash
36
+ --opponent-mode heuristic
37
  ```
38
 
39
+ Format-aware local policy. Uses T5/T20/ODI rules from `server/format_mapper.py`:
40
+ - Phase-weighted shot distributions (powerplay/middle/death per format)
41
+ - Wicket-pressure shift (heavier weight toward defensive shots when 7+ down)
42
+ - Batter/bowler roles selected from `data/format_rules.json`
43
 
44
+ Use for: fast local tests, cheap training rollouts, deterministic-ish baselines.
45
+ **No API key needed.**
46
 
47
+ ### 2.2 Cricsheet Opponent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  ```bash
50
+ --opponent-mode cricsheet
 
 
 
51
  ```
52
 
53
+ Samples real Cricsheet ball-by-ball deliveries, indexed by `(phase, wickets_band, innings_type)`.
 
 
 
 
 
 
 
 
54
 
55
+ Auto-selects data by format:
56
+ - `max_overs 25` → `ball_outcomes_t20_v1.pkl` (1.17M T20 deliveries)
57
+ - `max_overs > 25` → `ball_outcomes_odi_v1.pkl` (1.65M ODI deliveries)
 
58
 
59
+ Progressive fallback: drop innings_type → drop wickets_band → any phase record.
60
+ **No API key needed.** Data files must be present under `data/processed/`.
61
 
62
+ ### 2.3 Live LLM Opponent
 
 
 
 
 
 
 
 
 
 
63
 
64
  ```bash
65
+ export HF_TOKEN=hf_...
66
+ --opponent-mode llm_live
67
  ```
68
 
69
+ Calls `google/gemma-4-26B-A4B-it` via HF Router (or any OpenAI-compatible endpoint). Set `HF_TOKEN` or pass `--opponent-api-key`; otherwise use `--opponent-mode heuristic` for local runs.
70
 
71
+ Use for: demos, realistic opponent behavior, self-play experiments.
72
 
73
+ ### 2.4 Cached LLM Opponent
 
 
74
 
75
+ ```bash
76
+ --opponent-mode llm_cached
77
+ export CRICKET_OPPONENT_CACHE=data/opponent_cache/adaptive_t20_v1.jsonl
 
 
 
 
 
 
 
 
 
 
 
78
  ```
79
 
80
+ Replays pre-recorded decisions. Does **not** call any live model. Use for official/reproducible eval — every compared captain faces identical opponent decisions.
 
 
 
 
 
 
81
 
82
+ ## 3. Starting The Server
83
 
84
  ```bash
85
+ # Recommended (uvicorn auto-reload)
86
+ cd cricket_captain
87
+ python -m uvicorn server.app:app --port 8001
 
 
88
 
89
+ # Or via app.py directly
90
+ PYTHONPATH=. python server/app.py --port 8001
91
  ```
92
 
93
+ Health check: `curl http://localhost:8001/health` → `{"status":"healthy"}`
94
 
95
+ Set the URL for runners:
96
  ```bash
97
+ export CRICKET_CAPTAIN_ENV_URL=http://localhost:8001
 
98
  ```
99
 
100
+ On Lightning / remote runtimes, expose the port and pass the external URL:
 
 
 
 
 
101
  ```bash
102
+ export CRICKET_CAPTAIN_ENV_URL=ws://<lightning-exposed-host>/ws
103
  ```
104
 
105
+ ## 4. Step 1: Random Baseline (all 3 local modes)
106
+
107
+ No API key, no GPU needed. Verify the full loop works.
108
 
109
  ```bash
110
+ # Run all 3 modes in parallel
111
+ python inference.py --model random --episodes 5 --max-overs 5 --opponent-mode heuristic --env-url http://localhost:8001
112
+ python inference.py --model random --episodes 5 --max-overs 5 --opponent-mode cricsheet --env-url http://localhost:8001
113
+ python inference.py --model random --episodes 5 --max-overs 5 --opponent-mode llm_live --env-url http://localhost:8001
 
 
114
  ```
115
 
116
+ **Verified baselines (2026-04-25):**
117
 
118
+ | Opponent | score | coherence | reward | parse_err |
119
+ |----------|-------|-----------|--------|-----------|
120
+ | heuristic | 20.8 | 0.556 | −0.826 | 0% |
121
+ | cricsheet | 28.0 | 0.527 | −0.410 | 0% |
122
+ | llm_live | 27.4 | 0.537 | −0.723 | 0% |
123
 
124
+ ## 5. Step 2: Train-Smoke (verify reward signals, no GPU)
125
 
126
+ `train.py train-smoke` runs direct `CricketEnvironment` rollouts — **no server needed**.
 
 
 
 
127
 
128
  ```bash
129
+ python train.py train-smoke --matches 3 --max-overs 5 --opponent-mode heuristic
130
+ python train.py train-smoke --matches 3 --max-overs 5 --opponent-mode cricsheet
131
+ python train.py train-smoke --matches 3 --max-overs 5 --opponent-mode llm_live
 
 
 
132
  ```
133
 
134
+ **Verified train-smoke baselines (2026-04-25):** r_validity=1.0 on all 9 matches (3 modes × 3 matches). All reward signals active: coherence, adaptation, opponent_awareness, plan_commitment, staleness, regret.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ Quick 2-over smoke for CI-style checks:
137
 
138
+ ```bash
139
+ python train.py train-smoke --matches 1 --max-overs 2 --max-steps 240 --log-steps 90 --opponent-mode heuristic
140
  ```
141
 
142
+ ## 6. Step 3: Untrained LLM Baseline (requires HF token)
143
 
144
  ```bash
145
+ export HF_TOKEN=hf_...
146
+ python inference.py \
147
+ --model google/gemma-4-26B-A4B-it \
148
+ --api-base https://router.huggingface.co/v1 \
149
+ --api-key "$HF_TOKEN" \
150
+ --episodes 3 --max-overs 5 \
151
+ --opponent-mode llm_live \
152
+ --env-url http://localhost:8001
153
  ```
154
 
155
+ **Verified LLM captain run (2026-04-25):**
156
+ ```
157
+ model: google/gemma-4-26B-A4B-it via HF Router
158
+ coherence: 0.657 | adaptation: 0.502 | opp_aware: 0.750
159
+ parse errors: 0.0% | reward: +0.168
160
+ ```
 
 
 
161
 
162
+ ## 7. Step 4: SFT Tool Warmup
163
 
164
  ```bash
165
+ python train.py sft-data --output data/training/tool_sft_examples.jsonl --examples 500
 
 
 
 
166
  ```
167
 
168
+ Teaches tool-call shape, tool names, and argument schemas before RL. Not the main objective — just reduces wasted GRPO rollouts on syntax/tool-selection errors.
 
 
 
 
 
 
169
 
170
+ ## 8. Step 5: GRPO Warmup (5-over format)
 
 
171
 
172
+ The two-stage curriculum was collapsed into a single-stage full-reward run because Qwen3.5-4B
173
+ already supports tool calling natively (XML+JSON both accepted). What used to be "Stage 1" is
174
+ now a fast 5-over warmup, controlled entirely from YAML:
175
 
176
  ```bash
177
+ PYTORCH_ALLOC_CONF=expandable_segments:True \
178
+ python train.py train --config configs/cricket_train_warmup.yaml
 
 
 
179
  ```
180
 
181
+ Config: `max_overs=5`, `steps=25`, `num_generations=8`, `batch_size=8`, `max_completion_length=3072`,
182
+ `save_steps=5`, full composite reward from step 0. Approx 1.5–2 hrs on H200.
183
 
184
+ Goal: bootstrap the LoRA adapter on cheap short matches before the longer 20-over run.
185
+ Watch in WandB: `reward/composite_mean`, `tools/freq_*`, `rollout/match_completion_rate`,
186
+ `completions/clipped_ratio`.
 
 
187
 
188
+ ## 9. Step 6: GRPO Main Run (20-over T20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  ```bash
191
+ PYTORCH_ALLOC_CONF=expandable_segments:True \
192
+ python train.py train --config configs/cricket_train.yaml
 
 
 
 
 
193
  ```
194
 
195
+ Config: `max_overs=20`, `steps=100`, `num_generations=4`, `batch_size=8`, `max_completion_length=3072`,
196
+ `save_steps=10`, `beta=0.0` (no reference model). Approx 15-18 min/step (~30-40 steps fit a 10-hr budget).
197
 
198
+ To **resume from the warmup adapter**, uncomment in `configs/cricket_train.yaml`:
199
+ ```yaml
200
+ resume_from: ./checkpoints/stage2_final
 
 
 
 
 
 
 
 
201
  ```
202
+ or pass `--resume-from ./checkpoints/stage2_final` on the CLI. The base model still loads from
203
+ `Qwen/Qwen3.5-4B`; only the LoRA weights resume.
204
 
205
+ Goal: improve coherence, adaptation, opponent awareness, match outcomes on full T20s.
206
+ Watch in WandB: `reward/r_result_mean` (sparse outcome), `reward/r_coherence_mean`,
207
+ `reward/r_adaptation_mean`, `episode/tool_calls_mean` (should approach 720), `episode/agent_score_mean`.
208
 
209
+ Switch the opponent in YAML (`opponent.mode: llm_live`) and set `HF_TOKEN` for adversarial
210
+ training against live Gemma. Use `cricsheet` or `heuristic` for cheaper ablations.
211
 
212
+ `train.py train` uses TRL `GRPOTrainer(environment_factory=CricketCaptainToolEnv)`, so the model interacts with a live environment over multiple tool-calling turns. This is not inference through the HF Router; the trained captain model is loaded locally by Transformers/TRL, with LoRA adapters when using quantized weights.
 
 
 
213
 
214
+ ## 10. Step 7: Evaluation
215
 
216
  ```bash
217
+ python eval.py \
218
  --episodes 20 \
219
+ --env-url http://localhost:8001 \
220
+ --eval-pack-id adaptive_t20_v1 \
221
+ --opponent-mode llm_cached
 
 
 
 
 
 
 
222
  ```
223
 
224
+ Compare: random baseline → untrained Qwen3.5-4B → trained Qwen3.5-4B (warmup + main adapter via `compare_eval.py`).
225
 
226
+ ## 11. Format Comparison
 
 
 
 
 
 
227
 
228
+ | max_overs | Format selected | Data used by cricsheet | Typical target |
229
+ |-----------|----------------|------------------------|---------------|
230
+ | 5 | T5 | T20 pkl (closest) | ~47 runs |
231
+ | 7 | T5 | T20 pkl | ~66 runs |
232
+ | 20 | T20 | T20 pkl | ~160 runs |
233
+ | 50 | ODI | ODI pkl | ~290 runs |
234
 
235
+ All formats work with all opponent modes. Use `--max-overs N` with any runner.
236
 
237
+ ## 12. Tool budget and training
238
 
239
+ Implemented in `server/cricket_environment.py` (`TOOL_BUDGET_PER_OVER=3`, `TOOL_FINE_PER_EXCESS=0.04`).
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ **Overhead tools (only these increment the per-over counter):** `set_strategy`, `set_bowling_strategy`, `plan_delivery`, `reflect_after_ball`, `analyze_situation`.
242
 
243
+ **Not overhead:** `plan_shot`, `set_match_plan`, `update_match_plan`, `select_batter`, `choose_bowler`, `set_field_setting`, `play_delivery`, `bowl_delivery`, `call_toss`, and other tools that advance or directly execute the ball.
 
 
 
 
 
 
 
 
244
 
245
+ **Rule:** the first 3 overhead calls in each over are not fined; each further overhead call in that over adds **−0.04** to the step reward. The prompt includes `Tool budget: N/3 overhead calls used this over`.
246
 
247
+ **Why this matters for GRPO:** training uses the same environment as inference. Fines are part of the reward the trainer optimizes, so the policy learns to use reflection and `analyze_situation` when they matter, and to lean on `plan_shot` plus match-level planning (`set_match_plan` / `update_match_plan`) for routine structure without spending the 3 free overhead calls every ball.
248
 
249
+ ## 13. Configs
250
 
251
+ ```bash
252
+ # Start server with default config (llm_live opponent, 5-over default)
253
+ PYTHONPATH=. python server/app.py --port 8001 --config configs/default.yaml
 
 
 
 
 
 
254
 
255
+ # Start with reproducible cached eval
256
+ PYTHONPATH=. python server/app.py --port 8001 --config configs/cached_eval.yaml
257
+ ```
258
 
259
+ Config controls: `env.agent_team`, `env.max_overs`, `env.eval_pack_id`, `train.model`, `train.max_completion_length`, `train.max_tool_calling_iterations`, `opponent.mode`, `opponent.model`, `opponent.api_base`, `captain.model`, and `captain.api_base`.
 
 
 
 
 
 
260
 
261
+ ## 14. Latest Verified Run Results (2026-04-25)
262
 
263
+ All runs in `illustrations/`. Zero parse errors across all 14 inference runs. r_validity=1.0 across all train-smoke matches.
264
 
265
+ ```
266
+ Random agent 5-over, heuristic: score=20.8 coherence=0.556 reward=−0.826
267
+ Random agent 5-over, cricsheet: score=28.0 coherence=0.527 reward=−0.410
268
+ Random agent — 5-over, llm_live: score=27.4 coherence=0.537 reward=−0.723
269
+ Random agent — 20-over, cricsheet: score=63.6 coherence=0.568 reward=−5.632
270
+ Random agent — 20-over, heuristic: score=82.4 coherence=0.545 reward=−8.174
271
+
272
+ Train-smoke 5-over, heuristic: r_validity=1.0 coherence=0.596 3W/0L
273
+ Train-smoke — 5-over, cricsheet: r_validity=1.0 coherence=0.620 2W/1L
274
+ Train-smoke 5-over, llm_live: r_validity=1.0 coherence=0.552 2W/1L
275
+
276
+ LLM captain (gemma-4-26B) — 3-over, llm_live:
277
+ coherence=0.657 adaptation=0.502 opp_aware=0.750 parse_err=0%
 
 
 
 
 
 
 
 
 
 
 
278
  ```
279
 
280
+ ## 15. Immediate Next Steps
 
 
 
 
281
 
282
+ 1. **Run GRPO training** via `bash scripts/run_warmup_then_main.sh` (warmup curriculum + main run + auto-resume) to produce reward curves.
283
+ 2. **Deploy HF Space** for live Gradio demo Dockerfile present, just needs HF push.
284
+ 3. **Generate opponent cache** using `llm_live` for reproducible official eval.
285
+ 4. **Produce training plots** — coherence heatmap, reward curve, tool-usage timeline.
docs/slides.html CHANGED
@@ -350,35 +350,36 @@
350
  <tr>
351
  <th>Rubric</th><th>Weight</th><th>Frequency</th><th>Measures</th><th>Key Sub-signals</th>
352
  </tr>
353
- <tr>
354
- <td><code>r_result</code></td>
355
- <td><strong>55%</strong></td>
356
- <td>Episode end</td>
357
- <td>Win/loss vs DLS par, target margin</td>
358
- <td>score/par, wickets_remaining, lead/deficit</td>
359
- </tr>
360
  <tr>
361
  <td><code>r_cricket</code></td>
362
- <td><strong>25%</strong></td>
363
- <td>Innings end</td>
364
  <td>Dream11 proxy: runs, wickets, milestones</td>
365
- <td>dot%, boundary%, 50s/100s, maiden overs</td>
366
  </tr>
367
  <tr>
368
  <td><code>r_behavior</code></td>
369
- <td><strong>15%</strong></td>
370
- <td>Every delivery</td>
371
  <td>Declaration-execution alignment</td>
372
  <td>coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%)</td>
373
  </tr>
 
 
 
 
 
 
 
374
  <tr>
375
  <td><code>r_validity</code></td>
376
- <td><strong>5%</strong></td>
377
  <td>Every turn</td>
378
- <td>Parseable JSON tool call</td>
379
  <td>Format gate; 0 = parse fail, 1 = valid</td>
380
  </tr>
381
  </table>
 
382
  <div class="two-col" style="margin-top:18px;">
383
  <div>
384
  <h3>Coherence Score Formula (per delivery)</h3>
@@ -389,12 +390,12 @@
389
  )</pre>
390
  </div>
391
  <div>
392
- <h3>Two-Stage Curriculum (ToolRL)</h3>
393
  <ul>
394
- <li><strong>Stage 1:</strong> <code>r_validity</code> only teaches JSON format fast</li>
395
- <li><strong>Stage 2:</strong> all 4 rubrics teaches strategy and coherence</li>
396
- <li>Non-zero floor (0.05–0.15) for valid structural callsprevents dead gradient</li>
397
- <li>GRPO group size = 8; per-turn advantage estimation (MT-GRPO)</li>
398
  </ul>
399
  </div>
400
  </div>
@@ -445,15 +446,19 @@ python inference.py \
445
  --config configs/default.yaml \
446
  --max-overs 3 --opponent-mode llm_live
447
 
448
- <span class="dim"># 4. Stage 1 format mastery</span>
449
- python train.py train \
450
- --config configs/default.yaml \
451
- --stage 1 --steps 200
452
-
453
- <span class="dim"># 5. Stage 2 strategic coherence</span>
454
- python train.py train \
455
- --config configs/default.yaml \
456
- --stage 2 --steps 600</pre>
 
 
 
 
457
  <div class="wn" style="font-size:0.84rem;">
458
  All model / API / env settings live in <code>configs/default.yaml</code>. Zero hardcoding.
459
  </div>
@@ -489,8 +494,8 @@ python train.py train \
489
  <div>
490
  <h3>What training should produce (target)</h3>
491
  <ul>
492
- <li>r_validity: 0.70 → 0.98+ after Stage 1 (50 steps)</li>
493
- <li>Coherence: ~0.52 (random) → 0.75+ after Stage 2</li>
494
  <li>analyze_situation calls cluster at over 6, 16, 36 transitions</li>
495
  <li>Strategy declarations become more specific (&gt;15 word rationales)</li>
496
  <li>Shot choices match declared aggression level &gt;80% of deliveries</li>
@@ -570,7 +575,7 @@ python train.py train \
570
  <div>
571
  <h3>🔴 Critical Path (on-site, Day 1–2)</h3>
572
  <ul>
573
- <li>Run Colab notebook on HF compute credits Stage 1 then Stage 2 training</li>
574
  <li>Export: reward_curves.png, coherence_heatmap.png, tool_timeline.png</li>
575
  <li>Deploy to HuggingFace Spaces → live interactive Gradio demo URL</li>
576
  <li>Add HF Space URL + plot images to README</li>
 
350
  <tr>
351
  <th>Rubric</th><th>Weight</th><th>Frequency</th><th>Measures</th><th>Key Sub-signals</th>
352
  </tr>
 
 
 
 
 
 
 
353
  <tr>
354
  <td><code>r_cricket</code></td>
355
+ <td><strong>45%</strong></td>
356
+ <td>Per ball</td>
357
  <td>Dream11 proxy: runs, wickets, milestones</td>
358
+ <td>dot%, boundary%, 50s/100s, maiden overs, economy</td>
359
  </tr>
360
  <tr>
361
  <td><code>r_behavior</code></td>
362
+ <td><strong>25%</strong></td>
363
+ <td>Every turn</td>
364
  <td>Declaration-execution alignment</td>
365
  <td>coherence (50%) + adaptation (20%) + opponent_awareness (20%) + regret (10%)</td>
366
  </tr>
367
+ <tr>
368
+ <td><code>r_result</code></td>
369
+ <td><strong>20%</strong></td>
370
+ <td>Innings/episode end</td>
371
+ <td>Win/loss vs DLS par, target margin</td>
372
+ <td>score/par, wickets_remaining, lead/deficit, +0.25 progress bonus</td>
373
+ </tr>
374
  <tr>
375
  <td><code>r_validity</code></td>
376
+ <td><strong>10%</strong></td>
377
  <td>Every turn</td>
378
+ <td>Parseable XML/JSON tool call</td>
379
  <td>Format gate; 0 = parse fail, 1 = valid</td>
380
  </tr>
381
  </table>
382
+ <p style="margin-top:8px;font-size:0.9em;color:#888">Rebalanced from 55/25/15/5 → 45/25/20/10 to match the SWE-RL recipe (60% intermediate / 40% terminal). Reasoning: partial-trajectory training rarely fires <code>r_result</code>; weighting it 55% wastes gradient on a near-zero signal.</p>
383
  <div class="two-col" style="margin-top:18px;">
384
  <div>
385
  <h3>Coherence Score Formula (per delivery)</h3>
 
390
  )</pre>
391
  </div>
392
  <div>
393
+ <h3>Single-Stage Training + Format Curriculum</h3>
394
  <ul>
395
+ <li><strong>Warmup (5-over curriculum):</strong> per-scenario <code>max_overs</code> sampled from [2,2,2,2,2,3,3,3,4,4,5] so episodes complete in budget and <code>r_result</code> can fire</li>
396
+ <li><strong>Main run (full T20):</strong> resumes warmup adapter, trains on target eval distribution</li>
397
+ <li>Qwen3.5-4B already does tool calling natively (XML+JSON)no Stage 1 SFT needed</li>
398
+ <li>GRPO group size = 4; full episode advantages (TRL <code>environment_factory</code>)</li>
399
  </ul>
400
  </div>
401
  </div>
 
446
  --config configs/default.yaml \
447
  --max-overs 3 --opponent-mode llm_live
448
 
449
+ <span class="dim"># 4. Warmup Main chained run (auto-resumes adapter)</span>
450
+ bash scripts/run_warmup_then_main.sh
451
+
452
+ <span class="dim"># 5. Eval: untrained vs trained head-to-head</span>
453
+ python compare_eval.py --model Qwen/Qwen3.5-4B \
454
+ --label baseline --episodes 20 --max-overs 5 \
455
+ --output eval_results/baseline.json
456
+ python compare_eval.py --model Qwen/Qwen3.5-4B \
457
+ --adapter ./checkpoints/stage2_final \
458
+ --label trained --episodes 20 --max-overs 5 \
459
+ --output eval_results/trained.json
460
+ python compare_eval.py --compare \
461
+ eval_results/baseline.json eval_results/trained.json</pre>
462
  <div class="wn" style="font-size:0.84rem;">
463
  All model / API / env settings live in <code>configs/default.yaml</code>. Zero hardcoding.
464
  </div>
 
494
  <div>
495
  <h3>What training should produce (target)</h3>
496
  <ul>
497
+ <li>r_validity: 0.70 → 0.98+ after warmup (25 steps)</li>
498
+ <li>Coherence: ~0.52 (random) → 0.75+ after main run</li>
499
  <li>analyze_situation calls cluster at over 6, 16, 36 transitions</li>
500
  <li>Strategy declarations become more specific (&gt;15 word rationales)</li>
501
  <li>Shot choices match declared aggression level &gt;80% of deliveries</li>
 
575
  <div>
576
  <h3>🔴 Critical Path (on-site, Day 1–2)</h3>
577
  <ul>
578
+ <li>Run Colab notebook (notebooks/colab_train_minimal.ipynb) warmupmain chained training</li>
579
  <li>Export: reward_curves.png, coherence_heatmap.png, tool_timeline.png</li>
580
  <li>Deploy to HuggingFace Spaces → live interactive Gradio demo URL</li>
581
  <li>Add HF Space URL + plot images to README</li>
openenv.yaml CHANGED
@@ -6,23 +6,24 @@ app: server.app:app
6
  port: 8000
7
 
8
  description: >
9
- CricketCaptain-LLM trains LLMs to exhibit strategic coherence aligning
10
- declared intentions with executed actions across 300 sequential decisions.
11
- The agent uses 14 tools (toss, match-plan, batting, bowling, reflection)
12
- and is scored on four rubrics: match outcome (55%), cricket contribution (25%),
13
- behavioral coherence (15%), and tool-call validity (5%).
 
 
14
  Two-sided: a live or heuristic LLM opponent plays the opposing team.
15
- Two-stage ToolRL curriculum (format mastery full strategic reward).
 
 
 
16
 
17
  tasks:
18
- - name: stage1_format
19
- description: "5-over mini-match. r_validity only — teaches valid JSON tool-call structure (ToolRL Stage 1)."
20
- difficulty: easy
21
-
22
  - name: stage2_full
23
- description: "20-over match. Full reward: r_result (55%) + r_cricket (25%) + r_behavior (15%) + r_validity (5%)."
24
  difficulty: medium
25
 
26
  - name: eval_50over
27
- description: "Full 50-over ODI. Evaluation benchmark — measures trained coherence vs DLS par."
28
  difficulty: hard
 
6
  port: 8000
7
 
8
  description: >
9
+ CricketCaptain-LLM is a multi-turn agentic RL benchmark: train an LLM to captain
10
+ a cricket match end-to-end, alternating between batting and bowling phases across
11
+ ~180 sequential tool calls. The agent uses 14 phase-gated tools (toss, match plan,
12
+ batting, bowling, fielding, reflection, analysis) and is scored by a composite
13
+ 4-rubric reward Dream11-style per-ball cricket-position signal (45%), per-turn
14
+ behavioral coherence/adaptation/opponent-awareness/regret (25%), match outcome
15
+ with DLS par + win bonus + progress bonus (20%), and tool-call validity (10%).
16
  Two-sided: a live or heuristic LLM opponent plays the opposing team.
17
+ Real Markov outcome engine trained on 1.65M cricsheet deliveries.
18
+ Single-stage GRPO training with format-length curriculum (T2 → T5 in warmup,
19
+ full T20 in main run). Partial-trajectory training generalizes to full match
20
+ completion at eval time (same recipe as SWE-RL coding agents).
21
 
22
  tasks:
 
 
 
 
23
  - name: stage2_full
24
+ description: "Full T20 match (configurable max_overs 2-50). All 4 rubrics active. Composite reward: r_cricket (45%) + r_behavior (25%) + r_result (20%) + r_validity (10%)."
25
  difficulty: medium
26
 
27
  - name: eval_50over
28
+ description: "Full 50-over ODI. Evaluation benchmark — measures trained captaincy across the longest format (DLS par + chase pressure)."
29
  difficulty: hard
pyproject.toml CHANGED
@@ -15,15 +15,26 @@ dependencies = [
15
  "openai>=1.0.0",
16
  ]
17
 
 
 
 
 
 
 
 
18
  [project.optional-dependencies]
19
  train = [
20
- "trl>=0.24.0",
21
- "transformers>=4.50.0",
 
 
 
22
  "accelerate>=1.0.0",
23
  "datasets>=4.0.0",
24
  "bitsandbytes>=0.43.0",
25
- "mergekit>=0.1.0",
26
- "peft>=0.13.0",
 
27
  ]
28
  eval = [
29
  "matplotlib>=3.8.0",
 
15
  "openai>=1.0.0",
16
  ]
17
 
18
+ # Training extras — these are the versions that actually reconcile in 2026:
19
+ # transformers 5.6.2 ─┐
20
+ # trl 1.2.0 ├─ TRL multi-turn environment_factory needs transformers >=5.2,
21
+ # vllm 0.19.1 ┘ vLLM 0.19+ supports transformers 5, vLLM 0.18 does not.
22
+ # Earlier we tried vllm 0.11.x — it pinned transformers <5 and broke environment_factory.
23
+ # mergekit removed: pinned pydantic <2.11 which conflicts with openenv-core 0.2.3 (>=2.11.7).
24
+ # Not used by training anyway.
25
  [project.optional-dependencies]
26
  train = [
27
+ "torch==2.10.0",
28
+ "transformers==5.6.2",
29
+ "trl==1.2.0",
30
+ "vllm==0.19.1",
31
+ "peft>=0.13.0,<0.20.0",
32
  "accelerate>=1.0.0",
33
  "datasets>=4.0.0",
34
  "bitsandbytes>=0.43.0",
35
+ "wandb>=0.16",
36
+ # flash-attn is optional — vLLM has its own attention backends; uncomment if you want it:
37
+ # "flash-attn>=2.5.0",
38
  ]
39
  eval = [
40
  "matplotlib>=3.8.0",
scripts/eval_all_checkpoints.sh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Evaluate every saved checkpoint and find the best one.
3
+ #
4
+ # RL training is bumpy — the FINAL checkpoint isn't always the best policy.
5
+ # This script runs compare_eval.py against the baseline AND each checkpoint
6
+ # in ./checkpoints/, then picks whichever has the highest mean composite
7
+ # reward to be the "best trained" submission.
8
+ #
9
+ # Usage:
10
+ # bash scripts/eval_all_checkpoints.sh # 10 episodes each
11
+ # EVAL_EPISODES=20 bash scripts/eval_all_checkpoints.sh # more confidence
12
+
13
+ cd "$(dirname "$0")/.."
14
+ mkdir -p eval_results
15
+
16
+ EVAL_EPISODES="${EVAL_EPISODES:-10}"
17
+ EVAL_MAX_OVERS="${EVAL_MAX_OVERS:-5}"
18
+ BASE_MODEL="${BASE_MODEL:-Qwen/Qwen3.5-4B}"
19
+
20
+ # 1. Baseline (untrained)
21
+ if [ ! -f eval_results/baseline.json ]; then
22
+ echo "=== [baseline] running ==="
23
+ python compare_eval.py \
24
+ --model "$BASE_MODEL" \
25
+ --label baseline \
26
+ --episodes "$EVAL_EPISODES" \
27
+ --max-overs "$EVAL_MAX_OVERS" \
28
+ --opponent-mode heuristic \
29
+ --output eval_results/baseline.json
30
+ fi
31
+
32
+ # 2. Each checkpoint
33
+ for ckpt in ./checkpoints/stage*/checkpoint-* ./checkpoints/stage*_final; do
34
+ if [ -d "$ckpt" ]; then
35
+ # Verify it's a PEFT adapter dir (has adapter_config.json)
36
+ if [ ! -f "$ckpt/adapter_config.json" ]; then
37
+ echo "=== [$ckpt] no adapter_config.json — skip ==="
38
+ continue
39
+ fi
40
+ # Use a safe filename derived from path
41
+ label=$(echo "$ckpt" | sed 's|[/.]|_|g' | sed 's|^_||')
42
+ out="eval_results/${label}.json"
43
+ if [ -f "$out" ]; then
44
+ echo "=== [$ckpt] already evaluated, reading $out ==="
45
+ continue
46
+ fi
47
+ echo "=== [$ckpt] evaluating ==="
48
+ python compare_eval.py \
49
+ --model "$BASE_MODEL" \
50
+ --adapter "$ckpt" \
51
+ --label "$label" \
52
+ --episodes "$EVAL_EPISODES" \
53
+ --max-overs "$EVAL_MAX_OVERS" \
54
+ --opponent-mode heuristic \
55
+ --output "$out"
56
+ fi
57
+ done
58
+
59
+ # 3. Pick the best one and run a final compare against baseline
60
+ echo ""
61
+ echo "=== picking best checkpoint by mean composite reward ==="
62
+ python - <<'PYEOF'
63
+ import json, glob
64
+
65
+ best = None
66
+ for path in glob.glob("eval_results/*.json"):
67
+ if path.endswith("baseline.json"):
68
+ continue
69
+ try:
70
+ with open(path) as f:
71
+ data = json.load(f)
72
+ score = data["summary"].get("mean_composite_reward", 0.0) or 0.0
73
+ win_rate = data["summary"].get("win_rate_overall", 0.0) or 0.0
74
+ # Composite ranking: composite reward + 0.5 * win_rate
75
+ composite_score = score + 0.5 * win_rate
76
+ print(f" {path:50s} composite={score:.4f} win_rate={win_rate:.4f} ranking={composite_score:.4f}")
77
+ if best is None or composite_score > best[0]:
78
+ best = (composite_score, path, data)
79
+ except Exception as e:
80
+ print(f" {path}: skip ({e})")
81
+
82
+ if best:
83
+ print(f"\nBEST: {best[1]} (composite_score={best[0]:.4f}, label={best[2].get('label')})")
84
+ print(f" adapter: {best[2].get('adapter')}")
85
+
86
+ # Save a symlink/copy as 'best.json' for easy reference
87
+ import shutil
88
+ shutil.copy(best[1], "eval_results/best_trained.json")
89
+ print(f"\nCopied to eval_results/best_trained.json")
90
+ PYEOF
91
+
92
+ echo ""
93
+ echo "=== final comparison: baseline vs best trained ==="
94
+ python compare_eval.py --compare eval_results/baseline.json eval_results/best_trained.json \
95
+ | tee eval_results/final_comparison.txt
scripts/generate_training_plots.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate labeled PNG plots for the README from a WandB run OR from local
3
+ episode_stats.jsonl files.
4
+
5
+ Usage:
6
+ # From a WandB run id (preferred — uses the per-step rebalanced metrics)
7
+ python scripts/generate_training_plots.py \\
8
+ --wandb-run ptnv-s-research/huggingface/<RUN_ID> \\
9
+ --output-dir docs/plots/
10
+
11
+ # From local episode_stats.jsonl (faster, no API call)
12
+ python scripts/generate_training_plots.py \\
13
+ --jsonl logs/run_*/episode_stats.jsonl \\
14
+ --output-dir docs/plots/
15
+
16
+ Generates (with axis labels + units):
17
+ docs/plots/training_reward_over_steps.png
18
+ docs/plots/per_rubric_breakdown.png
19
+ docs/plots/tool_call_frequency.png
20
+ docs/plots/match_completion_rate.png
21
+ docs/plots/before_after_comparison.png (if --compare given)
22
+ """
23
+ import argparse
24
+ import glob
25
+ import json
26
+ import os
27
+ from pathlib import Path
28
+ from typing import Any
29
+
30
+ import matplotlib
31
+ matplotlib.use("Agg") # headless
32
+ import matplotlib.pyplot as plt
33
+
34
+
35
+ def _load_jsonl(path: str) -> list[dict[str, Any]]:
36
+ rows = []
37
+ paths = glob.glob(path) if "*" in path else [path]
38
+ for p in paths:
39
+ with open(p) as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if line:
43
+ try:
44
+ rows.append(json.loads(line))
45
+ except json.JSONDecodeError:
46
+ continue
47
+ return rows
48
+
49
+
50
+ def _load_wandb(run_path: str) -> tuple[list[dict[str, Any]], dict[str, Any]]:
51
+ """Returns (history, config). Requires `pip install wandb` and login."""
52
+ try:
53
+ import wandb
54
+ except ImportError:
55
+ raise RuntimeError("wandb not installed. pip install wandb")
56
+ api = wandb.Api()
57
+ run = api.run(run_path)
58
+ history = list(run.history(samples=10000))
59
+ return history, run.config
60
+
61
+
62
+ def plot_training_reward(history, out_dir: Path, label: str):
63
+ steps, rewards = [], []
64
+ for row in history:
65
+ if "rewards/environment_reward/mean" in row and row["rewards/environment_reward/mean"] is not None:
66
+ steps.append(row.get("_step", row.get("step", len(steps))))
67
+ rewards.append(row["rewards/environment_reward/mean"])
68
+ if not rewards:
69
+ print(" no environment_reward/mean found, skipping")
70
+ return
71
+ fig, ax = plt.subplots(figsize=(8, 4.5))
72
+ ax.plot(steps, rewards, marker="o", linewidth=1.5, markersize=4, color="#0066cc")
73
+ ax.set_xlabel("Training step (gradient updates)")
74
+ ax.set_ylabel("Mean environment reward (composite)")
75
+ ax.set_title(f"GRPO training reward over time — {label}")
76
+ ax.grid(alpha=0.3)
77
+ fig.tight_layout()
78
+ out_path = out_dir / "training_reward_over_steps.png"
79
+ fig.savefig(out_path, dpi=130)
80
+ plt.close(fig)
81
+ print(f" → {out_path}")
82
+
83
+
84
+ def plot_per_rubric_breakdown(history, out_dir: Path, label: str):
85
+ """Plot the per-step means of all 4 rubrics on one axes."""
86
+ rubrics = ("reward/composite_mean", "reward/r_result_mean",
87
+ "reward/r_cricket_mean", "reward/r_behavior_mean",
88
+ "reward/r_validity_mean")
89
+ series = {r: [] for r in rubrics}
90
+ steps_per = {r: [] for r in rubrics}
91
+ for row in history:
92
+ for r in rubrics:
93
+ if r in row and row[r] is not None:
94
+ series[r].append(row[r])
95
+ steps_per[r].append(row.get("_step", row.get("step", len(series[r]))))
96
+ if not any(series.values()):
97
+ print(" no per-rubric metrics found, skipping")
98
+ return
99
+ fig, ax = plt.subplots(figsize=(9, 5))
100
+ colors = {"reward/composite_mean": "#000",
101
+ "reward/r_result_mean": "#cc0000",
102
+ "reward/r_cricket_mean": "#0066cc",
103
+ "reward/r_behavior_mean": "#009900",
104
+ "reward/r_validity_mean": "#9900cc"}
105
+ for r in rubrics:
106
+ if series[r]:
107
+ ax.plot(steps_per[r], series[r], marker="o", markersize=3, linewidth=1.3,
108
+ label=r.replace("reward/", "").replace("_mean", ""),
109
+ color=colors[r])
110
+ ax.set_xlabel("Training step (gradient updates)")
111
+ ax.set_ylabel("Mean reward")
112
+ ax.set_title(f"Per-rubric reward breakdown — {label}")
113
+ ax.legend(loc="best", fontsize=9)
114
+ ax.grid(alpha=0.3)
115
+ fig.tight_layout()
116
+ out_path = out_dir / "per_rubric_breakdown.png"
117
+ fig.savefig(out_path, dpi=130)
118
+ plt.close(fig)
119
+ print(f" → {out_path}")
120
+
121
+
122
+ def plot_tool_call_frequency(history, out_dir: Path, label: str):
123
+ steps, freq = [], []
124
+ for row in history:
125
+ if "tools/call_frequency" in row and row["tools/call_frequency"] is not None:
126
+ steps.append(row.get("_step", row.get("step", len(steps))))
127
+ freq.append(row["tools/call_frequency"])
128
+ if not freq:
129
+ print(" no tools/call_frequency found, skipping")
130
+ return
131
+ fig, ax = plt.subplots(figsize=(8, 4.5))
132
+ ax.plot(steps, freq, marker="o", linewidth=1.5, markersize=4, color="#cc6600")
133
+ ax.set_xlabel("Training step (gradient updates)")
134
+ ax.set_ylabel("Mean tool calls per rollout")
135
+ ax.set_title(f"Tool-call execution frequency (proxy for match progress) — {label}")
136
+ ax.grid(alpha=0.3)
137
+ fig.tight_layout()
138
+ out_path = out_dir / "tool_call_frequency.png"
139
+ fig.savefig(out_path, dpi=130)
140
+ plt.close(fig)
141
+ print(f" → {out_path}")
142
+
143
+
144
+ def plot_completion_rate(history, out_dir: Path, label: str):
145
+ steps, rate = [], []
146
+ for row in history:
147
+ if "rollout/match_completion_rate" in row and row["rollout/match_completion_rate"] is not None:
148
+ steps.append(row.get("_step", row.get("step", len(steps))))
149
+ rate.append(row["rollout/match_completion_rate"])
150
+ if not rate:
151
+ print(" no match_completion_rate found, skipping")
152
+ return
153
+ fig, ax = plt.subplots(figsize=(8, 4.5))
154
+ ax.plot(steps, rate, marker="o", linewidth=1.5, markersize=4, color="#009966")
155
+ ax.set_xlabel("Training step (gradient updates)")
156
+ ax.set_ylabel("Match completion rate")
157
+ ax.set_ylim(0, 1.05)
158
+ ax.set_title(f"Fraction of rollouts that completed the full match — {label}")
159
+ ax.grid(alpha=0.3)
160
+ fig.tight_layout()
161
+ out_path = out_dir / "match_completion_rate.png"
162
+ fig.savefig(out_path, dpi=130)
163
+ plt.close(fig)
164
+ print(f" → {out_path}")
165
+
166
+
167
+ def plot_before_after(baseline_json: str, trained_json: str, out_dir: Path):
168
+ """Bar chart comparing baseline vs trained on key eval metrics."""
169
+ with open(baseline_json) as f:
170
+ b = json.load(f)
171
+ with open(trained_json) as f:
172
+ t = json.load(f)
173
+ bs, ts = b["summary"], t["summary"]
174
+ metrics = [
175
+ ("match_completion_rate", "Match\ncompletion rate"),
176
+ ("win_rate_overall", "Overall\nwin rate"),
177
+ ("mean_validity_rate", "Mean\nvalidity rate"),
178
+ ("mean_composite_reward", "Mean composite\nreward (scaled)"),
179
+ ]
180
+ bvals = [bs.get(k, 0) or 0 for k, _ in metrics]
181
+ tvals = [ts.get(k, 0) or 0 for k, _ in metrics]
182
+ labels = [lbl for _, lbl in metrics]
183
+
184
+ x = range(len(metrics))
185
+ fig, ax = plt.subplots(figsize=(9, 5))
186
+ width = 0.35
187
+ bars_b = ax.bar([xi - width/2 for xi in x], bvals, width, label="baseline (untrained)", color="#999")
188
+ bars_t = ax.bar([xi + width/2 for xi in x], tvals, width, label="trained (LoRA r=64)", color="#0066cc")
189
+
190
+ for bars in (bars_b, bars_t):
191
+ for bar in bars:
192
+ h = bar.get_height()
193
+ ax.text(bar.get_x() + bar.get_width()/2, h + 0.01,
194
+ f"{h:.2f}", ha="center", fontsize=8)
195
+
196
+ ax.set_xticks(list(x))
197
+ ax.set_xticklabels(labels)
198
+ ax.set_ylabel("Metric value")
199
+ ax.set_title(f"Before vs After training — {bs['n_episodes']} eval matches each")
200
+ ax.legend()
201
+ ax.grid(axis="y", alpha=0.3)
202
+ fig.tight_layout()
203
+ out_path = out_dir / "before_after_comparison.png"
204
+ fig.savefig(out_path, dpi=130)
205
+ plt.close(fig)
206
+ print(f" → {out_path}")
207
+
208
+
209
+ def main():
210
+ p = argparse.ArgumentParser()
211
+ p.add_argument("--wandb-run", default=None,
212
+ help="WandB run path: entity/project/run_id (e.g. ptnv-s-research/huggingface/abc123)")
213
+ p.add_argument("--jsonl", default=None,
214
+ help="Local episode_stats.jsonl path (or glob)")
215
+ p.add_argument("--output-dir", default="docs/plots",
216
+ help="Output directory for PNGs (default: docs/plots/)")
217
+ p.add_argument("--label", default="warmup", help="Label suffix for plot titles")
218
+ p.add_argument("--compare", nargs=2, metavar=("BASELINE_JSON", "TRAINED_JSON"),
219
+ help="Also generate before/after bar chart from two compare_eval JSON files")
220
+ args = p.parse_args()
221
+
222
+ out_dir = Path(args.output_dir)
223
+ out_dir.mkdir(parents=True, exist_ok=True)
224
+
225
+ history = []
226
+ if args.wandb_run:
227
+ print(f"Loading WandB run: {args.wandb_run}")
228
+ history, _ = _load_wandb(args.wandb_run)
229
+ print(f" {len(history)} history rows")
230
+ elif args.jsonl:
231
+ print(f"Loading local jsonl: {args.jsonl}")
232
+ history = _load_jsonl(args.jsonl)
233
+ print(f" {len(history)} rows")
234
+
235
+ if history:
236
+ plot_training_reward(history, out_dir, args.label)
237
+ plot_per_rubric_breakdown(history, out_dir, args.label)
238
+ plot_tool_call_frequency(history, out_dir, args.label)
239
+ plot_completion_rate(history, out_dir, args.label)
240
+
241
+ if args.compare:
242
+ plot_before_after(args.compare[0], args.compare[1], out_dir)
243
+
244
+ print(f"\nDone — PNGs in {out_dir}/")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()
scripts/run_full_pipeline.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Full pipeline: warmup → main → baseline eval → trained eval → comparison.
3
+ #
4
+ # This is the end-to-end deliverable:
5
+ # 1. Warmup (5-over curriculum, 25 steps) ~2 hr
6
+ # 2. Main run (20-over T20, 100 steps) ~6-8 hr
7
+ # 3. Baseline eval (untrained Qwen3.5-4B) ~30 min
8
+ # 4. Trained eval (warmup+main checkpoint) ~30 min
9
+ # 5. Print comparison table ~instant
10
+ #
11
+ # Outputs:
12
+ # - ./checkpoints/stage2_final/ trained LoRA adapter
13
+ # - eval_results/baseline.json baseline match stats
14
+ # - eval_results/trained.json trained match stats
15
+ # - /tmp/train_warmup.log warmup training log
16
+ # - /tmp/train_main.log main run training log
17
+ # - /tmp/eval_baseline.log baseline eval log
18
+ # - /tmp/eval_trained.log trained eval log
19
+ #
20
+ # Usage: bash scripts/run_full_pipeline.sh
21
+ # (or run individual stages as documented in the README)
22
+
23
+ # NOTE: deliberately NOT using `set -e`. We want to inspect each stage's exit
24
+ # code and decide whether to continue, not abort on first non-zero return.
25
+ cd "$(dirname "$0")/.."
26
+
27
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
28
+
29
+ EVAL_EPISODES="${EVAL_EPISODES:-15}"
30
+ EVAL_MAX_OVERS="${EVAL_MAX_OVERS:-5}"
31
+
32
+ # --------------------------------------------------------------------------
33
+ # Stage 1+2: warmup → main (chained)
34
+ # --------------------------------------------------------------------------
35
+ echo "=== [$(date '+%H:%M:%S')] FULL PIPELINE: warmup → main → eval ==="
36
+ bash scripts/run_warmup_then_main.sh
37
+ PIPE_STATUS=$?
38
+ if [ $PIPE_STATUS -ne 0 ]; then
39
+ echo "!!! Training pipeline failed (exit $PIPE_STATUS). Skipping eval."
40
+ exit $PIPE_STATUS
41
+ fi
42
+
43
+ # --------------------------------------------------------------------------
44
+ # Stage 3: baseline eval (untrained Qwen3.5-4B)
45
+ # --------------------------------------------------------------------------
46
+ mkdir -p eval_results
47
+ echo "=== [$(date '+%H:%M:%S')] EVAL: baseline (untrained Qwen3.5-4B) ==="
48
+ python compare_eval.py \
49
+ --model Qwen/Qwen3.5-4B \
50
+ --label baseline \
51
+ --episodes "$EVAL_EPISODES" \
52
+ --max-overs "$EVAL_MAX_OVERS" \
53
+ --opponent-mode heuristic \
54
+ --output eval_results/baseline.json \
55
+ > /tmp/eval_baseline.log 2>&1
56
+
57
+ # --------------------------------------------------------------------------
58
+ # Stage 4: trained eval (warmup + main adapter)
59
+ # --------------------------------------------------------------------------
60
+ echo "=== [$(date '+%H:%M:%S')] EVAL: trained (LoRA from ./checkpoints/stage2_final) ==="
61
+ python compare_eval.py \
62
+ --model Qwen/Qwen3.5-4B \
63
+ --adapter ./checkpoints/stage2_final \
64
+ --label trained \
65
+ --episodes "$EVAL_EPISODES" \
66
+ --max-overs "$EVAL_MAX_OVERS" \
67
+ --opponent-mode heuristic \
68
+ --output eval_results/trained.json \
69
+ > /tmp/eval_trained.log 2>&1
70
+
71
+ # --------------------------------------------------------------------------
72
+ # Stage 5: comparison
73
+ # --------------------------------------------------------------------------
74
+ echo "=== [$(date '+%H:%M:%S')] COMPARISON ==="
75
+ python compare_eval.py --compare eval_results/baseline.json eval_results/trained.json \
76
+ | tee eval_results/comparison.txt
77
+
78
+ echo ""
79
+ echo "=== [$(date '+%H:%M:%S')] DONE ==="
80
+ echo "Trained adapter: ./checkpoints/stage2_final/"
81
+ echo "Eval JSON: eval_results/{baseline,trained}.json"
82
+ echo "Comparison: eval_results/comparison.txt"
83
+ echo "Training logs: /tmp/train_warmup.log, /tmp/train_main.log"
84
+ echo "Eval logs: /tmp/eval_baseline.log, /tmp/eval_trained.log"
scripts/run_warmup_then_main.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Chained run: warmup (2-3 over curriculum) → main (5-over T20).
3
+ # Main auto-resumes from warmup's LoRA adapter at ./checkpoints/stage2_final
4
+ # (set via configs/cricket_train_qwen3.yaml resume_from).
5
+ #
6
+ # Active configs target Qwen3-4B-Instruct-2507 + vLLM colocate. The legacy
7
+ # Qwen3.5-4B configs are archived in configs/extras/.
8
+ #
9
+ # Usage: bash scripts/run_warmup_then_main.sh
10
+ # Logs: /tmp/train_warmup.log then /tmp/train_main.log
11
+
12
+ # NOTE: deliberately NOT using `set -e`. We want to inspect the warmup exit
13
+ # code and decide whether to continue, not abort on first non-zero return.
14
+ cd "$(dirname "$0")/.."
15
+
16
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
17
+ export TRL_EXPERIMENTAL_SILENCE=1
18
+
19
+ echo "=== [$(date '+%H:%M:%S')] WARMUP starting (2-3 over curriculum, 30 steps) ==="
20
+ python train.py train --config configs/cricket_train_qwen3_warmup.yaml \
21
+ > /tmp/train_warmup.log 2>&1
22
+
23
+ WARMUP_EXIT=$?
24
+ if [ $WARMUP_EXIT -ne 0 ]; then
25
+ echo "!!! WARMUP failed with exit $WARMUP_EXIT — see /tmp/train_warmup.log"
26
+ echo " Skipping main run."
27
+ exit $WARMUP_EXIT
28
+ fi
29
+
30
+ # Sanity: confirm checkpoint exists before launching main.
31
+ if [ ! -d ./checkpoints/stage2_final ]; then
32
+ echo "!!! Expected ./checkpoints/stage2_final not found after warmup."
33
+ echo " Main run wants to resume from there — aborting."
34
+ exit 1
35
+ fi
36
+
37
+ echo "=== [$(date '+%H:%M:%S')] WARMUP done — adapter saved to ./checkpoints/stage2_final ==="
38
+ echo "=== [$(date '+%H:%M:%S')] MAIN starting (5-over, 100 steps, resuming from warmup) ==="
39
+
40
+ python train.py train --config configs/cricket_train_qwen3.yaml \
41
+ > /tmp/train_main.log 2>&1
42
+
43
+ echo "=== [$(date '+%H:%M:%S')] MAIN done ==="
server/coherence_grader.py CHANGED
@@ -64,37 +64,44 @@ def bowling_coherence_score(
64
  ) -> float:
65
  """
66
  Grade bowling strategy coherence.
67
- Weights: 40% rationale specificity + 40% line/length/field logic + 20% phase appropriateness.
 
 
 
68
  """
69
  if not bowling_strategy:
70
  return 0.0
71
-
72
  rationale = bowling_strategy.get("rationale", "")
73
  r_spec = rationale_specificity(rationale)
74
-
75
- # Simple logic: Aggressive field with attacking line/length gets high score
76
- line = bowling_strategy.get("line", "outside off")
77
- length = bowling_strategy.get("length", "good length")
78
-
79
- logic_score = 0.5
 
 
 
 
 
80
  if field_setting == "Aggressive":
81
- if line in ["stumps", "on pads"] or length in ["bouncer", "short"]: # Attacking
82
- logic_score = 1.0
83
  elif field_setting == "Defensive":
84
- if line in ["outside off", "wide"] or length in ["yorker", "full"]: # Defensive
85
- logic_score = 1.0
86
- else:
87
- logic_score = 0.8 # Balanced is generally safe
88
-
89
- # Phase appropriate (e.g., spin in middle overs)
90
- p_approp = 1.0
91
  bowler_type = bowling_strategy.get("bowler_type", "pace")
92
  if phase == "middle" and bowler_type == "spin":
93
  p_approp = 1.0
94
- elif phase == "powerplay" and bowler_type == "pace":
95
  p_approp = 1.0
96
-
97
- score = 0.40 * r_spec + 0.40 * logic_score + 0.20 * p_approp
 
 
98
  return round(score, 4)
99
 
100
 
 
64
  ) -> float:
65
  """
66
  Grade bowling strategy coherence.
67
+ Weights (from game_knowledge.yaml): 40% rationale + 30% field logic + 30% phase fit.
68
+
69
+ Line/length values must already be normalized (normalize_line / normalize_length
70
+ from field_model.py) — e.g. "pads" not "on pads", "outside_off" not "outside off".
71
  """
72
  if not bowling_strategy:
73
  return 0.0
74
+
75
  rationale = bowling_strategy.get("rationale", "")
76
  r_spec = rationale_specificity(rationale)
77
+
78
+ line = bowling_strategy.get("line", "outside_off")
79
+ length = bowling_strategy.get("length", "good")
80
+
81
+ # Attacking plan: attack the stumps/pads with short/full threatening lengths
82
+ _ATTACKING_LINES = {"stumps", "pads"}
83
+ _ATTACKING_LENGTHS = {"bouncer", "short", "yorker"}
84
+ # Containing plan: bowl wide or full to restrict scoring
85
+ _DEFENSIVE_LINES = {"outside_off", "wide"}
86
+ _DEFENSIVE_LENGTHS = {"yorker", "full"}
87
+
88
  if field_setting == "Aggressive":
89
+ logic_score = 1.0 if (line in _ATTACKING_LINES or length in _ATTACKING_LENGTHS) else 0.5
 
90
  elif field_setting == "Defensive":
91
+ logic_score = 1.0 if (line in _DEFENSIVE_LINES or length in _DEFENSIVE_LENGTHS) else 0.5
92
+ else: # Balanced
93
+ logic_score = 0.8
94
+
95
+ # Phase appropriateness: spin in middle, pace in powerplay/death
 
 
96
  bowler_type = bowling_strategy.get("bowler_type", "pace")
97
  if phase == "middle" and bowler_type == "spin":
98
  p_approp = 1.0
99
+ elif phase in {"powerplay", "death"} and bowler_type == "pace":
100
  p_approp = 1.0
101
+ else:
102
+ p_approp = 0.6
103
+
104
+ score = 0.40 * r_spec + 0.30 * logic_score + 0.30 * p_approp
105
  return round(score, 4)
106
 
107
 
server/cricket_environment.py CHANGED
@@ -163,7 +163,7 @@ class CricketEnvironment(Environment):
163
  start_wickets = self._rng.randint(0, 9)
164
  start_score = int(start_over * self._rng.uniform(5.5, 8.5))
165
 
166
- start_phase = over_to_phase(start_over)
167
  start_bowler = sample_bowler_type(start_phase, self._rng)
168
 
169
  self._state = CricketState(
@@ -206,6 +206,20 @@ class CricketEnvironment(Environment):
206
  # Load roster for the agent's team
207
  agent_team = options.get("agent_team", os.environ.get("CRICKET_AGENT_TEAM", "india"))
208
  self._agent_roster = load_team_roster(agent_team)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Reset tool budget
210
  self._overhead_calls_this_over = 0
211
  self._total_tool_fines = 0.0
@@ -427,6 +441,23 @@ class CricketEnvironment(Environment):
427
  if shot_intent not in VALID_SHOT_INTENTS:
428
  self._format_violations += 1
429
  shot_intent = "defensive"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  context = self._context_for_policy()
431
  self._opponent_plan = self._opponent.bowling_plan(context)
432
  self._state.opponent_plan = self._opponent_plan
@@ -559,6 +590,23 @@ class CricketEnvironment(Environment):
559
  return self._build_obs(last_ball=f"Field set to {setting}.")
560
 
561
  def _handle_bowl_delivery(self, args: dict) -> CricketObservation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  context = self._context_for_policy()
563
  self._opponent_plan = self._opponent.batting_plan(context)
564
  self._state.opponent_plan = self._opponent_plan
@@ -599,6 +647,7 @@ class CricketEnvironment(Environment):
599
  field_setting=self._field_setting,
600
  wickets=self._state.wickets_lost,
601
  score=self._state.total_score,
 
602
  )
603
  metadata = {"event_type": "base_outcome", "target_area": normalize_target_area("", shot_intent)}
604
  return self._process_delivery(runs, wicket, extra, shot_intent, dismissal_type, metadata)
@@ -843,7 +892,7 @@ class CricketEnvironment(Environment):
843
  return round(max(-0.15, min(0.15, reward)), 4)
844
 
845
  def _update_phase(self):
846
- new_phase = over_to_phase(self._state.over)
847
  if new_phase != self._state.phase:
848
  self._state.phase = new_phase
849
  self._bowler_type = sample_bowler_type(new_phase, self._rng)
@@ -1083,8 +1132,8 @@ class CricketEnvironment(Environment):
1083
  if wicket:
1084
  return 0.0
1085
  if self._state.game_state == GameState.BATTING:
1086
- expected = self._engine.expected_runs(self._state.over, shot_intent, self._bowler_type)
1087
- baseline = max(self._engine.expected_runs(self._state.over, shot, self._bowler_type) for shot in SHOT_INTENTS)
1088
  return round(min(1.0, expected / max(baseline, 0.1)), 4)
1089
  # For bowling, lower conceded runs are better.
1090
  return round(max(0.0, 1.0 - (runs / 6.0)), 4)
@@ -1189,7 +1238,10 @@ def _render_prompt(ctx, batting_strat, bowling_strat, shot_plan, delivery_plan,
1189
  budget = 3
1190
  budget_str = f"Tool budget: {overhead_used}/{budget} overhead calls used this over"
1191
  if overhead_used >= budget:
1192
- budget_str += " ⚠ BUDGET EXHAUSTED — further set_strategy/plan_shot/analyze/reflect calls will be FINED"
 
 
 
1193
  lines.append(budget_str)
1194
  # Opponent plan intentionally NOT shown — agent must infer via analyze_situation
1195
 
 
163
  start_wickets = self._rng.randint(0, 9)
164
  start_score = int(start_over * self._rng.uniform(5.5, 8.5))
165
 
166
+ start_phase = over_to_phase(start_over, max_overs)
167
  start_bowler = sample_bowler_type(start_phase, self._rng)
168
 
169
  self._state = CricketState(
 
206
  # Load roster for the agent's team
207
  agent_team = options.get("agent_team", os.environ.get("CRICKET_AGENT_TEAM", "india"))
208
  self._agent_roster = load_team_roster(agent_team)
209
+ if self._agent_roster:
210
+ playing_xi = build_playing_xi(self._agent_roster)
211
+ batters = [p for p in playing_xi if p.get("role") != "bowler"]
212
+ bowlers = [p for p in playing_xi if p.get("bowler_type")]
213
+ if len(batters) >= 2:
214
+ self._current_batter = batter_profile_from_player(batters[0])
215
+ self._non_striker = batter_profile_from_player(batters[1])
216
+ elif len(playing_xi) >= 2:
217
+ self._current_batter = batter_profile_from_player(playing_xi[0])
218
+ self._non_striker = batter_profile_from_player(playing_xi[1])
219
+ matching_bowler = next((p for p in bowlers if p.get("bowler_type") == self._bowler_type), None)
220
+ if matching_bowler or bowlers:
221
+ self._current_bowler = bowler_profile_from_player(matching_bowler or bowlers[0])
222
+ self._bowler_type = self._current_bowler["type"]
223
  # Reset tool budget
224
  self._overhead_calls_this_over = 0
225
  self._total_tool_fines = 0.0
 
441
  if shot_intent not in VALID_SHOT_INTENTS:
442
  self._format_violations += 1
443
  shot_intent = "defensive"
444
+ # Inline shot-plan capture: if the model passes target_area/risk/trajectory/
445
+ # rationale directly to play_delivery (the new "execute-first" pattern that
446
+ # collapses plan_shot+play_delivery into a single turn), update the shot plan
447
+ # and score adaptation/opp_awareness here so we don't lose the reward signal.
448
+ has_plan_args = any(args.get(k) for k in ("target_area", "risk", "trajectory", "rationale", "explanation"))
449
+ if has_plan_args:
450
+ target_area = normalize_target_area(args.get("target_area", "gaps"), shot_intent)
451
+ risk = str(args.get("risk", "balanced")).lower()
452
+ self._shot_plan = {
453
+ "shot_intent": shot_intent,
454
+ "target_area": target_area,
455
+ "trajectory": infer_trajectory(shot_intent, risk, args.get("trajectory")),
456
+ "risk": risk,
457
+ "rationale": str(args.get("rationale", args.get("explanation", ""))),
458
+ }
459
+ self._score_adaptation(self._shot_plan)
460
+ self._score_opponent_awareness(self._shot_plan)
461
  context = self._context_for_policy()
462
  self._opponent_plan = self._opponent.bowling_plan(context)
463
  self._state.opponent_plan = self._opponent_plan
 
590
  return self._build_obs(last_ball=f"Field set to {setting}.")
591
 
592
  def _handle_bowl_delivery(self, args: dict) -> CricketObservation:
593
+ # Inline delivery-plan capture: if the model passes line/length/delivery_type
594
+ # directly to bowl_delivery (the new "execute-first" pattern that collapses
595
+ # plan_delivery+bowl_delivery into a single turn), update the delivery plan
596
+ # and score adaptation/opp_awareness here so we don't lose the reward signal.
597
+ has_plan_args = any(args.get(k) for k in ("line", "length", "delivery_type", "rationale"))
598
+ if has_plan_args:
599
+ current_type = str(self._current_bowler.get("type", self._bowler_type)).lower()
600
+ self._delivery_plan = {
601
+ "bowler_type": current_type,
602
+ "line": normalize_line(args.get("line", "outside off")),
603
+ "length": normalize_length(args.get("length", "good length")),
604
+ "delivery_type": normalize_variation(args.get("delivery_type", "stock"), current_type),
605
+ "rationale": str(args.get("rationale", "")),
606
+ }
607
+ self._bowling_strategy = dict(self._delivery_plan)
608
+ self._score_adaptation(self._delivery_plan)
609
+ self._score_opponent_awareness(self._delivery_plan)
610
  context = self._context_for_policy()
611
  self._opponent_plan = self._opponent.batting_plan(context)
612
  self._state.opponent_plan = self._opponent_plan
 
647
  field_setting=self._field_setting,
648
  wickets=self._state.wickets_lost,
649
  score=self._state.total_score,
650
+ max_overs=self._state.max_overs,
651
  )
652
  metadata = {"event_type": "base_outcome", "target_area": normalize_target_area("", shot_intent)}
653
  return self._process_delivery(runs, wicket, extra, shot_intent, dismissal_type, metadata)
 
892
  return round(max(-0.15, min(0.15, reward)), 4)
893
 
894
  def _update_phase(self):
895
+ new_phase = over_to_phase(self._state.over, self._state.max_overs)
896
  if new_phase != self._state.phase:
897
  self._state.phase = new_phase
898
  self._bowler_type = sample_bowler_type(new_phase, self._rng)
 
1132
  if wicket:
1133
  return 0.0
1134
  if self._state.game_state == GameState.BATTING:
1135
+ expected = self._engine.expected_runs(self._state.over, shot_intent, self._bowler_type, self._state.max_overs)
1136
+ baseline = max(self._engine.expected_runs(self._state.over, shot, self._bowler_type, self._state.max_overs) for shot in SHOT_INTENTS)
1137
  return round(min(1.0, expected / max(baseline, 0.1)), 4)
1138
  # For bowling, lower conceded runs are better.
1139
  return round(max(0.0, 1.0 - (runs / 6.0)), 4)
 
1238
  budget = 3
1239
  budget_str = f"Tool budget: {overhead_used}/{budget} overhead calls used this over"
1240
  if overhead_used >= budget:
1241
+ budget_str += (
1242
+ " ⚠ BUDGET EXHAUSTED — further set_strategy, set_bowling_strategy, plan_delivery, "
1243
+ "analyze_situation, or reflect_after_ball calls will be FINED"
1244
+ )
1245
  lines.append(budget_str)
1246
  # Opponent plan intentionally NOT shown — agent must infer via analyze_situation
1247
 
server/markov_engine.py CHANGED
@@ -65,12 +65,19 @@ BOWLER_TYPES = ["pace", "spin"]
65
  EXTRAS_RATE = 0.05
66
 
67
 
68
- def over_to_phase(over: int) -> str:
69
- if over <= 5:
70
- return "powerplay"
71
- if over <= 35:
72
- return "middle"
73
- return "death"
 
 
 
 
 
 
 
74
 
75
 
76
  def sample_bowler_type(phase: str, rng: random.Random) -> str:
@@ -118,6 +125,7 @@ class MarkovCricketEngine:
118
  score: int = 0,
119
  bowler_type: str = "pace",
120
  field_setting: str = "Balanced",
 
121
  ) -> tuple[int, bool, bool, str]:
122
  """Sample an outcome for one delivery.
123
 
@@ -131,7 +139,7 @@ class MarkovCricketEngine:
131
  if self._rng.random() < EXTRAS_RATE:
132
  return 1, False, True, ""
133
 
134
- phase = over_to_phase(over)
135
 
136
  if self._cricsheet is not None:
137
  runs, wicket = self._cricsheet_step(over, wickets, score, phase, bowler_type, shot_intent)
@@ -214,6 +222,7 @@ class MarkovCricketEngine:
214
  score=score,
215
  bowler_type=bowler_type,
216
  field_setting="Balanced",
 
217
  )
218
  if extra:
219
  metadata.update({"event_type": "wide", "fielder_effect": "none", "base_runs": runs, "base_wicket": wicket})
@@ -274,12 +283,13 @@ class MarkovCricketEngine:
274
  field_setting: str = "Balanced",
275
  wickets: int = 0,
276
  score: int = 0,
 
277
  ) -> tuple[int, bool, bool, str]:
278
  """Simulate an AI batter faced with the agent's bowling/fielding.
279
 
280
  Returns (runs, wicket, extra, shot_intent).
281
  """
282
- phase = over_to_phase(over)
283
 
284
  # Decide AI batter's shot intent based on state and phase
285
  # Aggression increases in death overs or with wickets in hand
@@ -311,15 +321,16 @@ class MarkovCricketEngine:
311
  wickets=wickets,
312
  score=score,
313
  bowler_type=bowler_type,
314
- field_setting=field_setting
 
315
  )
316
 
317
  return runs, wicket, extra, shot_intent, dismissal_type
318
 
319
- def expected_runs(self, over: int, shot_intent: str, bowler_type: str = "pace") -> float:
320
  if shot_intent not in SHOT_AGGRESSION:
321
  return 0.0
322
- phase = over_to_phase(over)
323
  if self._cricsheet:
324
  dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
325
  if dist:
@@ -327,10 +338,10 @@ class MarkovCricketEngine:
327
  dist = self._synthetic[shot_intent][phase]
328
  return sum(r * p for r, _, p in dist)
329
 
330
- def wicket_probability(self, over: int, shot_intent: str, bowler_type: str = "pace") -> float:
331
  if shot_intent not in SHOT_AGGRESSION:
332
  return 0.0
333
- phase = over_to_phase(over)
334
  if self._cricsheet:
335
  dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
336
  if dist:
 
65
  EXTRAS_RATE = 0.05
66
 
67
 
68
+ def over_to_phase(over: int, max_overs: int | None = None) -> str:
69
+ """Return the phase label for a given over, respecting the match format.
70
+
71
+ Without max_overs the old hardcoded thresholds (designed for ODI) would
72
+ leave T20 overs 16-19 classified as "middle" instead of "death". We now
73
+ delegate to format_mapper.get_phase which reads the correct phase windows
74
+ from data/format_rules.json.
75
+ """
76
+ try:
77
+ from server.format_mapper import get_phase
78
+ except ImportError:
79
+ from .format_mapper import get_phase
80
+ return get_phase(over, max_overs)
81
 
82
 
83
  def sample_bowler_type(phase: str, rng: random.Random) -> str:
 
125
  score: int = 0,
126
  bowler_type: str = "pace",
127
  field_setting: str = "Balanced",
128
+ max_overs: int | None = None,
129
  ) -> tuple[int, bool, bool, str]:
130
  """Sample an outcome for one delivery.
131
 
 
139
  if self._rng.random() < EXTRAS_RATE:
140
  return 1, False, True, ""
141
 
142
+ phase = over_to_phase(over, max_overs)
143
 
144
  if self._cricsheet is not None:
145
  runs, wicket = self._cricsheet_step(over, wickets, score, phase, bowler_type, shot_intent)
 
222
  score=score,
223
  bowler_type=bowler_type,
224
  field_setting="Balanced",
225
+ max_overs=max_overs,
226
  )
227
  if extra:
228
  metadata.update({"event_type": "wide", "fielder_effect": "none", "base_runs": runs, "base_wicket": wicket})
 
283
  field_setting: str = "Balanced",
284
  wickets: int = 0,
285
  score: int = 0,
286
+ max_overs: int | None = None,
287
  ) -> tuple[int, bool, bool, str]:
288
  """Simulate an AI batter faced with the agent's bowling/fielding.
289
 
290
  Returns (runs, wicket, extra, shot_intent).
291
  """
292
+ phase = over_to_phase(over, max_overs)
293
 
294
  # Decide AI batter's shot intent based on state and phase
295
  # Aggression increases in death overs or with wickets in hand
 
321
  wickets=wickets,
322
  score=score,
323
  bowler_type=bowler_type,
324
+ field_setting=field_setting,
325
+ max_overs=max_overs,
326
  )
327
 
328
  return runs, wicket, extra, shot_intent, dismissal_type
329
 
330
+ def expected_runs(self, over: int, shot_intent: str, bowler_type: str = "pace", max_overs: int | None = None) -> float:
331
  if shot_intent not in SHOT_AGGRESSION:
332
  return 0.0
333
+ phase = over_to_phase(over, max_overs)
334
  if self._cricsheet:
335
  dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
336
  if dist:
 
338
  dist = self._synthetic[shot_intent][phase]
339
  return sum(r * p for r, _, p in dist)
340
 
341
+ def wicket_probability(self, over: int, shot_intent: str, bowler_type: str = "pace", max_overs: int | None = None) -> float:
342
  if shot_intent not in SHOT_AGGRESSION:
343
  return 0.0
344
+ phase = over_to_phase(over, max_overs)
345
  if self._cricsheet:
346
  dist = self._get_cricsheet_dist(over, 3, 15, phase, bowler_type, shot_intent)
347
  if dist:
server/reward_calculator.py CHANGED
@@ -130,36 +130,47 @@ def compute_episode_reward(
130
  chase_progress = total_score / max(target, 1)
131
  wicket_penalty = wickets_lost * 0.08
132
  if total_score >= target:
133
- outcome_bonus = 1.0
134
  elif total_score == target - 1:
135
- outcome_bonus = 0.35
136
  else:
137
- outcome_bonus = 0.0
138
  r_cric = chase_progress + outcome_bonus - wicket_penalty
139
  else:
140
  # Bowling to defend: reward keeping opponent below target.
141
  defense_margin = max(target - total_score, 0) / max(target, 1)
142
  wicket_pressure = wickets_lost * 0.08
143
  if total_score < target - 1:
144
- outcome_bonus = 1.0
145
  elif total_score == target - 1:
146
- outcome_bonus = 0.35
147
  else:
148
- outcome_bonus = 0.0
149
  r_cric = defense_margin + wicket_pressure + outcome_bonus
150
  elif game_state == "batting":
151
  r_cric = (total_score / max(dls_par, 1.0)) - (wickets_lost * 0.08)
152
  else:
153
  # Bowling first innings: reward conceding fewer runs than DLS par.
154
- r_cric = max(0.0, (dls_par - total_score) / max(dls_par, 1.0))
155
- r_cric = max(-1.5, min(2.5, r_cric))
 
 
 
 
 
 
156
 
157
  # r_cricket: dense per-ball position signal via Dream11 proxy.
158
  # Normalised per innings then averaged so two-innings totals stay in [0, 1].
 
 
 
159
  if dream11_scores:
160
  r_dream11 = mean(normalize_dream11(s) for s in dream11_scores)
 
161
  else:
162
  r_dream11 = 0.0
 
163
 
164
  # Load weights from game_knowledge.yaml (cached after first load).
165
  w = get_reward_weights() if get_reward_weights is not None else None
@@ -179,12 +190,21 @@ def compute_episode_reward(
179
  r_tools = compute_tool_efficiency(tool_calls_made, analyze_calls, overs_played)
180
 
181
  eff_behavior_w = (w.r_behavior if w else 0.15) * coherence_weight_ramp
 
 
 
 
 
 
 
 
 
182
 
183
  composite = (
184
- (w.r_result if w else 0.55) * r_cric
185
- + (w.r_cricket if w else 0.25) * r_dream11
186
  + eff_behavior_w * r_strategy
187
- + (w.r_validity if w else 0.05) * r_format
188
  )
189
 
190
  return {
 
130
  chase_progress = total_score / max(target, 1)
131
  wicket_penalty = wickets_lost * 0.08
132
  if total_score >= target:
133
+ outcome_bonus = 1.0 # win
134
  elif total_score == target - 1:
135
+ outcome_bonus = 0.0 # tie — neutral, no consolation
136
  else:
137
+ outcome_bonus = -1.0 # loss — explicit negative signal
138
  r_cric = chase_progress + outcome_bonus - wicket_penalty
139
  else:
140
  # Bowling to defend: reward keeping opponent below target.
141
  defense_margin = max(target - total_score, 0) / max(target, 1)
142
  wicket_pressure = wickets_lost * 0.08
143
  if total_score < target - 1:
144
+ outcome_bonus = 1.0 # win (defended)
145
  elif total_score == target - 1:
146
+ outcome_bonus = 0.0 # tie
147
  else:
148
+ outcome_bonus = -1.0 # loss (target chased)
149
  r_cric = defense_margin + wicket_pressure + outcome_bonus
150
  elif game_state == "batting":
151
  r_cric = (total_score / max(dls_par, 1.0)) - (wickets_lost * 0.08)
152
  else:
153
  # Bowling first innings: reward conceding fewer runs than DLS par.
154
+ # Allow negative when conceding above par (was clamped to ≥0; now signed).
155
+ r_cric = (dls_par - total_score) / max(dls_par, 1.0)
156
+ # Progress bonus: small reward for actually executing balls instead of getting
157
+ # stuck in a planning loop. Reduced cap (0.25 → 0.10) so it doesn't drown out
158
+ # the loss penalty. Caps at +0.10 once the agent makes >=10 tool calls.
159
+ progress_bonus = min(0.10, tool_calls_made / 100.0)
160
+ r_cric = r_cric + progress_bonus
161
+ r_cric = max(-2.0, min(2.5, r_cric))
162
 
163
  # r_cricket: dense per-ball position signal via Dream11 proxy.
164
  # Normalised per innings then averaged so two-innings totals stay in [0, 1].
165
+ # When no innings has completed yet, dream11_scores is empty: redistribute
166
+ # its weight to r_result so the composite stays in the same [0, 1] range
167
+ # instead of silently capping at 0.75.
168
  if dream11_scores:
169
  r_dream11 = mean(normalize_dream11(s) for s in dream11_scores)
170
+ r_cricket_available = True
171
  else:
172
  r_dream11 = 0.0
173
+ r_cricket_available = False
174
 
175
  # Load weights from game_knowledge.yaml (cached after first load).
176
  w = get_reward_weights() if get_reward_weights is not None else None
 
190
  r_tools = compute_tool_efficiency(tool_calls_made, analyze_calls, overs_played)
191
 
192
  eff_behavior_w = (w.r_behavior if w else 0.15) * coherence_weight_ramp
193
+ r_result_w = w.r_result if w else 0.55
194
+ r_cricket_w = w.r_cricket if w else 0.25
195
+ r_validity_w = w.r_validity if w else 0.05
196
+
197
+ # If no innings has completed, fold r_cricket weight into r_result so the
198
+ # composite ceiling stays at 1.0 and gradients are not systematically suppressed.
199
+ if not r_cricket_available:
200
+ r_result_w = r_result_w + r_cricket_w
201
+ r_cricket_w = 0.0
202
 
203
  composite = (
204
+ r_result_w * r_cric
205
+ + r_cricket_w * r_dream11
206
  + eff_behavior_w * r_strategy
207
+ + r_validity_w * r_format
208
  )
209
 
210
  return {
train.py CHANGED
@@ -2,26 +2,32 @@
2
  MT-GRPO training script for CricketCaptain-LLM.
3
 
4
  Two-stage curriculum (ToolRL-style):
5
- Stage 1 (steps 0–N): format mastery — reward only valid JSON
6
- Stage 2 (steps N–M): full 4-rubric reward coherence + cricket score + tools
7
 
8
  Design:
9
- - Prompts are collected by running the CricketEnvironment directly (no server needed)
10
- - GRPOTrainer generates its own completions and calls our stateless reward_fn
11
- - reward_fn(prompts, completions, **kwargs) no shared env state required
12
- - Plain TRL + transformers + bitsandbytes (no Unsloth)
13
-
14
- Usage:
15
- python train.py --stage 1 --steps 200 --model Qwen/Qwen2.5-7B-Instruct
16
- python train.py --stage 2 --steps 600 --model ./checkpoints/stage1_final
 
 
 
17
  """
18
 
19
  import argparse
 
 
20
  import json
21
  import os
22
  import random
23
  import re
24
  import sys
 
25
  import time
26
  from pathlib import Path
27
  from typing import Any
@@ -32,12 +38,16 @@ from typing import Any
32
  try:
33
  import torch
34
  from datasets import Dataset
 
35
  from trl import GRPOConfig, GRPOTrainer
36
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
37
  _TRAIN_IMPORTS_AVAILABLE = True
38
  except ImportError:
39
  torch = None
40
  Dataset = None
 
 
 
41
  GRPOConfig = None
42
  GRPOTrainer = None
43
  AutoModelForCausalLM = None
@@ -49,12 +59,14 @@ try:
49
  from server.cricket_environment import CricketEnvironment
50
  from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
51
  from server.markov_engine import SHOT_AGGRESSION
 
52
  from models import CricketAction
53
  from config_yaml import get_game_constants, get_reward_weights
54
  except ImportError:
55
  from cricket_captain.server.cricket_environment import CricketEnvironment
56
  from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
57
  from cricket_captain.server.markov_engine import SHOT_AGGRESSION
 
58
  from cricket_captain.models import CricketAction
59
  from cricket_captain.config_yaml import get_game_constants, get_reward_weights
60
 
@@ -106,37 +118,172 @@ def extract_phase_from_prompt(prompt: str) -> str:
106
  # Per-turn reward components (all stateless) #
107
  # ------------------------------------------------------------------ #
108
 
 
 
 
 
109
  def _parse_completion(raw: str) -> dict | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  raw = raw.strip()
 
 
 
 
 
 
 
111
  if raw.startswith("```"):
112
  lines = raw.split("\n")
113
  raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw
 
 
114
  try:
115
  return json.loads(raw)
116
  except (json.JSONDecodeError, ValueError):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  return None
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def r_validity(completion: str) -> float:
121
- """r_validity: legal tool call with valid required fields. Returns 0 or 1."""
 
 
 
 
 
122
  data = _parse_completion(completion)
123
  if data is None:
124
  return 0.0
 
 
125
  tool = data.get("tool", "")
126
  args = data.get("arguments", {})
 
 
127
  if tool not in _VALID_TOOLS:
128
- return 0.0
 
 
129
  if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION:
130
- return 0.0
131
  if tool == "set_strategy":
132
  agg = args.get("aggression")
133
  if not isinstance(agg, (int, float)):
134
- return 0.0
135
  if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION:
136
- return 0.0
137
  if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}:
138
  if args.get("bowler_type") not in (None, "pace", "spin"):
139
- return 0.0
140
  return 1.0
141
 
142
 
@@ -299,13 +446,9 @@ def make_reward_fn(curriculum_stage: int):
299
  Returns reward_fn(prompts, completions, **kwargs) → list[float].
300
 
301
  Weights align with compute_episode_reward in reward_calculator.py:
302
- r_result 55% not computable stateless; omitted, remainder re-scaled
303
- r_cricket 25% not computable stateless; omitted
304
- r_behavior 15% → 0.75 of stateless composite
305
- r_validity 5% → 0.25 of stateless composite
306
-
307
- Scaling: behaviour=0.75, validity=0.25 preserves the relative 15:5 ratio
308
- from the episode-level rubric while using only what's available per-turn.
309
  """
310
  # Minimum reward for any structurally valid completion — ensures GRPO has a
311
  # positive gradient to reinforce valid tool use even for unscored tool types.
@@ -315,8 +458,18 @@ def make_reward_fn(curriculum_stage: int):
315
  rewards = []
316
  for prompt, completion in zip(prompts, completions):
317
  fmt = r_validity(completion)
 
318
  if curriculum_stage == 1:
319
- rewards.append(fmt)
 
 
 
 
 
 
 
 
 
320
  continue
321
 
322
  behavior = r_behavior_stateless(prompt, completion)
@@ -327,9 +480,12 @@ def make_reward_fn(curriculum_stage: int):
327
  + _RW.behavior_adaptation * adapt
328
  + _RW.behavior_opponent_awareness * aware
329
  )
330
- reward = _RW.training_behavior * r_beh + _RW.training_validity * fmt
 
 
 
331
  # Floor: valid JSON should always beat invalid JSON (reward=0)
332
- if fmt > 0.0:
333
  reward = max(reward, _VALID_FLOOR)
334
  rewards.append(round(reward, 4))
335
  return rewards
@@ -344,66 +500,136 @@ def make_reward_fn(curriculum_stage: int):
344
 
345
  SYSTEM_PROMPT = (
346
  "You are an expert adaptive cricket captain. Each turn you receive a scorecard "
347
- "and must respond with a SINGLE valid JSON tool call.\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  "Available tools:\n"
349
- " call_toss — Call heads/tails and choose bat/bowl\n"
350
- " select_batter — Choose batter profile for the match situation\n"
351
- " set_strategy — Declare batting intent (aggression 0–1, rationale)\n"
352
- " plan_shot — Pre-ball batting plan\n"
353
- " play_delivery — Choose a shot and advance the game\n\n"
354
- " choose_bowler — Choose bowler profile for the situation\n"
355
- " set_bowling_strategy — Declare bowling line/length/type/rationale\n"
356
- " plan_delivery — Pre-ball bowling plan\n"
357
- " set_field_setting — Aggressive/Balanced/Defensive field\n"
358
- " bowl_delivery — Execute the delivery\n"
359
- " reflect_after_ball — Adapt after the previous ball\n"
360
- " analyze_situation — Query pitch/bowler/field info\n\n"
361
  "Shot intents: leave | defensive | single | rotate | boundary | six\n\n"
362
- "Be specific about phase, target pressure, opponent plan, field, batter, and bowler.\n\n"
363
- "Respond with exactly one JSON object on a single line."
364
  )
365
 
 
 
 
 
366
  _RANDOM_SHOTS = list(SHOT_AGGRESSION.keys())
367
  _RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"]
368
  _RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"]
369
 
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  def _random_action(
372
  rng: random.Random,
373
  game_state: str = "batting",
374
  available_tools: list[str] | None = None,
375
  current_bowler_type: str | None = None,
 
376
  ) -> CricketAction:
377
- def allowed(action: CricketAction) -> CricketAction:
378
- if available_tools is None or action.tool in available_tools:
379
- return action
380
- if "bowl_delivery" in available_tools:
381
- return CricketAction(tool="bowl_delivery", arguments={})
382
- if "play_delivery" in available_tools:
383
- return CricketAction(tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"})
384
- if "call_toss" in available_tools:
385
- return CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
386
- return action
 
 
 
 
387
 
388
  if game_state == "toss":
389
- return allowed(CricketAction(
390
  tool="call_toss",
391
  arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])},
392
- ))
 
 
 
 
 
 
 
 
 
 
393
  if game_state == "bowling":
394
  choice = rng.random()
395
- if choice < 0.15:
396
- return allowed(CricketAction(
 
397
  tool="choose_bowler",
398
  arguments={
399
- "name": rng.choice(["Strike Pacer", "Control Spinner", "Death Specialist"]),
400
- "bowler_type": rng.choice(["pace", "spin"]),
401
- "style": rng.choice(["swing", "economy", "yorker"]),
402
- "rationale": "Match bowler to phase and batter matchup",
403
  },
404
- ))
405
- if choice < 0.35:
406
- return allowed(CricketAction(
407
  tool="plan_delivery",
408
  arguments={
409
  "bowler_type": current_bowler_type or rng.choice(["pace", "spin"]),
@@ -412,74 +638,96 @@ def _random_action(
412
  "delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]),
413
  "rationale": "Use field and batter style to control scoring zones",
414
  },
415
- ))
416
- if choice < 0.5:
417
- return allowed(CricketAction(tool="set_field_setting", arguments={"setting": rng.choice(["Aggressive", "Balanced", "Defensive"])}))
418
- if choice < 0.6:
419
- return allowed(CricketAction(tool="reflect_after_ball", arguments={"reflection": "Adjust line and field after the last ball"}))
420
- return allowed(CricketAction(tool="bowl_delivery", arguments={}))
 
 
 
 
 
 
 
 
 
 
421
 
422
  choice = rng.random()
423
- if choice < 0.15:
424
- return allowed(CricketAction(
 
425
  tool="select_batter",
426
  arguments={
427
- "name": rng.choice(["Opener", "Anchor", "Finisher"]),
428
- "style": rng.choice(["balanced", "anchor", "hitter", "finisher"]),
429
- "aggression": round(rng.uniform(0.2, 0.8), 2),
430
  "rationale": "Select batter based on phase, wickets, and target pressure",
431
  },
432
- ))
433
- if choice < 0.3:
434
- return allowed(CricketAction(
435
  tool="set_strategy",
436
  arguments={
437
  "phase_intent": rng.choice(["attack", "consolidate", "rotate"]),
438
  "aggression": round(rng.uniform(0.1, 0.9), 2),
439
- "rationale": "random rollout",
440
  },
441
- ))
442
- if choice < 0.45:
443
- return allowed(CricketAction(
444
  tool="plan_shot",
445
  arguments={
446
  "shot_intent": rng.choice(_RANDOM_SHOTS),
447
- "target_area": rng.choice(_RANDOM_ZONES),
448
- "trajectory": rng.choice(["ground", "lofted", "aerial"]),
449
  "risk": rng.choice(["low", "balanced", "high"]),
450
  "rationale": "Plan shot against bowler, field, and required rate",
451
  },
452
- ))
453
- if choice < 0.55:
454
- return allowed(CricketAction(
455
  tool="analyze_situation",
456
  arguments={"query_type": rng.choice(_RANDOM_QUERIES)},
457
- ))
458
- if choice < 0.65:
459
- return allowed(CricketAction(tool="reflect_after_ball", arguments={"reflection": "Revise risk after previous ball"}))
460
- return allowed(CricketAction(
461
- tool="play_delivery",
462
- arguments={"shot_intent": rng.choice(_RANDOM_SHOTS), "explanation": "random"},
463
- ))
 
 
464
 
465
 
466
  def collect_prompts(
467
  n_prompts: int,
468
  task: str = "stage2_full",
469
  seed: int = 42,
 
 
470
  ) -> list[str]:
471
  """
472
  Collect game-state prompts by running episodes with random actions.
473
  Returns a list of prompt strings (one per game state observation).
474
  """
475
  rng = random.Random(seed)
 
 
476
  prompts: list[str] = []
477
  ep_count = 0
478
 
479
  while len(prompts) < n_prompts:
480
  env = CricketEnvironment()
481
- obs = env.reset(seed=rng.randint(0, 99999), options={"task": task, "random_start": True})
482
- prompts.append(_format_prompt(obs.prompt_text))
 
 
 
 
 
483
  steps = 0
484
 
485
  while not obs.done and steps < 80:
@@ -488,10 +736,11 @@ def collect_prompts(
488
  obs.game_state,
489
  obs.available_tools,
490
  obs.current_bowler.get("type") if obs.current_bowler else None,
 
491
  )
492
  obs = env.step(action)
493
  if not obs.done:
494
- prompts.append(_format_prompt(obs.prompt_text))
495
  steps += 1
496
 
497
  ep_count += 1
@@ -513,13 +762,548 @@ def build_dataset(prompts: list[str]) -> Dataset:
513
  return Dataset.from_dict({"prompt": prompts})
514
 
515
 
516
- def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  """Stage 0 bootstrap data: valid tool JSON for every tool family."""
518
  rng = random.Random(seed)
 
519
  examples = []
520
  for _ in range(n_examples):
521
  game_state = rng.choice(["toss", "batting", "bowling"])
522
- action = _random_action(rng, game_state)
523
  prompt = (
524
  f"{SYSTEM_PROMPT}\n\n"
525
  f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n"
@@ -546,28 +1330,68 @@ def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42):
546
  # Model loading (plain transformers + bitsandbytes 4-bit) #
547
  # ------------------------------------------------------------------ #
548
 
549
- def load_model(model_name: str):
 
 
 
 
 
 
 
550
  if not _TRAIN_IMPORTS_AVAILABLE:
551
  raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
552
- print(f"Loading {model_name} …")
553
- bnb_cfg = BitsAndBytesConfig(
554
- load_in_4bit=True,
555
- bnb_4bit_compute_dtype=torch.bfloat16,
556
- bnb_4bit_use_double_quant=True,
557
- bnb_4bit_quant_type="nf4",
558
- )
559
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
560
  if tokenizer.pad_token is None:
561
  tokenizer.pad_token = tokenizer.eos_token
562
 
563
- model = AutoModelForCausalLM.from_pretrained(
564
- model_name,
565
- quantization_config=bnb_cfg,
 
 
 
 
566
  device_map="auto",
567
  trust_remote_code=True,
568
  torch_dtype=torch.bfloat16,
 
569
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  print(f"Loaded. Parameters: {model.num_parameters():,}")
 
571
  return model, tokenizer
572
 
573
 
@@ -578,52 +1402,173 @@ def load_model(model_name: str):
578
  def train(args):
579
  if not _TRAIN_IMPORTS_AVAILABLE:
580
  raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
 
 
 
 
 
 
 
581
  task = "stage1_format" if args.stage == 1 else "stage2_full"
582
- out_dir = f"./checkpoints/stage{args.stage}"
583
- save_dir = f"./checkpoints/stage{args.stage}_final"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
  print(f"\n=== Stage {args.stage} Training ===")
586
  print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}")
587
-
588
- # Collect prompts
589
- print("\nCollecting game-state prompts …")
590
- prompts = collect_prompts(args.prompts, task=task, seed=args.seed)
591
- dataset = build_dataset(prompts)
592
-
593
- # Load model
594
- model, tokenizer = load_model(args.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
  # GRPO config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  config = GRPOConfig(
598
  output_dir=out_dir,
 
599
  num_train_epochs=1,
600
  max_steps=args.steps,
601
  per_device_train_batch_size=args.batch_size,
602
  gradient_accumulation_steps=args.grad_accum,
603
- learning_rate=2e-5 if args.stage == 1 else 1e-5,
604
  warmup_ratio=0.05,
605
  lr_scheduler_type="cosine",
606
- logging_steps=10,
607
- save_steps=max(50, args.steps // 5),
608
- save_total_limit=2,
609
  bf16=True,
610
- max_prompt_length=512,
611
- max_completion_length=256,
612
  num_generations=args.num_generations,
613
- temperature=0.8,
614
- report_to="none",
 
 
615
  log_completions=True,
616
  seed=args.seed,
 
 
 
 
617
  )
618
 
619
- reward_fn = make_reward_fn(args.stage)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
  trainer = GRPOTrainer(
622
  model=model,
623
- reward_funcs=reward_fn,
624
  args=config,
625
  train_dataset=dataset,
626
  processing_class=tokenizer,
 
627
  )
628
 
629
  print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …")
@@ -659,7 +1604,11 @@ def evaluate(args):
659
 
660
  for ep in range(args.eval_episodes):
661
  env = CricketEnvironment()
662
- obs = env.reset(seed=rng.randint(0, 99999), options={"task": "stage2_full", "random_start": False})
 
 
 
 
663
  steps = 0
664
 
665
  while not obs.done and steps < 150:
@@ -670,12 +1619,7 @@ def evaluate(args):
670
  if data:
671
  action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {}))
672
  else:
673
- if obs.game_state == "bowling":
674
- action = CricketAction(tool="bowl_delivery", arguments={})
675
- elif obs.game_state == "toss":
676
- action = CricketAction(tool="call_toss", arguments={"call": "heads", "decision": "bat"})
677
- else:
678
- action = CricketAction(tool="play_delivery", arguments={"shot_intent": "defensive", "explanation": "fallback"})
679
 
680
  obs = env.step(action)
681
  steps += 1
@@ -706,6 +1650,7 @@ def _make_run_folder(prefix: str, model: str | None, opponent_mode: str | None,
706
  def train_smoke(args):
707
  """Run short direct-environment training rollouts without loading a model."""
708
  rng = random.Random(args.seed)
 
709
 
710
  # Auto-create run folder unless --output explicitly given
711
  if args.output:
@@ -746,6 +1691,7 @@ def train_smoke(args):
746
  "eval_pack_id": args.eval_pack_id,
747
  "opponent_mode": args.opponent_mode,
748
  "opponent_cache_path": args.opponent_cache_path,
 
749
  })
750
  prompts = [_format_prompt(obs.prompt_text)]
751
  total_reward = 0.0
@@ -763,6 +1709,7 @@ def train_smoke(args):
763
  obs.game_state,
764
  obs.available_tools,
765
  obs.current_bowler.get("type") if obs.current_bowler else None,
 
766
  )
767
  obs = env.step(action)
768
  step_end = time.perf_counter()
@@ -855,6 +1802,7 @@ def train_smoke(args):
855
  def _apply_yaml_defaults(args, cfg: dict) -> None:
856
  """Merge YAML config values into args, CLI args take precedence."""
857
  captain = cfg.get("captain", {}) or {}
 
858
  env_cfg = cfg.get("env", {}) or {}
859
  train_cfg = cfg.get("train", {}) or {}
860
 
@@ -862,17 +1810,48 @@ def _apply_yaml_defaults(args, cfg: dict) -> None:
862
  if val is not None and getattr(args, attr, None) is None:
863
  setattr(args, attr, val)
864
 
865
- _set("model", captain.get("model"))
 
 
 
866
  _set("api_base", captain.get("api_base"))
867
  _set("api_key", os.environ.get(captain.get("api_key_env", "")) or None)
868
  _set("eval_pack_id", env_cfg.get("eval_pack_id"))
869
- _set("opponent_mode", cfg.get("opponent", {}).get("mode"))
870
- _set("opponent_cache_path", cfg.get("opponent", {}).get("cache_path"))
 
 
 
 
871
  _set("max_overs", env_cfg.get("max_overs"))
 
872
  _set("steps", train_cfg.get("steps"))
873
  _set("prompts", train_cfg.get("prompts"))
874
  _set("batch_size", train_cfg.get("batch_size"))
 
875
  _set("stage", train_cfg.get("stage"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876
 
877
 
878
  def main():
@@ -888,15 +1867,55 @@ def main():
888
  t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect")
889
  t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps")
890
  t.add_argument("--batch-size", type=int, default=None, dest="batch_size")
891
- t.add_argument("--grad-accum", type=int, default=4, dest="grad_accum")
892
- t.add_argument("--num-generations", type=int, default=4, dest="num_generations")
 
 
 
 
 
 
 
 
 
 
 
 
 
893
  t.add_argument("--seed", type=int, default=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
894
 
895
  # eval
896
  e = sub.add_parser("eval", help="Evaluate a checkpoint")
897
  e.add_argument("--config", default=None)
898
  e.add_argument("--model", default=None)
899
  e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes")
 
900
  e.add_argument("--seed", type=int, default=0)
901
 
902
  # quick test (no GPU needed)
@@ -911,12 +1930,14 @@ def main():
911
  smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
912
  smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
913
  smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
 
914
  smoke.add_argument("--output", default=None)
915
  smoke.add_argument("--seed", type=int, default=42)
916
 
917
  sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples")
918
  sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl")
919
  sft.add_argument("--examples", type=int, default=240)
 
920
  sft.add_argument("--seed", type=int, default=42)
921
 
922
  args = parser.parse_args()
@@ -934,35 +1955,45 @@ def main():
934
  if getattr(args, "stage", None) is None:
935
  args.stage = 1
936
  if getattr(args, "model", None) is None:
937
- args.model = "Qwen/Qwen2.5-7B-Instruct"
938
  if getattr(args, "steps", None) is None:
939
  args.steps = 200
940
  if getattr(args, "prompts", None) is None:
941
  args.prompts = 500
942
  if getattr(args, "batch_size", None) is None:
943
  args.batch_size = 2
 
 
944
  if getattr(args, "eval_pack_id", None) is None:
945
  args.eval_pack_id = "adaptive_t20_v1"
946
  if getattr(args, "opponent_mode", None) is None:
947
  args.opponent_mode = "llm_live"
948
  if getattr(args, "max_overs", None) is None:
949
  args.max_overs = 5
 
 
 
 
 
 
 
 
950
 
951
  if args.cmd == "train":
952
  train(args)
953
  elif args.cmd == "eval":
954
  evaluate(args)
955
  elif args.cmd == "test":
956
- _smoke_test()
957
  elif args.cmd == "train-smoke":
958
  train_smoke(args)
959
  elif args.cmd == "sft-data":
960
- generate_sft_examples(args.output, args.examples, args.seed)
961
  else:
962
  parser.print_help()
963
 
964
 
965
- def _smoke_test():
966
  """Verify reward functions work correctly."""
967
  cases = [
968
  (
@@ -990,12 +2021,12 @@ def _smoke_test():
990
  ]
991
  print("Reward function smoke test:\n")
992
  for prompt, completion, expected in cases:
993
- fmt = r_format(completion)
994
- coh = r_coherence_stateless(prompt, completion)
995
  print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}")
996
 
997
  print("\nPrompt collection test (5 prompts):")
998
- p = collect_prompts(5, task="stage1_format", seed=1)
999
  for i, pp in enumerate(p):
1000
  print(f" [{i}] {pp[:80].strip()} …")
1001
 
 
2
  MT-GRPO training script for CricketCaptain-LLM.
3
 
4
  Two-stage curriculum (ToolRL-style):
5
+ Stage 1: tool-call mastery — emphasize valid, phase-legal tool usage
6
+ Stage 2: strategic behavior — full environment-backed reward (result + cricket + behavior + validity)
7
 
8
  Design:
9
+ - Training uses TRL GRPO with environment_factory=CricketCaptainToolEnv
10
+ - The model interacts with live CricketEnvironment instances over multi-turn tool calls
11
+ - Rewards are collected from the environment (environment_reward), not only from stateless prompt parsing
12
+ - The opponent policy is part of the environment: heuristic/cricsheet/llm_live/llm_cached
13
+ - Plain TRL + Transformers + bitsandbytes + PEFT (LoRA adapters for 4-bit models)
14
+
15
+ Usage (canonical Qwen3 setup):
16
+ python train.py train --config configs/cricket_train_qwen3_warmup.yaml # warmup
17
+ python train.py train --config configs/cricket_train_qwen3.yaml # main 5-over
18
+
19
+ Legacy Qwen3.5 configs live in configs/extras/.
20
  """
21
 
22
  import argparse
23
+ import copy
24
+ import datetime
25
  import json
26
  import os
27
  import random
28
  import re
29
  import sys
30
+ import threading
31
  import time
32
  from pathlib import Path
33
  from typing import Any
 
38
  try:
39
  import torch
40
  from datasets import Dataset
41
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
42
  from trl import GRPOConfig, GRPOTrainer
43
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
44
  _TRAIN_IMPORTS_AVAILABLE = True
45
  except ImportError:
46
  torch = None
47
  Dataset = None
48
+ LoraConfig = None
49
+ get_peft_model = None
50
+ prepare_model_for_kbit_training = None
51
  GRPOConfig = None
52
  GRPOTrainer = None
53
  AutoModelForCausalLM = None
 
59
  from server.cricket_environment import CricketEnvironment
60
  from server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
61
  from server.markov_engine import SHOT_AGGRESSION
62
+ from server.player_roster import build_playing_xi, load_team_roster
63
  from models import CricketAction
64
  from config_yaml import get_game_constants, get_reward_weights
65
  except ImportError:
66
  from cricket_captain.server.cricket_environment import CricketEnvironment
67
  from cricket_captain.server.coherence_grader import aggression_match, phase_appropriate, rationale_specificity
68
  from cricket_captain.server.markov_engine import SHOT_AGGRESSION
69
+ from cricket_captain.server.player_roster import build_playing_xi, load_team_roster
70
  from cricket_captain.models import CricketAction
71
  from cricket_captain.config_yaml import get_game_constants, get_reward_weights
72
 
 
118
  # Per-turn reward components (all stateless) #
119
  # ------------------------------------------------------------------ #
120
 
121
+ _XML_FN_RE = re.compile(r"<function\s*=?\s*([^>\s]+)\s*>", re.IGNORECASE)
122
+ _XML_PARAM_RE = re.compile(r"<parameter\s*=\s*([^>\s]+)\s*>(.*?)</parameter>", re.IGNORECASE | re.DOTALL)
123
+
124
+
125
  def _parse_completion(raw: str) -> dict | None:
126
+ """Parse a tool-call from the raw completion into our canonical {tool, arguments} dict.
127
+
128
+ Handles four common model output patterns:
129
+ 1. Plain JSON (ideal).
130
+ 2. Markdown code block (```json ... ```).
131
+ 3. Thinking-model preamble: <think>...</think> followed by JSON.
132
+ Qwen3/Qwen3.5 in default mode emits reasoning inside <think> tags;
133
+ we strip everything up to and including the closing </think> tag.
134
+ 4. XML function-call format that Qwen3.5 was trained on:
135
+ <function=tool_name><parameter=foo>bar</parameter>...</function>
136
+ Empirically (see logs/run_2026-04-25_21-08-45) every Stage-1 completion
137
+ emitted this XML form instead of JSON — so we extract it as a fallback
138
+ to give GRPO a non-zero gradient before the model has been trained
139
+ onto the JSON contract.
140
+ """
141
  raw = raw.strip()
142
+
143
+ # Strip <think>...</think> preamble emitted by thinking-mode models.
144
+ if "<think>" in raw:
145
+ think_end = raw.rfind("</think>")
146
+ if think_end != -1:
147
+ raw = raw[think_end + len("</think>"):].strip()
148
+
149
  if raw.startswith("```"):
150
  lines = raw.split("\n")
151
  raw = "\n".join(lines[1:-1]) if len(lines) > 2 else raw
152
+
153
+ # Try parsing the whole string, then fall back to the first {...} block.
154
  try:
155
  return json.loads(raw)
156
  except (json.JSONDecodeError, ValueError):
157
+ pass
158
+
159
+ start = raw.find("{")
160
+ end = raw.rfind("}")
161
+ if start != -1 and end > start:
162
+ try:
163
+ return json.loads(raw[start : end + 1])
164
+ except (json.JSONDecodeError, ValueError):
165
+ pass
166
+
167
+ # XML function-call fallback (Qwen3.5 default tool-call emission style).
168
+ fn_match = _XML_FN_RE.search(raw)
169
+ if fn_match:
170
+ tool = fn_match.group(1).strip().strip("\"'")
171
+ arguments: dict[str, Any] = {}
172
+ for pname, pval in _XML_PARAM_RE.findall(raw):
173
+ v = pval.strip()
174
+ # Coerce numeric/bool literals so downstream validators accept them.
175
+ try:
176
+ arguments[pname] = json.loads(v)
177
+ except (json.JSONDecodeError, ValueError):
178
+ arguments[pname] = v
179
+ return {"tool": tool, "arguments": arguments}
180
+
181
+ return None
182
+
183
+
184
+ # Bounded LRU-ish cache. Each snapshot is a deepcopy of CricketEnvironment
185
+ # (~1 MB) and only used by the LEGACY single-turn r_environment_rollout path,
186
+ # not by the multi-turn environment_factory training path. Cap at 4096 entries
187
+ # (~4 GB worst case) so a long collect_prompts call can't blow up host RAM.
188
+ _PROMPT_ENV_SNAPSHOTS: dict[str, CricketEnvironment] = {}
189
+ _PROMPT_SNAPSHOT_CAP = 4096
190
+ _ENV_REWARD_ROLLOUT_STEPS = 12
191
+
192
+
193
+ def _remember_prompt(obs_text: str, env: CricketEnvironment) -> str:
194
+ """Format an observation and keep the exact env state for rollout reward."""
195
+ prompt = _format_prompt(obs_text)
196
+ if len(_PROMPT_ENV_SNAPSHOTS) >= _PROMPT_SNAPSHOT_CAP:
197
+ # Evict oldest insertion (dict preserves insertion order in py3.7+).
198
+ oldest_key = next(iter(_PROMPT_ENV_SNAPSHOTS))
199
+ del _PROMPT_ENV_SNAPSHOTS[oldest_key]
200
+ _PROMPT_ENV_SNAPSHOTS[prompt] = copy.deepcopy(env)
201
+ return prompt
202
+
203
+
204
+ def r_environment_rollout(prompt: str, completion: str) -> float | None:
205
+ """Env-backed score for a generated tool call plus short continuation.
206
+
207
+ Returns None when the prompt was not collected from an env snapshot, allowing
208
+ callers to fall back to stateless scoring. Otherwise returns [0, 1], where 0
209
+ means invalid JSON/tool-for-state and higher values reflect the env reward.
210
+ """
211
+ snapshot = _PROMPT_ENV_SNAPSHOTS.get(prompt)
212
+ if snapshot is None:
213
  return None
214
 
215
+ data = _parse_completion(completion)
216
+ if data is None:
217
+ return 0.0
218
+
219
+ tool = data.get("tool", "")
220
+ args = data.get("arguments", {})
221
+ if not isinstance(args, dict):
222
+ return 0.0
223
+
224
+ env = copy.deepcopy(snapshot)
225
+ if tool not in env._get_available_tools():
226
+ return 0.0
227
+
228
+ try:
229
+ obs = env.step(CricketAction(tool=tool, arguments=args))
230
+ except Exception:
231
+ return 0.0
232
+
233
+ reward = float(obs.reward or 0.0)
234
+ rng = random.Random(hash(prompt + completion) & 0xFFFFFFFF)
235
+ roster = build_playing_xi(getattr(env, "_agent_roster", []))
236
+ for _ in range(_ENV_REWARD_ROLLOUT_STEPS):
237
+ if obs.done:
238
+ break
239
+ action = _random_action(
240
+ rng,
241
+ obs.game_state,
242
+ obs.available_tools,
243
+ obs.current_bowler.get("type") if obs.current_bowler else None,
244
+ roster,
245
+ )
246
+ obs = env.step(action)
247
+ reward += float(obs.reward or 0.0)
248
+
249
+ if obs.done and env.state.reward_breakdown:
250
+ reward += float(env.state.reward_breakdown.get("composite", 0.0))
251
+
252
+ # Map rollout reward into [0,1] while preserving penalties for bad tool choices.
253
+ return round(max(0.0, min(1.0, 0.5 + reward)), 4)
254
+
255
 
256
  def r_validity(completion: str) -> float:
257
+ """Schema reward for tool calling.
258
+
259
+ Exact env-executable calls receive 1.0. Malformed but parseable JSON gets a
260
+ small shaping signal so early GRPO has non-zero variance before the model has
261
+ learned the strict `{"tool": ..., "arguments": {...}}` contract.
262
+ """
263
  data = _parse_completion(completion)
264
  if data is None:
265
  return 0.0
266
+ if not isinstance(data, dict):
267
+ return 0.05
268
  tool = data.get("tool", "")
269
  args = data.get("arguments", {})
270
+ if "tool" not in data or "arguments" not in data:
271
+ return 0.15
272
  if tool not in _VALID_TOOLS:
273
+ return 0.25
274
+ if not isinstance(args, dict):
275
+ return 0.35
276
  if tool == "play_delivery" and args.get("shot_intent") not in SHOT_AGGRESSION:
277
+ return 0.5
278
  if tool == "set_strategy":
279
  agg = args.get("aggression")
280
  if not isinstance(agg, (int, float)):
281
+ return 0.5
282
  if tool == "plan_shot" and args.get("shot_intent") not in SHOT_AGGRESSION:
283
+ return 0.5
284
  if tool in {"choose_bowler", "set_bowling_strategy", "plan_delivery"}:
285
  if args.get("bowler_type") not in (None, "pace", "spin"):
286
+ return 0.5
287
  return 1.0
288
 
289
 
 
446
  Returns reward_fn(prompts, completions, **kwargs) → list[float].
447
 
448
  Weights align with compute_episode_reward in reward_calculator.py:
449
+ r_env one-step env rollout reward when prompt snapshot exists
450
+ r_behavior — stateless tactical/tool coherence
451
+ r_validity JSON/tool schema validity
 
 
 
 
452
  """
453
  # Minimum reward for any structurally valid completion — ensures GRPO has a
454
  # positive gradient to reinforce valid tool use even for unscored tool types.
 
458
  rewards = []
459
  for prompt, completion in zip(prompts, completions):
460
  fmt = r_validity(completion)
461
+ env_score = r_environment_rollout(prompt, completion)
462
  if curriculum_stage == 1:
463
+ # Length-efficiency penalty: a valid JSON tool call is ≤400 chars.
464
+ # Models with thinking mode (Qwen3/3.5) generate 800-2000 char
465
+ # preambles before the JSON; penalise that verbosity so GRPO
466
+ # learns to emit short, direct JSON. The penalty scales from
467
+ # 1.0 at ≤400 chars to 0.0 at ≥2400 chars (linear).
468
+ _JSON_TARGET = 400
469
+ _RAMP_RANGE = 2000
470
+ length_eff = max(0.0, 1.0 - max(0, len(completion) - _JSON_TARGET) / _RAMP_RANGE)
471
+ base = 0.5 * fmt + 0.5 * (env_score if env_score is not None else fmt)
472
+ rewards.append(round(length_eff * base, 4))
473
  continue
474
 
475
  behavior = r_behavior_stateless(prompt, completion)
 
480
  + _RW.behavior_adaptation * adapt
481
  + _RW.behavior_opponent_awareness * aware
482
  )
483
+ if env_score is None:
484
+ reward = _RW.training_behavior * r_beh + _RW.training_validity * fmt
485
+ else:
486
+ reward = 0.45 * env_score + 0.40 * r_beh + 0.15 * fmt
487
  # Floor: valid JSON should always beat invalid JSON (reward=0)
488
+ if fmt > 0.0 and (env_score is None or env_score > 0.0):
489
  reward = max(reward, _VALID_FLOOR)
490
  rewards.append(round(reward, 4))
491
  return rewards
 
500
 
501
  SYSTEM_PROMPT = (
502
  "You are an expert adaptive cricket captain. Each turn you receive a scorecard "
503
+ "and must choose exactly one cricket captaincy tool call.\n\n"
504
+ "EXECUTE FIRST — strict rule:\n"
505
+ " - The match only progresses when you call `play_delivery` (batting) or\n"
506
+ " `bowl_delivery` (bowling). Every other tool is overhead.\n"
507
+ " - Default action on EVERY ball: call `play_delivery` / `bowl_delivery` with\n"
508
+ " plan args INLINE: e.g. `play_delivery(shot_intent='single', risk='low', rationale='rotate')`\n"
509
+ " or `bowl_delivery(line='outside_off', length='good', delivery_type='stock')`.\n"
510
+ " - Use `set_match_plan` ONCE at the very start of an innings to declare strategy.\n"
511
+ " - Use `set_strategy` / `set_bowling_strategy` ONCE per phase boundary.\n"
512
+ " - DO NOT call `plan_shot` or `plan_delivery` (deprecated) — they only add a\n"
513
+ " wasted turn. Pass the same parameters to play_delivery / bowl_delivery directly.\n"
514
+ " - SKIP `reflect_after_ball` unless the previous ball was a wicket or boundary.\n"
515
+ " - You are scored on MATCH OUTCOMES, not on philosophical depth. Bloated\n"
516
+ " pre-ball planning truncates the episode and you forfeit the result reward.\n\n"
517
+ "THINKING BUDGET — HARD LIMIT:\n"
518
+ " - Per turn: ONE sentence of reasoning, max 30 tokens, inside <think>...</think>.\n"
519
+ " - Do NOT enumerate options, restate the scorecard, or re-derive the plan.\n"
520
+ " - Bad: '<think>This is the first ball, the field is balanced, Kohli is on strike at 0.45 aggression, I should consider...'\n"
521
+ " - Good: '<think>Powerplay, balanced field — single to rotate.</think>'\n"
522
+ " - Token budget per rollout is finite. Long thinking = match truncated = ZERO result reward.\n"
523
+ " - The plan you set at the start carries the strategy; do not re-derive it every ball.\n\n"
524
+ "Emit exactly one tool call wrapped in <tool_call>...</tool_call> XML tags. "
525
+ "Bare JSON without the wrapper is NOT recognized and will end the rollout.\n"
526
+ 'Example: <tool_call>{"name": "play_delivery", "arguments": {"shot_intent": "single", "explanation": "rotate strike"}}</tool_call>\n\n'
527
  "Available tools:\n"
528
+ " call_toss — Call heads/tails and choose bat/bowl\n"
529
+ " select_batter — Choose batter profile for the match situation\n"
530
+ " set_strategy — Declare batting intent (aggression 0–1, rationale)\n"
531
+ " plan_shot — Pre-ball batting plan\n"
532
+ " play_delivery — Choose a shot and advance the game\n"
533
+ " choose_bowler — Choose bowler profile for the situation\n"
534
+ " set_bowling_strategy — Declare bowling line/length/type/rationale\n"
535
+ " plan_delivery — Pre-ball bowling plan\n"
536
+ " set_field_setting — Aggressive/Balanced/Defensive field\n"
537
+ " bowl_delivery — Execute the delivery\n"
538
+ " reflect_after_ball — Adapt after the previous ball\n"
539
+ " analyze_situation — Query pitch/bowler/field info\n\n"
540
  "Shot intents: leave | defensive | single | rotate | boundary | six\n\n"
541
+ "PRIORITIES (in order): (1) finish the match, (2) win the match, (3) score well per ball.\n"
542
+ "Verbose reasoning forfeits all three. Decide fast, act, move on."
543
  )
544
 
545
+
546
+ def get_system_prompt(stage: int = 2) -> str:
547
+ return SYSTEM_PROMPT
548
+
549
  _RANDOM_SHOTS = list(SHOT_AGGRESSION.keys())
550
  _RANDOM_QUERIES = ["pitch_conditions", "bowler_info", "field_setting", "match_situation"]
551
  _RANDOM_ZONES = ["cover", "point", "straight", "midwicket", "square_leg", "fine_leg", "long_on", "long_off"]
552
 
553
 
554
+ def _training_roster(agent_team: str | None = None) -> list[dict]:
555
+ team = agent_team or os.environ.get("CRICKET_AGENT_TEAM")
556
+ if not team:
557
+ raise ValueError("Roster-backed training requires --agent-team or CRICKET_AGENT_TEAM.")
558
+ roster = load_team_roster(team)
559
+ if not roster:
560
+ raise ValueError(f"No player profile roster found for agent team '{team}'.")
561
+ playing_xi = build_playing_xi(roster)
562
+ if len(playing_xi) < 11:
563
+ raise ValueError(f"Player profile roster for '{team}' could not produce a playing XI.")
564
+ return playing_xi
565
+
566
+
567
+ def _sample_batter(rng: random.Random, roster: list[dict]) -> dict:
568
+ batters = [p for p in roster if p.get("role") != "bowler"] or roster
569
+ if not batters:
570
+ raise ValueError("Roster-backed training requires at least one batting-capable player.")
571
+ return rng.choice(batters)
572
+
573
+
574
+ def _sample_bowler(rng: random.Random, roster: list[dict]) -> dict:
575
+ bowlers = [p for p in roster if p.get("bowler_type")]
576
+ if not bowlers:
577
+ raise ValueError("Roster-backed training requires at least one bowling-capable player.")
578
+ return rng.choice(bowlers)
579
+
580
+
581
  def _random_action(
582
  rng: random.Random,
583
  game_state: str = "batting",
584
  available_tools: list[str] | None = None,
585
  current_bowler_type: str | None = None,
586
+ roster: list[dict] | None = None,
587
  ) -> CricketAction:
588
+ legal = set(available_tools or [])
589
+
590
+ def can(tool: str) -> bool:
591
+ return available_tools is None or tool in legal
592
+
593
+ def match_plan_action() -> CricketAction:
594
+ return CricketAction(tool="set_match_plan", arguments={
595
+ "powerplay_intent": "Use roster strengths to establish tempo while protecting wickets",
596
+ "middle_intent": "Rotate strike, attack favorable matchups, and preserve finishers",
597
+ "death_intent": "Commit boundary options with wickets and target pressure in mind",
598
+ "risk_budget": "Escalate only when phase, target, and wickets justify the risk",
599
+ "trigger_conditions": "Review after wicket clusters, phase changes, target pressure, or repeated boundary/dot outcomes",
600
+ "rationale": "Create a long-horizon plan before choosing ball-by-ball tactics",
601
+ })
602
 
603
  if game_state == "toss":
604
+ return CricketAction(
605
  tool="call_toss",
606
  arguments={"call": rng.choice(["heads", "tails"]), "decision": rng.choice(["bat", "bowl"])},
607
+ )
608
+
609
+ if can("set_match_plan") and rng.random() < 0.12:
610
+ return match_plan_action()
611
+
612
+ if can("update_match_plan") and rng.random() < 0.08:
613
+ return CricketAction(tool="update_match_plan", arguments={
614
+ "reason": "Adjust plan after phase, score pressure, wickets, and field information",
615
+ "risk_budget": "Shift risk based on current target pressure and wickets in hand",
616
+ })
617
+
618
  if game_state == "bowling":
619
  choice = rng.random()
620
+ if choice < 0.15 and can("choose_bowler"):
621
+ bowler = _sample_bowler(rng, roster or [])
622
+ return CricketAction(
623
  tool="choose_bowler",
624
  arguments={
625
+ "name": bowler["name"],
626
+ "bowler_type": bowler["bowler_type"],
627
+ "style": bowler.get("bowl_style", bowler.get("style", "stock")),
628
+ "rationale": "Match roster bowler to phase, batter matchup, and remaining overs",
629
  },
630
+ )
631
+ if choice < 0.35 and can("plan_delivery"):
632
+ return CricketAction(
633
  tool="plan_delivery",
634
  arguments={
635
  "bowler_type": current_bowler_type or rng.choice(["pace", "spin"]),
 
638
  "delivery_type": rng.choice(["stock", "yorker", "bouncer", "slower ball"]),
639
  "rationale": "Use field and batter style to control scoring zones",
640
  },
641
+ )
642
+ if choice < 0.5 and can("set_field_setting"):
643
+ return CricketAction(tool="set_field_setting", arguments={"setting": rng.choice(["Aggressive", "Balanced", "Defensive"])})
644
+ if choice < 0.6 and can("reflect_after_ball"):
645
+ return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Adjust line and field after the last ball"})
646
+ if can("bowl_delivery"):
647
+ return CricketAction(tool="bowl_delivery", arguments={})
648
+ if can("set_bowling_strategy"):
649
+ return CricketAction(tool="set_bowling_strategy", arguments={
650
+ "bowler_type": current_bowler_type or "pace",
651
+ "line": "outside off",
652
+ "length": "good length",
653
+ "delivery_type": "stock",
654
+ "rationale": "Set a legal bowling plan before executing the delivery",
655
+ })
656
+ raise ValueError(f"No legal bowling action available from tools={available_tools}")
657
 
658
  choice = rng.random()
659
+ if choice < 0.15 and can("select_batter"):
660
+ batter = _sample_batter(rng, roster or [])
661
+ return CricketAction(
662
  tool="select_batter",
663
  arguments={
664
+ "name": batter["name"],
665
+ "style": batter.get("style", "balanced"),
666
+ "aggression": round(float(batter["aggression"]), 2),
667
  "rationale": "Select batter based on phase, wickets, and target pressure",
668
  },
669
+ )
670
+ if choice < 0.3 and can("set_strategy"):
671
+ return CricketAction(
672
  tool="set_strategy",
673
  arguments={
674
  "phase_intent": rng.choice(["attack", "consolidate", "rotate"]),
675
  "aggression": round(rng.uniform(0.1, 0.9), 2),
676
+ "rationale": "Align roster strengths with phase, target pressure, and wickets",
677
  },
678
+ )
679
+ if choice < 0.45 and can("plan_shot"):
680
+ return CricketAction(
681
  tool="plan_shot",
682
  arguments={
683
  "shot_intent": rng.choice(_RANDOM_SHOTS),
684
+ "target_area": rng.choice(_RANDOM_ZONES),
685
+ "trajectory": rng.choice(["ground", "lofted", "aerial"]),
686
  "risk": rng.choice(["low", "balanced", "high"]),
687
  "rationale": "Plan shot against bowler, field, and required rate",
688
  },
689
+ )
690
+ if choice < 0.55 and can("analyze_situation"):
691
+ return CricketAction(
692
  tool="analyze_situation",
693
  arguments={"query_type": rng.choice(_RANDOM_QUERIES)},
694
+ )
695
+ if choice < 0.65 and can("reflect_after_ball"):
696
+ return CricketAction(tool="reflect_after_ball", arguments={"reflection": "Revise risk after previous ball"})
697
+ if can("play_delivery"):
698
+ return CricketAction(
699
+ tool="play_delivery",
700
+ arguments={"shot_intent": rng.choice(_RANDOM_SHOTS), "explanation": "Advance the innings according to the current plan"},
701
+ )
702
+ raise ValueError(f"No legal batting action available from tools={available_tools}")
703
 
704
 
705
  def collect_prompts(
706
  n_prompts: int,
707
  task: str = "stage2_full",
708
  seed: int = 42,
709
+ agent_team: str | None = None,
710
+ opponent_mode: str = "heuristic",
711
  ) -> list[str]:
712
  """
713
  Collect game-state prompts by running episodes with random actions.
714
  Returns a list of prompt strings (one per game state observation).
715
  """
716
  rng = random.Random(seed)
717
+ roster = _training_roster(agent_team)
718
+ _PROMPT_ENV_SNAPSHOTS.clear()
719
  prompts: list[str] = []
720
  ep_count = 0
721
 
722
  while len(prompts) < n_prompts:
723
  env = CricketEnvironment()
724
+ obs = env.reset(seed=rng.randint(0, 99999), options={
725
+ "task": task,
726
+ "random_start": True,
727
+ "agent_team": agent_team or os.environ.get("CRICKET_AGENT_TEAM"),
728
+ "opponent_mode": opponent_mode,
729
+ })
730
+ prompts.append(_remember_prompt(obs.prompt_text, env))
731
  steps = 0
732
 
733
  while not obs.done and steps < 80:
 
736
  obs.game_state,
737
  obs.available_tools,
738
  obs.current_bowler.get("type") if obs.current_bowler else None,
739
+ roster,
740
  )
741
  obs = env.step(action)
742
  if not obs.done:
743
+ prompts.append(_remember_prompt(obs.prompt_text, env))
744
  steps += 1
745
 
746
  ep_count += 1
 
762
  return Dataset.from_dict({"prompt": prompts})
763
 
764
 
765
+ class CricketCaptainToolEnv:
766
+ """TRL environment wrapper exposing CricketCaptain actions as real tools."""
767
+
768
+ _stats_lock = threading.Lock()
769
+
770
+ def __init__(self):
771
+ self.env = CricketEnvironment()
772
+ self.reward = 0.0
773
+ self.done = False
774
+ self.final_reward = 0.0
775
+ self._episode_seed: int | None = None
776
+ self._episode_started = False
777
+ self._max_tool_iters: int | None = None
778
+ self._episode_had_step = False
779
+ self._episode_logged = False
780
+
781
+ def _maybe_log_episode_end(self, termination_reason: str):
782
+ # Avoid double-logging the same episode (e.g. once at termination, again on reset()).
783
+ if self._episode_logged:
784
+ return
785
+ stats_path = os.environ.get("CRICKET_EPISODE_STATS_PATH")
786
+ if not stats_path:
787
+ return
788
+
789
+ state = getattr(self.env, "state", None)
790
+
791
+ payload = {
792
+ "ts": datetime.datetime.now().isoformat(),
793
+ "seed": self._episode_seed,
794
+ "done": bool(self.done),
795
+ "termination_reason": termination_reason,
796
+ "reward_running_sum": float(self.reward),
797
+ "final_reward_bonus": float(self.final_reward),
798
+ }
799
+
800
+ if state is not None:
801
+ # ---- match config / context ----
802
+ payload["max_overs"] = getattr(state, "max_overs", None)
803
+ payload["opponent_mode"] = getattr(state, "opponent_mode", None)
804
+ payload["agent_team"] = getattr(state, "eval_pack_id", None) or getattr(state, "agent_team", None)
805
+ payload["innings_type"] = getattr(state, "innings_type", None)
806
+ payload["game_state"] = getattr(state, "game_state", None)
807
+
808
+ # ---- match outcome ----
809
+ payload["overs_played"] = getattr(state, "over", None)
810
+ payload["balls_played"] = getattr(state, "ball", None)
811
+ payload["agent_score"] = getattr(state, "total_score", None)
812
+ payload["wickets_lost"] = getattr(state, "wickets_lost", None)
813
+ payload["first_innings_score"] = getattr(state, "first_innings_score", None)
814
+ payload["target"] = getattr(state, "target", None)
815
+ payload["match_result"] = getattr(state, "match_result", None) or None
816
+
817
+ # ---- tool calls ----
818
+ tool_calls_made = int(getattr(state, "tool_calls_made", 0) or 0)
819
+ payload["tool_calls"] = tool_calls_made
820
+ tool_history = getattr(state, "tool_history", None) or []
821
+ tool_breakdown: dict[str, int] = {}
822
+ for c in tool_history:
823
+ t = c.get("tool", "unknown")
824
+ tool_breakdown[t] = tool_breakdown.get(t, 0) + 1
825
+ payload["tool_breakdown"] = tool_breakdown
826
+ payload["analyze_calls"] = len(getattr(state, "analyze_calls", []) or [])
827
+
828
+ # ---- per-turn rubric averages (mean across the full episode) ----
829
+ def _mean(xs):
830
+ xs = list(xs or [])
831
+ return round(sum(xs) / len(xs), 4) if xs else None
832
+ payload["mean_coherence"] = _mean(getattr(state, "coherence_scores", None))
833
+ payload["mean_adaptation"] = _mean(getattr(state, "adaptation_scores", None))
834
+ payload["mean_opponent_awareness"] = _mean(getattr(state, "opponent_awareness_scores", None))
835
+ payload["mean_regret"] = _mean(getattr(state, "regret_scores", None))
836
+ payload["mean_plan_commitment"] = _mean(getattr(state, "plan_commitment_scores", None))
837
+ payload["mean_plan_freshness"] = _mean(getattr(state, "plan_freshness_scores", None))
838
+ payload["strategy_changes"] = getattr(state, "strategy_changes", None)
839
+ payload["plan_version"] = getattr(state, "plan_version", None)
840
+
841
+ # ---- composite + per-rubric reward (already computed in reward_calculator) ----
842
+ if getattr(state, "reward_breakdown", None):
843
+ payload["reward_breakdown"] = dict(state.reward_breakdown)
844
+
845
+ with self._stats_lock:
846
+ with open(stats_path, "a", encoding="utf-8") as f:
847
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
848
+ f.flush()
849
+ self._episode_logged = True
850
+
851
+ def reset(self, **kwargs) -> str:
852
+ # If the previous episode ended because the trainer hit the tool-iteration cap,
853
+ # TRL will stop calling tools and then call reset() for the next scenario.
854
+ # In that case, self.done will still be False, but tool_calls_made will be at/near the cap.
855
+ if self._episode_started and self._episode_had_step and not self._episode_logged:
856
+ prev_calls = getattr(getattr(self.env, "state", None), "tool_calls_made", None)
857
+ if self.done:
858
+ self._maybe_log_episode_end("natural")
859
+ elif self._max_tool_iters and prev_calls is not None and int(prev_calls) >= int(self._max_tool_iters):
860
+ self._maybe_log_episode_end("cap")
861
+ # Otherwise: trainer reset the env mid-episode (e.g. generation bookkeeping).
862
+ # Don't log — it would skew the termination distribution.
863
+
864
+ self.reward = 0.0
865
+ self.done = False
866
+ self.final_reward = 0.0
867
+
868
+ self._episode_seed = kwargs.get("seed")
869
+ self._episode_started = True
870
+ self._episode_had_step = False
871
+ self._episode_logged = False
872
+ self._max_tool_iters = (
873
+ int(kwargs["max_tool_calling_iterations"])
874
+ if "max_tool_calling_iterations" in kwargs and kwargs["max_tool_calling_iterations"] is not None
875
+ else (int(os.environ["CRICKET_MAX_TOOL_ITERS"]) if os.environ.get("CRICKET_MAX_TOOL_ITERS") else None)
876
+ )
877
+
878
+ obs = self.env.reset(seed=kwargs.get("seed"), options={
879
+ "task": kwargs.get("task", "stage2_full"),
880
+ "random_start": bool(kwargs.get("random_start", False)),
881
+ "max_overs": int(kwargs.get("max_overs", 5)),
882
+ "eval_pack_id": kwargs.get("eval_pack_id", "adaptive_t20_v1"),
883
+ "opponent_mode": kwargs.get("opponent_mode", "heuristic"),
884
+ "opponent_cache_path": kwargs.get("opponent_cache_path"),
885
+ "agent_team": kwargs.get("agent_team"),
886
+ })
887
+ return obs.prompt_text
888
+
889
+ def _apply(self, tool: str, arguments: dict[str, Any]) -> str:
890
+ if self.done:
891
+ raise ValueError("Match is already finished.")
892
+ self._episode_had_step = True
893
+ available = self.env.state.game_state and self.env._get_available_tools()
894
+ if tool not in available:
895
+ self.reward -= 0.2
896
+ raise ValueError(f"Tool '{tool}' is not available. Available tools: {available}")
897
+ obs = self.env.step(CricketAction(tool=tool, arguments=arguments))
898
+ self.done = bool(obs.done)
899
+ self.reward += float(obs.reward or 0.0)
900
+ if obs.done and self.env.state.reward_breakdown:
901
+ self.final_reward = float(self.env.state.reward_breakdown.get("composite", 0.0))
902
+ self.reward += self.final_reward
903
+ # Log at the time of termination (do not wait for reset()) so the file appears promptly.
904
+ if self.done:
905
+ self._maybe_log_episode_end("natural")
906
+ # Also log cap termination as soon as we hit it, so runs always get stats even if TRL delays reset().
907
+ elif self._max_tool_iters:
908
+ state = getattr(self.env, "state", None)
909
+ calls = getattr(state, "tool_calls_made", None) if state is not None else None
910
+ if calls is not None and int(calls) >= int(self._max_tool_iters):
911
+ self._maybe_log_episode_end("cap")
912
+ return obs.prompt_text
913
+
914
+ def call_toss(self, call: str, decision: str) -> str:
915
+ """
916
+ Call the coin toss and choose whether to bat or bowl if the toss is won.
917
+
918
+ Args:
919
+ call: Coin call, either "heads" or "tails".
920
+ decision: Preferred decision, either "bat" or "bowl".
921
+
922
+ Returns:
923
+ Updated match observation after the toss.
924
+ """
925
+ return self._apply("call_toss", {"call": call, "decision": decision})
926
+
927
+ def set_match_plan(
928
+ self,
929
+ powerplay_intent: str,
930
+ middle_intent: str,
931
+ death_intent: str,
932
+ risk_budget: str,
933
+ trigger_conditions: str,
934
+ rationale: str,
935
+ ) -> str:
936
+ """
937
+ Establish the long-horizon plan for the innings.
938
+
939
+ Args:
940
+ powerplay_intent: Plan for overs in the powerplay.
941
+ middle_intent: Plan for middle overs.
942
+ death_intent: Plan for death overs.
943
+ risk_budget: How wickets, overs, and target pressure affect risk.
944
+ trigger_conditions: Match-state changes that should trigger a plan update.
945
+ rationale: Why this plan fits the roster and match situation.
946
+
947
+ Returns:
948
+ Updated match observation after setting the plan.
949
+ """
950
+ return self._apply("set_match_plan", {
951
+ "powerplay_intent": powerplay_intent,
952
+ "middle_intent": middle_intent,
953
+ "death_intent": death_intent,
954
+ "risk_budget": risk_budget,
955
+ "trigger_conditions": trigger_conditions,
956
+ "rationale": rationale,
957
+ })
958
+
959
+ def update_match_plan(self, reason: str, risk_budget: str = "", trigger_conditions: str = "") -> str:
960
+ """
961
+ Update the long-horizon plan after a meaningful match-state change.
962
+
963
+ Args:
964
+ reason: Specific reason for updating the plan.
965
+ risk_budget: Optional revised risk budget.
966
+ trigger_conditions: Optional revised trigger conditions.
967
+
968
+ Returns:
969
+ Updated match observation after revising the plan.
970
+ """
971
+ args = {"reason": reason}
972
+ if risk_budget:
973
+ args["risk_budget"] = risk_budget
974
+ if trigger_conditions:
975
+ args["trigger_conditions"] = trigger_conditions
976
+ return self._apply("update_match_plan", args)
977
+
978
+ def select_batter(self, name: str, style: str, aggression: float, rationale: str) -> str:
979
+ """
980
+ Select the next batter from the configured roster.
981
+
982
+ Args:
983
+ name: Player name from the team roster.
984
+ style: Batter style from the roster or tactical role.
985
+ aggression: Batting aggression from 0.0 to 1.0.
986
+ rationale: Why this batter fits the phase, wickets, and target.
987
+
988
+ Returns:
989
+ Updated match observation after selecting the batter.
990
+ """
991
+ return self._apply("select_batter", {
992
+ "name": name,
993
+ "style": style,
994
+ "aggression": aggression,
995
+ "rationale": rationale,
996
+ })
997
+
998
+ def set_strategy(self, phase_intent: str, aggression: float, rationale: str) -> str:
999
+ """
1000
+ Set batting strategy for the current phase.
1001
+
1002
+ Args:
1003
+ phase_intent: Tactical batting intent for this phase.
1004
+ aggression: Batting aggression from 0.0 to 1.0.
1005
+ rationale: Why the strategy fits score, wickets, target, and field.
1006
+
1007
+ Returns:
1008
+ Updated match observation after setting batting strategy.
1009
+ """
1010
+ return self._apply("set_strategy", {
1011
+ "phase_intent": phase_intent,
1012
+ "aggression": aggression,
1013
+ "rationale": rationale,
1014
+ })
1015
+
1016
+ def plan_shot(self, shot_intent: str, target_area: str, risk: str, trajectory: str, rationale: str) -> str:
1017
+ """DEPRECATED — pass these args inline to play_delivery() instead.
1018
+
1019
+ Args:
1020
+ shot_intent: leave|defensive|single|rotate|boundary|six.
1021
+ target_area: scoring area.
1022
+ risk: low|balanced|high.
1023
+ trajectory: ground|lofted|aerial.
1024
+ rationale: one-line reason.
1025
+
1026
+ Returns:
1027
+ Updated observation.
1028
+ """
1029
+ return self._apply("plan_shot", {
1030
+ "shot_intent": shot_intent,
1031
+ "target_area": target_area,
1032
+ "risk": risk,
1033
+ "trajectory": trajectory,
1034
+ "rationale": rationale,
1035
+ })
1036
+
1037
+ def play_delivery(
1038
+ self,
1039
+ shot_intent: str = "",
1040
+ target_area: str = "",
1041
+ risk: str = "",
1042
+ trajectory: str = "",
1043
+ rationale: str = "",
1044
+ ) -> str:
1045
+ """
1046
+ Execute the ball. Pass shot params inline to skip plan_shot.
1047
+
1048
+ Args:
1049
+ shot_intent: leave|defensive|single|rotate|boundary|six.
1050
+ target_area: optional scoring area.
1051
+ risk: optional low|balanced|high.
1052
+ trajectory: optional ground|lofted|aerial.
1053
+ rationale: optional one-line reason.
1054
+
1055
+ Returns:
1056
+ Updated observation after the ball outcome.
1057
+ """
1058
+ args: dict[str, Any] = {}
1059
+ if shot_intent: args["shot_intent"] = shot_intent
1060
+ if target_area: args["target_area"] = target_area
1061
+ if risk: args["risk"] = risk
1062
+ if trajectory: args["trajectory"] = trajectory
1063
+ if rationale: args["rationale"] = rationale
1064
+ return self._apply("play_delivery", args)
1065
+
1066
+ def choose_bowler(self, name: str, bowler_type: str, style: str, rationale: str) -> str:
1067
+ """
1068
+ Choose the bowler at the start of an over from the configured roster.
1069
+
1070
+ Args:
1071
+ name: Player name from the team roster.
1072
+ bowler_type: Bowler type, either pace or spin.
1073
+ style: Bowling style or role.
1074
+ rationale: Why this bowler fits phase, matchup, and remaining overs.
1075
+
1076
+ Returns:
1077
+ Updated match observation after choosing the bowler.
1078
+ """
1079
+ return self._apply("choose_bowler", {
1080
+ "name": name,
1081
+ "bowler_type": bowler_type,
1082
+ "style": style,
1083
+ "rationale": rationale,
1084
+ })
1085
+
1086
+ def set_bowling_strategy(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str:
1087
+ """
1088
+ Set bowling strategy for the current bowler.
1089
+
1090
+ Args:
1091
+ bowler_type: Current bowler type, either pace or spin.
1092
+ line: Intended line.
1093
+ length: Intended length.
1094
+ delivery_type: Variation or stock delivery type.
1095
+ rationale: Why this plan fits batter, field, phase, and target.
1096
+
1097
+ Returns:
1098
+ Updated match observation after setting bowling strategy.
1099
+ """
1100
+ return self._apply("set_bowling_strategy", {
1101
+ "bowler_type": bowler_type,
1102
+ "line": line,
1103
+ "length": length,
1104
+ "delivery_type": delivery_type,
1105
+ "rationale": rationale,
1106
+ })
1107
+
1108
+ def plan_delivery(self, bowler_type: str, line: str, length: str, delivery_type: str, rationale: str) -> str:
1109
+ """DEPRECATED — pass these args inline to bowl_delivery() instead.
1110
+
1111
+ Args:
1112
+ bowler_type: pace|spin.
1113
+ line: line.
1114
+ length: length.
1115
+ delivery_type: variation or stock.
1116
+ rationale: one-line reason.
1117
+
1118
+ Returns:
1119
+ Updated observation.
1120
+ """
1121
+ return self._apply("plan_delivery", {
1122
+ "bowler_type": bowler_type,
1123
+ "line": line,
1124
+ "length": length,
1125
+ "delivery_type": delivery_type,
1126
+ "rationale": rationale,
1127
+ })
1128
+
1129
+ def set_field_setting(self, setting: str) -> str:
1130
+ """
1131
+ Set the field preset.
1132
+
1133
+ Args:
1134
+ setting: One of Aggressive, Balanced, or Defensive.
1135
+
1136
+ Returns:
1137
+ Updated match observation after setting the field.
1138
+ """
1139
+ return self._apply("set_field_setting", {"setting": setting})
1140
+
1141
+ def bowl_delivery(
1142
+ self,
1143
+ line: str = "",
1144
+ length: str = "",
1145
+ delivery_type: str = "",
1146
+ rationale: str = "",
1147
+ ) -> str:
1148
+ """
1149
+ Execute the delivery. Pass plan params inline to skip plan_delivery.
1150
+
1151
+ Args:
1152
+ line: optional line.
1153
+ length: optional length.
1154
+ delivery_type: optional variation or stock.
1155
+ rationale: optional one-line reason.
1156
+
1157
+ Returns:
1158
+ Updated observation after the ball outcome.
1159
+ """
1160
+ args: dict[str, Any] = {}
1161
+ if line: args["line"] = line
1162
+ if length: args["length"] = length
1163
+ if delivery_type: args["delivery_type"] = delivery_type
1164
+ if rationale: args["rationale"] = rationale
1165
+ return self._apply("bowl_delivery", args)
1166
+
1167
+ def reflect_after_ball(self, reflection: str) -> str:
1168
+ """
1169
+ Reflect after the previous ball and adapt the plan.
1170
+
1171
+ Args:
1172
+ reflection: Specific tactical lesson from the previous ball.
1173
+
1174
+ Returns:
1175
+ Updated match observation after recording reflection.
1176
+ """
1177
+ return self._apply("reflect_after_ball", {"reflection": reflection})
1178
+
1179
+ def analyze_situation(self, query_type: str) -> str:
1180
+ """
1181
+ Analyze part of the match context.
1182
+
1183
+ Args:
1184
+ query_type: One of pitch_conditions, bowler_info, field_setting, or match_situation.
1185
+
1186
+ Returns:
1187
+ Updated observation containing the analysis result.
1188
+ """
1189
+ return self._apply("analyze_situation", {"query_type": query_type})
1190
+
1191
+
1192
+ def build_agent_dataset(n_examples: int, args) -> Dataset:
1193
+ if Dataset is None:
1194
+ raise ImportError("datasets is required for training. Install with: pip install '.[train]'")
1195
+ rows = []
1196
+ rng = random.Random(args.seed)
1197
+ stage_prompt = get_system_prompt(args.stage)
1198
+ # Curriculum distribution. If --max-overs is set, use it as a fixed format.
1199
+ # Otherwise sample per-scenario from a T2-heavy distribution that tapers to T5.
1200
+ # Rationale: T2 episodes (~72 tool calls) actually COMPLETE within our token
1201
+ # budget so r_result fires; T5 episodes (~180) train the model on its
1202
+ # eval distribution. Heavy weight on short formats early so the policy
1203
+ # escapes the "planning loop" before tackling longer matches.
1204
+ overs_distribution = getattr(args, "overs_distribution", None)
1205
+ fixed_overs = args.max_overs if args.max_overs and args.max_overs > 0 else None
1206
+ if fixed_overs is None and not overs_distribution:
1207
+ # default curriculum: 50% T2, 30% T3, 15% T4, 5% T5
1208
+ overs_distribution = [2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
1209
+ for idx in range(n_examples):
1210
+ scenario_overs = fixed_overs if fixed_overs is not None else rng.choice(overs_distribution)
1211
+ rows.append({
1212
+ "prompt": [
1213
+ {"role": "system", "content": stage_prompt},
1214
+ {"role": "user", "content": ""},
1215
+ ],
1216
+ "seed": rng.randint(0, 999999),
1217
+ "task": "stage1_format" if args.stage == 1 else "stage2_full",
1218
+ "random_start": False,
1219
+ "max_overs": scenario_overs,
1220
+ "eval_pack_id": args.eval_pack_id,
1221
+ "opponent_mode": args.opponent_mode,
1222
+ "opponent_cache_path": getattr(args, "opponent_cache_path", None),
1223
+ "agent_team": args.agent_team,
1224
+ "scenario_id": idx,
1225
+ })
1226
+ return Dataset.from_list(rows)
1227
+
1228
+
1229
+ def environment_reward(environments, **kwargs) -> list[float]:
1230
+ rewards = []
1231
+ # Aggregate metrics across all envs in this gradient step for WandB logging.
1232
+ agg = {
1233
+ "r_result": [], "r_cricket": [], "r_behavior": [], "r_validity": [],
1234
+ "r_coherence": [], "r_adaptation": [], "r_opponent_awareness": [], "r_regret": [],
1235
+ "composite": [], "tool_calls": [], "wickets_lost": [], "agent_score": [],
1236
+ "matches_completed": 0, "n": 0,
1237
+ }
1238
+ tool_freq: dict[str, int] = {}
1239
+ for env in environments:
1240
+ state = env.env.state
1241
+ breakdown = state.reward_breakdown or {}
1242
+ terminal = float(breakdown.get("composite", 0.0))
1243
+ plan_score = (sum(state.plan_commitment_scores) / len(state.plan_commitment_scores)) if state.plan_commitment_scores else 0.0
1244
+ validity = 1.0 - min(1.0, len([c for c in state.tool_history if c.get("tool") == "invalid_json"]) / max(state.step_count, 1))
1245
+ reward = env.reward + terminal + 0.1 * plan_score + 0.05 * validity
1246
+ # Reward clip removed: when rollouts complete naturally, the composite
1247
+ # reward easily saturates [-1, 1], causing GRPO group-std → 0 and
1248
+ # killing the gradient signal. Let GRPO standardize the advantage itself.
1249
+ rewards.append(round(reward, 4))
1250
+
1251
+ # Collect for aggregate logging.
1252
+ agg["n"] += 1
1253
+ if env.done:
1254
+ agg["matches_completed"] += 1
1255
+ for k in ("r_result", "r_cricket", "r_behavior", "r_validity",
1256
+ "r_coherence", "r_adaptation", "r_opponent_awareness",
1257
+ "r_regret", "composite"):
1258
+ v = breakdown.get(k)
1259
+ if v is not None:
1260
+ agg[k].append(float(v))
1261
+ agg["tool_calls"].append(int(getattr(state, "tool_calls_made", 0) or 0))
1262
+ agg["wickets_lost"].append(int(getattr(state, "wickets_lost", 0) or 0))
1263
+ agg["agent_score"].append(int(getattr(state, "total_score", 0) or 0))
1264
+ for c in (state.tool_history or []):
1265
+ t = c.get("tool", "unknown")
1266
+ tool_freq[t] = tool_freq.get(t, 0) + 1
1267
+
1268
+ # WandB log — only if wandb is initialised in this process.
1269
+ try:
1270
+ import wandb
1271
+ if wandb.run is not None and agg["n"] > 0:
1272
+ log_dict: dict[str, float] = {
1273
+ "rollout/n_episodes": agg["n"],
1274
+ "rollout/matches_completed": agg["matches_completed"],
1275
+ "rollout/match_completion_rate": agg["matches_completed"] / agg["n"],
1276
+ }
1277
+ for k in ("r_result", "r_cricket", "r_behavior", "r_validity",
1278
+ "r_coherence", "r_adaptation", "r_opponent_awareness",
1279
+ "r_regret", "composite"):
1280
+ if agg[k]:
1281
+ log_dict[f"reward/{k}_mean"] = sum(agg[k]) / len(agg[k])
1282
+ log_dict[f"reward/{k}_max"] = max(agg[k])
1283
+ log_dict[f"reward/{k}_min"] = min(agg[k])
1284
+ for k in ("tool_calls", "wickets_lost", "agent_score"):
1285
+ if agg[k]:
1286
+ log_dict[f"episode/{k}_mean"] = sum(agg[k]) / len(agg[k])
1287
+ log_dict[f"episode/{k}_max"] = max(agg[k])
1288
+ # Tool usage breakdown — frequency per tool name across this step.
1289
+ total_tools = sum(tool_freq.values()) or 1
1290
+ for t, n in tool_freq.items():
1291
+ log_dict[f"tools/freq_{t}"] = n / total_tools
1292
+ wandb.log(log_dict)
1293
+ except Exception:
1294
+ # Never let logging break training.
1295
+ pass
1296
+ return rewards
1297
+
1298
+
1299
+ def generate_sft_examples(out_path: str, n_examples: int = 240, seed: int = 42, agent_team: str | None = None):
1300
  """Stage 0 bootstrap data: valid tool JSON for every tool family."""
1301
  rng = random.Random(seed)
1302
+ roster = _training_roster(agent_team)
1303
  examples = []
1304
  for _ in range(n_examples):
1305
  game_state = rng.choice(["toss", "batting", "bowling"])
1306
+ action = _random_action(rng, game_state, roster=roster)
1307
  prompt = (
1308
  f"{SYSTEM_PROMPT}\n\n"
1309
  f"[CricketCaptain] {game_state.upper()} | Example adaptive scenario\n"
 
1330
  # Model loading (plain transformers + bitsandbytes 4-bit) #
1331
  # ------------------------------------------------------------------ #
1332
 
1333
+ def load_model(model_name: str, *, use_vllm: bool = False, resume_adapter_from: str | None = None):
1334
+ """Load base + LoRA. When use_vllm=True, base is loaded in bf16 (vLLM
1335
+ does not support 4-bit BNB inference); otherwise 4-bit NF4.
1336
+
1337
+ resume_adapter_from: optional path to a PEFT adapter directory (e.g. a previous
1338
+ checkpoint dir). If provided, loads the adapter weights instead of initializing
1339
+ a fresh LoRA. The base model is still loaded from `model_name`. The adapter's
1340
+ LoraConfig is preserved (so you can resume even if r= or alpha= drift between runs)."""
1341
  if not _TRAIN_IMPORTS_AVAILABLE:
1342
  raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
1343
+ print(f"Loading {model_name} … (use_vllm={use_vllm}, dtype={'bf16' if use_vllm else '4-bit NF4'})")
 
 
 
 
 
 
1344
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
1345
  if tokenizer.pad_token is None:
1346
  tokenizer.pad_token = tokenizer.eos_token
1347
 
1348
+ try:
1349
+ import flash_attn # noqa: F401
1350
+ attn_impl = "flash_attention_2"
1351
+ except ImportError:
1352
+ attn_impl = "sdpa"
1353
+
1354
+ load_kwargs = dict(
1355
  device_map="auto",
1356
  trust_remote_code=True,
1357
  torch_dtype=torch.bfloat16,
1358
+ attn_implementation=attn_impl,
1359
  )
1360
+ if not use_vllm:
1361
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
1362
+ load_in_4bit=True,
1363
+ bnb_4bit_compute_dtype=torch.bfloat16,
1364
+ bnb_4bit_use_double_quant=True,
1365
+ bnb_4bit_quant_type="nf4",
1366
+ )
1367
+
1368
+ model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
1369
+ if not use_vllm:
1370
+ model = prepare_model_for_kbit_training(model)
1371
+
1372
+ if resume_adapter_from:
1373
+ # Resume from a previous PEFT adapter checkpoint (e.g. warmup output).
1374
+ # PeftModel.from_pretrained reads the adapter_config.json from the dir,
1375
+ # so any r/alpha/target_modules saved with the warmup run is preserved.
1376
+ from peft import PeftModel
1377
+ adapter_path = Path(resume_adapter_from)
1378
+ if not adapter_path.exists():
1379
+ raise FileNotFoundError(f"resume_adapter_from path does not exist: {adapter_path}")
1380
+ print(f"Resuming LoRA adapter from {adapter_path}")
1381
+ model = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=True)
1382
+ else:
1383
+ lora_cfg = LoraConfig(
1384
+ r=64,
1385
+ lora_alpha=128,
1386
+ lora_dropout=0.05,
1387
+ bias="none",
1388
+ task_type="CAUSAL_LM",
1389
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
1390
+ )
1391
+ model = get_peft_model(model, lora_cfg)
1392
+
1393
  print(f"Loaded. Parameters: {model.num_parameters():,}")
1394
+ model.print_trainable_parameters()
1395
  return model, tokenizer
1396
 
1397
 
 
1402
  def train(args):
1403
  if not _TRAIN_IMPORTS_AVAILABLE:
1404
  raise ImportError("Training dependencies are missing. Install with: pip install '.[train]'")
1405
+ if args.opponent_mode == "llm_live":
1406
+ if args.opponent_model:
1407
+ os.environ["CRICKET_OPPONENT_MODEL"] = args.opponent_model
1408
+ if args.opponent_api_base:
1409
+ os.environ["CRICKET_OPPONENT_API_BASE"] = args.opponent_api_base
1410
+ if args.opponent_api_key:
1411
+ os.environ["CRICKET_OPPONENT_API_KEY"] = args.opponent_api_key
1412
  task = "stage1_format" if args.stage == 1 else "stage2_full"
1413
+ # CRICKET_CKPT_ROOT lets a side-by-side run write checkpoints to a different
1414
+ # tree (e.g. ./checkpoints_smoke) without trampling an active production run.
1415
+ # Default unchanged: ./checkpoints/.
1416
+ ckpt_root = os.environ.get("CRICKET_CKPT_ROOT", "./checkpoints").rstrip("/")
1417
+ out_dir = f"{ckpt_root}/stage{args.stage}"
1418
+ save_dir = f"{ckpt_root}/stage{args.stage}_final"
1419
+
1420
+ ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
1421
+ log_dir = Path(f"./logs/run_{ts}_stage{args.stage}_{args.opponent_mode}")
1422
+ log_dir.mkdir(parents=True, exist_ok=True)
1423
+
1424
+ # Make episode termination stats available to the environment wrapper.
1425
+ # This lets us distinguish natural terminations from tool-iteration cap truncations.
1426
+ stats_path = log_dir / "episode_stats.jsonl"
1427
+ os.environ["CRICKET_EPISODE_STATS_PATH"] = str(stats_path)
1428
+ os.environ["CRICKET_MAX_TOOL_ITERS"] = str(args.max_tool_calling_iterations)
1429
+ # Create the file immediately so users can find/tail it even before the first termination.
1430
+ stats_path.touch(exist_ok=True)
1431
 
1432
  print(f"\n=== Stage {args.stage} Training ===")
1433
  print(f"Task: {task} | Prompts: {args.prompts} | Steps: {args.steps}")
1434
+ print(f"Logs: {log_dir}/ | Checkpoints: {out_dir}/")
1435
+ print(f"max_tool_calling_iterations={args.max_tool_calling_iterations} (full 5-over match needs ~180; 20-over needs ~720)")
1436
+
1437
+ (log_dir / "metadata.json").write_text(json.dumps({
1438
+ "stage": args.stage, "model": args.model, "agent_team": args.agent_team,
1439
+ "max_overs": args.max_overs, "opponent_mode": args.opponent_mode,
1440
+ "prompts": args.prompts, "steps": args.steps,
1441
+ "batch_size": args.batch_size, "grad_accum": args.grad_accum,
1442
+ "num_generations": args.num_generations,
1443
+ "max_completion_length": args.max_completion_length,
1444
+ "max_tool_calling_iterations": args.max_tool_calling_iterations,
1445
+ "logging_steps": args.logging_steps,
1446
+ "timestamp": ts,
1447
+ }, indent=2))
1448
+
1449
+ # Build scenario seeds. TRL's environment_factory performs the actual
1450
+ # multi-turn rollout and tool execution during training.
1451
+ print("\nBuilding environment scenarios …")
1452
+ dataset = build_agent_dataset(args.prompts, args)
1453
+
1454
+ # Load model — bf16 if vLLM is on (vLLM rejects 4-bit BNB) or --bf16-base, else 4-bit NF4.
1455
+ # If resume_from is set, load the LoRA adapter from that path instead of fresh init.
1456
+ bf16_base = getattr(args, "use_vllm", False) or getattr(args, "bf16_base", False)
1457
+ resume_from = getattr(args, "resume_from", None)
1458
+ model, tokenizer = load_model(args.model, use_vllm=bf16_base, resume_adapter_from=resume_from)
1459
 
1460
  # GRPO config
1461
+ #
1462
+ # Qwen3 / Qwen3.5 ship with hybrid thinking ENABLED by default. Empirically
1463
+ # (see logs/run_2026-04-25_21-08-45 completions parquet) every generation
1464
+ # spent ~1200 chars inside <think>...</think> and then emitted XML-style
1465
+ # <function><parameter> tags instead of the JSON tool call we asked for.
1466
+ # That meant 0/32 generations were parseable, _apply() never advanced the
1467
+ # match, and episodes always hit max_tool_calling_iterations before any
1468
+ # innings finished — so r_result (55% of the composite) was never earned.
1469
+ #
1470
+ chat_template_kwargs = {}
1471
+ generation_kwargs = {}
1472
+
1473
+ completion_len = max(args.max_completion_length, 2048)
1474
+ use_vllm = getattr(args, "use_vllm", False)
1475
+ vllm_kwargs = {}
1476
+ if use_vllm:
1477
+ # vllm_model_impl: None (default) → vLLM picks its native class. Use this for
1478
+ # Qwen3-* (Qwen3ForCausalLM is registered, native path with full LoRA support).
1479
+ # Set to "transformers" only for Qwen3.5-* where vLLM has no text-only class
1480
+ # registered and the native path tries to load a vision tower.
1481
+ vllm_kwargs = dict(
1482
+ use_vllm=True,
1483
+ vllm_mode="colocate",
1484
+ vllm_gpu_memory_utilization=getattr(args, "vllm_gpu_memory", 0.5),
1485
+ vllm_max_model_length=completion_len + 2048,
1486
+ )
1487
+ vllm_impl = getattr(args, "vllm_model_impl", None)
1488
+ if vllm_impl:
1489
+ vllm_kwargs["vllm_model_impl"] = vllm_impl
1490
+
1491
+ # Resolve hyperparameters from YAML/CLI with sensible fallbacks.
1492
+ lr = args.learning_rate if getattr(args, "learning_rate", None) is not None \
1493
+ else (2e-5 if args.stage == 1 else 1e-5)
1494
+ grpo_beta = getattr(args, "beta", None)
1495
+ grpo_temp = getattr(args, "temperature", None) or 0.8
1496
+ grpo_top_p = getattr(args, "top_p", None)
1497
+ grad_ckpt = getattr(args, "gradient_checkpointing", None)
1498
+ grad_ckpt_kwargs = None
1499
+ if grad_ckpt and getattr(args, "gradient_checkpointing_use_reentrant", None) is not None:
1500
+ grad_ckpt_kwargs = {"use_reentrant": bool(args.gradient_checkpointing_use_reentrant)}
1501
+
1502
+ optional_cfg = {}
1503
+ if grpo_beta is not None:
1504
+ optional_cfg["beta"] = grpo_beta
1505
+ if grpo_top_p is not None:
1506
+ optional_cfg["top_p"] = grpo_top_p
1507
+ if grad_ckpt is not None:
1508
+ optional_cfg["gradient_checkpointing"] = bool(grad_ckpt)
1509
+ if grad_ckpt_kwargs is not None:
1510
+ optional_cfg["gradient_checkpointing_kwargs"] = grad_ckpt_kwargs
1511
+ if getattr(args, "dataloader_pin_memory", None) is not None:
1512
+ optional_cfg["dataloader_pin_memory"] = bool(args.dataloader_pin_memory)
1513
+ if getattr(args, "dataloader_num_workers", None) is not None:
1514
+ optional_cfg["dataloader_num_workers"] = int(args.dataloader_num_workers)
1515
+
1516
  config = GRPOConfig(
1517
  output_dir=out_dir,
1518
+ logging_dir=str(log_dir / "tensorboard"),
1519
  num_train_epochs=1,
1520
  max_steps=args.steps,
1521
  per_device_train_batch_size=args.batch_size,
1522
  gradient_accumulation_steps=args.grad_accum,
1523
+ learning_rate=lr,
1524
  warmup_ratio=0.05,
1525
  lr_scheduler_type="cosine",
1526
+ logging_steps=args.logging_steps,
1527
+ save_steps=getattr(args, "save_steps", None) or 10,
1528
+ save_total_limit=getattr(args, "save_total_limit", None) or 20,
1529
  bf16=True,
1530
+ max_completion_length=completion_len,
 
1531
  num_generations=args.num_generations,
1532
+ max_tool_calling_iterations=args.max_tool_calling_iterations,
1533
+ temperature=grpo_temp,
1534
+ report_to=args.report_to,
1535
+ run_name=args.run_name,
1536
  log_completions=True,
1537
  seed=args.seed,
1538
+ chat_template_kwargs=chat_template_kwargs,
1539
+ generation_kwargs=generation_kwargs,
1540
+ **optional_cfg,
1541
+ **vllm_kwargs,
1542
  )
1543
 
1544
+ # TRL's add_response_schema pattern-matches tokenizer.chat_template against
1545
+ # a fixed list and raises "Unrecognized chat template" if no match. Some
1546
+ # newer Qwen3 builds (e.g. Qwen3-4B-Instruct-2507, Aug 2025) ship a
1547
+ # template that differs from TRL's stored string (the Instruct release
1548
+ # dropped the enable_thinking block) — but the tool-call format
1549
+ # (<tool_call>…</tool_call>) is identical, so the appropriate schema still
1550
+ # parses correctly. We assign it manually before GRPOTrainer init; TRL
1551
+ # checks `response_schema is None` first so this is a safe override.
1552
+ if getattr(tokenizer, "response_schema", None) is None:
1553
+ try:
1554
+ from trl.chat_template_utils import qwen3_schema, qwen3_5_schema
1555
+ m = args.model.lower()
1556
+ if "qwen3.5" in m or "qwen3_5" in m:
1557
+ tokenizer.response_schema = qwen3_5_schema
1558
+ print("Set tokenizer.response_schema = qwen3_5_schema (manual override).")
1559
+ elif "qwen3" in m:
1560
+ tokenizer.response_schema = qwen3_schema
1561
+ print("Set tokenizer.response_schema = qwen3_schema (manual override).")
1562
+ except ImportError:
1563
+ pass
1564
 
1565
  trainer = GRPOTrainer(
1566
  model=model,
1567
+ reward_funcs=environment_reward,
1568
  args=config,
1569
  train_dataset=dataset,
1570
  processing_class=tokenizer,
1571
+ environment_factory=CricketCaptainToolEnv,
1572
  )
1573
 
1574
  print(f"\nStarting training ({args.steps} steps, {len(dataset)} prompts) …")
 
1604
 
1605
  for ep in range(args.eval_episodes):
1606
  env = CricketEnvironment()
1607
+ obs = env.reset(seed=rng.randint(0, 99999), options={
1608
+ "task": "stage2_full",
1609
+ "random_start": False,
1610
+ "agent_team": args.agent_team,
1611
+ })
1612
  steps = 0
1613
 
1614
  while not obs.done and steps < 150:
 
1619
  if data:
1620
  action = CricketAction(tool=data["tool"], arguments=data.get("arguments", {}))
1621
  else:
1622
+ action = CricketAction(tool="invalid_json", arguments={})
 
 
 
 
 
1623
 
1624
  obs = env.step(action)
1625
  steps += 1
 
1650
  def train_smoke(args):
1651
  """Run short direct-environment training rollouts without loading a model."""
1652
  rng = random.Random(args.seed)
1653
+ roster = _training_roster(args.agent_team)
1654
 
1655
  # Auto-create run folder unless --output explicitly given
1656
  if args.output:
 
1691
  "eval_pack_id": args.eval_pack_id,
1692
  "opponent_mode": args.opponent_mode,
1693
  "opponent_cache_path": args.opponent_cache_path,
1694
+ "agent_team": args.agent_team,
1695
  })
1696
  prompts = [_format_prompt(obs.prompt_text)]
1697
  total_reward = 0.0
 
1709
  obs.game_state,
1710
  obs.available_tools,
1711
  obs.current_bowler.get("type") if obs.current_bowler else None,
1712
+ roster,
1713
  )
1714
  obs = env.step(action)
1715
  step_end = time.perf_counter()
 
1802
  def _apply_yaml_defaults(args, cfg: dict) -> None:
1803
  """Merge YAML config values into args, CLI args take precedence."""
1804
  captain = cfg.get("captain", {}) or {}
1805
+ opponent = cfg.get("opponent", {}) or {}
1806
  env_cfg = cfg.get("env", {}) or {}
1807
  train_cfg = cfg.get("train", {}) or {}
1808
 
 
1810
  if val is not None and getattr(args, attr, None) is None:
1811
  setattr(args, attr, val)
1812
 
1813
+ if getattr(args, "cmd", None) == "train":
1814
+ _set("model", train_cfg.get("model"))
1815
+ else:
1816
+ _set("model", captain.get("model"))
1817
  _set("api_base", captain.get("api_base"))
1818
  _set("api_key", os.environ.get(captain.get("api_key_env", "")) or None)
1819
  _set("eval_pack_id", env_cfg.get("eval_pack_id"))
1820
+ _set("opponent_mode", opponent.get("mode"))
1821
+ _set("opponent_cache_path", opponent.get("cache_path"))
1822
+ _set("opponent_model", opponent.get("model"))
1823
+ _set("opponent_api_base", opponent.get("api_base"))
1824
+ api_key_env = opponent.get("api_key_env")
1825
+ _set("opponent_api_key", os.environ.get(api_key_env, "") if api_key_env else None)
1826
  _set("max_overs", env_cfg.get("max_overs"))
1827
+ _set("agent_team", env_cfg.get("agent_team"))
1828
  _set("steps", train_cfg.get("steps"))
1829
  _set("prompts", train_cfg.get("prompts"))
1830
  _set("batch_size", train_cfg.get("batch_size"))
1831
+ _set("grad_accum", train_cfg.get("grad_accum"))
1832
  _set("stage", train_cfg.get("stage"))
1833
+ _set("num_generations", train_cfg.get("num_generations"))
1834
+ _set("max_completion_length", train_cfg.get("max_completion_length"))
1835
+ _set("max_tool_calling_iterations", train_cfg.get("max_tool_calling_iterations"))
1836
+ _set("logging_steps", train_cfg.get("logging_steps"))
1837
+ _set("report_to", train_cfg.get("report_to"))
1838
+ _set("run_name", train_cfg.get("run_name"))
1839
+ _set("learning_rate", train_cfg.get("learning_rate"))
1840
+ _set("beta", train_cfg.get("beta"))
1841
+ _set("temperature", train_cfg.get("temperature"))
1842
+ _set("top_p", train_cfg.get("top_p"))
1843
+ _set("gradient_checkpointing", train_cfg.get("gradient_checkpointing"))
1844
+ _set("gradient_checkpointing_use_reentrant", train_cfg.get("gradient_checkpointing_use_reentrant"))
1845
+ _set("dataloader_pin_memory", train_cfg.get("dataloader_pin_memory"))
1846
+ _set("dataloader_num_workers", train_cfg.get("dataloader_num_workers"))
1847
+ _set("bf16_base", train_cfg.get("bf16_base"))
1848
+ _set("save_steps", train_cfg.get("save_steps"))
1849
+ _set("save_total_limit", train_cfg.get("save_total_limit"))
1850
+ _set("resume_from", train_cfg.get("resume_from"))
1851
+ _set("overs_distribution", train_cfg.get("overs_distribution"))
1852
+ _set("use_vllm", train_cfg.get("use_vllm"))
1853
+ _set("vllm_gpu_memory", train_cfg.get("vllm_gpu_memory"))
1854
+ _set("vllm_model_impl", train_cfg.get("vllm_model_impl"))
1855
 
1856
 
1857
  def main():
 
1867
  t.add_argument("--prompts", type=int, default=None, help="Game state prompts to collect")
1868
  t.add_argument("--steps", type=int, default=None, help="GRPOTrainer max_steps")
1869
  t.add_argument("--batch-size", type=int, default=None, dest="batch_size")
1870
+ t.add_argument("--grad-accum", type=int, default=None, dest="grad_accum")
1871
+ t.add_argument("--num-generations", type=int, default=None, dest="num_generations")
1872
+ t.add_argument("--agent-team", default=None, dest="agent_team")
1873
+ t.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
1874
+ t.add_argument("--opponent-model", default=None, dest="opponent_model")
1875
+ t.add_argument("--opponent-api-base", default=None, dest="opponent_api_base")
1876
+ t.add_argument("--opponent-api-key", default=None, dest="opponent_api_key")
1877
+ t.add_argument("--max-overs", type=int, default=None, dest="max_overs")
1878
+ t.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
1879
+ t.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
1880
+ t.add_argument("--max-completion-length", type=int, default=None, dest="max_completion_length")
1881
+ t.add_argument("--max-tool-calling-iterations", type=int, default=None, dest="max_tool_calling_iterations")
1882
+ t.add_argument("--logging-steps", type=int, default=None, dest="logging_steps")
1883
+ t.add_argument("--report-to", default=None, dest="report_to")
1884
+ t.add_argument("--run-name", default=None, dest="run_name")
1885
  t.add_argument("--seed", type=int, default=42)
1886
+ t.add_argument("--resume-from", default=None, dest="resume_from",
1887
+ help="Path to a previous LoRA adapter dir (e.g. ./checkpoints/stage2_final). "
1888
+ "When set, the adapter is loaded on top of the base model instead of a fresh init.")
1889
+ t.add_argument("--save-steps", type=int, default=None, dest="save_steps")
1890
+ t.add_argument("--save-total-limit", type=int, default=None, dest="save_total_limit")
1891
+ t.add_argument("--learning-rate", type=float, default=None, dest="learning_rate")
1892
+ t.add_argument("--beta", type=float, default=None, dest="beta",
1893
+ help="GRPO KL coefficient. Lower = more exploration.")
1894
+ t.add_argument("--temperature", type=float, default=None, dest="temperature")
1895
+ t.add_argument("--top-p", type=float, default=None, dest="top_p")
1896
+ t.add_argument("--gradient-checkpointing", action="store_true", dest="gradient_checkpointing", default=None)
1897
+ t.add_argument("--no-gradient-checkpointing", action="store_false", dest="gradient_checkpointing")
1898
+ t.add_argument("--gradient-checkpointing-use-reentrant", action="store_true",
1899
+ dest="gradient_checkpointing_use_reentrant", default=None)
1900
+ t.add_argument("--dataloader-pin-memory", action="store_true", dest="dataloader_pin_memory", default=None)
1901
+ t.add_argument("--dataloader-num-workers", type=int, default=None, dest="dataloader_num_workers")
1902
+ t.add_argument("--use-vllm", action="store_true", dest="use_vllm", default=None,
1903
+ help="Use vLLM-backed rollouts (colocate). Forces bf16 base.")
1904
+ t.add_argument("--bf16-base", action="store_true", dest="bf16_base", default=None,
1905
+ help="Load base model in bf16 instead of 4-bit NF4. Faster matmul on H200 since 4B fits in 8GB.")
1906
+ t.add_argument("--vllm-gpu-memory", type=float, default=0.5, dest="vllm_gpu_memory",
1907
+ help="Fraction of GPU memory reserved for vLLM (colocate). Default 0.5.")
1908
+ t.add_argument("--vllm-model-impl", default=None, dest="vllm_model_impl",
1909
+ choices=["transformers", "vllm"],
1910
+ help="vLLM model backend. None (default) = native vLLM class (e.g. Qwen3ForCausalLM); "
1911
+ "'transformers' = HF transformers backend (workaround for Qwen3.5 — flaky with LoRA).")
1912
 
1913
  # eval
1914
  e = sub.add_parser("eval", help="Evaluate a checkpoint")
1915
  e.add_argument("--config", default=None)
1916
  e.add_argument("--model", default=None)
1917
  e.add_argument("--eval-episodes", type=int, default=10, dest="eval_episodes")
1918
+ e.add_argument("--agent-team", default=None, dest="agent_team")
1919
  e.add_argument("--seed", type=int, default=0)
1920
 
1921
  # quick test (no GPU needed)
 
1930
  smoke.add_argument("--eval-pack-id", default=None, dest="eval_pack_id")
1931
  smoke.add_argument("--opponent-mode", default=None, choices=["heuristic", "llm_live", "llm_cached", "cricsheet"], dest="opponent_mode")
1932
  smoke.add_argument("--opponent-cache-path", default=None, dest="opponent_cache_path")
1933
+ smoke.add_argument("--agent-team", default=None, dest="agent_team")
1934
  smoke.add_argument("--output", default=None)
1935
  smoke.add_argument("--seed", type=int, default=42)
1936
 
1937
  sft = sub.add_parser("sft-data", help="Generate Stage 0 supervised tool-format examples")
1938
  sft.add_argument("--output", default="./data/training/tool_sft_examples.jsonl")
1939
  sft.add_argument("--examples", type=int, default=240)
1940
+ sft.add_argument("--agent-team", default=None, dest="agent_team")
1941
  sft.add_argument("--seed", type=int, default=42)
1942
 
1943
  args = parser.parse_args()
 
1955
  if getattr(args, "stage", None) is None:
1956
  args.stage = 1
1957
  if getattr(args, "model", None) is None:
1958
+ args.model = "Qwen/Qwen3.5-4B"
1959
  if getattr(args, "steps", None) is None:
1960
  args.steps = 200
1961
  if getattr(args, "prompts", None) is None:
1962
  args.prompts = 500
1963
  if getattr(args, "batch_size", None) is None:
1964
  args.batch_size = 2
1965
+ if getattr(args, "grad_accum", None) is None:
1966
+ args.grad_accum = 4
1967
  if getattr(args, "eval_pack_id", None) is None:
1968
  args.eval_pack_id = "adaptive_t20_v1"
1969
  if getattr(args, "opponent_mode", None) is None:
1970
  args.opponent_mode = "llm_live"
1971
  if getattr(args, "max_overs", None) is None:
1972
  args.max_overs = 5
1973
+ if getattr(args, "agent_team", None) is None:
1974
+ args.agent_team = os.environ.get("CRICKET_AGENT_TEAM")
1975
+ if getattr(args, "max_tool_calling_iterations", None) is None:
1976
+ args.max_tool_calling_iterations = 200
1977
+ if getattr(args, "logging_steps", None) is None:
1978
+ args.logging_steps = 1
1979
+ if getattr(args, "report_to", None) is None:
1980
+ args.report_to = "none"
1981
 
1982
  if args.cmd == "train":
1983
  train(args)
1984
  elif args.cmd == "eval":
1985
  evaluate(args)
1986
  elif args.cmd == "test":
1987
+ _smoke_test(args.agent_team, args.opponent_mode)
1988
  elif args.cmd == "train-smoke":
1989
  train_smoke(args)
1990
  elif args.cmd == "sft-data":
1991
+ generate_sft_examples(args.output, args.examples, args.seed, args.agent_team)
1992
  else:
1993
  parser.print_help()
1994
 
1995
 
1996
+ def _smoke_test(agent_team: str | None, opponent_mode: str):
1997
  """Verify reward functions work correctly."""
1998
  cases = [
1999
  (
 
2021
  ]
2022
  print("Reward function smoke test:\n")
2023
  for prompt, completion, expected in cases:
2024
+ fmt = r_validity(completion)
2025
+ coh = r_behavior_stateless(prompt, completion)
2026
  print(f" expected={expected:4s} | fmt={fmt:.0f} | coh={coh:.3f} | {completion[:60]}")
2027
 
2028
  print("\nPrompt collection test (5 prompts):")
2029
+ p = collect_prompts(5, task="stage1_format", seed=1, agent_team=agent_team, opponent_mode=opponent_mode)
2030
  for i, pp in enumerate(p):
2031
  print(f" [{i}] {pp[:80].strip()} …")
2032