Spaces:
Sleeping
Sleeping
| import random | |
| from typing import Any, Callable, Dict, List, Optional | |
| from pydantic import BaseModel, ConfigDict, Field | |
| from data import TASK_DATA | |
| class GroundTruth(BaseModel): | |
| classification: str | |
| suspect_drug: str | |
| severity_assessment: str | |
| recommended_action: str | |
| class TaskDefinition(BaseModel): | |
| model_config = ConfigDict(arbitrary_types_allowed=True, revalidate_instances="never") | |
| task_id: str = Field(..., description="Unique pharmacovigilance task identifier") | |
| difficulty: str = Field(..., description="easy | medium | hard") | |
| reports: List[Any] = Field(default_factory=list, description="Synthetic adverse event reports") | |
| drug_interaction_db: dict = Field(default_factory=dict, description="Hardcoded interaction and safety context") | |
| ground_truth: GroundTruth | |
| action_grader: Callable[[Any], Any] | |
| description: str = Field(default="", description="Human-readable task summary") | |
| def id(self) -> str: | |
| return self.task_id | |
| def _base_breakdown(action: Any, ground_truth: GroundTruth) -> dict: | |
| action_suspect = action.suspect_drug.strip().lower() | |
| truth_suspect = ground_truth.suspect_drug.strip().lower() | |
| suspect_match = ( | |
| action_suspect == truth_suspect | |
| or action_suspect in truth_suspect | |
| or truth_suspect in action_suspect | |
| ) | |
| breakdown = { | |
| "classification": 0.25 if action.classification == ground_truth.classification else 0.0, | |
| "suspect_drug": 0.25 if suspect_match else 0.0, | |
| "severity_assessment": 0.20 if action.severity_assessment == ground_truth.severity_assessment else 0.0, | |
| "recommended_action": 0.15 if action.recommended_action == ground_truth.recommended_action else 0.0, | |
| "consistency_bonus": 0.0, | |
| "confidence_adjustment": 0.0, | |
| "false_alarm_penalty": 0.0, | |
| "missed_signal_penalty": 0.0, | |
| "reasoning_bonus": 0.0, | |
| } | |
| if action.classification == "new_signal" and ground_truth.classification == "noise": | |
| breakdown["false_alarm_penalty"] = -0.10 | |
| if action.classification == "noise" and ground_truth.classification == "new_signal": | |
| breakdown["missed_signal_penalty"] = -0.20 | |
| return breakdown | |
| def _is_consistent_pipeline(action: Any, ground_truth: GroundTruth, report_count: int) -> bool: | |
| classification = action.classification | |
| severity = action.severity_assessment | |
| recommendation = action.recommended_action | |
| suspect_drug = action.suspect_drug.strip() | |
| if classification == "known_side_effect": | |
| return bool(suspect_drug) and recommendation in {"log_and_monitor", "request_more_info"} and severity in { | |
| "mild", | |
| "moderate", | |
| "severe", | |
| } | |
| if classification == "new_signal": | |
| return bool(suspect_drug) and recommendation == "escalate" and ( | |
| severity in {"severe", "critical"} or report_count > 1 | |
| ) | |
| if classification == "noise": | |
| return recommendation in {"dismiss", "request_more_info"} and severity in {"mild", "moderate"} | |
| if classification == "duplicate": | |
| return recommendation in {"dismiss", "log_and_monitor"} | |
| return False | |
| def _confidence_adjustment(action: Any, correctness_score: float) -> float: | |
| confidence = getattr(action, "confidence", None) | |
| if confidence is None: | |
| return 0.0 | |
| if confidence >= 80 and correctness_score >= 0.75: | |
| return 0.05 | |
| if confidence >= 80 and correctness_score < 0.40: | |
| return -0.10 | |
| if confidence <= 40 and correctness_score >= 0.75: | |
| return -0.03 | |
| return 0.0 | |
| def _reward_from_breakdown(breakdown: dict): | |
| try: | |
| from .env import Reward | |
| except ImportError: | |
| from env import Reward | |
| total = round(sum(breakdown.values()), 4) | |
| return Reward(total=max(-0.25, min(1.0, total)), breakdown=breakdown) | |
| def known_signal_easy_action_grader(action: Any): | |
| truth = GroundTruth(**TASK_DATA["known_signal_easy"]["ground_truth"]) | |
| breakdown = _base_breakdown(action, truth) | |
| correctness_score = ( | |
| breakdown["classification"] | |
| + breakdown["suspect_drug"] | |
| + breakdown["severity_assessment"] | |
| + breakdown["recommended_action"] | |
| ) | |
| if _is_consistent_pipeline(action, truth, len(TASK_DATA["known_signal_easy"]["reports"])): | |
| breakdown["consistency_bonus"] = 0.10 | |
| breakdown["confidence_adjustment"] = _confidence_adjustment(action, correctness_score) | |
| return _reward_from_breakdown(breakdown) | |
| def cluster_signal_medium_action_grader(action: Any): | |
| truth = GroundTruth(**TASK_DATA["cluster_signal_medium"]["ground_truth"]) | |
| breakdown = _base_breakdown(action, truth) | |
| correctness_score = ( | |
| breakdown["classification"] | |
| + breakdown["suspect_drug"] | |
| + breakdown["severity_assessment"] | |
| + breakdown["recommended_action"] | |
| ) | |
| if _is_consistent_pipeline(action, truth, len(TASK_DATA["cluster_signal_medium"]["reports"])): | |
| breakdown["consistency_bonus"] = 0.10 | |
| breakdown["confidence_adjustment"] = _confidence_adjustment(action, correctness_score) | |
| return _reward_from_breakdown(breakdown) | |
| def confounded_hard_action_grader(action: Any): | |
| truth = GroundTruth(**TASK_DATA["confounded_hard"]["ground_truth"]) | |
| breakdown = _base_breakdown(action, truth) | |
| correctness_score = ( | |
| breakdown["classification"] | |
| + breakdown["suspect_drug"] | |
| + breakdown["severity_assessment"] | |
| + breakdown["recommended_action"] | |
| ) | |
| if _is_consistent_pipeline(action, truth, len(TASK_DATA["confounded_hard"]["reports"])): | |
| breakdown["consistency_bonus"] = 0.10 | |
| breakdown["confidence_adjustment"] = _confidence_adjustment(action, correctness_score) | |
| reasoning = action.reasoning.lower() | |
| if any( | |
| term in reasoning | |
| for term in ("drug interaction", "tacrolimus", "voriconazole", "azole", "calcineurin", "level monitoring") | |
| ): | |
| breakdown["reasoning_bonus"] = 0.05 | |
| return _reward_from_breakdown(breakdown) | |
| def _grader_score_from_trajectory(trajectory: Any = None) -> float: | |
| trajectory = trajectory or {} | |
| raw_score = 0.5 | |
| if isinstance(trajectory, dict): | |
| if "score" in trajectory: | |
| raw_score = float(trajectory["score"]) | |
| elif "rewards" in trajectory and trajectory["rewards"]: | |
| raw_score = float(trajectory["rewards"][-1]) | |
| elif "reward" in trajectory: | |
| reward_val = trajectory["reward"] | |
| if isinstance(reward_val, dict) and "total" in reward_val: | |
| raw_score = float(reward_val["total"]) | |
| else: | |
| raw_score = float(reward_val) | |
| return max(0.01, min(0.99, round(raw_score, 4))) | |
| def known_signal_easy_grader(trajectory: Any = None) -> float: | |
| from server.graders import known_signal_easy_grader as _delegate | |
| return _delegate(trajectory) | |
| def cluster_signal_medium_grader(trajectory: Any = None) -> float: | |
| from server.graders import cluster_signal_medium_grader as _delegate | |
| return _delegate(trajectory) | |
| def confounded_hard_grader(trajectory: Any = None) -> float: | |
| from server.graders import confounded_hard_grader as _delegate | |
| return _delegate(trajectory) | |
| def _task_definition( | |
| task_id: str, | |
| difficulty: str, | |
| description: str, | |
| action_grader: Callable[[Any], Any], | |
| ) -> TaskDefinition: | |
| try: | |
| from .env import AdverseEventReport | |
| except ImportError: | |
| from env import AdverseEventReport | |
| task_data = TASK_DATA[task_id] | |
| return TaskDefinition( | |
| task_id=task_id, | |
| difficulty=difficulty, | |
| reports=[AdverseEventReport(**report) for report in task_data["reports"]], | |
| drug_interaction_db=task_data["drug_interaction_db"], | |
| ground_truth=GroundTruth(**task_data["ground_truth"]), | |
| action_grader=action_grader, | |
| description=description, | |
| ) | |
| def _build_all_tasks() -> Dict[str, List[TaskDefinition]]: | |
| """Build and return the complete task pool grouped by difficulty.""" | |
| return { | |
| "easy": [ | |
| _task_definition( | |
| "known_signal_easy", | |
| "easy", | |
| "Known ACE-inhibitor cough case that should be logged and monitored rather than escalated.", | |
| known_signal_easy_action_grader, | |
| ), | |
| ], | |
| "medium": [ | |
| _task_definition( | |
| "cluster_signal_medium", | |
| "medium", | |
| "Cluster of bradycardia reports around a newly approved therapy that should be escalated as a signal.", | |
| cluster_signal_medium_action_grader, | |
| ), | |
| ], | |
| "hard": [ | |
| _task_definition( | |
| "confounded_hard", | |
| "hard", | |
| "Confounded transplant case where the blamed drug is wrong and the real problem is a tacrolimus interaction.", | |
| confounded_hard_action_grader, | |
| ), | |
| ], | |
| } | |
| def get_tasks( | |
| difficulty: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| n: int = 5, | |
| grouped: bool = False, | |
| ): | |
| """ | |
| Return tasks either as a flat task-id map or a difficulty-filtered list. | |
| Args: | |
| difficulty: Optional difficulty tier to select from. | |
| seed: Optional seed for reproducible shuffling within a difficulty pool. | |
| n: Maximum number of tasks to return when selecting by difficulty. | |
| grouped: When True and difficulty is None, return the difficulty-grouped dict. | |
| Returns: | |
| If grouped=True and difficulty is None: | |
| Dict[str, List[TaskDefinition]] | |
| If difficulty is None: | |
| Dict[str, TaskDefinition] | |
| Otherwise: | |
| List[TaskDefinition] | |
| """ | |
| all_tasks = _build_all_tasks() | |
| if difficulty is None: | |
| if grouped: | |
| return {level: tasks[:n] for level, tasks in all_tasks.items()} | |
| return { | |
| task.task_id: task | |
| for tasks in all_tasks.values() | |
| for task in tasks[:n] | |
| } | |
| pool = list(all_tasks.get(difficulty, [])) | |
| if seed is not None: | |
| rng = random.Random(seed) | |
| rng.shuffle(pool) | |
| return pool[:n] | |
| def get_task(task_id: str) -> TaskDefinition: | |
| tasks = get_tasks() | |
| if task_id not in tasks: | |
| raise KeyError(f"Unknown task_id: {task_id}") | |
| return tasks[task_id] | |