cernenv-trainer / server /rewards /reward_function.py
anugrahhu's picture
sft+reward-fix: server/rewards/reward_function.py
d91fe20 verified
"""Decomposed reward function.
Two stages:
1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small
incentives (progress, evidence quality, valid prerequisites) and
penalties (rule violations, repeated work, wasted resources).
2. **Terminal reward** ``compute_terminal_reward``: graded only when the
agent submits a discovery claim or runs out of resources. Compares the
submitted claim against the hidden ``LatentParticle`` truth.
The terminal reward is intentionally dominant so the policy must care about
the *correct* discovery, not just looking busy.
Anti-reward-hacking design notes
--------------------------------
The shaping reward is layered with several independent checks so that
exploiting any single one alone cannot dominate the terminal grade
(see hackathon guidance: *"use multiple independent reward functions"*):
* ``tool_fit`` is **gated**: the agent only earns it when ``method`` is
in ``TOOL_REGISTRY`` *and* the tool's category matches the action's
expected category. Bogus method strings get **penalized**, not rewarded.
* ``valid_action`` is gated on a parsed structured action that the rules
engine accepts — pure JSON-shaped junk does not earn it.
* ``progress_milestone`` only fires on the *first* time a milestone is
unlocked, so re-doing already-completed steps cannot farm it.
* ``redundancy`` and the new ``repeat_action_penalty`` punish loops that
re-emit the same action type many times in a row.
* The terminal grade dominates total reward via ``terminal_scale``, and
the overconfident-wrong penalty also fires when the claim *significance*
exceeds what was actually measured.
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Optional
import numpy as np
from models import (
ActionType,
DiscoveryClaim,
ExperimentAction,
IntermediateOutput,
TOOL_REGISTRY,
is_recommended_tool,
)
from server.rules.engine import RuleResult, ViolationCode
from server.simulator.latent_state import FullLatentState
# ── Configuration ────────────────────────────────────────────────────────
@dataclass
class RewardWeights:
# ── per-step shaping ────────────────────────────────────────
valid_action: float = 0.05
progress_milestone: float = 0.25
evidence_quality: float = 0.20
# Cut to ~1/3 of original (was 0.10) to lower the per-step shaping
# floor. Combined with a smaller step_reward_clip and a heavier
# repeat-action penalty this prevents the agent from farming
# +0.20+/step by cycling well-formed-but-inert tool calls.
tool_fit: float = 0.033 # paid only on a method ∈ TOOL_REGISTRY
# whose category matches the action.
bogus_method_penalty: float = -0.05 # penalises method strings outside
# TOOL_REGISTRY (anti-string-spam).
# Was -0.08; bumped to -0.5 because the previous value was easily out-
# earned by stacking format_bonus + valid_action + tool_fit. The
# gating in compute_step_reward also now triggers from the *2nd*
# consecutive identical action_type instead of the 3rd.
repeat_action_penalty: float = -0.5
soft_violation: float = -0.05
hard_violation: float = -0.50
redundancy: float = -0.10
resource_overspend: float = -0.30
failure: float = -0.30
# Hard cap on what a single shaping step can earn. Without this a
# policy could in principle stack milestone + evidence_quality +
# tool_fit + valid_action and approach the terminal reward magnitude.
# Cut from 0.75 → 0.25 so the per-step shaping floor cannot exceed
# ~1/3 of the wrong-claim terminal penalty.
step_reward_clip: float = 0.25
# ── terminal grading ────────────────────────────────────────
terminal_scale: float = 5.0 # multiplied with the convex sum below
mass_calibration: float = 0.30
significance_quality: float = 0.20
channel_correctness: float = 0.20
spin_correctness: float = 0.10
width_calibration: float = 0.05
confidence_calibration: float = 0.10
efficiency_bonus: float = 0.05
overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma
# Big bonus for getting BOTH mass and channel right, on top of the
# terminal grade. Makes the bandit math strictly favour attempting a
# claim when uncertain rather than running out the clock: a correct
# claim now returns ~+10–12, a wrong one ~−1.85, no claim ~−5.
correct_claim_bonus: float = 6.0
# Penalty applied at episode end when the trajectory never even
# *attempted* a SUBMIT_DISCOVERY_CLAIM. Defeats the "hide forever and
# farm shaping" reward hack we observed in v1 (mean +0.22/step over
# ~12 steps was a better deal than risking the wrong-claim penalty).
no_claim_terminal_penalty: float = -5.0
# ── Outputs ──────────────────────────────────────────────────────────────
@dataclass
class RewardBreakdown:
components: Dict[str, float] = field(default_factory=dict)
total: float = 0.0
def add(self, key: str, value: float) -> None:
self.components[key] = self.components.get(key, 0.0) + value
self.total += value
@dataclass
class StepReward:
reward: float
breakdown: RewardBreakdown
@dataclass
class TerminalReward:
reward: float
breakdown: RewardBreakdown
discovered: bool
correct_mass: bool
correct_channel: bool
correct_spin: bool
# ── Per-step ─────────────────────────────────────────────────────────────
_PROGRESS_FLAGS = [
"beam_configured",
"luminosity_allocated",
"trigger_set",
"collisions_collected",
"channel_selected",
"tracks_reconstructed",
"detector_calibrated",
"invariant_mass_built",
"background_subtracted",
"resonance_fitted",
"significance_estimated",
]
def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int:
"""Number of new progress milestones unlocked this step."""
delta = 0
for flag in _PROGRESS_FLAGS:
was = getattr(state_before.progress, flag)
now = getattr(state_after.progress, flag)
if now and not was:
delta += 1
return delta
def _consecutive_repeat_count(
history: List, action_type: ActionType, look_back: int = 6
) -> int:
"""How many times this action_type appeared *consecutively* most recently
(excluding the just-applied action). Used to dampen loops.
"""
if not history:
return 0
n = 0
for rec in reversed(history[-look_back:]):
if getattr(rec, "action_type", None) == action_type:
n += 1
else:
break
return n
def compute_step_reward(
*,
action: ExperimentAction,
output: IntermediateOutput,
state_before: FullLatentState,
state_after: FullLatentState,
rule_result: RuleResult,
weights: RewardWeights = RewardWeights(),
history: Optional[List] = None,
) -> StepReward:
"""Compute the per-step shaping reward.
``history`` is the list of ``PipelineStepRecord`` *before* this step. We
use it to detect consecutive-repeat loops (e.g. a model spamming the
same action_type to farm shaping). All other fields are local.
"""
breakdown = RewardBreakdown()
# ── basic validity / failure ────────────────────────────────────
if rule_result.allowed and output.success:
breakdown.add("valid_action", weights.valid_action)
if not output.success:
breakdown.add("failure", weights.failure)
# ── milestone progress (one-shot per flag, anti-farming) ────────
new_milestones = _milestone_progress(state_before, state_after)
if new_milestones > 0:
breakdown.add("progress", weights.progress_milestone * new_milestones)
# ── evidence quality ────────────────────────────────────────────
if output.success:
breakdown.add(
"evidence_quality",
weights.evidence_quality * float(output.quality_score),
)
# ── tool fit: gated on TOOL_REGISTRY membership + category match ─
# Bogus or mismatched method strings are explicitly penalised so the
# model can't farm shaping reward by setting method='whatever'.
if action.method:
if is_recommended_tool(action.action_type, action.method):
breakdown.add("tool_fit", weights.tool_fit)
elif action.method not in TOOL_REGISTRY:
breakdown.add("bogus_method", weights.bogus_method_penalty)
# If the tool exists but the category doesn't match the action,
# we silently award nothing (no penalty, no reward).
# ── rule penalties ──────────────────────────────────────────────
if rule_result.violations:
breakdown.add(
"hard_violation",
weights.hard_violation * len(rule_result.violations),
)
if rule_result.soft_violations:
soft_redundant = sum(
1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT
)
soft_other = len(rule_result.soft_violations) - soft_redundant
if soft_redundant:
breakdown.add("redundancy", weights.redundancy * soft_redundant)
if soft_other:
breakdown.add("soft_violation", weights.soft_violation * soft_other)
# ── consecutive-repeat penalty (catches loop hacks) ─────────────
# Triggers from the *2nd* identical action in a row (previously
# only kicked in at the 3rd). The escalating multiplier scales with
# the run length so that 4-in-a-row gets 4× the base penalty —
# important because v1 found that a tiny -0.08 was easily out-earned
# by the +0.22/step shaping floor.
repeats = _consecutive_repeat_count(history or [], action.action_type)
if repeats >= 1:
breakdown.add(
"repeat_action",
weights.repeat_action_penalty * repeats,
)
# ── resource overspend ──────────────────────────────────────────
res = state_after.resources
if res.budget_used_musd > res.budget_total_musd:
breakdown.add("budget_overspend", weights.resource_overspend)
if res.luminosity_used_fb > res.luminosity_total_fb:
breakdown.add("lumi_overspend", weights.resource_overspend)
if res.time_used_days > res.time_limit_days:
breakdown.add("time_overspend", weights.resource_overspend)
# ── total + soft cap ────────────────────────────────────────────
total = float(breakdown.total)
if weights.step_reward_clip > 0:
total = float(np.clip(total, -10.0, weights.step_reward_clip))
return StepReward(reward=total, breakdown=breakdown)
# ── Terminal grading ─────────────────────────────────────────────────────
def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float:
"""1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger)."""
if claim_mass is None or true_mass <= 0:
return 0.0
err = abs(claim_mass - true_mass)
# Tolerance: max(1.0 GeV, 1% of true mass, claimed unc)
tol = max(1.0, 0.01 * true_mass)
if unc is not None and unc > 0:
tol = max(tol, float(unc))
if err <= tol:
return 1.0
if err >= 5 * tol:
return 0.0
return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0))
def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float:
"""High score when claimed σ matches measured σ and is ≥ 5.
A claim_sigma far above the measured significance is a classic
reward-hacking pattern (just write '50' in the field), so we penalise
over-claiming proportionally.
"""
measured = state.progress.best_significance_sigma or 0.0
if claim_sigma is None:
return 0.0
over_claim = max(0.0, claim_sigma - measured)
base = float(np.clip(measured / 5.0, 0.0, 1.0))
penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0))
return float(np.clip(base - 0.5 * penalty, 0.0, 1.0))
def _significance_overclaim(
state: FullLatentState, claim_sigma: Optional[float], threshold_sigma: float = 1.5
) -> float:
"""How many σ the claim *exceeds* what the env actually measured.
Used as an extra penalty — distinct from ``_significance_score`` —
so that a model can't compensate a giant over-claim by getting the
mass slightly more accurate. Returns ``max(0, claim - measured - τ)``.
"""
if claim_sigma is None:
return 0.0
measured = state.progress.best_significance_sigma or 0.0
return float(max(0.0, claim_sigma - measured - threshold_sigma))
def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float:
"""Reward agents whose confidence tracks their actual accuracy."""
actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0)
err = abs(actual - claim_conf)
return float(np.clip(1.0 - err, 0.0, 1.0))
def _efficiency_bonus(state: FullLatentState) -> float:
"""Reward leftover budget (encourages succinct experiments)."""
res = state.resources
score = 0.0
score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0)
score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0)
score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0)
return float(score / 3.0)
def compute_terminal_reward(
*,
state: FullLatentState,
claim: Optional[DiscoveryClaim],
weights: RewardWeights = RewardWeights(),
) -> TerminalReward:
"""Grade the end-of-episode submission.
``claim`` is ``None`` when the episode terminated by *any* reason
other than a ``submit_discovery_claim`` action (max_steps, budget
exhausted, time exhausted) AND the trajectory never attempted to
submit a claim. In that case we return a flat
``no_claim_terminal_penalty`` so the bandit math always favours
*attempting* a claim over hiding forever to farm per-step shaping.
See: v1 (anugrahhu/cernenv-grpo-smollm2-360m) which exploited this
exact gap by spamming request_systematics for ~+0.22/step instead
of risking the wrong-claim penalty (~−1.85).
"""
breakdown = RewardBreakdown()
if claim is None:
breakdown.add("no_claim_terminal_penalty", weights.no_claim_terminal_penalty)
return TerminalReward(
reward=float(weights.no_claim_terminal_penalty),
breakdown=breakdown,
discovered=False,
correct_mass=False,
correct_channel=False,
correct_spin=False,
)
truth = state.particle
mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
breakdown.add("mass_calibration", weights.mass_calibration * mass_score)
sig_score = _significance_score(state, claim.significance_sigma)
breakdown.add("significance_quality", weights.significance_quality * sig_score)
channel_ok = claim.decay_channel == truth.primary_channel
breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0))
spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin
breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0))
width_score = 0.0
if claim.width_estimate_gev is not None and truth.width_gev > 0:
rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3)
width_score = float(np.clip(1.0 - rel, 0.0, 1.0))
breakdown.add("width_calibration", weights.width_calibration * width_score)
conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok)
breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score)
eff_score = _efficiency_bonus(state)
breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score)
discovered = (
mass_score >= 0.5
and channel_ok
and (claim.significance_sigma or 0.0) >= 4.5
)
raw = breakdown.total * weights.terminal_scale
# Asymmetric claim cost (Fix #4). When the claim gets BOTH the mass
# and the decay channel right, add a flat bonus on top of the graded
# terminal so that a correct attempt is worth substantially more
# than the no-claim penalty (-5) and the wrong-claim penalty (~-1.85).
# This makes the bandit math: correct +10–12 ≫ no-claim −5 > wrong −2.
if mass_score >= 0.5 and channel_ok:
raw += weights.correct_claim_bonus
breakdown.add("correct_claim_bonus", weights.correct_claim_bonus)
# Overconfident-wrong penalty: high confidence but wrong channel & far mass
if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
raw -= weights.overconfident_wrong_penalty
breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty)
# Significance-overclaim penalty (anti-reward-hacking): discourages the
# model from just writing a giant σ in the claim regardless of evidence.
overclaim_sigma = _significance_overclaim(state, claim.significance_sigma)
if overclaim_sigma > 0:
pen = weights.overclaim_significance_penalty * float(
np.clip(overclaim_sigma / 3.0, 0.0, 2.0)
)
raw -= pen
breakdown.add("overclaim_significance", -pen)
# If the claim has zero/None mass and zero/None significance, treat it
# as a "no-information" submission — clamp the raw reward so the model
# can't pass the rules engine and then submit garbage to end early.
if (claim.mass_estimate_gev is None) and (claim.significance_sigma in (None, 0.0)):
raw = float(min(raw, 0.0))
breakdown.add("no_information_claim", 0.0)
return TerminalReward(
reward=float(raw),
breakdown=breakdown,
discovered=discovered,
correct_mass=mass_score >= 0.5,
correct_channel=channel_ok,
correct_spin=spin_ok,
)
__all__ = [
"RewardBreakdown",
"RewardWeights",
"StepReward",
"TerminalReward",
"compute_step_reward",
"compute_terminal_reward",
]