adithya9903's picture
Flatten project to root for OpenEnv submission readiness.
fa51dd9
"""PolypharmacyEnv – core environment implementing OpenEnv step / reset / state."""
from __future__ import annotations
from copy import deepcopy
from itertools import combinations
from typing import Any, Dict, List, Optional, Tuple
from openenv.core.env_server.interfaces import Environment
from .config import CRITICAL_DRUG_IDS, TaskConfig
from .data_loader import PatientEpisode
from .ddi_simulator import DDISimulator
from .graders import (
grade_budgeted_screening,
grade_complex_tradeoff,
grade_easy_screening,
)
from .models import (
InteractionQueryRecord,
InterventionRecord,
MedicationEntry,
PolypharmacyAction,
PolypharmacyObservation,
PolypharmacyState,
)
from .rewards import compute_regimen_risk, compute_shaped_reward
from .tasks import get_task_config, sample_episode
class PolypharmacyEnv(
Environment[PolypharmacyAction, PolypharmacyObservation, PolypharmacyState]
):
"""OpenEnv-compliant environment for elderly polypharmacy medication review.
Extends openenv.core.env_server.interfaces.Environment with typed
Action/Observation/State generics.
"""
def __init__(self) -> None:
super().__init__()
self._sim = DDISimulator()
self._task_cfg: Optional[TaskConfig] = None
self._episode: Optional[PatientEpisode] = None
self._medications: List[MedicationEntry] = []
self._interaction_queries: List[InteractionQueryRecord] = []
self._interventions: List[InterventionRecord] = []
self._risk_deltas: List[float] = [] # per-intervention risk improvement
self._step_count: int = 0
self._done: bool = True
self._baseline_risk: float = 0.0
self._current_risk: float = 0.0
self._remaining_query_budget: int = 0
self._remaining_intervention_budget: int = 0
self._severe_moderate_discovered: int = 0
self._total_drug_changes: int = 0
self._critical_stopped_without_sub: int = 0
self._last_reward: float = 0.0
# ── reset ────────────────────────────────────────────────────────────────
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> PolypharmacyObservation:
task_id = kwargs.get("task_id", None)
self._task_cfg = get_task_config(task_id)
self._episode = sample_episode(task_id, seed=seed, episode_id=episode_id)
# Build medication list
self._medications = []
for did in self._episode.medication_ids:
meta = self._sim.get_drug_meta(did)
if meta is None:
continue
flags = self._sim.get_beers_flags(did, self._episode.conditions)
self._medications.append(MedicationEntry(
drug_id=did,
generic_name=meta.generic_name,
atc_class=meta.atc_class,
dose_mg=meta.default_dose_mg,
is_high_risk_elderly=meta.is_high_risk_elderly,
beers_flags=flags,
))
self._interaction_queries = []
self._interventions = []
self._risk_deltas = []
self._step_count = 0
self._done = False
self._remaining_query_budget = self._task_cfg.query_budget
self._remaining_intervention_budget = self._task_cfg.intervention_budget
self._severe_moderate_discovered = 0
self._total_drug_changes = 0
self._critical_stopped_without_sub = 0
self._last_reward = 0.0
# Compute baseline risk
self._baseline_risk = self._compute_risk()
self._current_risk = self._baseline_risk
return self._make_observation()
# ── step ─────────────────────────────────────────────────────────────────
def step(
self,
action: PolypharmacyAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> PolypharmacyObservation:
if self._done:
return self._make_observation()
assert self._task_cfg is not None
assert self._episode is not None
reward = 0.0
info: Dict[str, Any] = {}
# Validate basic action structure
valid, err = self._validate_action(action)
if not valid:
reward = compute_shaped_reward(
self._current_risk, self._current_risk,
action.action_type, is_invalid=True,
)
info["error"] = err
self._step_count += 1
return self._check_timeout_and_build_obs(reward, info)
if action.action_type == "query_ddi":
reward, info = self._handle_query(action)
elif action.action_type == "propose_intervention":
reward, info = self._handle_intervention(action)
elif action.action_type == "finish_review":
self._done = True
score = self._run_grader()
reward = score # terminal bonus
info["grader_score"] = score
self._step_count += 1
return self._check_timeout_and_build_obs(reward, info)
# ── state property ───────────────────────────────────────────────────────
@property
def state(self) -> PolypharmacyState:
return PolypharmacyState(
episode_id=self._episode.episode_id if self._episode else None,
step_count=self._step_count,
task_id=self._task_cfg.task_id if self._task_cfg else "",
max_steps=self._task_cfg.max_steps if self._task_cfg else 0,
num_query_actions=len(self._interaction_queries),
num_interventions=len(self._interventions),
)
# ── Internal helpers ─────────────────────────────────────────────────────
def _compute_risk(self) -> float:
drug_ids = [m.drug_id for m in self._medications]
return compute_regimen_risk(
drug_ids,
self._episode.conditions if self._episode else [],
self._sim.ddi_rules,
self._sim.beers_criteria,
self._sim.drug_metadata,
)
def _validate_action(self, action: PolypharmacyAction) -> Tuple[bool, str]:
if action.action_type == "query_ddi":
if not action.drug_id_1 or not action.drug_id_2:
return False, "query_ddi requires drug_id_1 and drug_id_2"
elif action.action_type == "propose_intervention":
if not action.target_drug_id:
return False, "propose_intervention requires target_drug_id"
if action.intervention_type in (None, "none"):
return False, "propose_intervention requires a valid intervention_type"
return True, ""
def _handle_query(self, action: PolypharmacyAction) -> Tuple[float, Dict[str, Any]]:
info: Dict[str, Any] = {}
assert action.drug_id_1 and action.drug_id_2
if self._remaining_query_budget <= 0:
reward = compute_shaped_reward(
self._current_risk, self._current_risk,
"query_ddi", is_invalid=True,
)
info["error"] = "Query budget exhausted"
return reward, info
result = self._sim.lookup_ddi(action.drug_id_1, action.drug_id_2)
self._remaining_query_budget -= 1
self._interaction_queries.append(InteractionQueryRecord(
drug_id_1=action.drug_id_1,
drug_id_2=action.drug_id_2,
severity=result.severity,
recommendation=result.recommendation,
risk_score=result.base_risk_score,
step_index=self._step_count,
))
discovered_severe = result.severity in ("severe", "moderate")
if discovered_severe:
self._severe_moderate_discovered += 1
reward = compute_shaped_reward(
self._current_risk, self._current_risk,
"query_ddi",
discovered_severe=(result.severity == "severe"),
)
info["ddi_result"] = {
"severity": result.severity,
"recommendation": result.recommendation,
"risk_score": result.base_risk_score,
}
return reward, info
def _handle_intervention(self, action: PolypharmacyAction) -> Tuple[float, Dict[str, Any]]:
info: Dict[str, Any] = {}
assert action.target_drug_id
assert action.intervention_type and action.intervention_type != "none"
if self._remaining_intervention_budget <= 0:
reward = compute_shaped_reward(
self._current_risk, self._current_risk,
"propose_intervention", is_invalid=True,
)
info["error"] = "Intervention budget exhausted"
return reward, info
# Find target medication
target_idx: Optional[int] = None
for i, m in enumerate(self._medications):
if m.drug_id == action.target_drug_id:
target_idx = i
break
if target_idx is None:
reward = compute_shaped_reward(
self._current_risk, self._current_risk,
"propose_intervention", is_invalid=True,
)
info["error"] = f"Drug {action.target_drug_id} not in current medications"
return reward, info
previous_risk = self._current_risk
target_med = self._medications[target_idx]
if action.intervention_type == "stop":
self._medications.pop(target_idx)
self._total_drug_changes += 1
if action.target_drug_id in CRITICAL_DRUG_IDS:
self._critical_stopped_without_sub += 1
elif action.intervention_type == "dose_reduce":
meta = self._sim.get_drug_meta(action.target_drug_id)
if meta:
new_dose = max(meta.min_dose_mg, target_med.dose_mg * 0.5)
self._medications[target_idx] = target_med.model_copy(
update={"dose_mg": new_dose}
)
elif action.intervention_type == "substitute":
new_drug_id = action.proposed_new_drug_id
if not new_drug_id:
# Auto-find substitute
current_ids = [m.drug_id for m in self._medications]
new_drug_id = self._sim.find_substitute(action.target_drug_id, current_ids)
if new_drug_id:
new_meta = self._sim.get_drug_meta(new_drug_id)
if new_meta:
flags = self._sim.get_beers_flags(
new_drug_id,
self._episode.conditions if self._episode else [],
)
self._medications[target_idx] = MedicationEntry(
drug_id=new_drug_id,
generic_name=new_meta.generic_name,
atc_class=new_meta.atc_class,
dose_mg=new_meta.default_dose_mg,
is_high_risk_elderly=new_meta.is_high_risk_elderly,
beers_flags=flags,
)
self._total_drug_changes += 1
# If critical drug was substituted, don't penalise
if action.target_drug_id in CRITICAL_DRUG_IDS:
pass # substitution is acceptable
else:
info["warning"] = f"Substitute {new_drug_id} not found in metadata"
# Don't consume budget for a failed substitute
self._remaining_intervention_budget += 1
else:
info["warning"] = "No suitable substitute found"
# Don't consume budget for a failed substitute
self._remaining_intervention_budget += 1
elif action.intervention_type == "add_monitoring":
# Tag in metadata but don't change regimen
self._medications[target_idx] = target_med.model_copy(
update={"beers_flags": target_med.beers_flags + ["monitored"]}
)
self._remaining_intervention_budget -= 1
self._current_risk = self._compute_risk()
risk_delta = previous_risk - self._current_risk
self._risk_deltas.append(risk_delta)
self._interventions.append(InterventionRecord(
target_drug_id=action.target_drug_id,
action_type=action.intervention_type,
proposed_new_drug_id=action.proposed_new_drug_id,
rationale=action.rationale or "",
step_index=self._step_count,
))
reward = compute_shaped_reward(previous_risk, self._current_risk, "propose_intervention")
info["risk_delta"] = risk_delta
return reward, info
def _run_grader(self) -> float:
assert self._task_cfg is not None
tid = self._task_cfg.task_id
if tid == "easy_screening":
severe_pairs = self._get_severe_pairs()
return grade_easy_screening(
self._baseline_risk,
self._current_risk,
self._interventions,
severe_pairs,
)
elif tid == "budgeted_screening":
return grade_budgeted_screening(
self._baseline_risk,
self._current_risk,
self._interventions,
self._risk_deltas,
len(self._interaction_queries),
self._severe_moderate_discovered,
)
elif tid == "complex_tradeoff":
return grade_complex_tradeoff(
self._baseline_risk,
self._current_risk,
self._interventions,
self._total_drug_changes,
self._critical_stopped_without_sub,
)
return 0.0
def _get_severe_pairs(self) -> List[Tuple[str, str]]:
"""Return all severe DDI pairs present in the *initial* medication list."""
if not self._episode:
return []
pairs: List[Tuple[str, str]] = []
med_ids = self._episode.medication_ids
for a, b in combinations(sorted(set(med_ids)), 2):
key = (a, b) if a < b else (b, a)
rule = self._sim.ddi_rules.get(key)
if rule and rule.severity == "severe":
pairs.append(key)
return pairs
def _check_timeout_and_build_obs(
self, reward: float, info: Dict[str, Any]
) -> PolypharmacyObservation:
assert self._task_cfg is not None
if not self._done and self._step_count >= self._task_cfg.max_steps:
self._done = True
timeout_penalty = compute_shaped_reward(
self._current_risk, self._current_risk,
"finish_review", is_timeout=True,
)
score = self._run_grader()
reward += timeout_penalty + score
info["timeout"] = True
info["grader_score"] = score
self._last_reward = reward
info["current_risk"] = self._current_risk
info["baseline_risk"] = self._baseline_risk
return self._make_observation(reward=reward, info=info)
def _make_observation(
self, reward: float = 0.0, info: Optional[Dict[str, Any]] = None,
) -> PolypharmacyObservation:
ep = self._episode
cfg = self._task_cfg
return PolypharmacyObservation(
episode_id=ep.episode_id if ep else "",
task_id=cfg.task_id if cfg else "budgeted_screening",
age=ep.age if ep else 65,
sex=ep.sex if ep else "M",
conditions=ep.conditions if ep else [],
eGFR_category=ep.eGFR_category if ep else "normal",
liver_function_category=ep.liver_function_category if ep else "normal",
current_medications=deepcopy(self._medications),
interaction_queries=deepcopy(self._interaction_queries),
interventions=deepcopy(self._interventions),
step_index=self._step_count,
remaining_query_budget=self._remaining_query_budget,
remaining_intervention_budget=self._remaining_intervention_budget,
shaped_reward=reward,
done=self._done,
reward=reward,
metadata=info or {},
)