anugrahhu commited on
Commit
11307a1
·
verified ·
1 Parent(s): 3080a66

vanilla GRPO: backport EvidenceCallback for live evidence/*.csv + plots

Browse files
Files changed (1) hide show
  1. training/training_script.py +633 -437
training/training_script.py CHANGED
@@ -1,437 +1,633 @@
1
- """GRPO (Group-Relative Policy Optimization) training script for CERNenv.
2
-
3
- Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
4
- fine-tune a small instruction-tuned model on full episodes of the CERN
5
- environment. Each ``query`` is a prompt sampled from a freshly-reset env;
6
- the reward function rolls the model's response through the environment and
7
- returns the per-step + (optional) terminal reward.
8
-
9
- This script is intentionally CPU-friendly and self-contained. For
10
- GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
11
-
12
- Run:
13
- python -m training.training_script \
14
- --model_name HuggingFaceTB/SmolLM2-360M-Instruct \
15
- --total_episodes 200 --max_steps 18 --output_dir training/grpo-output
16
- """
17
-
18
- from __future__ import annotations
19
-
20
- import argparse
21
- import logging
22
- import math
23
- import os
24
- import threading
25
- from dataclasses import dataclass, field
26
- from typing import Any, Dict, List, Optional, TYPE_CHECKING
27
-
28
- # Heavy ML deps (torch, datasets, transformers) are imported lazily inside
29
- # ``main`` and ``build_dataset`` so the lightweight helpers reward
30
- # function, curriculum schedule, format-validity bonus remain importable
31
- # in environments that only have the env's runtime dependencies (numpy,
32
- # pydantic, openenv-core). This keeps ``tests/`` runnable on CPU.
33
-
34
- from models import ExperimentAction
35
- from server.environment import CERNCollisionEnvironment
36
- from training.llm_agent import (
37
- LLMAgentConfig,
38
- build_chat,
39
- parse_action,
40
- safe_default_action,
41
- )
42
-
43
- if TYPE_CHECKING: # pragma: no cover
44
- from datasets import Dataset
45
-
46
-
47
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
48
- logger = logging.getLogger(__name__)
49
-
50
-
51
- # ── Episode reward harness ───────────────────────────────────────────────
52
-
53
-
54
- @dataclass
55
- class EpisodeContext:
56
- """Per-prompt reusable env + default rollout config.
57
-
58
- ``seed`` and ``difficulty`` here are *fallback* values used when the
59
- TRL reward function does not receive per-prompt overrides via dataset
60
- columns. With a curriculum-aware dataset we always pass per-prompt
61
- ``seed``/``difficulty`` so the reward truly corresponds to the
62
- scored prompt.
63
- """
64
-
65
- env: CERNCollisionEnvironment
66
- seed: int
67
- scenario: Optional[str]
68
- difficulty: Optional[str]
69
-
70
-
71
- @dataclass
72
- class EpisodeStats:
73
- """Per-rollout reward breakdown surfaced for component-level logging.
74
-
75
- The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual
76
- reward function columns, not just average reward". This struct gives
77
- the EvidenceCallback enough information to log each component on its
78
- own column so a reviewer (or you) can see *which* reward terms drove
79
- the policy update at any given training step.
80
- """
81
-
82
- cumulative_reward: float = 0.0
83
- terminal_reward: float = 0.0
84
- step_shaping: float = 0.0 # cumulative_reward - terminal_reward
85
- discovered: bool = False
86
- correct_mass: bool = False
87
- correct_channel: bool = False
88
- correct_spin: bool = False
89
- parsed_ok: bool = False
90
- n_steps: int = 0
91
- difficulty: Optional[str] = None
92
-
93
-
94
- def _stepwise_reward(
95
- *,
96
- completion_text: str,
97
- ctx: EpisodeContext,
98
- seed: Optional[int] = None,
99
- difficulty: Optional[str] = None,
100
- scenario: Optional[str] = None,
101
- out_stats: Optional[EpisodeStats] = None,
102
- ) -> float:
103
- """Roll the model's first response through one full episode and
104
- return the cumulative reward (per-step + terminal).
105
-
106
- The completion is interpreted as the first action only; subsequent
107
- steps fall back to the safe default policy. This keeps the reward
108
- bandwidth high for early-exploration training without requiring
109
- multi-turn rollouts inside GRPO.
110
-
111
- If ``out_stats`` is provided, it is populated in-place with a
112
- per-rollout breakdown (terminal vs shaping reward, success flags)
113
- so the caller can stream component-level metrics into the evidence
114
- log instead of relying only on aggregate reward.
115
- """
116
- env = ctx.env
117
- obs = env.reset(
118
- seed=seed if seed is not None else ctx.seed,
119
- scenario=scenario if scenario is not None else ctx.scenario,
120
- difficulty=difficulty if difficulty is not None else ctx.difficulty,
121
- )
122
-
123
- parsed = parse_action(completion_text)
124
- action = parsed or safe_default_action(obs)
125
- obs = env.step(action)
126
- cumulative = float(obs.reward or 0.0)
127
- n_steps = 1
128
-
129
- while not obs.done:
130
- fallback = safe_default_action(obs)
131
- obs = env.step(fallback)
132
- cumulative += float(obs.reward or 0.0)
133
- n_steps += 1
134
-
135
- if out_stats is not None:
136
- st = env.state
137
- terminal = float(st.terminal_reward or 0.0)
138
- out_stats.cumulative_reward = cumulative
139
- out_stats.terminal_reward = terminal
140
- out_stats.step_shaping = cumulative - terminal
141
- out_stats.discovered = bool(st.discovered) if st.discovered is not None else False
142
- out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False
143
- out_stats.correct_channel = (
144
- bool(st.correct_channel) if st.correct_channel is not None else False
145
- )
146
- out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False
147
- out_stats.parsed_ok = parsed is not None
148
- out_stats.n_steps = n_steps
149
- out_stats.difficulty = st.difficulty
150
-
151
- return cumulative
152
-
153
-
154
- # ── Reward-component accumulator (used by EvidenceCallback) ──────────────
155
-
156
-
157
- class RewardComponentAccumulator:
158
- """Thread-safe rolling buffer of per-rollout ``EpisodeStats``.
159
-
160
- The reward function appends to this; the EvidenceCallback drains it
161
- on each ``on_log`` and writes one summary row to
162
- ``evidence/reward_components.csv``. By pairing each row with the
163
- matching GRPO ``state.global_step``, we can plot per-component reward
164
- curves *aligned* with the loss curve.
165
- """
166
-
167
- def __init__(self) -> None:
168
- self._lock = threading.Lock()
169
- self._buf: List[EpisodeStats] = []
170
-
171
- def append(self, stats: EpisodeStats) -> None:
172
- with self._lock:
173
- self._buf.append(stats)
174
-
175
- def drain(self) -> List[EpisodeStats]:
176
- with self._lock:
177
- out, self._buf = self._buf, []
178
- return out
179
-
180
- @staticmethod
181
- def summarise(stats: List[EpisodeStats]) -> Dict[str, float]:
182
- if not stats:
183
- return {
184
- "n": 0,
185
- "mean_cumulative": 0.0,
186
- "mean_terminal": 0.0,
187
- "mean_step_shaping": 0.0,
188
- "discovered_rate": 0.0,
189
- "mass_correct_rate": 0.0,
190
- "channel_correct_rate": 0.0,
191
- "spin_correct_rate": 0.0,
192
- "parsed_rate": 0.0,
193
- "mean_n_steps": 0.0,
194
- }
195
- n = len(stats)
196
- return {
197
- "n": n,
198
- "mean_cumulative": sum(s.cumulative_reward for s in stats) / n,
199
- "mean_terminal": sum(s.terminal_reward for s in stats) / n,
200
- "mean_step_shaping": sum(s.step_shaping for s in stats) / n,
201
- "discovered_rate": sum(1 for s in stats if s.discovered) / n,
202
- "mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n,
203
- "channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n,
204
- "spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n,
205
- "parsed_rate": sum(1 for s in stats if s.parsed_ok) / n,
206
- "mean_n_steps": sum(s.n_steps for s in stats) / n,
207
- }
208
-
209
-
210
- FORMAT_BONUS_VALID = 0.15
211
- FORMAT_BONUS_INVALID = -0.20
212
-
213
-
214
- def _format_validity_bonus(completion_text: str) -> float:
215
- """Small ± nudge for emitting a structured action.
216
-
217
- Kept intentionally small (≪ terminal_scale) so the policy can't be
218
- dominated by a "spam well-formed JSON" objective. The negative branch
219
- is slightly larger than the positive branch so unparseable garbage
220
- is dispreferred without crowding out the actual task reward.
221
- """
222
- return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID
223
-
224
-
225
- def make_reward_fn(
226
- ctx: EpisodeContext,
227
- accumulator: Optional[RewardComponentAccumulator] = None,
228
- ):
229
- """Return a TRL-compatible reward function.
230
-
231
- TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``)
232
- as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We
233
- use those here so the rollout used to score completion ``i`` matches
234
- the prompt that produced it, which also unlocks curriculum training.
235
-
236
- If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is
237
- appended to it so the trainer's ``on_log`` callback can flush a
238
- per-component summary into the evidence CSV that's what produces
239
- the "watch individual reward function columns" view recommended in
240
- the hackathon FAQ.
241
- """
242
-
243
- def reward_fn(
244
- prompts: List[str],
245
- completions: List[str],
246
- **kwargs: Any,
247
- ) -> List[float]:
248
- seeds = kwargs.get("seed")
249
- diffs = kwargs.get("difficulty")
250
- scenarios = kwargs.get("scenario")
251
- rewards: List[float] = []
252
- for i, completion in enumerate(completions):
253
- stats = EpisodeStats() if accumulator is not None else None
254
- r = _stepwise_reward(
255
- completion_text=completion,
256
- ctx=ctx,
257
- seed=int(seeds[i]) if seeds is not None else None,
258
- difficulty=diffs[i] if diffs is not None else None,
259
- scenario=scenarios[i] if scenarios is not None else None,
260
- out_stats=stats,
261
- )
262
- r += _format_validity_bonus(completion)
263
- rewards.append(float(r))
264
- if accumulator is not None and stats is not None:
265
- accumulator.append(stats)
266
- return rewards
267
-
268
- return reward_fn
269
-
270
-
271
- # ── Prompt dataset ───────────────────────────────────────────────────────
272
-
273
-
274
- DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [
275
- ("easy", 0.50),
276
- ("medium", 0.30),
277
- ("hard", 0.20),
278
- ]
279
-
280
-
281
- def curriculum_difficulty_for(
282
- idx: int,
283
- n_prompts: int,
284
- schedule: Optional[List[tuple]] = None,
285
- ) -> str:
286
- """Map an episode index to a difficulty using a deterministic ramp.
287
-
288
- A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6)
289
- is enough to keep early-training success rate non-zero, which is the
290
- whole point of curriculum: the policy must occasionally see positive
291
- reward before RL can move probability mass toward it.
292
- """
293
- sched = schedule or DEFAULT_CURRICULUM_SCHEDULE
294
- boundaries: List[tuple] = []
295
- cumulative = 0.0
296
- for diff, frac in sched:
297
- cumulative += frac
298
- boundaries.append((diff, cumulative * n_prompts))
299
- for diff, upper in boundaries:
300
- if idx < upper:
301
- return diff
302
- return boundaries[-1][0]
303
-
304
-
305
- def build_dataset(
306
- *,
307
- tokenizer,
308
- n_prompts: int,
309
- seed: int,
310
- scenario: Optional[str],
311
- difficulty: Optional[str],
312
- curriculum: bool = False,
313
- schedule: Optional[List[tuple]] = None,
314
- ) -> "Dataset":
315
- from datasets import Dataset # lazy: heavy import path
316
-
317
- env = CERNCollisionEnvironment()
318
- prompts: List[str] = []
319
- seeds: List[int] = []
320
- diffs: List[str] = []
321
- for i in range(n_prompts):
322
- ep_seed = seed + i
323
- ep_diff = (
324
- curriculum_difficulty_for(i, n_prompts, schedule)
325
- if curriculum else (difficulty or "easy")
326
- )
327
- obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff)
328
- chat = build_chat(obs)
329
- prompt = tokenizer.apply_chat_template(
330
- chat, add_generation_prompt=True, tokenize=False
331
- )
332
- prompts.append(prompt)
333
- seeds.append(ep_seed)
334
- diffs.append(ep_diff)
335
- return Dataset.from_dict({
336
- "prompt": prompts,
337
- "seed": seeds,
338
- "difficulty": diffs,
339
- })
340
-
341
-
342
- # ── Main ─────────────────────────────────────────────────────────────────
343
-
344
-
345
- def main() -> None: # pragma: no cover - training entrypoint
346
- parser = argparse.ArgumentParser()
347
- parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
348
- parser.add_argument("--scenario", default=None)
349
- parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
350
- parser.add_argument(
351
- "--curriculum",
352
- action="store_true",
353
- help="Build the prompt set with an easy→medium→hard ramp.",
354
- )
355
- parser.add_argument("--total_episodes", type=int, default=200)
356
- parser.add_argument("--seed", type=int, default=42)
357
- parser.add_argument("--max_steps", type=int, default=18)
358
- parser.add_argument("--num_generations", type=int, default=4)
359
- parser.add_argument("--learning_rate", type=float, default=1e-5)
360
- parser.add_argument("--max_prompt_length", type=int, default=1024)
361
- parser.add_argument("--max_completion_length", type=int, default=256)
362
- parser.add_argument("--output_dir", default="training/grpo-output")
363
- args = parser.parse_args()
364
-
365
- try:
366
- import torch
367
- from transformers import AutoModelForCausalLM, AutoTokenizer
368
- from trl import GRPOConfig, GRPOTrainer
369
- except ImportError as exc: # pragma: no cover
370
- raise SystemExit(
371
- "TRL (Transformer Reinforcement Learning) is required: "
372
- "pip install -r requirements-train.txt"
373
- ) from exc
374
-
375
- logger.info("Loading tokenizer + model: %s", args.model_name)
376
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
377
- if tokenizer.pad_token is None:
378
- tokenizer.pad_token = tokenizer.eos_token
379
- model = AutoModelForCausalLM.from_pretrained(
380
- args.model_name,
381
- torch_dtype=torch.float32,
382
- )
383
-
384
- logger.info(
385
- "Building prompt dataset (%d prompts, curriculum=%s)",
386
- args.total_episodes, args.curriculum,
387
- )
388
- dataset = build_dataset(
389
- tokenizer=tokenizer,
390
- n_prompts=args.total_episodes,
391
- seed=args.seed,
392
- scenario=args.scenario,
393
- difficulty=args.difficulty,
394
- curriculum=args.curriculum,
395
- )
396
-
397
- env = CERNCollisionEnvironment(max_steps=args.max_steps)
398
- ctx = EpisodeContext(
399
- env=env,
400
- seed=args.seed,
401
- scenario=args.scenario,
402
- difficulty=args.difficulty,
403
- )
404
- reward_fn = make_reward_fn(ctx)
405
-
406
- cfg = GRPOConfig(
407
- output_dir=args.output_dir,
408
- per_device_train_batch_size=2,
409
- gradient_accumulation_steps=2,
410
- num_generations=args.num_generations,
411
- learning_rate=args.learning_rate,
412
- max_prompt_length=args.max_prompt_length,
413
- max_completion_length=args.max_completion_length,
414
- logging_steps=5,
415
- save_steps=50,
416
- seed=args.seed,
417
- bf16=False,
418
- fp16=False,
419
- report_to=[],
420
- )
421
-
422
- trainer = GRPOTrainer(
423
- model=model,
424
- processing_class=tokenizer,
425
- train_dataset=dataset,
426
- reward_funcs=[reward_fn],
427
- args=cfg,
428
- )
429
- logger.info("Starting GRPO training")
430
- trainer.train()
431
- trainer.save_model(args.output_dir)
432
- tokenizer.save_pretrained(args.output_dir)
433
- logger.info("Saved model to %s", args.output_dir)
434
-
435
-
436
- if __name__ == "__main__": # pragma: no cover
437
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO (Group-Relative Policy Optimization) training script for CERNenv.
2
+
3
+ Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
4
+ fine-tune a small instruction-tuned model on full episodes of the CERN
5
+ environment. Each ``query`` is a prompt sampled from a freshly-reset env;
6
+ the reward function rolls the model's response through the environment and
7
+ returns the per-step + (optional) terminal reward.
8
+
9
+ This script is intentionally CPU-friendly and self-contained. For
10
+ GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
11
+
12
+ Run:
13
+ python -m training.training_script \
14
+ --model_name HuggingFaceTB/SmolLM2-360M-Instruct \
15
+ --total_episodes 200 --max_steps 18 --output_dir training/grpo-output
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import logging
22
+ import math
23
+ import os
24
+ import threading
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Dict, List, Optional, TYPE_CHECKING
28
+
29
+ # Heavy ML deps (torch, datasets, transformers) are imported lazily inside
30
+ # ``main`` and ``build_dataset`` so the lightweight helpers — reward
31
+ # function, curriculum schedule, format-validity bonus remain importable
32
+ # in environments that only have the env's runtime dependencies (numpy,
33
+ # pydantic, openenv-core). This keeps ``tests/`` runnable on CPU.
34
+
35
+ from models import ExperimentAction
36
+ from server.environment import CERNCollisionEnvironment
37
+ from training.llm_agent import (
38
+ LLMAgentConfig,
39
+ build_chat,
40
+ parse_action,
41
+ safe_default_action,
42
+ )
43
+
44
+ if TYPE_CHECKING: # pragma: no cover
45
+ from datasets import Dataset
46
+
47
+
48
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ # ── Episode reward harness ───────────────────────────────────────────────
53
+
54
+
55
+ @dataclass
56
+ class EpisodeContext:
57
+ """Per-prompt reusable env + default rollout config.
58
+
59
+ ``seed`` and ``difficulty`` here are *fallback* values used when the
60
+ TRL reward function does not receive per-prompt overrides via dataset
61
+ columns. With a curriculum-aware dataset we always pass per-prompt
62
+ ``seed``/``difficulty`` so the reward truly corresponds to the
63
+ scored prompt.
64
+ """
65
+
66
+ env: CERNCollisionEnvironment
67
+ seed: int
68
+ scenario: Optional[str]
69
+ difficulty: Optional[str]
70
+
71
+
72
+ @dataclass
73
+ class EpisodeStats:
74
+ """Per-rollout reward breakdown surfaced for component-level logging.
75
+
76
+ The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual
77
+ reward function columns, not just average reward". This struct gives
78
+ the EvidenceCallback enough information to log each component on its
79
+ own column so a reviewer (or you) can see *which* reward terms drove
80
+ the policy update at any given training step.
81
+ """
82
+
83
+ cumulative_reward: float = 0.0
84
+ terminal_reward: float = 0.0
85
+ step_shaping: float = 0.0 # cumulative_reward - terminal_reward
86
+ discovered: bool = False
87
+ correct_mass: bool = False
88
+ correct_channel: bool = False
89
+ correct_spin: bool = False
90
+ parsed_ok: bool = False
91
+ n_steps: int = 0
92
+ difficulty: Optional[str] = None
93
+
94
+
95
+ def _stepwise_reward(
96
+ *,
97
+ completion_text: str,
98
+ ctx: EpisodeContext,
99
+ seed: Optional[int] = None,
100
+ difficulty: Optional[str] = None,
101
+ scenario: Optional[str] = None,
102
+ out_stats: Optional[EpisodeStats] = None,
103
+ ) -> float:
104
+ """Roll the model's first response through one full episode and
105
+ return the cumulative reward (per-step + terminal).
106
+
107
+ The completion is interpreted as the first action only; subsequent
108
+ steps fall back to the safe default policy. This keeps the reward
109
+ bandwidth high for early-exploration training without requiring
110
+ multi-turn rollouts inside GRPO.
111
+
112
+ If ``out_stats`` is provided, it is populated in-place with a
113
+ per-rollout breakdown (terminal vs shaping reward, success flags)
114
+ so the caller can stream component-level metrics into the evidence
115
+ log instead of relying only on aggregate reward.
116
+ """
117
+ env = ctx.env
118
+ obs = env.reset(
119
+ seed=seed if seed is not None else ctx.seed,
120
+ scenario=scenario if scenario is not None else ctx.scenario,
121
+ difficulty=difficulty if difficulty is not None else ctx.difficulty,
122
+ )
123
+
124
+ parsed = parse_action(completion_text)
125
+ action = parsed or safe_default_action(obs)
126
+ obs = env.step(action)
127
+ cumulative = float(obs.reward or 0.0)
128
+ n_steps = 1
129
+
130
+ while not obs.done:
131
+ fallback = safe_default_action(obs)
132
+ obs = env.step(fallback)
133
+ cumulative += float(obs.reward or 0.0)
134
+ n_steps += 1
135
+
136
+ if out_stats is not None:
137
+ st = env.state
138
+ terminal = float(st.terminal_reward or 0.0)
139
+ out_stats.cumulative_reward = cumulative
140
+ out_stats.terminal_reward = terminal
141
+ out_stats.step_shaping = cumulative - terminal
142
+ out_stats.discovered = bool(st.discovered) if st.discovered is not None else False
143
+ out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False
144
+ out_stats.correct_channel = (
145
+ bool(st.correct_channel) if st.correct_channel is not None else False
146
+ )
147
+ out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False
148
+ out_stats.parsed_ok = parsed is not None
149
+ out_stats.n_steps = n_steps
150
+ out_stats.difficulty = st.difficulty
151
+
152
+ return cumulative
153
+
154
+
155
+ # ── Reward-component accumulator (used by EvidenceCallback) ──────────────
156
+
157
+
158
+ class RewardComponentAccumulator:
159
+ """Thread-safe rolling buffer of per-rollout ``EpisodeStats``.
160
+
161
+ The reward function appends to this; the EvidenceCallback drains it
162
+ on each ``on_log`` and writes one summary row to
163
+ ``evidence/reward_components.csv``. By pairing each row with the
164
+ matching GRPO ``state.global_step``, we can plot per-component reward
165
+ curves *aligned* with the loss curve.
166
+ """
167
+
168
+ def __init__(self) -> None:
169
+ self._lock = threading.Lock()
170
+ self._buf: List[EpisodeStats] = []
171
+
172
+ def append(self, stats: EpisodeStats) -> None:
173
+ with self._lock:
174
+ self._buf.append(stats)
175
+
176
+ def drain(self) -> List[EpisodeStats]:
177
+ with self._lock:
178
+ out, self._buf = self._buf, []
179
+ return out
180
+
181
+ @staticmethod
182
+ def summarise(stats: List[EpisodeStats]) -> Dict[str, float]:
183
+ if not stats:
184
+ return {
185
+ "n": 0,
186
+ "mean_cumulative": 0.0,
187
+ "mean_terminal": 0.0,
188
+ "mean_step_shaping": 0.0,
189
+ "discovered_rate": 0.0,
190
+ "mass_correct_rate": 0.0,
191
+ "channel_correct_rate": 0.0,
192
+ "spin_correct_rate": 0.0,
193
+ "parsed_rate": 0.0,
194
+ "mean_n_steps": 0.0,
195
+ }
196
+ n = len(stats)
197
+ return {
198
+ "n": n,
199
+ "mean_cumulative": sum(s.cumulative_reward for s in stats) / n,
200
+ "mean_terminal": sum(s.terminal_reward for s in stats) / n,
201
+ "mean_step_shaping": sum(s.step_shaping for s in stats) / n,
202
+ "discovered_rate": sum(1 for s in stats if s.discovered) / n,
203
+ "mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n,
204
+ "channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n,
205
+ "spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n,
206
+ "parsed_rate": sum(1 for s in stats if s.parsed_ok) / n,
207
+ "mean_n_steps": sum(s.n_steps for s in stats) / n,
208
+ }
209
+
210
+
211
+ FORMAT_BONUS_VALID = 0.15
212
+ FORMAT_BONUS_INVALID = -0.20
213
+
214
+
215
+ def _format_validity_bonus(completion_text: str) -> float:
216
+ """Small ± nudge for emitting a structured action.
217
+
218
+ Kept intentionally small (≪ terminal_scale) so the policy can't be
219
+ dominated by a "spam well-formed JSON" objective. The negative branch
220
+ is slightly larger than the positive branch so unparseable garbage
221
+ is dispreferred without crowding out the actual task reward.
222
+ """
223
+ return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID
224
+
225
+
226
+ def make_reward_fn(
227
+ ctx: EpisodeContext,
228
+ accumulator: Optional[RewardComponentAccumulator] = None,
229
+ ):
230
+ """Return a TRL-compatible reward function.
231
+
232
+ TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``)
233
+ as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We
234
+ use those here so the rollout used to score completion ``i`` matches
235
+ the prompt that produced it, which also unlocks curriculum training.
236
+
237
+ If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is
238
+ appended to it so the trainer's ``on_log`` callback can flush a
239
+ per-component summary into the evidence CSV that's what produces
240
+ the "watch individual reward function columns" view recommended in
241
+ the hackathon FAQ.
242
+ """
243
+
244
+ def reward_fn(
245
+ prompts: List[str],
246
+ completions: List[str],
247
+ **kwargs: Any,
248
+ ) -> List[float]:
249
+ seeds = kwargs.get("seed")
250
+ diffs = kwargs.get("difficulty")
251
+ scenarios = kwargs.get("scenario")
252
+ rewards: List[float] = []
253
+ for i, completion in enumerate(completions):
254
+ stats = EpisodeStats() if accumulator is not None else None
255
+ r = _stepwise_reward(
256
+ completion_text=completion,
257
+ ctx=ctx,
258
+ seed=int(seeds[i]) if seeds is not None else None,
259
+ difficulty=diffs[i] if diffs is not None else None,
260
+ scenario=scenarios[i] if scenarios is not None else None,
261
+ out_stats=stats,
262
+ )
263
+ r += _format_validity_bonus(completion)
264
+ rewards.append(float(r))
265
+ if accumulator is not None and stats is not None:
266
+ accumulator.append(stats)
267
+ return rewards
268
+
269
+ return reward_fn
270
+
271
+
272
+ # ── Prompt dataset ───────────────────────────────────────────────────────
273
+
274
+
275
+ DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [
276
+ ("easy", 0.50),
277
+ ("medium", 0.30),
278
+ ("hard", 0.20),
279
+ ]
280
+
281
+
282
+ def curriculum_difficulty_for(
283
+ idx: int,
284
+ n_prompts: int,
285
+ schedule: Optional[List[tuple]] = None,
286
+ ) -> str:
287
+ """Map an episode index to a difficulty using a deterministic ramp.
288
+
289
+ A simple "easy first harder later" schedule (FAQ Q14, help-guide §6)
290
+ is enough to keep early-training success rate non-zero, which is the
291
+ whole point of curriculum: the policy must occasionally see positive
292
+ reward before RL can move probability mass toward it.
293
+ """
294
+ sched = schedule or DEFAULT_CURRICULUM_SCHEDULE
295
+ boundaries: List[tuple] = []
296
+ cumulative = 0.0
297
+ for diff, frac in sched:
298
+ cumulative += frac
299
+ boundaries.append((diff, cumulative * n_prompts))
300
+ for diff, upper in boundaries:
301
+ if idx < upper:
302
+ return diff
303
+ return boundaries[-1][0]
304
+
305
+
306
+ def build_dataset(
307
+ *,
308
+ tokenizer,
309
+ n_prompts: int,
310
+ seed: int,
311
+ scenario: Optional[str],
312
+ difficulty: Optional[str],
313
+ curriculum: bool = False,
314
+ schedule: Optional[List[tuple]] = None,
315
+ ) -> "Dataset":
316
+ from datasets import Dataset # lazy: heavy import path
317
+
318
+ env = CERNCollisionEnvironment()
319
+ prompts: List[str] = []
320
+ seeds: List[int] = []
321
+ diffs: List[str] = []
322
+ for i in range(n_prompts):
323
+ ep_seed = seed + i
324
+ ep_diff = (
325
+ curriculum_difficulty_for(i, n_prompts, schedule)
326
+ if curriculum else (difficulty or "easy")
327
+ )
328
+ obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff)
329
+ chat = build_chat(obs)
330
+ prompt = tokenizer.apply_chat_template(
331
+ chat, add_generation_prompt=True, tokenize=False
332
+ )
333
+ prompts.append(prompt)
334
+ seeds.append(ep_seed)
335
+ diffs.append(ep_diff)
336
+ return Dataset.from_dict({
337
+ "prompt": prompts,
338
+ "seed": seeds,
339
+ "difficulty": diffs,
340
+ })
341
+
342
+
343
+ # ── Main ─────────────────────────────────────────────────────────────────
344
+
345
+
346
+ def main() -> None: # pragma: no cover - training entrypoint
347
+ parser = argparse.ArgumentParser()
348
+ parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
349
+ parser.add_argument("--scenario", default=None)
350
+ parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
351
+ parser.add_argument(
352
+ "--curriculum",
353
+ action="store_true",
354
+ help="Build the prompt set with an easy→medium→hard ramp.",
355
+ )
356
+ parser.add_argument("--total_episodes", type=int, default=200)
357
+ parser.add_argument("--seed", type=int, default=42)
358
+ parser.add_argument("--max_steps", type=int, default=18)
359
+ parser.add_argument("--num_generations", type=int, default=4)
360
+ parser.add_argument("--learning_rate", type=float, default=1e-5)
361
+ parser.add_argument("--max_prompt_length", type=int, default=1024)
362
+ parser.add_argument("--max_completion_length", type=int, default=256)
363
+ parser.add_argument("--output_dir", default="training/grpo-output")
364
+ parser.add_argument(
365
+ "--evidence_dir",
366
+ default="evidence",
367
+ help="Directory for training_log.csv, reward_components.csv, "
368
+ "checkpoint_evals.csv and the corresponding *.png plots.",
369
+ )
370
+ parser.add_argument(
371
+ "--checkpoint_eval_steps",
372
+ type=int,
373
+ default=25,
374
+ help="Run a held-out eval every N GRPO updates for the progression curve.",
375
+ )
376
+ parser.add_argument(
377
+ "--checkpoint_eval_episodes",
378
+ type=int,
379
+ default=8,
380
+ help="Number of held-out episodes per mid-training eval.",
381
+ )
382
+ args = parser.parse_args()
383
+
384
+ try:
385
+ import torch
386
+ from transformers import AutoModelForCausalLM, AutoTokenizer
387
+ from trl import GRPOConfig, GRPOTrainer
388
+ except ImportError as exc: # pragma: no cover
389
+ raise SystemExit(
390
+ "TRL (Transformer Reinforcement Learning) is required: "
391
+ "pip install -r requirements-train.txt"
392
+ ) from exc
393
+
394
+ logger.info("Loading tokenizer + model: %s", args.model_name)
395
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
396
+ if tokenizer.pad_token is None:
397
+ tokenizer.pad_token = tokenizer.eos_token
398
+ model = AutoModelForCausalLM.from_pretrained(
399
+ args.model_name,
400
+ torch_dtype=torch.float32,
401
+ )
402
+
403
+ logger.info(
404
+ "Building prompt dataset (%d prompts, curriculum=%s)",
405
+ args.total_episodes, args.curriculum,
406
+ )
407
+ dataset = build_dataset(
408
+ tokenizer=tokenizer,
409
+ n_prompts=args.total_episodes,
410
+ seed=args.seed,
411
+ scenario=args.scenario,
412
+ difficulty=args.difficulty,
413
+ curriculum=args.curriculum,
414
+ )
415
+
416
+ env = CERNCollisionEnvironment(max_steps=args.max_steps)
417
+ ctx = EpisodeContext(
418
+ env=env,
419
+ seed=args.seed,
420
+ scenario=args.scenario,
421
+ difficulty=args.difficulty,
422
+ )
423
+
424
+ # ── Evidence wiring (training_log.csv / reward_components.csv /
425
+ # checkpoint_evals.csv + PNG plots). Mirrors training_unsloth.py so the
426
+ # vanilla GRPO backend hydrates the same dashboard cards. The render
427
+ # helpers are best-effort: matplotlib import failures are swallowed and
428
+ # the corresponding PNG is skipped, never crashing training.
429
+ import time as _time
430
+ from transformers import TrainerCallback
431
+ from training.evidence import (
432
+ CheckpointEvalWriter,
433
+ EvidencePaths,
434
+ RewardComponentLogWriter,
435
+ TrainingLogWriter,
436
+ render_checkpoint_progression,
437
+ render_reward_components,
438
+ render_training_curve,
439
+ )
440
+ from training.llm_agent import LLMAgentConfig
441
+ from training.rollouts import collect_episode
442
+
443
+ paths = EvidencePaths(root=Path(args.evidence_dir))
444
+ paths.ensure()
445
+ log_writer = TrainingLogWriter(paths.training_log_csv)
446
+ ckpt_writer = CheckpointEvalWriter(paths.checkpoint_evals_csv)
447
+ component_writer = RewardComponentLogWriter(paths.reward_components_csv)
448
+ component_accumulator = RewardComponentAccumulator()
449
+ held_out_seeds = list(range(900_000, 900_000 + args.checkpoint_eval_episodes))
450
+
451
+ reward_fn = make_reward_fn(ctx, accumulator=component_accumulator)
452
+
453
+ cfg = GRPOConfig(
454
+ output_dir=args.output_dir,
455
+ per_device_train_batch_size=2,
456
+ gradient_accumulation_steps=2,
457
+ num_generations=args.num_generations,
458
+ learning_rate=args.learning_rate,
459
+ max_prompt_length=args.max_prompt_length,
460
+ max_completion_length=args.max_completion_length,
461
+ logging_steps=5,
462
+ save_steps=50,
463
+ seed=args.seed,
464
+ bf16=False,
465
+ fp16=False,
466
+ report_to=[],
467
+ )
468
+
469
+ class EvidenceCallback(TrainerCallback):
470
+ """Stream training metrics + run periodic mid-training evals.
471
+
472
+ Backported from training/training_unsloth.py so the vanilla GRPO
473
+ path produces the same evidence/*.csv + *.png artefacts the
474
+ dashboard reads. Differs from the Unsloth version only in the
475
+ train/eval mode toggle: plain transformers uses model.eval() /
476
+ model.train() instead of FastLanguageModel.for_inference().
477
+ """
478
+
479
+ def __init__(self) -> None:
480
+ self._t0 = _time.time()
481
+ self._last_eval_step = -1
482
+
483
+ def on_log(self, _args, state, control, logs=None, **kw):
484
+ logs = logs or {}
485
+ row = {
486
+ "step": state.global_step,
487
+ "epoch": logs.get("epoch"),
488
+ "loss": logs.get("loss"),
489
+ "reward": logs.get("reward") or logs.get("rewards/mean"),
490
+ "reward_std": logs.get("reward_std") or logs.get("rewards/std"),
491
+ "kl": logs.get("kl"),
492
+ "grad_norm": logs.get("grad_norm"),
493
+ "learning_rate": logs.get("learning_rate"),
494
+ "wall_time_s": round(_time.time() - self._t0, 2),
495
+ }
496
+ if any(v is not None for k, v in row.items() if k != "step"):
497
+ log_writer.append(row)
498
+ try:
499
+ render_training_curve(paths.training_log_csv, paths.training_curve_png)
500
+ except Exception as exc: # pragma: no cover - plotting is best-effort
501
+ logger.warning("training curve render failed: %s", exc)
502
+
503
+ drained = component_accumulator.drain()
504
+ if drained:
505
+ summary = RewardComponentAccumulator.summarise(drained)
506
+ summary["step"] = state.global_step
507
+ component_writer.append(summary)
508
+ try:
509
+ render_reward_components(
510
+ paths.reward_components_csv, paths.reward_components_png,
511
+ )
512
+ except Exception as exc: # pragma: no cover
513
+ logger.warning("reward components render failed: %s", exc)
514
+
515
+ def on_step_end(self, _args, state, control, **kw):
516
+ step = state.global_step
517
+ if step <= 0 or step == self._last_eval_step:
518
+ return control
519
+ if step % args.checkpoint_eval_steps != 0:
520
+ return control
521
+ self._last_eval_step = step
522
+ try:
523
+ self._run_checkpoint_eval(step, state)
524
+ except Exception as exc:
525
+ logger.warning("checkpoint eval failed at step %d: %s", step, exc)
526
+ return control
527
+
528
+ def _run_checkpoint_eval(self, step: int, state) -> None:
529
+ was_training = model.training
530
+ model.eval()
531
+ try:
532
+ episodes = []
533
+ for s in held_out_seeds:
534
+ ep = self._rollout_one(seed=s)
535
+ if ep is not None:
536
+ episodes.append(ep)
537
+ if not episodes:
538
+ return
539
+ rewards = [e.cumulative_reward for e in episodes]
540
+ success_rate = sum(1 for e in episodes if e.discovered) / len(episodes)
541
+ ckpt_writer.append(
542
+ step=step,
543
+ fraction_done=round(step / max(state.max_steps or step, 1), 4),
544
+ episodes=len(episodes),
545
+ mean_reward=round(sum(rewards) / len(rewards), 4),
546
+ success_rate=round(success_rate, 4),
547
+ mass_acc=round(
548
+ sum(1 for e in episodes if e.correct_mass) / len(episodes), 4,
549
+ ),
550
+ channel_acc=round(
551
+ sum(1 for e in episodes if e.correct_channel) / len(episodes), 4,
552
+ ),
553
+ )
554
+ try:
555
+ render_checkpoint_progression(
556
+ paths.checkpoint_evals_csv,
557
+ paths.checkpoint_progression_png,
558
+ )
559
+ except Exception as exc: # pragma: no cover
560
+ logger.warning("checkpoint progression render failed: %s", exc)
561
+ logger.info(
562
+ "[checkpoint-eval step=%d] reward=%.3f success=%.2f",
563
+ step,
564
+ sum(rewards) / len(rewards) if rewards else 0.0,
565
+ success_rate,
566
+ )
567
+ finally:
568
+ if was_training:
569
+ model.train()
570
+
571
+ def _rollout_one(self, seed: int):
572
+ def prompt_fn(chat):
573
+ return tokenizer.apply_chat_template(
574
+ chat, add_generation_prompt=True, tokenize=False,
575
+ )
576
+
577
+ def generate_fn(prompt: str, _config) -> str:
578
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
579
+ with torch.no_grad():
580
+ outputs = model.generate(
581
+ **inputs,
582
+ max_new_tokens=args.max_completion_length,
583
+ do_sample=True,
584
+ temperature=0.7,
585
+ top_p=0.95,
586
+ pad_token_id=tokenizer.pad_token_id,
587
+ )
588
+ gen = outputs[0][inputs["input_ids"].shape[1]:]
589
+ return tokenizer.decode(gen, skip_special_tokens=True)
590
+
591
+ return collect_episode(
592
+ env=env,
593
+ seed=seed,
594
+ scenario=args.scenario,
595
+ difficulty=args.difficulty,
596
+ prompt_fn=prompt_fn,
597
+ generate_fn=generate_fn,
598
+ config=LLMAgentConfig(),
599
+ )
600
+
601
+ trainer = GRPOTrainer(
602
+ model=model,
603
+ processing_class=tokenizer,
604
+ train_dataset=dataset,
605
+ reward_funcs=[reward_fn],
606
+ args=cfg,
607
+ callbacks=[EvidenceCallback()],
608
+ )
609
+ logger.info("Starting GRPO training")
610
+ trainer.train()
611
+
612
+ # Drain any rollouts the final on_log didn't catch so the last row of
613
+ # reward_components.csv reflects the end-of-training state.
614
+ final_drain = component_accumulator.drain()
615
+ if final_drain:
616
+ summary = RewardComponentAccumulator.summarise(final_drain)
617
+ summary["step"] = trainer.state.global_step
618
+ component_writer.append(summary)
619
+ try:
620
+ render_reward_components(
621
+ paths.reward_components_csv, paths.reward_components_png,
622
+ )
623
+ except Exception as exc: # pragma: no cover
624
+ logger.warning("final reward components render failed: %s", exc)
625
+
626
+ trainer.save_model(args.output_dir)
627
+ tokenizer.save_pretrained(args.output_dir)
628
+ logger.info("Saved model to %s", args.output_dir)
629
+ logger.info("Evidence artifacts in %s", paths.root)
630
+
631
+
632
+ if __name__ == "__main__": # pragma: no cover
633
+ main()