"""Decomposed reward function. Two stages: 1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small incentives (progress, evidence quality, valid prerequisites) and penalties (rule violations, repeated work, wasted resources). 2. **Terminal reward** ``compute_terminal_reward``: graded only when the agent submits a discovery claim or runs out of resources. Compares the submitted claim against the hidden ``LatentParticle`` truth. The terminal reward is intentionally dominant so the policy must care about the *correct* discovery, not just looking busy. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Dict, List, Optional import numpy as np from models import ( ActionType, DiscoveryClaim, ExperimentAction, IntermediateOutput, ) from server.rules.engine import RuleResult, ViolationCode from server.simulator.latent_state import FullLatentState # ── Configuration ──────────────────────────────────────────────────────── @dataclass class RewardWeights: # ── per-step shaping ──────────────────────────────────────── valid_action: float = 0.05 progress_milestone: float = 0.25 evidence_quality: float = 0.20 tool_fit: float = 0.10 soft_violation: float = -0.05 hard_violation: float = -0.50 redundancy: float = -0.10 resource_overspend: float = -0.30 failure: float = -0.30 # ── terminal grading ──────────────────────────────────────── terminal_scale: float = 5.0 # multiplied with the convex sum below mass_calibration: float = 0.30 significance_quality: float = 0.20 channel_correctness: float = 0.20 spin_correctness: float = 0.10 width_calibration: float = 0.05 confidence_calibration: float = 0.10 efficiency_bonus: float = 0.05 overconfident_wrong_penalty: float = 4.0 # subtracted from terminal # ── Outputs ────────────────────────────────────────────────────────────── @dataclass class RewardBreakdown: components: Dict[str, float] = field(default_factory=dict) total: float = 0.0 def add(self, key: str, value: float) -> None: self.components[key] = self.components.get(key, 0.0) + value self.total += value @dataclass class StepReward: reward: float breakdown: RewardBreakdown @dataclass class TerminalReward: reward: float breakdown: RewardBreakdown discovered: bool correct_mass: bool correct_channel: bool correct_spin: bool # ── Per-step ───────────────────────────────────────────────────────────── _PROGRESS_FLAGS = [ "beam_configured", "luminosity_allocated", "trigger_set", "collisions_collected", "channel_selected", "tracks_reconstructed", "detector_calibrated", "invariant_mass_built", "background_subtracted", "resonance_fitted", "significance_estimated", ] def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int: """Number of new progress milestones unlocked this step.""" delta = 0 for flag in _PROGRESS_FLAGS: was = getattr(state_before.progress, flag) now = getattr(state_after.progress, flag) if now and not was: delta += 1 return delta def compute_step_reward( *, action: ExperimentAction, output: IntermediateOutput, state_before: FullLatentState, state_after: FullLatentState, rule_result: RuleResult, weights: RewardWeights = RewardWeights(), ) -> StepReward: breakdown = RewardBreakdown() if rule_result.allowed and output.success: breakdown.add("valid_action", weights.valid_action) if not output.success: breakdown.add("failure", weights.failure) # progress new_milestones = _milestone_progress(state_before, state_after) if new_milestones > 0: breakdown.add("progress", weights.progress_milestone * new_milestones) # evidence quality if output.success: breakdown.add("evidence_quality", weights.evidence_quality * float(output.quality_score)) # tool fit (named method exists in the recommended toolset) if action.method: breakdown.add("tool_fit", weights.tool_fit * 0.5) # rule penalties if rule_result.violations: breakdown.add("hard_violation", weights.hard_violation * len(rule_result.violations)) if rule_result.soft_violations: soft_redundant = sum(1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT) soft_other = len(rule_result.soft_violations) - soft_redundant if soft_redundant: breakdown.add("redundancy", weights.redundancy * soft_redundant) if soft_other: breakdown.add("soft_violation", weights.soft_violation * soft_other) # resource overspend res = state_after.resources if res.budget_used_musd > res.budget_total_musd: breakdown.add("budget_overspend", weights.resource_overspend) if res.luminosity_used_fb > res.luminosity_total_fb: breakdown.add("lumi_overspend", weights.resource_overspend) if res.time_used_days > res.time_limit_days: breakdown.add("time_overspend", weights.resource_overspend) return StepReward(reward=float(breakdown.total), breakdown=breakdown) # ── Terminal grading ───────────────────────────────────────────────────── def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float: """1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger).""" if claim_mass is None or true_mass <= 0: return 0.0 err = abs(claim_mass - true_mass) # Tolerance: max(1.0 GeV, 1% of true mass, claimed unc) tol = max(1.0, 0.01 * true_mass) if unc is not None and unc > 0: tol = max(tol, float(unc)) if err <= tol: return 1.0 if err >= 5 * tol: return 0.0 return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0)) def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float: """High score when claimed σ matches measured σ and is ≥ 5.""" measured = state.progress.best_significance_sigma or 0.0 if claim_sigma is None: return 0.0 over_claim = max(0.0, claim_sigma - measured) base = float(np.clip(measured / 5.0, 0.0, 1.0)) penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0)) return float(np.clip(base - 0.5 * penalty, 0.0, 1.0)) def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float: """Reward agents whose confidence tracks their actual accuracy.""" actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0) err = abs(actual - claim_conf) return float(np.clip(1.0 - err, 0.0, 1.0)) def _efficiency_bonus(state: FullLatentState) -> float: """Reward leftover budget (encourages succinct experiments).""" res = state.resources score = 0.0 score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0) score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0) score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0) return float(score / 3.0) def compute_terminal_reward( *, state: FullLatentState, claim: DiscoveryClaim, weights: RewardWeights = RewardWeights(), ) -> TerminalReward: breakdown = RewardBreakdown() truth = state.particle mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev) breakdown.add("mass_calibration", weights.mass_calibration * mass_score) sig_score = _significance_score(state, claim.significance_sigma) breakdown.add("significance_quality", weights.significance_quality * sig_score) channel_ok = claim.decay_channel == truth.primary_channel breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0)) spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0)) width_score = 0.0 if claim.width_estimate_gev is not None and truth.width_gev > 0: rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3) width_score = float(np.clip(1.0 - rel, 0.0, 1.0)) breakdown.add("width_calibration", weights.width_calibration * width_score) conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok) breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score) eff_score = _efficiency_bonus(state) breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score) discovered = ( mass_score >= 0.5 and channel_ok and (claim.significance_sigma or 0.0) >= 4.5 ) raw = breakdown.total * weights.terminal_scale # Overconfident-wrong penalty: high confidence but wrong channel & far mass if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok): raw -= weights.overconfident_wrong_penalty breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty) return TerminalReward( reward=float(raw), breakdown=breakdown, discovered=discovered, correct_mass=mass_score >= 0.5, correct_channel=channel_ok, correct_spin=spin_ok, ) __all__ = [ "RewardBreakdown", "RewardWeights", "StepReward", "TerminalReward", "compute_step_reward", "compute_terminal_reward", ]