Spaces:
Sleeping
Sleeping
File size: 8,721 Bytes
60c0453 f2beac3 60c0453 f2beac3 9ab33d8 f2beac3 60c0453 9ab33d8 60c0453 f2beac3 60c0453 f2beac3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from tasks import TaskDefinition, get_task
class AdverseEventReport(BaseModel):
report_id: str
patient_age: int
patient_sex: str
drugs: List[str]
suspect_drug: str
reaction: str
onset_days: int
severity: str
outcome: str
similar_reports_last_30d: int
class Observation(BaseModel):
task_id: str
reports: List[AdverseEventReport]
drug_interaction_db: dict
step_number: int
max_steps: int
feedback: Optional[str] = None
class Action(BaseModel):
classification: str
suspect_drug: str
severity_assessment: str
recommended_action: str
reasoning: str
confidence: Optional[int] = Field(default=None, ge=0, le=100)
class Reward(BaseModel):
total: float = Field(..., ge=-1.0, le=1.0)
breakdown: dict
class PharmaVigilanceEnv:
def __init__(self):
self.current_task: Optional[TaskDefinition] = None
self.current_task_id: Optional[str] = None
self.step_number = 0
self.max_steps = 2
self.last_action: Optional[dict] = None
self.last_reward: Optional[dict] = None
self.initial_action: Optional[Action] = None
self.initial_reward: Optional[Reward] = None
def _review_note(self) -> str:
notes = {
"known_signal_easy": (
"Senior review note: labeling already documents ACE-inhibitor cough, "
"and the recent case volume suggests this is a routine known-effect triage question."
),
"cluster_signal_medium": (
"Senior review note: the safety mailbox added 3 follow-up summaries showing "
"symptomatic bradycardia with no competing causative drug class in common."
),
"confounded_hard": (
"Senior review note: tacrolimus trough levels returned at 4x baseline after "
"recent voriconazole exposure, which is more mechanistically informative than the reporter's blamed drug."
),
}
return notes.get(self.current_task_id or "", "Senior review note: additional case review context is now available.")
@staticmethod
def _clamp_reward(total: float, breakdown: dict) -> Reward:
return Reward(total=max(-0.25, min(1.0, round(total, 4))), breakdown=breakdown)
def _initial_triage_reward(self, action: Action) -> Reward:
truth = self.current_task.ground_truth
action_suspect = action.suspect_drug.strip().lower()
truth_suspect = truth.suspect_drug.strip().lower()
suspect_match = (
action_suspect == truth_suspect
or action_suspect in truth_suspect
or truth_suspect in action_suspect
)
breakdown = {
"initial_classification": 0.15 if action.classification == truth.classification else 0.0,
"initial_suspect_drug": 0.15 if suspect_match else 0.0,
"initial_severity": 0.05 if action.severity_assessment == truth.severity_assessment else 0.0,
"initial_action": 0.05 if action.recommended_action == truth.recommended_action else 0.0,
"initial_false_alarm_penalty": 0.0,
"initial_missed_signal_penalty": 0.0,
}
if action.classification == "new_signal" and truth.classification == "noise":
breakdown["initial_false_alarm_penalty"] = -0.05
if action.classification == "noise" and truth.classification == "new_signal":
breakdown["initial_missed_signal_penalty"] = -0.10
return self._clamp_reward(sum(breakdown.values()), breakdown)
def _finalize_reward(self, action: Action) -> Reward:
final_reward = self.current_task.action_grader(action)
breakdown = dict(final_reward.breakdown)
initial_total = self.initial_reward.total if self.initial_reward else 0.0
breakdown["revision_bonus"] = 0.0
breakdown["stubborn_penalty"] = 0.0
breakdown["flip_penalty"] = 0.0
if final_reward.total - initial_total >= 0.20:
breakdown["revision_bonus"] = 0.05
if (
self.initial_action is not None
and initial_total < 0.20
and self.initial_action.classification == action.classification
and self.initial_action.suspect_drug.strip().lower() == action.suspect_drug.strip().lower()
and self.initial_action.recommended_action == action.recommended_action
):
breakdown["stubborn_penalty"] = -0.05
if self.initial_action is not None and initial_total >= 0.70 and initial_total - final_reward.total >= 0.25:
breakdown["flip_penalty"] = -0.04
return self._clamp_reward(sum(breakdown.values()), breakdown)
def reset(self, task_id: str = "known_signal_easy") -> Observation:
self.current_task = get_task(task_id)
self.current_task_id = self.current_task.task_id
self.step_number = 0
self.last_action = None
self.last_reward = None
self.initial_action = None
self.initial_reward = None
return Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback="Task loaded. Submit an initial triage, then revise after senior review context arrives.",
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]:
if self.current_task is None:
raise RuntimeError("Call reset() before step().")
if self.step_number >= self.max_steps:
raise RuntimeError("Episode already complete. Call reset() before another step().")
if self.step_number == 0:
reward = self._initial_triage_reward(action)
self.initial_action = action
self.initial_reward = reward
self.step_number += 1
self.last_action = action.model_dump()
self.last_reward = reward.model_dump()
done = False
observation = Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback=(
"Initial triage recorded. "
f"{self._review_note()} "
"Review the added context and submit your final assessment."
),
)
info = {
"phase": "initial_triage",
"difficulty": self.current_task.difficulty,
"reward_breakdown": reward.breakdown,
}
return observation, reward, done, info
reward = self._finalize_reward(action)
self.step_number += 1
self.last_action = action.model_dump()
self.last_reward = reward.model_dump()
done = True
matched = sum(
1
for field in (
"classification",
"suspect_drug",
"severity_assessment",
"recommended_action",
)
if reward.breakdown.get(field, 0.0) > 0
)
if reward.total >= 0.9:
feedback = "Strong assessment. The key safety judgment and follow-up were correct."
elif reward.total >= 0.5:
feedback = "Partially correct assessment. Some causal or operational details were missed."
else:
feedback = "Weak assessment. This case would need human analyst correction."
observation = Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback=feedback,
)
info = {
"matched_fields": matched,
"difficulty": self.current_task.difficulty,
"phase": "final_review",
"reward_breakdown": reward.breakdown,
}
return observation, reward, done, info
def state(self) -> dict:
return {
"task_id": self.current_task_id,
"step_number": self.step_number,
"last_action": self.last_action,
"last_reward": self.last_reward,
}
|