modelbuilderhq's picture
Upload folder using huggingface_hub
9ab33d8 verified
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from tasks import TaskDefinition, get_task
class AdverseEventReport(BaseModel):
report_id: str
patient_age: int
patient_sex: str
drugs: List[str]
suspect_drug: str
reaction: str
onset_days: int
severity: str
outcome: str
similar_reports_last_30d: int
class Observation(BaseModel):
task_id: str
reports: List[AdverseEventReport]
drug_interaction_db: dict
step_number: int
max_steps: int
feedback: Optional[str] = None
class Action(BaseModel):
classification: str
suspect_drug: str
severity_assessment: str
recommended_action: str
reasoning: str
confidence: Optional[int] = Field(default=None, ge=0, le=100)
class Reward(BaseModel):
total: float = Field(..., ge=-1.0, le=1.0)
breakdown: dict
class PharmaVigilanceEnv:
def __init__(self):
self.current_task: Optional[TaskDefinition] = None
self.current_task_id: Optional[str] = None
self.step_number = 0
self.max_steps = 2
self.last_action: Optional[dict] = None
self.last_reward: Optional[dict] = None
self.initial_action: Optional[Action] = None
self.initial_reward: Optional[Reward] = None
def _review_note(self) -> str:
notes = {
"known_signal_easy": (
"Senior review note: labeling already documents ACE-inhibitor cough, "
"and the recent case volume suggests this is a routine known-effect triage question."
),
"cluster_signal_medium": (
"Senior review note: the safety mailbox added 3 follow-up summaries showing "
"symptomatic bradycardia with no competing causative drug class in common."
),
"confounded_hard": (
"Senior review note: tacrolimus trough levels returned at 4x baseline after "
"recent voriconazole exposure, which is more mechanistically informative than the reporter's blamed drug."
),
}
return notes.get(self.current_task_id or "", "Senior review note: additional case review context is now available.")
@staticmethod
def _clamp_reward(total: float, breakdown: dict) -> Reward:
return Reward(total=max(-0.25, min(1.0, round(total, 4))), breakdown=breakdown)
def _initial_triage_reward(self, action: Action) -> Reward:
truth = self.current_task.ground_truth
action_suspect = action.suspect_drug.strip().lower()
truth_suspect = truth.suspect_drug.strip().lower()
suspect_match = (
action_suspect == truth_suspect
or action_suspect in truth_suspect
or truth_suspect in action_suspect
)
breakdown = {
"initial_classification": 0.15 if action.classification == truth.classification else 0.0,
"initial_suspect_drug": 0.15 if suspect_match else 0.0,
"initial_severity": 0.05 if action.severity_assessment == truth.severity_assessment else 0.0,
"initial_action": 0.05 if action.recommended_action == truth.recommended_action else 0.0,
"initial_false_alarm_penalty": 0.0,
"initial_missed_signal_penalty": 0.0,
}
if action.classification == "new_signal" and truth.classification == "noise":
breakdown["initial_false_alarm_penalty"] = -0.05
if action.classification == "noise" and truth.classification == "new_signal":
breakdown["initial_missed_signal_penalty"] = -0.10
return self._clamp_reward(sum(breakdown.values()), breakdown)
def _finalize_reward(self, action: Action) -> Reward:
final_reward = self.current_task.action_grader(action)
breakdown = dict(final_reward.breakdown)
initial_total = self.initial_reward.total if self.initial_reward else 0.0
breakdown["revision_bonus"] = 0.0
breakdown["stubborn_penalty"] = 0.0
breakdown["flip_penalty"] = 0.0
if final_reward.total - initial_total >= 0.20:
breakdown["revision_bonus"] = 0.05
if (
self.initial_action is not None
and initial_total < 0.20
and self.initial_action.classification == action.classification
and self.initial_action.suspect_drug.strip().lower() == action.suspect_drug.strip().lower()
and self.initial_action.recommended_action == action.recommended_action
):
breakdown["stubborn_penalty"] = -0.05
if self.initial_action is not None and initial_total >= 0.70 and initial_total - final_reward.total >= 0.25:
breakdown["flip_penalty"] = -0.04
return self._clamp_reward(sum(breakdown.values()), breakdown)
def reset(self, task_id: str = "known_signal_easy") -> Observation:
self.current_task = get_task(task_id)
self.current_task_id = self.current_task.task_id
self.step_number = 0
self.last_action = None
self.last_reward = None
self.initial_action = None
self.initial_reward = None
return Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback="Task loaded. Submit an initial triage, then revise after senior review context arrives.",
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]:
if self.current_task is None:
raise RuntimeError("Call reset() before step().")
if self.step_number >= self.max_steps:
raise RuntimeError("Episode already complete. Call reset() before another step().")
if self.step_number == 0:
reward = self._initial_triage_reward(action)
self.initial_action = action
self.initial_reward = reward
self.step_number += 1
self.last_action = action.model_dump()
self.last_reward = reward.model_dump()
done = False
observation = Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback=(
"Initial triage recorded. "
f"{self._review_note()} "
"Review the added context and submit your final assessment."
),
)
info = {
"phase": "initial_triage",
"difficulty": self.current_task.difficulty,
"reward_breakdown": reward.breakdown,
}
return observation, reward, done, info
reward = self._finalize_reward(action)
self.step_number += 1
self.last_action = action.model_dump()
self.last_reward = reward.model_dump()
done = True
matched = sum(
1
for field in (
"classification",
"suspect_drug",
"severity_assessment",
"recommended_action",
)
if reward.breakdown.get(field, 0.0) > 0
)
if reward.total >= 0.9:
feedback = "Strong assessment. The key safety judgment and follow-up were correct."
elif reward.total >= 0.5:
feedback = "Partially correct assessment. Some causal or operational details were missed."
else:
feedback = "Weak assessment. This case would need human analyst correction."
observation = Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback=feedback,
)
info = {
"matched_fields": matched,
"difficulty": self.current_task.difficulty,
"phase": "final_review",
"reward_breakdown": reward.breakdown,
}
return observation, reward, done, info
def state(self) -> dict:
return {
"task_id": self.current_task_id,
"step_number": self.step_number,
"last_action": self.last_action,
"last_reward": self.last_reward,
}