anugrahhu commited on
Commit
d91fe20
·
verified ·
1 Parent(s): 7df4308

sft+reward-fix: server/rewards/reward_function.py

Browse files
Files changed (1) hide show
  1. server/rewards/reward_function.py +66 -9
server/rewards/reward_function.py CHANGED
@@ -61,12 +61,19 @@ class RewardWeights:
61
  valid_action: float = 0.05
62
  progress_milestone: float = 0.25
63
  evidence_quality: float = 0.20
64
- tool_fit: float = 0.10 # paid only on a method ∈ TOOL_REGISTRY
 
 
 
 
65
  # whose category matches the action.
66
  bogus_method_penalty: float = -0.05 # penalises method strings outside
67
  # TOOL_REGISTRY (anti-string-spam).
68
- repeat_action_penalty: float = -0.08 # per consecutive repeat beyond the
69
- # second identical action_type in a row.
 
 
 
70
  soft_violation: float = -0.05
71
  hard_violation: float = -0.50
72
  redundancy: float = -0.10
@@ -76,7 +83,9 @@ class RewardWeights:
76
  # Hard cap on what a single shaping step can earn. Without this a
77
  # policy could in principle stack milestone + evidence_quality +
78
  # tool_fit + valid_action and approach the terminal reward magnitude.
79
- step_reward_clip: float = 0.75
 
 
80
 
81
  # ── terminal grading ────────────────────────────────────────
82
  terminal_scale: float = 5.0 # multiplied with the convex sum below
@@ -92,6 +101,18 @@ class RewardWeights:
92
  overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
93
  overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma
94
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # ── Outputs ──────────────────────────────────────────────────────────────
97
 
@@ -233,13 +254,16 @@ def compute_step_reward(
233
  breakdown.add("soft_violation", weights.soft_violation * soft_other)
234
 
235
  # ── consecutive-repeat penalty (catches loop hacks) ─────────────
236
- # Two-in-a-row is mildly OK (sometimes you re-collect data); three
237
- # or more identical action_types in a row earns escalating penalty.
 
 
 
238
  repeats = _consecutive_repeat_count(history or [], action.action_type)
239
- if repeats >= 2:
240
  breakdown.add(
241
  "repeat_action",
242
- weights.repeat_action_penalty * (repeats - 1),
243
  )
244
 
245
  # ── resource overspend ──────────────────────────────────────────
@@ -328,10 +352,34 @@ def _efficiency_bonus(state: FullLatentState) -> float:
328
  def compute_terminal_reward(
329
  *,
330
  state: FullLatentState,
331
- claim: DiscoveryClaim,
332
  weights: RewardWeights = RewardWeights(),
333
  ) -> TerminalReward:
 
 
 
 
 
 
 
 
 
 
 
 
334
  breakdown = RewardBreakdown()
 
 
 
 
 
 
 
 
 
 
 
 
335
  truth = state.particle
336
 
337
  mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
@@ -366,6 +414,15 @@ def compute_terminal_reward(
366
 
367
  raw = breakdown.total * weights.terminal_scale
368
 
 
 
 
 
 
 
 
 
 
369
  # Overconfident-wrong penalty: high confidence but wrong channel & far mass
370
  if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
371
  raw -= weights.overconfident_wrong_penalty
 
61
  valid_action: float = 0.05
62
  progress_milestone: float = 0.25
63
  evidence_quality: float = 0.20
64
+ # Cut to ~1/3 of original (was 0.10) to lower the per-step shaping
65
+ # floor. Combined with a smaller step_reward_clip and a heavier
66
+ # repeat-action penalty this prevents the agent from farming
67
+ # +0.20+/step by cycling well-formed-but-inert tool calls.
68
+ tool_fit: float = 0.033 # paid only on a method ∈ TOOL_REGISTRY
69
  # whose category matches the action.
70
  bogus_method_penalty: float = -0.05 # penalises method strings outside
71
  # TOOL_REGISTRY (anti-string-spam).
72
+ # Was -0.08; bumped to -0.5 because the previous value was easily out-
73
+ # earned by stacking format_bonus + valid_action + tool_fit. The
74
+ # gating in compute_step_reward also now triggers from the *2nd*
75
+ # consecutive identical action_type instead of the 3rd.
76
+ repeat_action_penalty: float = -0.5
77
  soft_violation: float = -0.05
78
  hard_violation: float = -0.50
79
  redundancy: float = -0.10
 
83
  # Hard cap on what a single shaping step can earn. Without this a
84
  # policy could in principle stack milestone + evidence_quality +
85
  # tool_fit + valid_action and approach the terminal reward magnitude.
86
+ # Cut from 0.75 → 0.25 so the per-step shaping floor cannot exceed
87
+ # ~1/3 of the wrong-claim terminal penalty.
88
+ step_reward_clip: float = 0.25
89
 
90
  # ── terminal grading ────────────────────────────────────────
91
  terminal_scale: float = 5.0 # multiplied with the convex sum below
 
101
  overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
102
  overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma
103
 
104
+ # Big bonus for getting BOTH mass and channel right, on top of the
105
+ # terminal grade. Makes the bandit math strictly favour attempting a
106
+ # claim when uncertain rather than running out the clock: a correct
107
+ # claim now returns ~+10–12, a wrong one ~−1.85, no claim ~−5.
108
+ correct_claim_bonus: float = 6.0
109
+
110
+ # Penalty applied at episode end when the trajectory never even
111
+ # *attempted* a SUBMIT_DISCOVERY_CLAIM. Defeats the "hide forever and
112
+ # farm shaping" reward hack we observed in v1 (mean +0.22/step over
113
+ # ~12 steps was a better deal than risking the wrong-claim penalty).
114
+ no_claim_terminal_penalty: float = -5.0
115
+
116
 
117
  # ── Outputs ──────────────────────────────────────────────────────────────
118
 
 
254
  breakdown.add("soft_violation", weights.soft_violation * soft_other)
255
 
256
  # ── consecutive-repeat penalty (catches loop hacks) ─────────────
257
+ # Triggers from the *2nd* identical action in a row (previously
258
+ # only kicked in at the 3rd). The escalating multiplier scales with
259
+ # the run length so that 4-in-a-row gets 4× the base penalty —
260
+ # important because v1 found that a tiny -0.08 was easily out-earned
261
+ # by the +0.22/step shaping floor.
262
  repeats = _consecutive_repeat_count(history or [], action.action_type)
263
+ if repeats >= 1:
264
  breakdown.add(
265
  "repeat_action",
266
+ weights.repeat_action_penalty * repeats,
267
  )
268
 
269
  # ── resource overspend ──────────────────────────────────────────
 
352
  def compute_terminal_reward(
353
  *,
354
  state: FullLatentState,
355
+ claim: Optional[DiscoveryClaim],
356
  weights: RewardWeights = RewardWeights(),
357
  ) -> TerminalReward:
358
+ """Grade the end-of-episode submission.
359
+
360
+ ``claim`` is ``None`` when the episode terminated by *any* reason
361
+ other than a ``submit_discovery_claim`` action (max_steps, budget
362
+ exhausted, time exhausted) AND the trajectory never attempted to
363
+ submit a claim. In that case we return a flat
364
+ ``no_claim_terminal_penalty`` so the bandit math always favours
365
+ *attempting* a claim over hiding forever to farm per-step shaping.
366
+ See: v1 (anugrahhu/cernenv-grpo-smollm2-360m) which exploited this
367
+ exact gap by spamming request_systematics for ~+0.22/step instead
368
+ of risking the wrong-claim penalty (~−1.85).
369
+ """
370
  breakdown = RewardBreakdown()
371
+
372
+ if claim is None:
373
+ breakdown.add("no_claim_terminal_penalty", weights.no_claim_terminal_penalty)
374
+ return TerminalReward(
375
+ reward=float(weights.no_claim_terminal_penalty),
376
+ breakdown=breakdown,
377
+ discovered=False,
378
+ correct_mass=False,
379
+ correct_channel=False,
380
+ correct_spin=False,
381
+ )
382
+
383
  truth = state.particle
384
 
385
  mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
 
414
 
415
  raw = breakdown.total * weights.terminal_scale
416
 
417
+ # Asymmetric claim cost (Fix #4). When the claim gets BOTH the mass
418
+ # and the decay channel right, add a flat bonus on top of the graded
419
+ # terminal so that a correct attempt is worth substantially more
420
+ # than the no-claim penalty (-5) and the wrong-claim penalty (~-1.85).
421
+ # This makes the bandit math: correct +10–12 ≫ no-claim −5 > wrong −2.
422
+ if mass_score >= 0.5 and channel_ok:
423
+ raw += weights.correct_claim_bonus
424
+ breakdown.add("correct_claim_bonus", weights.correct_claim_bonus)
425
+
426
  # Overconfident-wrong penalty: high confidence but wrong channel & far mass
427
  if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
428
  raw -= weights.overconfident_wrong_penalty