Spaces:
Sleeping
Sleeping
File size: 6,504 Bytes
b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb d110f58 b42dbeb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """Deterministic heuristic baseline agent for PolypharmacyEnv.
Strategy:
1. Query all unordered medication pairs for DDIs (within budget),
prioritising high-risk elderly drugs first.
2. For each severe DDI found, attempt substitution or stop.
3. For each moderate DDI found, attempt substitution or stop.
4. For remaining budget, address Beers-flagged "avoid" drugs.
5. Call finish_review.
"""
from __future__ import annotations
from itertools import combinations
from typing import List, Tuple
from ..env_core import PolypharmacyEnv
from ..models import PolypharmacyAction, PolypharmacyObservation
def run_heuristic_episode(
env: PolypharmacyEnv,
task_id: str = "budgeted_screening",
seed: int | None = None,
) -> Tuple[float, float, int]:
"""Run one episode with the heuristic agent.
Returns (total_reward, grader_score, steps).
"""
obs = env.reset(seed=seed, task_id=task_id)
total_reward = 0.0
grader_score = 0.0
steps = 0
# Phase 1: Query DDIs between medication pairs, prioritising high-risk drugs
meds = obs.current_medications
# Sort: high-risk elderly drugs first, then by Beers flag count
meds_sorted = sorted(
meds,
key=lambda m: (not m.is_high_risk_elderly, -len(m.beers_flags), m.drug_id),
)
med_ids = [m.drug_id for m in meds_sorted]
pairs: List[Tuple[str, str]] = list(combinations(med_ids, 2))
severe_pairs: List[Tuple[str, str]] = []
moderate_pairs: List[Tuple[str, str]] = []
for a, b in pairs:
if obs.remaining_query_budget <= 0:
break
action = PolypharmacyAction(
action_type="query_ddi",
drug_id_1=a,
drug_id_2=b,
)
obs = env.step(action)
reward = obs.reward or 0.0
total_reward += reward
steps += 1
if obs.done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Track severity from metadata
ddi_info = obs.metadata.get("ddi_result", {})
sev = ddi_info.get("severity", "none")
if sev == "severe":
severe_pairs.append((a, b))
elif sev == "moderate":
moderate_pairs.append((a, b))
# Phase 2: Intervene on severe DDI drugs first
intervened: set[str] = set()
def _try_intervene(
target: str,
rationale: str,
) -> Tuple[bool, PolypharmacyObservation]:
"""Try substitute then stop. Returns (done, obs)."""
nonlocal total_reward, steps
# Try substitute first
act = PolypharmacyAction(
action_type="propose_intervention",
target_drug_id=target,
intervention_type="substitute",
rationale=rationale,
)
obs_new = env.step(act)
total_reward += obs_new.reward or 0.0
steps += 1
if obs_new.done:
return True, obs_new
# If substitute failed, try stop
if obs_new.metadata.get("warning"):
if obs_new.remaining_intervention_budget <= 0:
return False, obs_new
act2 = PolypharmacyAction(
action_type="propose_intervention",
target_drug_id=target,
intervention_type="stop",
rationale=f"No substitute; {rationale}",
)
obs_new = env.step(act2)
total_reward += obs_new.reward or 0.0
steps += 1
if obs_new.done:
return True, obs_new
return False, obs_new
# Intervene on severe pairs
for a, b in severe_pairs:
if obs.remaining_intervention_budget <= 0:
break
target = b if a in intervened else a
if target in intervened:
target = b
if target in intervened:
continue
intervened.add(target)
done, obs = _try_intervene(target, f"Severe DDI between {a} and {b}")
if done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Phase 2b: Intervene on moderate DDI drugs
for a, b in moderate_pairs:
if obs.remaining_intervention_budget <= 0:
break
target = b if a in intervened else a
if target in intervened:
target = b
if target in intervened:
continue
intervened.add(target)
done, obs = _try_intervene(target, f"Moderate DDI between {a} and {b}")
if done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Phase 3: Address Beers-flagged "avoid" drugs
for med in meds_sorted:
if obs.remaining_intervention_budget <= 0:
break
if med.drug_id in intervened:
continue
if not med.beers_flags:
continue
if any("avoid" in f for f in med.beers_flags):
intervened.add(med.drug_id)
done, obs = _try_intervene(
med.drug_id, f"Beers criteria: {', '.join(med.beers_flags)}"
)
if done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Phase 4: Finish
action = PolypharmacyAction(action_type="finish_review")
obs = env.step(action)
total_reward += obs.reward or 0.0
steps += 1
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
def run_heuristic_baseline(
n_episodes: int = 5,
task_ids: List[str] | None = None,
) -> None:
"""Run the heuristic agent across tasks and print results."""
if task_ids is None:
task_ids = ["easy_screening", "budgeted_screening", "complex_tradeoff"]
env = PolypharmacyEnv()
for tid in task_ids:
scores: list[float] = []
rewards: list[float] = []
for i in range(n_episodes):
total_r, score, steps = run_heuristic_episode(env, task_id=tid, seed=i)
scores.append(score)
rewards.append(total_r)
print(f" [{tid}] ep={i} steps={steps} reward={total_r:.4f} score={score:.4f}")
avg_s = sum(scores) / len(scores) if scores else 0.0
avg_r = sum(rewards) / len(rewards) if rewards else 0.0
print(f" [{tid}] avg_score={avg_s:.4f} avg_reward={avg_r:.4f}\n")
if __name__ == "__main__":
run_heuristic_baseline()
|