cernenv / server /rewards /reward_function.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""Decomposed reward function.
Two stages:
1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small
incentives (progress, evidence quality, valid prerequisites) and
penalties (rule violations, repeated work, wasted resources).
2. **Terminal reward** ``compute_terminal_reward``: graded only when the
agent submits a discovery claim or runs out of resources. Compares the
submitted claim against the hidden ``LatentParticle`` truth.
The terminal reward is intentionally dominant so the policy must care about
the *correct* discovery, not just looking busy.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import numpy as np
from models import (
ActionType,
DiscoveryClaim,
ExperimentAction,
IntermediateOutput,
)
from server.rules.engine import RuleResult, ViolationCode
from server.simulator.latent_state import FullLatentState
# ── Configuration ────────────────────────────────────────────────────────
@dataclass
class RewardWeights:
# ── per-step shaping ────────────────────────────────────────
valid_action: float = 0.05
progress_milestone: float = 0.25
evidence_quality: float = 0.20
tool_fit: float = 0.10
soft_violation: float = -0.05
hard_violation: float = -0.50
redundancy: float = -0.10
resource_overspend: float = -0.30
failure: float = -0.30
# ── terminal grading ────────────────────────────────────────
terminal_scale: float = 5.0 # multiplied with the convex sum below
mass_calibration: float = 0.30
significance_quality: float = 0.20
channel_correctness: float = 0.20
spin_correctness: float = 0.10
width_calibration: float = 0.05
confidence_calibration: float = 0.10
efficiency_bonus: float = 0.05
overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
# ── Outputs ──────────────────────────────────────────────────────────────
@dataclass
class RewardBreakdown:
components: Dict[str, float] = field(default_factory=dict)
total: float = 0.0
def add(self, key: str, value: float) -> None:
self.components[key] = self.components.get(key, 0.0) + value
self.total += value
@dataclass
class StepReward:
reward: float
breakdown: RewardBreakdown
@dataclass
class TerminalReward:
reward: float
breakdown: RewardBreakdown
discovered: bool
correct_mass: bool
correct_channel: bool
correct_spin: bool
# ── Per-step ─────────────────────────────────────────────────────────────
_PROGRESS_FLAGS = [
"beam_configured",
"luminosity_allocated",
"trigger_set",
"collisions_collected",
"channel_selected",
"tracks_reconstructed",
"detector_calibrated",
"invariant_mass_built",
"background_subtracted",
"resonance_fitted",
"significance_estimated",
]
def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int:
"""Number of new progress milestones unlocked this step."""
delta = 0
for flag in _PROGRESS_FLAGS:
was = getattr(state_before.progress, flag)
now = getattr(state_after.progress, flag)
if now and not was:
delta += 1
return delta
def compute_step_reward(
*,
action: ExperimentAction,
output: IntermediateOutput,
state_before: FullLatentState,
state_after: FullLatentState,
rule_result: RuleResult,
weights: RewardWeights = RewardWeights(),
) -> StepReward:
breakdown = RewardBreakdown()
if rule_result.allowed and output.success:
breakdown.add("valid_action", weights.valid_action)
if not output.success:
breakdown.add("failure", weights.failure)
# progress
new_milestones = _milestone_progress(state_before, state_after)
if new_milestones > 0:
breakdown.add("progress", weights.progress_milestone * new_milestones)
# evidence quality
if output.success:
breakdown.add("evidence_quality", weights.evidence_quality * float(output.quality_score))
# tool fit (named method exists in the recommended toolset)
if action.method:
breakdown.add("tool_fit", weights.tool_fit * 0.5)
# rule penalties
if rule_result.violations:
breakdown.add("hard_violation", weights.hard_violation * len(rule_result.violations))
if rule_result.soft_violations:
soft_redundant = sum(1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT)
soft_other = len(rule_result.soft_violations) - soft_redundant
if soft_redundant:
breakdown.add("redundancy", weights.redundancy * soft_redundant)
if soft_other:
breakdown.add("soft_violation", weights.soft_violation * soft_other)
# resource overspend
res = state_after.resources
if res.budget_used_musd > res.budget_total_musd:
breakdown.add("budget_overspend", weights.resource_overspend)
if res.luminosity_used_fb > res.luminosity_total_fb:
breakdown.add("lumi_overspend", weights.resource_overspend)
if res.time_used_days > res.time_limit_days:
breakdown.add("time_overspend", weights.resource_overspend)
return StepReward(reward=float(breakdown.total), breakdown=breakdown)
# ── Terminal grading ─────────────────────────────────────────────────────
def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float:
"""1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger)."""
if claim_mass is None or true_mass <= 0:
return 0.0
err = abs(claim_mass - true_mass)
# Tolerance: max(1.0 GeV, 1% of true mass, claimed unc)
tol = max(1.0, 0.01 * true_mass)
if unc is not None and unc > 0:
tol = max(tol, float(unc))
if err <= tol:
return 1.0
if err >= 5 * tol:
return 0.0
return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0))
def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float:
"""High score when claimed σ matches measured σ and is ≥ 5."""
measured = state.progress.best_significance_sigma or 0.0
if claim_sigma is None:
return 0.0
over_claim = max(0.0, claim_sigma - measured)
base = float(np.clip(measured / 5.0, 0.0, 1.0))
penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0))
return float(np.clip(base - 0.5 * penalty, 0.0, 1.0))
def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float:
"""Reward agents whose confidence tracks their actual accuracy."""
actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0)
err = abs(actual - claim_conf)
return float(np.clip(1.0 - err, 0.0, 1.0))
def _efficiency_bonus(state: FullLatentState) -> float:
"""Reward leftover budget (encourages succinct experiments)."""
res = state.resources
score = 0.0
score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0)
score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0)
score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0)
return float(score / 3.0)
def compute_terminal_reward(
*,
state: FullLatentState,
claim: DiscoveryClaim,
weights: RewardWeights = RewardWeights(),
) -> TerminalReward:
breakdown = RewardBreakdown()
truth = state.particle
mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
breakdown.add("mass_calibration", weights.mass_calibration * mass_score)
sig_score = _significance_score(state, claim.significance_sigma)
breakdown.add("significance_quality", weights.significance_quality * sig_score)
channel_ok = claim.decay_channel == truth.primary_channel
breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0))
spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin
breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0))
width_score = 0.0
if claim.width_estimate_gev is not None and truth.width_gev > 0:
rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3)
width_score = float(np.clip(1.0 - rel, 0.0, 1.0))
breakdown.add("width_calibration", weights.width_calibration * width_score)
conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok)
breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score)
eff_score = _efficiency_bonus(state)
breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score)
discovered = (
mass_score >= 0.5
and channel_ok
and (claim.significance_sigma or 0.0) >= 4.5
)
raw = breakdown.total * weights.terminal_scale
# Overconfident-wrong penalty: high confidence but wrong channel & far mass
if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
raw -= weights.overconfident_wrong_penalty
breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty)
return TerminalReward(
reward=float(raw),
breakdown=breakdown,
discovered=discovered,
correct_mass=mass_score >= 0.5,
correct_channel=channel_ok,
correct_spin=spin_ok,
)
__all__ = [
"RewardBreakdown",
"RewardWeights",
"StepReward",
"TerminalReward",
"compute_step_reward",
"compute_terminal_reward",
]