File size: 8,721 Bytes
60c0453
 
 
 
 
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c0453
 
 
 
 
 
 
f2beac3
 
9ab33d8
 
 
f2beac3
 
60c0453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab33d8
60c0453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c0453
 
 
 
 
 
 
f2beac3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from typing import Dict, List, Optional, Tuple

from pydantic import BaseModel, Field

from tasks import TaskDefinition, get_task


class AdverseEventReport(BaseModel):
    report_id: str
    patient_age: int
    patient_sex: str
    drugs: List[str]
    suspect_drug: str
    reaction: str
    onset_days: int
    severity: str
    outcome: str
    similar_reports_last_30d: int


class Observation(BaseModel):
    task_id: str
    reports: List[AdverseEventReport]
    drug_interaction_db: dict
    step_number: int
    max_steps: int
    feedback: Optional[str] = None


class Action(BaseModel):
    classification: str
    suspect_drug: str
    severity_assessment: str
    recommended_action: str
    reasoning: str
    confidence: Optional[int] = Field(default=None, ge=0, le=100)


class Reward(BaseModel):
    total: float = Field(..., ge=-1.0, le=1.0)
    breakdown: dict


class PharmaVigilanceEnv:
    def __init__(self):
        self.current_task: Optional[TaskDefinition] = None
        self.current_task_id: Optional[str] = None
        self.step_number = 0
        self.max_steps = 2
        self.last_action: Optional[dict] = None
        self.last_reward: Optional[dict] = None
        self.initial_action: Optional[Action] = None
        self.initial_reward: Optional[Reward] = None

    def _review_note(self) -> str:
        notes = {
            "known_signal_easy": (
                "Senior review note: labeling already documents ACE-inhibitor cough, "
                "and the recent case volume suggests this is a routine known-effect triage question."
            ),
            "cluster_signal_medium": (
                "Senior review note: the safety mailbox added 3 follow-up summaries showing "
                "symptomatic bradycardia with no competing causative drug class in common."
            ),
            "confounded_hard": (
                "Senior review note: tacrolimus trough levels returned at 4x baseline after "
                "recent voriconazole exposure, which is more mechanistically informative than the reporter's blamed drug."
            ),
        }
        return notes.get(self.current_task_id or "", "Senior review note: additional case review context is now available.")

    @staticmethod
    def _clamp_reward(total: float, breakdown: dict) -> Reward:
        return Reward(total=max(-0.25, min(1.0, round(total, 4))), breakdown=breakdown)

    def _initial_triage_reward(self, action: Action) -> Reward:
        truth = self.current_task.ground_truth
        action_suspect = action.suspect_drug.strip().lower()
        truth_suspect = truth.suspect_drug.strip().lower()
        suspect_match = (
            action_suspect == truth_suspect
            or action_suspect in truth_suspect
            or truth_suspect in action_suspect
        )

        breakdown = {
            "initial_classification": 0.15 if action.classification == truth.classification else 0.0,
            "initial_suspect_drug": 0.15 if suspect_match else 0.0,
            "initial_severity": 0.05 if action.severity_assessment == truth.severity_assessment else 0.0,
            "initial_action": 0.05 if action.recommended_action == truth.recommended_action else 0.0,
            "initial_false_alarm_penalty": 0.0,
            "initial_missed_signal_penalty": 0.0,
        }

        if action.classification == "new_signal" and truth.classification == "noise":
            breakdown["initial_false_alarm_penalty"] = -0.05
        if action.classification == "noise" and truth.classification == "new_signal":
            breakdown["initial_missed_signal_penalty"] = -0.10

        return self._clamp_reward(sum(breakdown.values()), breakdown)

    def _finalize_reward(self, action: Action) -> Reward:
        final_reward = self.current_task.action_grader(action)
        breakdown = dict(final_reward.breakdown)
        initial_total = self.initial_reward.total if self.initial_reward else 0.0

        breakdown["revision_bonus"] = 0.0
        breakdown["stubborn_penalty"] = 0.0
        breakdown["flip_penalty"] = 0.0

        if final_reward.total - initial_total >= 0.20:
            breakdown["revision_bonus"] = 0.05

        if (
            self.initial_action is not None
            and initial_total < 0.20
            and self.initial_action.classification == action.classification
            and self.initial_action.suspect_drug.strip().lower() == action.suspect_drug.strip().lower()
            and self.initial_action.recommended_action == action.recommended_action
        ):
            breakdown["stubborn_penalty"] = -0.05

        if self.initial_action is not None and initial_total >= 0.70 and initial_total - final_reward.total >= 0.25:
            breakdown["flip_penalty"] = -0.04

        return self._clamp_reward(sum(breakdown.values()), breakdown)

    def reset(self, task_id: str = "known_signal_easy") -> Observation:
        self.current_task = get_task(task_id)
        self.current_task_id = self.current_task.task_id
        self.step_number = 0
        self.last_action = None
        self.last_reward = None
        self.initial_action = None
        self.initial_reward = None
        return Observation(
            task_id=self.current_task.task_id,
            reports=self.current_task.reports,
            drug_interaction_db=self.current_task.drug_interaction_db,
            step_number=self.step_number,
            max_steps=self.max_steps,
            feedback="Task loaded. Submit an initial triage, then revise after senior review context arrives.",
        )

    def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]:
        if self.current_task is None:
            raise RuntimeError("Call reset() before step().")
        if self.step_number >= self.max_steps:
            raise RuntimeError("Episode already complete. Call reset() before another step().")

        if self.step_number == 0:
            reward = self._initial_triage_reward(action)
            self.initial_action = action
            self.initial_reward = reward
            self.step_number += 1
            self.last_action = action.model_dump()
            self.last_reward = reward.model_dump()
            done = False
            observation = Observation(
                task_id=self.current_task.task_id,
                reports=self.current_task.reports,
                drug_interaction_db=self.current_task.drug_interaction_db,
                step_number=self.step_number,
                max_steps=self.max_steps,
                feedback=(
                    "Initial triage recorded. "
                    f"{self._review_note()} "
                    "Review the added context and submit your final assessment."
                ),
            )
            info = {
                "phase": "initial_triage",
                "difficulty": self.current_task.difficulty,
                "reward_breakdown": reward.breakdown,
            }
            return observation, reward, done, info

        reward = self._finalize_reward(action)
        self.step_number += 1
        self.last_action = action.model_dump()
        self.last_reward = reward.model_dump()
        done = True

        matched = sum(
            1
            for field in (
                "classification",
                "suspect_drug",
                "severity_assessment",
                "recommended_action",
            )
            if reward.breakdown.get(field, 0.0) > 0
        )

        if reward.total >= 0.9:
            feedback = "Strong assessment. The key safety judgment and follow-up were correct."
        elif reward.total >= 0.5:
            feedback = "Partially correct assessment. Some causal or operational details were missed."
        else:
            feedback = "Weak assessment. This case would need human analyst correction."

        observation = Observation(
            task_id=self.current_task.task_id,
            reports=self.current_task.reports,
            drug_interaction_db=self.current_task.drug_interaction_db,
            step_number=self.step_number,
            max_steps=self.max_steps,
            feedback=feedback,
        )
        info = {
            "matched_fields": matched,
            "difficulty": self.current_task.difficulty,
            "phase": "final_review",
            "reward_breakdown": reward.breakdown,
        }
        return observation, reward, done, info

    def state(self) -> dict:
        return {
            "task_id": self.current_task_id,
            "step_number": self.step_number,
            "last_action": self.last_action,
            "last_reward": self.last_reward,
        }