Spaces:
Sleeping
Sleeping
| """Decomposed reward function. | |
| Two stages: | |
| 1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small | |
| incentives (progress, evidence quality, valid prerequisites) and | |
| penalties (rule violations, repeated work, wasted resources). | |
| 2. **Terminal reward** ``compute_terminal_reward``: graded only when the | |
| agent submits a discovery claim or runs out of resources. Compares the | |
| submitted claim against the hidden ``LatentParticle`` truth. | |
| The terminal reward is intentionally dominant so the policy must care about | |
| the *correct* discovery, not just looking busy. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| from models import ( | |
| ActionType, | |
| DiscoveryClaim, | |
| ExperimentAction, | |
| IntermediateOutput, | |
| ) | |
| from server.rules.engine import RuleResult, ViolationCode | |
| from server.simulator.latent_state import FullLatentState | |
| # ── Configuration ──────────────────────────────────────────────────────── | |
| class RewardWeights: | |
| # ── per-step shaping ──────────────────────────────────────── | |
| valid_action: float = 0.05 | |
| progress_milestone: float = 0.25 | |
| evidence_quality: float = 0.20 | |
| tool_fit: float = 0.10 | |
| soft_violation: float = -0.05 | |
| hard_violation: float = -0.50 | |
| redundancy: float = -0.10 | |
| resource_overspend: float = -0.30 | |
| failure: float = -0.30 | |
| # ── terminal grading ──────────────────────────────────────── | |
| terminal_scale: float = 5.0 # multiplied with the convex sum below | |
| mass_calibration: float = 0.30 | |
| significance_quality: float = 0.20 | |
| channel_correctness: float = 0.20 | |
| spin_correctness: float = 0.10 | |
| width_calibration: float = 0.05 | |
| confidence_calibration: float = 0.10 | |
| efficiency_bonus: float = 0.05 | |
| overconfident_wrong_penalty: float = 4.0 # subtracted from terminal | |
| # ── Outputs ────────────────────────────────────────────────────────────── | |
| class RewardBreakdown: | |
| components: Dict[str, float] = field(default_factory=dict) | |
| total: float = 0.0 | |
| def add(self, key: str, value: float) -> None: | |
| self.components[key] = self.components.get(key, 0.0) + value | |
| self.total += value | |
| class StepReward: | |
| reward: float | |
| breakdown: RewardBreakdown | |
| class TerminalReward: | |
| reward: float | |
| breakdown: RewardBreakdown | |
| discovered: bool | |
| correct_mass: bool | |
| correct_channel: bool | |
| correct_spin: bool | |
| # ── Per-step ───────────────────────────────────────────────────────────── | |
| _PROGRESS_FLAGS = [ | |
| "beam_configured", | |
| "luminosity_allocated", | |
| "trigger_set", | |
| "collisions_collected", | |
| "channel_selected", | |
| "tracks_reconstructed", | |
| "detector_calibrated", | |
| "invariant_mass_built", | |
| "background_subtracted", | |
| "resonance_fitted", | |
| "significance_estimated", | |
| ] | |
| def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int: | |
| """Number of new progress milestones unlocked this step.""" | |
| delta = 0 | |
| for flag in _PROGRESS_FLAGS: | |
| was = getattr(state_before.progress, flag) | |
| now = getattr(state_after.progress, flag) | |
| if now and not was: | |
| delta += 1 | |
| return delta | |
| def compute_step_reward( | |
| *, | |
| action: ExperimentAction, | |
| output: IntermediateOutput, | |
| state_before: FullLatentState, | |
| state_after: FullLatentState, | |
| rule_result: RuleResult, | |
| weights: RewardWeights = RewardWeights(), | |
| ) -> StepReward: | |
| breakdown = RewardBreakdown() | |
| if rule_result.allowed and output.success: | |
| breakdown.add("valid_action", weights.valid_action) | |
| if not output.success: | |
| breakdown.add("failure", weights.failure) | |
| # progress | |
| new_milestones = _milestone_progress(state_before, state_after) | |
| if new_milestones > 0: | |
| breakdown.add("progress", weights.progress_milestone * new_milestones) | |
| # evidence quality | |
| if output.success: | |
| breakdown.add("evidence_quality", weights.evidence_quality * float(output.quality_score)) | |
| # tool fit (named method exists in the recommended toolset) | |
| if action.method: | |
| breakdown.add("tool_fit", weights.tool_fit * 0.5) | |
| # rule penalties | |
| if rule_result.violations: | |
| breakdown.add("hard_violation", weights.hard_violation * len(rule_result.violations)) | |
| if rule_result.soft_violations: | |
| soft_redundant = sum(1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT) | |
| soft_other = len(rule_result.soft_violations) - soft_redundant | |
| if soft_redundant: | |
| breakdown.add("redundancy", weights.redundancy * soft_redundant) | |
| if soft_other: | |
| breakdown.add("soft_violation", weights.soft_violation * soft_other) | |
| # resource overspend | |
| res = state_after.resources | |
| if res.budget_used_musd > res.budget_total_musd: | |
| breakdown.add("budget_overspend", weights.resource_overspend) | |
| if res.luminosity_used_fb > res.luminosity_total_fb: | |
| breakdown.add("lumi_overspend", weights.resource_overspend) | |
| if res.time_used_days > res.time_limit_days: | |
| breakdown.add("time_overspend", weights.resource_overspend) | |
| return StepReward(reward=float(breakdown.total), breakdown=breakdown) | |
| # ── Terminal grading ───────────────────────────────────────────────────── | |
| def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float: | |
| """1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger).""" | |
| if claim_mass is None or true_mass <= 0: | |
| return 0.0 | |
| err = abs(claim_mass - true_mass) | |
| # Tolerance: max(1.0 GeV, 1% of true mass, claimed unc) | |
| tol = max(1.0, 0.01 * true_mass) | |
| if unc is not None and unc > 0: | |
| tol = max(tol, float(unc)) | |
| if err <= tol: | |
| return 1.0 | |
| if err >= 5 * tol: | |
| return 0.0 | |
| return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0)) | |
| def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float: | |
| """High score when claimed σ matches measured σ and is ≥ 5.""" | |
| measured = state.progress.best_significance_sigma or 0.0 | |
| if claim_sigma is None: | |
| return 0.0 | |
| over_claim = max(0.0, claim_sigma - measured) | |
| base = float(np.clip(measured / 5.0, 0.0, 1.0)) | |
| penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0)) | |
| return float(np.clip(base - 0.5 * penalty, 0.0, 1.0)) | |
| def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float: | |
| """Reward agents whose confidence tracks their actual accuracy.""" | |
| actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0) | |
| err = abs(actual - claim_conf) | |
| return float(np.clip(1.0 - err, 0.0, 1.0)) | |
| def _efficiency_bonus(state: FullLatentState) -> float: | |
| """Reward leftover budget (encourages succinct experiments).""" | |
| res = state.resources | |
| score = 0.0 | |
| score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0) | |
| score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0) | |
| score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0) | |
| return float(score / 3.0) | |
| def compute_terminal_reward( | |
| *, | |
| state: FullLatentState, | |
| claim: DiscoveryClaim, | |
| weights: RewardWeights = RewardWeights(), | |
| ) -> TerminalReward: | |
| breakdown = RewardBreakdown() | |
| truth = state.particle | |
| mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev) | |
| breakdown.add("mass_calibration", weights.mass_calibration * mass_score) | |
| sig_score = _significance_score(state, claim.significance_sigma) | |
| breakdown.add("significance_quality", weights.significance_quality * sig_score) | |
| channel_ok = claim.decay_channel == truth.primary_channel | |
| breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0)) | |
| spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin | |
| breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0)) | |
| width_score = 0.0 | |
| if claim.width_estimate_gev is not None and truth.width_gev > 0: | |
| rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3) | |
| width_score = float(np.clip(1.0 - rel, 0.0, 1.0)) | |
| breakdown.add("width_calibration", weights.width_calibration * width_score) | |
| conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok) | |
| breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score) | |
| eff_score = _efficiency_bonus(state) | |
| breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score) | |
| discovered = ( | |
| mass_score >= 0.5 | |
| and channel_ok | |
| and (claim.significance_sigma or 0.0) >= 4.5 | |
| ) | |
| raw = breakdown.total * weights.terminal_scale | |
| # Overconfident-wrong penalty: high confidence but wrong channel & far mass | |
| if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok): | |
| raw -= weights.overconfident_wrong_penalty | |
| breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty) | |
| return TerminalReward( | |
| reward=float(raw), | |
| breakdown=breakdown, | |
| discovered=discovered, | |
| correct_mass=mass_score >= 0.5, | |
| correct_channel=channel_ok, | |
| correct_spin=spin_ok, | |
| ) | |
| __all__ = [ | |
| "RewardBreakdown", | |
| "RewardWeights", | |
| "StepReward", | |
| "TerminalReward", | |
| "compute_step_reward", | |
| "compute_terminal_reward", | |
| ] | |