"""Decomposed reward function. Two stages: 1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small incentives (progress, evidence quality, valid prerequisites) and penalties (rule violations, repeated work, wasted resources). 2. **Terminal reward** ``compute_terminal_reward``: graded only when the agent submits a discovery claim or runs out of resources. Compares the submitted claim against the hidden ``LatentParticle`` truth. The terminal reward is intentionally dominant so the policy must care about the *correct* discovery, not just looking busy. Anti-reward-hacking design notes -------------------------------- The shaping reward is layered with several independent checks so that exploiting any single one alone cannot dominate the terminal grade (see hackathon guidance: *"use multiple independent reward functions"*): * ``tool_fit`` is **gated**: the agent only earns it when ``method`` is in ``TOOL_REGISTRY`` *and* the tool's category matches the action's expected category. Bogus method strings get **penalized**, not rewarded. * ``valid_action`` is gated on a parsed structured action that the rules engine accepts — pure JSON-shaped junk does not earn it. * ``progress_milestone`` only fires on the *first* time a milestone is unlocked, so re-doing already-completed steps cannot farm it. * ``redundancy`` and the new ``repeat_action_penalty`` punish loops that re-emit the same action type many times in a row. * The terminal grade dominates total reward via ``terminal_scale``, and the overconfident-wrong penalty also fires when the claim *significance* exceeds what was actually measured. """ from __future__ import annotations from collections import deque from dataclasses import dataclass, field from typing import Deque, Dict, List, Optional import numpy as np from models import ( ActionType, DiscoveryClaim, ExperimentAction, IntermediateOutput, TOOL_REGISTRY, is_recommended_tool, ) from server.rules.engine import RuleResult, ViolationCode from server.simulator.latent_state import FullLatentState # ── Configuration ──────────────────────────────────────────────────────── @dataclass class RewardWeights: # ── per-step shaping ──────────────────────────────────────── valid_action: float = 0.05 progress_milestone: float = 0.25 evidence_quality: float = 0.20 tool_fit: float = 0.10 # paid only on a method ∈ TOOL_REGISTRY # whose category matches the action. bogus_method_penalty: float = -0.05 # penalises method strings outside # TOOL_REGISTRY (anti-string-spam). repeat_action_penalty: float = -0.08 # per consecutive repeat beyond the # second identical action_type in a row. soft_violation: float = -0.05 hard_violation: float = -0.50 redundancy: float = -0.10 resource_overspend: float = -0.30 failure: float = -0.30 # Hard cap on what a single shaping step can earn. Without this a # policy could in principle stack milestone + evidence_quality + # tool_fit + valid_action and approach the terminal reward magnitude. step_reward_clip: float = 0.75 # ── terminal grading ──────────────────────────────────────── terminal_scale: float = 5.0 # multiplied with the convex sum below mass_calibration: float = 0.30 significance_quality: float = 0.20 channel_correctness: float = 0.20 spin_correctness: float = 0.10 width_calibration: float = 0.05 confidence_calibration: float = 0.10 efficiency_bonus: float = 0.05 overconfident_wrong_penalty: float = 4.0 # subtracted from terminal overclaim_significance_penalty: float = 1.5 # claim_sigma >> measured_sigma # ── Outputs ────────────────────────────────────────────────────────────── @dataclass class RewardBreakdown: components: Dict[str, float] = field(default_factory=dict) total: float = 0.0 def add(self, key: str, value: float) -> None: self.components[key] = self.components.get(key, 0.0) + value self.total += value @dataclass class StepReward: reward: float breakdown: RewardBreakdown @dataclass class TerminalReward: reward: float breakdown: RewardBreakdown discovered: bool correct_mass: bool correct_channel: bool correct_spin: bool # ── Per-step ───────────────────────────────────────────────────────────── _PROGRESS_FLAGS = [ "beam_configured", "luminosity_allocated", "trigger_set", "collisions_collected", "channel_selected", "tracks_reconstructed", "detector_calibrated", "invariant_mass_built", "background_subtracted", "resonance_fitted", "significance_estimated", ] def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int: """Number of new progress milestones unlocked this step.""" delta = 0 for flag in _PROGRESS_FLAGS: was = getattr(state_before.progress, flag) now = getattr(state_after.progress, flag) if now and not was: delta += 1 return delta def _consecutive_repeat_count( history: List, action_type: ActionType, look_back: int = 6 ) -> int: """How many times this action_type appeared *consecutively* most recently (excluding the just-applied action). Used to dampen loops. """ if not history: return 0 n = 0 for rec in reversed(history[-look_back:]): if getattr(rec, "action_type", None) == action_type: n += 1 else: break return n def compute_step_reward( *, action: ExperimentAction, output: IntermediateOutput, state_before: FullLatentState, state_after: FullLatentState, rule_result: RuleResult, weights: RewardWeights = RewardWeights(), history: Optional[List] = None, ) -> StepReward: """Compute the per-step shaping reward. ``history`` is the list of ``PipelineStepRecord`` *before* this step. We use it to detect consecutive-repeat loops (e.g. a model spamming the same action_type to farm shaping). All other fields are local. """ breakdown = RewardBreakdown() # ── basic validity / failure ──────────────────────────────────── if rule_result.allowed and output.success: breakdown.add("valid_action", weights.valid_action) if not output.success: breakdown.add("failure", weights.failure) # ── milestone progress (one-shot per flag, anti-farming) ──────── new_milestones = _milestone_progress(state_before, state_after) if new_milestones > 0: breakdown.add("progress", weights.progress_milestone * new_milestones) # ── evidence quality ──────────────────────────────────────────── if output.success: breakdown.add( "evidence_quality", weights.evidence_quality * float(output.quality_score), ) # ── tool fit: gated on TOOL_REGISTRY membership + category match ─ # Bogus or mismatched method strings are explicitly penalised so the # model can't farm shaping reward by setting method='whatever'. if action.method: if is_recommended_tool(action.action_type, action.method): breakdown.add("tool_fit", weights.tool_fit) elif action.method not in TOOL_REGISTRY: breakdown.add("bogus_method", weights.bogus_method_penalty) # If the tool exists but the category doesn't match the action, # we silently award nothing (no penalty, no reward). # ── rule penalties ────────────────────────────────────────────── if rule_result.violations: breakdown.add( "hard_violation", weights.hard_violation * len(rule_result.violations), ) if rule_result.soft_violations: soft_redundant = sum( 1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT ) soft_other = len(rule_result.soft_violations) - soft_redundant if soft_redundant: breakdown.add("redundancy", weights.redundancy * soft_redundant) if soft_other: breakdown.add("soft_violation", weights.soft_violation * soft_other) # ── consecutive-repeat penalty (catches loop hacks) ───────────── # Two-in-a-row is mildly OK (sometimes you re-collect data); three # or more identical action_types in a row earns escalating penalty. repeats = _consecutive_repeat_count(history or [], action.action_type) if repeats >= 2: breakdown.add( "repeat_action", weights.repeat_action_penalty * (repeats - 1), ) # ── resource overspend ────────────────────────────────────────── res = state_after.resources if res.budget_used_musd > res.budget_total_musd: breakdown.add("budget_overspend", weights.resource_overspend) if res.luminosity_used_fb > res.luminosity_total_fb: breakdown.add("lumi_overspend", weights.resource_overspend) if res.time_used_days > res.time_limit_days: breakdown.add("time_overspend", weights.resource_overspend) # ── total + soft cap ──────────────────────────────────────────── total = float(breakdown.total) if weights.step_reward_clip > 0: total = float(np.clip(total, -10.0, weights.step_reward_clip)) return StepReward(reward=total, breakdown=breakdown) # ── Terminal grading ───────────────────────────────────────────────────── def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float: """1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger).""" if claim_mass is None or true_mass <= 0: return 0.0 err = abs(claim_mass - true_mass) # Tolerance: max(1.0 GeV, 1% of true mass, claimed unc) tol = max(1.0, 0.01 * true_mass) if unc is not None and unc > 0: tol = max(tol, float(unc)) if err <= tol: return 1.0 if err >= 5 * tol: return 0.0 return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0)) def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float: """High score when claimed σ matches measured σ and is ≥ 5. A claim_sigma far above the measured significance is a classic reward-hacking pattern (just write '50' in the field), so we penalise over-claiming proportionally. """ measured = state.progress.best_significance_sigma or 0.0 if claim_sigma is None: return 0.0 over_claim = max(0.0, claim_sigma - measured) base = float(np.clip(measured / 5.0, 0.0, 1.0)) penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0)) return float(np.clip(base - 0.5 * penalty, 0.0, 1.0)) def _significance_overclaim( state: FullLatentState, claim_sigma: Optional[float], threshold_sigma: float = 1.5 ) -> float: """How many σ the claim *exceeds* what the env actually measured. Used as an extra penalty — distinct from ``_significance_score`` — so that a model can't compensate a giant over-claim by getting the mass slightly more accurate. Returns ``max(0, claim - measured - τ)``. """ if claim_sigma is None: return 0.0 measured = state.progress.best_significance_sigma or 0.0 return float(max(0.0, claim_sigma - measured - threshold_sigma)) def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float: """Reward agents whose confidence tracks their actual accuracy.""" actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0) err = abs(actual - claim_conf) return float(np.clip(1.0 - err, 0.0, 1.0)) def _efficiency_bonus(state: FullLatentState) -> float: """Reward leftover budget (encourages succinct experiments).""" res = state.resources score = 0.0 score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0) score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0) score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0) return float(score / 3.0) def compute_terminal_reward( *, state: FullLatentState, claim: DiscoveryClaim, weights: RewardWeights = RewardWeights(), ) -> TerminalReward: breakdown = RewardBreakdown() truth = state.particle mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev) breakdown.add("mass_calibration", weights.mass_calibration * mass_score) sig_score = _significance_score(state, claim.significance_sigma) breakdown.add("significance_quality", weights.significance_quality * sig_score) channel_ok = claim.decay_channel == truth.primary_channel breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0)) spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0)) width_score = 0.0 if claim.width_estimate_gev is not None and truth.width_gev > 0: rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3) width_score = float(np.clip(1.0 - rel, 0.0, 1.0)) breakdown.add("width_calibration", weights.width_calibration * width_score) conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok) breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score) eff_score = _efficiency_bonus(state) breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score) discovered = ( mass_score >= 0.5 and channel_ok and (claim.significance_sigma or 0.0) >= 4.5 ) raw = breakdown.total * weights.terminal_scale # Overconfident-wrong penalty: high confidence but wrong channel & far mass if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok): raw -= weights.overconfident_wrong_penalty breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty) # Significance-overclaim penalty (anti-reward-hacking): discourages the # model from just writing a giant σ in the claim regardless of evidence. overclaim_sigma = _significance_overclaim(state, claim.significance_sigma) if overclaim_sigma > 0: pen = weights.overclaim_significance_penalty * float( np.clip(overclaim_sigma / 3.0, 0.0, 2.0) ) raw -= pen breakdown.add("overclaim_significance", -pen) # If the claim has zero/None mass and zero/None significance, treat it # as a "no-information" submission — clamp the raw reward so the model # can't pass the rules engine and then submit garbage to end early. if (claim.mass_estimate_gev is None) and (claim.significance_sigma in (None, 0.0)): raw = float(min(raw, 0.0)) breakdown.add("no_information_claim", 0.0) return TerminalReward( reward=float(raw), breakdown=breakdown, discovered=discovered, correct_mass=mass_score >= 0.5, correct_channel=channel_ok, correct_spin=spin_ok, ) __all__ = [ "RewardBreakdown", "RewardWeights", "StepReward", "TerminalReward", "compute_step_reward", "compute_terminal_reward", ]