File size: 3,692 Bytes
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Dict, List, Optional, Tuple

from pydantic import BaseModel, Field

from tasks import TaskDefinition, get_task


class AdverseEventReport(BaseModel):
    report_id: str
    patient_age: int
    patient_sex: str
    drugs: List[str]
    suspect_drug: str
    reaction: str
    onset_days: int
    severity: str
    outcome: str
    similar_reports_last_30d: int


class Observation(BaseModel):
    task_id: str
    reports: List[AdverseEventReport]
    drug_interaction_db: dict
    step_number: int
    max_steps: int
    feedback: Optional[str] = None


class Action(BaseModel):
    classification: str
    suspect_drug: str
    severity_assessment: str
    recommended_action: str
    reasoning: str


class Reward(BaseModel):
    total: float = Field(..., ge=0.0, le=1.0)
    breakdown: dict


class PharmaVigilanceEnv:
    def __init__(self):
        self.current_task: Optional[TaskDefinition] = None
        self.current_task_id: Optional[str] = None
        self.step_number = 0
        self.max_steps = 1
        self.last_action: Optional[dict] = None
        self.last_reward: Optional[dict] = None

    def reset(self, task_id: str = "known_signal_easy") -> Observation:
        self.current_task = get_task(task_id)
        self.current_task_id = self.current_task.task_id
        self.step_number = 0
        self.last_action = None
        self.last_reward = None
        return Observation(
            task_id=self.current_task.task_id,
            reports=self.current_task.reports,
            drug_interaction_db=self.current_task.drug_interaction_db,
            step_number=self.step_number,
            max_steps=self.max_steps,
            feedback="Task loaded. Submit one final pharmacovigilance assessment.",
        )

    def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]:
        if self.current_task is None:
            raise RuntimeError("Call reset() before step().")

        reward = self.current_task.action_grader(action)
        self.step_number += 1
        self.last_action = action.model_dump()
        self.last_reward = reward.model_dump()
        done = True

        matched = sum(
            1
            for field in (
                "classification",
                "suspect_drug",
                "severity_assessment",
                "recommended_action",
            )
            if reward.breakdown.get(field, 0.0) > 0
        )

        if reward.total >= 0.9:
            feedback = "Strong assessment. The key safety judgment and follow-up were correct."
        elif reward.total >= 0.5:
            feedback = "Partially correct assessment. Some causal or operational details were missed."
        else:
            feedback = "Weak assessment. This case would need human analyst correction."

        observation = Observation(
            task_id=self.current_task.task_id,
            reports=self.current_task.reports,
            drug_interaction_db=self.current_task.drug_interaction_db,
            step_number=self.step_number,
            max_steps=self.max_steps,
            feedback=feedback,
        )
        info = {
            "matched_fields": matched,
            "difficulty": self.current_task.difficulty,
            "reward_breakdown": reward.breakdown,
        }
        return observation, reward, done, info

    def state(self) -> dict:
        return {
            "task_id": self.current_task_id,
            "step_number": self.step_number,
            "last_action": self.last_action,
            "last_reward": self.last_reward,
        }