modelbuilderhq's picture
Upload folder using huggingface_hub
f2beac3 verified
raw
history blame
3.69 kB
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from tasks import TaskDefinition, get_task
class AdverseEventReport(BaseModel):
report_id: str
patient_age: int
patient_sex: str
drugs: List[str]
suspect_drug: str
reaction: str
onset_days: int
severity: str
outcome: str
similar_reports_last_30d: int
class Observation(BaseModel):
task_id: str
reports: List[AdverseEventReport]
drug_interaction_db: dict
step_number: int
max_steps: int
feedback: Optional[str] = None
class Action(BaseModel):
classification: str
suspect_drug: str
severity_assessment: str
recommended_action: str
reasoning: str
class Reward(BaseModel):
total: float = Field(..., ge=0.0, le=1.0)
breakdown: dict
class PharmaVigilanceEnv:
def __init__(self):
self.current_task: Optional[TaskDefinition] = None
self.current_task_id: Optional[str] = None
self.step_number = 0
self.max_steps = 1
self.last_action: Optional[dict] = None
self.last_reward: Optional[dict] = None
def reset(self, task_id: str = "known_signal_easy") -> Observation:
self.current_task = get_task(task_id)
self.current_task_id = self.current_task.task_id
self.step_number = 0
self.last_action = None
self.last_reward = None
return Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback="Task loaded. Submit one final pharmacovigilance assessment.",
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]:
if self.current_task is None:
raise RuntimeError("Call reset() before step().")
reward = self.current_task.action_grader(action)
self.step_number += 1
self.last_action = action.model_dump()
self.last_reward = reward.model_dump()
done = True
matched = sum(
1
for field in (
"classification",
"suspect_drug",
"severity_assessment",
"recommended_action",
)
if reward.breakdown.get(field, 0.0) > 0
)
if reward.total >= 0.9:
feedback = "Strong assessment. The key safety judgment and follow-up were correct."
elif reward.total >= 0.5:
feedback = "Partially correct assessment. Some causal or operational details were missed."
else:
feedback = "Weak assessment. This case would need human analyst correction."
observation = Observation(
task_id=self.current_task.task_id,
reports=self.current_task.reports,
drug_interaction_db=self.current_task.drug_interaction_db,
step_number=self.step_number,
max_steps=self.max_steps,
feedback=feedback,
)
info = {
"matched_fields": matched,
"difficulty": self.current_task.difficulty,
"reward_breakdown": reward.breakdown,
}
return observation, reward, done, info
def state(self) -> dict:
return {
"task_id": self.current_task_id,
"step_number": self.step_number,
"last_action": self.last_action,
"last_reward": self.last_reward,
}