Spaces:
Sleeping
Sleeping
cleanup: strip verbose comments from physix/training/reward_fns.py
Browse files
physix/training/reward_fns.py
CHANGED
|
@@ -1,34 +1,4 @@
|
|
| 1 |
-
"""TRL-compatible reward functions for GRPO training.
|
| 2 |
-
|
| 3 |
-
Responsibility: expose a stateless reward function for each independent
|
| 4 |
-
reward signal. Internally each component delegates to a shared
|
| 5 |
-
:class:`Scorer` so a single completion is parsed and simulated exactly
|
| 6 |
-
once per training step regardless of how many reward functions query it.
|
| 7 |
-
|
| 8 |
-
The TRL signature for a reward function is::
|
| 9 |
-
|
| 10 |
-
def reward_func(*, prompts, completions, **kwargs) -> list[float]: ...
|
| 11 |
-
|
| 12 |
-
where ``prompts`` and ``completions`` are batched lists. Extra columns from
|
| 13 |
-
the training dataset arrive as keyword arguments — we expect the columns
|
| 14 |
-
listed in :class:`SystemContext` to be present.
|
| 15 |
-
|
| 16 |
-
Reward set design (anti-hack, RCA from W&B run 5kuqns9x):
|
| 17 |
-
|
| 18 |
-
- ``reward_match`` — raw R² on the trajectory (linear).
|
| 19 |
-
- ``reward_match_dense`` — sqrt(R²); denser gradient at low values.
|
| 20 |
-
- ``reward_correctness`` — binary cliff at R² ≥ 0.70; pushes past plateau.
|
| 21 |
-
- ``reward_simplicity`` — gated on R² ≥ 0.10 (no free reward for trivial
|
| 22 |
-
equations).
|
| 23 |
-
- ``reward_format`` — 1.0 only if the equation parsed *and*
|
| 24 |
-
simulated. No partial credit for parseable
|
| 25 |
-
but uncomputable garbage.
|
| 26 |
-
|
| 27 |
-
The legacy ``reward_progress`` is intentionally absent. In single-turn
|
| 28 |
-
GRPO every dataset row carries ``previous_r_match=0``, which made
|
| 29 |
-
``progress = max(0, match - 0) = match`` for every rollout — a perfect
|
| 30 |
-
duplicate of ``reward_match`` that diluted advantage estimation.
|
| 31 |
-
"""
|
| 32 |
|
| 33 |
from __future__ import annotations
|
| 34 |
|
|
@@ -45,31 +15,7 @@ RewardFunction = Callable[..., list[float]]
|
|
| 45 |
def make_reward_funcs(
|
| 46 |
scorer: Scorer | None = None,
|
| 47 |
) -> dict[str, RewardFunction]:
|
| 48 |
-
"""Build
|
| 49 |
-
|
| 50 |
-
Each function is named ``reward_<component>`` so TRL's GRPO trainer
|
| 51 |
-
logs them individually to W&B under
|
| 52 |
-
``train/rewards/reward_<component>/mean``.
|
| 53 |
-
|
| 54 |
-
The scorer is shared across all functions. TRL calls reward functions
|
| 55 |
-
one-by-one for the same batch (same ``completions`` list, same indices).
|
| 56 |
-
The ``match`` function resets the cache and populates it; the
|
| 57 |
-
remaining functions (``match_dense``, ``correctness``, ``simplicity``,
|
| 58 |
-
``format``) reuse the cached results via ``cache_key=i``. This means
|
| 59 |
-
each completion is parsed + simulated exactly once per step regardless
|
| 60 |
-
of how many reward functions query it.
|
| 61 |
-
|
| 62 |
-
Returns a dict whose keys are:
|
| 63 |
-
|
| 64 |
-
- ``match`` / ``simplicity`` / ``format`` — direct reads from the
|
| 65 |
-
:class:`RewardBreakdown`. ``simplicity`` is internally gated on
|
| 66 |
-
match ≥ 0.10 and ``format`` on simulation success.
|
| 67 |
-
- ``match_dense`` — ``sqrt(match)`` for denser low-value gradient.
|
| 68 |
-
- ``correctness`` — binary 1.0 above an R² threshold (``0.70``).
|
| 69 |
-
|
| 70 |
-
All functions share the scorer cache, so they cost one parse +
|
| 71 |
-
simulate per completion combined, not five.
|
| 72 |
-
"""
|
| 73 |
shared = scorer if scorer is not None else Scorer()
|
| 74 |
|
| 75 |
def _make_breakdown_reader(component: str, *, reset_cache: bool) -> RewardFunction:
|
|
@@ -125,8 +71,7 @@ def make_reward_funcs(
|
|
| 125 |
|
| 126 |
_reward_correctness.__name__ = "reward_correctness"
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
# so subsequent functions get fresh results for this step's completions.
|
| 130 |
funcs: dict[str, RewardFunction] = {
|
| 131 |
"match": _make_breakdown_reader("match", reset_cache=True),
|
| 132 |
"simplicity": _make_breakdown_reader("simplicity", reset_cache=False),
|
|
@@ -138,12 +83,7 @@ def make_reward_funcs(
|
|
| 138 |
|
| 139 |
|
| 140 |
def _hydrate_contexts(batch_size: int, kwargs: dict[str, Any]) -> list[SystemContext]:
|
| 141 |
-
"""
|
| 142 |
-
|
| 143 |
-
TRL passes dataset columns as kwargs where each value is a list of
|
| 144 |
-
length ``batch_size``. We zip them together into per-row dicts and hand
|
| 145 |
-
each off to :func:`SystemContext.from_row`.
|
| 146 |
-
"""
|
| 147 |
expected_keys = (
|
| 148 |
"system_id",
|
| 149 |
"state_variables",
|
|
|
|
| 1 |
+
"""TRL-compatible reward functions for GRPO training."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 15 |
def make_reward_funcs(
|
| 16 |
scorer: Scorer | None = None,
|
| 17 |
) -> dict[str, RewardFunction]:
|
| 18 |
+
"""Build reward functions keyed by component name, sharing a single scorer cache."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
shared = scorer if scorer is not None else Scorer()
|
| 20 |
|
| 21 |
def _make_breakdown_reader(component: str, *, reset_cache: bool) -> RewardFunction:
|
|
|
|
| 71 |
|
| 72 |
_reward_correctness.__name__ = "reward_correctness"
|
| 73 |
|
| 74 |
+
# match resets the cache first so subsequent functions reuse parsed results.
|
|
|
|
| 75 |
funcs: dict[str, RewardFunction] = {
|
| 76 |
"match": _make_breakdown_reader("match", reset_cache=True),
|
| 77 |
"simplicity": _make_breakdown_reader("simplicity", reset_cache=False),
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def _hydrate_contexts(batch_size: int, kwargs: dict[str, Any]) -> list[SystemContext]:
|
| 86 |
+
"""Convert TRL batch kwargs into per-row SystemContext records."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
expected_keys = (
|
| 88 |
"system_id",
|
| 89 |
"state_variables",
|