Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Submission inference script for Polypharmacy OpenEnv environment. | |
| Required environment variables: | |
| API_BASE_URL OpenAI-compatible base URL | |
| MODEL_NAME Model identifier | |
| HF_TOKEN API key/token | |
| Optional: | |
| POLYPHARMACY_ENV_URL Environment API base (default: http://localhost:7860) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from typing import Any, Dict, List | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| import requests | |
| from openai import OpenAI | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Also accept GROQ_API_KEY or API_KEY as fallback for the token | |
| _API_KEY = HF_TOKEN or os.getenv("GROQ_API_KEY") or os.getenv("API_KEY") | |
| ENV_URL = os.getenv("POLYPHARMACY_ENV_URL", "http://localhost:7860").rstrip("/") | |
| BENCHMARK = "polypharmacy_env" | |
| TASKS = ["easy_screening", "budgeted_screening", "complex_tradeoff"] | |
| MAX_STEPS = 16 | |
| TEMPERATURE = 0.0 | |
| MAX_TOKENS = 220 | |
| VALID_ACTION_TYPES = {"query_ddi", "propose_intervention", "finish_review"} | |
| VALID_INTERVENTIONS = {"stop", "dose_reduce", "substitute", "add_monitoring"} | |
| SYSTEM_PROMPT = ( | |
| "You are a clinical-pharmacist agent reviewing an elderly patient's medications. " | |
| "You MUST return ONLY a single JSON object (no markdown, no explanation). " | |
| "The action_type MUST be exactly one of: query_ddi, propose_intervention, finish_review. " | |
| "Schema for query_ddi: " | |
| '{"action_type":"query_ddi","drug_id_1":"DRUG_X","drug_id_2":"DRUG_Y"} ' | |
| "Schema for propose_intervention: " | |
| '{"action_type":"propose_intervention","target_drug_id":"DRUG_X",' | |
| '"intervention_type":"stop|dose_reduce|substitute|add_monitoring",' | |
| '"rationale":"reason"} ' | |
| "Schema for finish_review: " | |
| '{"action_type":"finish_review"} ' | |
| "Strategy: First query_ddi for high-risk drug pairs (especially those with beers_flags). " | |
| "Then propose_intervention for dangerous findings. Finally finish_review." | |
| ) | |
| def _b(v: bool) -> str: | |
| return str(bool(v)).lower() | |
| def _fmt_reward(v: float) -> str: | |
| return f"{float(v):.2f}" | |
| def _clamp01(v: float) -> float: | |
| """Clamp score to strict (0.001, 0.999) — never outside this range.""" | |
| return max(0.001, min(0.999, float(v))) | |
| def log_start(task: str) -> None: | |
| print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True) | |
| def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None: | |
| err = error if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action_str} reward={_fmt_reward(reward)} " | |
| f"done={_b(done)} error={err}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(_fmt_reward(r) for r in rewards) | |
| print( | |
| f"[END] success={_b(success)} steps={steps} score={_clamp01(score):.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def _safe_json(text: str) -> Dict[str, Any]: | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = re.sub(r"^```[a-zA-Z]*\n?", "", text) | |
| text = text.replace("```", "").strip() | |
| try: | |
| data = json.loads(text) | |
| if isinstance(data, dict): | |
| return _sanitize_action(data) | |
| except Exception: | |
| pass | |
| return {"action_type": "finish_review"} | |
| def _sanitize_action(raw: Dict[str, Any]) -> Dict[str, Any]: | |
| """Build a clean action dict with only the fields relevant to the action type.""" | |
| atype = raw.get("action_type", "") | |
| if atype not in VALID_ACTION_TYPES: | |
| return {"action_type": "finish_review"} | |
| if atype == "query_ddi": | |
| return { | |
| "action_type": "query_ddi", | |
| "drug_id_1": raw.get("drug_id_1") or None, | |
| "drug_id_2": raw.get("drug_id_2") or None, | |
| } | |
| if atype == "propose_intervention": | |
| it = raw.get("intervention_type", "") | |
| if it not in VALID_INTERVENTIONS: | |
| it = "add_monitoring" | |
| result: Dict[str, Any] = { | |
| "action_type": "propose_intervention", | |
| "target_drug_id": raw.get("target_drug_id") or None, | |
| "intervention_type": it, | |
| } | |
| new_drug = raw.get("proposed_new_drug_id") or None | |
| if new_drug: | |
| result["proposed_new_drug_id"] = new_drug | |
| rationale = raw.get("rationale") or None | |
| if rationale: | |
| result["rationale"] = rationale | |
| return result | |
| return {"action_type": "finish_review"} | |
| def _llm_action(client: OpenAI, obs: Dict[str, Any]) -> Dict[str, Any]: | |
| meds = obs.get("current_medications", []) | |
| summary = { | |
| "step_index": obs.get("step_index", 0), | |
| "remaining_query_budget": obs.get("remaining_query_budget", 0), | |
| "remaining_intervention_budget": obs.get("remaining_intervention_budget", 0), | |
| "conditions": obs.get("conditions", []), | |
| "current_medications": [ | |
| { | |
| "drug_id": m.get("drug_id"), | |
| "generic_name": m.get("generic_name"), | |
| "dose_mg": m.get("dose_mg"), | |
| "beers_flags": m.get("beers_flags", []), | |
| } | |
| for m in meds | |
| ], | |
| "interaction_queries": obs.get("interaction_queries", []), | |
| "interventions": obs.get("interventions", []), | |
| } | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": json.dumps(summary, separators=(",", ":"))}, | |
| ], | |
| ) | |
| content = (resp.choices[0].message.content or "").strip() | |
| return _safe_json(content) | |
| def _reset(task_id: str) -> Dict[str, Any]: | |
| r = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=45) | |
| r.raise_for_status() | |
| return r.json() | |
| def _step(action: Dict[str, Any]) -> Dict[str, Any]: | |
| r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=45) | |
| if r.status_code == 422: | |
| # Invalid action — return a clamped penalty and let the agent continue | |
| return {"observation": {}, "reward": 0.01, "done": False, "info": {"error": r.text[:200]}} | |
| r.raise_for_status() | |
| return r.json() | |
| def run_task(client: OpenAI, task_id: str) -> None: | |
| rewards: List[float] = [] | |
| steps = 0 | |
| success = False | |
| score = 0.001 # strict (0.001, 0.999) — never outside this range | |
| log_start(task_id) | |
| try: | |
| reset_payload = _reset(task_id) | |
| obs = reset_payload.get("observation", {}) | |
| done = bool(reset_payload.get("done", False)) | |
| for i in range(1, MAX_STEPS + 1): | |
| if done: | |
| break | |
| action = _llm_action(client, obs) | |
| action_str = json.dumps(action, separators=(",", ":")) | |
| step_payload = _step(action) | |
| obs = step_payload.get("observation", {}) | |
| reward = _clamp01(float(step_payload.get("reward") or 0.0)) | |
| done = bool(step_payload.get("done", False)) | |
| metadata = (obs or {}).get("metadata", {}) or {} | |
| last_error = metadata.get("error") | |
| rewards.append(reward) | |
| steps = i | |
| log_step(i, action_str, reward, done, str(last_error) if last_error else None) | |
| if done: | |
| raw_score = metadata.get("grader_score", None) | |
| if raw_score is not None: | |
| score = _clamp01(float(raw_score)) | |
| else: | |
| score = _clamp01(sum(max(0.0, r) for r in rewards) / max(len(rewards), 1)) | |
| success = score > 0.001 | |
| break | |
| except Exception: | |
| success = False | |
| finally: | |
| log_end(success=success, steps=steps, score=score, rewards=rewards) | |
| def main() -> int: | |
| if not _API_KEY: | |
| print("HF_TOKEN (or GROQ_API_KEY / API_KEY) is required", flush=True) | |
| return 1 | |
| client = OpenAI(base_url=API_BASE_URL, api_key=_API_KEY) | |
| for task in TASKS: | |
| run_task(client, task) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |