Spaces:
Paused
Paused
| """Decomposed reward function. | |
| Two stages: | |
| 1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small | |
| incentives (progress, evidence quality, valid prerequisites) and | |
| penalties (rule violations, repeated work, wasted resources). | |
| 2. **Terminal reward** ``compute_terminal_reward``: graded only when the | |
| agent submits a discovery claim or runs out of resources. Compares the | |
| submitted claim against the hidden ``LatentParticle`` truth. | |
| The terminal reward is intentionally dominant so the policy must care about | |
| the *correct* discovery, not just looking busy. | |
| Anti-reward-hacking design notes | |
| -------------------------------- | |
| The shaping reward is layered with several independent checks so that | |
| exploiting any single one alone cannot dominate the terminal grade | |
| (see hackathon guidance: *"use multiple independent reward functions"*): | |
| * ``tool_fit`` is **gated**: the agent only earns it when ``method`` is | |
| in ``TOOL_REGISTRY`` *and* the tool's category matches the action's | |
| expected category. Bogus method strings get **penalized**, not rewarded. | |
| * ``valid_action`` is gated on a parsed structured action that the rules | |
| engine accepts — pure JSON-shaped junk does not earn it. | |
| * ``progress_milestone`` only fires on the *first* time a milestone is | |
| unlocked, so re-doing already-completed steps cannot farm it. | |
| * ``redundancy`` and the new ``repeat_action_penalty`` punish loops that | |
| re-emit the same action type many times in a row. | |
| * The terminal grade dominates total reward via ``terminal_scale``, and | |
| the overconfident-wrong penalty also fires when the claim *significance* | |
| exceeds what was actually measured. | |
| """ | |
| from __future__ import annotations | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from typing import Deque, Dict, List, Optional | |
| import numpy as np | |
| from models import ( | |
| ActionType, | |
| DiscoveryClaim, | |
| ExperimentAction, | |
| IntermediateOutput, | |
| TOOL_REGISTRY, | |
| is_recommended_tool, | |
| ) | |
| from server.rules.engine import RuleResult, ViolationCode | |
| from server.simulator.latent_state import FullLatentState | |
| # ── Configuration ──────────────────────────────────────────────────────── | |
| class RewardWeights: | |
| # ── per-step shaping ──────────────────────────────────────── | |
| valid_action: float = 0.05 | |
| progress_milestone: float = 0.25 | |
| evidence_quality: float = 0.20 | |
| # Cut to ~1/3 of original (was 0.10) to lower the per-step shaping | |
| # floor. Combined with a smaller step_reward_clip and a heavier | |
| # repeat-action penalty this prevents the agent from farming | |
| # +0.20+/step by cycling well-formed-but-inert tool calls. | |
| tool_fit: float = 0.033 # paid only on a method ∈ TOOL_REGISTRY | |
| # whose category matches the action. | |
| bogus_method_penalty: float = -0.05 # penalises method strings outside | |
| # TOOL_REGISTRY (anti-string-spam). | |
| # Was -0.08; bumped to -0.5 because the previous value was easily out- | |
| # earned by stacking format_bonus + valid_action + tool_fit. The | |
| # gating in compute_step_reward also now triggers from the *2nd* | |
| # consecutive identical action_type instead of the 3rd. | |
| repeat_action_penalty: float = -0.5 | |
| soft_violation: float = -0.05 | |
| hard_violation: float = -0.50 | |
| redundancy: float = -0.10 | |
| resource_overspend: float = -0.30 | |
| failure: float = -0.30 | |
| # Hard cap on what a single shaping step can earn. Without this a | |
| # policy could in principle stack milestone + evidence_quality + | |
| # tool_fit + valid_action and approach the terminal reward magnitude. | |
| # Cut from 0.75 → 0.25 so the per-step shaping floor cannot exceed | |
| # ~1/3 of the wrong-claim terminal penalty. | |
| step_reward_clip: float = 0.25 | |
| # ── terminal grading ──────────────────────────────────────── | |
| terminal_scale: float = 5.0 # multiplied with the convex sum below | |
| mass_calibration: float = 0.30 | |
| significance_quality: float = 0.20 | |
| channel_correctness: float = 0.20 | |
| spin_correctness: float = 0.10 | |
| width_calibration: float = 0.05 | |
| confidence_calibration: float = 0.10 | |
| efficiency_bonus: float = 0.05 | |
| overconfident_wrong_penalty: float = 4.0 # subtracted from terminal | |
| overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma | |
| # Big bonus for getting BOTH mass and channel right, on top of the | |
| # terminal grade. Makes the bandit math strictly favour attempting a | |
| # claim when uncertain rather than running out the clock: a correct | |
| # claim now returns ~+10–12, a wrong one ~−1.85, no claim ~−5. | |
| correct_claim_bonus: float = 6.0 | |
| # Penalty applied at episode end when the trajectory never even | |
| # *attempted* a SUBMIT_DISCOVERY_CLAIM. Defeats the "hide forever and | |
| # farm shaping" reward hack we observed in v1 (mean +0.22/step over | |
| # ~12 steps was a better deal than risking the wrong-claim penalty). | |
| no_claim_terminal_penalty: float = -5.0 | |
| # ── Outputs ────────────────────────────────────────────────────────────── | |
| class RewardBreakdown: | |
| components: Dict[str, float] = field(default_factory=dict) | |
| total: float = 0.0 | |
| def add(self, key: str, value: float) -> None: | |
| self.components[key] = self.components.get(key, 0.0) + value | |
| self.total += value | |
| class StepReward: | |
| reward: float | |
| breakdown: RewardBreakdown | |
| class TerminalReward: | |
| reward: float | |
| breakdown: RewardBreakdown | |
| discovered: bool | |
| correct_mass: bool | |
| correct_channel: bool | |
| correct_spin: bool | |
| # ── Per-step ───────────────────────────────────────────────────────────── | |
| _PROGRESS_FLAGS = [ | |
| "beam_configured", | |
| "luminosity_allocated", | |
| "trigger_set", | |
| "collisions_collected", | |
| "channel_selected", | |
| "tracks_reconstructed", | |
| "detector_calibrated", | |
| "invariant_mass_built", | |
| "background_subtracted", | |
| "resonance_fitted", | |
| "significance_estimated", | |
| ] | |
| def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int: | |
| """Number of new progress milestones unlocked this step.""" | |
| delta = 0 | |
| for flag in _PROGRESS_FLAGS: | |
| was = getattr(state_before.progress, flag) | |
| now = getattr(state_after.progress, flag) | |
| if now and not was: | |
| delta += 1 | |
| return delta | |
| def _consecutive_repeat_count( | |
| history: List, action_type: ActionType, look_back: int = 6 | |
| ) -> int: | |
| """How many times this action_type appeared *consecutively* most recently | |
| (excluding the just-applied action). Used to dampen loops. | |
| """ | |
| if not history: | |
| return 0 | |
| n = 0 | |
| for rec in reversed(history[-look_back:]): | |
| if getattr(rec, "action_type", None) == action_type: | |
| n += 1 | |
| else: | |
| break | |
| return n | |
| def compute_step_reward( | |
| *, | |
| action: ExperimentAction, | |
| output: IntermediateOutput, | |
| state_before: FullLatentState, | |
| state_after: FullLatentState, | |
| rule_result: RuleResult, | |
| weights: RewardWeights = RewardWeights(), | |
| history: Optional[List] = None, | |
| ) -> StepReward: | |
| """Compute the per-step shaping reward. | |
| ``history`` is the list of ``PipelineStepRecord`` *before* this step. We | |
| use it to detect consecutive-repeat loops (e.g. a model spamming the | |
| same action_type to farm shaping). All other fields are local. | |
| """ | |
| breakdown = RewardBreakdown() | |
| # ── basic validity / failure ──────────────────────────────────── | |
| if rule_result.allowed and output.success: | |
| breakdown.add("valid_action", weights.valid_action) | |
| if not output.success: | |
| breakdown.add("failure", weights.failure) | |
| # ── milestone progress (one-shot per flag, anti-farming) ──────── | |
| new_milestones = _milestone_progress(state_before, state_after) | |
| if new_milestones > 0: | |
| breakdown.add("progress", weights.progress_milestone * new_milestones) | |
| # ── evidence quality ──────────────────────────────────────────── | |
| if output.success: | |
| breakdown.add( | |
| "evidence_quality", | |
| weights.evidence_quality * float(output.quality_score), | |
| ) | |
| # ── tool fit: gated on TOOL_REGISTRY membership + category match ─ | |
| # Bogus or mismatched method strings are explicitly penalised so the | |
| # model can't farm shaping reward by setting method='whatever'. | |
| if action.method: | |
| if is_recommended_tool(action.action_type, action.method): | |
| breakdown.add("tool_fit", weights.tool_fit) | |
| elif action.method not in TOOL_REGISTRY: | |
| breakdown.add("bogus_method", weights.bogus_method_penalty) | |
| # If the tool exists but the category doesn't match the action, | |
| # we silently award nothing (no penalty, no reward). | |
| # ── rule penalties ────────────────────────────────────────────── | |
| if rule_result.violations: | |
| breakdown.add( | |
| "hard_violation", | |
| weights.hard_violation * len(rule_result.violations), | |
| ) | |
| if rule_result.soft_violations: | |
| soft_redundant = sum( | |
| 1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT | |
| ) | |
| soft_other = len(rule_result.soft_violations) - soft_redundant | |
| if soft_redundant: | |
| breakdown.add("redundancy", weights.redundancy * soft_redundant) | |
| if soft_other: | |
| breakdown.add("soft_violation", weights.soft_violation * soft_other) | |
| # ── consecutive-repeat penalty (catches loop hacks) ───────────── | |
| # Triggers from the *2nd* identical action in a row (previously | |
| # only kicked in at the 3rd). The escalating multiplier scales with | |
| # the run length so that 4-in-a-row gets 4× the base penalty — | |
| # important because v1 found that a tiny -0.08 was easily out-earned | |
| # by the +0.22/step shaping floor. | |
| repeats = _consecutive_repeat_count(history or [], action.action_type) | |
| if repeats >= 1: | |
| breakdown.add( | |
| "repeat_action", | |
| weights.repeat_action_penalty * repeats, | |
| ) | |
| # ── resource overspend ────────────────────────────────────────── | |
| res = state_after.resources | |
| if res.budget_used_musd > res.budget_total_musd: | |
| breakdown.add("budget_overspend", weights.resource_overspend) | |
| if res.luminosity_used_fb > res.luminosity_total_fb: | |
| breakdown.add("lumi_overspend", weights.resource_overspend) | |
| if res.time_used_days > res.time_limit_days: | |
| breakdown.add("time_overspend", weights.resource_overspend) | |
| # ── total + soft cap ──────────────────────────────────────────── | |
| total = float(breakdown.total) | |
| if weights.step_reward_clip > 0: | |
| total = float(np.clip(total, -10.0, weights.step_reward_clip)) | |
| return StepReward(reward=total, breakdown=breakdown) | |
| # ── Terminal grading ───────────────────────────────────────────────────── | |
| def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float: | |
| """1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger).""" | |
| if claim_mass is None or true_mass <= 0: | |
| return 0.0 | |
| err = abs(claim_mass - true_mass) | |
| # Tolerance: max(1.0 GeV, 1% of true mass, claimed unc) | |
| tol = max(1.0, 0.01 * true_mass) | |
| if unc is not None and unc > 0: | |
| tol = max(tol, float(unc)) | |
| if err <= tol: | |
| return 1.0 | |
| if err >= 5 * tol: | |
| return 0.0 | |
| return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0)) | |
| def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float: | |
| """High score when claimed σ matches measured σ and is ≥ 5. | |
| A claim_sigma far above the measured significance is a classic | |
| reward-hacking pattern (just write '50' in the field), so we penalise | |
| over-claiming proportionally. | |
| """ | |
| measured = state.progress.best_significance_sigma or 0.0 | |
| if claim_sigma is None: | |
| return 0.0 | |
| over_claim = max(0.0, claim_sigma - measured) | |
| base = float(np.clip(measured / 5.0, 0.0, 1.0)) | |
| penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0)) | |
| return float(np.clip(base - 0.5 * penalty, 0.0, 1.0)) | |
| def _significance_overclaim( | |
| state: FullLatentState, claim_sigma: Optional[float], threshold_sigma: float = 1.5 | |
| ) -> float: | |
| """How many σ the claim *exceeds* what the env actually measured. | |
| Used as an extra penalty — distinct from ``_significance_score`` — | |
| so that a model can't compensate a giant over-claim by getting the | |
| mass slightly more accurate. Returns ``max(0, claim - measured - τ)``. | |
| """ | |
| if claim_sigma is None: | |
| return 0.0 | |
| measured = state.progress.best_significance_sigma or 0.0 | |
| return float(max(0.0, claim_sigma - measured - threshold_sigma)) | |
| def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float: | |
| """Reward agents whose confidence tracks their actual accuracy.""" | |
| actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0) | |
| err = abs(actual - claim_conf) | |
| return float(np.clip(1.0 - err, 0.0, 1.0)) | |
| def _efficiency_bonus(state: FullLatentState) -> float: | |
| """Reward leftover budget (encourages succinct experiments).""" | |
| res = state.resources | |
| score = 0.0 | |
| score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0) | |
| score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0) | |
| score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0) | |
| return float(score / 3.0) | |
| def compute_terminal_reward( | |
| *, | |
| state: FullLatentState, | |
| claim: Optional[DiscoveryClaim], | |
| weights: RewardWeights = RewardWeights(), | |
| ) -> TerminalReward: | |
| """Grade the end-of-episode submission. | |
| ``claim`` is ``None`` when the episode terminated by *any* reason | |
| other than a ``submit_discovery_claim`` action (max_steps, budget | |
| exhausted, time exhausted) AND the trajectory never attempted to | |
| submit a claim. In that case we return a flat | |
| ``no_claim_terminal_penalty`` so the bandit math always favours | |
| *attempting* a claim over hiding forever to farm per-step shaping. | |
| See: v1 (anugrahhu/cernenv-grpo-smollm2-360m) which exploited this | |
| exact gap by spamming request_systematics for ~+0.22/step instead | |
| of risking the wrong-claim penalty (~−1.85). | |
| """ | |
| breakdown = RewardBreakdown() | |
| if claim is None: | |
| breakdown.add("no_claim_terminal_penalty", weights.no_claim_terminal_penalty) | |
| return TerminalReward( | |
| reward=float(weights.no_claim_terminal_penalty), | |
| breakdown=breakdown, | |
| discovered=False, | |
| correct_mass=False, | |
| correct_channel=False, | |
| correct_spin=False, | |
| ) | |
| truth = state.particle | |
| mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev) | |
| breakdown.add("mass_calibration", weights.mass_calibration * mass_score) | |
| sig_score = _significance_score(state, claim.significance_sigma) | |
| breakdown.add("significance_quality", weights.significance_quality * sig_score) | |
| channel_ok = claim.decay_channel == truth.primary_channel | |
| breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0)) | |
| spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin | |
| breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0)) | |
| width_score = 0.0 | |
| if claim.width_estimate_gev is not None and truth.width_gev > 0: | |
| rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3) | |
| width_score = float(np.clip(1.0 - rel, 0.0, 1.0)) | |
| breakdown.add("width_calibration", weights.width_calibration * width_score) | |
| conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok) | |
| breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score) | |
| eff_score = _efficiency_bonus(state) | |
| breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score) | |
| discovered = ( | |
| mass_score >= 0.5 | |
| and channel_ok | |
| and (claim.significance_sigma or 0.0) >= 4.5 | |
| ) | |
| raw = breakdown.total * weights.terminal_scale | |
| # Asymmetric claim cost (Fix #4). When the claim gets BOTH the mass | |
| # and the decay channel right, add a flat bonus on top of the graded | |
| # terminal so that a correct attempt is worth substantially more | |
| # than the no-claim penalty (-5) and the wrong-claim penalty (~-1.85). | |
| # This makes the bandit math: correct +10–12 ≫ no-claim −5 > wrong −2. | |
| if mass_score >= 0.5 and channel_ok: | |
| raw += weights.correct_claim_bonus | |
| breakdown.add("correct_claim_bonus", weights.correct_claim_bonus) | |
| # Overconfident-wrong penalty: high confidence but wrong channel & far mass | |
| if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok): | |
| raw -= weights.overconfident_wrong_penalty | |
| breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty) | |
| # Significance-overclaim penalty (anti-reward-hacking): discourages the | |
| # model from just writing a giant σ in the claim regardless of evidence. | |
| overclaim_sigma = _significance_overclaim(state, claim.significance_sigma) | |
| if overclaim_sigma > 0: | |
| pen = weights.overclaim_significance_penalty * float( | |
| np.clip(overclaim_sigma / 3.0, 0.0, 2.0) | |
| ) | |
| raw -= pen | |
| breakdown.add("overclaim_significance", -pen) | |
| # If the claim has zero/None mass and zero/None significance, treat it | |
| # as a "no-information" submission — clamp the raw reward so the model | |
| # can't pass the rules engine and then submit garbage to end early. | |
| if (claim.mass_estimate_gev is None) and (claim.significance_sigma in (None, 0.0)): | |
| raw = float(min(raw, 0.0)) | |
| breakdown.add("no_information_claim", 0.0) | |
| return TerminalReward( | |
| reward=float(raw), | |
| breakdown=breakdown, | |
| discovered=discovered, | |
| correct_mass=mass_score >= 0.5, | |
| correct_channel=channel_ok, | |
| correct_spin=spin_ok, | |
| ) | |
| __all__ = [ | |
| "RewardBreakdown", | |
| "RewardWeights", | |
| "StepReward", | |
| "TerminalReward", | |
| "compute_step_reward", | |
| "compute_terminal_reward", | |
| ] | |