Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Optional, Tuple | |
| from pydantic import BaseModel, Field | |
| from tasks import TaskDefinition, get_task | |
| class AdverseEventReport(BaseModel): | |
| report_id: str | |
| patient_age: int | |
| patient_sex: str | |
| drugs: List[str] | |
| suspect_drug: str | |
| reaction: str | |
| onset_days: int | |
| severity: str | |
| outcome: str | |
| similar_reports_last_30d: int | |
| class Observation(BaseModel): | |
| task_id: str | |
| reports: List[AdverseEventReport] | |
| drug_interaction_db: dict | |
| step_number: int | |
| max_steps: int | |
| feedback: Optional[str] = None | |
| class Action(BaseModel): | |
| classification: str | |
| suspect_drug: str | |
| severity_assessment: str | |
| recommended_action: str | |
| reasoning: str | |
| class Reward(BaseModel): | |
| total: float = Field(..., ge=0.0, le=1.0) | |
| breakdown: dict | |
| class PharmaVigilanceEnv: | |
| def __init__(self): | |
| self.current_task: Optional[TaskDefinition] = None | |
| self.current_task_id: Optional[str] = None | |
| self.step_number = 0 | |
| self.max_steps = 1 | |
| self.last_action: Optional[dict] = None | |
| self.last_reward: Optional[dict] = None | |
| def reset(self, task_id: str = "known_signal_easy") -> Observation: | |
| self.current_task = get_task(task_id) | |
| self.current_task_id = self.current_task.task_id | |
| self.step_number = 0 | |
| self.last_action = None | |
| self.last_reward = None | |
| return Observation( | |
| task_id=self.current_task.task_id, | |
| reports=self.current_task.reports, | |
| drug_interaction_db=self.current_task.drug_interaction_db, | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| feedback="Task loaded. Submit one final pharmacovigilance assessment.", | |
| ) | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]: | |
| if self.current_task is None: | |
| raise RuntimeError("Call reset() before step().") | |
| reward = self.current_task.action_grader(action) | |
| self.step_number += 1 | |
| self.last_action = action.model_dump() | |
| self.last_reward = reward.model_dump() | |
| done = True | |
| matched = sum( | |
| 1 | |
| for field in ( | |
| "classification", | |
| "suspect_drug", | |
| "severity_assessment", | |
| "recommended_action", | |
| ) | |
| if reward.breakdown.get(field, 0.0) > 0 | |
| ) | |
| if reward.total >= 0.9: | |
| feedback = "Strong assessment. The key safety judgment and follow-up were correct." | |
| elif reward.total >= 0.5: | |
| feedback = "Partially correct assessment. Some causal or operational details were missed." | |
| else: | |
| feedback = "Weak assessment. This case would need human analyst correction." | |
| observation = Observation( | |
| task_id=self.current_task.task_id, | |
| reports=self.current_task.reports, | |
| drug_interaction_db=self.current_task.drug_interaction_db, | |
| step_number=self.step_number, | |
| max_steps=self.max_steps, | |
| feedback=feedback, | |
| ) | |
| info = { | |
| "matched_fields": matched, | |
| "difficulty": self.current_task.difficulty, | |
| "reward_breakdown": reward.breakdown, | |
| } | |
| return observation, reward, done, info | |
| def state(self) -> dict: | |
| return { | |
| "task_id": self.current_task_id, | |
| "step_number": self.step_number, | |
| "last_action": self.last_action, | |
| "last_reward": self.last_reward, | |
| } | |