| """Tests for the decomposable reward function.""" | |
| from models import ActionType, ConclusionClaim, ExperimentAction, IntermediateOutput, OutputType | |
| from server.rewards.reward import RewardComputer | |
| from server.simulator.latent_state import ( | |
| ExperimentProgress, | |
| FullLatentState, | |
| LatentBiologicalState, | |
| ResourceState, | |
| ) | |
| def _states( | |
| prev_flags: dict | None = None, | |
| next_flags: dict | None = None, | |
| budget_used: float = 0.0, | |
| ): | |
| prev = FullLatentState( | |
| progress=ExperimentProgress(**(prev_flags or {})), | |
| resources=ResourceState(budget_total=100_000, budget_used=budget_used), | |
| ) | |
| nf = dict(prev_flags or {}) | |
| nf.update(next_flags or {}) | |
| nxt = FullLatentState( | |
| progress=ExperimentProgress(**nf), | |
| resources=ResourceState(budget_total=100_000, budget_used=budget_used + 5000), | |
| ) | |
| return prev, nxt | |
| class TestStepReward: | |
| def test_valid_step_positive(self): | |
| rc = RewardComputer() | |
| prev, nxt = _states( | |
| prev_flags={"samples_collected": True, "library_prepared": True}, | |
| next_flags={"cells_sequenced": True}, | |
| ) | |
| output = IntermediateOutput( | |
| output_type=OutputType.SEQUENCING_RESULT, | |
| step_index=1, | |
| quality_score=0.85, | |
| uncertainty=0.15, | |
| ) | |
| rb = rc.step_reward( | |
| ExperimentAction(action_type=ActionType.SEQUENCE_CELLS), | |
| prev, nxt, output, [], [], | |
| ) | |
| assert rb.total > 0 | |
| def test_hard_violation_negative(self): | |
| rc = RewardComputer() | |
| prev, nxt = _states() | |
| output = IntermediateOutput( | |
| output_type=OutputType.FAILURE_REPORT, | |
| step_index=1, | |
| success=False, | |
| ) | |
| rb = rc.step_reward( | |
| ExperimentAction(action_type=ActionType.SEQUENCE_CELLS), | |
| prev, nxt, output, ["blocked"], [], | |
| ) | |
| assert rb.total < 0 | |
| def test_premature_meta_action_gets_penalized(self): | |
| rc = RewardComputer() | |
| prev, nxt = _states( | |
| prev_flags={"data_normalized": True}, | |
| next_flags={"followup_designed": True}, | |
| budget_used=2_000, | |
| ) | |
| output = IntermediateOutput( | |
| output_type=OutputType.FOLLOWUP_DESIGN, | |
| step_index=2, | |
| quality_score=1.0, | |
| uncertainty=0.0, | |
| ) | |
| rb = rc.step_reward( | |
| ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP), | |
| prev, | |
| nxt, | |
| output, | |
| [], | |
| [], | |
| ) | |
| assert rb.components.get("premature_meta_action_penalty", 0.0) < 0.0 | |
| class TestTerminalReward: | |
| def test_correct_conclusion_rewarded(self): | |
| rc = RewardComputer() | |
| state = FullLatentState( | |
| biology=LatentBiologicalState( | |
| causal_mechanisms=["TGF-beta-driven fibrosis"], | |
| true_markers=["NPPA"], | |
| ), | |
| progress=ExperimentProgress( | |
| samples_collected=True, cells_sequenced=True, | |
| qc_performed=True, data_filtered=True, | |
| data_normalized=True, de_performed=True, | |
| conclusion_reached=True, | |
| ), | |
| resources=ResourceState(budget_total=100_000, budget_used=40_000), | |
| ) | |
| claims = [ | |
| ConclusionClaim( | |
| claim="TGF-beta-driven fibrosis observed", | |
| confidence=0.9, | |
| claim_type="causal", | |
| ), | |
| ] | |
| rb = rc.terminal_reward( | |
| state, | |
| claims, | |
| [], | |
| discovered_markers=["NPPA"], | |
| candidate_mechanisms=["TGF-beta-driven fibrosis"], | |
| ) | |
| assert rb.terminal > 0 | |
| def test_overconfident_wrong_claim_penalised(self): | |
| rc = RewardComputer() | |
| state = FullLatentState( | |
| biology=LatentBiologicalState(causal_mechanisms=["real_mechanism"]), | |
| progress=ExperimentProgress(conclusion_reached=True), | |
| ) | |
| claims = [ | |
| ConclusionClaim( | |
| claim="completely_wrong_mechanism", | |
| confidence=0.95, | |
| claim_type="causal", | |
| ), | |
| ] | |
| rb = rc.terminal_reward(state, claims, []) | |
| assert rb.components.get("overconfidence_penalty", 0) < 0 | |
| def test_discovery_error_penalizes_wrong_markers_and_mechanisms(self): | |
| rc = RewardComputer() | |
| state = FullLatentState( | |
| biology=LatentBiologicalState( | |
| true_markers=["NPPA", "NPPB"], | |
| causal_mechanisms=["TGF-beta-driven fibrosis"], | |
| ), | |
| progress=ExperimentProgress( | |
| samples_collected=True, | |
| cells_sequenced=True, | |
| qc_performed=True, | |
| data_filtered=True, | |
| data_normalized=True, | |
| de_performed=True, | |
| markers_discovered=True, | |
| conclusion_reached=True, | |
| ), | |
| resources=ResourceState(budget_total=100_000, budget_used=40_000), | |
| ) | |
| aligned = rc.terminal_reward( | |
| state, | |
| [], | |
| [], | |
| discovered_markers=["NPPA", "NPPB"], | |
| candidate_mechanisms=["TGF-beta-driven fibrosis"], | |
| ) | |
| misaligned = rc.terminal_reward( | |
| state, | |
| [], | |
| [], | |
| discovered_markers=["WRONG1", "WRONG2"], | |
| candidate_mechanisms=["unrelated inflammatory process"], | |
| ) | |
| assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"] | |
| assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"] | |
| assert aligned.terminal > misaligned.terminal | |
| def test_conclusion_error_penalizes_wrong_structured_claims(self): | |
| rc = RewardComputer() | |
| state = FullLatentState( | |
| biology=LatentBiologicalState( | |
| true_markers=["NPPA", "NPPB"], | |
| causal_mechanisms=["TGF-beta-driven fibrosis"], | |
| ), | |
| progress=ExperimentProgress( | |
| data_normalized=True, | |
| de_performed=True, | |
| markers_discovered=True, | |
| pathways_analyzed=True, | |
| conclusion_reached=True, | |
| ), | |
| resources=ResourceState(budget_total=100_000, budget_used=40_000), | |
| ) | |
| aligned = rc.terminal_reward( | |
| state, | |
| [ | |
| ConclusionClaim( | |
| top_markers=["NPPA", "NPPB"], | |
| causal_mechanisms=["TGF-beta-driven fibrosis"], | |
| confidence=0.8, | |
| ), | |
| ], | |
| [], | |
| ) | |
| misaligned = rc.terminal_reward( | |
| state, | |
| [ | |
| ConclusionClaim( | |
| top_markers=["WRONG1"], | |
| causal_mechanisms=["unrelated process"], | |
| confidence=0.8, | |
| ), | |
| ], | |
| [], | |
| ) | |
| assert aligned.components["conclusion_alignment"] > misaligned.components["conclusion_alignment"] | |
| assert aligned.components["conclusion_error_penalty"] > misaligned.components["conclusion_error_penalty"] | |
| assert aligned.terminal > misaligned.terminal | |