File size: 3,492 Bytes
ad12dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Integration tests for the full DrugTargetEnvironment."""

from models import ActionType, DrugTargetAction
from server.hackathon_environment import DrugTargetEnvironment


class TestEnvironmentLifecycle:
    def test_reset_returns_valid_observation(self):
        env = DrugTargetEnvironment()
        obs = env.reset()
        assert obs.step_index == 0
        assert obs.done is False
        assert obs.target_gene != ""
        assert obs.credits_remaining > 0

    def test_step_increments_step_count(self):
        env = DrugTargetEnvironment()
        env.reset()
        obs = env.step(DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION))
        assert obs.step_index == 1
        assert env.state.step_count == 1

    def test_valid_investigation_trajectory(self):
        env = DrugTargetEnvironment(
            scenario_name="egfr_nsclc_viable", domain_randomise=False,
        )
        env.reset(seed=0)
        actions = [
            DrugTargetAction(
                action_type=ActionType.QUERY_EXPRESSION,
                parameters={"database": "GTEx"},
            ),
            DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN),
            DrugTargetAction(action_type=ActionType.OFF_TARGET_SCREEN),
            DrugTargetAction(action_type=ActionType.TOXICITY_PANEL),
            DrugTargetAction(action_type=ActionType.CLINICAL_TRIAL_LOOKUP),
        ]
        for a in actions:
            obs = env.step(a)
            assert obs.latest_output is not None
            assert obs.latest_output.success is True, (
                f"Step {a.action_type} failed: {obs.rule_violations}"
            )
        assert obs.step_index == len(actions)
        assert obs.credits_remaining < obs.credits_total

    def test_submit_report_ends_episode(self):
        env = DrugTargetEnvironment(
            scenario_name="egfr_nsclc_viable", domain_randomise=False,
        )
        env.reset(seed=1)
        env.step(DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION))
        env.step(DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN))
        obs = env.step(DrugTargetAction(
            action_type=ActionType.SUBMIT_VALIDATION_REPORT,
            final_decision="go",
            confidence=0.85,
            reasoning="Test conclusion",
        ))
        assert obs.done is True

    def test_submit_without_evidence_is_blocked(self):
        env = DrugTargetEnvironment(
            scenario_name="egfr_nsclc_viable", domain_randomise=False,
        )
        env.reset(seed=2)
        obs = env.step(DrugTargetAction(
            action_type=ActionType.SUBMIT_VALIDATION_REPORT,
            final_decision="go",
            confidence=0.85,
        ))
        assert obs.latest_output is not None
        assert obs.latest_output.success is False
        assert any("evidence" in m.lower() for m in obs.rule_violations)

    def test_credit_exhaustion_ends_episode(self):
        env = DrugTargetEnvironment(
            scenario_name="egfr_nsclc_viable", domain_randomise=False,
        )
        env.reset(seed=3)
        # Drain the credit pool by manipulating the latent state directly,
        # then take a single cheap action to trigger the exhaustion check.
        env._latent.credits.credits_used = env._latent.credits.credits_total - 1
        obs = env.step(
            DrugTargetAction(action_type=ActionType.LITERATURE_SEARCH)
        )
        assert obs.done is True
        assert obs.credits_remaining == 0