pharma-vigilance / tasks.py
modelbuilderhq's picture
Upload folder using huggingface_hub
9ab33d8 verified
import random
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
from data import TASK_DATA
class GroundTruth(BaseModel):
classification: str
suspect_drug: str
severity_assessment: str
recommended_action: str
class TaskDefinition(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, revalidate_instances="never")
task_id: str = Field(..., description="Unique pharmacovigilance task identifier")
difficulty: str = Field(..., description="easy | medium | hard")
reports: List[Any] = Field(default_factory=list, description="Synthetic adverse event reports")
drug_interaction_db: dict = Field(default_factory=dict, description="Hardcoded interaction and safety context")
ground_truth: GroundTruth
action_grader: Callable[[Any], Any]
description: str = Field(default="", description="Human-readable task summary")
@property
def id(self) -> str:
return self.task_id
def _base_breakdown(action: Any, ground_truth: GroundTruth) -> dict:
action_suspect = action.suspect_drug.strip().lower()
truth_suspect = ground_truth.suspect_drug.strip().lower()
suspect_match = (
action_suspect == truth_suspect
or action_suspect in truth_suspect
or truth_suspect in action_suspect
)
breakdown = {
"classification": 0.25 if action.classification == ground_truth.classification else 0.0,
"suspect_drug": 0.25 if suspect_match else 0.0,
"severity_assessment": 0.20 if action.severity_assessment == ground_truth.severity_assessment else 0.0,
"recommended_action": 0.15 if action.recommended_action == ground_truth.recommended_action else 0.0,
"consistency_bonus": 0.0,
"confidence_adjustment": 0.0,
"false_alarm_penalty": 0.0,
"missed_signal_penalty": 0.0,
"reasoning_bonus": 0.0,
}
if action.classification == "new_signal" and ground_truth.classification == "noise":
breakdown["false_alarm_penalty"] = -0.10
if action.classification == "noise" and ground_truth.classification == "new_signal":
breakdown["missed_signal_penalty"] = -0.20
return breakdown
def _is_consistent_pipeline(action: Any, ground_truth: GroundTruth, report_count: int) -> bool:
classification = action.classification
severity = action.severity_assessment
recommendation = action.recommended_action
suspect_drug = action.suspect_drug.strip()
if classification == "known_side_effect":
return bool(suspect_drug) and recommendation in {"log_and_monitor", "request_more_info"} and severity in {
"mild",
"moderate",
"severe",
}
if classification == "new_signal":
return bool(suspect_drug) and recommendation == "escalate" and (
severity in {"severe", "critical"} or report_count > 1
)
if classification == "noise":
return recommendation in {"dismiss", "request_more_info"} and severity in {"mild", "moderate"}
if classification == "duplicate":
return recommendation in {"dismiss", "log_and_monitor"}
return False
def _confidence_adjustment(action: Any, correctness_score: float) -> float:
confidence = getattr(action, "confidence", None)
if confidence is None:
return 0.0
if confidence >= 80 and correctness_score >= 0.75:
return 0.05
if confidence >= 80 and correctness_score < 0.40:
return -0.10
if confidence <= 40 and correctness_score >= 0.75:
return -0.03
return 0.0
def _reward_from_breakdown(breakdown: dict):
try:
from .env import Reward
except ImportError:
from env import Reward
total = round(sum(breakdown.values()), 4)
return Reward(total=max(-0.25, min(1.0, total)), breakdown=breakdown)
def known_signal_easy_action_grader(action: Any):
truth = GroundTruth(**TASK_DATA["known_signal_easy"]["ground_truth"])
breakdown = _base_breakdown(action, truth)
correctness_score = (
breakdown["classification"]
+ breakdown["suspect_drug"]
+ breakdown["severity_assessment"]
+ breakdown["recommended_action"]
)
if _is_consistent_pipeline(action, truth, len(TASK_DATA["known_signal_easy"]["reports"])):
breakdown["consistency_bonus"] = 0.10
breakdown["confidence_adjustment"] = _confidence_adjustment(action, correctness_score)
return _reward_from_breakdown(breakdown)
def cluster_signal_medium_action_grader(action: Any):
truth = GroundTruth(**TASK_DATA["cluster_signal_medium"]["ground_truth"])
breakdown = _base_breakdown(action, truth)
correctness_score = (
breakdown["classification"]
+ breakdown["suspect_drug"]
+ breakdown["severity_assessment"]
+ breakdown["recommended_action"]
)
if _is_consistent_pipeline(action, truth, len(TASK_DATA["cluster_signal_medium"]["reports"])):
breakdown["consistency_bonus"] = 0.10
breakdown["confidence_adjustment"] = _confidence_adjustment(action, correctness_score)
return _reward_from_breakdown(breakdown)
def confounded_hard_action_grader(action: Any):
truth = GroundTruth(**TASK_DATA["confounded_hard"]["ground_truth"])
breakdown = _base_breakdown(action, truth)
correctness_score = (
breakdown["classification"]
+ breakdown["suspect_drug"]
+ breakdown["severity_assessment"]
+ breakdown["recommended_action"]
)
if _is_consistent_pipeline(action, truth, len(TASK_DATA["confounded_hard"]["reports"])):
breakdown["consistency_bonus"] = 0.10
breakdown["confidence_adjustment"] = _confidence_adjustment(action, correctness_score)
reasoning = action.reasoning.lower()
if any(
term in reasoning
for term in ("drug interaction", "tacrolimus", "voriconazole", "azole", "calcineurin", "level monitoring")
):
breakdown["reasoning_bonus"] = 0.05
return _reward_from_breakdown(breakdown)
def _grader_score_from_trajectory(trajectory: Any = None) -> float:
trajectory = trajectory or {}
raw_score = 0.5
if isinstance(trajectory, dict):
if "score" in trajectory:
raw_score = float(trajectory["score"])
elif "rewards" in trajectory and trajectory["rewards"]:
raw_score = float(trajectory["rewards"][-1])
elif "reward" in trajectory:
reward_val = trajectory["reward"]
if isinstance(reward_val, dict) and "total" in reward_val:
raw_score = float(reward_val["total"])
else:
raw_score = float(reward_val)
return max(0.01, min(0.99, round(raw_score, 4)))
def known_signal_easy_grader(trajectory: Any = None) -> float:
from server.graders import known_signal_easy_grader as _delegate
return _delegate(trajectory)
def cluster_signal_medium_grader(trajectory: Any = None) -> float:
from server.graders import cluster_signal_medium_grader as _delegate
return _delegate(trajectory)
def confounded_hard_grader(trajectory: Any = None) -> float:
from server.graders import confounded_hard_grader as _delegate
return _delegate(trajectory)
def _task_definition(
task_id: str,
difficulty: str,
description: str,
action_grader: Callable[[Any], Any],
) -> TaskDefinition:
try:
from .env import AdverseEventReport
except ImportError:
from env import AdverseEventReport
task_data = TASK_DATA[task_id]
return TaskDefinition(
task_id=task_id,
difficulty=difficulty,
reports=[AdverseEventReport(**report) for report in task_data["reports"]],
drug_interaction_db=task_data["drug_interaction_db"],
ground_truth=GroundTruth(**task_data["ground_truth"]),
action_grader=action_grader,
description=description,
)
def _build_all_tasks() -> Dict[str, List[TaskDefinition]]:
"""Build and return the complete task pool grouped by difficulty."""
return {
"easy": [
_task_definition(
"known_signal_easy",
"easy",
"Known ACE-inhibitor cough case that should be logged and monitored rather than escalated.",
known_signal_easy_action_grader,
),
],
"medium": [
_task_definition(
"cluster_signal_medium",
"medium",
"Cluster of bradycardia reports around a newly approved therapy that should be escalated as a signal.",
cluster_signal_medium_action_grader,
),
],
"hard": [
_task_definition(
"confounded_hard",
"hard",
"Confounded transplant case where the blamed drug is wrong and the real problem is a tacrolimus interaction.",
confounded_hard_action_grader,
),
],
}
def get_tasks(
difficulty: Optional[str] = None,
seed: Optional[int] = None,
n: int = 5,
grouped: bool = False,
):
"""
Return tasks either as a flat task-id map or a difficulty-filtered list.
Args:
difficulty: Optional difficulty tier to select from.
seed: Optional seed for reproducible shuffling within a difficulty pool.
n: Maximum number of tasks to return when selecting by difficulty.
grouped: When True and difficulty is None, return the difficulty-grouped dict.
Returns:
If grouped=True and difficulty is None:
Dict[str, List[TaskDefinition]]
If difficulty is None:
Dict[str, TaskDefinition]
Otherwise:
List[TaskDefinition]
"""
all_tasks = _build_all_tasks()
if difficulty is None:
if grouped:
return {level: tasks[:n] for level, tasks in all_tasks.items()}
return {
task.task_id: task
for tasks in all_tasks.values()
for task in tasks[:n]
}
pool = list(all_tasks.get(difficulty, []))
if seed is not None:
rng = random.Random(seed)
rng.shuffle(pool)
return pool[:n]
def get_task(task_id: str) -> TaskDefinition:
tasks = get_tasks()
if task_id not in tasks:
raise KeyError(f"Unknown task_id: {task_id}")
return tasks[task_id]