File size: 3,455 Bytes
2043afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Pydantic models for the PolypharmacyEnv environment.

Extends OpenEnv base types (Action, Observation, State) and defines
auxiliary records for medications, interactions, and interventions.
"""

from __future__ import annotations

from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field

from openenv.core.env_server.types import (
    Action as OpenEnvAction,
    Observation as OpenEnvObservation,
    State as OpenEnvState,
)


# ── Auxiliary models ─────────────────────────────────────────────────────────

class MedicationEntry(BaseModel):
    model_config = ConfigDict(extra="forbid")

    drug_id: str
    generic_name: str
    atc_class: str
    dose_mg: float
    frequency: str = "qd"
    route: str = "po"
    is_high_risk_elderly: bool = False
    beers_flags: List[str] = Field(default_factory=list)


class InteractionQueryRecord(BaseModel):
    model_config = ConfigDict(extra="forbid")

    drug_id_1: str
    drug_id_2: str
    severity: Optional[str] = None
    recommendation: Optional[str] = None
    risk_score: Optional[float] = None
    step_index: int = 0


class InterventionRecord(BaseModel):
    model_config = ConfigDict(extra="forbid")

    target_drug_id: str
    action_type: Literal["stop", "dose_reduce", "substitute", "add_monitoring"]
    proposed_new_drug_id: Optional[str] = None
    rationale: str = ""
    step_index: int = 0


# ── OpenEnv wire models ─────────────────────────────────────────────────────

class PolypharmacyAction(OpenEnvAction):
    """Action sent by the agent each step.

    Extends openenv.core.env_server.types.Action.
    """

    action_type: Literal["query_ddi", "propose_intervention", "finish_review"]
    drug_id_1: Optional[str] = None
    drug_id_2: Optional[str] = None
    target_drug_id: Optional[str] = None
    intervention_type: Optional[
        Literal["stop", "dose_reduce", "substitute", "add_monitoring", "none"]
    ] = None
    proposed_new_drug_id: Optional[str] = None
    rationale: Optional[str] = None


class PolypharmacyObservation(OpenEnvObservation):
    """Observation returned to the agent.

    Extends openenv.core.env_server.types.Observation which provides:
    - done: bool
    - reward: float | None
    - metadata: Dict[str, Any]
    """

    episode_id: str = ""
    task_id: str = "budgeted_screening"
    age: int = 65
    sex: str = "M"
    conditions: List[str] = Field(default_factory=list)
    eGFR_category: str = "normal"
    liver_function_category: str = "normal"
    current_medications: List[MedicationEntry] = Field(default_factory=list)
    interaction_queries: List[InteractionQueryRecord] = Field(default_factory=list)
    interventions: List[InterventionRecord] = Field(default_factory=list)
    step_index: int = 0
    remaining_query_budget: int = 0
    remaining_intervention_budget: int = 0
    shaped_reward: float = 0.0


class PolypharmacyState(OpenEnvState):
    """Compact state snapshot for the /state endpoint.

    Extends openenv.core.env_server.types.State which provides:
    - episode_id: str | None
    - step_count: int
    """

    task_id: str = ""
    max_steps: int = 0
    num_query_actions: int = 0
    num_interventions: int = 0