File size: 5,244 Bytes
ad12dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Tests for the latent-state simulator modules."""

from models import ActionType, DrugTargetAction, OutputType
from server.simulator.latent_state import (
    CreditState,
    DataQualityState,
    FullLatentState,
    TargetProfile,
    ValidationProgress,
)
from server.simulator.noise import NoiseModel
from server.simulator.output_generator import OutputGenerator
from server.simulator.transition import TransitionEngine


def _make_state() -> FullLatentState:
    return FullLatentState(
        target=TargetProfile(
            expression_level="high_specific",
            tissue_specificity=0.8,
            disease_overexpression=3.0,
            druggability_score=0.7,
            binding_pocket_quality="good",
            has_known_ligands=True,
            allosteric_site_available=True,
            selectivity_ratio=8.0,
            off_target_count=2,
            off_target_genes=["OFF1", "OFF2"],
            toxicity_profile="mild",
            toxicity_tissues=[],
            clinical_precedent="positive",
            clinical_stage_reached="phase2",
            in_vitro_ic50_nM=10.0,
            in_vivo_efficacy="strong",
            crispr_essentiality=-1.0,
            true_viability_score=0.8,
            correct_decision="go",
            key_evidence_dimensions=["expression", "druggability"],
        ),
        data_quality=DataQualityState(noise_level=0.1),
        progress=ValidationProgress(),
        credits=CreditState(credits_total=50),
    )


class TestNoiseModel:
    def test_deterministic_with_seed(self):
        n1 = NoiseModel(seed=42)
        n2 = NoiseModel(seed=42)
        assert n1.sample_qc_metric(0.5, 0.1) == n2.sample_qc_metric(0.5, 0.1)

    def test_false_positives(self):
        n = NoiseModel(seed=0)
        fps = n.generate_false_positives(1000, 0.01)
        assert all(g.startswith("FP_GENE_") for g in fps)

    def test_quality_degradation_bounded(self):
        n = NoiseModel(seed=0)
        for _ in range(100):
            q = n.quality_degradation(0.9, [0.8, 0.7])
            assert 0.0 <= q <= 1.0


class TestOutputGenerator:
    def test_query_expression_returns_expression_result(self):
        gen = OutputGenerator(NoiseModel(seed=1))
        out = gen.generate(
            DrugTargetAction(
                action_type=ActionType.QUERY_EXPRESSION,
                parameters={"database": "GTEx"},
            ),
            _make_state(),
            1,
        )
        assert out.output_type == OutputType.EXPRESSION_RESULT
        assert out.data["expression_level"] == "high_specific"

    def test_druggability_screen_returns_druggability_result(self):
        gen = OutputGenerator(NoiseModel(seed=42))
        out = gen.generate(
            DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN),
            _make_state(),
            2,
        )
        assert out.output_type == OutputType.DRUGGABILITY_RESULT
        assert "druggability_score" in out.data

    def test_binding_site_allosteric_flag(self):
        gen = OutputGenerator(NoiseModel(seed=7))
        out = gen.generate(
            DrugTargetAction(
                action_type=ActionType.BINDING_SITE_ANALYSIS,
                parameters={"include_allosteric": True},
            ),
            _make_state(),
            3,
        )
        assert out.output_type == OutputType.BINDING_SITE_RESULT
        assert out.data["allosteric_site_detected"] is True


class TestTransitionEngine:
    def test_progress_flag_set_after_action(self):
        engine = TransitionEngine(NoiseModel(seed=0))
        result = engine.step(
            _make_state(),
            DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
        )
        assert result.next_state.progress.expression_queried is True

    def test_credit_deduction(self):
        engine = TransitionEngine(NoiseModel(seed=0))
        state = _make_state()
        result = engine.step(
            state,
            DrugTargetAction(action_type=ActionType.IN_VIVO_MODEL),
        )
        assert result.next_state.credits.credits_used == 8

    def test_hard_violation_blocks(self):
        engine = TransitionEngine(NoiseModel(seed=0))
        result = engine.step(
            _make_state(),
            DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
            hard_violations=["test_block"],
        )
        assert result.output.success is False
        assert result.output.output_type == OutputType.FAILURE_REPORT

    def test_submit_validation_report_ends_episode(self):
        engine = TransitionEngine(NoiseModel(seed=0))
        result = engine.step(
            _make_state(),
            DrugTargetAction(
                action_type=ActionType.SUBMIT_VALIDATION_REPORT,
                final_decision="go",
                confidence=0.8,
            ),
        )
        assert result.done is True

    def test_credit_exhaustion_ends_episode(self):
        engine = TransitionEngine(NoiseModel(seed=0))
        state = _make_state()
        state.credits.credits_used = state.credits.credits_total - 1
        result = engine.step(
            state,
            DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
        )
        assert result.done is True