File size: 8,283 Bytes
e543908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa6a46
 
 
 
 
 
e543908
 
 
250a6d2
 
0aa6a46
 
 
e543908
 
 
 
 
 
 
 
0aa6a46
 
 
e543908
0aa6a46
 
 
 
 
 
 
e543908
0aa6a46
 
 
 
 
e543908
 
 
 
 
 
 
 
 
 
 
 
c314a65
373c99b
e543908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373c99b
e543908
 
 
 
 
 
 
 
 
 
 
 
0aa6a46
e543908
 
 
 
 
0aa6a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e543908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa6a46
c314a65
 
e543908
 
 
 
 
 
 
 
c314a65
e543908
 
 
 
 
 
 
 
 
 
 
 
 
c314a65
e543908
 
 
 
 
 
 
 
 
 
 
 
 
c314a65
e543908
 
 
 
 
 
 
 
0aa6a46
 
e543908
0aa6a46
e543908
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3
"""Submission inference script for Polypharmacy OpenEnv environment.

Required environment variables:
  API_BASE_URL   OpenAI-compatible base URL
  MODEL_NAME     Model identifier
  HF_TOKEN       API key/token

Optional:
  POLYPHARMACY_ENV_URL  Environment API base (default: http://localhost:7860)
"""

from __future__ import annotations

import json
import os
import re
from typing import Any, Dict, List

try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

import requests
from openai import OpenAI

API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
HF_TOKEN = os.getenv("HF_TOKEN")
# Also accept GROQ_API_KEY or API_KEY as fallback for the token
_API_KEY = HF_TOKEN or os.getenv("GROQ_API_KEY") or os.getenv("API_KEY")
ENV_URL = os.getenv("POLYPHARMACY_ENV_URL", "http://localhost:7860").rstrip("/")

BENCHMARK = "polypharmacy_env"
TASKS = ["easy_screening", "budgeted_screening", "complex_tradeoff"]
MAX_STEPS = 16
TEMPERATURE = 0.0
MAX_TOKENS = 220

VALID_ACTION_TYPES = {"query_ddi", "propose_intervention", "finish_review"}
VALID_INTERVENTIONS = {"stop", "dose_reduce", "substitute", "add_monitoring"}

SYSTEM_PROMPT = (
    "You are a clinical-pharmacist agent reviewing an elderly patient's medications. "
    "You MUST return ONLY a single JSON object (no markdown, no explanation). "
    "The action_type MUST be exactly one of: query_ddi, propose_intervention, finish_review. "
    "Schema for query_ddi: "
    '{"action_type":"query_ddi","drug_id_1":"DRUG_X","drug_id_2":"DRUG_Y"} '
    "Schema for propose_intervention: "
    '{"action_type":"propose_intervention","target_drug_id":"DRUG_X",'
    '"intervention_type":"stop|dose_reduce|substitute|add_monitoring",'
    '"rationale":"reason"} '
    "Schema for finish_review: "
    '{"action_type":"finish_review"} '
    "Strategy: First query_ddi for high-risk drug pairs (especially those with beers_flags). "
    "Then propose_intervention for dangerous findings. Finally finish_review."
)


def _b(v: bool) -> str:
    return str(bool(v)).lower()


def _fmt_reward(v: float) -> str:
    return f"{float(v):.2f}"


def _clamp01(v: float) -> float:
    """Clamp score to strict (0.001, 0.999) — never outside this range."""
    return max(0.001, min(0.999, float(v)))


def log_start(task: str) -> None:
    print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True)


def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
    err = error if error else "null"
    print(
        f"[STEP] step={step} action={action_str} reward={_fmt_reward(reward)} "
        f"done={_b(done)} error={err}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(_fmt_reward(r) for r in rewards)
    print(
        f"[END] success={_b(success)} steps={steps} score={_clamp01(score):.3f} rewards={rewards_str}",
        flush=True,
    )


def _safe_json(text: str) -> Dict[str, Any]:
    text = text.strip()
    if text.startswith("```"):
        text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
        text = text.replace("```", "").strip()
    try:
        data = json.loads(text)
        if isinstance(data, dict):
            return _sanitize_action(data)
    except Exception:
        pass
    return {"action_type": "finish_review"}


def _sanitize_action(raw: Dict[str, Any]) -> Dict[str, Any]:
    """Build a clean action dict with only the fields relevant to the action type."""
    atype = raw.get("action_type", "")
    if atype not in VALID_ACTION_TYPES:
        return {"action_type": "finish_review"}

    if atype == "query_ddi":
        return {
            "action_type": "query_ddi",
            "drug_id_1": raw.get("drug_id_1") or None,
            "drug_id_2": raw.get("drug_id_2") or None,
        }
    if atype == "propose_intervention":
        it = raw.get("intervention_type", "")
        if it not in VALID_INTERVENTIONS:
            it = "add_monitoring"
        result: Dict[str, Any] = {
            "action_type": "propose_intervention",
            "target_drug_id": raw.get("target_drug_id") or None,
            "intervention_type": it,
        }
        new_drug = raw.get("proposed_new_drug_id") or None
        if new_drug:
            result["proposed_new_drug_id"] = new_drug
        rationale = raw.get("rationale") or None
        if rationale:
            result["rationale"] = rationale
        return result
    return {"action_type": "finish_review"}


def _llm_action(client: OpenAI, obs: Dict[str, Any]) -> Dict[str, Any]:
    meds = obs.get("current_medications", [])
    summary = {
        "step_index": obs.get("step_index", 0),
        "remaining_query_budget": obs.get("remaining_query_budget", 0),
        "remaining_intervention_budget": obs.get("remaining_intervention_budget", 0),
        "conditions": obs.get("conditions", []),
        "current_medications": [
            {
                "drug_id": m.get("drug_id"),
                "generic_name": m.get("generic_name"),
                "dose_mg": m.get("dose_mg"),
                "beers_flags": m.get("beers_flags", []),
            }
            for m in meds
        ],
        "interaction_queries": obs.get("interaction_queries", []),
        "interventions": obs.get("interventions", []),
    }
    resp = client.chat.completions.create(
        model=MODEL_NAME,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": json.dumps(summary, separators=(",", ":"))},
        ],
    )
    content = (resp.choices[0].message.content or "").strip()
    return _safe_json(content)


def _reset(task_id: str) -> Dict[str, Any]:
    r = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=45)
    r.raise_for_status()
    return r.json()


def _step(action: Dict[str, Any]) -> Dict[str, Any]:
    r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=45)
    if r.status_code == 422:
        # Invalid action — return a clamped penalty and let the agent continue
        return {"observation": {}, "reward": 0.01, "done": False, "info": {"error": r.text[:200]}}
    r.raise_for_status()
    return r.json()


def run_task(client: OpenAI, task_id: str) -> None:
    rewards: List[float] = []
    steps = 0
    success = False
    score = 0.001  # strict (0.001, 0.999) — never outside this range
    log_start(task_id)
    try:
        reset_payload = _reset(task_id)
        obs = reset_payload.get("observation", {})
        done = bool(reset_payload.get("done", False))

        for i in range(1, MAX_STEPS + 1):
            if done:
                break
            action = _llm_action(client, obs)
            action_str = json.dumps(action, separators=(",", ":"))
            step_payload = _step(action)
            obs = step_payload.get("observation", {})
            reward = _clamp01(float(step_payload.get("reward") or 0.0))
            done = bool(step_payload.get("done", False))
            metadata = (obs or {}).get("metadata", {}) or {}
            last_error = metadata.get("error")
            rewards.append(reward)
            steps = i
            log_step(i, action_str, reward, done, str(last_error) if last_error else None)

            if done:
                raw_score = metadata.get("grader_score", None)
                if raw_score is not None:
                    score = _clamp01(float(raw_score))
                else:
                    score = _clamp01(sum(max(0.0, r) for r in rewards) / max(len(rewards), 1))
                success = score > 0.001
                break
    except Exception:
        success = False
    finally:
        log_end(success=success, steps=steps, score=score, rewards=rewards)


def main() -> int:
    if not _API_KEY:
        print("HF_TOKEN (or GROQ_API_KEY / API_KEY) is required", flush=True)
        return 1
    client = OpenAI(base_url=API_BASE_URL, api_key=_API_KEY)
    for task in TASKS:
        run_task(client, task)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())