File size: 7,309 Bytes
77e1e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Transition dynamics engine for the drug-target-validation simulator.

Orchestrates latent-state updates, output generation, credit accounting,
and constraint propagation for every agent action.
"""

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

from models import (
    ActionType,
    DrugTargetAction,
    IntermediateOutput,
    OutputType,
)

from .latent_state import FullLatentState
from .noise import NoiseModel
from .output_generator import OutputGenerator


# Credit costs per ActionType.
_BASE_ACTION_COSTS: Dict[ActionType, int] = {
    ActionType.QUERY_EXPRESSION: 2,
    ActionType.DIFFERENTIAL_EXPRESSION: 2,
    ActionType.PATHWAY_ENRICHMENT: 2,
    ActionType.COEXPRESSION_NETWORK: 2,
    ActionType.PROTEIN_STRUCTURE_LOOKUP: 3,
    ActionType.BINDING_SITE_ANALYSIS: 3,
    ActionType.PROTEIN_INTERACTION_NETWORK: 2,
    ActionType.DRUGGABILITY_SCREEN: 3,
    ActionType.CLINICAL_TRIAL_LOOKUP: 3,
    ActionType.TOXICITY_PANEL: 3,
    ActionType.OFF_TARGET_SCREEN: 3,
    ActionType.PATIENT_STRATIFICATION: 3,
    ActionType.LITERATURE_SEARCH: 1,
    ActionType.EVIDENCE_SYNTHESIS: 1,
    ActionType.COMPETITOR_LANDSCAPE: 1,
    ActionType.IN_VITRO_ASSAY: 5,
    ActionType.IN_VIVO_MODEL: 8,
    ActionType.CRISPR_KNOCKOUT: 4,
    ActionType.BIOMARKER_CORRELATION: 3,
    ActionType.FLAG_RED_FLAG: 0,
    ActionType.REQUEST_EXPERT_REVIEW: 1,
    ActionType.SUBMIT_VALIDATION_REPORT: 0,
}

# Public alias kept for callers that historically imported ACTION_COSTS.
ACTION_COSTS = _BASE_ACTION_COSTS


def compute_action_cost(action: DrugTargetAction) -> int:
    """Return the credit cost for a single action."""
    return _BASE_ACTION_COSTS.get(action.action_type, 0)


# Map action type → progress flag that should be set when it succeeds.
_PROGRESS_MAP: Dict[ActionType, str] = {
    ActionType.QUERY_EXPRESSION: "expression_queried",
    ActionType.DIFFERENTIAL_EXPRESSION: "expression_queried",
    ActionType.PATHWAY_ENRICHMENT: "pathway_analysed",
    ActionType.COEXPRESSION_NETWORK: "interactions_mapped",
    ActionType.PROTEIN_STRUCTURE_LOOKUP: "structure_resolved",
    ActionType.BINDING_SITE_ANALYSIS: "druggability_assessed",
    ActionType.PROTEIN_INTERACTION_NETWORK: "interactions_mapped",
    ActionType.DRUGGABILITY_SCREEN: "druggability_assessed",
    ActionType.CLINICAL_TRIAL_LOOKUP: "clinical_checked",
    ActionType.TOXICITY_PANEL: "toxicity_assessed",
    ActionType.OFF_TARGET_SCREEN: "selectivity_checked",
    ActionType.PATIENT_STRATIFICATION: "patient_stratification_done",
    ActionType.LITERATURE_SEARCH: "literature_reviewed",
    ActionType.EVIDENCE_SYNTHESIS: "evidence_synthesised",
    ActionType.COMPETITOR_LANDSCAPE: "literature_reviewed",
    ActionType.IN_VITRO_ASSAY: "in_vitro_done",
    ActionType.IN_VIVO_MODEL: "in_vivo_done",
    ActionType.CRISPR_KNOCKOUT: "crispr_done",
    ActionType.BIOMARKER_CORRELATION: "biomarker_correlated",
    ActionType.REQUEST_EXPERT_REVIEW: "expert_reviewed",
    ActionType.SUBMIT_VALIDATION_REPORT: "report_submitted",
}


@dataclass
class TransitionResult:
    """Bundle returned by the transition engine after one step."""

    next_state: FullLatentState
    output: IntermediateOutput
    reward_components: Dict[str, float] = field(default_factory=dict)
    hard_violations: List[str] = field(default_factory=list)
    soft_violations: List[str] = field(default_factory=list)
    done: bool = False


class TransitionEngine:
    """Applies one action to the latent state, producing the next state and
    a simulated intermediate output. Delegates output generation to
    ``OutputGenerator``.
    """

    def __init__(self, noise: NoiseModel):
        self.noise = noise
        self.output_gen = OutputGenerator(noise)

    def step(
        self,
        state: FullLatentState,
        action: DrugTargetAction,
        *,
        hard_violations: Optional[List[str]] = None,
        soft_violations: Optional[List[str]] = None,
    ) -> TransitionResult:
        s = deepcopy(state)
        step_idx = sum(s.action_call_counts.values()) + 1

        hard_v = hard_violations or []
        soft_v = soft_violations or []

        if hard_v:
            output = IntermediateOutput(
                output_type=OutputType.FAILURE_REPORT,
                step_index=step_idx,
                success=False,
                summary=f"Action blocked: {'; '.join(hard_v)}",
            )
            done = action.action_type == ActionType.SUBMIT_VALIDATION_REPORT
            return TransitionResult(
                next_state=s,
                output=output,
                hard_violations=hard_v,
                soft_violations=soft_v,
                done=done,
            )

        # Track call counts before deduction so the rule engine can use
        # them when reasoning about redundancy on the next step.
        key = action.action_type.value
        s.action_call_counts[key] = s.action_call_counts.get(key, 0) + 1

        # Deduct credits.
        cost = compute_action_cost(action)
        s.credits.credits_used += cost

        # If credits exhausted *and* this isn't a terminal report, the
        # episode ends with a failure-style output (the caller still
        # records the action).
        credits_exhausted_after = s.credits.exhausted

        # Generate the simulated output.
        output = self.output_gen.generate(action, s, step_idx)

        if soft_v:
            output.quality_score = float(max(0.0, output.quality_score * 0.7))
            output.warnings = list(output.warnings) + list(soft_v)

        # Update progress flags for successful actions.
        flag = _PROGRESS_MAP.get(action.action_type)
        if flag and output.success:
            setattr(s.progress, flag, True)

        # Determine episode termination.
        done = (
            action.action_type == ActionType.SUBMIT_VALIDATION_REPORT
            or credits_exhausted_after
        )

        return TransitionResult(
            next_state=s,
            output=output,
            soft_violations=soft_v,
            done=done,
        )

    @staticmethod
    def covered_evidence_dimensions(s: FullLatentState) -> List[str]:
        """Return the set of *evidence dimensions* the agent has touched.

        Mirrors the keys used in ``TargetProfile.key_evidence_dimensions``
        so the reward computer can compute coverage directly.
        """
        p = s.progress
        flags: List[Tuple[str, bool]] = [
            ("expression", p.expression_queried),
            ("druggability", p.druggability_assessed),
            ("off_target", p.selectivity_checked),
            ("toxicity", p.toxicity_assessed),
            ("clinical", p.clinical_checked),
            ("literature", p.literature_reviewed),
            ("in_vitro", p.in_vitro_done),
            ("in_vivo", p.in_vivo_done),
            ("patient_stratification", p.patient_stratification_done),
            ("pathway", p.pathway_analysed),
            ("structure", p.structure_resolved),
            ("interactions", p.interactions_mapped),
            ("crispr", p.crispr_done),
            ("biomarker", p.biomarker_correlated),
        ]
        return [name for name, hit in flags if hit]