drugenv-trainer / tests /test_environment.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Integration tests for the full DrugTargetEnvironment."""
from models import ActionType, DrugTargetAction
from server.hackathon_environment import DrugTargetEnvironment
class TestEnvironmentLifecycle:
def test_reset_returns_valid_observation(self):
env = DrugTargetEnvironment()
obs = env.reset()
assert obs.step_index == 0
assert obs.done is False
assert obs.target_gene != ""
assert obs.credits_remaining > 0
def test_step_increments_step_count(self):
env = DrugTargetEnvironment()
env.reset()
obs = env.step(DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION))
assert obs.step_index == 1
assert env.state.step_count == 1
def test_valid_investigation_trajectory(self):
env = DrugTargetEnvironment(
scenario_name="egfr_nsclc_viable", domain_randomise=False,
)
env.reset(seed=0)
actions = [
DrugTargetAction(
action_type=ActionType.QUERY_EXPRESSION,
parameters={"database": "GTEx"},
),
DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN),
DrugTargetAction(action_type=ActionType.OFF_TARGET_SCREEN),
DrugTargetAction(action_type=ActionType.TOXICITY_PANEL),
DrugTargetAction(action_type=ActionType.CLINICAL_TRIAL_LOOKUP),
]
for a in actions:
obs = env.step(a)
assert obs.latest_output is not None
assert obs.latest_output.success is True, (
f"Step {a.action_type} failed: {obs.rule_violations}"
)
assert obs.step_index == len(actions)
assert obs.credits_remaining < obs.credits_total
def test_submit_report_ends_episode(self):
env = DrugTargetEnvironment(
scenario_name="egfr_nsclc_viable", domain_randomise=False,
)
env.reset(seed=1)
env.step(DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION))
env.step(DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN))
obs = env.step(DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
final_decision="go",
confidence=0.85,
reasoning="Test conclusion",
))
assert obs.done is True
def test_submit_without_evidence_is_blocked(self):
env = DrugTargetEnvironment(
scenario_name="egfr_nsclc_viable", domain_randomise=False,
)
env.reset(seed=2)
obs = env.step(DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
final_decision="go",
confidence=0.85,
))
assert obs.latest_output is not None
assert obs.latest_output.success is False
assert any("evidence" in m.lower() for m in obs.rule_violations)
def test_credit_exhaustion_ends_episode(self):
env = DrugTargetEnvironment(
scenario_name="egfr_nsclc_viable", domain_randomise=False,
)
env.reset(seed=3)
# Drain the credit pool by manipulating the latent state directly,
# then take a single cheap action to trigger the exhaustion check.
env._latent.credits.credits_used = env._latent.credits.credits_total - 1
obs = env.step(
DrugTargetAction(action_type=ActionType.LITERATURE_SEARCH)
)
assert obs.done is True
assert obs.credits_remaining == 0