| """Tests for the per-step + terminal reward function.
|
|
|
| This is the most safety-critical module in CERNenv. The test suite is
|
| structured around the same anti-hacking principles called out in the
|
| hackathon FAQ (Q12, Q13, Q42, Q56, Q57):
|
|
|
| * the terminal grade should *dominate* total reward,
|
| * shaping rewards must be hard to farm,
|
| * obvious model "cheats" (string-spam, claim-spam, JSON-spam) must
|
| not produce high reward.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import pytest
|
|
|
| from models import (
|
| ActionType,
|
| DiscoveryClaim,
|
| ExperimentAction,
|
| IntermediateOutput,
|
| OutputType,
|
| )
|
| from server.rewards.reward_function import (
|
| RewardWeights,
|
| _mass_score,
|
| _significance_overclaim,
|
| _significance_score,
|
| compute_step_reward,
|
| compute_terminal_reward,
|
| )
|
| from server.rules.engine import RuleResult, ViolationCode
|
| from server.simulator.latent_state import FullLatentState, LatentParticle, ResourceState
|
|
|
|
|
|
|
|
|
|
|
| def _passing_rule_result() -> RuleResult:
|
| return RuleResult(allowed=True)
|
|
|
|
|
| def _failing_rule_result(*violations: ViolationCode) -> RuleResult:
|
| r = RuleResult(allowed=True)
|
| for v in violations:
|
| r.add(v, str(v))
|
| return r
|
|
|
|
|
| def _ok_output(action_type: ActionType = ActionType.CONFIGURE_BEAM) -> IntermediateOutput:
|
| return IntermediateOutput(
|
| output_type=OutputType.BEAM_CONFIG,
|
| step_index=0,
|
| success=True,
|
| quality_score=0.9,
|
| summary="ok",
|
| )
|
|
|
|
|
| def _fresh_state() -> FullLatentState:
|
| return FullLatentState(
|
| particle=LatentParticle(),
|
| resources=ResourceState(),
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def test_mass_score_perfect_inside_tolerance():
|
| assert _mass_score(125.0, 125.0, unc=None) == pytest.approx(1.0)
|
|
|
|
|
| def test_mass_score_decays_outside_tolerance():
|
| high = _mass_score(125.0, 125.5, unc=None)
|
| low = _mass_score(125.0, 130.0, unc=None)
|
| assert 0.0 < low < high <= 1.0
|
|
|
|
|
| def test_mass_score_zero_when_far_off():
|
| assert _mass_score(125.0, 200.0, unc=None) == 0.0
|
|
|
|
|
| def test_mass_score_returns_zero_when_claim_missing():
|
| assert _mass_score(125.0, None, None) == 0.0
|
|
|
|
|
|
|
|
|
|
|
| def test_significance_score_uses_measured_value():
|
| """Under-claiming is fine (we just return the base); over-claiming is
|
| actively penalised (anti-hacking)."""
|
| s = _fresh_state()
|
| s.progress.best_significance_sigma = 5.0
|
| score_match = _significance_score(s, claim_sigma=5.0)
|
| score_over = _significance_score(s, claim_sigma=20.0)
|
| score_none = _significance_score(s, claim_sigma=None)
|
| assert score_match == pytest.approx(1.0)
|
| assert score_over < score_match
|
| assert score_none == 0.0
|
|
|
|
|
| def test_significance_overclaim_only_above_threshold():
|
| s = _fresh_state()
|
| s.progress.best_significance_sigma = 4.0
|
| assert _significance_overclaim(s, claim_sigma=4.5) == 0.0
|
| assert _significance_overclaim(s, claim_sigma=10.0) > 0.0
|
|
|
|
|
|
|
|
|
|
|
| def test_bogus_method_string_is_penalised_not_rewarded():
|
| state = _fresh_state()
|
| action = ExperimentAction(
|
| action_type=ActionType.FIT_RESONANCE,
|
| method="LITERAL_GIBBERISH_BOGUS_LMAO",
|
| )
|
| out = _ok_output()
|
| result = compute_step_reward(
|
| action=action,
|
| output=out,
|
| state_before=state,
|
| state_after=state,
|
| rule_result=_passing_rule_result(),
|
| )
|
| assert "tool_fit" not in result.breakdown.components
|
| assert result.breakdown.components.get("bogus_method", 0.0) < 0.0
|
|
|
|
|
| def test_real_method_with_correct_category_is_rewarded():
|
| state = _fresh_state()
|
| action = ExperimentAction(
|
| action_type=ActionType.FIT_RESONANCE,
|
| method="ROOT_RooFit",
|
| )
|
| out = _ok_output()
|
| result = compute_step_reward(
|
| action=action,
|
| output=out,
|
| state_before=state,
|
| state_after=state,
|
| rule_result=_passing_rule_result(),
|
| )
|
| assert result.breakdown.components.get("tool_fit", 0.0) > 0.0
|
|
|
|
|
| def test_real_method_with_mismatched_category_is_silent():
|
| state = _fresh_state()
|
| action = ExperimentAction(
|
| action_type=ActionType.CALIBRATE_DETECTOR,
|
| method="BumpHunter",
|
| )
|
| out = _ok_output()
|
| result = compute_step_reward(
|
| action=action,
|
| output=out,
|
| state_before=state,
|
| state_after=state,
|
| rule_result=_passing_rule_result(),
|
| )
|
| assert "tool_fit" not in result.breakdown.components
|
|
|
| assert "bogus_method" not in result.breakdown.components
|
|
|
|
|
|
|
|
|
|
|
| def test_repeat_action_penalty_escalates():
|
| """Three identical action_types in a row should be penalised."""
|
| from models import PipelineStepRecord
|
|
|
| state = _fresh_state()
|
| action_type = ActionType.REQUEST_THEORY_REVIEW
|
| history = [
|
| PipelineStepRecord(
|
| step_index=i,
|
| action_type=action_type,
|
| output_type=OutputType.THEORY_REVIEW,
|
| output_summary="...",
|
| )
|
| for i in range(3)
|
| ]
|
| action = ExperimentAction(action_type=action_type)
|
| result = compute_step_reward(
|
| action=action,
|
| output=_ok_output(),
|
| state_before=state,
|
| state_after=state,
|
| rule_result=_passing_rule_result(),
|
| history=history,
|
| )
|
| assert result.breakdown.components.get("repeat_action", 0.0) < 0.0
|
|
|
|
|
| def test_no_repeat_penalty_for_first_use():
|
| state = _fresh_state()
|
| action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM)
|
| result = compute_step_reward(
|
| action=action,
|
| output=_ok_output(),
|
| state_before=state,
|
| state_after=state,
|
| rule_result=_passing_rule_result(),
|
| history=[],
|
| )
|
| assert "repeat_action" not in result.breakdown.components
|
|
|
|
|
|
|
|
|
|
|
| def test_step_reward_is_clipped_above():
|
| state = _fresh_state()
|
| weights = RewardWeights(step_reward_clip=0.1)
|
| action = ExperimentAction(action_type=ActionType.CONFIGURE_BEAM, method="ATLAS_HLT")
|
| result = compute_step_reward(
|
| action=action,
|
| output=_ok_output(),
|
| state_before=state,
|
| state_after=state,
|
| rule_result=_passing_rule_result(),
|
| weights=weights,
|
| )
|
| assert result.reward <= 0.1 + 1e-9
|
|
|
|
|
|
|
|
|
|
|
| def test_terminal_reward_high_for_correct_calibrated_claim():
|
| s = _fresh_state()
|
| s.particle = LatentParticle(
|
| mass_gev=125.0, primary_channel="diphoton", spin=0, parity="+", width_gev=0.004,
|
| )
|
| s.progress.best_significance_sigma = 5.5
|
| claim = DiscoveryClaim(
|
| mass_estimate_gev=125.0,
|
| mass_uncertainty_gev=0.5,
|
| significance_sigma=5.5,
|
| decay_channel="diphoton",
|
| spin_hypothesis=0,
|
| parity="+",
|
| confidence=0.9,
|
| )
|
| out = compute_terminal_reward(state=s, claim=claim)
|
| assert out.discovered
|
| assert out.correct_mass and out.correct_channel and out.correct_spin
|
| assert out.reward > 1.0
|
|
|
|
|
| def test_terminal_reward_overconfident_wrong_is_punished():
|
| s = _fresh_state()
|
| s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton")
|
| s.progress.best_significance_sigma = 4.5
|
| claim = DiscoveryClaim(
|
| mass_estimate_gev=600.0,
|
| decay_channel="dijet",
|
| significance_sigma=5.0,
|
| confidence=0.95,
|
| )
|
| out = compute_terminal_reward(state=s, claim=claim)
|
| assert not out.discovered
|
| assert out.reward < 0.0
|
| assert out.breakdown.components.get("overconfident_wrong", 0.0) < 0.0
|
|
|
|
|
| def test_significance_overclaim_penalty_fires():
|
| s = _fresh_state()
|
| s.particle = LatentParticle(mass_gev=125.0, primary_channel="diphoton")
|
| s.progress.best_significance_sigma = 1.0
|
| claim = DiscoveryClaim(
|
| mass_estimate_gev=125.0,
|
| decay_channel="diphoton",
|
| significance_sigma=20.0,
|
| confidence=0.5,
|
| )
|
| out = compute_terminal_reward(state=s, claim=claim)
|
| assert out.breakdown.components.get("overclaim_significance", 0.0) < 0.0
|
|
|
|
|
| def test_no_information_claim_clamped_nonpositive():
|
| s = _fresh_state()
|
| claim = DiscoveryClaim()
|
| out = compute_terminal_reward(state=s, claim=claim)
|
| assert out.reward <= 0.0
|
|
|
|
|
|
|
|
|
|
|
| def test_hard_violation_dominates_step_reward():
|
| state = _fresh_state()
|
| rule = _failing_rule_result(ViolationCode.PREREQ_MISSING)
|
| action = ExperimentAction(action_type=ActionType.COLLECT_COLLISIONS)
|
| out = IntermediateOutput(
|
| output_type=OutputType.FAILURE_REPORT,
|
| step_index=0,
|
| success=False,
|
| quality_score=0.0,
|
| )
|
| r = compute_step_reward(
|
| action=action,
|
| output=out,
|
| state_before=state,
|
| state_after=state,
|
| rule_result=rule,
|
| )
|
| assert r.reward < 0.0
|
|
|