Bio-EnvRL / tests /test_rewards.py
Ev3Dev's picture
Upload folder using huggingface_hub
df98fca verified
"""Tests for the decomposable reward function."""
from models import ActionType, ConclusionClaim, ExperimentAction, IntermediateOutput, OutputType
from server.rewards.reward import RewardComputer
from server.simulator.latent_state import (
ExperimentProgress,
FullLatentState,
LatentBiologicalState,
ResourceState,
)
def _states(
prev_flags: dict | None = None,
next_flags: dict | None = None,
budget_used: float = 0.0,
):
prev = FullLatentState(
progress=ExperimentProgress(**(prev_flags or {})),
resources=ResourceState(budget_total=100_000, budget_used=budget_used),
)
nf = dict(prev_flags or {})
nf.update(next_flags or {})
nxt = FullLatentState(
progress=ExperimentProgress(**nf),
resources=ResourceState(budget_total=100_000, budget_used=budget_used + 5000),
)
return prev, nxt
class TestStepReward:
def test_valid_step_positive(self):
rc = RewardComputer()
prev, nxt = _states(
prev_flags={"samples_collected": True, "library_prepared": True},
next_flags={"cells_sequenced": True},
)
output = IntermediateOutput(
output_type=OutputType.SEQUENCING_RESULT,
step_index=1,
quality_score=0.85,
uncertainty=0.15,
)
rb = rc.step_reward(
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
prev, nxt, output, [], [],
)
assert rb.total > 0
def test_hard_violation_negative(self):
rc = RewardComputer()
prev, nxt = _states()
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=1,
success=False,
)
rb = rc.step_reward(
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
prev, nxt, output, ["blocked"], [],
)
assert rb.total < 0
class TestTerminalReward:
def test_correct_conclusion_rewarded(self):
rc = RewardComputer()
state = FullLatentState(
biology=LatentBiologicalState(
causal_mechanisms=["TGF-beta-driven fibrosis"],
true_markers=["NPPA"],
),
progress=ExperimentProgress(
samples_collected=True, cells_sequenced=True,
qc_performed=True, data_filtered=True,
data_normalized=True, de_performed=True,
conclusion_reached=True,
),
resources=ResourceState(budget_total=100_000, budget_used=40_000),
)
claims = [
ConclusionClaim(
claim="TGF-beta-driven fibrosis observed",
confidence=0.9,
claim_type="causal",
),
]
rb = rc.terminal_reward(state, claims, [])
assert rb.terminal > 0
def test_overconfident_wrong_claim_penalised(self):
rc = RewardComputer()
state = FullLatentState(
biology=LatentBiologicalState(causal_mechanisms=["real_mechanism"]),
progress=ExperimentProgress(conclusion_reached=True),
)
claims = [
ConclusionClaim(
claim="completely_wrong_mechanism",
confidence=0.95,
claim_type="causal",
),
]
rb = rc.terminal_reward(state, claims, [])
assert rb.components.get("overconfidence_penalty", 0) < 0