cernenv / tests /test_rewards.py
anugrahhu's picture
Update CERNenv Space
f28409b verified
"""Tests for the per-step + terminal reward function.
This is the most safety-critical module in CERNenv. The test suite is
structured around the same anti-hacking principles called out in the
hackathon FAQ (Q12, Q13, Q42, Q56, Q57):
* the terminal grade should *dominate* total reward,
* shaping rewards must be hard to farm,
* obvious model "cheats" (string-spam, claim-spam, JSON-spam) must
not produce high reward.
"""
from __future__ import annotations
import pytest
from models import (
ActionType,
DiscoveryClaim,
ExperimentAction,
IntermediateOutput,
OutputType,
)
from server.rewards.reward_function import (
RewardWeights,
_mass_score,
_significance_overclaim,
_significance_score,
compute_step_reward,
compute_terminal_reward,
)
from server.rules.engine import RuleResult, ViolationCode
from server.simulator.latent_state import FullLatentState, LatentParticle, ResourceState
# ── helpers ─────────────────────────────────────────────────────────────
def _passing_rule_result() -> RuleResult:
return RuleResult(allowed=True)
def _failing_rule_result(*violations: ViolationCode) -> RuleResult:
r = RuleResult(allowed=True)
for v in violations:
r.add(v, str(v))
return r
def _ok_output(action_type: ActionType = ActionType.CONFIGURE_BEAM) -> IntermediateOutput:
return IntermediateOutput(
output_type=OutputType.BEAM_CONFIG,
step_index=0,
success=True,
quality_score=0.9,
summary="ok",
)
def _fresh_state() -> FullLatentState:
return FullLatentState(
particle=LatentParticle(),
resources=ResourceState(),
)
# ── _mass_score ─────────────────────────────────────────────────────────
def test_mass_score_perfect_inside_tolerance():
assert _mass_score(125.0, 125.0, unc=None) == pytest.approx(1.0)
def test_mass_score_decays_outside_tolerance():
high = _mass_score(125.0, 125.5, unc=None)
low = _mass_score(125.0, 130.0, unc=None)
assert 0.0 < low < high <= 1.0
def test_mass_score_zero_when_far_off():
assert _mass_score(125.0, 200.0, unc=None) == 0.0
def test_mass_score_returns_zero_when_claim_missing():
assert _mass_score(125.0, None, None) == 0.0
# ── _significance_score and overclaim ───────────────────────────────────
def test_significance_score_uses_measured_value():
"""Under-claiming is fine (we just return the base); over-claiming is
actively penalised (anti-hacking)."""
s = _fresh_state()
s.progress.best_significance_sigma = 5.0
score_match = _significance_score(s, claim_sigma=5.0)
score_over = _significance_score(s, claim_sigma=20.0)
score_none = _significance_score(s, claim_sigma=None)
assert score_match == pytest.approx(1.0)
assert score_over < score_match
assert score_none == 0.0
def test_significance_overclaim_only_above_threshold():
s = _fresh_state()
s.progress.best_significance_sigma = 4.0
assert _significance_overclaim(s, claim_sigma=4.5) == 0.0
assert _significance_overclaim(s, claim_sigma=10.0) > 0.0
# ── compute_step_reward: tool_fit gating (anti-hacking) ─────────────────
def test_bogus_method_string_is_penalised_not_rewarded():
state = _fresh_state()
action = ExperimentAction(
action_type=ActionType.FIT_RESONANCE,
method="LITERAL_GIBBERISH_BOGUS_LMAO",
)
out = _ok_output()
result = compute_step_reward(
action=action,
output=out,
state_before=state,
state_after=state,
rule_result=_passing_rule_result(),
)
assert "tool_fit" not in result.breakdown.components
assert result.breakdown.components.get("bogus_method", 0.0) < 0.0
def test_real_method_with_correct_category_is_rewarded():
state = _fresh_state()
action = ExperimentAction(
action_type=ActionType.FIT_RESONANCE,
method="ROOT_RooFit", # ANALYSIS category, matches FIT_RESONANCE
)
out = _ok_output()
result = compute_step_reward(
action=action,
output=out,
state_before=state,
state_after=state,
rule_result=_passing_rule_result(),
)
assert result.breakdown.components.get("tool_fit", 0.0) > 0.0
def test_real_method_with_mismatched_category_is_silent():
state = _fresh_state()
action = ExperimentAction(
action_type=ActionType.CALIBRATE_DETECTOR,
method="BumpHunter", # STATISTICS, mismatch with CALIBRATION
)
out = _ok_output()
result = compute_step_reward(
action=action,
output=out,
state_before=state,
state_after=state,
rule_result=_passing_rule_result(),
)
assert "tool_fit" not in result.breakdown.components
# method IS in the registry → no bogus_method penalty either
assert "bogus_method" not in result.breakdown.components
# ── compute_step_reward: repeat-action penalty ──────────────────────────
def test_repeat_action_penalty_escalates():
"""Three identical action_types in a row should be penalised."""
from models import PipelineStepRecord
state = _fresh_state()
action_type = ActionType.REQUEST_THEORY_REVIEW
history = [
PipelineStepRecord(
step_index=i,
action_type=action_type,
output_type=OutputType.THEORY_REVIEW,
output_summary="...",
)
for i in range(3)
]
action = ExperimentAction(action_type=action_type)
result = compute_step_reward(
action=action,
output=_ok_output(),
state_before=state,
state_after=state,
rule_result=_passing_rule_result(),
history=history,
)
assert result.breakdown.components.get("repeat_action", 0.0) < 0.0
def test_no_repeat_penalty_for_first_use():
state = _fresh_state()
action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM)
result = compute_step_reward(
action=action,
output=_ok_output(),
state_before=state,
state_after=state,
rule_result=_passing_rule_result(),
history=[],
)
assert "repeat_action" not in result.breakdown.components
# ── compute_step_reward: clip ───────────────────────────────────────────
def test_step_reward_is_clipped_above():
state = _fresh_state()
weights = RewardWeights(step_reward_clip=0.1)
action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM, method="ATLAS_HLT")
result = compute_step_reward(
action=action,
output=_ok_output(),
state_before=state,
state_after=state,
rule_result=_passing_rule_result(),
weights=weights,
)
assert result.reward <= 0.1 + 1e-9
# ── compute_terminal_reward: correctness + overconfidence ───────────────
def test_terminal_reward_high_for_correct_calibrated_claim():
s = _fresh_state()
s.particle = LatentParticle(
mass_gev=125.0, primary_channel="diphoton", spin=0, parity="+", width_gev=0.004,
)
s.progress.best_significance_sigma = 5.5
claim = DiscoveryClaim(
mass_estimate_gev=125.0,
mass_uncertainty_gev=0.5,
significance_sigma=5.5,
decay_channel="diphoton",
spin_hypothesis=0,
parity="+",
confidence=0.9,
)
out = compute_terminal_reward(state=s, claim=claim)
assert out.discovered
assert out.correct_mass and out.correct_channel and out.correct_spin
assert out.reward > 1.0
def test_terminal_reward_overconfident_wrong_is_punished():
s = _fresh_state()
s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton")
s.progress.best_significance_sigma = 4.5
claim = DiscoveryClaim(
mass_estimate_gev=600.0, # way off
decay_channel="dijet", # wrong
significance_sigma=5.0,
confidence=0.95,
)
out = compute_terminal_reward(state=s, claim=claim)
assert not out.discovered
assert out.reward < 0.0
assert out.breakdown.components.get("overconfident_wrong", 0.0) < 0.0
def test_significance_overclaim_penalty_fires():
s = _fresh_state()
s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton")
s.progress.best_significance_sigma = 1.0 # weak evidence
claim = DiscoveryClaim(
mass_estimate_gev=125.0,
decay_channel="diphoton",
significance_sigma=20.0, # absurd over-claim
confidence=0.5,
)
out = compute_terminal_reward(state=s, claim=claim)
assert out.breakdown.components.get("overclaim_significance", 0.0) < 0.0
def test_no_information_claim_clamped_nonpositive():
s = _fresh_state()
claim = DiscoveryClaim() # everything None / zero
out = compute_terminal_reward(state=s, claim=claim)
assert out.reward <= 0.0
# ── Hard / soft / failure penalties ─────────────────────────────────────
def test_hard_violation_dominates_step_reward():
state = _fresh_state()
rule = _failing_rule_result(ViolationCode.PREREQ_MISSING)
action = ExperimentAction(action_type=ActionType.COLLECT_COLLISIONS)
out = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=0,
success=False,
quality_score=0.0,
)
r = compute_step_reward(
action=action,
output=out,
state_before=state,
state_after=state,
rule_result=rule,
)
assert r.reward < 0.0