Aman Khare commited on
Commit
3856d60
·
1 Parent(s): b3d1ac3

edited inference,py acc to submition tempelate

Browse files
Files changed (3) hide show
  1. err.txt +24 -0
  2. inference.py +240 -242
  3. out.txt +9 -0
err.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"event": "START", "timestamp": 1775576189.364181, "task_id": "easy_routine_checkup"}
2
+ [DEBUG] Model request failed: Error code: 401 - {'error': 'Invalid username or password.'}
3
+ {"event": "STEP", "timestamp": 1775576190.2672057, "step": 1, "action_type": "submit_note", "reward": 0.7}
4
+ {"event": "END", "timestamp": 1775576190.2674263, "task_id": "easy_routine_checkup", "final_score": 0.7}
5
+ {"event": "START", "timestamp": 1775576190.269494, "task_id": "medium_chronic_disease_followup"}
6
+ [DEBUG] Model request failed: Error code: 401 - {'error': 'Invalid username or password.'}
7
+ {"event": "STEP", "timestamp": 1775576190.6036963, "step": 1, "action_type": "submit_note", "reward": 0.7}
8
+ {"event": "END", "timestamp": 1775576190.6037915, "task_id": "medium_chronic_disease_followup", "final_score": 0.7}
9
+ {"event": "START", "timestamp": 1775576190.604777, "task_id": "hard_complex_er_visit"}
10
+ [DEBUG] Model request failed: Error code: 401 - {'error': 'Invalid username or password.'}
11
+ {"event": "STEP", "timestamp": 1775576190.9611442, "step": 1, "action_type": "submit_note", "reward": 0.7}
12
+ {"event": "END", "timestamp": 1775576190.961212, "task_id": "hard_complex_er_visit", "final_score": 0.7}
13
+
14
+ ============================================================
15
+ SUMMARY
16
+ ============================================================
17
+ Task Score Steps
18
+ ------------------------------- ------- -----
19
+ easy_routine_checkup 0.7000 1
20
+ medium_chronic_disease_followup 0.7000 1
21
+ hard_complex_er_visit 0.7000 1
22
+ ------------------------------- ------- -----
23
+ AVERAGE 0.7000
24
+
inference.py CHANGED
@@ -1,17 +1,37 @@
1
- """Baseline inference script for the Clinical Note Scribe environment.
2
-
3
- Runs all three tasks (easy → medium → hard) sequentially using an
4
- OpenAI-compatible API to generate SOAP notes from doctor–patient transcripts.
5
-
6
- Environment variables
7
- ---------------------
8
- OPENAI_API_KEY – API key for the model provider
9
- API_BASE_URL – Base URL for the OpenAI-compatible endpoint (default: https://api.openai.com/v1)
10
- MODEL_NAME – Model identifier to use (default: gpt-4o-mini)
11
-
12
- Usage::
13
-
14
- python inference.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  Designed to complete in under 20 minutes on 2 vCPU / 8 GB RAM.
17
  """
@@ -22,23 +42,23 @@
22
  import logging
23
  import os
24
  import sys
25
- import time
26
- from typing import Any
 
 
27
 
28
  # ---------------------------------------------------------------------------
29
- # Bootstrap logging BEFORE importing environment modules so the root logger
30
- # is configured and child loggers (clinical_note_scribe.*) propagate cleanly.
31
  # ---------------------------------------------------------------------------
 
 
 
 
 
32
 
33
- logging.basicConfig(
34
- level=logging.INFO,
35
- format="%(message)s",
36
- handlers=[logging.StreamHandler(sys.stdout)],
37
- )
38
- logger = logging.getLogger("inference")
39
 
40
  # ---------------------------------------------------------------------------
41
- # Environment imports (after logging is configured)
42
  # ---------------------------------------------------------------------------
43
 
44
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -50,275 +70,253 @@
50
  # Config
51
  # ---------------------------------------------------------------------------
52
 
53
- API_KEY = os.environ.get("OPENAI_API_KEY", "")
54
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
55
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
 
56
 
57
- TASK_IDS = list(TASK_REGISTRY.keys()) # deterministic order
58
-
59
- # Maximum tokens for the model response keeps latency low
60
- MAX_TOKENS = 1024
 
61
 
62
  # ---------------------------------------------------------------------------
63
  # System prompt
64
  # ---------------------------------------------------------------------------
65
 
66
- SYSTEM_PROMPT = """\
67
- You are a clinical documentation assistant. Given a doctorpatient transcript \
68
- and patient context, generate a concise, clinically accurate SOAP note.
69
-
70
- RULES:
71
- 1. Use professional medical language. Avoid over-certain phrasing such as \
72
- "patient definitely has", "diagnosis is certain", or "100% certain".
73
- 2. Keep the note concise aim for under 400 words total across all four sections.
74
- 3. Return your output as a **single valid JSON object** matching this schema exactly:
75
-
76
- {
77
- "action_type": "submit_note",
78
- "soap_note": {
79
- "subjective": "<patient's reported symptoms, history, and concerns>",
80
- "objective": "<exam findings, vitals, lab results, imaging>",
81
- "assessment": "<differential diagnoses and clinical reasoning>",
82
- "plan": "<treatment plan, medications, follow-up, referrals>"
83
- }
84
- }
85
-
86
- Return ONLY the JSON object. No markdown fences, no commentary, no extra keys.
87
- """
 
88
 
89
  # ---------------------------------------------------------------------------
90
- # Helpers
91
  # ---------------------------------------------------------------------------
92
 
 
 
93
 
94
- def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
95
- """Build the user message containing the transcript and context."""
96
- ctx_str = json.dumps(patient_context, indent=2, default=str)
97
- return (
98
- f"## Patient Context\n```json\n{ctx_str}\n```\n\n"
99
- f"## Doctor–Patient Transcript\n```\n{transcript}\n```\n\n"
100
- "Generate the SOAP note as a JSON Action object."
101
- )
102
 
 
 
 
 
 
 
 
103
 
104
- def _call_model(user_prompt: str) -> dict[str, Any]:
105
- """Call the OpenAI-compatible API and return the parsed JSON action dict.
106
 
107
- Uses ``urllib`` so there is zero dependency on ``openai`` package
108
- this keeps the Docker image small and avoids version conflicts.
109
- Falls back to the ``openai`` package if installed.
110
- """
111
- try:
112
- return _call_model_sdk(user_prompt)
113
- except ImportError:
114
- return _call_model_urllib(user_prompt)
115
 
116
 
117
- def _call_model_sdk(user_prompt: str) -> dict[str, Any]:
118
- """Call via the ``openai`` Python SDK."""
119
- from openai import OpenAI # noqa: F811
120
 
121
- client = OpenAI(
122
- api_key=API_KEY,
123
- base_url=API_BASE_URL,
124
- )
125
- response = client.chat.completions.create(
126
- model=MODEL_NAME,
127
- messages=[
128
- {"role": "system", "content": SYSTEM_PROMPT},
129
- {"role": "user", "content": user_prompt},
130
- ],
131
- max_tokens=MAX_TOKENS,
132
- temperature=0.2,
133
- )
134
- raw = response.choices[0].message.content.strip()
135
- return _parse_json(raw)
136
-
137
-
138
- def _call_model_urllib(user_prompt: str) -> dict[str, Any]:
139
- """Fallback: call the API with ``urllib`` (no extra dependencies)."""
140
- import urllib.request
141
-
142
- url = f"{API_BASE_URL.rstrip('/')}/chat/completions"
143
- payload = json.dumps({
144
- "model": MODEL_NAME,
145
- "messages": [
146
- {"role": "system", "content": SYSTEM_PROMPT},
147
- {"role": "user", "content": user_prompt},
148
- ],
149
- "max_tokens": MAX_TOKENS,
150
- "temperature": 0.2,
151
- }).encode()
152
-
153
- req = urllib.request.Request(
154
- url,
155
- data=payload,
156
- headers={
157
- "Content-Type": "application/json",
158
- "Authorization": f"Bearer {API_KEY}",
159
- },
160
  )
161
- with urllib.request.urlopen(req, timeout=120) as resp:
162
- body = json.loads(resp.read())
163
-
164
- raw = body["choices"][0]["message"]["content"].strip()
165
- return _parse_json(raw)
166
 
167
 
168
  def _parse_json(raw: str) -> dict[str, Any]:
169
  """Parse the model's raw text output into a dict, tolerating markdown fences."""
170
- # Strip markdown code fences if present
171
- cleaned = raw
172
  if cleaned.startswith("```"):
173
- # remove opening fence (possibly ```json)
174
  first_newline = cleaned.index("\n")
175
  cleaned = cleaned[first_newline + 1:]
176
  if cleaned.endswith("```"):
177
- cleaned = cleaned[: -3]
178
  cleaned = cleaned.strip()
179
 
180
  try:
181
  return json.loads(cleaned)
182
  except json.JSONDecodeError as exc:
183
- logger.error("Failed to parse model output as JSON: %s", exc)
184
- logger.error("Raw output:\n%s", raw)
185
  raise
186
 
187
 
188
- def _log_event(event: str, **kwargs: Any) -> None:
189
- """Emit a structured JSON log line."""
190
- payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
191
- payload.update(kwargs)
192
- logger.info(json.dumps(payload, default=str))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  # ---------------------------------------------------------------------------
196
- # Main loop
197
  # ---------------------------------------------------------------------------
198
 
 
 
 
 
 
 
 
199
 
200
- def run_all_tasks() -> list[dict[str, Any]]:
201
- """Run every registered task and return a list of result dicts."""
202
- env = ClinicalNoteScribeEnv()
203
- results: list[dict[str, Any]] = []
204
-
205
- for task_id in TASK_IDS:
206
- logger.info("")
207
- logger.info("=" * 60)
208
- logger.info(" TASK: %s", task_id)
209
- logger.info("=" * 60)
210
-
211
- t0 = time.time()
212
- _log_event("INFERENCE_START", task_id=task_id)
213
 
 
214
  # ---- reset ----
215
  obs = env.reset(task_id)
216
- logger.info(" Transcript length : %d chars", len(obs.transcript))
217
- logger.info(" Patient context keys: %s", list(obs.patient_context.keys()))
218
-
219
- # ---- generate SOAP note via LLM ----
220
- user_prompt = _build_user_prompt(obs.transcript, obs.patient_context)
221
- logger.info(" Calling model (%s) ...", MODEL_NAME)
222
-
223
- try:
224
- action_dict = _call_model(user_prompt)
225
- except Exception as exc:
226
- logger.error(" Model call failed: %s", exc)
227
- results.append({
228
- "task_id": task_id,
229
- "score": 0.0,
230
- "error": str(exc),
231
- "elapsed_s": round(time.time() - t0, 2),
232
- })
233
- _log_event("INFERENCE_ERROR", task_id=task_id, error=str(exc))
234
- continue
235
-
236
- # ---- validate and create Action ----
237
- try:
238
- action = Action(**action_dict)
239
- except Exception as exc:
240
- logger.error(" Invalid action schema: %s", exc)
241
- logger.error(" Model returned: %s", json.dumps(action_dict, indent=2))
242
- results.append({
243
- "task_id": task_id,
244
- "score": 0.0,
245
- "error": f"schema_error: {exc}",
246
- "elapsed_s": round(time.time() - t0, 2),
247
- })
248
- _log_event("INFERENCE_ERROR", task_id=task_id, error=str(exc))
249
- continue
250
-
251
- # ---- step (submit) ----
252
- obs2, reward, done, info = env.step(action)
253
- elapsed = round(time.time() - t0, 2)
254
-
255
- logger.info(" Done: %s | Reward: %.4f | Elapsed: %.1fs", done, reward.value, elapsed)
256
- logger.info(" Signals: %s",
257
- {k: v for k, v in reward.signals.items() if not k.startswith("_")})
258
-
259
- _log_event("INFERENCE_END", task_id=task_id, score=reward.value, elapsed_s=elapsed)
260
-
261
- results.append({
262
- "task_id": task_id,
263
- "score": reward.value,
264
- "elapsed_s": elapsed,
265
- })
266
-
267
- return results
268
-
269
-
270
- def _print_summary(results: list[dict[str, Any]]) -> None:
271
- """Print a formatted summary table."""
272
- logger.info("")
273
- logger.info("=" * 60)
274
- logger.info(" SUMMARY")
275
- logger.info("=" * 60)
276
-
277
- col_task = max(len("Task"), *(len(r["task_id"]) for r in results))
278
- col_score = 7 # "Score" + padding
279
- col_time = 9 # "Time (s)"
280
-
281
- header = f" {'Task':<{col_task}} {'Score':>{col_score}} {'Time (s)':>{col_time}}"
282
- sep = f" {'-' * col_task} {'-' * col_score} {'-' * col_time}"
283
- logger.info(header)
284
- logger.info(sep)
285
-
286
- total_score = 0.0
287
- for r in results:
288
- score_str = f"{r['score']:.4f}" if "error" not in r else "ERROR"
289
- time_str = f"{r['elapsed_s']:.1f}"
290
- logger.info(f" {r['task_id']:<{col_task}} {score_str:>{col_score}} {time_str:>{col_time}}")
291
- total_score += r["score"]
292
 
293
- logger.info(sep)
294
- avg = total_score / len(results) if results else 0.0
295
- logger.info(f" {'AVERAGE':<{col_task}} {avg:>{col_score}.4f}")
296
- logger.info("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
 
299
  # ---------------------------------------------------------------------------
300
- # Entry point
301
  # ---------------------------------------------------------------------------
302
 
303
- if __name__ == "__main__":
304
  if not API_KEY:
305
- logger.warning(
306
- "OPENAI_API_KEY is not set. The model calls will fail unless "
307
- "the API endpoint does not require authentication."
 
 
308
  )
309
 
310
- logger.info("Clinical Note Scribe — Baseline Inference")
311
- logger.info(" Model : %s", MODEL_NAME)
312
- logger.info(" API Base : %s", API_BASE_URL)
313
- logger.info(" Tasks : %s", TASK_IDS)
314
- logger.info("")
 
 
315
 
316
- start = time.time()
317
- results = run_all_tasks()
318
- total_elapsed = round(time.time() - start, 2)
 
 
319
 
320
- _print_summary(results)
321
- logger.info(" Total wall-clock time: %.1fs", total_elapsed)
 
 
 
322
 
323
- _log_event("INFERENCE_COMPLETE", total_elapsed_s=total_elapsed,
324
- scores={r["task_id"]: r["score"] for r in results})
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script — Clinical Note Scribe
3
+ ===================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined in your environment configuration:
6
+ API_BASE_URL The API endpoint for the LLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+ LOCAL_IMAGE_NAME The name of the local image to use for the environment
10
+ if you are using from_docker_image() method.
11
+
12
+ - Defaults are set only for API_BASE_URL and MODEL_NAME:
13
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
14
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
15
+
16
+ - The inference script must be named `inference.py` and placed in the root directory.
17
+ - Participants must use OpenAI Client for all LLM calls using above variables.
18
+
19
+ STDOUT FORMAT
20
+ - The script must emit exactly three line types to stdout, in this order:
21
+
22
+ [START] task=<task_name> env=<benchmark> model=<model_name>
23
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
24
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
25
+
26
+ Rules:
27
+ - One [START] line at episode begin.
28
+ - One [STEP] line per step, immediately after env.step() returns.
29
+ - One [END] line after the task finishes, always emitted (even on exception).
30
+ - reward and rewards are formatted to 2 decimal places.
31
+ - done and success are lowercase booleans: true or false.
32
+ - error is the raw last_action_error string, or null if none.
33
+ - All fields on a single line with no newlines within a line.
34
+ - Each task should return score in [0, 1].
35
 
36
  Designed to complete in under 20 minutes on 2 vCPU / 8 GB RAM.
37
  """
 
42
  import logging
43
  import os
44
  import sys
45
+ import textwrap
46
+ from typing import Any, List, Optional
47
+
48
+ from openai import OpenAI
49
 
50
  # ---------------------------------------------------------------------------
51
+ # Silence the underlying env's stdout JSON logs (redirect them to stderr)
 
52
  # ---------------------------------------------------------------------------
53
+ env_logger = logging.getLogger("clinical_note_scribe")
54
+ env_logger.setLevel(logging.INFO)
55
+ env_logger.handlers.clear()
56
+ env_logger.addHandler(logging.StreamHandler(sys.stderr))
57
+ env_logger.propagate = False
58
 
 
 
 
 
 
 
59
 
60
  # ---------------------------------------------------------------------------
61
+ # Environment imports
62
  # ---------------------------------------------------------------------------
63
 
64
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
70
  # Config
71
  # ---------------------------------------------------------------------------
72
 
73
+ IMAGE_NAME = os.getenv("IMAGE_NAME")
74
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
75
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
76
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
77
 
78
+ BENCHMARK = "clinical-note-scribe"
79
+ TASK_IDS = list(TASK_REGISTRY.keys())
80
+ MAX_STEPS = 5 # Max steps per task (submit + optional clarify/revise)
81
+ MAX_TOKENS = 1024
82
+ TEMPERATURE = 0.2
83
 
84
  # ---------------------------------------------------------------------------
85
  # System prompt
86
  # ---------------------------------------------------------------------------
87
 
88
+ SYSTEM_PROMPT = textwrap.dedent("""\
89
+ You are a clinical documentation assistant. Given a doctor-patient transcript
90
+ and patient context, generate a concise, clinically accurate SOAP note.
91
+
92
+ RULES:
93
+ 1. Use professional medical language. Avoid over-certain phrasing such as
94
+ "patient definitely has", "diagnosis is certain", or "100% certain".
95
+ 2. Keep the note concise - aim for under 400 words total across all four sections.
96
+ 3. Return your output as a **single valid JSON object** matching this schema exactly:
97
+
98
+ {
99
+ "action_type": "submit_note",
100
+ "soap_note": {
101
+ "subjective": "<patient's reported symptoms, history, and concerns>",
102
+ "objective": "<exam findings, vitals, lab results, imaging>",
103
+ "assessment": "<differential diagnoses and clinical reasoning>",
104
+ "plan": "<treatment plan, medications, follow-up, referrals>"
105
+ }
106
+ }
107
+
108
+ Return ONLY the JSON object. No markdown fences, no commentary, no extra keys.
109
+ """).strip()
110
+
111
 
112
  # ---------------------------------------------------------------------------
113
+ # Stdout logging — mandatory hackathon format
114
  # ---------------------------------------------------------------------------
115
 
116
+ def log_start(task: str, env: str, model: str) -> None:
117
+ print(f"[START] task={task} env={env} model={model}", flush=True)
118
 
 
 
 
 
 
 
 
 
119
 
120
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
121
+ error_val = error if error else "null"
122
+ done_val = str(done).lower()
123
+ print(
124
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
125
+ flush=True,
126
+ )
127
 
 
 
128
 
129
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
130
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
131
+ print(
132
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
133
+ flush=True,
134
+ )
 
 
135
 
136
 
137
+ # ---------------------------------------------------------------------------
138
+ # Helpers
139
+ # ---------------------------------------------------------------------------
140
 
141
+ def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
142
+ """Build the user message containing the transcript and context."""
143
+ ctx_str = json.dumps(patient_context, indent=2, default=str)
144
+ return (
145
+ f"## Patient Context\n```json\n{ctx_str}\n```\n\n"
146
+ f"## Doctor-Patient Transcript\n```\n{transcript}\n```\n\n"
147
+ "Generate the SOAP note as a JSON Action object."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
 
 
 
 
 
149
 
150
 
151
  def _parse_json(raw: str) -> dict[str, Any]:
152
  """Parse the model's raw text output into a dict, tolerating markdown fences."""
153
+ cleaned = raw.strip()
 
154
  if cleaned.startswith("```"):
 
155
  first_newline = cleaned.index("\n")
156
  cleaned = cleaned[first_newline + 1:]
157
  if cleaned.endswith("```"):
158
+ cleaned = cleaned[:-3]
159
  cleaned = cleaned.strip()
160
 
161
  try:
162
  return json.loads(cleaned)
163
  except json.JSONDecodeError as exc:
164
+ print(f"[DEBUG] Failed to parse model output as JSON: {exc}", file=sys.stderr, flush=True)
165
+ print(f"[DEBUG] Raw output:\n{raw}", file=sys.stderr, flush=True)
166
  raise
167
 
168
 
169
+ def get_soap_note(client: OpenAI, transcript: str, patient_context: dict[str, Any]) -> dict[str, Any]:
170
+ """Call the OpenAI-compatible API and return the parsed JSON action dict."""
171
+ user_prompt = _build_user_prompt(transcript, patient_context)
172
+ try:
173
+ completion = client.chat.completions.create(
174
+ model=MODEL_NAME,
175
+ messages=[
176
+ {"role": "system", "content": SYSTEM_PROMPT},
177
+ {"role": "user", "content": user_prompt},
178
+ ],
179
+ temperature=TEMPERATURE,
180
+ max_tokens=MAX_TOKENS,
181
+ stream=False,
182
+ )
183
+ raw = (completion.choices[0].message.content or "").strip()
184
+ return _parse_json(raw)
185
+ except Exception as exc:
186
+ print(f"[DEBUG] Model request failed: {exc}", file=sys.stderr, flush=True)
187
+ raise
188
 
189
 
190
  # ---------------------------------------------------------------------------
191
+ # Per-task runner
192
  # ---------------------------------------------------------------------------
193
 
194
+ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[str, Any]:
195
+ """Run a single task episode and return the result dict."""
196
+ rewards: List[float] = []
197
+ steps_taken = 0
198
+ score = 0.0
199
+ success = False
200
+ last_error: Optional[str] = None
201
 
202
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ try:
205
  # ---- reset ----
206
  obs = env.reset(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ for step in range(1, MAX_STEPS + 1):
209
+ # ---- generate SOAP note via LLM ----
210
+ try:
211
+ action_dict = get_soap_note(client, obs.transcript, obs.patient_context)
212
+ action = Action(**action_dict)
213
+ action_str = f"submit_note(sections=S,O,A,P)"
214
+ except Exception as exc:
215
+ # On model / parse failure, submit a minimal note to avoid hanging
216
+ action = Action(
217
+ action_type="submit_note",
218
+ soap_note=SOAPNote(
219
+ subjective="Unable to generate.",
220
+ objective="Unable to generate.",
221
+ assessment="Unable to generate.",
222
+ plan="Unable to generate.",
223
+ ),
224
+ )
225
+ action_str = f"submit_note(fallback)"
226
+ last_error = str(exc)
227
+
228
+ # ---- step ----
229
+ obs, reward_obj, done, info = env.step(action)
230
+
231
+ reward_val = reward_obj.value
232
+ rewards.append(reward_val)
233
+ steps_taken = step
234
+
235
+ # Check for env-level errors
236
+ error_msg = None
237
+ if obs.errors_so_far:
238
+ error_msg = obs.errors_so_far[-1]
239
+ elif last_error:
240
+ error_msg = last_error
241
+ last_error = None
242
+
243
+ log_step(
244
+ step=step,
245
+ action=action_str,
246
+ reward=reward_val,
247
+ done=done,
248
+ error=error_msg,
249
+ )
250
+
251
+ if done:
252
+ break
253
+
254
+ # Final score = last reward value (already in [0, 1])
255
+ score = rewards[-1] if rewards else 0.0
256
+ score = min(max(score, 0.0), 1.0)
257
+ success = score > 0.0
258
+
259
+ except Exception as exc:
260
+ print(f"[DEBUG] Task {task_id} failed: {exc}", file=sys.stderr, flush=True)
261
+ score = 0.0
262
+ success = False
263
+
264
+ finally:
265
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
266
+
267
+ return {
268
+ "task_id": task_id,
269
+ "score": score,
270
+ "steps": steps_taken,
271
+ "rewards": rewards,
272
+ "success": success,
273
+ }
274
 
275
 
276
  # ---------------------------------------------------------------------------
277
+ # Main
278
  # ---------------------------------------------------------------------------
279
 
280
+ def main() -> None:
281
  if not API_KEY:
282
+ print(
283
+ "[DEBUG] WARNING: HF_TOKEN / API_KEY is not set. "
284
+ "Model calls will fail unless the endpoint requires no auth.",
285
+ file=sys.stderr,
286
+ flush=True,
287
  )
288
 
289
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
290
+ env = ClinicalNoteScribeEnv()
291
+ results: List[dict[str, Any]] = []
292
+
293
+ for task_id in TASK_IDS:
294
+ result = run_task(client, env, task_id)
295
+ results.append(result)
296
 
297
+ # ---- Summary table ----
298
+ print("", file=sys.stderr, flush=True)
299
+ print("=" * 60, file=sys.stderr, flush=True)
300
+ print(" SUMMARY", file=sys.stderr, flush=True)
301
+ print("=" * 60, file=sys.stderr, flush=True)
302
 
303
+ col_task = max(len("Task"), *(len(r["task_id"]) for r in results))
304
+ header = f" {'Task':<{col_task}} {'Score':>7} {'Steps':>5}"
305
+ sep = f" {'-' * col_task} {'-' * 7} {'-' * 5}"
306
+ print(header, file=sys.stderr, flush=True)
307
+ print(sep, file=sys.stderr, flush=True)
308
 
309
+ total_score = 0.0
310
+ for r in results:
311
+ s = f"{r['score']:.4f}" if r["success"] else "ERROR"
312
+ print(f" {r['task_id']:<{col_task}} {s:>7} {r['steps']:>5}", file=sys.stderr, flush=True)
313
+ total_score += r["score"]
314
+
315
+ print(sep, file=sys.stderr, flush=True)
316
+ avg = total_score / len(results) if results else 0.0
317
+ print(f" {'AVERAGE':<{col_task}} {avg:>7.4f}", file=sys.stderr, flush=True)
318
+ print("", file=sys.stderr, flush=True)
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()
out.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [START] task=easy_routine_checkup env=clinical-note-scribe model=gpt-4o-mini
2
+ [STEP] step=1 action=submit_note(fallback) reward=0.70 done=true error=Error code: 401 - {'error': 'Invalid username or password.'}
3
+ [END] success=true steps=1 score=0.70 rewards=0.70
4
+ [START] task=medium_chronic_disease_followup env=clinical-note-scribe model=gpt-4o-mini
5
+ [STEP] step=1 action=submit_note(fallback) reward=0.70 done=true error=Error code: 401 - {'error': 'Invalid username or password.'}
6
+ [END] success=true steps=1 score=0.70 rewards=0.70
7
+ [START] task=hard_complex_er_visit env=clinical-note-scribe model=gpt-4o-mini
8
+ [STEP] step=1 action=submit_note(fallback) reward=0.70 done=true error=Error code: 401 - {'error': 'Invalid username or password.'}
9
+ [END] success=true steps=1 score=0.70 rewards=0.70