"""Groq-powered action suggester for PolypharmacyEnv.""" from __future__ import annotations import json import os from typing import Any from openai import OpenAI from ..models import PolypharmacyAction, PolypharmacyObservation DEFAULT_MODEL = "llama-3.1-8b-instant" FALLBACK_MODELS = [ "llama-3.1-8b-instant", "llama-3.3-70b-versatile", "gemma2-9b-it", ] CRITICAL_DRUG_IDS = {"DRUG_WARFARIN", "DRUG_INSULIN_GLARGINE", "DRUG_DIGOXIN"} SYSTEM_PROMPT = """You are a clinical medication safety assistant. Return exactly one JSON object describing the next action. Allowed output schema: { "action_type": "query_ddi" | "propose_intervention" | "finish_review", "drug_id_1": "optional", "drug_id_2": "optional", "target_drug_id": "optional", "intervention_type": "stop|dose_reduce|substitute|add_monitoring|none", "proposed_new_drug_id": "optional", "rationale": "optional" } No markdown fences. No extra text. Do NOT use finish_review early. First, gather evidence with query_ddi and/or perform at least one meaningful intervention when needed. """ def _obs_to_prompt(obs: PolypharmacyObservation) -> str: meds = ", ".join(m.drug_id for m in obs.current_medications) conds = ", ".join(obs.conditions) return ( f"Task: {obs.task_id}\n" f"Age: {obs.age}, sex: {obs.sex}\n" f"Conditions: {conds}\n" f"Medications: {meds}\n" f"Query budget: {obs.remaining_query_budget}\n" f"Intervention budget: {obs.remaining_intervention_budget}\n" f"Step index: {obs.step_index}\n" "Choose the single safest, most useful next action." ) def _parse_action(text: str) -> PolypharmacyAction: raw = text.strip() if raw.startswith("```"): raw = raw.split("\n", 1)[-1] if raw.endswith("```"): raw = raw.rsplit("```", 1)[0] raw = raw.strip() payload: dict[str, Any] = json.loads(raw) return PolypharmacyAction.model_validate(payload) def _fallback_query_action(obs: PolypharmacyObservation) -> PolypharmacyAction: meds = [m.drug_id for m in obs.current_medications] if len(meds) >= 2 and obs.remaining_query_budget > 0: return PolypharmacyAction( action_type="query_ddi", drug_id_1=meds[0], drug_id_2=meds[1], ) return PolypharmacyAction(action_type="finish_review") def _norm_pair(a: str, b: str) -> tuple[str, str]: return (a, b) if a < b else (b, a) def _pick_unseen_query_pair(obs: PolypharmacyObservation) -> tuple[str, str] | None: meds = [m.drug_id for m in obs.current_medications] if len(meds) < 2 or obs.remaining_query_budget <= 0: return None seen = { _norm_pair(q.drug_id_1, q.drug_id_2) for q in obs.interaction_queries } # Prioritize pairs containing high-risk drugs. high_risk = [m.drug_id for m in obs.current_medications if m.is_high_risk_elderly] ordered = high_risk + [m for m in meds if m not in set(high_risk)] for i in range(len(ordered)): for j in range(i + 1, len(ordered)): p = _norm_pair(ordered[i], ordered[j]) if p not in seen: return p return None def _pick_intervention_target(obs: PolypharmacyObservation) -> str | None: if obs.remaining_intervention_budget <= 0: return None med_set = {m.drug_id for m in obs.current_medications} # Use latest discovered severe/moderate query as intervention target. for q in reversed(obs.interaction_queries): if q.severity in ("severe", "moderate"): m1 = next((m for m in obs.current_medications if m.drug_id == q.drug_id_1), None) m2 = next((m for m in obs.current_medications if m.drug_id == q.drug_id_2), None) candidates = [m for m in (m1, m2) if m is not None] if not candidates: continue # Prefer non-critical risky drugs first. candidates.sort( key=lambda m: ( m.drug_id in CRITICAL_DRUG_IDS, 0 if any("avoid" in f for f in m.beers_flags) else 1, 0 if m.is_high_risk_elderly else 1, ) ) return candidates[0].drug_id # Fallback: if no severe/moderate discovered, still intervene on obviously # risky medications (Beers/high-risk flags) when budgets permit. risky = sorted( obs.current_medications, key=lambda m: ( 0 if any("avoid" in f for f in m.beers_flags) else 1, 0 if m.is_high_risk_elderly else 1, 1 if m.drug_id in CRITICAL_DRUG_IDS else 0, ), ) for med in risky: if any("avoid" in f for f in med.beers_flags) or med.is_high_risk_elderly: return med.drug_id return None def _rule_based_action(obs: PolypharmacyObservation) -> PolypharmacyAction | None: # If we already discovered significant risk, intervene before more querying. target = _pick_intervention_target(obs) if target and ( obs.step_index >= 1 and ( obs.remaining_query_budget <= 2 or len(obs.interaction_queries) >= 4 or any(q.severity in ("severe", "moderate") for q in obs.interaction_queries) ) ): intervention = "stop" rationale = "Remove likely contributor to discovered interaction risk" if target in CRITICAL_DRUG_IDS: # Avoid blunt stop for critical meds. intervention = "dose_reduce" rationale = "Critical medication: prefer dose reduction over abrupt stop" return PolypharmacyAction( action_type="propose_intervention", target_drug_id=target, intervention_type=intervention, rationale=rationale, ) pair = _pick_unseen_query_pair(obs) if pair: return PolypharmacyAction( action_type="query_ddi", drug_id_1=pair[0], drug_id_2=pair[1], ) if obs.remaining_intervention_budget > 0: # Final fallback before finish: at least one safety action. target = _pick_intervention_target(obs) if target: return PolypharmacyAction( action_type="propose_intervention", target_drug_id=target, intervention_type="dose_reduce" if target in CRITICAL_DRUG_IDS else "stop", rationale="Fallback intervention when query options are exhausted", ) if obs.step_index >= 3: return PolypharmacyAction(action_type="finish_review") return None def _postprocess_action( obs: PolypharmacyObservation, action: PolypharmacyAction ) -> PolypharmacyAction: # First apply deterministic guardrails to avoid repetitive loops. ruled = _rule_based_action(obs) if ruled is not None: return ruled # Guardrail: prevent useless immediate finish actions. if action.action_type == "finish_review": if obs.step_index < 2 and obs.remaining_query_budget > 0: return _fallback_query_action(obs) if len(obs.interaction_queries) == 0 and obs.remaining_query_budget > 0: return _fallback_query_action(obs) return action def suggest_action_from_observation( observation: PolypharmacyObservation, model_name: str | None = None, ) -> PolypharmacyAction: """Use Groq chat completions to suggest a valid action.""" api_key = os.getenv("GROQ_API_KEY", "").strip() if not api_key: raise ValueError("GROQ_API_KEY is missing. Add it to your .env file.") base_url = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1").strip() model = (model_name or os.getenv("GROQ_MODEL_NAME", DEFAULT_MODEL)).strip() client = OpenAI(api_key=api_key, base_url=base_url) user_prompt = _obs_to_prompt(observation) tried: list[tuple[str, str]] = [] candidates: list[str] = [model] + [m for m in FALLBACK_MODELS if m != model] for candidate in candidates: try: resp = client.chat.completions.create( model=candidate, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.2, max_tokens=220, ) generated = (resp.choices[0].message.content or "").strip() parsed = _parse_action(generated) return _postprocess_action(observation, parsed) except Exception as exc: tried.append((candidate, str(exc))) tried_txt = " | ".join(f"{m}: {err}" for m, err in tried) raise ValueError( "No Groq model worked. Try one of: " "llama-3.3-70b-versatile, llama-3.1-8b-instant, gemma2-9b-it. " f"Errors: {tried_txt}" )