"""Tests for the per-step + terminal reward function. This is the most safety-critical module in CERNenv. The test suite is structured around the same anti-hacking principles called out in the hackathon FAQ (Q12, Q13, Q42, Q56, Q57): * the terminal grade should *dominate* total reward, * shaping rewards must be hard to farm, * obvious model "cheats" (string-spam, claim-spam, JSON-spam) must not produce high reward. """ from __future__ import annotations import pytest from models import ( ActionType, DiscoveryClaim, ExperimentAction, IntermediateOutput, OutputType, ) from server.rewards.reward_function import ( RewardWeights, _mass_score, _significance_overclaim, _significance_score, compute_step_reward, compute_terminal_reward, ) from server.rules.engine import RuleResult, ViolationCode from server.simulator.latent_state import FullLatentState, LatentParticle, ResourceState # ── helpers ───────────────────────────────────────────────────────────── def _passing_rule_result() -> RuleResult: return RuleResult(allowed=True) def _failing_rule_result(*violations: ViolationCode) -> RuleResult: r = RuleResult(allowed=True) for v in violations: r.add(v, str(v)) return r def _ok_output(action_type: ActionType = ActionType.CONFIGURE_BEAM) -> IntermediateOutput: return IntermediateOutput( output_type=OutputType.BEAM_CONFIG, step_index=0, success=True, quality_score=0.9, summary="ok", ) def _fresh_state() -> FullLatentState: return FullLatentState( particle=LatentParticle(), resources=ResourceState(), ) # ── _mass_score ───────────────────────────────────────────────────────── def test_mass_score_perfect_inside_tolerance(): assert _mass_score(125.0, 125.0, unc=None) == pytest.approx(1.0) def test_mass_score_decays_outside_tolerance(): high = _mass_score(125.0, 125.5, unc=None) low = _mass_score(125.0, 130.0, unc=None) assert 0.0 < low < high <= 1.0 def test_mass_score_zero_when_far_off(): assert _mass_score(125.0, 200.0, unc=None) == 0.0 def test_mass_score_returns_zero_when_claim_missing(): assert _mass_score(125.0, None, None) == 0.0 # ── _significance_score and overclaim ─────────────────────────────────── def test_significance_score_uses_measured_value(): """Under-claiming is fine (we just return the base); over-claiming is actively penalised (anti-hacking).""" s = _fresh_state() s.progress.best_significance_sigma = 5.0 score_match = _significance_score(s, claim_sigma=5.0) score_over = _significance_score(s, claim_sigma=20.0) score_none = _significance_score(s, claim_sigma=None) assert score_match == pytest.approx(1.0) assert score_over < score_match assert score_none == 0.0 def test_significance_overclaim_only_above_threshold(): s = _fresh_state() s.progress.best_significance_sigma = 4.0 assert _significance_overclaim(s, claim_sigma=4.5) == 0.0 assert _significance_overclaim(s, claim_sigma=10.0) > 0.0 # ── compute_step_reward: tool_fit gating (anti-hacking) ───────────────── def test_bogus_method_string_is_penalised_not_rewarded(): state = _fresh_state() action = ExperimentAction( action_type=ActionType.FIT_RESONANCE, method="LITERAL_GIBBERISH_BOGUS_LMAO", ) out = _ok_output() result = compute_step_reward( action=action, output=out, state_before=state, state_after=state, rule_result=_passing_rule_result(), ) assert "tool_fit" not in result.breakdown.components assert result.breakdown.components.get("bogus_method", 0.0) < 0.0 def test_real_method_with_correct_category_is_rewarded(): state = _fresh_state() action = ExperimentAction( action_type=ActionType.FIT_RESONANCE, method="ROOT_RooFit", # ANALYSIS category, matches FIT_RESONANCE ) out = _ok_output() result = compute_step_reward( action=action, output=out, state_before=state, state_after=state, rule_result=_passing_rule_result(), ) assert result.breakdown.components.get("tool_fit", 0.0) > 0.0 def test_real_method_with_mismatched_category_is_silent(): state = _fresh_state() action = ExperimentAction( action_type=ActionType.CALIBRATE_DETECTOR, method="BumpHunter", # STATISTICS, mismatch with CALIBRATION ) out = _ok_output() result = compute_step_reward( action=action, output=out, state_before=state, state_after=state, rule_result=_passing_rule_result(), ) assert "tool_fit" not in result.breakdown.components # method IS in the registry → no bogus_method penalty either assert "bogus_method" not in result.breakdown.components # ── compute_step_reward: repeat-action penalty ────────────────────────── def test_repeat_action_penalty_escalates(): """Three identical action_types in a row should be penalised.""" from models import PipelineStepRecord state = _fresh_state() action_type = ActionType.REQUEST_THEORY_REVIEW history = [ PipelineStepRecord( step_index=i, action_type=action_type, output_type=OutputType.THEORY_REVIEW, output_summary="...", ) for i in range(3) ] action = ExperimentAction(action_type=action_type) result = compute_step_reward( action=action, output=_ok_output(), state_before=state, state_after=state, rule_result=_passing_rule_result(), history=history, ) assert result.breakdown.components.get("repeat_action", 0.0) < 0.0 def test_no_repeat_penalty_for_first_use(): state = _fresh_state() action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM) result = compute_step_reward( action=action, output=_ok_output(), state_before=state, state_after=state, rule_result=_passing_rule_result(), history=[], ) assert "repeat_action" not in result.breakdown.components # ── compute_step_reward: clip ─────────────────────────────────────────── def test_step_reward_is_clipped_above(): state = _fresh_state() weights = RewardWeights(step_reward_clip=0.1) action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM, method="ATLAS_HLT") result = compute_step_reward( action=action, output=_ok_output(), state_before=state, state_after=state, rule_result=_passing_rule_result(), weights=weights, ) assert result.reward <= 0.1 + 1e-9 # ── compute_terminal_reward: correctness + overconfidence ─────────────── def test_terminal_reward_high_for_correct_calibrated_claim(): s = _fresh_state() s.particle = LatentParticle( mass_gev=125.0, primary_channel="diphoton", spin=0, parity="+", width_gev=0.004, ) s.progress.best_significance_sigma = 5.5 claim = DiscoveryClaim( mass_estimate_gev=125.0, mass_uncertainty_gev=0.5, significance_sigma=5.5, decay_channel="diphoton", spin_hypothesis=0, parity="+", confidence=0.9, ) out = compute_terminal_reward(state=s, claim=claim) assert out.discovered assert out.correct_mass and out.correct_channel and out.correct_spin assert out.reward > 1.0 def test_terminal_reward_overconfident_wrong_is_punished(): s = _fresh_state() s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton") s.progress.best_significance_sigma = 4.5 claim = DiscoveryClaim( mass_estimate_gev=600.0, # way off decay_channel="dijet", # wrong significance_sigma=5.0, confidence=0.95, ) out = compute_terminal_reward(state=s, claim=claim) assert not out.discovered assert out.reward < 0.0 assert out.breakdown.components.get("overconfident_wrong", 0.0) < 0.0 def test_significance_overclaim_penalty_fires(): s = _fresh_state() s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton") s.progress.best_significance_sigma = 1.0 # weak evidence claim = DiscoveryClaim( mass_estimate_gev=125.0, decay_channel="diphoton", significance_sigma=20.0, # absurd over-claim confidence=0.5, ) out = compute_terminal_reward(state=s, claim=claim) assert out.breakdown.components.get("overclaim_significance", 0.0) < 0.0 def test_no_information_claim_clamped_nonpositive(): s = _fresh_state() claim = DiscoveryClaim() # everything None / zero out = compute_terminal_reward(state=s, claim=claim) assert out.reward <= 0.0 # ── Hard / soft / failure penalties ───────────────────────────────────── def test_hard_violation_dominates_step_reward(): state = _fresh_state() rule = _failing_rule_result(ViolationCode.PREREQ_MISSING) action = ExperimentAction(action_type=ActionType.COLLECT_COLLISIONS) out = IntermediateOutput( output_type=OutputType.FAILURE_REPORT, step_index=0, success=False, quality_score=0.0, ) r = compute_step_reward( action=action, output=out, state_before=state, state_after=state, rule_result=rule, ) assert r.reward < 0.0