Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Optional, Tuple | |
| from pydantic import BaseModel, Field | |
| from tasks import TaskDefinition, get_task | |
| class AdverseEventReport(BaseModel): | |
| report_id: str | |
| patient_age: int | |
| patient_sex: str | |
| drugs: List[str] | |
| suspect_drug: str | |
| reaction: str | |
| onset_days: int | |
| severity: str | |
| outcome: str | |
| similar_reports_last_30d: int | |
| class Observation(BaseModel): | |
| task_id: str | |
| reports: List[AdverseEventReport] | |
| drug_interaction_db: dict | |
| step_number: int | |
| max_steps: int | |
| feedback: Optional[str] = None | |
| class Action(BaseModel): | |
| classification: str | |
| suspect_drug: str | |
| severity_assessment: str | |
| recommended_action: str | |
| reasoning: str | |
| confidence: Optional[int] = Field(default=None, ge=0, le=100) | |
| class Reward(BaseModel): | |
| total: float = Field(..., ge=-1.0, le=1.0) | |
| breakdown: dict | |
| class PharmaVigilanceEnv: | |
| def __init__(self): | |
| self.current_task: Optional[TaskDefinition] = None | |
| self.current_task_id: Optional[str] = None | |
| self.step_number = 0 | |
| self.max_steps = 2 | |
| self.last_action: Optional[dict] = None | |
| self.last_reward: Optional[dict] = None | |
| self.initial_action: Optional[Action] = None | |
| self.initial_reward: Optional[Reward] = None | |
| def _review_note(self) -> str: | |
| notes = { | |
| "known_signal_easy": ( | |
| "Senior review note: labeling already documents ACE-inhibitor cough, " | |
| "and the recent case volume suggests this is a routine known-effect triage question." | |
| ), | |
| "cluster_signal_medium": ( | |
| "Senior review note: the safety mailbox added 3 follow-up summaries showing " | |
| "symptomatic bradycardia with no competing causative drug class in common." | |
| ), | |
| "confounded_hard": ( | |
| "Senior review note: tacrolimus trough levels returned at 4x baseline after " | |
| "recent voriconazole exposure, which is more mechanistically informative than the reporter's blamed drug." | |
| ), | |
| } | |
| return notes.get(self.current_task_id or "", "Senior review note: additional case review context is now available.") | |
| def _clamp_reward(total: float, breakdown: dict) -> Reward: | |
| return Reward(total=max(-0.25, min(1.0, round(total, 4))), breakdown=breakdown) | |
| def _initial_triage_reward(self, action: Action) -> Reward: | |
| truth = self.current_task.ground_truth | |
| action_suspect = action.suspect_drug.strip().lower() | |
| truth_suspect = truth.suspect_drug.strip().lower() | |
| suspect_match = ( | |
| action_suspect == truth_suspect | |
| or action_suspect in truth_suspect | |
| or truth_suspect in action_suspect | |
| ) | |
| breakdown = { | |
| "initial_classification": 0.15 if action.classification == truth.classification else 0.0, | |
| "initial_suspect_drug": 0.15 if suspect_match else 0.0, | |
| "initial_severity": 0.05 if action.severity_assessment == truth.severity_assessment else 0.0, | |
| "initial_action": 0.05 if action.recommended_action == truth.recommended_action else 0.0, | |
| "initial_false_alarm_penalty": 0.0, | |
| "initial_missed_signal_penalty": 0.0, | |
| } | |
| if action.classification == "new_signal" and truth.classification == "noise": | |
| breakdown["initial_false_alarm_penalty"] = -0.05 | |
| if action.classification == "noise" and truth.classification == "new_signal": | |
| breakdown["initial_missed_signal_penalty"] = -0.10 | |
| return self._clamp_reward(sum(breakdown.values()), breakdown) | |
| def _finalize_reward(self, action: Action) -> Reward: | |
| final_reward = self.current_task.action_grader(action) | |
| breakdown = dict(final_reward.breakdown) | |
| initial_total = self.initial_reward.total if self.initial_reward else 0.0 | |
| breakdown["revision_bonus"] = 0.0 | |
| breakdown["stubborn_penalty"] = 0.0 | |
| breakdown["flip_penalty"] = 0.0 | |
| if final_reward.total - initial_total >= 0.20: | |
| breakdown["revision_bonus"] = 0.05 | |
| if ( | |
| self.initial_action is not None | |
| and initial_total < 0.20 | |
| and self.initial_action.classification == action.classification | |
| and self.initial_action.suspect_drug.strip().lower() == action.suspect_drug.strip().lower() | |
| and self.initial_action.recommended_action == action.recommended_action | |
| ): | |
| breakdown["stubborn_penalty"] = -0.05 | |
| if self.initial_action is not None and initial_total >= 0.70 and initial_total - final_reward.total >= 0.25: | |
| breakdown["flip_penalty"] = -0.04 | |
| return self._clamp_reward(sum(breakdown.values()), breakdown) | |
| def reset(self, task_id: str = "known_signal_easy") -> Observation: | |
| self.current_task = get_task(task_id) | |
| self.current_task_id = self.current_task.task_id | |
| self.step_number = 0 | |
| self.last_action = None | |
| self.last_reward = None | |
| self.initial_action = None | |
| self.initial_reward = None | |
| return Observation( | |
| task_id=self.current_task.task_id, | |
| reports=self.current_task.reports, | |
| drug_interaction_db=self.current_task.drug_interaction_db, | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| feedback="Task loaded. Submit an initial triage, then revise after senior review context arrives.", | |
| ) | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]: | |
| if self.current_task is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self.step_number >= self.max_steps: | |
| raise RuntimeError("Episode already complete. Call reset() before another step().") | |
| if self.step_number == 0: | |
| reward = self._initial_triage_reward(action) | |
| self.initial_action = action | |
| self.initial_reward = reward | |
| self.step_number += 1 | |
| self.last_action = action.model_dump() | |
| self.last_reward = reward.model_dump() | |
| done = False | |
| observation = Observation( | |
| task_id=self.current_task.task_id, | |
| reports=self.current_task.reports, | |
| drug_interaction_db=self.current_task.drug_interaction_db, | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| feedback=( | |
| "Initial triage recorded. " | |
| f"{self._review_note()} " | |
| "Review the added context and submit your final assessment." | |
| ), | |
| ) | |
| info = { | |
| "phase": "initial_triage", | |
| "difficulty": self.current_task.difficulty, | |
| "reward_breakdown": reward.breakdown, | |
| } | |
| return observation, reward, done, info | |
| reward = self._finalize_reward(action) | |
| self.step_number += 1 | |
| self.last_action = action.model_dump() | |
| self.last_reward = reward.model_dump() | |
| done = True | |
| matched = sum( | |
| 1 | |
| for field in ( | |
| "classification", | |
| "suspect_drug", | |
| "severity_assessment", | |
| "recommended_action", | |
| ) | |
| if reward.breakdown.get(field, 0.0) > 0 | |
| ) | |
| if reward.total >= 0.9: | |
| feedback = "Strong assessment. The key safety judgment and follow-up were correct." | |
| elif reward.total >= 0.5: | |
| feedback = "Partially correct assessment. Some causal or operational details were missed." | |
| else: | |
| feedback = "Weak assessment. This case would need human analyst correction." | |
| observation = Observation( | |
| task_id=self.current_task.task_id, | |
| reports=self.current_task.reports, | |
| drug_interaction_db=self.current_task.drug_interaction_db, | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| feedback=feedback, | |
| ) | |
| info = { | |
| "matched_fields": matched, | |
| "difficulty": self.current_task.difficulty, | |
| "phase": "final_review", | |
| "reward_breakdown": reward.breakdown, | |
| } | |
| return observation, reward, done, info | |
| def state(self) -> dict: | |
| return { | |
| "task_id": self.current_task_id, | |
| "step_number": self.step_number, | |
| "last_action": self.last_action, | |
| "last_reward": self.last_reward, | |
| } | |