| """Decomposed reward function.
|
|
|
| Two stages:
|
| 1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small
|
| incentives (progress, evidence quality, valid prerequisites) and
|
| penalties (rule violations, repeated work, wasted resources).
|
| 2. **Terminal reward** ``compute_terminal_reward``: graded only when the
|
| agent submits a discovery claim or runs out of resources. Compares the
|
| submitted claim against the hidden ``LatentParticle`` truth.
|
|
|
| The terminal reward is intentionally dominant so the policy must care about
|
| the *correct* discovery, not just looking busy.
|
|
|
| Anti-reward-hacking design notes
|
| --------------------------------
|
| The shaping reward is layered with several independent checks so that
|
| exploiting any single one alone cannot dominate the terminal grade
|
| (see hackathon guidance: *"use multiple independent reward functions"*):
|
|
|
| * ``tool_fit`` is **gated**: the agent only earns it when ``method`` is
|
| in ``TOOL_REGISTRY`` *and* the tool's category matches the action's
|
| expected category. Bogus method strings get **penalized**, not rewarded.
|
| * ``valid_action`` is gated on a parsed structured action that the rules
|
| engine accepts — pure JSON-shaped junk does not earn it.
|
| * ``progress_milestone`` only fires on the *first* time a milestone is
|
| unlocked, so re-doing already-completed steps cannot farm it.
|
| * ``redundancy`` and the new ``repeat_action_penalty`` punish loops that
|
| re-emit the same action type many times in a row.
|
| * The terminal grade dominates total reward via ``terminal_scale``, and
|
| the overconfident-wrong penalty also fires when the claim *significance*
|
| exceeds what was actually measured.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from collections import deque
|
| from dataclasses import dataclass, field
|
| from typing import Deque, Dict, List, Optional
|
|
|
| import numpy as np
|
|
|
| from models import (
|
| ActionType,
|
| DiscoveryClaim,
|
| ExperimentAction,
|
| IntermediateOutput,
|
| TOOL_REGISTRY,
|
| is_recommended_tool,
|
| )
|
|
|
| from server.rules.engine import RuleResult, ViolationCode
|
| from server.simulator.latent_state import FullLatentState
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class RewardWeights:
|
|
|
| valid_action: float = 0.05
|
| progress_milestone: float = 0.25
|
| evidence_quality: float = 0.20
|
| tool_fit: float = 0.10
|
|
|
| bogus_method_penalty: float = -0.05
|
|
|
| repeat_action_penalty: float = -0.08
|
|
|
| soft_violation: float = -0.05
|
| hard_violation: float = -0.50
|
| redundancy: float = -0.10
|
| resource_overspend: float = -0.30
|
| failure: float = -0.30
|
|
|
|
|
|
|
|
|
| step_reward_clip: float = 0.75
|
|
|
|
|
| terminal_scale: float = 5.0
|
|
|
| mass_calibration: float = 0.30
|
| significance_quality: float = 0.20
|
| channel_correctness: float = 0.20
|
| spin_correctness: float = 0.10
|
| width_calibration: float = 0.05
|
| confidence_calibration: float = 0.10
|
| efficiency_bonus: float = 0.05
|
|
|
| overconfident_wrong_penalty: float = 4.0
|
| overclaim_significance_penalty: float = 1.5
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class RewardBreakdown:
|
| components: Dict[str, float] = field(default_factory=dict)
|
| total: float = 0.0
|
|
|
| def add(self, key: str, value: float) -> None:
|
| self.components[key] = self.components.get(key, 0.0) + value
|
| self.total += value
|
|
|
|
|
| @dataclass
|
| class StepReward:
|
| reward: float
|
| breakdown: RewardBreakdown
|
|
|
|
|
| @dataclass
|
| class TerminalReward:
|
| reward: float
|
| breakdown: RewardBreakdown
|
| discovered: bool
|
| correct_mass: bool
|
| correct_channel: bool
|
| correct_spin: bool
|
|
|
|
|
|
|
|
|
|
|
| _PROGRESS_FLAGS = [
|
| "beam_configured",
|
| "luminosity_allocated",
|
| "trigger_set",
|
| "collisions_collected",
|
| "channel_selected",
|
| "tracks_reconstructed",
|
| "detector_calibrated",
|
| "invariant_mass_built",
|
| "background_subtracted",
|
| "resonance_fitted",
|
| "significance_estimated",
|
| ]
|
|
|
|
|
| def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int:
|
| """Number of new progress milestones unlocked this step."""
|
| delta = 0
|
| for flag in _PROGRESS_FLAGS:
|
| was = getattr(state_before.progress, flag)
|
| now = getattr(state_after.progress, flag)
|
| if now and not was:
|
| delta += 1
|
| return delta
|
|
|
|
|
| def _consecutive_repeat_count(
|
| history: List, action_type: ActionType, look_back: int = 6
|
| ) -> int:
|
| """How many times this action_type appeared *consecutively* most recently
|
| (excluding the just-applied action). Used to dampen loops.
|
| """
|
| if not history:
|
| return 0
|
| n = 0
|
| for rec in reversed(history[-look_back:]):
|
| if getattr(rec, "action_type", None) == action_type:
|
| n += 1
|
| else:
|
| break
|
| return n
|
|
|
|
|
| def compute_step_reward(
|
| *,
|
| action: ExperimentAction,
|
| output: IntermediateOutput,
|
| state_before: FullLatentState,
|
| state_after: FullLatentState,
|
| rule_result: RuleResult,
|
| weights: RewardWeights = RewardWeights(),
|
| history: Optional[List] = None,
|
| ) -> StepReward:
|
| """Compute the per-step shaping reward.
|
|
|
| ``history`` is the list of ``PipelineStepRecord`` *before* this step. We
|
| use it to detect consecutive-repeat loops (e.g. a model spamming the
|
| same action_type to farm shaping). All other fields are local.
|
| """
|
|
|
| breakdown = RewardBreakdown()
|
|
|
|
|
| if rule_result.allowed and output.success:
|
| breakdown.add("valid_action", weights.valid_action)
|
| if not output.success:
|
| breakdown.add("failure", weights.failure)
|
|
|
|
|
| new_milestones = _milestone_progress(state_before, state_after)
|
| if new_milestones > 0:
|
| breakdown.add("progress", weights.progress_milestone * new_milestones)
|
|
|
|
|
| if output.success:
|
| breakdown.add(
|
| "evidence_quality",
|
| weights.evidence_quality * float(output.quality_score),
|
| )
|
|
|
|
|
|
|
|
|
| if action.method:
|
| if is_recommended_tool(action.action_type, action.method):
|
| breakdown.add("tool_fit", weights.tool_fit)
|
| elif action.method not in TOOL_REGISTRY:
|
| breakdown.add("bogus_method", weights.bogus_method_penalty)
|
|
|
|
|
|
|
|
|
| if rule_result.violations:
|
| breakdown.add(
|
| "hard_violation",
|
| weights.hard_violation * len(rule_result.violations),
|
| )
|
| if rule_result.soft_violations:
|
| soft_redundant = sum(
|
| 1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT
|
| )
|
| soft_other = len(rule_result.soft_violations) - soft_redundant
|
| if soft_redundant:
|
| breakdown.add("redundancy", weights.redundancy * soft_redundant)
|
| if soft_other:
|
| breakdown.add("soft_violation", weights.soft_violation * soft_other)
|
|
|
|
|
|
|
|
|
| repeats = _consecutive_repeat_count(history or [], action.action_type)
|
| if repeats >= 2:
|
| breakdown.add(
|
| "repeat_action",
|
| weights.repeat_action_penalty * (repeats - 1),
|
| )
|
|
|
|
|
| res = state_after.resources
|
| if res.budget_used_musd > res.budget_total_musd:
|
| breakdown.add("budget_overspend", weights.resource_overspend)
|
| if res.luminosity_used_fb > res.luminosity_total_fb:
|
| breakdown.add("lumi_overspend", weights.resource_overspend)
|
| if res.time_used_days > res.time_limit_days:
|
| breakdown.add("time_overspend", weights.resource_overspend)
|
|
|
|
|
| total = float(breakdown.total)
|
| if weights.step_reward_clip > 0:
|
| total = float(np.clip(total, -10.0, weights.step_reward_clip))
|
| return StepReward(reward=total, breakdown=breakdown)
|
|
|
|
|
|
|
|
|
|
|
| def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float:
|
| """1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger)."""
|
| if claim_mass is None or true_mass <= 0:
|
| return 0.0
|
| err = abs(claim_mass - true_mass)
|
|
|
| tol = max(1.0, 0.01 * true_mass)
|
| if unc is not None and unc > 0:
|
| tol = max(tol, float(unc))
|
| if err <= tol:
|
| return 1.0
|
| if err >= 5 * tol:
|
| return 0.0
|
| return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0))
|
|
|
|
|
| def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float:
|
| """High score when claimed σ matches measured σ and is ≥ 5.
|
|
|
| A claim_sigma far above the measured significance is a classic
|
| reward-hacking pattern (just write '50' in the field), so we penalise
|
| over-claiming proportionally.
|
| """
|
| measured = state.progress.best_significance_sigma or 0.0
|
| if claim_sigma is None:
|
| return 0.0
|
| over_claim = max(0.0, claim_sigma - measured)
|
| base = float(np.clip(measured / 5.0, 0.0, 1.0))
|
| penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0))
|
| return float(np.clip(base - 0.5 * penalty, 0.0, 1.0))
|
|
|
|
|
| def _significance_overclaim(
|
| state: FullLatentState, claim_sigma: Optional[float], threshold_sigma: float = 1.5
|
| ) -> float:
|
| """How many σ the claim *exceeds* what the env actually measured.
|
|
|
| Used as an extra penalty — distinct from ``_significance_score`` —
|
| so that a model can't compensate a giant over-claim by getting the
|
| mass slightly more accurate. Returns ``max(0, claim - measured - τ)``.
|
| """
|
| if claim_sigma is None:
|
| return 0.0
|
| measured = state.progress.best_significance_sigma or 0.0
|
| return float(max(0.0, claim_sigma - measured - threshold_sigma))
|
|
|
|
|
| def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float:
|
| """Reward agents whose confidence tracks their actual accuracy."""
|
| actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0)
|
| err = abs(actual - claim_conf)
|
| return float(np.clip(1.0 - err, 0.0, 1.0))
|
|
|
|
|
| def _efficiency_bonus(state: FullLatentState) -> float:
|
| """Reward leftover budget (encourages succinct experiments)."""
|
| res = state.resources
|
| score = 0.0
|
| score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0)
|
| score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0)
|
| score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0)
|
| return float(score / 3.0)
|
|
|
|
|
| def compute_terminal_reward(
|
| *,
|
| state: FullLatentState,
|
| claim: DiscoveryClaim,
|
| weights: RewardWeights = RewardWeights(),
|
| ) -> TerminalReward:
|
| breakdown = RewardBreakdown()
|
| truth = state.particle
|
|
|
| mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
|
| breakdown.add("mass_calibration", weights.mass_calibration * mass_score)
|
|
|
| sig_score = _significance_score(state, claim.significance_sigma)
|
| breakdown.add("significance_quality", weights.significance_quality * sig_score)
|
|
|
| channel_ok = claim.decay_channel == truth.primary_channel
|
| breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0))
|
|
|
| spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin
|
| breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0))
|
|
|
| width_score = 0.0
|
| if claim.width_estimate_gev is not None and truth.width_gev > 0:
|
| rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3)
|
| width_score = float(np.clip(1.0 - rel, 0.0, 1.0))
|
| breakdown.add("width_calibration", weights.width_calibration * width_score)
|
|
|
| conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok)
|
| breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score)
|
|
|
| eff_score = _efficiency_bonus(state)
|
| breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score)
|
|
|
| discovered = (
|
| mass_score >= 0.5
|
| and channel_ok
|
| and (claim.significance_sigma or 0.0) >= 4.5
|
| )
|
|
|
| raw = breakdown.total * weights.terminal_scale
|
|
|
|
|
| if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
|
| raw -= weights.overconfident_wrong_penalty
|
| breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty)
|
|
|
|
|
|
|
| overclaim_sigma = _significance_overclaim(state, claim.significance_sigma)
|
| if overclaim_sigma > 0:
|
| pen = weights.overclaim_significance_penalty * float(
|
| np.clip(overclaim_sigma / 3.0, 0.0, 2.0)
|
| )
|
| raw -= pen
|
| breakdown.add("overclaim_significance", -pen)
|
|
|
|
|
|
|
|
|
| if (claim.mass_estimate_gev is None) and (claim.significance_sigma in (None, 0.0)):
|
| raw = float(min(raw, 0.0))
|
| breakdown.add("no_information_claim", 0.0)
|
|
|
| return TerminalReward(
|
| reward=float(raw),
|
| breakdown=breakdown,
|
| discovered=discovered,
|
| correct_mass=mass_score >= 0.5,
|
| correct_channel=channel_ok,
|
| correct_spin=spin_ok,
|
| )
|
|
|
|
|
| __all__ = [
|
| "RewardBreakdown",
|
| "RewardWeights",
|
| "StepReward",
|
| "TerminalReward",
|
| "compute_step_reward",
|
| "compute_terminal_reward",
|
| ]
|
|
|