Pratyush-01 commited on
Commit
0128624
·
verified ·
1 Parent(s): b4bd6d8

cleanup: strip verbose comments from physix/training/reward_fns.py

Browse files
Files changed (1) hide show
  1. physix/training/reward_fns.py +4 -64
physix/training/reward_fns.py CHANGED
@@ -1,34 +1,4 @@
1
- """TRL-compatible reward functions for GRPO training.
2
-
3
- Responsibility: expose a stateless reward function for each independent
4
- reward signal. Internally each component delegates to a shared
5
- :class:`Scorer` so a single completion is parsed and simulated exactly
6
- once per training step regardless of how many reward functions query it.
7
-
8
- The TRL signature for a reward function is::
9
-
10
- def reward_func(*, prompts, completions, **kwargs) -> list[float]: ...
11
-
12
- where ``prompts`` and ``completions`` are batched lists. Extra columns from
13
- the training dataset arrive as keyword arguments — we expect the columns
14
- listed in :class:`SystemContext` to be present.
15
-
16
- Reward set design (anti-hack, RCA from W&B run 5kuqns9x):
17
-
18
- - ``reward_match`` — raw R² on the trajectory (linear).
19
- - ``reward_match_dense`` — sqrt(R²); denser gradient at low values.
20
- - ``reward_correctness`` — binary cliff at R² ≥ 0.70; pushes past plateau.
21
- - ``reward_simplicity`` — gated on R² ≥ 0.10 (no free reward for trivial
22
- equations).
23
- - ``reward_format`` — 1.0 only if the equation parsed *and*
24
- simulated. No partial credit for parseable
25
- but uncomputable garbage.
26
-
27
- The legacy ``reward_progress`` is intentionally absent. In single-turn
28
- GRPO every dataset row carries ``previous_r_match=0``, which made
29
- ``progress = max(0, match - 0) = match`` for every rollout — a perfect
30
- duplicate of ``reward_match`` that diluted advantage estimation.
31
- """
32
 
33
  from __future__ import annotations
34
 
@@ -45,31 +15,7 @@ RewardFunction = Callable[..., list[float]]
45
  def make_reward_funcs(
46
  scorer: Scorer | None = None,
47
  ) -> dict[str, RewardFunction]:
48
- """Build a fresh dict of reward functions wired to a shared scorer.
49
-
50
- Each function is named ``reward_<component>`` so TRL's GRPO trainer
51
- logs them individually to W&B under
52
- ``train/rewards/reward_<component>/mean``.
53
-
54
- The scorer is shared across all functions. TRL calls reward functions
55
- one-by-one for the same batch (same ``completions`` list, same indices).
56
- The ``match`` function resets the cache and populates it; the
57
- remaining functions (``match_dense``, ``correctness``, ``simplicity``,
58
- ``format``) reuse the cached results via ``cache_key=i``. This means
59
- each completion is parsed + simulated exactly once per step regardless
60
- of how many reward functions query it.
61
-
62
- Returns a dict whose keys are:
63
-
64
- - ``match`` / ``simplicity`` / ``format`` — direct reads from the
65
- :class:`RewardBreakdown`. ``simplicity`` is internally gated on
66
- match ≥ 0.10 and ``format`` on simulation success.
67
- - ``match_dense`` — ``sqrt(match)`` for denser low-value gradient.
68
- - ``correctness`` — binary 1.0 above an R² threshold (``0.70``).
69
-
70
- All functions share the scorer cache, so they cost one parse +
71
- simulate per completion combined, not five.
72
- """
73
  shared = scorer if scorer is not None else Scorer()
74
 
75
  def _make_breakdown_reader(component: str, *, reset_cache: bool) -> RewardFunction:
@@ -125,8 +71,7 @@ def make_reward_funcs(
125
 
126
  _reward_correctness.__name__ = "reward_correctness"
127
 
128
- # ``match`` is always the first function TRL calls; it resets the cache
129
- # so subsequent functions get fresh results for this step's completions.
130
  funcs: dict[str, RewardFunction] = {
131
  "match": _make_breakdown_reader("match", reset_cache=True),
132
  "simplicity": _make_breakdown_reader("simplicity", reset_cache=False),
@@ -138,12 +83,7 @@ def make_reward_funcs(
138
 
139
 
140
  def _hydrate_contexts(batch_size: int, kwargs: dict[str, Any]) -> list[SystemContext]:
141
- """Project per-row kwargs into :class:`SystemContext` records.
142
-
143
- TRL passes dataset columns as kwargs where each value is a list of
144
- length ``batch_size``. We zip them together into per-row dicts and hand
145
- each off to :func:`SystemContext.from_row`.
146
- """
147
  expected_keys = (
148
  "system_id",
149
  "state_variables",
 
1
+ """TRL-compatible reward functions for GRPO training."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
15
  def make_reward_funcs(
16
  scorer: Scorer | None = None,
17
  ) -> dict[str, RewardFunction]:
18
+ """Build reward functions keyed by component name, sharing a single scorer cache."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  shared = scorer if scorer is not None else Scorer()
20
 
21
  def _make_breakdown_reader(component: str, *, reset_cache: bool) -> RewardFunction:
 
71
 
72
  _reward_correctness.__name__ = "reward_correctness"
73
 
74
+ # match resets the cache first so subsequent functions reuse parsed results.
 
75
  funcs: dict[str, RewardFunction] = {
76
  "match": _make_breakdown_reader("match", reset_cache=True),
77
  "simplicity": _make_breakdown_reader("simplicity", reset_cache=False),
 
83
 
84
 
85
  def _hydrate_contexts(batch_size: int, kwargs: dict[str, Any]) -> list[SystemContext]:
86
+ """Convert TRL batch kwargs into per-row SystemContext records."""
 
 
 
 
 
87
  expected_keys = (
88
  "system_id",
89
  "state_variables",