"""Deterministic heuristic baseline agent for PolypharmacyEnv. Strategy: 1. Query all unordered medication pairs for DDIs (within budget), prioritising high-risk elderly drugs first. 2. For each severe DDI found, attempt substitution or stop. 3. For each moderate DDI found, attempt substitution or stop. 4. For remaining budget, address Beers-flagged "avoid" drugs. 5. Call finish_review. """ from __future__ import annotations from itertools import combinations from typing import List, Tuple from ..env_core import PolypharmacyEnv from ..models import PolypharmacyAction, PolypharmacyObservation def run_heuristic_episode( env: PolypharmacyEnv, task_id: str = "budgeted_screening", seed: int | None = None, ) -> Tuple[float, float, int]: """Run one episode with the heuristic agent. Returns (total_reward, grader_score, steps). """ obs = env.reset(seed=seed, task_id=task_id) total_reward = 0.0 grader_score = 0.0 steps = 0 # Phase 1: Query DDIs between medication pairs, prioritising high-risk drugs meds = obs.current_medications # Sort: high-risk elderly drugs first, then by Beers flag count meds_sorted = sorted( meds, key=lambda m: (not m.is_high_risk_elderly, -len(m.beers_flags), m.drug_id), ) med_ids = [m.drug_id for m in meds_sorted] pairs: List[Tuple[str, str]] = list(combinations(med_ids, 2)) severe_pairs: List[Tuple[str, str]] = [] moderate_pairs: List[Tuple[str, str]] = [] for a, b in pairs: if obs.remaining_query_budget <= 0: break action = PolypharmacyAction( action_type="query_ddi", drug_id_1=a, drug_id_2=b, ) obs = env.step(action) reward = obs.reward or 0.0 total_reward += reward steps += 1 if obs.done: grader_score = obs.metadata.get("grader_score", 0.0) return total_reward, grader_score, steps # Track severity from metadata ddi_info = obs.metadata.get("ddi_result", {}) sev = ddi_info.get("severity", "none") if sev == "severe": severe_pairs.append((a, b)) elif sev == "moderate": moderate_pairs.append((a, b)) # Phase 2: Intervene on severe DDI drugs first intervened: set[str] = set() def _try_intervene( target: str, rationale: str, ) -> Tuple[bool, PolypharmacyObservation]: """Try substitute then stop. Returns (done, obs).""" nonlocal total_reward, steps # Try substitute first act = PolypharmacyAction( action_type="propose_intervention", target_drug_id=target, intervention_type="substitute", rationale=rationale, ) obs_new = env.step(act) total_reward += obs_new.reward or 0.0 steps += 1 if obs_new.done: return True, obs_new # If substitute failed, try stop if obs_new.metadata.get("warning"): if obs_new.remaining_intervention_budget <= 0: return False, obs_new act2 = PolypharmacyAction( action_type="propose_intervention", target_drug_id=target, intervention_type="stop", rationale=f"No substitute; {rationale}", ) obs_new = env.step(act2) total_reward += obs_new.reward or 0.0 steps += 1 if obs_new.done: return True, obs_new return False, obs_new # Intervene on severe pairs for a, b in severe_pairs: if obs.remaining_intervention_budget <= 0: break target = b if a in intervened else a if target in intervened: target = b if target in intervened: continue intervened.add(target) done, obs = _try_intervene(target, f"Severe DDI between {a} and {b}") if done: grader_score = obs.metadata.get("grader_score", 0.0) return total_reward, grader_score, steps # Phase 2b: Intervene on moderate DDI drugs for a, b in moderate_pairs: if obs.remaining_intervention_budget <= 0: break target = b if a in intervened else a if target in intervened: target = b if target in intervened: continue intervened.add(target) done, obs = _try_intervene(target, f"Moderate DDI between {a} and {b}") if done: grader_score = obs.metadata.get("grader_score", 0.0) return total_reward, grader_score, steps # Phase 3: Address Beers-flagged "avoid" drugs for med in meds_sorted: if obs.remaining_intervention_budget <= 0: break if med.drug_id in intervened: continue if not med.beers_flags: continue if any("avoid" in f for f in med.beers_flags): intervened.add(med.drug_id) done, obs = _try_intervene( med.drug_id, f"Beers criteria: {', '.join(med.beers_flags)}" ) if done: grader_score = obs.metadata.get("grader_score", 0.0) return total_reward, grader_score, steps # Phase 4: Finish action = PolypharmacyAction(action_type="finish_review") obs = env.step(action) total_reward += obs.reward or 0.0 steps += 1 grader_score = obs.metadata.get("grader_score", 0.0) return total_reward, grader_score, steps def run_heuristic_baseline( n_episodes: int = 5, task_ids: List[str] | None = None, ) -> None: """Run the heuristic agent across tasks and print results.""" if task_ids is None: task_ids = ["easy_screening", "budgeted_screening", "complex_tradeoff"] env = PolypharmacyEnv() for tid in task_ids: scores: list[float] = [] rewards: list[float] = [] for i in range(n_episodes): total_r, score, steps = run_heuristic_episode(env, task_id=tid, seed=i) scores.append(score) rewards.append(total_r) print(f" [{tid}] ep={i} steps={steps} reward={total_r:.4f} score={score:.4f}") avg_s = sum(scores) / len(scores) if scores else 0.0 avg_r = sum(rewards) / len(rewards) if rewards else 0.0 print(f" [{tid}] avg_score={avg_s:.4f} avg_reward={avg_r:.4f}\n") if __name__ == "__main__": run_heuristic_baseline()