Spaces:
Paused
Paused
| """Tests for the per-step + terminal reward function. | |
| This is the most safety-critical module in CERNenv. The test suite is | |
| structured around the same anti-hacking principles called out in the | |
| hackathon FAQ (Q12, Q13, Q42, Q56, Q57): | |
| * the terminal grade should *dominate* total reward, | |
| * shaping rewards must be hard to farm, | |
| * obvious model "cheats" (string-spam, claim-spam, JSON-spam) must | |
| not produce high reward. | |
| """ | |
| from __future__ import annotations | |
| import pytest | |
| from models import ( | |
| ActionType, | |
| DiscoveryClaim, | |
| ExperimentAction, | |
| IntermediateOutput, | |
| OutputType, | |
| ) | |
| from server.rewards.reward_function import ( | |
| RewardWeights, | |
| _mass_score, | |
| _significance_overclaim, | |
| _significance_score, | |
| compute_step_reward, | |
| compute_terminal_reward, | |
| ) | |
| from server.rules.engine import RuleResult, ViolationCode | |
| from server.simulator.latent_state import FullLatentState, LatentParticle, ResourceState | |
| # ── helpers ───────────────────────────────────────────────────────────── | |
| def _passing_rule_result() -> RuleResult: | |
| return RuleResult(allowed=True) | |
| def _failing_rule_result(*violations: ViolationCode) -> RuleResult: | |
| r = RuleResult(allowed=True) | |
| for v in violations: | |
| r.add(v, str(v)) | |
| return r | |
| def _ok_output(action_type: ActionType = ActionType.CONFIGURE_BEAM) -> IntermediateOutput: | |
| return IntermediateOutput( | |
| output_type=OutputType.BEAM_CONFIG, | |
| step_index=0, | |
| success=True, | |
| quality_score=0.9, | |
| summary="ok", | |
| ) | |
| def _fresh_state() -> FullLatentState: | |
| return FullLatentState( | |
| particle=LatentParticle(), | |
| resources=ResourceState(), | |
| ) | |
| # ── _mass_score ───────────────────────────────────────────────────────── | |
| def test_mass_score_perfect_inside_tolerance(): | |
| assert _mass_score(125.0, 125.0, unc=None) == pytest.approx(1.0) | |
| def test_mass_score_decays_outside_tolerance(): | |
| high = _mass_score(125.0, 125.5, unc=None) | |
| low = _mass_score(125.0, 130.0, unc=None) | |
| assert 0.0 < low < high <= 1.0 | |
| def test_mass_score_zero_when_far_off(): | |
| assert _mass_score(125.0, 200.0, unc=None) == 0.0 | |
| def test_mass_score_returns_zero_when_claim_missing(): | |
| assert _mass_score(125.0, None, None) == 0.0 | |
| # ── _significance_score and overclaim ─────────────────────────────────── | |
| def test_significance_score_uses_measured_value(): | |
| """Under-claiming is fine (we just return the base); over-claiming is | |
| actively penalised (anti-hacking).""" | |
| s = _fresh_state() | |
| s.progress.best_significance_sigma = 5.0 | |
| score_match = _significance_score(s, claim_sigma=5.0) | |
| score_over = _significance_score(s, claim_sigma=20.0) | |
| score_none = _significance_score(s, claim_sigma=None) | |
| assert score_match == pytest.approx(1.0) | |
| assert score_over < score_match | |
| assert score_none == 0.0 | |
| def test_significance_overclaim_only_above_threshold(): | |
| s = _fresh_state() | |
| s.progress.best_significance_sigma = 4.0 | |
| assert _significance_overclaim(s, claim_sigma=4.5) == 0.0 | |
| assert _significance_overclaim(s, claim_sigma=10.0) > 0.0 | |
| # ── compute_step_reward: tool_fit gating (anti-hacking) ───────────────── | |
| def test_bogus_method_string_is_penalised_not_rewarded(): | |
| state = _fresh_state() | |
| action = ExperimentAction( | |
| action_type=ActionType.FIT_RESONANCE, | |
| method="LITERAL_GIBBERISH_BOGUS_LMAO", | |
| ) | |
| out = _ok_output() | |
| result = compute_step_reward( | |
| action=action, | |
| output=out, | |
| state_before=state, | |
| state_after=state, | |
| rule_result=_passing_rule_result(), | |
| ) | |
| assert "tool_fit" not in result.breakdown.components | |
| assert result.breakdown.components.get("bogus_method", 0.0) < 0.0 | |
| def test_real_method_with_correct_category_is_rewarded(): | |
| state = _fresh_state() | |
| action = ExperimentAction( | |
| action_type=ActionType.FIT_RESONANCE, | |
| method="ROOT_RooFit", # ANALYSIS category, matches FIT_RESONANCE | |
| ) | |
| out = _ok_output() | |
| result = compute_step_reward( | |
| action=action, | |
| output=out, | |
| state_before=state, | |
| state_after=state, | |
| rule_result=_passing_rule_result(), | |
| ) | |
| assert result.breakdown.components.get("tool_fit", 0.0) > 0.0 | |
| def test_real_method_with_mismatched_category_is_silent(): | |
| state = _fresh_state() | |
| action = ExperimentAction( | |
| action_type=ActionType.CALIBRATE_DETECTOR, | |
| method="BumpHunter", # STATISTICS, mismatch with CALIBRATION | |
| ) | |
| out = _ok_output() | |
| result = compute_step_reward( | |
| action=action, | |
| output=out, | |
| state_before=state, | |
| state_after=state, | |
| rule_result=_passing_rule_result(), | |
| ) | |
| assert "tool_fit" not in result.breakdown.components | |
| # method IS in the registry → no bogus_method penalty either | |
| assert "bogus_method" not in result.breakdown.components | |
| # ── compute_step_reward: repeat-action penalty ────────────────────────── | |
| def test_repeat_action_penalty_escalates(): | |
| """Three identical action_types in a row should be penalised.""" | |
| from models import PipelineStepRecord | |
| state = _fresh_state() | |
| action_type = ActionType.REQUEST_THEORY_REVIEW | |
| history = [ | |
| PipelineStepRecord( | |
| step_index=i, | |
| action_type=action_type, | |
| output_type=OutputType.THEORY_REVIEW, | |
| output_summary="...", | |
| ) | |
| for i in range(3) | |
| ] | |
| action = ExperimentAction(action_type=action_type) | |
| result = compute_step_reward( | |
| action=action, | |
| output=_ok_output(), | |
| state_before=state, | |
| state_after=state, | |
| rule_result=_passing_rule_result(), | |
| history=history, | |
| ) | |
| assert result.breakdown.components.get("repeat_action", 0.0) < 0.0 | |
| def test_no_repeat_penalty_for_first_use(): | |
| state = _fresh_state() | |
| action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM) | |
| result = compute_step_reward( | |
| action=action, | |
| output=_ok_output(), | |
| state_before=state, | |
| state_after=state, | |
| rule_result=_passing_rule_result(), | |
| history=[], | |
| ) | |
| assert "repeat_action" not in result.breakdown.components | |
| # ── compute_step_reward: clip ─────────────────────────────────────────── | |
| def test_step_reward_is_clipped_above(): | |
| state = _fresh_state() | |
| weights = RewardWeights(step_reward_clip=0.1) | |
| action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM, method="ATLAS_HLT") | |
| result = compute_step_reward( | |
| action=action, | |
| output=_ok_output(), | |
| state_before=state, | |
| state_after=state, | |
| rule_result=_passing_rule_result(), | |
| weights=weights, | |
| ) | |
| assert result.reward <= 0.1 + 1e-9 | |
| # ── compute_terminal_reward: correctness + overconfidence ─────────────── | |
| def test_terminal_reward_high_for_correct_calibrated_claim(): | |
| s = _fresh_state() | |
| s.particle = LatentParticle( | |
| mass_gev=125.0, primary_channel="diphoton", spin=0, parity="+", width_gev=0.004, | |
| ) | |
| s.progress.best_significance_sigma = 5.5 | |
| claim = DiscoveryClaim( | |
| mass_estimate_gev=125.0, | |
| mass_uncertainty_gev=0.5, | |
| significance_sigma=5.5, | |
| decay_channel="diphoton", | |
| spin_hypothesis=0, | |
| parity="+", | |
| confidence=0.9, | |
| ) | |
| out = compute_terminal_reward(state=s, claim=claim) | |
| assert out.discovered | |
| assert out.correct_mass and out.correct_channel and out.correct_spin | |
| assert out.reward > 1.0 | |
| def test_terminal_reward_overconfident_wrong_is_punished(): | |
| s = _fresh_state() | |
| s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton") | |
| s.progress.best_significance_sigma = 4.5 | |
| claim = DiscoveryClaim( | |
| mass_estimate_gev=600.0, # way off | |
| decay_channel="dijet", # wrong | |
| significance_sigma=5.0, | |
| confidence=0.95, | |
| ) | |
| out = compute_terminal_reward(state=s, claim=claim) | |
| assert not out.discovered | |
| assert out.reward < 0.0 | |
| assert out.breakdown.components.get("overconfident_wrong", 0.0) < 0.0 | |
| def test_significance_overclaim_penalty_fires(): | |
| s = _fresh_state() | |
| s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton") | |
| s.progress.best_significance_sigma = 1.0 # weak evidence | |
| claim = DiscoveryClaim( | |
| mass_estimate_gev=125.0, | |
| decay_channel="diphoton", | |
| significance_sigma=20.0, # absurd over-claim | |
| confidence=0.5, | |
| ) | |
| out = compute_terminal_reward(state=s, claim=claim) | |
| assert out.breakdown.components.get("overclaim_significance", 0.0) < 0.0 | |
| def test_no_information_claim_clamped_nonpositive(): | |
| s = _fresh_state() | |
| claim = DiscoveryClaim() # everything None / zero | |
| out = compute_terminal_reward(state=s, claim=claim) | |
| assert out.reward <= 0.0 | |
| # ── Hard / soft / failure penalties ───────────────────────────────────── | |
| def test_hard_violation_dominates_step_reward(): | |
| state = _fresh_state() | |
| rule = _failing_rule_result(ViolationCode.PREREQ_MISSING) | |
| action = ExperimentAction(action_type=ActionType.COLLECT_COLLISIONS) | |
| out = IntermediateOutput( | |
| output_type=OutputType.FAILURE_REPORT, | |
| step_index=0, | |
| success=False, | |
| quality_score=0.0, | |
| ) | |
| r = compute_step_reward( | |
| action=action, | |
| output=out, | |
| state_before=state, | |
| state_after=state, | |
| rule_result=rule, | |
| ) | |
| assert r.reward < 0.0 | |