File size: 2,449 Bytes
ad12dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Tests for the drug-target-validation POMDP schema models."""

import pytest

from models import (
    ActionType,
    CreditUsage,
    DrugTargetAction,
    EvidenceDossier,
    IntermediateOutput,
    OutputType,
    ValidationObservation,
    ValidationStepRecord,
    ValidationTaskSpec,
)


def test_drug_target_action_roundtrip():
    a = DrugTargetAction(
        action_type=ActionType.QUERY_EXPRESSION,
        parameters={"database": "GTEx"},
        reasoning="Establish tissue baseline",
    )
    d = a.model_dump()
    assert d["action_type"] == "query_expression"
    assert d["parameters"]["database"] == "GTEx"
    reconstructed = DrugTargetAction(**d)
    assert reconstructed.action_type == ActionType.QUERY_EXPRESSION


def test_drug_target_action_terminal_payload():
    a = DrugTargetAction(
        action_type=ActionType.SUBMIT_VALIDATION_REPORT,
        reasoning="All evidence gathered",
        final_decision="go",
        confidence=0.82,
    )
    assert a.final_decision == "go"
    assert a.confidence == pytest.approx(0.82)


def test_validation_observation_defaults():
    obs = ValidationObservation(done=False, reward=0.0)
    assert obs.step_index == 0
    assert obs.pipeline_history == []
    assert obs.credits_remaining == 50
    assert obs.credits_total == 50
    assert isinstance(obs.dossier, EvidenceDossier)


def test_intermediate_output_quality_bounds():
    with pytest.raises(Exception):
        IntermediateOutput(
            output_type=OutputType.EXPRESSION_RESULT,
            step_index=1,
            quality_score=1.5,
        )


def test_validation_task_spec_defaults():
    t = ValidationTaskSpec()
    assert t.credits_limit == 50
    assert t.expected_findings == []
    assert "submit_validation_report" in t.available_actions


def test_validation_step_record_roundtrip():
    rec = ValidationStepRecord(
        step_index=2,
        action_type=ActionType.DRUGGABILITY_SCREEN,
        parameters={"include_allosteric": True},
        output_summary="Druggability score=0.81",
        output_type=OutputType.DRUGGABILITY_RESULT,
        success=True,
        quality_score=0.9,
        credit_cost=3,
    )
    d = rec.model_dump()
    assert d["action_type"] == "druggability_screen"
    assert d["credit_cost"] == 3


def test_credit_usage_defaults():
    cu = CreditUsage()
    assert cu.credits_total == 50
    assert cu.credits_remaining == 50
    assert cu.credits_used == 0