Spaces:
Sleeping
Sleeping
| """Groq-powered action suggester for PolypharmacyEnv.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Any | |
| from openai import OpenAI | |
| from ..models import PolypharmacyAction, PolypharmacyObservation | |
| DEFAULT_MODEL = "llama-3.1-8b-instant" | |
| FALLBACK_MODELS = [ | |
| "llama-3.1-8b-instant", | |
| "llama-3.3-70b-versatile", | |
| "gemma2-9b-it", | |
| ] | |
| CRITICAL_DRUG_IDS = {"DRUG_WARFARIN", "DRUG_INSULIN_GLARGINE", "DRUG_DIGOXIN"} | |
| SYSTEM_PROMPT = """You are a clinical medication safety assistant. | |
| Return exactly one JSON object describing the next action. | |
| Allowed output schema: | |
| { | |
| "action_type": "query_ddi" | "propose_intervention" | "finish_review", | |
| "drug_id_1": "optional", | |
| "drug_id_2": "optional", | |
| "target_drug_id": "optional", | |
| "intervention_type": "stop|dose_reduce|substitute|add_monitoring|none", | |
| "proposed_new_drug_id": "optional", | |
| "rationale": "optional" | |
| } | |
| No markdown fences. No extra text. | |
| Do NOT use finish_review early. First, gather evidence with query_ddi and/or | |
| perform at least one meaningful intervention when needed. | |
| """ | |
| def _obs_to_prompt(obs: PolypharmacyObservation) -> str: | |
| meds = ", ".join(m.drug_id for m in obs.current_medications) | |
| conds = ", ".join(obs.conditions) | |
| return ( | |
| f"Task: {obs.task_id}\n" | |
| f"Age: {obs.age}, sex: {obs.sex}\n" | |
| f"Conditions: {conds}\n" | |
| f"Medications: {meds}\n" | |
| f"Query budget: {obs.remaining_query_budget}\n" | |
| f"Intervention budget: {obs.remaining_intervention_budget}\n" | |
| f"Step index: {obs.step_index}\n" | |
| "Choose the single safest, most useful next action." | |
| ) | |
| def _parse_action(text: str) -> PolypharmacyAction: | |
| raw = text.strip() | |
| if raw.startswith("```"): | |
| raw = raw.split("\n", 1)[-1] | |
| if raw.endswith("```"): | |
| raw = raw.rsplit("```", 1)[0] | |
| raw = raw.strip() | |
| payload: dict[str, Any] = json.loads(raw) | |
| return PolypharmacyAction.model_validate(payload) | |
| def _fallback_query_action(obs: PolypharmacyObservation) -> PolypharmacyAction: | |
| meds = [m.drug_id for m in obs.current_medications] | |
| if len(meds) >= 2 and obs.remaining_query_budget > 0: | |
| return PolypharmacyAction( | |
| action_type="query_ddi", | |
| drug_id_1=meds[0], | |
| drug_id_2=meds[1], | |
| ) | |
| return PolypharmacyAction(action_type="finish_review") | |
| def _norm_pair(a: str, b: str) -> tuple[str, str]: | |
| return (a, b) if a < b else (b, a) | |
| def _pick_unseen_query_pair(obs: PolypharmacyObservation) -> tuple[str, str] | None: | |
| meds = [m.drug_id for m in obs.current_medications] | |
| if len(meds) < 2 or obs.remaining_query_budget <= 0: | |
| return None | |
| seen = { | |
| _norm_pair(q.drug_id_1, q.drug_id_2) | |
| for q in obs.interaction_queries | |
| } | |
| # Prioritize pairs containing high-risk drugs. | |
| high_risk = [m.drug_id for m in obs.current_medications if m.is_high_risk_elderly] | |
| ordered = high_risk + [m for m in meds if m not in set(high_risk)] | |
| for i in range(len(ordered)): | |
| for j in range(i + 1, len(ordered)): | |
| p = _norm_pair(ordered[i], ordered[j]) | |
| if p not in seen: | |
| return p | |
| return None | |
| def _pick_intervention_target(obs: PolypharmacyObservation) -> str | None: | |
| if obs.remaining_intervention_budget <= 0: | |
| return None | |
| med_set = {m.drug_id for m in obs.current_medications} | |
| # Use latest discovered severe/moderate query as intervention target. | |
| for q in reversed(obs.interaction_queries): | |
| if q.severity in ("severe", "moderate"): | |
| m1 = next((m for m in obs.current_medications if m.drug_id == q.drug_id_1), None) | |
| m2 = next((m for m in obs.current_medications if m.drug_id == q.drug_id_2), None) | |
| candidates = [m for m in (m1, m2) if m is not None] | |
| if not candidates: | |
| continue | |
| # Prefer non-critical risky drugs first. | |
| candidates.sort( | |
| key=lambda m: ( | |
| m.drug_id in CRITICAL_DRUG_IDS, | |
| 0 if any("avoid" in f for f in m.beers_flags) else 1, | |
| 0 if m.is_high_risk_elderly else 1, | |
| ) | |
| ) | |
| return candidates[0].drug_id | |
| # Fallback: if no severe/moderate discovered, still intervene on obviously | |
| # risky medications (Beers/high-risk flags) when budgets permit. | |
| risky = sorted( | |
| obs.current_medications, | |
| key=lambda m: ( | |
| 0 if any("avoid" in f for f in m.beers_flags) else 1, | |
| 0 if m.is_high_risk_elderly else 1, | |
| 1 if m.drug_id in CRITICAL_DRUG_IDS else 0, | |
| ), | |
| ) | |
| for med in risky: | |
| if any("avoid" in f for f in med.beers_flags) or med.is_high_risk_elderly: | |
| return med.drug_id | |
| return None | |
| def _rule_based_action(obs: PolypharmacyObservation) -> PolypharmacyAction | None: | |
| # If we already discovered significant risk, intervene before more querying. | |
| target = _pick_intervention_target(obs) | |
| if target and ( | |
| obs.step_index >= 1 | |
| and ( | |
| obs.remaining_query_budget <= 2 | |
| or len(obs.interaction_queries) >= 4 | |
| or any(q.severity in ("severe", "moderate") for q in obs.interaction_queries) | |
| ) | |
| ): | |
| intervention = "stop" | |
| rationale = "Remove likely contributor to discovered interaction risk" | |
| if target in CRITICAL_DRUG_IDS: | |
| # Avoid blunt stop for critical meds. | |
| intervention = "dose_reduce" | |
| rationale = "Critical medication: prefer dose reduction over abrupt stop" | |
| return PolypharmacyAction( | |
| action_type="propose_intervention", | |
| target_drug_id=target, | |
| intervention_type=intervention, | |
| rationale=rationale, | |
| ) | |
| pair = _pick_unseen_query_pair(obs) | |
| if pair: | |
| return PolypharmacyAction( | |
| action_type="query_ddi", | |
| drug_id_1=pair[0], | |
| drug_id_2=pair[1], | |
| ) | |
| if obs.remaining_intervention_budget > 0: | |
| # Final fallback before finish: at least one safety action. | |
| target = _pick_intervention_target(obs) | |
| if target: | |
| return PolypharmacyAction( | |
| action_type="propose_intervention", | |
| target_drug_id=target, | |
| intervention_type="dose_reduce" | |
| if target in CRITICAL_DRUG_IDS | |
| else "stop", | |
| rationale="Fallback intervention when query options are exhausted", | |
| ) | |
| if obs.step_index >= 3: | |
| return PolypharmacyAction(action_type="finish_review") | |
| return None | |
| def _postprocess_action( | |
| obs: PolypharmacyObservation, action: PolypharmacyAction | |
| ) -> PolypharmacyAction: | |
| # First apply deterministic guardrails to avoid repetitive loops. | |
| ruled = _rule_based_action(obs) | |
| if ruled is not None: | |
| return ruled | |
| # Guardrail: prevent useless immediate finish actions. | |
| if action.action_type == "finish_review": | |
| if obs.step_index < 2 and obs.remaining_query_budget > 0: | |
| return _fallback_query_action(obs) | |
| if len(obs.interaction_queries) == 0 and obs.remaining_query_budget > 0: | |
| return _fallback_query_action(obs) | |
| return action | |
| def suggest_action_from_observation( | |
| observation: PolypharmacyObservation, | |
| model_name: str | None = None, | |
| ) -> PolypharmacyAction: | |
| """Use Groq chat completions to suggest a valid action.""" | |
| api_key = os.getenv("GROQ_API_KEY", "").strip() | |
| if not api_key: | |
| raise ValueError("GROQ_API_KEY is missing. Add it to your .env file.") | |
| base_url = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1").strip() | |
| model = (model_name or os.getenv("GROQ_MODEL_NAME", DEFAULT_MODEL)).strip() | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| user_prompt = _obs_to_prompt(observation) | |
| tried: list[tuple[str, str]] = [] | |
| candidates: list[str] = [model] + [m for m in FALLBACK_MODELS if m != model] | |
| for candidate in candidates: | |
| try: | |
| resp = client.chat.completions.create( | |
| model=candidate, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=220, | |
| ) | |
| generated = (resp.choices[0].message.content or "").strip() | |
| parsed = _parse_action(generated) | |
| return _postprocess_action(observation, parsed) | |
| except Exception as exc: | |
| tried.append((candidate, str(exc))) | |
| tried_txt = " | ".join(f"{m}: {err}" for m, err in tried) | |
| raise ValueError( | |
| "No Groq model worked. Try one of: " | |
| "llama-3.3-70b-versatile, llama-3.1-8b-instant, gemma2-9b-it. " | |
| f"Errors: {tried_txt}" | |
| ) | |