Pratyush-01 commited on
Commit
b4bd6d8
·
verified ·
1 Parent(s): b788dab

cleanup: strip verbose comments from physix/training/loop.py

Browse files
Files changed (1) hide show
  1. physix/training/loop.py +18 -255
physix/training/loop.py CHANGED
@@ -1,21 +1,6 @@
1
- """GRPO training loop using Unsloth + TRL + W&B.
2
 
3
- Requires the ``[train]`` optional dependency group. Importing this module on
4
- a machine without the heavy ML deps installed will fail at module load,
5
- which is the documented contract — local development tools (env server,
6
- verifier, demo UI) live in lighter modules and remain usable.
7
-
8
- Run via::
9
-
10
- python -m physix.training.loop \
11
- --model Qwen/Qwen2.5-1.5B-Instruct \
12
- --output-dir runs/physix-1.5b-rl \
13
- --num-steps 300
14
-
15
- Environment variables:
16
-
17
- - ``WANDB_PROJECT`` (default ``physix-live``)
18
- - ``HUGGINGFACE_HUB_TOKEN`` if pushing the adapter to the Hub
19
  """
20
 
21
  from __future__ import annotations
@@ -40,17 +25,8 @@ from physix.training.dataset import (
40
  from physix.training.reward_fns import make_reward_funcs
41
  from physix.training.scorer import Scorer
42
 
43
- # IMPORTANT: Unsloth's GRPO patches must be applied *before* importing
44
- # ``GRPOTrainer`` so its kernels are swapped in. Without this, the trainer
45
- # falls back to the stock TRL path and Unsloth's optimisations are bypassed
46
- # (and on recent versions the import will hard-fail). Keep this block
47
- # directly above the ``trl`` import — order matters.
48
- #
49
- # Version note: this requires ``trl<=0.24.0``. Newer TRL versions ship
50
- # ``trl.experimental.openenv`` which Unsloth's ``patch_trl_openenv``
51
- # hook tries to ``inspect.getsource()`` on; that fails with ``OSError:
52
- # could not get source code`` and crashes ``PatchFastRL``. ``trl==0.24.0``
53
- # is the pinned upper bound declared in unsloth's pyproject.toml.
54
  from unsloth import FastLanguageModel, PatchFastRL # noqa: E402
55
 
56
  PatchFastRL("GRPO", FastLanguageModel)
@@ -71,19 +47,9 @@ class TrainingConfig(BaseModel):
71
  model_config = ConfigDict(frozen=True)
72
 
73
  model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
74
- #: Optional path to a LoRA adapter produced by the SFT warm-start step.
75
- #: When set, the base model is loaded and the adapter weights are applied
76
- #: before GRPO begins. Without this the cold base model rarely produces
77
- #: any reward signal in early steps.
78
  sft_checkpoint: Optional[str] = None
79
- #: Optional Hub repo id (or local path) of an existing LoRA adapter to
80
- #: warm-start GRPO from — e.g. a previous GRPO run that was interrupted
81
- #: and pushed checkpoints to ``hub_checkpoint_repo_id``. When set, the
82
- #: base ``model_name`` is loaded and this adapter is applied as the
83
- #: starting trainable LoRA (skipping the fresh ``get_peft_model`` call).
84
- #: SFT is unnecessary in this case (the adapter is already downstream
85
- #: of an SFT warm-start), so leave ``sft_checkpoint`` unset when using
86
- #: this flag.
87
  lora_adapter_repo: Optional[str] = None
88
  output_dir: str = "runs/physix-1.5b-rl"
89
  max_seq_length: int = 2048
@@ -97,31 +63,19 @@ class TrainingConfig(BaseModel):
97
  per_device_train_batch_size: int = 1
98
  gradient_accumulation_steps: int = 8
99
  num_steps: int = 300
100
- #: Stop early if ``reward_std`` stays below 0.05 for this many consecutive
101
- #: logged steps. Set to 0 to disable early stopping.
102
  early_stop_patience: int = 50
103
  seed: int = 0
104
  instances_per_system: int = 32
105
- #: Subset of system IDs to train on. Defaults to all SUPPORTED_SYSTEMS.
106
- #: Pass a single ID (e.g. ``("damped_spring",)``) for focused single-task runs.
107
  system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS
108
  ablation: Optional[Ablation] = None
109
  wandb_project: str = "physix-live"
110
  wandb_run_name: Optional[str] = None
111
  push_to_hub: bool = False
112
  hub_repo_id: Optional[str] = None
113
- #: HF repo to push LoRA checkpoints to every save_steps during GRPO.
114
- #: Separate from hub_repo_id (which receives the final merged model).
115
- #: Set this to enable mid-run checkpoint persistence and W&B artifact logging.
116
  hub_checkpoint_repo_id: Optional[str] = None
117
- #: Path to a Trainer checkpoint dir to resume GRPO from (e.g. from a
118
- #: previous run killed mid-training). Set automatically by train.sh.
119
  resume_from_checkpoint: Optional[str] = None
120
- #: How to persist the final adapter. ``"lora"`` saves only the adapter
121
- #: weights (small, requires the base model at load time). ``"merged_16bit"``
122
- #: merges the adapter into the base and saves a deployable bf16/fp16
123
- #: checkpoint (large, but loadable as a normal HF model — what you want
124
- #: for Hub pushes and Ollama exports).
125
  save_method: SaveMethod = "merged_16bit"
126
 
127
 
@@ -140,8 +94,6 @@ def train(config: TrainingConfig) -> None:
140
  resume="allow",
141
  )
142
 
143
- # Pin a few high-signal pointers into the run summary right away so the
144
- # W&B "Overview" tab shows them prominently (no scrolling, no hunting).
145
  if config.hub_checkpoint_repo_id:
146
  ckpt_url = f"https://huggingface.co/{config.hub_checkpoint_repo_id}"
147
  wandb.run.summary["checkpoint/repo"] = config.hub_checkpoint_repo_id
@@ -156,8 +108,6 @@ def train(config: TrainingConfig) -> None:
156
  wandb.run.summary["resume/from_url"] = (
157
  f"https://huggingface.co/{config.lora_adapter_repo}"
158
  )
159
- # If a parent W&B run is named (set by the orchestrator script),
160
- # surface it prominently so the lineage is one click away.
161
  parent_run = os.environ.get("WANDB_RESUMED_FROM")
162
  if parent_run:
163
  wandb.run.summary["resume/parent_wandb_run"] = parent_run
@@ -217,21 +167,7 @@ def train(config: TrainingConfig) -> None:
217
 
218
 
219
  def _log_reward_summary(trainer: "GRPOTrainer") -> None:
220
- """Emit a final reward-signal summary at end of training.
221
-
222
- Pulls the last ``log_history`` entry that contains reward keys and prints
223
- the mean of every ``rewards/*/mean`` it finds. If *no* reward keys are
224
- present we hard-fail — that means the reward functions never produced a
225
- non-NaN value, which is a real bug worth surfacing.
226
-
227
- Note on ``train/loss``: this scalar IS the GRPO surrogate objective
228
- (advantage-weighted token log-probabilities, plus the KL-to-ref penalty
229
- when ``beta > 0``). Per the TRL docs (``trl/docs/source/grpo_trainer.md``)
230
- the ``Trainer`` superclass logs the full surrogate as ``loss``, not just
231
- the KL term. So ``train/loss`` collapsing without ``train/reward`` rising
232
- is a real failure mode — typically a sign of reward hacking or saturated
233
- advantages — and should be debugged, not dismissed.
234
- """
235
  history = getattr(trainer.state, "log_history", []) or []
236
  reward_entries = [
237
  entry for entry in history
@@ -257,16 +193,6 @@ def _log_reward_summary(trainer: "GRPOTrainer") -> None:
257
  v1 = last.get(key)
258
  if isinstance(v0, (int, float)) and isinstance(v1, (int, float)):
259
  _log.info(" %-40s %.4f → %.4f (Δ=%+.4f)", key, v0, v1, v1 - v0)
260
- _log.info("-" * 60)
261
- _log.info("Interpretation guide:")
262
- _log.info(" train/loss — full GRPO surrogate (policy + KL*beta).")
263
- _log.info(" Should DECREASE as advantages get exploited.")
264
- _log.info(" train/reward — mean episode reward across rollouts.")
265
- _log.info(" Should INCREASE; this is the headline curve.")
266
- _log.info(" train/kl — KL(policy || ref). Should grow slowly.")
267
- _log.info(" rewards/*/mean — per-component reward (match, simplicity, …).")
268
- _log.info("Loss-down WITHOUT reward-up is a red flag (reward hacking or")
269
- _log.info("advantage saturation).")
270
  _log.info("=" * 60)
271
 
272
 
@@ -274,32 +200,7 @@ def _render_training_curves(
274
  trainer: "GRPOTrainer",
275
  config: TrainingConfig,
276
  ) -> None:
277
- """Render the headline training curves to PNG and ship them.
278
-
279
- Why we do this in-process at end of training (instead of pulling from
280
- W&B post-hoc):
281
-
282
- 1. The competition's automated validation requires PNG plots committed
283
- to the public repo at submission time. Wandb-only links don't count.
284
- 2. ``trainer.state.log_history`` already contains every metric the
285
- Trainer logged step-by-step — no API roundtrip needed.
286
- 3. We can also push the PNGs to the model Hub repo so they're discoverable
287
- from the model card without a separate deploy step.
288
-
289
- Renders three curves:
290
-
291
- - ``loss.png`` — ``train/loss`` over global step.
292
- GRPO surrogate; SHOULD trend down.
293
- - ``reward.png`` — ``reward`` (or ``train/reward``) over step
294
- with ±1σ band. SHOULD trend up.
295
- - ``reward_components.png`` — overlay of every ``rewards/<name>/mean``
296
- so reward hacking shows up visually
297
- (e.g. ``simplicity`` rising while
298
- ``match`` regresses).
299
-
300
- Failures are logged and swallowed — a missing plot must not crash a
301
- successful training run, since the model artefact is still useful.
302
- """
303
  try:
304
  import matplotlib
305
  matplotlib.use("Agg") # headless / no display server in HF Jobs
@@ -329,7 +230,6 @@ def _render_training_curves(
329
 
330
  rendered: list[Path] = []
331
 
332
- # 1) Loss — the GRPO surrogate.
333
  steps_l, losses = _series("loss")
334
  if steps_l:
335
  fig, ax = plt.subplots(figsize=(8, 4.5))
@@ -346,7 +246,6 @@ def _render_training_curves(
346
  else:
347
  _log.warning("No 'loss' entries in log_history.")
348
 
349
- # 2) Reward — headline curve (with ±std band when available).
350
  steps_r, rewards = _series("reward")
351
  _, reward_std = _series("reward_std")
352
  if steps_r:
@@ -371,7 +270,6 @@ def _render_training_curves(
371
  else:
372
  _log.warning("No 'reward' entries in log_history.")
373
 
374
- # 3) Per-component reward overlay — exposes reward hacking patterns.
375
  component_keys = sorted({
376
  k for entry in history for k in entry
377
  if k.startswith("rewards/") and k.endswith("/mean")
@@ -400,8 +298,6 @@ def _render_training_curves(
400
 
401
  _log.info("Rendered %d curve PNG(s) to %s", len(rendered), plots_dir)
402
 
403
- # Log the PNGs as wandb.Images so they appear in the run's Media tab,
404
- # and persist to the run summary as a reference table.
405
  try:
406
  import wandb
407
  if wandb.run is not None:
@@ -412,8 +308,6 @@ def _render_training_curves(
412
  except Exception as exc: # noqa: BLE001
413
  _log.warning("Could not log plots to wandb: %s", exc)
414
 
415
- # Push PNGs to the final Hub model repo under ``plots/`` so the model
416
- # card can render them and ``sync-plots.sh`` can pull them locally.
417
  if config.push_to_hub and config.hub_repo_id:
418
  try:
419
  from huggingface_hub import HfApi, create_repo
@@ -445,22 +339,8 @@ def _render_training_curves(
445
  def _load_model_and_tokenizer(
446
  config: TrainingConfig,
447
  ) -> tuple[FastLanguageModel, AutoTokenizer]:
448
- """Load Qwen via Unsloth in 4-bit and attach a LoRA adapter.
449
-
450
- If ``config.sft_checkpoint`` is set, the SFT adapter weights are merged
451
- on top of the base model before GRPO starts. This gives GRPO a warm base
452
- policy that already knows the JSON format and equation grammar, so early
453
- rollouts produce meaningful reward signal instead of all scoring zero.
454
- """
455
  if config.lora_adapter_repo:
456
- # Resume path: load the base model and attach the existing LoRA
457
- # adapter via PEFT. We deliberately do NOT call
458
- # ``FastLanguageModel.from_pretrained(model_name=adapter_repo)``
459
- # because the adapter's ``adapter_config.json`` may carry a stale
460
- # ``base_model_name_or_path`` pointing at a path that only existed
461
- # inside the previous training container (e.g. ``/tmp/physix-sft/merged``).
462
- # PEFT's ``load_adapter`` ignores that field — it adapts onto whatever
463
- # base we hand it.
464
  _log.info(
465
  "Resuming from existing LoRA adapter %s on top of %s",
466
  config.lora_adapter_repo,
@@ -472,12 +352,6 @@ def _load_model_and_tokenizer(
472
  load_in_4bit=True,
473
  dtype=None,
474
  )
475
- # Wrap the base in a fresh trainable LoRA, then overwrite its
476
- # weights with the saved adapter. We use the adapter's own r/alpha
477
- # by relying on PEFT's ``load_adapter`` resolving from the repo's
478
- # adapter_config.json. The dummy ``get_peft_model`` call is just to
479
- # turn the model into a ``PeftModel`` instance whose ``load_adapter``
480
- # method accepts a hub repo id.
481
  model = FastLanguageModel.get_peft_model(
482
  model,
483
  r=config.lora_r,
@@ -490,8 +364,6 @@ def _load_model_and_tokenizer(
490
  use_gradient_checkpointing="unsloth",
491
  random_state=config.seed,
492
  )
493
- # Overwrite the freshly-initialised LoRA weights with the saved ones.
494
- # ``adapter_name='default'`` matches what ``get_peft_model`` creates.
495
  model.load_adapter(
496
  config.lora_adapter_repo,
497
  adapter_name="default",
@@ -571,28 +443,7 @@ def _build_and_format_dataset(
571
 
572
 
573
  def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]:
574
- """Return the GRPO reward function set.
575
-
576
- Default set (5 functions, summed by GRPOTrainer into the advantage):
577
-
578
- - ``reward_match`` — raw R² (linear).
579
- - ``reward_match_dense`` — sqrt(R²); dense low-value gradient.
580
- - ``reward_correctness`` — binary cliff at R² ≥ 0.70.
581
- - ``reward_simplicity`` — gated on R² ≥ 0.10 (anti-hack).
582
- - ``reward_format`` — 1.0 only if parsed AND simulated.
583
-
584
- Why this composition: empirically (RCA from W&B run 5kuqns9x) the
585
- previous ``{match, progress, simplicity, format}`` mix had a
586
- progress-equals-match duplicate (single-turn ``previous_r_match=0``)
587
- AND let the model farm format+simplicity by emitting trivial
588
- parseable equations. The new set both removes the duplicate and
589
- triple-weights correctness via three different correctness-shaped
590
- signals (match, match_dense, correctness_bonus) so that physical
591
- accuracy dominates the GRPO advantage.
592
-
593
- Ablations strip one signal at a time (used by the experiment matrix,
594
- not by the main runs).
595
- """
596
  scorer = Scorer()
597
  funcs = make_reward_funcs(scorer)
598
  full = [
@@ -609,10 +460,7 @@ def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]:
609
  if ablation == "no_format":
610
  return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["simplicity"]]
611
  if ablation == "no_progress":
612
- # Backward-compat alias: ``progress`` no longer exists, the new
613
- # reward set already excludes it. Treat ``no_progress`` as the
614
- # full default set so old job configs still work without surprise.
615
- return full
616
  raise ValueError(
617
  f"Unknown ablation {ablation!r}. Choose from "
618
  "no_progress | no_simplicity | no_format | None."
@@ -620,17 +468,7 @@ def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]:
620
 
621
 
622
  class _RewardConvergenceCallback(TrainerCallback):
623
- """Stop training early when the GRPO reward has converged.
624
-
625
- Convergence criterion: ``reward_std`` (std of total reward across the
626
- rollout batch) stays below ``min_std`` for ``patience`` consecutive
627
- logged steps. When ``reward_std ≈ 0`` every generation scores the
628
- same, so the GRPO advantage estimates are all zero and the policy
629
- gradient vanishes — continuing burns compute without learning.
630
-
631
- The callback also logs the early-stop event to W&B so the decision
632
- is visible on the run page.
633
- """
634
 
635
  def __init__(self, patience: int = 50, min_std: float = 0.05) -> None:
636
  self._patience = patience
@@ -679,37 +517,12 @@ class _RewardConvergenceCallback(TrainerCallback):
679
 
680
 
681
  class _WandbCheckpointCallback(TrainerCallback):
682
- """Make checkpoints first-class in W&B.
683
-
684
- After every Trainer save, this callback:
685
-
686
- 1. Resolves the latest commit hash on the Hub repo (best-effort — the
687
- trainer's own ``PushToHubCallback`` runs ``git push`` asynchronously
688
- so we may briefly see an older commit; that is fine, it self-corrects
689
- on the next save).
690
- 2. Updates the W&B run summary with persistent, prominent keys
691
- (visible in the "Overview" tab of the run):
692
- - ``checkpoint/last_step``
693
- - ``checkpoint/last_commit``
694
- - ``checkpoint/repo_url``
695
- - ``checkpoint/last_url``
696
- 3. Logs a step-indexed scalar ``checkpoint/step`` so a chart appears
697
- on the W&B run page (one tick per save).
698
- 4. Maintains a running ``checkpoint_history`` ``wandb.Table`` so every
699
- saved checkpoint is browsable as a sortable table directly on the
700
- run page (Tables tab).
701
- 5. Prints a banner to stdout (visible in ``hf jobs logs``) with the
702
- direct URL — so the checkpoint is also impossible to miss in the
703
- job logs.
704
-
705
- No model bytes are uploaded to W&B; the actual weights live on the HF
706
- Hub checkpoint repo. We never crash training if any of this fails.
707
- """
708
 
709
  def __init__(self, hub_checkpoint_repo_id: str) -> None:
710
  self._repo = hub_checkpoint_repo_id
711
  self._repo_url = f"https://huggingface.co/{hub_checkpoint_repo_id}"
712
- self._table = None # lazy: wandb may not be initialised at __init__
713
 
714
  def on_train_begin(
715
  self,
@@ -718,8 +531,6 @@ class _WandbCheckpointCallback(TrainerCallback):
718
  control: TrainerControl,
719
  **kwargs,
720
  ) -> None:
721
- # Pin the repo URL into the run config + summary at the very start
722
- # so the link is visible on the W&B "Overview" panel from step 0.
723
  try:
724
  import wandb
725
 
@@ -735,11 +546,6 @@ class _WandbCheckpointCallback(TrainerCallback):
735
  f"\n[wandb] Checkpoint repo pinned in run summary: {self._repo_url}\n",
736
  flush=True,
737
  )
738
-
739
- # Stash the W&B run id at the *root* of the checkpoint repo so a
740
- # future re-launch can find it without W&B API calls. Atomic with
741
- # checkpoint storage, ~36 bytes. We do this once at train begin
742
- # instead of every save to avoid 200 redundant commits.
743
  self._publish_wandb_run_id(wandb.run.id)
744
  except Exception as exc: # noqa: BLE001
745
  _log.warning("Could not pin checkpoint repo to W&B summary: %s", exc)
@@ -788,28 +594,18 @@ class _WandbCheckpointCallback(TrainerCallback):
788
  else f"{self._repo_url}/tree/main"
789
  )
790
 
791
- # 1. Persistent summary keys (top-of-run, always visible).
792
  wandb.run.summary["checkpoint/last_step"] = step
793
  wandb.run.summary["checkpoint/last_commit"] = commit_sha or "pending"
794
  wandb.run.summary["checkpoint/last_url"] = tree_url
795
-
796
- # 2. Step-indexed scalar so a small chart appears on the run page.
797
  wandb.log({"checkpoint/step": step}, step=step)
798
 
799
- # 3. Running history table.
800
  if self._table is None:
801
  self._table = wandb.Table(
802
  columns=["step", "commit", "url", "repo"]
803
  )
804
  self._table.add_data(step, commit_sha or "pending", tree_url, self._repo)
805
- # Re-log the entire table each time so the latest version shows.
806
  wandb.log({"checkpoint_history": self._table}, step=step)
807
 
808
- # 4. Pointer-only W&B Artifact (~200 bytes JSON). Doesn't upload
809
- # weights — those are on the Hub already — but makes every
810
- # checkpoint a first-class, addressable W&B artifact that can
811
- # be looked up later by `wandb artifact get`. Side effect:
812
- # populates the run's "Artifacts" panel with one entry per save.
813
  if commit_sha:
814
  from physix.training.checkpoints import (
815
  CheckpointHandle,
@@ -826,7 +622,6 @@ class _WandbCheckpointCallback(TrainerCallback):
826
  artifact_name="physix-grpo-checkpoint",
827
  )
828
 
829
- # 5. Stdout banner — also visible in `hf jobs logs`.
830
  print(
831
  "\n"
832
  "================ CHECKPOINT SAVED ================\n"
@@ -852,13 +647,7 @@ class _WandbCheckpointCallback(TrainerCallback):
852
  )
853
 
854
  def _latest_commit_sha(self) -> Optional[str]:
855
- """Best-effort fetch of the most recent commit on the checkpoint repo.
856
-
857
- Uses ``HfApi.list_repo_commits`` if available; returns ``None`` on
858
- any failure. The async ``git push`` may not be done at the instant
859
- ``on_save`` fires, so we may see the *previous* checkpoint's commit;
860
- that's acceptable — it self-corrects on the next save.
861
- """
862
  try:
863
  from huggingface_hub import HfApi
864
 
@@ -872,19 +661,6 @@ class _WandbCheckpointCallback(TrainerCallback):
872
 
873
 
874
  def _build_grpo_config(config: TrainingConfig) -> GRPOConfig:
875
- # Note on the metrics this run will produce in W&B (per TRL docs):
876
- # train/loss — the GRPO surrogate objective being minimized.
877
- # = -E[advantage * logπ(action|state)] + β * KL.
878
- # Should DECREASE as the policy exploits advantages.
879
- # train/reward — mean total reward per rollout. Should INCREASE.
880
- # train/kl — KL(policy || reference). Bounded by β; grows slowly.
881
- # rewards/<f>/mean — per-component reward (one per reward function).
882
- #
883
- # ``train/loss`` going to ~0 *only* if ``train/reward`` rises in lockstep
884
- # is fine — it just means advantages got fully exploited. Loss collapsing
885
- # without reward growth is reward hacking, broken parsing, or a saturated
886
- # KL anchor. We surface both via _log_reward_summary at end of training
887
- # AND via _GenerateCurvesCallback which renders both curves to PNG.
888
  effective_batch = (
889
  config.per_device_train_batch_size * config.gradient_accumulation_steps
890
  )
@@ -933,20 +709,7 @@ def _save_artifacts(
933
  tokenizer: AutoTokenizer,
934
  config: TrainingConfig,
935
  ) -> None:
936
- """Persist the trained adapter via Unsloth's save path.
937
-
938
- ``save_pretrained_merged`` dispatches on ``save_method``:
939
-
940
- - ``"lora"``: writes only the adapter weights (small; requires the base
941
- model at load time).
942
- - ``"merged_16bit"``: merges LoRA into base and writes a standard HF
943
- checkpoint in bf16/fp16 (large; loadable without Unsloth, exportable to
944
- GGUF for Ollama).
945
- - ``"merged_4bit"``: same merge but quantised back to 4-bit.
946
-
947
- Hub pushes use the same ``save_method`` so the on-disk artifact and the
948
- Hub artifact are byte-identical.
949
- """
950
  out_path = Path(config.output_dir)
951
  out_path.mkdir(parents=True, exist_ok=True)
952
 
 
1
+ """GRPO training loop using Unsloth + TRL.
2
 
3
+ Requires the ``[train]`` optional dependency group (heavy ML deps).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
 
25
  from physix.training.reward_fns import make_reward_funcs
26
  from physix.training.scorer import Scorer
27
 
28
+ # Unsloth patches must be applied before importing GRPOTrainer — order matters.
29
+ # Requires trl<=0.24.0; newer versions break PatchFastRL.
 
 
 
 
 
 
 
 
 
30
  from unsloth import FastLanguageModel, PatchFastRL # noqa: E402
31
 
32
  PatchFastRL("GRPO", FastLanguageModel)
 
47
  model_config = ConfigDict(frozen=True)
48
 
49
  model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
50
+ #: Path to merged SFT model to warm-start GRPO from.
 
 
 
51
  sft_checkpoint: Optional[str] = None
52
+ #: Hub repo id or local path of an existing LoRA adapter to resume from.
 
 
 
 
 
 
 
53
  lora_adapter_repo: Optional[str] = None
54
  output_dir: str = "runs/physix-1.5b-rl"
55
  max_seq_length: int = 2048
 
63
  per_device_train_batch_size: int = 1
64
  gradient_accumulation_steps: int = 8
65
  num_steps: int = 300
66
+ #: Set to 0 to disable early stopping.
 
67
  early_stop_patience: int = 50
68
  seed: int = 0
69
  instances_per_system: int = 32
 
 
70
  system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS
71
  ablation: Optional[Ablation] = None
72
  wandb_project: str = "physix-live"
73
  wandb_run_name: Optional[str] = None
74
  push_to_hub: bool = False
75
  hub_repo_id: Optional[str] = None
76
+ #: HF repo to push LoRA checkpoints to every save_steps.
 
 
77
  hub_checkpoint_repo_id: Optional[str] = None
 
 
78
  resume_from_checkpoint: Optional[str] = None
 
 
 
 
 
79
  save_method: SaveMethod = "merged_16bit"
80
 
81
 
 
94
  resume="allow",
95
  )
96
 
 
 
97
  if config.hub_checkpoint_repo_id:
98
  ckpt_url = f"https://huggingface.co/{config.hub_checkpoint_repo_id}"
99
  wandb.run.summary["checkpoint/repo"] = config.hub_checkpoint_repo_id
 
108
  wandb.run.summary["resume/from_url"] = (
109
  f"https://huggingface.co/{config.lora_adapter_repo}"
110
  )
 
 
111
  parent_run = os.environ.get("WANDB_RESUMED_FROM")
112
  if parent_run:
113
  wandb.run.summary["resume/parent_wandb_run"] = parent_run
 
167
 
168
 
169
  def _log_reward_summary(trainer: "GRPOTrainer") -> None:
170
+ """Log first→last reward delta for every component. Raises if no rewards were logged."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  history = getattr(trainer.state, "log_history", []) or []
172
  reward_entries = [
173
  entry for entry in history
 
193
  v1 = last.get(key)
194
  if isinstance(v0, (int, float)) and isinstance(v1, (int, float)):
195
  _log.info(" %-40s %.4f → %.4f (Δ=%+.4f)", key, v0, v1, v1 - v0)
 
 
 
 
 
 
 
 
 
 
196
  _log.info("=" * 60)
197
 
198
 
 
200
  trainer: "GRPOTrainer",
201
  config: TrainingConfig,
202
  ) -> None:
203
+ """Render loss/reward/component PNGs from log_history and push to Hub."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  try:
205
  import matplotlib
206
  matplotlib.use("Agg") # headless / no display server in HF Jobs
 
230
 
231
  rendered: list[Path] = []
232
 
 
233
  steps_l, losses = _series("loss")
234
  if steps_l:
235
  fig, ax = plt.subplots(figsize=(8, 4.5))
 
246
  else:
247
  _log.warning("No 'loss' entries in log_history.")
248
 
 
249
  steps_r, rewards = _series("reward")
250
  _, reward_std = _series("reward_std")
251
  if steps_r:
 
270
  else:
271
  _log.warning("No 'reward' entries in log_history.")
272
 
 
273
  component_keys = sorted({
274
  k for entry in history for k in entry
275
  if k.startswith("rewards/") and k.endswith("/mean")
 
298
 
299
  _log.info("Rendered %d curve PNG(s) to %s", len(rendered), plots_dir)
300
 
 
 
301
  try:
302
  import wandb
303
  if wandb.run is not None:
 
308
  except Exception as exc: # noqa: BLE001
309
  _log.warning("Could not log plots to wandb: %s", exc)
310
 
 
 
311
  if config.push_to_hub and config.hub_repo_id:
312
  try:
313
  from huggingface_hub import HfApi, create_repo
 
339
  def _load_model_and_tokenizer(
340
  config: TrainingConfig,
341
  ) -> tuple[FastLanguageModel, AutoTokenizer]:
342
+ """Load model via Unsloth in 4-bit and attach a LoRA adapter."""
 
 
 
 
 
 
343
  if config.lora_adapter_repo:
 
 
 
 
 
 
 
 
344
  _log.info(
345
  "Resuming from existing LoRA adapter %s on top of %s",
346
  config.lora_adapter_repo,
 
352
  load_in_4bit=True,
353
  dtype=None,
354
  )
 
 
 
 
 
 
355
  model = FastLanguageModel.get_peft_model(
356
  model,
357
  r=config.lora_r,
 
364
  use_gradient_checkpointing="unsloth",
365
  random_state=config.seed,
366
  )
 
 
367
  model.load_adapter(
368
  config.lora_adapter_repo,
369
  adapter_name="default",
 
443
 
444
 
445
  def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]:
446
+ """Return the active reward function list, optionally with one signal ablated."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  scorer = Scorer()
448
  funcs = make_reward_funcs(scorer)
449
  full = [
 
460
  if ablation == "no_format":
461
  return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["simplicity"]]
462
  if ablation == "no_progress":
463
+ return full # progress was removed; treat as full set for backward compat
 
 
 
464
  raise ValueError(
465
  f"Unknown ablation {ablation!r}. Choose from "
466
  "no_progress | no_simplicity | no_format | None."
 
468
 
469
 
470
  class _RewardConvergenceCallback(TrainerCallback):
471
+ """Stop early when reward_std stays below min_std for `patience` consecutive steps."""
 
 
 
 
 
 
 
 
 
 
472
 
473
  def __init__(self, patience: int = 50, min_std: float = 0.05) -> None:
474
  self._patience = patience
 
517
 
518
 
519
  class _WandbCheckpointCallback(TrainerCallback):
520
+ """Log checkpoint metadata to W&B summary and stdout after each Trainer save."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  def __init__(self, hub_checkpoint_repo_id: str) -> None:
523
  self._repo = hub_checkpoint_repo_id
524
  self._repo_url = f"https://huggingface.co/{hub_checkpoint_repo_id}"
525
+ self._table = None
526
 
527
  def on_train_begin(
528
  self,
 
531
  control: TrainerControl,
532
  **kwargs,
533
  ) -> None:
 
 
534
  try:
535
  import wandb
536
 
 
546
  f"\n[wandb] Checkpoint repo pinned in run summary: {self._repo_url}\n",
547
  flush=True,
548
  )
 
 
 
 
 
549
  self._publish_wandb_run_id(wandb.run.id)
550
  except Exception as exc: # noqa: BLE001
551
  _log.warning("Could not pin checkpoint repo to W&B summary: %s", exc)
 
594
  else f"{self._repo_url}/tree/main"
595
  )
596
 
 
597
  wandb.run.summary["checkpoint/last_step"] = step
598
  wandb.run.summary["checkpoint/last_commit"] = commit_sha or "pending"
599
  wandb.run.summary["checkpoint/last_url"] = tree_url
 
 
600
  wandb.log({"checkpoint/step": step}, step=step)
601
 
 
602
  if self._table is None:
603
  self._table = wandb.Table(
604
  columns=["step", "commit", "url", "repo"]
605
  )
606
  self._table.add_data(step, commit_sha or "pending", tree_url, self._repo)
 
607
  wandb.log({"checkpoint_history": self._table}, step=step)
608
 
 
 
 
 
 
609
  if commit_sha:
610
  from physix.training.checkpoints import (
611
  CheckpointHandle,
 
622
  artifact_name="physix-grpo-checkpoint",
623
  )
624
 
 
625
  print(
626
  "\n"
627
  "================ CHECKPOINT SAVED ================\n"
 
647
  )
648
 
649
  def _latest_commit_sha(self) -> Optional[str]:
650
+ """Best-effort fetch of the latest commit SHA; returns None on failure."""
 
 
 
 
 
 
651
  try:
652
  from huggingface_hub import HfApi
653
 
 
661
 
662
 
663
  def _build_grpo_config(config: TrainingConfig) -> GRPOConfig:
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  effective_batch = (
665
  config.per_device_train_batch_size * config.gradient_accumulation_steps
666
  )
 
709
  tokenizer: AutoTokenizer,
710
  config: TrainingConfig,
711
  ) -> None:
712
+ """Save model locally and optionally push to Hub."""
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  out_path = Path(config.output_dir)
714
  out_path.mkdir(parents=True, exist_ok=True)
715