| """Decomposed reward function.
|
|
|
| Two stages:
|
| 1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small
|
| incentives (progress, evidence quality, valid prerequisites) and
|
| penalties (rule violations, repeated work, wasted resources).
|
| 2. **Terminal reward** ``compute_terminal_reward``: graded only when the
|
| agent submits a discovery claim or runs out of resources. Compares the
|
| submitted claim against the hidden ``LatentParticle`` truth.
|
|
|
| The terminal reward is intentionally dominant so the policy must care about
|
| the *correct* discovery, not just looking busy.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from dataclasses import dataclass, field
|
| from typing import Dict, List, Optional
|
|
|
| import numpy as np
|
|
|
| from models import (
|
| ActionType,
|
| DiscoveryClaim,
|
| ExperimentAction,
|
| IntermediateOutput,
|
| )
|
|
|
| from server.rules.engine import RuleResult, ViolationCode
|
| from server.simulator.latent_state import FullLatentState
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class RewardWeights:
|
|
|
| valid_action: float = 0.05
|
| progress_milestone: float = 0.25
|
| evidence_quality: float = 0.20
|
| tool_fit: float = 0.10
|
| soft_violation: float = -0.05
|
| hard_violation: float = -0.50
|
| redundancy: float = -0.10
|
| resource_overspend: float = -0.30
|
| failure: float = -0.30
|
|
|
|
|
| terminal_scale: float = 5.0
|
|
|
| mass_calibration: float = 0.30
|
| significance_quality: float = 0.20
|
| channel_correctness: float = 0.20
|
| spin_correctness: float = 0.10
|
| width_calibration: float = 0.05
|
| confidence_calibration: float = 0.10
|
| efficiency_bonus: float = 0.05
|
|
|
| overconfident_wrong_penalty: float = 4.0
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class RewardBreakdown:
|
| components: Dict[str, float] = field(default_factory=dict)
|
| total: float = 0.0
|
|
|
| def add(self, key: str, value: float) -> None:
|
| self.components[key] = self.components.get(key, 0.0) + value
|
| self.total += value
|
|
|
|
|
| @dataclass
|
| class StepReward:
|
| reward: float
|
| breakdown: RewardBreakdown
|
|
|
|
|
| @dataclass
|
| class TerminalReward:
|
| reward: float
|
| breakdown: RewardBreakdown
|
| discovered: bool
|
| correct_mass: bool
|
| correct_channel: bool
|
| correct_spin: bool
|
|
|
|
|
|
|
|
|
|
|
| _PROGRESS_FLAGS = [
|
| "beam_configured",
|
| "luminosity_allocated",
|
| "trigger_set",
|
| "collisions_collected",
|
| "channel_selected",
|
| "tracks_reconstructed",
|
| "detector_calibrated",
|
| "invariant_mass_built",
|
| "background_subtracted",
|
| "resonance_fitted",
|
| "significance_estimated",
|
| ]
|
|
|
|
|
| def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int:
|
| """Number of new progress milestones unlocked this step."""
|
| delta = 0
|
| for flag in _PROGRESS_FLAGS:
|
| was = getattr(state_before.progress, flag)
|
| now = getattr(state_after.progress, flag)
|
| if now and not was:
|
| delta += 1
|
| return delta
|
|
|
|
|
| def compute_step_reward(
|
| *,
|
| action: ExperimentAction,
|
| output: IntermediateOutput,
|
| state_before: FullLatentState,
|
| state_after: FullLatentState,
|
| rule_result: RuleResult,
|
| weights: RewardWeights = RewardWeights(),
|
| ) -> StepReward:
|
| breakdown = RewardBreakdown()
|
|
|
| if rule_result.allowed and output.success:
|
| breakdown.add("valid_action", weights.valid_action)
|
| if not output.success:
|
| breakdown.add("failure", weights.failure)
|
|
|
|
|
| new_milestones = _milestone_progress(state_before, state_after)
|
| if new_milestones > 0:
|
| breakdown.add("progress", weights.progress_milestone * new_milestones)
|
|
|
|
|
| if output.success:
|
| breakdown.add("evidence_quality", weights.evidence_quality * float(output.quality_score))
|
|
|
|
|
| if action.method:
|
| breakdown.add("tool_fit", weights.tool_fit * 0.5)
|
|
|
|
|
| if rule_result.violations:
|
| breakdown.add("hard_violation", weights.hard_violation * len(rule_result.violations))
|
| if rule_result.soft_violations:
|
| soft_redundant = sum(1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT)
|
| soft_other = len(rule_result.soft_violations) - soft_redundant
|
| if soft_redundant:
|
| breakdown.add("redundancy", weights.redundancy * soft_redundant)
|
| if soft_other:
|
| breakdown.add("soft_violation", weights.soft_violation * soft_other)
|
|
|
|
|
| res = state_after.resources
|
| if res.budget_used_musd > res.budget_total_musd:
|
| breakdown.add("budget_overspend", weights.resource_overspend)
|
| if res.luminosity_used_fb > res.luminosity_total_fb:
|
| breakdown.add("lumi_overspend", weights.resource_overspend)
|
| if res.time_used_days > res.time_limit_days:
|
| breakdown.add("time_overspend", weights.resource_overspend)
|
|
|
| return StepReward(reward=float(breakdown.total), breakdown=breakdown)
|
|
|
|
|
|
|
|
|
|
|
| def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float:
|
| """1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger)."""
|
| if claim_mass is None or true_mass <= 0:
|
| return 0.0
|
| err = abs(claim_mass - true_mass)
|
|
|
| tol = max(1.0, 0.01 * true_mass)
|
| if unc is not None and unc > 0:
|
| tol = max(tol, float(unc))
|
| if err <= tol:
|
| return 1.0
|
| if err >= 5 * tol:
|
| return 0.0
|
| return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0))
|
|
|
|
|
| def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float:
|
| """High score when claimed σ matches measured σ and is ≥ 5."""
|
| measured = state.progress.best_significance_sigma or 0.0
|
| if claim_sigma is None:
|
| return 0.0
|
| over_claim = max(0.0, claim_sigma - measured)
|
| base = float(np.clip(measured / 5.0, 0.0, 1.0))
|
| penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0))
|
| return float(np.clip(base - 0.5 * penalty, 0.0, 1.0))
|
|
|
|
|
| def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float:
|
| """Reward agents whose confidence tracks their actual accuracy."""
|
| actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0)
|
| err = abs(actual - claim_conf)
|
| return float(np.clip(1.0 - err, 0.0, 1.0))
|
|
|
|
|
| def _efficiency_bonus(state: FullLatentState) -> float:
|
| """Reward leftover budget (encourages succinct experiments)."""
|
| res = state.resources
|
| score = 0.0
|
| score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0)
|
| score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0)
|
| score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0)
|
| return float(score / 3.0)
|
|
|
|
|
| def compute_terminal_reward(
|
| *,
|
| state: FullLatentState,
|
| claim: DiscoveryClaim,
|
| weights: RewardWeights = RewardWeights(),
|
| ) -> TerminalReward:
|
| breakdown = RewardBreakdown()
|
| truth = state.particle
|
|
|
| mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
|
| breakdown.add("mass_calibration", weights.mass_calibration * mass_score)
|
|
|
| sig_score = _significance_score(state, claim.significance_sigma)
|
| breakdown.add("significance_quality", weights.significance_quality * sig_score)
|
|
|
| channel_ok = claim.decay_channel == truth.primary_channel
|
| breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0))
|
|
|
| spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin
|
| breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0))
|
|
|
| width_score = 0.0
|
| if claim.width_estimate_gev is not None and truth.width_gev > 0:
|
| rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3)
|
| width_score = float(np.clip(1.0 - rel, 0.0, 1.0))
|
| breakdown.add("width_calibration", weights.width_calibration * width_score)
|
|
|
| conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok)
|
| breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score)
|
|
|
| eff_score = _efficiency_bonus(state)
|
| breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score)
|
|
|
| discovered = (
|
| mass_score >= 0.5
|
| and channel_ok
|
| and (claim.significance_sigma or 0.0) >= 4.5
|
| )
|
|
|
| raw = breakdown.total * weights.terminal_scale
|
|
|
|
|
| if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
|
| raw -= weights.overconfident_wrong_penalty
|
| breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty)
|
|
|
| return TerminalReward(
|
| reward=float(raw),
|
| breakdown=breakdown,
|
| discovered=discovered,
|
| correct_mass=mass_score >= 0.5,
|
| correct_channel=channel_ok,
|
| correct_spin=spin_ok,
|
| )
|
|
|
|
|
| __all__ = [
|
| "RewardBreakdown",
|
| "RewardWeights",
|
| "StepReward",
|
| "TerminalReward",
|
| "compute_step_reward",
|
| "compute_terminal_reward",
|
| ]
|
|
|