Spaces:
Paused
Paused
sft+reward-fix: server/rewards/reward_function.py
Browse files
server/rewards/reward_function.py
CHANGED
|
@@ -61,12 +61,19 @@ class RewardWeights:
|
|
| 61 |
valid_action: float = 0.05
|
| 62 |
progress_milestone: float = 0.25
|
| 63 |
evidence_quality: float = 0.20
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# whose category matches the action.
|
| 66 |
bogus_method_penalty: float = -0.05 # penalises method strings outside
|
| 67 |
# TOOL_REGISTRY (anti-string-spam).
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
soft_violation: float = -0.05
|
| 71 |
hard_violation: float = -0.50
|
| 72 |
redundancy: float = -0.10
|
|
@@ -76,7 +83,9 @@ class RewardWeights:
|
|
| 76 |
# Hard cap on what a single shaping step can earn. Without this a
|
| 77 |
# policy could in principle stack milestone + evidence_quality +
|
| 78 |
# tool_fit + valid_action and approach the terminal reward magnitude.
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# ── terminal grading ────────────────────────────────────────
|
| 82 |
terminal_scale: float = 5.0 # multiplied with the convex sum below
|
|
@@ -92,6 +101,18 @@ class RewardWeights:
|
|
| 92 |
overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
|
| 93 |
overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# ── Outputs ──────────────────────────────────────────────────────────────
|
| 97 |
|
|
@@ -233,13 +254,16 @@ def compute_step_reward(
|
|
| 233 |
breakdown.add("soft_violation", weights.soft_violation * soft_other)
|
| 234 |
|
| 235 |
# ── consecutive-repeat penalty (catches loop hacks) ─────────────
|
| 236 |
-
#
|
| 237 |
-
#
|
|
|
|
|
|
|
|
|
|
| 238 |
repeats = _consecutive_repeat_count(history or [], action.action_type)
|
| 239 |
-
if repeats >=
|
| 240 |
breakdown.add(
|
| 241 |
"repeat_action",
|
| 242 |
-
weights.repeat_action_penalty *
|
| 243 |
)
|
| 244 |
|
| 245 |
# ── resource overspend ──────────────────────────────────────────
|
|
@@ -328,10 +352,34 @@ def _efficiency_bonus(state: FullLatentState) -> float:
|
|
| 328 |
def compute_terminal_reward(
|
| 329 |
*,
|
| 330 |
state: FullLatentState,
|
| 331 |
-
claim: DiscoveryClaim,
|
| 332 |
weights: RewardWeights = RewardWeights(),
|
| 333 |
) -> TerminalReward:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
breakdown = RewardBreakdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
truth = state.particle
|
| 336 |
|
| 337 |
mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
|
|
@@ -366,6 +414,15 @@ def compute_terminal_reward(
|
|
| 366 |
|
| 367 |
raw = breakdown.total * weights.terminal_scale
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
# Overconfident-wrong penalty: high confidence but wrong channel & far mass
|
| 370 |
if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
|
| 371 |
raw -= weights.overconfident_wrong_penalty
|
|
|
|
| 61 |
valid_action: float = 0.05
|
| 62 |
progress_milestone: float = 0.25
|
| 63 |
evidence_quality: float = 0.20
|
| 64 |
+
# Cut to ~1/3 of original (was 0.10) to lower the per-step shaping
|
| 65 |
+
# floor. Combined with a smaller step_reward_clip and a heavier
|
| 66 |
+
# repeat-action penalty this prevents the agent from farming
|
| 67 |
+
# +0.20+/step by cycling well-formed-but-inert tool calls.
|
| 68 |
+
tool_fit: float = 0.033 # paid only on a method ∈ TOOL_REGISTRY
|
| 69 |
# whose category matches the action.
|
| 70 |
bogus_method_penalty: float = -0.05 # penalises method strings outside
|
| 71 |
# TOOL_REGISTRY (anti-string-spam).
|
| 72 |
+
# Was -0.08; bumped to -0.5 because the previous value was easily out-
|
| 73 |
+
# earned by stacking format_bonus + valid_action + tool_fit. The
|
| 74 |
+
# gating in compute_step_reward also now triggers from the *2nd*
|
| 75 |
+
# consecutive identical action_type instead of the 3rd.
|
| 76 |
+
repeat_action_penalty: float = -0.5
|
| 77 |
soft_violation: float = -0.05
|
| 78 |
hard_violation: float = -0.50
|
| 79 |
redundancy: float = -0.10
|
|
|
|
| 83 |
# Hard cap on what a single shaping step can earn. Without this a
|
| 84 |
# policy could in principle stack milestone + evidence_quality +
|
| 85 |
# tool_fit + valid_action and approach the terminal reward magnitude.
|
| 86 |
+
# Cut from 0.75 → 0.25 so the per-step shaping floor cannot exceed
|
| 87 |
+
# ~1/3 of the wrong-claim terminal penalty.
|
| 88 |
+
step_reward_clip: float = 0.25
|
| 89 |
|
| 90 |
# ── terminal grading ────────────────────────────────────────
|
| 91 |
terminal_scale: float = 5.0 # multiplied with the convex sum below
|
|
|
|
| 101 |
overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
|
| 102 |
overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma
|
| 103 |
|
| 104 |
+
# Big bonus for getting BOTH mass and channel right, on top of the
|
| 105 |
+
# terminal grade. Makes the bandit math strictly favour attempting a
|
| 106 |
+
# claim when uncertain rather than running out the clock: a correct
|
| 107 |
+
# claim now returns ~+10–12, a wrong one ~−1.85, no claim ~−5.
|
| 108 |
+
correct_claim_bonus: float = 6.0
|
| 109 |
+
|
| 110 |
+
# Penalty applied at episode end when the trajectory never even
|
| 111 |
+
# *attempted* a SUBMIT_DISCOVERY_CLAIM. Defeats the "hide forever and
|
| 112 |
+
# farm shaping" reward hack we observed in v1 (mean +0.22/step over
|
| 113 |
+
# ~12 steps was a better deal than risking the wrong-claim penalty).
|
| 114 |
+
no_claim_terminal_penalty: float = -5.0
|
| 115 |
+
|
| 116 |
|
| 117 |
# ── Outputs ──────────────────────────────────────────────────────────────
|
| 118 |
|
|
|
|
| 254 |
breakdown.add("soft_violation", weights.soft_violation * soft_other)
|
| 255 |
|
| 256 |
# ── consecutive-repeat penalty (catches loop hacks) ─────────────
|
| 257 |
+
# Triggers from the *2nd* identical action in a row (previously
|
| 258 |
+
# only kicked in at the 3rd). The escalating multiplier scales with
|
| 259 |
+
# the run length so that 4-in-a-row gets 4× the base penalty —
|
| 260 |
+
# important because v1 found that a tiny -0.08 was easily out-earned
|
| 261 |
+
# by the +0.22/step shaping floor.
|
| 262 |
repeats = _consecutive_repeat_count(history or [], action.action_type)
|
| 263 |
+
if repeats >= 1:
|
| 264 |
breakdown.add(
|
| 265 |
"repeat_action",
|
| 266 |
+
weights.repeat_action_penalty * repeats,
|
| 267 |
)
|
| 268 |
|
| 269 |
# ── resource overspend ──────────────────────────────────────────
|
|
|
|
| 352 |
def compute_terminal_reward(
|
| 353 |
*,
|
| 354 |
state: FullLatentState,
|
| 355 |
+
claim: Optional[DiscoveryClaim],
|
| 356 |
weights: RewardWeights = RewardWeights(),
|
| 357 |
) -> TerminalReward:
|
| 358 |
+
"""Grade the end-of-episode submission.
|
| 359 |
+
|
| 360 |
+
``claim`` is ``None`` when the episode terminated by *any* reason
|
| 361 |
+
other than a ``submit_discovery_claim`` action (max_steps, budget
|
| 362 |
+
exhausted, time exhausted) AND the trajectory never attempted to
|
| 363 |
+
submit a claim. In that case we return a flat
|
| 364 |
+
``no_claim_terminal_penalty`` so the bandit math always favours
|
| 365 |
+
*attempting* a claim over hiding forever to farm per-step shaping.
|
| 366 |
+
See: v1 (anugrahhu/cernenv-grpo-smollm2-360m) which exploited this
|
| 367 |
+
exact gap by spamming request_systematics for ~+0.22/step instead
|
| 368 |
+
of risking the wrong-claim penalty (~−1.85).
|
| 369 |
+
"""
|
| 370 |
breakdown = RewardBreakdown()
|
| 371 |
+
|
| 372 |
+
if claim is None:
|
| 373 |
+
breakdown.add("no_claim_terminal_penalty", weights.no_claim_terminal_penalty)
|
| 374 |
+
return TerminalReward(
|
| 375 |
+
reward=float(weights.no_claim_terminal_penalty),
|
| 376 |
+
breakdown=breakdown,
|
| 377 |
+
discovered=False,
|
| 378 |
+
correct_mass=False,
|
| 379 |
+
correct_channel=False,
|
| 380 |
+
correct_spin=False,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
truth = state.particle
|
| 384 |
|
| 385 |
mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
|
|
|
|
| 414 |
|
| 415 |
raw = breakdown.total * weights.terminal_scale
|
| 416 |
|
| 417 |
+
# Asymmetric claim cost (Fix #4). When the claim gets BOTH the mass
|
| 418 |
+
# and the decay channel right, add a flat bonus on top of the graded
|
| 419 |
+
# terminal so that a correct attempt is worth substantially more
|
| 420 |
+
# than the no-claim penalty (-5) and the wrong-claim penalty (~-1.85).
|
| 421 |
+
# This makes the bandit math: correct +10–12 ≫ no-claim −5 > wrong −2.
|
| 422 |
+
if mass_score >= 0.5 and channel_ok:
|
| 423 |
+
raw += weights.correct_claim_bonus
|
| 424 |
+
breakdown.add("correct_claim_bonus", weights.correct_claim_bonus)
|
| 425 |
+
|
| 426 |
# Overconfident-wrong penalty: high confidence but wrong channel & far mass
|
| 427 |
if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
|
| 428 |
raw -= weights.overconfident_wrong_penalty
|