File size: 16,755 Bytes
2043afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
"""PolypharmacyEnv – core environment implementing OpenEnv step / reset / state."""

from __future__ import annotations

from copy import deepcopy
from itertools import combinations
from typing import Any, Dict, List, Optional, Tuple

from openenv.core.env_server.interfaces import Environment

from .config import CRITICAL_DRUG_IDS, TaskConfig
from .data_loader import PatientEpisode
from .ddi_simulator import DDISimulator
from .graders import (
    grade_budgeted_screening,
    grade_complex_tradeoff,
    grade_easy_screening,
)
from .models import (
    InteractionQueryRecord,
    InterventionRecord,
    MedicationEntry,
    PolypharmacyAction,
    PolypharmacyObservation,
    PolypharmacyState,
)
from .rewards import compute_regimen_risk, compute_shaped_reward
from .tasks import get_task_config, sample_episode


class PolypharmacyEnv(
    Environment[PolypharmacyAction, PolypharmacyObservation, PolypharmacyState]
):
    """OpenEnv-compliant environment for elderly polypharmacy medication review.

    Extends openenv.core.env_server.interfaces.Environment with typed
    Action/Observation/State generics.
    """

    def __init__(self) -> None:
        super().__init__()
        self._sim = DDISimulator()
        self._task_cfg: Optional[TaskConfig] = None
        self._episode: Optional[PatientEpisode] = None
        self._medications: List[MedicationEntry] = []
        self._interaction_queries: List[InteractionQueryRecord] = []
        self._interventions: List[InterventionRecord] = []
        self._risk_deltas: List[float] = []  # per-intervention risk improvement
        self._step_count: int = 0
        self._done: bool = True
        self._baseline_risk: float = 0.0
        self._current_risk: float = 0.0
        self._remaining_query_budget: int = 0
        self._remaining_intervention_budget: int = 0
        self._severe_moderate_discovered: int = 0
        self._total_drug_changes: int = 0
        self._critical_stopped_without_sub: int = 0
        self._last_reward: float = 0.0

    # ── reset ────────────────────────────────────────────────────────────────

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> PolypharmacyObservation:
        task_id = kwargs.get("task_id", None)
        self._task_cfg = get_task_config(task_id)
        self._episode = sample_episode(task_id, seed=seed, episode_id=episode_id)

        # Build medication list
        self._medications = []
        for did in self._episode.medication_ids:
            meta = self._sim.get_drug_meta(did)
            if meta is None:
                continue
            flags = self._sim.get_beers_flags(did, self._episode.conditions)
            self._medications.append(MedicationEntry(
                drug_id=did,
                generic_name=meta.generic_name,
                atc_class=meta.atc_class,
                dose_mg=meta.default_dose_mg,
                is_high_risk_elderly=meta.is_high_risk_elderly,
                beers_flags=flags,
            ))

        self._interaction_queries = []
        self._interventions = []
        self._risk_deltas = []
        self._step_count = 0
        self._done = False
        self._remaining_query_budget = self._task_cfg.query_budget
        self._remaining_intervention_budget = self._task_cfg.intervention_budget
        self._severe_moderate_discovered = 0
        self._total_drug_changes = 0
        self._critical_stopped_without_sub = 0
        self._last_reward = 0.0

        # Compute baseline risk
        self._baseline_risk = self._compute_risk()
        self._current_risk = self._baseline_risk

        return self._make_observation()

    # ── step ─────────────────────────────────────────────────────────────────

    def step(
        self,
        action: PolypharmacyAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> PolypharmacyObservation:
        if self._done:
            return self._make_observation()

        assert self._task_cfg is not None
        assert self._episode is not None

        reward = 0.0
        info: Dict[str, Any] = {}

        # Validate basic action structure
        valid, err = self._validate_action(action)
        if not valid:
            reward = compute_shaped_reward(
                self._current_risk, self._current_risk,
                action.action_type, is_invalid=True,
            )
            info["error"] = err
            self._step_count += 1
            return self._check_timeout_and_build_obs(reward, info)

        if action.action_type == "query_ddi":
            reward, info = self._handle_query(action)

        elif action.action_type == "propose_intervention":
            reward, info = self._handle_intervention(action)

        elif action.action_type == "finish_review":
            self._done = True
            score = self._run_grader()
            reward = score  # terminal bonus
            info["grader_score"] = score

        self._step_count += 1
        return self._check_timeout_and_build_obs(reward, info)

    # ── state property ───────────────────────────────────────────────────────

    @property
    def state(self) -> PolypharmacyState:
        return PolypharmacyState(
            episode_id=self._episode.episode_id if self._episode else None,
            step_count=self._step_count,
            task_id=self._task_cfg.task_id if self._task_cfg else "",
            max_steps=self._task_cfg.max_steps if self._task_cfg else 0,
            num_query_actions=len(self._interaction_queries),
            num_interventions=len(self._interventions),
        )

    # ── Internal helpers ─────────────────────────────────────────────────────

    def _compute_risk(self) -> float:
        drug_ids = [m.drug_id for m in self._medications]
        return compute_regimen_risk(
            drug_ids,
            self._episode.conditions if self._episode else [],
            self._sim.ddi_rules,
            self._sim.beers_criteria,
            self._sim.drug_metadata,
        )

    def _validate_action(self, action: PolypharmacyAction) -> Tuple[bool, str]:
        if action.action_type == "query_ddi":
            if not action.drug_id_1 or not action.drug_id_2:
                return False, "query_ddi requires drug_id_1 and drug_id_2"
        elif action.action_type == "propose_intervention":
            if not action.target_drug_id:
                return False, "propose_intervention requires target_drug_id"
            if action.intervention_type in (None, "none"):
                return False, "propose_intervention requires a valid intervention_type"
        return True, ""

    def _handle_query(self, action: PolypharmacyAction) -> Tuple[float, Dict[str, Any]]:
        info: Dict[str, Any] = {}
        assert action.drug_id_1 and action.drug_id_2

        if self._remaining_query_budget <= 0:
            reward = compute_shaped_reward(
                self._current_risk, self._current_risk,
                "query_ddi", is_invalid=True,
            )
            info["error"] = "Query budget exhausted"
            return reward, info

        result = self._sim.lookup_ddi(action.drug_id_1, action.drug_id_2)
        self._remaining_query_budget -= 1

        self._interaction_queries.append(InteractionQueryRecord(
            drug_id_1=action.drug_id_1,
            drug_id_2=action.drug_id_2,
            severity=result.severity,
            recommendation=result.recommendation,
            risk_score=result.base_risk_score,
            step_index=self._step_count,
        ))

        discovered_severe = result.severity in ("severe", "moderate")
        if discovered_severe:
            self._severe_moderate_discovered += 1

        reward = compute_shaped_reward(
            self._current_risk, self._current_risk,
            "query_ddi",
            discovered_severe=(result.severity == "severe"),
        )
        info["ddi_result"] = {
            "severity": result.severity,
            "recommendation": result.recommendation,
            "risk_score": result.base_risk_score,
        }
        return reward, info

    def _handle_intervention(self, action: PolypharmacyAction) -> Tuple[float, Dict[str, Any]]:
        info: Dict[str, Any] = {}
        assert action.target_drug_id
        assert action.intervention_type and action.intervention_type != "none"

        if self._remaining_intervention_budget <= 0:
            reward = compute_shaped_reward(
                self._current_risk, self._current_risk,
                "propose_intervention", is_invalid=True,
            )
            info["error"] = "Intervention budget exhausted"
            return reward, info

        # Find target medication
        target_idx: Optional[int] = None
        for i, m in enumerate(self._medications):
            if m.drug_id == action.target_drug_id:
                target_idx = i
                break

        if target_idx is None:
            reward = compute_shaped_reward(
                self._current_risk, self._current_risk,
                "propose_intervention", is_invalid=True,
            )
            info["error"] = f"Drug {action.target_drug_id} not in current medications"
            return reward, info

        previous_risk = self._current_risk
        target_med = self._medications[target_idx]

        if action.intervention_type == "stop":
            self._medications.pop(target_idx)
            self._total_drug_changes += 1
            if action.target_drug_id in CRITICAL_DRUG_IDS:
                self._critical_stopped_without_sub += 1

        elif action.intervention_type == "dose_reduce":
            meta = self._sim.get_drug_meta(action.target_drug_id)
            if meta:
                new_dose = max(meta.min_dose_mg, target_med.dose_mg * 0.5)
                self._medications[target_idx] = target_med.model_copy(
                    update={"dose_mg": new_dose}
                )

        elif action.intervention_type == "substitute":
            new_drug_id = action.proposed_new_drug_id
            if not new_drug_id:
                # Auto-find substitute
                current_ids = [m.drug_id for m in self._medications]
                new_drug_id = self._sim.find_substitute(action.target_drug_id, current_ids)
            if new_drug_id:
                new_meta = self._sim.get_drug_meta(new_drug_id)
                if new_meta:
                    flags = self._sim.get_beers_flags(
                        new_drug_id,
                        self._episode.conditions if self._episode else [],
                    )
                    self._medications[target_idx] = MedicationEntry(
                        drug_id=new_drug_id,
                        generic_name=new_meta.generic_name,
                        atc_class=new_meta.atc_class,
                        dose_mg=new_meta.default_dose_mg,
                        is_high_risk_elderly=new_meta.is_high_risk_elderly,
                        beers_flags=flags,
                    )
                    self._total_drug_changes += 1
                    # If critical drug was substituted, don't penalise
                    if action.target_drug_id in CRITICAL_DRUG_IDS:
                        pass  # substitution is acceptable
                else:
                    info["warning"] = f"Substitute {new_drug_id} not found in metadata"
                    # Don't consume budget for a failed substitute
                    self._remaining_intervention_budget += 1
            else:
                info["warning"] = "No suitable substitute found"
                # Don't consume budget for a failed substitute
                self._remaining_intervention_budget += 1

        elif action.intervention_type == "add_monitoring":
            # Tag in metadata but don't change regimen
            self._medications[target_idx] = target_med.model_copy(
                update={"beers_flags": target_med.beers_flags + ["monitored"]}
            )

        self._remaining_intervention_budget -= 1
        self._current_risk = self._compute_risk()
        risk_delta = previous_risk - self._current_risk
        self._risk_deltas.append(risk_delta)

        self._interventions.append(InterventionRecord(
            target_drug_id=action.target_drug_id,
            action_type=action.intervention_type,
            proposed_new_drug_id=action.proposed_new_drug_id,
            rationale=action.rationale or "",
            step_index=self._step_count,
        ))

        reward = compute_shaped_reward(previous_risk, self._current_risk, "propose_intervention")
        info["risk_delta"] = risk_delta
        return reward, info

    def _run_grader(self) -> float:
        assert self._task_cfg is not None
        tid = self._task_cfg.task_id

        if tid == "easy_screening":
            severe_pairs = self._get_severe_pairs()
            return grade_easy_screening(
                self._baseline_risk,
                self._current_risk,
                self._interventions,
                severe_pairs,
            )
        elif tid == "budgeted_screening":
            return grade_budgeted_screening(
                self._baseline_risk,
                self._current_risk,
                self._interventions,
                self._risk_deltas,
                len(self._interaction_queries),
                self._severe_moderate_discovered,
            )
        elif tid == "complex_tradeoff":
            return grade_complex_tradeoff(
                self._baseline_risk,
                self._current_risk,
                self._interventions,
                self._total_drug_changes,
                self._critical_stopped_without_sub,
            )
        return 0.0

    def _get_severe_pairs(self) -> List[Tuple[str, str]]:
        """Return all severe DDI pairs present in the *initial* medication list."""
        if not self._episode:
            return []
        pairs: List[Tuple[str, str]] = []
        med_ids = self._episode.medication_ids
        for a, b in combinations(sorted(set(med_ids)), 2):
            key = (a, b) if a < b else (b, a)
            rule = self._sim.ddi_rules.get(key)
            if rule and rule.severity == "severe":
                pairs.append(key)
        return pairs

    def _check_timeout_and_build_obs(
        self, reward: float, info: Dict[str, Any]
    ) -> PolypharmacyObservation:
        assert self._task_cfg is not None

        if not self._done and self._step_count >= self._task_cfg.max_steps:
            self._done = True
            timeout_penalty = compute_shaped_reward(
                self._current_risk, self._current_risk,
                "finish_review", is_timeout=True,
            )
            score = self._run_grader()
            reward += timeout_penalty + score
            info["timeout"] = True
            info["grader_score"] = score

        self._last_reward = reward
        info["current_risk"] = self._current_risk
        info["baseline_risk"] = self._baseline_risk

        return self._make_observation(reward=reward, info=info)

    def _make_observation(
        self, reward: float = 0.0, info: Optional[Dict[str, Any]] = None,
    ) -> PolypharmacyObservation:
        ep = self._episode
        cfg = self._task_cfg
        return PolypharmacyObservation(
            episode_id=ep.episode_id if ep else "",
            task_id=cfg.task_id if cfg else "budgeted_screening",
            age=ep.age if ep else 65,
            sex=ep.sex if ep else "M",
            conditions=ep.conditions if ep else [],
            eGFR_category=ep.eGFR_category if ep else "normal",
            liver_function_category=ep.liver_function_category if ep else "normal",
            current_medications=deepcopy(self._medications),
            interaction_queries=deepcopy(self._interaction_queries),
            interventions=deepcopy(self._interventions),
            step_index=self._step_count,
            remaining_query_budget=self._remaining_query_budget,
            remaining_intervention_budget=self._remaining_intervention_budget,
            shaped_reward=reward,
            done=self._done,
            reward=reward,
            metadata=info or {},
        )