adithya9903's picture
Flatten project to root for OpenEnv submission readiness.
fa51dd9
"""Deterministic heuristic baseline agent for PolypharmacyEnv.
Strategy:
1. Query all unordered medication pairs for DDIs (within budget),
prioritising high-risk elderly drugs first.
2. For each severe DDI found, attempt substitution or stop.
3. For each moderate DDI found, attempt substitution or stop.
4. For remaining budget, address Beers-flagged "avoid" drugs.
5. Call finish_review.
"""
from __future__ import annotations
from itertools import combinations
from typing import List, Tuple
from ..env_core import PolypharmacyEnv
from ..models import PolypharmacyAction, PolypharmacyObservation
def run_heuristic_episode(
env: PolypharmacyEnv,
task_id: str = "budgeted_screening",
seed: int | None = None,
) -> Tuple[float, float, int]:
"""Run one episode with the heuristic agent.
Returns (total_reward, grader_score, steps).
"""
obs = env.reset(seed=seed, task_id=task_id)
total_reward = 0.0
grader_score = 0.0
steps = 0
# Phase 1: Query DDIs between medication pairs, prioritising high-risk drugs
meds = obs.current_medications
# Sort: high-risk elderly drugs first, then by Beers flag count
meds_sorted = sorted(
meds,
key=lambda m: (not m.is_high_risk_elderly, -len(m.beers_flags), m.drug_id),
)
med_ids = [m.drug_id for m in meds_sorted]
pairs: List[Tuple[str, str]] = list(combinations(med_ids, 2))
severe_pairs: List[Tuple[str, str]] = []
moderate_pairs: List[Tuple[str, str]] = []
for a, b in pairs:
if obs.remaining_query_budget <= 0:
break
action = PolypharmacyAction(
action_type="query_ddi",
drug_id_1=a,
drug_id_2=b,
)
obs = env.step(action)
reward = obs.reward or 0.0
total_reward += reward
steps += 1
if obs.done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Track severity from metadata
ddi_info = obs.metadata.get("ddi_result", {})
sev = ddi_info.get("severity", "none")
if sev == "severe":
severe_pairs.append((a, b))
elif sev == "moderate":
moderate_pairs.append((a, b))
# Phase 2: Intervene on severe DDI drugs first
intervened: set[str] = set()
def _try_intervene(
target: str,
rationale: str,
) -> Tuple[bool, PolypharmacyObservation]:
"""Try substitute then stop. Returns (done, obs)."""
nonlocal total_reward, steps
# Try substitute first
act = PolypharmacyAction(
action_type="propose_intervention",
target_drug_id=target,
intervention_type="substitute",
rationale=rationale,
)
obs_new = env.step(act)
total_reward += obs_new.reward or 0.0
steps += 1
if obs_new.done:
return True, obs_new
# If substitute failed, try stop
if obs_new.metadata.get("warning"):
if obs_new.remaining_intervention_budget <= 0:
return False, obs_new
act2 = PolypharmacyAction(
action_type="propose_intervention",
target_drug_id=target,
intervention_type="stop",
rationale=f"No substitute; {rationale}",
)
obs_new = env.step(act2)
total_reward += obs_new.reward or 0.0
steps += 1
if obs_new.done:
return True, obs_new
return False, obs_new
# Intervene on severe pairs
for a, b in severe_pairs:
if obs.remaining_intervention_budget <= 0:
break
target = b if a in intervened else a
if target in intervened:
target = b
if target in intervened:
continue
intervened.add(target)
done, obs = _try_intervene(target, f"Severe DDI between {a} and {b}")
if done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Phase 2b: Intervene on moderate DDI drugs
for a, b in moderate_pairs:
if obs.remaining_intervention_budget <= 0:
break
target = b if a in intervened else a
if target in intervened:
target = b
if target in intervened:
continue
intervened.add(target)
done, obs = _try_intervene(target, f"Moderate DDI between {a} and {b}")
if done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Phase 3: Address Beers-flagged "avoid" drugs
for med in meds_sorted:
if obs.remaining_intervention_budget <= 0:
break
if med.drug_id in intervened:
continue
if not med.beers_flags:
continue
if any("avoid" in f for f in med.beers_flags):
intervened.add(med.drug_id)
done, obs = _try_intervene(
med.drug_id, f"Beers criteria: {', '.join(med.beers_flags)}"
)
if done:
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
# Phase 4: Finish
action = PolypharmacyAction(action_type="finish_review")
obs = env.step(action)
total_reward += obs.reward or 0.0
steps += 1
grader_score = obs.metadata.get("grader_score", 0.0)
return total_reward, grader_score, steps
def run_heuristic_baseline(
n_episodes: int = 5,
task_ids: List[str] | None = None,
) -> None:
"""Run the heuristic agent across tasks and print results."""
if task_ids is None:
task_ids = ["easy_screening", "budgeted_screening", "complex_tradeoff"]
env = PolypharmacyEnv()
for tid in task_ids:
scores: list[float] = []
rewards: list[float] = []
for i in range(n_episodes):
total_r, score, steps = run_heuristic_episode(env, task_id=tid, seed=i)
scores.append(score)
rewards.append(total_r)
print(f" [{tid}] ep={i} steps={steps} reward={total_r:.4f} score={score:.4f}")
avg_s = sum(scores) / len(scores) if scores else 0.0
avg_r = sum(rewards) / len(rewards) if rewards else 0.0
print(f" [{tid}] avg_score={avg_s:.4f} avg_reward={avg_r:.4f}\n")
if __name__ == "__main__":
run_heuristic_baseline()