Aman Khare commited on
Commit
8b7bdb7
·
1 Parent(s): 3c850f8

Optimize codebase + add minimalist frontend

Browse files
environment/env.py CHANGED
@@ -1,18 +1,4 @@
1
- """ClinicalNoteScribeEnv — core environment loop.
2
-
3
- Implements the ``reset() → Observation``, ``step(Action) → (Observation, Reward, bool, dict)``,
4
- and ``state() → EnvironmentState`` interface required by the OpenEnv spec.
5
-
6
- Structured logging
7
- ------------------
8
- Every episode emits exactly three kinds of JSON log lines to **stdout**:
9
-
10
- - ``{"event": "START", "task_id": "...", "timestamp": ...}``
11
- - ``{"event": "STEP", "step": N, "action_type": "...", "reward": R}``
12
- - ``{"event": "END", "task_id": "...", "final_score": S}``
13
-
14
- The OpenEnv validator scrapes ``[START]``, ``[STEP]``, ``[END]`` keywords.
15
- """
16
 
17
  from __future__ import annotations
18
 
@@ -27,115 +13,55 @@
27
  from environment.tasks import GRADER_REGISTRY, TASK_REGISTRY
28
 
29
  logger = logging.getLogger("clinical_note_scribe")
 
30
 
31
 
32
- # ---------------------------------------------------------------------------
33
- # Helpers
34
- # ---------------------------------------------------------------------------
35
-
36
- def _load_transcript(transcript_path: str) -> str:
37
- """Load a transcript text file from *project-root-relative* path."""
38
- base = Path(__file__).resolve().parent.parent # clinical-note-scribe/
39
- full_path = base / transcript_path
40
- if full_path.exists():
41
- return full_path.read_text(encoding="utf-8")
42
- return f"[Transcript file not found: {transcript_path}]"
43
-
44
-
45
- def _log_event(event: str, **kwargs: Any) -> None:
46
- """Emit a structured JSON log line to stdout via the logger."""
47
- payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
48
- payload.update(kwargs)
49
- logger.info(json.dumps(payload))
50
 
51
 
52
- def _soap_to_text(soap: SOAPNote) -> str:
53
- """Flatten a SOAPNote into a readable multi-line string."""
54
- return (
55
- f"S: {soap.subjective}\n"
56
- f"O: {soap.objective}\n"
57
- f"A: {soap.assessment}\n"
58
- f"P: {soap.plan}"
59
- )
60
-
61
 
62
- # ---------------------------------------------------------------------------
63
- # Main environment class
64
- # ---------------------------------------------------------------------------
 
65
 
66
- class ClinicalNoteScribeEnv:
67
- """Open-environment wrapper for the clinical note-scribe tasks.
68
-
69
- Lifecycle
70
- ---------
71
- 1. ``env.reset(task_id)`` → returns initial ``Observation``
72
- 2. ``env.step(action)`` → returns ``(Observation, Reward, done, info)``
73
- 3. ``env.state()`` → returns full ``EnvironmentState`` snapshot
74
-
75
- Parameters
76
- ----------
77
- clarify_answers_path:
78
- Project-root-relative path to the clarification lookup JSON.
79
- """
80
-
81
- def __init__(
82
- self,
83
- clarify_answers_path: str = "data/clarify_answers.json",
84
- ) -> None:
85
- self._clarify_answers: dict[str, str] = {}
86
- base = Path(__file__).resolve().parent.parent
87
- ca_path = base / clarify_answers_path
88
- if ca_path.exists():
89
- self._clarify_answers = json.loads(ca_path.read_text(encoding="utf-8"))
90
-
91
- # Episode state (initialised properly in reset())
92
  self._task: dict[str, Any] = {}
93
- self._task_id: str = ""
94
- self._transcript: str = ""
95
  self._patient_context: dict[str, Any] = {}
96
- self._max_steps: int = 10
97
- self._step_count: int = 0
98
- self._done: bool = True
99
  self._current_draft: str | None = None
100
  self._errors_so_far: list[str] = []
101
  self._last_reward: Reward | None = None
102
  self._last_observation: Observation | None = None
103
 
104
- # --------------------------------------------------------------------- #
 
 
 
 
 
 
 
 
 
105
  # Public API
106
- # --------------------------------------------------------------------- #
107
 
108
  def reset(self, task_id: str | None = None) -> Observation:
109
- """Start (or restart) an episode for the given *task_id*.
110
-
111
- Parameters
112
- ----------
113
- task_id:
114
- One of the keys in ``TASK_REGISTRY``. When ``None`` the first
115
- registered task is used.
116
-
117
- Returns
118
- -------
119
- Observation
120
- The initial observation for the episode.
121
-
122
- Raises
123
- ------
124
- ValueError
125
- If *task_id* is not found in the registry.
126
- """
127
- if task_id is None:
128
- task_id = next(iter(TASK_REGISTRY))
129
-
130
  if task_id not in TASK_REGISTRY:
131
- available = ", ".join(TASK_REGISTRY.keys())
132
- raise ValueError(
133
- f"Unknown task_id '{task_id}'. Available: {available}"
134
- )
135
 
136
  self._task = TASK_REGISTRY[task_id]
137
  self._task_id = task_id
138
- self._transcript = _load_transcript(self._task["transcript_file"])
 
139
  self._patient_context = self._task.get("patient_context", {})
140
  self._max_steps = self._task.get("max_steps", 10)
141
  self._step_count = 0
@@ -144,254 +70,118 @@ def reset(self, task_id: str | None = None) -> Observation:
144
  self._errors_so_far = []
145
  self._last_reward = None
146
 
147
- _log_event("START", task_id=self._task_id)
148
-
149
- obs = self._build_observation()
150
- self._last_observation = obs
151
- return obs
152
 
153
  def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
154
- """Execute one agent action and return the resulting observation, reward,
155
- done flag, and info dict.
156
-
157
- Parameters
158
- ----------
159
- action:
160
- The agent's chosen action.
161
-
162
- Returns
163
- -------
164
- tuple[Observation, Reward, bool, dict]
165
- """
166
  if self._done:
167
- raise RuntimeError(
168
- "Episode is done. Call reset() before stepping again."
169
- )
170
 
171
  self._step_count += 1
172
  info: dict[str, Any] = {}
173
 
174
- # ---- dispatch by action type ----
175
- if action.action_type == "submit_note":
176
- reward = self._handle_submit(action, info)
177
- elif action.action_type == "request_clarify":
178
- reward = self._handle_clarify(action, info)
179
- elif action.action_type == "revise_section":
180
- reward = self._handle_revise(action, info)
 
 
181
  else:
182
- # Should never happen thanks to the Literal type, but be safe
183
- self._errors_so_far.append(
184
- f"Unknown action_type: {action.action_type}"
185
- )
186
- reward = compute_reward(
187
- action,
188
- grader_score=0.0,
189
- step_count=self._step_count,
190
- errors_so_far=self._errors_so_far,
191
- done=False,
192
- info={"error": "bad_action"},
193
- )
194
-
195
- # ---- enforce max-step termination ----
196
  if self._step_count >= self._max_steps and not self._done:
197
  self._done = True
198
- reward = Reward(
199
- value=reward.value,
200
- signals=reward.signals,
201
- done=True,
202
- info={**reward.info, "termination_reason": "max_steps_reached"},
203
- )
204
 
205
  self._last_reward = reward
206
-
207
- _log_event(
208
- "STEP",
209
- step=self._step_count,
210
- action_type=action.action_type,
211
- reward=reward.value,
212
- )
213
-
214
  if self._done:
215
- _log_event(
216
- "END",
217
- task_id=self._task_id,
218
- final_score=reward.value,
219
- )
220
 
221
- obs = self._build_observation()
222
- self._last_observation = obs
223
- return obs, reward, self._done, info
224
 
225
  def state(self) -> EnvironmentState:
226
- """Return the full internal state snapshot."""
227
  return EnvironmentState(
228
- task_id=self._task_id,
229
- step_count=self._step_count,
230
- max_steps=self._max_steps,
231
- done=self._done,
232
- current_draft=self._current_draft,
233
  errors_so_far=list(self._errors_so_far),
234
- last_reward=self._last_reward,
235
- observation=self._last_observation,
236
  )
237
 
238
- # --------------------------------------------------------------------- #
239
  # Action handlers
240
- # --------------------------------------------------------------------- #
241
-
242
- def _handle_submit(self, action: Action, info: dict) -> Reward:
243
- """Process a ``submit_note`` action.
244
 
245
- If ``action.soap_note`` is provided, it is used directly.
246
- Otherwise, if the agent has built up a draft via ``revise_section``,
247
- the draft is parsed into a SOAPNote automatically.
248
- """
249
  soap = action.soap_note
250
 
251
- # Fall back to the current draft if no explicit note is provided
252
  if soap is None and self._current_draft:
253
- sections: dict[str, str] = {}
254
  for line in self._current_draft.split("\n"):
255
- for prefix in ("S: ", "O: ", "A: ", "P: "):
256
- if line.startswith(prefix):
257
- sections[prefix[0]] = line[len(prefix):]
258
- if all(k in sections for k in "SOAP"):
259
- soap = SOAPNote(
260
- subjective=sections["S"],
261
- objective=sections["O"],
262
- assessment=sections["A"],
263
- plan=sections["P"],
264
- )
265
 
266
  if soap is None:
267
- error = "submit_note requires a non-null soap_note (or a complete draft from revise_section)."
268
- self._errors_so_far.append(error)
269
- return compute_reward(
270
- action,
271
- grader_score=0.0,
272
- step_count=self._step_count,
273
- errors_so_far=self._errors_so_far,
274
- done=False,
275
- info={"error": error},
276
- )
277
-
278
- self._current_draft = _soap_to_text(soap)
279
  self._done = True
280
 
281
- # Attempt to grade via the task-specific grader
282
  grader = GRADER_REGISTRY.get(self._task_id)
283
- if grader is None:
284
  info["warning"] = "No grader registered; returning default reward."
285
- return compute_reward(
286
- action,
287
- grader_score=0.5,
288
- step_count=self._step_count,
289
- errors_so_far=self._errors_so_far,
290
- done=True,
291
- info=info,
292
- )
293
 
294
  try:
295
- raw_signals = grader(soap, self._task)
296
- # Grader returns a signals dict; extract a single scalar score
297
- # as the mean of its values for use as grader_score.
298
- grader_score = (
299
- sum(raw_signals.values()) / len(raw_signals)
300
- if raw_signals else 0.0
301
- )
302
- info["grader_signals"] = raw_signals
303
  except Exception as exc:
304
  info["warning"] = f"Grader error: {exc}"
305
- grader_score = 0.0
306
 
307
- return compute_reward(
308
- action,
309
- grader_score=grader_score,
310
- step_count=self._step_count,
311
- errors_so_far=self._errors_so_far,
312
- done=True,
313
- info=info,
314
- )
315
 
316
- def _handle_clarify(self, action: Action, info: dict) -> Reward:
317
- """Process a ``request_clarify`` action."""
318
- question = (action.clarify_question or "").strip()
319
- if not question:
320
- error = "request_clarify requires a non-empty clarify_question."
321
- self._errors_so_far.append(error)
322
- return Reward(
323
- value=0.0,
324
- signals={"error": 1.0},
325
- done=False,
326
- info={"error": error},
327
- )
328
-
329
- # Lookup a canned answer (case-insensitive key match)
330
- answer = self._clarify_answers.get(question.lower())
331
- if answer:
332
- info["clarify_answer"] = answer
333
- else:
334
- info["clarify_answer"] = (
335
- "No additional information available for that question."
336
- )
337
-
338
- # Intermediate actions get zero reward — only submit_note earns score
339
- return Reward(
340
- value=0.0,
341
- signals={"intermediate_step": 1.0},
342
- done=False,
343
- info=info,
344
- )
345
 
346
- def _handle_revise(self, action: Action, info: dict) -> Reward:
347
- """Process a ``revise_section`` action."""
348
  if action.section is None or action.revision_text is None:
349
- error = "revise_section requires both 'section' and 'revision_text'."
350
- self._errors_so_far.append(error)
351
- return Reward(
352
- value=0.0,
353
- signals={"error": 1.0},
354
- done=False,
355
- info={"error": error},
356
- )
357
-
358
- # If there is an existing draft, patch the requested section
359
  if self._current_draft:
360
  lines = self._current_draft.split("\n")
361
- prefix = f"{action.section}: "
362
- patched = False
363
  for i, line in enumerate(lines):
364
  if line.startswith(prefix):
365
- lines[i] = f"{prefix}{action.revision_text}"
366
- patched = True
367
  break
368
- if patched:
369
- self._current_draft = "\n".join(lines)
370
  else:
371
  self._current_draft += f"\n{prefix}{action.revision_text}"
372
  else:
373
- self._current_draft = f"{action.section}: {action.revision_text}"
374
 
375
  info["revised_section"] = action.section
376
-
377
- # Intermediate actions get zero reward — only submit_note earns score
378
- return Reward(
379
- value=0.0,
380
- signals={"intermediate_step": 1.0},
381
- done=False,
382
- info=info,
383
- )
384
-
385
- # --------------------------------------------------------------------- #
386
- # Internal helpers
387
- # --------------------------------------------------------------------- #
388
-
389
- def _build_observation(self) -> Observation:
390
- return Observation(
391
- transcript=self._transcript,
392
- task_id=self._task_id,
393
- patient_context=self._patient_context,
394
- current_draft=self._current_draft,
395
- errors_so_far=list(self._errors_so_far),
396
- step_count=self._step_count,
397
- )
 
1
+ """ClinicalNoteScribeEnv — core environment implementing reset/step/state."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
13
  from environment.tasks import GRADER_REGISTRY, TASK_REGISTRY
14
 
15
  logger = logging.getLogger("clinical_note_scribe")
16
+ _ROOT = Path(__file__).resolve().parent.parent
17
 
18
 
19
+ def _log(event: str, **kw: Any) -> None:
20
+ logger.info(json.dumps({"event": event, "timestamp": time.time(), **kw}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
+ class ClinicalNoteScribeEnv:
24
+ """OpenEnv-compliant environment for clinical SOAP-note generation."""
 
 
 
 
 
 
 
25
 
26
+ def __init__(self, clarify_answers_path: str = "data/clarify_answers.json") -> None:
27
+ ca = _ROOT / clarify_answers_path
28
+ self._clarify_answers: dict[str, str] = json.loads(ca.read_text()) if ca.exists() else {}
29
+ self._reset_state()
30
 
31
+ def _reset_state(self) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  self._task: dict[str, Any] = {}
33
+ self._task_id = ""
34
+ self._transcript = ""
35
  self._patient_context: dict[str, Any] = {}
36
+ self._max_steps = 10
37
+ self._step_count = 0
38
+ self._done = True
39
  self._current_draft: str | None = None
40
  self._errors_so_far: list[str] = []
41
  self._last_reward: Reward | None = None
42
  self._last_observation: Observation | None = None
43
 
44
+ def _obs(self) -> Observation:
45
+ return Observation(
46
+ transcript=self._transcript,
47
+ task_id=self._task_id,
48
+ patient_context=self._patient_context,
49
+ current_draft=self._current_draft,
50
+ errors_so_far=list(self._errors_so_far),
51
+ step_count=self._step_count,
52
+ )
53
+
54
  # Public API
 
55
 
56
  def reset(self, task_id: str | None = None) -> Observation:
57
+ task_id = task_id or next(iter(TASK_REGISTRY))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if task_id not in TASK_REGISTRY:
59
+ raise ValueError(f"Unknown task_id '{task_id}'. Available: {', '.join(TASK_REGISTRY)}")
 
 
 
60
 
61
  self._task = TASK_REGISTRY[task_id]
62
  self._task_id = task_id
63
+ path = _ROOT / self._task["transcript_file"]
64
+ self._transcript = path.read_text(encoding="utf-8") if path.exists() else f"[Not found: {path}]"
65
  self._patient_context = self._task.get("patient_context", {})
66
  self._max_steps = self._task.get("max_steps", 10)
67
  self._step_count = 0
 
70
  self._errors_so_far = []
71
  self._last_reward = None
72
 
73
+ _log("START", task_id=task_id)
74
+ self._last_observation = self._obs()
75
+ return self._last_observation
 
 
76
 
77
  def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
 
 
 
 
 
 
 
 
 
 
 
 
78
  if self._done:
79
+ raise RuntimeError("Episode is done. Call reset() before stepping again.")
 
 
80
 
81
  self._step_count += 1
82
  info: dict[str, Any] = {}
83
 
84
+ # Dispatch
85
+ handler = {
86
+ "submit_note": self._submit,
87
+ "request_clarify": self._clarify,
88
+ "revise_section": self._revise,
89
+ }.get(action.action_type)
90
+
91
+ if handler:
92
+ reward = handler(action, info)
93
  else:
94
+ self._errors_so_far.append(f"Unknown action_type: {action.action_type}")
95
+ reward = compute_reward(action, 0.0, self._step_count, self._errors_so_far, done=False, info={"error": "bad_action"})
96
+
97
+ # Max-step termination
 
 
 
 
 
 
 
 
 
 
98
  if self._step_count >= self._max_steps and not self._done:
99
  self._done = True
100
+ reward = Reward(value=reward.value, signals=reward.signals, done=True,
101
+ info={**reward.info, "termination_reason": "max_steps_reached"})
 
 
 
 
102
 
103
  self._last_reward = reward
104
+ _log("STEP", step=self._step_count, action_type=action.action_type, reward=reward.value)
 
 
 
 
 
 
 
105
  if self._done:
106
+ _log("END", task_id=self._task_id, final_score=reward.value)
 
 
 
 
107
 
108
+ self._last_observation = self._obs()
109
+ return self._last_observation, reward, self._done, info
 
110
 
111
  def state(self) -> EnvironmentState:
 
112
  return EnvironmentState(
113
+ task_id=self._task_id, step_count=self._step_count, max_steps=self._max_steps,
114
+ done=self._done, current_draft=self._current_draft,
 
 
 
115
  errors_so_far=list(self._errors_so_far),
116
+ last_reward=self._last_reward, observation=self._last_observation,
 
117
  )
118
 
 
119
  # Action handlers
 
 
 
 
120
 
121
+ def _submit(self, action: Action, info: dict) -> Reward:
 
 
 
122
  soap = action.soap_note
123
 
124
+ # Fall back to draft
125
  if soap is None and self._current_draft:
126
+ secs = {}
127
  for line in self._current_draft.split("\n"):
128
+ for p in ("S: ", "O: ", "A: ", "P: "):
129
+ if line.startswith(p):
130
+ secs[p[0]] = line[len(p):]
131
+ if all(k in secs for k in "SOAP"):
132
+ soap = SOAPNote(subjective=secs["S"], objective=secs["O"], assessment=secs["A"], plan=secs["P"])
 
 
 
 
 
133
 
134
  if soap is None:
135
+ err = "submit_note requires a non-null soap_note (or a complete draft from revise_section)."
136
+ self._errors_so_far.append(err)
137
+ return compute_reward(action, 0.0, self._step_count, self._errors_so_far, done=False, info={"error": err})
138
+
139
+ self._current_draft = f"S: {soap.subjective}\nO: {soap.objective}\nA: {soap.assessment}\nP: {soap.plan}"
 
 
 
 
 
 
 
140
  self._done = True
141
 
 
142
  grader = GRADER_REGISTRY.get(self._task_id)
143
+ if not grader:
144
  info["warning"] = "No grader registered; returning default reward."
145
+ return compute_reward(action, 0.5, self._step_count, self._errors_so_far, done=True, info=info)
 
 
 
 
 
 
 
146
 
147
  try:
148
+ signals = grader(soap, self._task)
149
+ score = sum(signals.values()) / len(signals) if signals else 0.0
150
+ info["grader_signals"] = signals
 
 
 
 
 
151
  except Exception as exc:
152
  info["warning"] = f"Grader error: {exc}"
153
+ score = 0.0
154
 
155
+ return compute_reward(action, score, self._step_count, self._errors_so_far, done=True, info=info)
 
 
 
 
 
 
 
156
 
157
+ def _clarify(self, action: Action, info: dict) -> Reward:
158
+ q = (action.clarify_question or "").strip()
159
+ if not q:
160
+ err = "request_clarify requires a non-empty clarify_question."
161
+ self._errors_so_far.append(err)
162
+ return Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": err})
163
+
164
+ info["clarify_answer"] = self._clarify_answers.get(q.lower(), "No additional information available for that question.")
165
+ return Reward(value=0.0, signals={"intermediate_step": 1.0}, done=False, info=info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ def _revise(self, action: Action, info: dict) -> Reward:
 
168
  if action.section is None or action.revision_text is None:
169
+ err = "revise_section requires both 'section' and 'revision_text'."
170
+ self._errors_so_far.append(err)
171
+ return Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": err})
172
+
173
+ prefix = f"{action.section}: "
 
 
 
 
 
174
  if self._current_draft:
175
  lines = self._current_draft.split("\n")
 
 
176
  for i, line in enumerate(lines):
177
  if line.startswith(prefix):
178
+ lines[i] = prefix + action.revision_text
179
+ self._current_draft = "\n".join(lines)
180
  break
 
 
181
  else:
182
  self._current_draft += f"\n{prefix}{action.revision_text}"
183
  else:
184
+ self._current_draft = prefix + action.revision_text
185
 
186
  info["revised_section"] = action.section
187
+ return Reward(value=0.0, signals={"intermediate_step": 1.0}, done=False, info=info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
environment/models.py CHANGED
@@ -1,175 +1,47 @@
1
- """Pydantic v2 models for the Clinical Note Scribe environment.
2
-
3
- Defines the typed contracts for observations, actions, rewards,
4
- and overall environment state used by the OpenEnv spec.
5
- """
6
 
7
  from __future__ import annotations
8
-
9
  from typing import Any, Literal, Optional
10
-
11
  from pydantic import BaseModel, Field
12
 
13
 
14
- # ---------------------------------------------------------------------------
15
- # Observation — what the agent sees after each step
16
- # ---------------------------------------------------------------------------
17
-
18
  class Observation(BaseModel):
19
- """Snapshot of the environment returned to the agent."""
 
 
 
 
 
20
 
21
- transcript: str = Field(
22
- ...,
23
- description="Full doctor–patient transcript for the current task.",
24
- )
25
- task_id: str = Field(
26
- ...,
27
- description="Unique identifier for the task (e.g. 'easy_routine_checkup').",
28
- )
29
- patient_context: dict[str, Any] = Field(
30
- default_factory=dict,
31
- description="Structured patient demographics and history.",
32
- )
33
- current_draft: Optional[str] = Field(
34
- default=None,
35
- description="The agent's most recent SOAP-note draft, if any.",
36
- )
37
- errors_so_far: list[str] = Field(
38
- default_factory=list,
39
- description="Accumulated error/feedback messages from prior steps.",
40
- )
41
- step_count: int = Field(
42
- default=0,
43
- ge=0,
44
- description="Number of steps taken in the current episode.",
45
- )
46
-
47
-
48
- # ---------------------------------------------------------------------------
49
- # SOAPNote — structured clinical note
50
- # ---------------------------------------------------------------------------
51
 
52
  class SOAPNote(BaseModel):
53
- """Standard SOAP clinical-note format."""
54
-
55
- subjective: str = Field(
56
- ...,
57
- description="Patient's self-reported symptoms and history.",
58
- )
59
- objective: str = Field(
60
- ...,
61
- description="Clinician's measurable findings (vitals, exam, labs).",
62
- )
63
- assessment: str = Field(
64
- ...,
65
- description="Clinician's diagnosis or differential.",
66
- )
67
- plan: str = Field(
68
- ...,
69
- description="Treatment plan, follow-ups, and prescriptions.",
70
- )
71
-
72
 
73
- # ---------------------------------------------------------------------------
74
- # Action — what the agent can do
75
- # ---------------------------------------------------------------------------
76
 
77
  class Action(BaseModel):
78
- """An action the agent submits to the environment."""
 
 
 
 
79
 
80
- action_type: Literal["submit_note", "request_clarify", "revise_section"] = Field(
81
- ...,
82
- description="The kind of action the agent is taking.",
83
- )
84
-
85
- # --- submit_note fields ---
86
- soap_note: Optional[SOAPNote] = Field(
87
- default=None,
88
- description="Complete SOAP note (required when action_type == 'submit_note').",
89
- )
90
-
91
- # --- revise_section fields ---
92
- section: Optional[Literal["S", "O", "A", "P"]] = Field(
93
- default=None,
94
- description="Which SOAP section to revise (required when action_type == 'revise_section').",
95
- )
96
- revision_text: Optional[str] = Field(
97
- default=None,
98
- description="Replacement text for the specified section.",
99
- )
100
-
101
- # --- request_clarify fields ---
102
- clarify_question: Optional[str] = Field(
103
- default=None,
104
- description="Free-text question the agent asks for clarification.",
105
- )
106
-
107
-
108
- # ---------------------------------------------------------------------------
109
- # Reward — multi-signal feedback
110
- # ---------------------------------------------------------------------------
111
 
112
  class Reward(BaseModel):
113
- """Reward returned after each step."""
114
-
115
- value: float = Field(
116
- ...,
117
- ge=0.0,
118
- le=1.0,
119
- description="Aggregate reward in the range [0.0, 1.0].",
120
- )
121
- signals: dict[str, float] = Field(
122
- default_factory=dict,
123
- description="Breakdown of individual reward sub-signals.",
124
- )
125
- done: bool = Field(
126
- ...,
127
- description="Whether the episode has ended.",
128
- )
129
- info: dict[str, Any] = Field(
130
- default_factory=dict,
131
- description="Auxiliary metadata (e.g. grader diagnostics).",
132
- )
133
 
134
 
135
- # ---------------------------------------------------------------------------
136
- # EnvironmentState — full internal state exposed by state()
137
- # ---------------------------------------------------------------------------
138
-
139
  class EnvironmentState(BaseModel):
140
- """Complete snapshot of the environment's internal state."""
141
-
142
- task_id: str = Field(
143
- ...,
144
- description="Active task identifier.",
145
- )
146
- step_count: int = Field(
147
- default=0,
148
- ge=0,
149
- description="Steps taken so far in this episode.",
150
- )
151
- max_steps: int = Field(
152
- default=10,
153
- ge=1,
154
- description="Maximum steps allowed per episode.",
155
- )
156
- done: bool = Field(
157
- default=False,
158
- description="Whether the current episode has terminated.",
159
- )
160
- current_draft: Optional[str] = Field(
161
- default=None,
162
- description="Latest SOAP-note draft text, if any.",
163
- )
164
- errors_so_far: list[str] = Field(
165
- default_factory=list,
166
- description="Accumulated feedback/error messages.",
167
- )
168
- last_reward: Optional[Reward] = Field(
169
- default=None,
170
- description="Most recent reward object, if a step has been taken.",
171
- )
172
- observation: Optional[Observation] = Field(
173
- default=None,
174
- description="Most recent observation returned to the agent.",
175
- )
 
1
+ """Pydantic v2 models for the Clinical Note Scribe environment."""
 
 
 
 
2
 
3
  from __future__ import annotations
 
4
  from typing import Any, Literal, Optional
 
5
  from pydantic import BaseModel, Field
6
 
7
 
 
 
 
 
8
  class Observation(BaseModel):
9
+ transcript: str = Field(..., description="Full doctor–patient transcript.")
10
+ task_id: str = Field(..., description="Unique task identifier.")
11
+ patient_context: dict[str, Any] = Field(default_factory=dict, description="Patient demographics and history.")
12
+ current_draft: Optional[str] = Field(None, description="Most recent SOAP-note draft.")
13
+ errors_so_far: list[str] = Field(default_factory=list, description="Accumulated error messages.")
14
+ step_count: int = Field(0, ge=0, description="Steps taken in the current episode.")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class SOAPNote(BaseModel):
18
+ subjective: str = Field(..., description="Patient's self-reported symptoms and history.")
19
+ objective: str = Field(..., description="Clinician's measurable findings.")
20
+ assessment: str = Field(..., description="Clinician's diagnosis or differential.")
21
+ plan: str = Field(..., description="Treatment plan, follow-ups, and prescriptions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
23
 
24
  class Action(BaseModel):
25
+ action_type: Literal["submit_note", "request_clarify", "revise_section"] = Field(..., description="Action kind.")
26
+ soap_note: Optional[SOAPNote] = Field(None, description="SOAP note (required for submit_note).")
27
+ section: Optional[Literal["S", "O", "A", "P"]] = Field(None, description="Section to revise.")
28
+ revision_text: Optional[str] = Field(None, description="Replacement text for the section.")
29
+ clarify_question: Optional[str] = Field(None, description="Clarification question.")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  class Reward(BaseModel):
33
+ value: float = Field(..., ge=0.0, le=1.0, description="Aggregate reward [0, 1].")
34
+ signals: dict[str, float] = Field(default_factory=dict, description="Reward sub-signals.")
35
+ done: bool = Field(..., description="Whether the episode ended.")
36
+ info: dict[str, Any] = Field(default_factory=dict, description="Auxiliary metadata.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
 
 
 
 
39
  class EnvironmentState(BaseModel):
40
+ task_id: str = Field(..., description="Active task identifier.")
41
+ step_count: int = Field(0, ge=0, description="Steps taken so far.")
42
+ max_steps: int = Field(10, ge=1, description="Max steps per episode.")
43
+ done: bool = Field(False, description="Whether the episode terminated.")
44
+ current_draft: Optional[str] = Field(None, description="Latest SOAP draft text.")
45
+ errors_so_far: list[str] = Field(default_factory=list, description="Error messages.")
46
+ last_reward: Optional[Reward] = Field(None, description="Most recent reward.")
47
+ observation: Optional[Observation] = Field(None, description="Most recent observation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
environment/reward.py CHANGED
@@ -1,18 +1,4 @@
1
- """Multi-signal reward computation for the Clinical Note Scribe environment.
2
-
3
- Reward formula (all weights sum to 1.0 before penalties):
4
-
5
- weighted_sum = grader_score × 0.60 (clinical accuracy from task grader)
6
- + conciseness_bonus × 0.10 (1.0 if note ≤ 400 words, else 0.0)
7
- + safe_language_score× 0.15 (1.0 if no unsafe-certainty phrases)
8
- + format_valid × 0.15 (1.0 if SOAP JSON is well-formed)
9
-
10
- Deductions (applied after weighted sum):
11
- - 0.05 × max(0, step_count - 3) (penalty for excessive clarification steps)
12
- - 0.10 × len(errors_so_far) (penalty for each invalid action)
13
-
14
- Final value is clamped to [0.0, 1.0].
15
- """
16
 
17
  from __future__ import annotations
18
 
@@ -21,99 +7,38 @@
21
 
22
  from environment.models import Action, Reward, SOAPNote
23
 
 
 
24
 
25
- # ---------------------------------------------------------------------------
26
- # Weights
27
- # ---------------------------------------------------------------------------
28
-
29
- W_GRADER = 0.60
30
- W_CONCISE = 0.10
31
- W_SAFE_LANG = 0.15
32
- W_FORMAT = 0.15
33
-
34
- # Deduction constants
35
- STEP_PENALTY_RATE = 0.05 # per step beyond FREE_STEPS
36
- FREE_STEPS = 3
37
- ERROR_PENALTY_RATE = 0.10 # per item in errors_so_far
38
-
39
- # Conciseness threshold
40
  WORD_LIMIT = 400
41
 
42
- # Phrases that indicate unsafe clinical certainty
43
- # (over-confident language that a scribe should avoid in a note)
44
- _UNSAFE_PATTERNS: list[re.Pattern[str]] = [
45
- re.compile(p, re.IGNORECASE)
46
- for p in [
47
- r"\bpatient definitely has\b",
48
- r"\bdiagnosis is certain\b",
49
- r"\bno doubt\b",
50
- r"\babsolutely confirmed\b",
51
- r"\b100%\s+certain\b",
52
- r"\bwill definitely\b",
53
- r"\bguaranteed to\b",
54
- r"\bcannot be\s+\w+\s+else\b",
55
- r"\bwithout question\b",
56
- r"\bthis is clearly\b",
57
- ]
58
- ]
59
-
60
-
61
- # ---------------------------------------------------------------------------
62
- # Sub-signal helpers
63
- # ---------------------------------------------------------------------------
64
-
65
- def _conciseness_bonus(soap_note: Optional[SOAPNote]) -> float:
66
- """Return 1.0 if the total SOAP note word count is at or below WORD_LIMIT."""
67
- if soap_note is None:
68
- return 0.0
69
- text = " ".join([
70
- soap_note.subjective,
71
- soap_note.objective,
72
- soap_note.assessment,
73
- soap_note.plan,
74
- ])
75
- word_count = len(text.split())
76
- return 1.0 if word_count <= WORD_LIMIT else 0.0
77
-
78
-
79
- def _safe_language_score(soap_note: Optional[SOAPNote]) -> float:
80
- """Return 1.0 if no unsafe-certainty phrases are found in the SOAP note."""
81
- if soap_note is None:
82
- return 1.0 # no note submitted → no unsafe language
83
- text = " ".join([
84
- soap_note.subjective,
85
- soap_note.objective,
86
- soap_note.assessment,
87
- soap_note.plan,
88
- ])
89
- for pattern in _UNSAFE_PATTERNS:
90
- if pattern.search(text):
91
- return 0.0
92
- return 1.0
93
 
94
 
95
- def _format_valid(action: Action) -> float:
96
- """Return 1.0 if the submitted note has all required non-empty SOAP fields.
97
-
98
- This acts as a lightweight structural / «JSON well-formed» check:
99
- each of S, O, A, P must be a non-empty string, and the action_type
100
- must be ``submit_note``.
101
- """
102
- if action.action_type != "submit_note":
103
- return 1.0 # non-submission actions are not graded on format
104
- if action.soap_note is None:
105
- return 0.0
106
- soap = action.soap_note
107
- fields = [soap.subjective, soap.objective, soap.assessment, soap.plan]
108
- if all(isinstance(f, str) and f.strip() for f in fields):
109
- return 1.0
110
- return 0.0
111
-
112
-
113
- # ---------------------------------------------------------------------------
114
- # Public API
115
- # ---------------------------------------------------------------------------
116
-
117
  def compute_reward(
118
  action: Action,
119
  grader_score: float,
@@ -123,75 +48,54 @@ def compute_reward(
123
  done: bool = False,
124
  info: Optional[dict[str, Any]] = None,
125
  ) -> Reward:
126
- """Compute the multi-signal reward for a completed step.
127
-
128
- Parameters
129
- ----------
130
- action:
131
- The action that was just executed.
132
- grader_score:
133
- Clinical-accuracy score returned by the task-specific grader (0.0–1.0).
134
- Use 0.0 for non-submission actions.
135
- step_count:
136
- Total number of steps taken so far in the episode (including this one).
137
- errors_so_far:
138
- List of error messages accumulated during the episode.
139
- done:
140
- Whether the episode ended with this step.
141
- info:
142
- Optional auxiliary metadata dict to include in the Reward.
143
-
144
- Returns
145
- -------
146
- Reward
147
- Fully populated Reward with ``value`` and ``signals`` breakdown.
148
- """
149
  grader_score = max(0.0, min(1.0, grader_score))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # ---- per-signal scores ----
152
- conciseness = _conciseness_bonus(action.soap_note)
153
- safe_lang = _safe_language_score(action.soap_note)
154
- fmt = _format_valid(action)
155
-
156
- # ---- weighted sum ----
157
- weighted = (
158
  grader_score * W_GRADER
159
- + conciseness * W_CONCISE
160
- + safe_lang * W_SAFE_LANG
161
- + fmt * W_FORMAT
 
 
162
  )
163
-
164
- # ---- deductions ----
165
- extra_steps = max(0, step_count - FREE_STEPS)
166
- step_penalty = extra_steps * STEP_PENALTY_RATE
167
- error_penalty = len(errors_so_far) * ERROR_PENALTY_RATE
168
-
169
- raw = weighted - step_penalty - error_penalty
170
-
171
- # ---- clamp ----
172
- value = max(0.01, min(0.99, raw))
173
-
174
- signals: dict[str, float] = {
175
- # positive contributions
176
- "grader_score": round(grader_score * W_GRADER, 4),
177
- "conciseness_bonus": round(conciseness * W_CONCISE, 4),
178
- "safe_language_score": round(safe_lang * W_SAFE_LANG, 4),
179
- "format_valid": round(fmt * W_FORMAT, 4),
180
- # deductions (stored as negative numbers for clarity)
181
- "step_penalty": round(-step_penalty, 4),
182
- "error_penalty": round(-error_penalty, 4),
183
- # raw sub-signal values (unweighted, for introspection)
184
- "_grader_score_raw": round(grader_score, 4),
185
- "_conciseness_raw": round(conciseness, 4),
186
- "_safe_language_raw": round(safe_lang, 4),
187
- "_format_valid_raw": round(fmt, 4),
188
- "_extra_steps": float(extra_steps),
189
- "_error_count": float(len(errors_so_far)),
190
- }
191
 
192
  return Reward(
193
- value=round(value, 4),
194
- signals=signals,
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  done=done,
196
  info=info or {},
197
  )
 
1
+ """Multi-signal reward computation for the Clinical Note Scribe environment."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
7
 
8
  from environment.models import Action, Reward, SOAPNote
9
 
10
+ # Reward weights (sum to 1.0)
11
+ W_GRADER, W_CONCISE, W_SAFE_LANG, W_FORMAT = 0.60, 0.10, 0.15, 0.15
12
 
13
+ # Deductions
14
+ STEP_PENALTY_RATE = 0.05
15
+ FREE_STEPS = 3
16
+ ERROR_PENALTY_RATE = 0.10
 
 
 
 
 
 
 
 
 
 
 
17
  WORD_LIMIT = 400
18
 
19
+ # Pre-compiled unsafe clinical certainty patterns
20
+ _UNSAFE_RE = re.compile(
21
+ r"\bpatient definitely has\b"
22
+ r"|\bdiagnosis is certain\b"
23
+ r"|\bno doubt\b"
24
+ r"|\babsolutely confirmed\b"
25
+ r"|\b100%\s+certain\b"
26
+ r"|\bwill definitely\b"
27
+ r"|\bguaranteed to\b"
28
+ r"|\bcannot be\s+\w+\s+else\b"
29
+ r"|\bwithout question\b"
30
+ r"|\bthis is clearly\b",
31
+ re.IGNORECASE,
32
+ )
33
+
34
+
35
+ def _soap_text(soap: Optional[SOAPNote]) -> Optional[str]:
36
+ """Join all SOAP fields into one string. Returns None if no note."""
37
+ if soap is None:
38
+ return None
39
+ return f"{soap.subjective} {soap.objective} {soap.assessment} {soap.plan}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def compute_reward(
43
  action: Action,
44
  grader_score: float,
 
48
  done: bool = False,
49
  info: Optional[dict[str, Any]] = None,
50
  ) -> Reward:
51
+ """Compute the multi-signal reward for a completed step."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  grader_score = max(0.0, min(1.0, grader_score))
53
+ text = _soap_text(action.soap_note)
54
+
55
+ # Sub-signals
56
+ conciseness = 1.0 if text and len(text.split()) <= WORD_LIMIT else (0.0 if text else 0.0)
57
+ safe_lang = 0.0 if (text and _UNSAFE_RE.search(text)) else 1.0
58
+ fmt = (
59
+ 1.0
60
+ if action.action_type != "submit_note"
61
+ else (
62
+ 1.0
63
+ if action.soap_note and all(
64
+ getattr(action.soap_note, f).strip()
65
+ for f in ("subjective", "objective", "assessment", "plan")
66
+ )
67
+ else 0.0
68
+ )
69
+ )
70
 
71
+ # Weighted sum minus deductions, clamped to (0, 1)
72
+ extra_steps = max(0, step_count - FREE_STEPS)
73
+ raw = (
 
 
 
 
74
  grader_score * W_GRADER
75
+ + conciseness * W_CONCISE
76
+ + safe_lang * W_SAFE_LANG
77
+ + fmt * W_FORMAT
78
+ - extra_steps * STEP_PENALTY_RATE
79
+ - len(errors_so_far) * ERROR_PENALTY_RATE
80
  )
81
+ value = round(max(0.01, min(0.99, raw)), 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  return Reward(
84
+ value=value,
85
+ signals={
86
+ "grader_score": round(grader_score * W_GRADER, 4),
87
+ "conciseness_bonus": round(conciseness * W_CONCISE, 4),
88
+ "safe_language_score": round(safe_lang * W_SAFE_LANG, 4),
89
+ "format_valid": round(fmt * W_FORMAT, 4),
90
+ "step_penalty": round(-extra_steps * STEP_PENALTY_RATE, 4),
91
+ "error_penalty": round(-len(errors_so_far) * ERROR_PENALTY_RATE, 4),
92
+ "_grader_score_raw": round(grader_score, 4),
93
+ "_conciseness_raw": round(conciseness, 4),
94
+ "_safe_language_raw": round(safe_lang, 4),
95
+ "_format_valid_raw": round(fmt, 4),
96
+ "_extra_steps": float(extra_steps),
97
+ "_error_count": float(len(errors_so_far)),
98
+ },
99
  done=done,
100
  info=info or {},
101
  )
environment/tasks/task_easy.py CHANGED
@@ -1,85 +1,37 @@
1
- """Easy task — routine check-up.
2
-
3
- Grader uses keyword-based clinical rubric scoring to evaluate the SOAP note
4
- against expected findings from a simple cold / blood pressure check visit.
5
- """
6
 
7
  from __future__ import annotations
8
-
9
  from typing import Any
10
-
11
  from environment.models import SOAPNote
12
 
13
-
14
- # ---------------------------------------------------------------------------
15
- # Task definition
16
- # ---------------------------------------------------------------------------
17
-
18
  EASY_TASK: dict[str, Any] = {
19
  "task_id": "easy_routine_checkup",
20
  "description": "Generate a SOAP note for a routine annual check-up visit.",
21
  "transcript_file": "data/transcripts/easy.txt",
22
  "patient_context": {
23
- "patient_id": "P-1001",
24
- "name": "Jane Doe",
25
- "age": 34,
26
- "sex": "F",
27
- "known_conditions": [],
28
- "current_medications": [],
29
- "allergies": ["Penicillin"],
30
  },
31
  "max_steps": 5,
32
  }
33
 
34
 
35
- # ---------------------------------------------------------------------------
36
- # Grader
37
- # ---------------------------------------------------------------------------
38
-
39
  def grade_easy(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
40
- """Score a submitted SOAP note against the easy-task rubric.
41
-
42
- Checks for mention of key clinical findings from the transcript:
43
- chief complaints, vitals, viral URI assessment, and supportive plan.
44
-
45
- Returns
46
- -------
47
- dict mapping signal names to float scores in [0, 1].
48
- """
49
- text_s = soap_note.subjective.lower()
50
- text_o = soap_note.objective.lower()
51
- text_a = soap_note.assessment.lower()
52
- text_p = soap_note.plan.lower()
53
-
54
- # 1. Subjective — chief complaints
55
- s_score = 0.0
56
- if "sore throat" in text_s or "runny nose" in text_s or "congestion" in text_s:
57
- s_score += 0.5
58
- if "5 days" in text_s or "five days" in text_s or "headache" in text_s:
59
- s_score += 0.5
60
-
61
- # 2. Objective — vitals
62
- o_score = 0.0
63
- if "118/76" in text_o or "118 over 76" in text_o or "blood pressure" in text_o:
64
- o_score += 0.5
65
- if "72" in text_o or "heart rate" in text_o or "lungs clear" in text_o:
66
- o_score += 0.5
67
-
68
- # 3. Assessment — viral URI
69
- a_score = 0.0
70
- if "viral" in text_a or "uri" in text_a or "upper respiratory" in text_a:
71
- a_score += 1.0
72
-
73
- # 4. Plan — supportive care
74
- p_score = 0.0
75
- if "fluids" in text_p or "rest" in text_p or "hydrat" in text_p:
76
- p_score += 0.5
77
- if "dayquil" in text_p or "follow" in text_p or "return" in text_p:
78
- p_score += 0.5
79
-
80
  return {
81
- "subjective_accuracy": max(0.01, min(s_score, 0.99)),
82
- "objective_accuracy": max(0.01, min(o_score, 0.99)),
83
- "assessment_accuracy": max(0.01, min(a_score, 0.99)),
84
- "plan_accuracy": max(0.01, min(p_score, 0.99)),
85
  }
 
1
+ """Easy task — routine check-up."""
 
 
 
 
2
 
3
  from __future__ import annotations
 
4
  from typing import Any
 
5
  from environment.models import SOAPNote
6
 
 
 
 
 
 
7
  EASY_TASK: dict[str, Any] = {
8
  "task_id": "easy_routine_checkup",
9
  "description": "Generate a SOAP note for a routine annual check-up visit.",
10
  "transcript_file": "data/transcripts/easy.txt",
11
  "patient_context": {
12
+ "patient_id": "P-1001", "name": "Jane Doe", "age": 34, "sex": "F",
13
+ "known_conditions": [], "current_medications": [], "allergies": ["Penicillin"],
 
 
 
 
 
14
  },
15
  "max_steps": 5,
16
  }
17
 
18
 
 
 
 
 
19
  def grade_easy(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
20
+ s, o, a, p = (soap_note.subjective.lower(), soap_note.objective.lower(),
21
+ soap_note.assessment.lower(), soap_note.plan.lower())
22
+
23
+ s_score = 0.5 * any(k in s for k in ("sore throat", "runny nose", "congestion")) \
24
+ + 0.5 * any(k in s for k in ("5 days", "five days", "headache"))
25
+ o_score = 0.5 * any(k in o for k in ("118/76", "118 over 76", "blood pressure")) \
26
+ + 0.5 * any(k in o for k in ("72", "heart rate", "lungs clear"))
27
+ a_score = 1.0 * any(k in a for k in ("viral", "uri", "upper respiratory"))
28
+ p_score = 0.5 * any(k in p for k in ("fluids", "rest", "hydrat")) \
29
+ + 0.5 * any(k in p for k in ("dayquil", "follow", "return"))
30
+
31
+ clamp = lambda v: max(0.01, min(v, 0.99))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return {
33
+ "subjective_accuracy": clamp(s_score),
34
+ "objective_accuracy": clamp(o_score),
35
+ "assessment_accuracy": clamp(a_score),
36
+ "plan_accuracy": clamp(p_score),
37
  }
environment/tasks/task_hard.py CHANGED
@@ -1,118 +1,42 @@
1
- """Hard task — complex ER visit.
2
-
3
- Grader uses keyword-based clinical rubric scoring to evaluate the SOAP note
4
- against expected findings from a complex ER visit with overlapping chest pain,
5
- SOB, and a possible PE complicated by a contrast dye allergy.
6
- """
7
 
8
  from __future__ import annotations
9
-
10
  from typing import Any
11
-
12
  from environment.models import SOAPNote
13
 
14
-
15
- # ---------------------------------------------------------------------------
16
- # Task definition
17
- # ---------------------------------------------------------------------------
18
-
19
  HARD_TASK: dict[str, Any] = {
20
  "task_id": "hard_complex_er_visit",
21
- "description": (
22
- "Generate a SOAP note for a complex emergency-room visit involving "
23
- "chest pain, polytrauma assessment, and multiple co-morbidities."
24
- ),
25
  "transcript_file": "data/transcripts/hard.txt",
26
  "patient_context": {
27
- "patient_id": "P-3782",
28
- "name": "Maria Garcia",
29
- "age": 72,
30
- "sex": "F",
31
- "known_conditions": [
32
- "Coronary Artery Disease",
33
- "Atrial Fibrillation",
34
- "Chronic Kidney Disease Stage 3",
35
- "Osteoarthritis",
36
- ],
37
- "current_medications": [
38
- "Aspirin 81 mg daily",
39
- "Warfarin 5 mg daily",
40
- "Metoprolol 50 mg BID",
41
- "Furosemide 40 mg daily",
42
- "Amlodipine 5 mg daily",
43
- ],
44
  "allergies": ["Sulfa drugs", "Contrast dye"],
45
- "recent_labs": {
46
- "troponin_I": "0.08 ng/mL",
47
- "BNP": "450 pg/mL",
48
- "creatinine": "1.9 mg/dL",
49
- "eGFR": "34 mL/min",
50
- "INR": "2.6",
51
- "hemoglobin": "10.2 g/dL",
52
- },
53
- "vitals_on_arrival": {
54
- "BP": "168/94 mmHg",
55
- "HR": "112 bpm (irregular)",
56
- "RR": "22 breaths/min",
57
- "SpO2": "91% on room air",
58
- "Temp": "37.2°C",
59
- },
60
  },
61
  "max_steps": 10,
62
  }
63
 
64
 
65
- # ---------------------------------------------------------------------------
66
- # Grader
67
- # ---------------------------------------------------------------------------
68
-
69
  def grade_hard(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
70
- """Score a submitted SOAP note against the hard-task rubric.
71
-
72
- Checks for chest pain / SOB and the nitroglycerin contradiction (subjective),
73
- D-dimer and contrast allergy (objective), ACS vs PE differential (assessment),
74
- and V/Q scan + ICU admission (plan).
75
-
76
- Returns
77
- -------
78
- dict mapping signal names to float scores in [0, 1].
79
- """
80
- text_s = soap_note.subjective.lower()
81
- text_o = soap_note.objective.lower()
82
- text_a = soap_note.assessment.lower()
83
- text_p = soap_note.plan.lower()
84
-
85
- # 1. Subjective — catching the contradiction and presenting complaints
86
- s_score = 0.0
87
- if "chest pain" in text_s or "shortness of breath" in text_s or "sob" in text_s:
88
- s_score += 0.5
89
- if "nitroglycerin" in text_s or "contradict" in text_s or "denied" in text_s:
90
- s_score += 0.5
91
-
92
- # 2. Objective — elevated D-dimer and allergy awareness
93
- o_score = 0.0
94
- if "d-dimer" in text_o or "1840" in text_o or "d dimer" in text_o:
95
- o_score += 0.5
96
- if "allergy" in text_o or "contrast" in text_o or "troponin" in text_o:
97
- o_score += 0.5
98
-
99
- # 3. Assessment — the dual differential (ACS vs PE)
100
- a_score = 0.0
101
- if "acs" in text_a or "acute coronary" in text_a or "coronary" in text_a or "ischemia" in text_a:
102
- a_score += 0.5
103
- if "pe" in text_a or "pulmonary embolism" in text_a or "embolism" in text_a:
104
- a_score += 0.5
105
-
106
- # 4. Plan — adapting to the allergy (V/Q scan) and admission
107
- p_score = 0.0
108
- if "v/q" in text_p or "ventilation" in text_p or "perfusion" in text_p:
109
- p_score += 0.5
110
- if "icu" in text_p or "admit" in text_p or "cardiac" in text_p:
111
- p_score += 0.5
112
-
113
  return {
114
- "subjective_accuracy": max(0.01, min(s_score, 0.99)),
115
- "objective_accuracy": max(0.01, min(o_score, 0.99)),
116
- "assessment_accuracy": max(0.01, min(a_score, 0.99)),
117
- "plan_accuracy": max(0.01, min(p_score, 0.99)),
118
  }
 
1
+ """Hard task — complex ER visit."""
 
 
 
 
 
2
 
3
  from __future__ import annotations
 
4
  from typing import Any
 
5
  from environment.models import SOAPNote
6
 
 
 
 
 
 
7
  HARD_TASK: dict[str, Any] = {
8
  "task_id": "hard_complex_er_visit",
9
+ "description": "Generate a SOAP note for a complex ER visit with chest pain, PE differential, and contrast allergy.",
 
 
 
10
  "transcript_file": "data/transcripts/hard.txt",
11
  "patient_context": {
12
+ "patient_id": "P-3782", "name": "Maria Garcia", "age": 72, "sex": "F",
13
+ "known_conditions": ["Coronary Artery Disease", "Atrial Fibrillation", "Chronic Kidney Disease Stage 3", "Osteoarthritis"],
14
+ "current_medications": ["Aspirin 81 mg daily", "Warfarin 5 mg daily", "Metoprolol 50 mg BID", "Furosemide 40 mg daily", "Amlodipine 5 mg daily"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  "allergies": ["Sulfa drugs", "Contrast dye"],
16
+ "recent_labs": {"troponin_I": "0.08 ng/mL", "BNP": "450 pg/mL", "creatinine": "1.9 mg/dL", "eGFR": "34 mL/min", "INR": "2.6", "hemoglobin": "10.2 g/dL"},
17
+ "vitals_on_arrival": {"BP": "168/94 mmHg", "HR": "112 bpm (irregular)", "RR": "22 breaths/min", "SpO2": "91% on room air", "Temp": "37.2°C"},
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  },
19
  "max_steps": 10,
20
  }
21
 
22
 
 
 
 
 
23
  def grade_hard(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
24
+ s, o, a, p = (soap_note.subjective.lower(), soap_note.objective.lower(),
25
+ soap_note.assessment.lower(), soap_note.plan.lower())
26
+
27
+ s_score = 0.5 * any(k in s for k in ("chest pain", "shortness of breath", "sob")) \
28
+ + 0.5 * any(k in s for k in ("nitroglycerin", "contradict", "denied"))
29
+ o_score = 0.5 * any(k in o for k in ("d-dimer", "1840", "d dimer")) \
30
+ + 0.5 * any(k in o for k in ("allergy", "contrast", "troponin"))
31
+ a_score = 0.5 * any(k in a for k in ("acs", "acute coronary", "coronary", "ischemia")) \
32
+ + 0.5 * any(k in a for k in ("pe", "pulmonary embolism", "embolism"))
33
+ p_score = 0.5 * any(k in p for k in ("v/q", "ventilation", "perfusion")) \
34
+ + 0.5 * any(k in p for k in ("icu", "admit", "cardiac"))
35
+
36
+ clamp = lambda v: max(0.01, min(v, 0.99))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return {
38
+ "subjective_accuracy": clamp(s_score),
39
+ "objective_accuracy": clamp(o_score),
40
+ "assessment_accuracy": clamp(a_score),
41
+ "plan_accuracy": clamp(p_score),
42
  }
environment/tasks/task_medium.py CHANGED
@@ -1,98 +1,41 @@
1
- """Medium task — chronic disease follow-up.
2
-
3
- Grader uses keyword-based clinical rubric scoring to evaluate the SOAP note
4
- against expected findings from a Type 2 Diabetes / Hypertension follow-up.
5
- """
6
 
7
  from __future__ import annotations
8
-
9
  from typing import Any
10
-
11
  from environment.models import SOAPNote
12
 
13
-
14
- # ---------------------------------------------------------------------------
15
- # Task definition
16
- # ---------------------------------------------------------------------------
17
-
18
  MEDIUM_TASK: dict[str, Any] = {
19
  "task_id": "medium_chronic_disease_followup",
20
  "description": "Generate a SOAP note for a Type 2 Diabetes follow-up visit.",
21
  "transcript_file": "data/transcripts/medium.txt",
22
  "patient_context": {
23
- "patient_id": "P-2045",
24
- "name": "Robert Smith",
25
- "age": 58,
26
- "sex": "M",
27
  "known_conditions": ["Type 2 Diabetes Mellitus", "Hypertension"],
28
- "current_medications": [
29
- "Metformin 1000 mg BID",
30
- "Lisinopril 20 mg daily",
31
- "Atorvastatin 40 mg daily",
32
- ],
33
  "allergies": [],
34
- "recent_labs": {
35
- "HbA1c": "7.8%",
36
- "fasting_glucose": "156 mg/dL",
37
- "creatinine": "1.1 mg/dL",
38
- "eGFR": "78 mL/min",
39
- "LDL": "102 mg/dL",
40
- },
41
  },
42
  "max_steps": 8,
43
  }
44
 
45
 
46
- # ---------------------------------------------------------------------------
47
- # Grader
48
- # ---------------------------------------------------------------------------
49
-
50
  def grade_medium(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
51
- """Score a submitted SOAP note against the medium-task rubric.
52
-
53
- Checks for mention of dietary habits, HbA1c lab values, core diagnoses,
54
- and medication adjustments (glipizide, lisinopril uptitration).
55
-
56
- Returns
57
- -------
58
- dict mapping signal names to float scores in [0, 1].
59
- """
60
- text_s = soap_note.subjective.lower()
61
- text_o = soap_note.objective.lower()
62
- text_a = soap_note.assessment.lower()
63
- text_p = soap_note.plan.lower()
64
-
65
- # 1. Subjective — dietary habits / statin gap
66
- s_score = 0.0
67
- if "restaurant" in text_s or "diet" in text_s or "eating" in text_s:
68
- s_score += 0.5
69
- if "statin" in text_s or "gap" in text_s or "missed" in text_s:
70
- s_score += 0.5
71
-
72
- # 2. Objective — HbA1c values
73
- o_score = 0.0
74
- if "7.8" in text_o or "7.2" in text_o or "a1c" in text_o or "hba1c" in text_o:
75
- o_score += 0.5
76
- if "156" in text_o or "fasting glucose" in text_o or "glucose" in text_o:
77
- o_score += 0.5
78
-
79
- # 3. Assessment — core diagnoses
80
- a_score = 0.0
81
- if "diabetes" in text_a or "t2dm" in text_a or "dm" in text_a:
82
- a_score += 0.5
83
- if "hypertension" in text_a or "htn" in text_a or "blood pressure" in text_a:
84
- a_score += 0.5
85
-
86
- # 4. Plan — medication changes
87
- p_score = 0.0
88
- if "glipizide" in text_p and ("5" in text_p or "add" in text_p):
89
- p_score += 0.5
90
- if "lisinopril" in text_p and ("40" in text_p or "increase" in text_p or "uptitrat" in text_p):
91
- p_score += 0.5
92
-
93
  return {
94
- "subjective_accuracy": max(0.01, min(s_score, 0.99)),
95
- "objective_accuracy": max(0.01, min(o_score, 0.99)),
96
- "assessment_accuracy": max(0.01, min(a_score, 0.99)),
97
- "plan_accuracy": max(0.01, min(p_score, 0.99)),
98
  }
 
1
+ """Medium task — chronic disease follow-up."""
 
 
 
 
2
 
3
  from __future__ import annotations
 
4
  from typing import Any
 
5
  from environment.models import SOAPNote
6
 
 
 
 
 
 
7
  MEDIUM_TASK: dict[str, Any] = {
8
  "task_id": "medium_chronic_disease_followup",
9
  "description": "Generate a SOAP note for a Type 2 Diabetes follow-up visit.",
10
  "transcript_file": "data/transcripts/medium.txt",
11
  "patient_context": {
12
+ "patient_id": "P-2045", "name": "Robert Smith", "age": 58, "sex": "M",
 
 
 
13
  "known_conditions": ["Type 2 Diabetes Mellitus", "Hypertension"],
14
+ "current_medications": ["Metformin 1000 mg BID", "Lisinopril 20 mg daily", "Atorvastatin 40 mg daily"],
 
 
 
 
15
  "allergies": [],
16
+ "recent_labs": {"HbA1c": "7.8%", "fasting_glucose": "156 mg/dL", "creatinine": "1.1 mg/dL", "eGFR": "78 mL/min", "LDL": "102 mg/dL"},
 
 
 
 
 
 
17
  },
18
  "max_steps": 8,
19
  }
20
 
21
 
 
 
 
 
22
  def grade_medium(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
23
+ s, o, a, p = (soap_note.subjective.lower(), soap_note.objective.lower(),
24
+ soap_note.assessment.lower(), soap_note.plan.lower())
25
+
26
+ s_score = 0.5 * any(k in s for k in ("restaurant", "diet", "eating")) \
27
+ + 0.5 * any(k in s for k in ("statin", "gap", "missed"))
28
+ o_score = 0.5 * any(k in o for k in ("7.8", "7.2", "a1c", "hba1c")) \
29
+ + 0.5 * any(k in o for k in ("156", "fasting glucose", "glucose"))
30
+ a_score = 0.5 * any(k in a for k in ("diabetes", "t2dm", "dm")) \
31
+ + 0.5 * any(k in a for k in ("hypertension", "htn", "blood pressure"))
32
+ p_score = 0.5 * ("glipizide" in p and any(k in p for k in ("5", "add"))) \
33
+ + 0.5 * ("lisinopril" in p and any(k in p for k in ("40", "increase", "uptitrat")))
34
+
35
+ clamp = lambda v: max(0.01, min(v, 0.99))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  return {
37
+ "subjective_accuracy": clamp(s_score),
38
+ "objective_accuracy": clamp(o_score),
39
+ "assessment_accuracy": clamp(a_score),
40
+ "plan_accuracy": clamp(p_score),
41
  }
frontend/app.js ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Clinical Note Scribe — Frontend Logic
2
+
3
+ const API = "";
4
+
5
+ // DOM refs
6
+ const taskSelect = document.getElementById("taskSelect");
7
+ const resetBtn = document.getElementById("resetBtn");
8
+ const stepBtn = document.getElementById("stepBtn");
9
+ const statusBadge = document.getElementById("statusBadge");
10
+ const contextSection = document.getElementById("contextSection");
11
+ const contextGrid = document.getElementById("contextGrid");
12
+ const transcriptArea = document.getElementById("transcriptArea");
13
+ const actionSection = document.getElementById("actionSection");
14
+ const actionType = document.getElementById("actionType");
15
+ const sectionSelect = document.getElementById("sectionSelect");
16
+ const soapInputs = document.getElementById("soapInputs");
17
+ const reviseInput = document.getElementById("reviseInput");
18
+ const clarifyInput = document.getElementById("clarifyInput");
19
+ const rewardSection = document.getElementById("rewardSection");
20
+ const scoreValue = document.getElementById("scoreValue");
21
+ const rewardFill = document.getElementById("rewardFill");
22
+ const draftArea = document.getElementById("draftArea");
23
+ const draftEmpty = document.getElementById("draftEmpty");
24
+ const soapDraft = document.getElementById("soapDraft");
25
+ const soapGrid = document.getElementById("soapGrid");
26
+ const logContainer = document.getElementById("logContainer");
27
+
28
+ let currentObs = null;
29
+ let isDone = false;
30
+
31
+ // Logging
32
+ function addLog(msg, type = "") {
33
+ const time = new Date().toLocaleTimeString("en-US", { hour12: false });
34
+ const entry = document.createElement("div");
35
+ entry.className = "log-entry " + type;
36
+ entry.innerHTML = `<span class="log-time">${time}</span>${msg}`;
37
+ logContainer.prepend(entry);
38
+ }
39
+
40
+ // Status badge
41
+ function setStatus(state) {
42
+ statusBadge.className = "status-badge " + state;
43
+ statusBadge.textContent = state === "idle" ? "Idle" : state === "active" ? "Active" : "Done";
44
+ }
45
+
46
+ // Toggle action inputs based on action type
47
+ actionType.addEventListener("change", () => {
48
+ const val = actionType.value;
49
+ soapInputs.style.display = val === "submit_note" ? "block" : "none";
50
+ reviseInput.style.display = val === "revise_section" ? "block" : "none";
51
+ clarifyInput.style.display = val === "request_clarify" ? "block" : "none";
52
+ sectionSelect.style.display = val === "revise_section" ? "inline-block" : "none";
53
+ });
54
+
55
+ // Format transcript with speaker highlighting
56
+ function renderTranscript(text) {
57
+ if (!text) return "";
58
+ const lines = text.split("\n");
59
+ return lines.map(line => {
60
+ if (/^(Dr\.|Doctor)/i.test(line.trim())) {
61
+ return `<div><span class="speaker-doctor">${escapeHtml(line)}</span></div>`;
62
+ } else if (/^(Patient|Pt)/i.test(line.trim())) {
63
+ return `<div><span class="speaker-patient">${escapeHtml(line)}</span></div>`;
64
+ }
65
+ return `<div>${escapeHtml(line)}</div>`;
66
+ }).join("");
67
+ }
68
+
69
+ function escapeHtml(str) {
70
+ const div = document.createElement("div");
71
+ div.textContent = str;
72
+ return div.innerHTML;
73
+ }
74
+
75
+ // Render patient context as cards
76
+ function renderContext(ctx) {
77
+ if (!ctx || Object.keys(ctx).length === 0) {
78
+ contextSection.style.display = "none";
79
+ return;
80
+ }
81
+ contextSection.style.display = "block";
82
+ contextGrid.innerHTML = "";
83
+
84
+ const flat = flattenContext(ctx);
85
+ for (const [key, val] of Object.entries(flat)) {
86
+ const card = document.createElement("div");
87
+ card.className = "context-card";
88
+ card.innerHTML = `<div class="label">${escapeHtml(key)}</div><div class="value">${escapeHtml(String(val))}</div>`;
89
+ contextGrid.appendChild(card);
90
+ }
91
+ }
92
+
93
+ function flattenContext(obj, prefix = "") {
94
+ const result = {};
95
+ for (const [k, v] of Object.entries(obj)) {
96
+ const key = prefix ? `${prefix} › ${k}` : k;
97
+ if (v && typeof v === "object" && !Array.isArray(v)) {
98
+ Object.assign(result, flattenContext(v, key));
99
+ } else if (Array.isArray(v)) {
100
+ result[key] = v.length > 0 ? v.join(", ") : "—";
101
+ } else {
102
+ result[key] = v ?? "—";
103
+ }
104
+ }
105
+ return result;
106
+ }
107
+
108
+ // Render SOAP draft
109
+ function renderDraft(draftText) {
110
+ if (!draftText) {
111
+ draftEmpty.style.display = "flex";
112
+ soapDraft.style.display = "none";
113
+ return;
114
+ }
115
+ draftEmpty.style.display = "none";
116
+ soapDraft.style.display = "block";
117
+
118
+ const sections = { S: "", O: "", A: "", P: "" };
119
+ const lines = draftText.split("\n");
120
+ for (const line of lines) {
121
+ for (const prefix of ["S: ", "O: ", "A: ", "P: "]) {
122
+ if (line.startsWith(prefix)) {
123
+ sections[prefix[0]] = line.slice(prefix.length);
124
+ }
125
+ }
126
+ }
127
+
128
+ const labels = { S: "Subjective", O: "Objective", A: "Assessment", P: "Plan" };
129
+ soapGrid.innerHTML = "";
130
+ for (const [key, label] of Object.entries(labels)) {
131
+ const card = document.createElement("div");
132
+ card.className = `soap-card ${key.toLowerCase()}`;
133
+ card.innerHTML = `
134
+ <div class="soap-label">${label}</div>
135
+ <div class="soap-text">${escapeHtml(sections[key]) || '<em style="opacity:0.4">Empty</em>'}</div>
136
+ `;
137
+ soapGrid.appendChild(card);
138
+ }
139
+ }
140
+
141
+ // Render reward
142
+ function renderReward(rewardObj) {
143
+ if (!rewardObj) {
144
+ rewardSection.style.display = "none";
145
+ return;
146
+ }
147
+ rewardSection.style.display = "block";
148
+ const val = rewardObj.value;
149
+ scoreValue.textContent = val.toFixed(4);
150
+ rewardFill.style.width = (val * 100) + "%";
151
+
152
+ if (val >= 0.7) {
153
+ scoreValue.style.color = "var(--green)";
154
+ } else if (val >= 0.4) {
155
+ scoreValue.style.color = "var(--yellow)";
156
+ } else {
157
+ scoreValue.style.color = "var(--red)";
158
+ }
159
+ }
160
+
161
+ // Update UI from observation
162
+ function updateUI(obs, reward = null, done = false) {
163
+ currentObs = obs;
164
+ isDone = done;
165
+
166
+ transcriptArea.innerHTML = `<div class="transcript-box">${renderTranscript(obs.transcript)}</div>`;
167
+ renderContext(obs.patient_context);
168
+ renderDraft(obs.current_draft);
169
+
170
+ if (reward) renderReward(reward);
171
+
172
+ actionSection.style.display = done ? "none" : "block";
173
+ setStatus(done ? "done" : "active");
174
+
175
+ if (done) {
176
+ addLog(`Episode complete — score: ${reward ? reward.value.toFixed(4) : "N/A"}`, "success");
177
+ }
178
+ }
179
+
180
+ // Reset
181
+ resetBtn.addEventListener("click", async () => {
182
+ const taskId = taskSelect.value;
183
+ resetBtn.disabled = true;
184
+ addLog(`Resetting with task: ${taskId}`);
185
+
186
+ try {
187
+ const res = await fetch(`${API}/reset`, {
188
+ method: "POST",
189
+ headers: { "Content-Type": "application/json" },
190
+ body: JSON.stringify({ task_id: taskId }),
191
+ });
192
+ if (!res.ok) throw new Error(await res.text());
193
+ const obs = await res.json();
194
+ rewardSection.style.display = "none";
195
+ updateUI(obs);
196
+ addLog("Environment reset successfully", "success");
197
+ } catch (err) {
198
+ addLog(`Reset failed: ${err.message}`, "error");
199
+ } finally {
200
+ resetBtn.disabled = false;
201
+ }
202
+ });
203
+
204
+ // Step
205
+ stepBtn.addEventListener("click", async () => {
206
+ if (isDone) {
207
+ addLog("Episode is done. Reset first.", "error");
208
+ return;
209
+ }
210
+
211
+ const action = actionType.value;
212
+ let payload = {};
213
+
214
+ if (action === "submit_note") {
215
+ payload = {
216
+ action_type: "submit_note",
217
+ soap_note: {
218
+ subjective: document.getElementById("inputS").value,
219
+ objective: document.getElementById("inputO").value,
220
+ assessment: document.getElementById("inputA").value,
221
+ plan: document.getElementById("inputP").value,
222
+ },
223
+ };
224
+ } else if (action === "revise_section") {
225
+ payload = {
226
+ action_type: "revise_section",
227
+ section: sectionSelect.value,
228
+ revision_text: document.getElementById("inputRevision").value,
229
+ };
230
+ } else if (action === "request_clarify") {
231
+ payload = {
232
+ action_type: "request_clarify",
233
+ clarify_question: document.getElementById("inputClarify").value,
234
+ };
235
+ }
236
+
237
+ stepBtn.disabled = true;
238
+ addLog(`Sending action: ${action}`);
239
+
240
+ try {
241
+ const res = await fetch(`${API}/step`, {
242
+ method: "POST",
243
+ headers: { "Content-Type": "application/json" },
244
+ body: JSON.stringify(payload),
245
+ });
246
+ if (!res.ok) throw new Error(await res.text());
247
+ const data = await res.json();
248
+ updateUI(data.observation, data.reward, data.done);
249
+
250
+ if (data.info && data.info.clarify_answer) {
251
+ addLog(`Clarify answer: ${data.info.clarify_answer}`);
252
+ }
253
+ addLog(`Step done — reward: ${data.reward.value.toFixed(4)}, done: ${data.done}`);
254
+ } catch (err) {
255
+ addLog(`Step failed: ${err.message}`, "error");
256
+ } finally {
257
+ stepBtn.disabled = false;
258
+ }
259
+ });
260
+
261
+ // Init
262
+ addLog("Frontend loaded. Select a task and click Reset.");
frontend/index.html ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Clinical Note Scribe</title>
7
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
8
+ <style>
9
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
10
+
11
+ :root {
12
+ --bg: #0f1117;
13
+ --surface: #1a1d27;
14
+ --surface-2: #242836;
15
+ --border: #2e3345;
16
+ --text: #e4e6ef;
17
+ --text-muted: #8b8fa3;
18
+ --accent: #4f8cff;
19
+ --accent-glow: rgba(79, 140, 255, 0.15);
20
+ --green: #34d399;
21
+ --red: #f87171;
22
+ --yellow: #fbbf24;
23
+ --radius: 10px;
24
+ --font: 'Inter', -apple-system, sans-serif;
25
+ }
26
+
27
+ body {
28
+ font-family: var(--font);
29
+ background: var(--bg);
30
+ color: var(--text);
31
+ min-height: 100vh;
32
+ line-height: 1.5;
33
+ }
34
+
35
+ /* Header */
36
+ .header {
37
+ display: flex;
38
+ align-items: center;
39
+ justify-content: space-between;
40
+ padding: 16px 28px;
41
+ border-bottom: 1px solid var(--border);
42
+ background: var(--surface);
43
+ }
44
+ .header h1 {
45
+ font-size: 18px;
46
+ font-weight: 600;
47
+ display: flex;
48
+ align-items: center;
49
+ gap: 8px;
50
+ }
51
+ .header h1 span { font-size: 22px; }
52
+ .header-right {
53
+ display: flex;
54
+ align-items: center;
55
+ gap: 12px;
56
+ }
57
+ .status-badge {
58
+ font-size: 12px;
59
+ padding: 4px 10px;
60
+ border-radius: 20px;
61
+ font-weight: 500;
62
+ }
63
+ .status-badge.idle { background: var(--surface-2); color: var(--text-muted); }
64
+ .status-badge.active { background: rgba(52,211,153,0.15); color: var(--green); }
65
+ .status-badge.done { background: rgba(79,140,255,0.15); color: var(--accent); }
66
+
67
+ /* Layout */
68
+ .layout {
69
+ display: grid;
70
+ grid-template-columns: 1fr 1fr;
71
+ gap: 0;
72
+ height: calc(100vh - 57px);
73
+ }
74
+
75
+ .panel {
76
+ display: flex;
77
+ flex-direction: column;
78
+ overflow: hidden;
79
+ }
80
+ .panel-left { border-right: 1px solid var(--border); }
81
+
82
+ .panel-section {
83
+ padding: 16px 20px;
84
+ border-bottom: 1px solid var(--border);
85
+ }
86
+ .panel-section-title {
87
+ font-size: 11px;
88
+ font-weight: 600;
89
+ text-transform: uppercase;
90
+ letter-spacing: 0.8px;
91
+ color: var(--text-muted);
92
+ margin-bottom: 10px;
93
+ }
94
+
95
+ .scrollable {
96
+ flex: 1;
97
+ overflow-y: auto;
98
+ padding: 16px 20px;
99
+ }
100
+ .scrollable::-webkit-scrollbar { width: 6px; }
101
+ .scrollable::-webkit-scrollbar-track { background: transparent; }
102
+ .scrollable::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
103
+
104
+ /* Controls */
105
+ .controls {
106
+ display: flex;
107
+ gap: 8px;
108
+ flex-wrap: wrap;
109
+ align-items: center;
110
+ }
111
+ select, input, textarea {
112
+ font-family: var(--font);
113
+ font-size: 13px;
114
+ background: var(--surface-2);
115
+ border: 1px solid var(--border);
116
+ color: var(--text);
117
+ border-radius: 6px;
118
+ padding: 7px 10px;
119
+ outline: none;
120
+ transition: border-color 0.15s;
121
+ }
122
+ select:focus, input:focus, textarea:focus {
123
+ border-color: var(--accent);
124
+ }
125
+ textarea {
126
+ width: 100%;
127
+ resize: vertical;
128
+ min-height: 80px;
129
+ }
130
+
131
+ .btn {
132
+ font-family: var(--font);
133
+ font-size: 13px;
134
+ font-weight: 500;
135
+ padding: 7px 16px;
136
+ border: none;
137
+ border-radius: 6px;
138
+ cursor: pointer;
139
+ transition: all 0.15s;
140
+ }
141
+ .btn:disabled { opacity: 0.4; cursor: not-allowed; }
142
+ .btn-primary { background: var(--accent); color: #fff; }
143
+ .btn-primary:hover:not(:disabled) { background: #3d7ae8; }
144
+ .btn-secondary { background: var(--surface-2); color: var(--text); border: 1px solid var(--border); }
145
+ .btn-secondary:hover:not(:disabled) { background: var(--border); }
146
+ .btn-green { background: var(--green); color: #0f1117; }
147
+ .btn-green:hover:not(:disabled) { background: #2bc48d; }
148
+
149
+ /* Transcript */
150
+ .transcript-box {
151
+ font-size: 13px;
152
+ white-space: pre-wrap;
153
+ color: var(--text);
154
+ line-height: 1.7;
155
+ }
156
+ .transcript-box .speaker-doctor { color: var(--accent); font-weight: 500; }
157
+ .transcript-box .speaker-patient { color: var(--green); font-weight: 500; }
158
+
159
+ /* Context cards */
160
+ .context-grid {
161
+ display: grid;
162
+ grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
163
+ gap: 8px;
164
+ }
165
+ .context-card {
166
+ background: var(--surface-2);
167
+ border-radius: 8px;
168
+ padding: 10px 12px;
169
+ }
170
+ .context-card .label {
171
+ font-size: 10px;
172
+ font-weight: 600;
173
+ text-transform: uppercase;
174
+ letter-spacing: 0.5px;
175
+ color: var(--text-muted);
176
+ margin-bottom: 3px;
177
+ }
178
+ .context-card .value {
179
+ font-size: 13px;
180
+ color: var(--text);
181
+ }
182
+
183
+ /* SOAP Draft */
184
+ .soap-grid {
185
+ display: grid;
186
+ grid-template-columns: 1fr 1fr;
187
+ gap: 10px;
188
+ }
189
+ .soap-card {
190
+ background: var(--surface-2);
191
+ border-radius: 8px;
192
+ padding: 12px 14px;
193
+ border-left: 3px solid var(--border);
194
+ }
195
+ .soap-card.s { border-left-color: #60a5fa; }
196
+ .soap-card.o { border-left-color: #34d399; }
197
+ .soap-card.a { border-left-color: #fbbf24; }
198
+ .soap-card.p { border-left-color: #c084fc; }
199
+ .soap-card .soap-label {
200
+ font-size: 11px;
201
+ font-weight: 600;
202
+ text-transform: uppercase;
203
+ letter-spacing: 0.5px;
204
+ margin-bottom: 6px;
205
+ }
206
+ .soap-card.s .soap-label { color: #60a5fa; }
207
+ .soap-card.o .soap-label { color: #34d399; }
208
+ .soap-card.a .soap-label { color: #fbbf24; }
209
+ .soap-card.p .soap-label { color: #c084fc; }
210
+ .soap-card .soap-text {
211
+ font-size: 13px;
212
+ color: var(--text-muted);
213
+ line-height: 1.6;
214
+ }
215
+
216
+ /* Reward bar */
217
+ .reward-bar {
218
+ display: flex;
219
+ align-items: center;
220
+ gap: 12px;
221
+ padding: 10px 0;
222
+ }
223
+ .reward-bar .score-value {
224
+ font-size: 28px;
225
+ font-weight: 700;
226
+ font-variant-numeric: tabular-nums;
227
+ }
228
+ .reward-meter {
229
+ flex: 1;
230
+ height: 8px;
231
+ background: var(--surface-2);
232
+ border-radius: 4px;
233
+ overflow: hidden;
234
+ }
235
+ .reward-meter-fill {
236
+ height: 100%;
237
+ border-radius: 4px;
238
+ background: linear-gradient(90deg, var(--accent), var(--green));
239
+ transition: width 0.4s ease;
240
+ }
241
+
242
+ /* Action form */
243
+ .action-form {
244
+ display: flex;
245
+ flex-direction: column;
246
+ gap: 10px;
247
+ }
248
+ .action-row {
249
+ display: flex;
250
+ gap: 8px;
251
+ align-items: flex-start;
252
+ }
253
+ .action-row select { min-width: 160px; }
254
+
255
+ /* Log */
256
+ .log-container {
257
+ font-size: 12px;
258
+ font-family: 'SF Mono', 'Fira Code', monospace;
259
+ color: var(--text-muted);
260
+ max-height: 120px;
261
+ overflow-y: auto;
262
+ padding: 8px 0;
263
+ }
264
+ .log-entry {
265
+ padding: 2px 0;
266
+ border-bottom: 1px solid rgba(46, 51, 69, 0.4);
267
+ }
268
+ .log-entry .log-time { color: var(--border); margin-right: 8px; }
269
+ .log-entry.error { color: var(--red); }
270
+ .log-entry.success { color: var(--green); }
271
+
272
+ /* Empty state */
273
+ .empty-state {
274
+ display: flex;
275
+ flex-direction: column;
276
+ align-items: center;
277
+ justify-content: center;
278
+ height: 100%;
279
+ color: var(--text-muted);
280
+ text-align: center;
281
+ gap: 12px;
282
+ }
283
+ .empty-state .icon { font-size: 36px; opacity: 0.5; }
284
+ .empty-state p { font-size: 14px; max-width: 280px; }
285
+ </style>
286
+ </head>
287
+ <body>
288
+
289
+ <header class="header">
290
+ <h1><span>🏥</span> Clinical Note Scribe</h1>
291
+ <div class="header-right">
292
+ <span class="status-badge idle" id="statusBadge">Idle</span>
293
+ </div>
294
+ </header>
295
+
296
+ <div class="layout">
297
+
298
+ <!-- Left Panel: Transcript + Context -->
299
+ <div class="panel panel-left">
300
+ <div class="panel-section">
301
+ <div class="controls">
302
+ <select id="taskSelect">
303
+ <option value="easy_routine_checkup">🟢 Easy — Routine Check-Up</option>
304
+ <option value="medium_chronic_disease_followup">🟡 Medium — Chronic Follow-Up</option>
305
+ <option value="hard_complex_er_visit">🔴 Hard — Complex ER Visit</option>
306
+ </select>
307
+ <button class="btn btn-primary" id="resetBtn">Reset</button>
308
+ </div>
309
+ </div>
310
+
311
+ <div class="panel-section" id="contextSection" style="display:none;">
312
+ <div class="panel-section-title">Patient Context</div>
313
+ <div class="context-grid" id="contextGrid"></div>
314
+ </div>
315
+
316
+ <div class="scrollable" id="transcriptArea">
317
+ <div class="empty-state">
318
+ <div class="icon">📋</div>
319
+ <p>Select a task and click <strong>Reset</strong> to load the transcript.</p>
320
+ </div>
321
+ </div>
322
+ </div>
323
+
324
+ <!-- Right Panel: Actions + Draft + Reward -->
325
+ <div class="panel">
326
+
327
+ <div class="panel-section" id="actionSection" style="display:none;">
328
+ <div class="panel-section-title">Action</div>
329
+ <div class="action-form">
330
+ <div class="action-row">
331
+ <select id="actionType">
332
+ <option value="submit_note">Submit Note</option>
333
+ <option value="revise_section">Revise Section</option>
334
+ <option value="request_clarify">Request Clarify</option>
335
+ </select>
336
+ <select id="sectionSelect" style="display:none;">
337
+ <option value="S">Subjective</option>
338
+ <option value="O">Objective</option>
339
+ <option value="A">Assessment</option>
340
+ <option value="P">Plan</option>
341
+ </select>
342
+ <button class="btn btn-green" id="stepBtn">Send</button>
343
+ </div>
344
+ <div id="soapInputs">
345
+ <div style="display:grid; grid-template-columns:1fr 1fr; gap:8px;">
346
+ <textarea id="inputS" placeholder="Subjective..." rows="3"></textarea>
347
+ <textarea id="inputO" placeholder="Objective..." rows="3"></textarea>
348
+ <textarea id="inputA" placeholder="Assessment..." rows="3"></textarea>
349
+ <textarea id="inputP" placeholder="Plan..." rows="3"></textarea>
350
+ </div>
351
+ </div>
352
+ <div id="reviseInput" style="display:none;">
353
+ <textarea id="inputRevision" placeholder="Revision text..." rows="3"></textarea>
354
+ </div>
355
+ <div id="clarifyInput" style="display:none;">
356
+ <input id="inputClarify" type="text" placeholder="Your question..." style="width:100%;">
357
+ </div>
358
+ </div>
359
+ </div>
360
+
361
+ <div class="panel-section" id="rewardSection" style="display:none;">
362
+ <div class="panel-section-title">Reward</div>
363
+ <div class="reward-bar">
364
+ <div class="score-value" id="scoreValue">0.00</div>
365
+ <div class="reward-meter">
366
+ <div class="reward-meter-fill" id="rewardFill" style="width:0%"></div>
367
+ </div>
368
+ </div>
369
+ </div>
370
+
371
+ <div class="scrollable" id="draftArea">
372
+ <div class="empty-state" id="draftEmpty">
373
+ <div class="icon">📝</div>
374
+ <p>Your SOAP note draft will appear here after you submit or revise.</p>
375
+ </div>
376
+ <div id="soapDraft" style="display:none;">
377
+ <div class="panel-section-title" style="margin-bottom:12px;">Current Draft</div>
378
+ <div class="soap-grid" id="soapGrid"></div>
379
+ </div>
380
+ </div>
381
+
382
+ <div class="panel-section" style="border-top:1px solid var(--border); border-bottom:none;">
383
+ <div class="panel-section-title">Log</div>
384
+ <div class="log-container" id="logContainer"></div>
385
+ </div>
386
+
387
+ </div>
388
+ </div>
389
+
390
+ <script src="/static/app.js"></script>
391
+ </body>
392
+ </html>
inference.py CHANGED
@@ -1,19 +1,18 @@
1
  """
2
  Inference Script — Clinical Note Scribe
3
- ===================================
4
  MANDATORY
5
- - Before submitting, ensure the following variables are defined in your environment configuration:
6
  API_BASE_URL The API endpoint for the LLM.
7
  MODEL_NAME The model identifier to use for inference.
8
  HF_TOKEN Your Hugging Face / API key.
9
- LOCAL_IMAGE_NAME The name of the local image to use for the environment
10
- if you are using from_docker_image() method.
11
 
12
  - Defaults are set only for API_BASE_URL and MODEL_NAME:
13
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
15
 
16
- - The inference script must be named `inference.py` and placed in the root directory.
17
  - Participants must use OpenAI Client for all LLM calls using above variables.
18
 
19
  STDOUT FORMAT
@@ -47,29 +46,20 @@
47
 
48
  from openai import OpenAI
49
 
50
- # ---------------------------------------------------------------------------
51
  # Silence the underlying env's stdout JSON logs (redirect them to stderr)
52
- # ---------------------------------------------------------------------------
53
  env_logger = logging.getLogger("clinical_note_scribe")
54
  env_logger.setLevel(logging.INFO)
55
  env_logger.handlers.clear()
56
  env_logger.addHandler(logging.StreamHandler(sys.stderr))
57
  env_logger.propagate = False
58
 
59
-
60
- # ---------------------------------------------------------------------------
61
  # Environment imports
62
- # ---------------------------------------------------------------------------
63
-
64
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
65
 
66
  from environment import ClinicalNoteScribeEnv, Action, SOAPNote # noqa: E402
67
  from environment.tasks import TASK_REGISTRY # noqa: E402
68
 
69
- # ---------------------------------------------------------------------------
70
  # Config
71
- # ---------------------------------------------------------------------------
72
-
73
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
74
  HF_TOKEN = os.getenv("HF_TOKEN")
75
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
@@ -77,14 +67,11 @@
77
 
78
  BENCHMARK = "clinical-note-scribe"
79
  TASK_IDS = list(TASK_REGISTRY.keys())
80
- MAX_STEPS = 5 # Max steps per task (submit + optional clarify/revise)
81
  MAX_TOKENS = 1024
82
  TEMPERATURE = 0.2
83
 
84
- # ---------------------------------------------------------------------------
85
  # System prompt
86
- # ---------------------------------------------------------------------------
87
-
88
  SYSTEM_PROMPT = textwrap.dedent("""\
89
  You are a clinical documentation assistant. Given a doctor-patient transcript
90
  and patient context, generate a concise, clinically accurate SOAP note.
@@ -109,10 +96,7 @@
109
  """).strip()
110
 
111
 
112
- # ---------------------------------------------------------------------------
113
  # Stdout logging — mandatory hackathon format
114
- # ---------------------------------------------------------------------------
115
-
116
  def log_start(task: str, env: str, model: str) -> None:
117
  print(f"[START] task={task} env={env} model={model}", flush=True)
118
 
@@ -134,10 +118,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
134
  )
135
 
136
 
137
- # ---------------------------------------------------------------------------
138
  # Helpers
139
- # ---------------------------------------------------------------------------
140
-
141
  def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
142
  """Build the user message containing the transcript and context."""
143
  ctx_str = json.dumps(patient_context, indent=2, default=str)
@@ -187,10 +168,7 @@ def get_soap_note(client: OpenAI, transcript: str, patient_context: dict[str, An
187
  raise
188
 
189
 
190
- # ---------------------------------------------------------------------------
191
  # Per-task runner
192
- # ---------------------------------------------------------------------------
193
-
194
  def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[str, Any]:
195
  """Run a single task episode and return the result dict."""
196
  rewards: List[float] = []
@@ -202,18 +180,15 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
202
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
203
 
204
  try:
205
- # ---- reset ----
206
  obs = env.reset(task_id)
207
 
208
  for step in range(1, MAX_STEPS + 1):
209
- # ---- generate SOAP note via LLM ----
210
  try:
211
  action_dict = get_soap_note(client, obs.transcript, obs.patient_context)
212
  action = Action(**action_dict)
213
  action_str = f"submit_note(sections=S,O,A,P)"
214
  except Exception as exc:
215
- # On model / parse failure, submit an empty note so all sub-signals
216
- # grade to 0.0 (format_valid=0 because fields are empty, grader=0).
217
  action = Action(
218
  action_type="submit_note",
219
  soap_note=SOAPNote(
@@ -226,14 +201,12 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
226
  action_str = "submit_note(fallback)"
227
  last_error = str(exc)
228
 
229
- # ---- step ----
230
  obs, reward_obj, done, info = env.step(action)
231
 
232
  reward_val = reward_obj.value
233
  rewards.append(reward_val)
234
  steps_taken = step
235
 
236
- # Check for env-level errors
237
  error_msg = None
238
  if obs.errors_so_far:
239
  error_msg = obs.errors_so_far[-1]
@@ -252,7 +225,6 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
252
  if done:
253
  break
254
 
255
- # Final score = last reward value (already in [0, 1])
256
  score = rewards[-1] if rewards else 0.0
257
  score = min(max(score, 0.0), 1.0)
258
  success = score > 0.0
@@ -274,10 +246,7 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
274
  }
275
 
276
 
277
- # ---------------------------------------------------------------------------
278
  # Main
279
- # ---------------------------------------------------------------------------
280
-
281
  def main() -> None:
282
  if not HF_TOKEN:
283
  print(
@@ -295,7 +264,7 @@ def main() -> None:
295
  result = run_task(client, env, task_id)
296
  results.append(result)
297
 
298
- # ---- Summary table ----
299
  print("", file=sys.stderr, flush=True)
300
  print("=" * 60, file=sys.stderr, flush=True)
301
  print(" SUMMARY", file=sys.stderr, flush=True)
 
1
  """
2
  Inference Script — Clinical Note Scribe
3
+ ========================================
4
  MANDATORY
5
+ - Before submitting, ensure the following variables are defined:
6
  API_BASE_URL The API endpoint for the LLM.
7
  MODEL_NAME The model identifier to use for inference.
8
  HF_TOKEN Your Hugging Face / API key.
9
+ LOCAL_IMAGE_NAME The name of the local image for the environment.
 
10
 
11
  - Defaults are set only for API_BASE_URL and MODEL_NAME:
12
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
13
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
14
 
15
+ - The inference script must be named inference.py and placed in the root directory.
16
  - Participants must use OpenAI Client for all LLM calls using above variables.
17
 
18
  STDOUT FORMAT
 
46
 
47
  from openai import OpenAI
48
 
 
49
  # Silence the underlying env's stdout JSON logs (redirect them to stderr)
 
50
  env_logger = logging.getLogger("clinical_note_scribe")
51
  env_logger.setLevel(logging.INFO)
52
  env_logger.handlers.clear()
53
  env_logger.addHandler(logging.StreamHandler(sys.stderr))
54
  env_logger.propagate = False
55
 
 
 
56
  # Environment imports
 
 
57
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
58
 
59
  from environment import ClinicalNoteScribeEnv, Action, SOAPNote # noqa: E402
60
  from environment.tasks import TASK_REGISTRY # noqa: E402
61
 
 
62
  # Config
 
 
63
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
64
  HF_TOKEN = os.getenv("HF_TOKEN")
65
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
 
67
 
68
  BENCHMARK = "clinical-note-scribe"
69
  TASK_IDS = list(TASK_REGISTRY.keys())
70
+ MAX_STEPS = 5
71
  MAX_TOKENS = 1024
72
  TEMPERATURE = 0.2
73
 
 
74
  # System prompt
 
 
75
  SYSTEM_PROMPT = textwrap.dedent("""\
76
  You are a clinical documentation assistant. Given a doctor-patient transcript
77
  and patient context, generate a concise, clinically accurate SOAP note.
 
96
  """).strip()
97
 
98
 
 
99
  # Stdout logging — mandatory hackathon format
 
 
100
  def log_start(task: str, env: str, model: str) -> None:
101
  print(f"[START] task={task} env={env} model={model}", flush=True)
102
 
 
118
  )
119
 
120
 
 
121
  # Helpers
 
 
122
  def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
123
  """Build the user message containing the transcript and context."""
124
  ctx_str = json.dumps(patient_context, indent=2, default=str)
 
168
  raise
169
 
170
 
 
171
  # Per-task runner
 
 
172
  def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[str, Any]:
173
  """Run a single task episode and return the result dict."""
174
  rewards: List[float] = []
 
180
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
181
 
182
  try:
 
183
  obs = env.reset(task_id)
184
 
185
  for step in range(1, MAX_STEPS + 1):
 
186
  try:
187
  action_dict = get_soap_note(client, obs.transcript, obs.patient_context)
188
  action = Action(**action_dict)
189
  action_str = f"submit_note(sections=S,O,A,P)"
190
  except Exception as exc:
191
+ # On model / parse failure, submit an empty note
 
192
  action = Action(
193
  action_type="submit_note",
194
  soap_note=SOAPNote(
 
201
  action_str = "submit_note(fallback)"
202
  last_error = str(exc)
203
 
 
204
  obs, reward_obj, done, info = env.step(action)
205
 
206
  reward_val = reward_obj.value
207
  rewards.append(reward_val)
208
  steps_taken = step
209
 
 
210
  error_msg = None
211
  if obs.errors_so_far:
212
  error_msg = obs.errors_so_far[-1]
 
225
  if done:
226
  break
227
 
 
228
  score = rewards[-1] if rewards else 0.0
229
  score = min(max(score, 0.0), 1.0)
230
  success = score > 0.0
 
246
  }
247
 
248
 
 
249
  # Main
 
 
250
  def main() -> None:
251
  if not HF_TOKEN:
252
  print(
 
264
  result = run_task(client, env, task_id)
265
  results.append(result)
266
 
267
+ # Summary table
268
  print("", file=sys.stderr, flush=True)
269
  print("=" * 60, file=sys.stderr, flush=True)
270
  print(" SUMMARY", file=sys.stderr, flush=True)
server/app.py CHANGED
@@ -1,60 +1,40 @@
1
- """FastAPI application for the Clinical Note Scribe environment.
2
-
3
- Run locally::
4
-
5
- uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
6
-
7
- Or via Docker (see ``Dockerfile`` in project root).
8
- """
9
 
10
  from __future__ import annotations
11
 
12
  import logging
13
  import sys
 
14
 
15
  from fastapi import FastAPI
 
 
16
 
17
  from server.routes import router
18
 
19
- # ---------------------------------------------------------------------------
20
- # Configure root logging → structured JSON to stdout
21
- # ---------------------------------------------------------------------------
22
-
23
- logging.basicConfig(
24
- level=logging.INFO,
25
- format="%(message)s",
26
- handlers=[logging.StreamHandler(sys.stdout)],
27
- )
28
-
29
- # Silence noisy uvicorn access logs so our structured events stay clean
30
  logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
31
 
32
- # ---------------------------------------------------------------------------
33
- # Application factory
34
- # ---------------------------------------------------------------------------
35
-
36
  app = FastAPI(
37
  title="Clinical Note Scribe – OpenEnv",
38
- description=(
39
- "An OpenEnv-compliant environment for evaluating AI agents on "
40
- "clinical SOAP-note generation from doctor–patient transcripts."
41
- ),
42
- version="0.1.0",
43
  )
 
44
 
45
- from fastapi.responses import RedirectResponse
 
46
 
47
- # Mount all routes at root (/)
48
- app.include_router(router)
49
 
50
  @app.get("/", include_in_schema=False)
51
  async def root():
52
- """Redirect to the FastAPI interactive documentation."""
53
- return RedirectResponse(url="/docs")
54
 
55
  def main():
56
  import uvicorn
57
  uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
58
 
 
59
  if __name__ == "__main__":
60
  main()
 
1
+ """FastAPI application for the Clinical Note Scribe environment."""
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import logging
6
  import sys
7
+ from pathlib import Path
8
 
9
  from fastapi import FastAPI
10
+ from fastapi.staticfiles import StaticFiles
11
+ from fastapi.responses import FileResponse
12
 
13
  from server.routes import router
14
 
15
+ logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[logging.StreamHandler(sys.stdout)])
 
 
 
 
 
 
 
 
 
 
16
  logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
17
 
 
 
 
 
18
  app = FastAPI(
19
  title="Clinical Note Scribe – OpenEnv",
20
+ description="OpenEnv-compliant environment for evaluating AI agents on clinical SOAP-note generation.",
21
+ version="1.0.0",
 
 
 
22
  )
23
+ app.include_router(router)
24
 
25
+ FRONTEND_DIR = Path(__file__).resolve().parent.parent / "frontend"
26
+ app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
27
 
 
 
28
 
29
  @app.get("/", include_in_schema=False)
30
  async def root():
31
+ return FileResponse(str(FRONTEND_DIR / "index.html"))
32
+
33
 
34
  def main():
35
  import uvicorn
36
  uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
37
 
38
+
39
  if __name__ == "__main__":
40
  main()
server/routes.py CHANGED
@@ -1,61 +1,29 @@
1
- """FastAPI route definitions for the Clinical Note Scribe environment.
2
-
3
- Endpoints
4
- ---------
5
- POST /reset – start a new episode (takes ``task_id``, returns ``Observation``)
6
- POST /step – execute an action (takes ``Action``, returns step result)
7
- GET /state – inspect env state (returns ``EnvironmentState``)
8
- GET /health – liveness probe (returns ``{"status": "ok"}``)
9
-
10
- Structured logging
11
- ------------------
12
- The underlying ``ClinicalNoteScribeEnv`` already emits ``[START]``, ``[STEP]``,
13
- and ``[END]`` JSON lines to stdout via Python's ``logging`` module. This router
14
- adds a thin request-level log wrapper so every inbound HTTP call is also
15
- traceable.
16
- """
17
 
18
  from __future__ import annotations
19
 
20
- import logging
21
  import json
 
22
  import time
23
  from typing import Any, Optional
24
 
25
  from fastapi import APIRouter, HTTPException
26
- from pydantic import BaseModel, Field
27
-
28
- from environment.models import (
29
- Action,
30
- EnvironmentState,
31
- Observation,
32
- Reward,
33
- )
34
  from environment.env import ClinicalNoteScribeEnv
35
 
36
  logger = logging.getLogger("clinical_note_scribe.server")
37
-
38
-
39
- # ---------------------------------------------------------------------------
40
- # Singleton environment instance
41
- # ---------------------------------------------------------------------------
42
-
43
  _env = ClinicalNoteScribeEnv()
 
 
44
 
 
 
45
 
46
- # ---------------------------------------------------------------------------
47
- # Request / response schemas
48
- # ---------------------------------------------------------------------------
49
 
50
  class ResetRequest(BaseModel):
51
- task_id: Optional[str] = Field(
52
- default=None,
53
- description=(
54
- "Task to load. One of: easy_routine_checkup, "
55
- "medium_chronic_disease_followup, hard_complex_er_visit. "
56
- "Defaults to the first registered task."
57
- ),
58
- )
59
 
60
 
61
  class StepResponse(BaseModel):
@@ -69,114 +37,47 @@ class HealthResponse(BaseModel):
69
  status: str = "ok"
70
 
71
 
72
- # ---------------------------------------------------------------------------
73
- # Helpers
74
- # ---------------------------------------------------------------------------
75
-
76
- def _log(event: str, **kwargs: Any) -> None:
77
- """Emit a structured JSON log line to stdout."""
78
- payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
79
- payload.update(kwargs)
80
- logger.info(json.dumps(payload, default=str))
81
-
82
-
83
- # ---------------------------------------------------------------------------
84
- # Router
85
- # ---------------------------------------------------------------------------
86
-
87
- router = APIRouter()
88
-
89
-
90
- @router.post(
91
- "/reset",
92
- response_model=Observation,
93
- summary="Reset the environment and start a new episode",
94
- )
95
  async def reset(body: Optional[ResetRequest] = None) -> Observation:
96
- """Load a task and return the initial ``Observation``.
97
-
98
- The underlying environment emits a ``[START]`` log event.
99
- """
100
  task_id = body.task_id if body else None
101
  _log("START", endpoint="/reset", task_id=task_id)
102
  try:
103
- obs = _env.reset(task_id=task_id)
104
  except ValueError as exc:
105
  raise HTTPException(status_code=400, detail=str(exc))
106
- return obs
107
 
108
 
109
- @router.post(
110
- "/step",
111
- response_model=StepResponse,
112
- summary="Submit an action and advance the environment by one step",
113
- )
114
  async def step(payload: dict[str, Any]) -> StepResponse:
115
- """Execute an action in the current episode.
116
-
117
- Accepts a raw JSON body and validates it into an ``Action``.
118
- If validation fails, the error is recorded in the environment
119
- instead of returning an HTTP 422.
120
- """
121
- from pydantic import ValidationError
122
- from environment.models import Reward
123
-
124
  try:
125
  action = Action(**payload)
126
  except (ValidationError, TypeError) as exc:
127
- # Gracefully absorb bad payloads instead of crashing with HTTP 422
128
  _log("STEP", endpoint="/step", action_type="invalid", error=str(exc))
129
  error_msg = f"Invalid action payload: {exc}"
130
  _env._errors_so_far.append(error_msg)
131
  _env._step_count += 1
132
-
133
- obs = _env._build_observation()
134
- reward = Reward(
135
- value=0.0,
136
- signals={"error": 1.0},
137
- done=False,
138
- info={"error": error_msg},
139
- )
140
  return StepResponse(
141
- observation=obs,
142
- reward=reward,
143
- done=False,
144
- info={"error": error_msg},
145
  )
146
 
147
  _log("STEP", endpoint="/step", action_type=action.action_type)
148
  try:
149
  obs, reward, done, info = _env.step(action)
150
  except RuntimeError as exc:
151
- # e.g. stepping after episode is done without reset
152
  raise HTTPException(status_code=409, detail=str(exc))
153
 
154
  if done:
155
  _log("END", endpoint="/step", final_score=reward.value)
156
-
157
- return StepResponse(
158
- observation=obs,
159
- reward=reward,
160
- done=done,
161
- info=info,
162
- )
163
 
164
 
165
- @router.get(
166
- "/state",
167
- response_model=EnvironmentState,
168
- summary="Return the full internal environment state",
169
- )
170
  async def state() -> EnvironmentState:
171
- """Inspect the environment without mutating it."""
172
  return _env.state()
173
 
174
 
175
- @router.get(
176
- "/health",
177
- response_model=HealthResponse,
178
- summary="Liveness probe",
179
- )
180
  async def health() -> HealthResponse:
181
- """Returns HTTP 200 with ``{"status": "ok"}``."""
182
  return HealthResponse()
 
1
+ """FastAPI routes for the Clinical Note Scribe environment."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import json
6
+ import logging
7
  import time
8
  from typing import Any, Optional
9
 
10
  from fastapi import APIRouter, HTTPException
11
+ from pydantic import BaseModel, Field, ValidationError
12
+
13
+ from environment.models import Action, EnvironmentState, Observation, Reward
 
 
 
 
 
14
  from environment.env import ClinicalNoteScribeEnv
15
 
16
  logger = logging.getLogger("clinical_note_scribe.server")
 
 
 
 
 
 
17
  _env = ClinicalNoteScribeEnv()
18
+ router = APIRouter()
19
+
20
 
21
+ def _log(event: str, **kw: Any) -> None:
22
+ logger.info(json.dumps({"event": event, "timestamp": time.time(), **kw}, default=str))
23
 
 
 
 
24
 
25
  class ResetRequest(BaseModel):
26
+ task_id: Optional[str] = Field(None, description="Task to load. Defaults to first registered task.")
 
 
 
 
 
 
 
27
 
28
 
29
  class StepResponse(BaseModel):
 
37
  status: str = "ok"
38
 
39
 
40
+ @router.post("/reset", response_model=Observation, summary="Reset and start a new episode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  async def reset(body: Optional[ResetRequest] = None) -> Observation:
 
 
 
 
42
  task_id = body.task_id if body else None
43
  _log("START", endpoint="/reset", task_id=task_id)
44
  try:
45
+ return _env.reset(task_id=task_id)
46
  except ValueError as exc:
47
  raise HTTPException(status_code=400, detail=str(exc))
 
48
 
49
 
50
+ @router.post("/step", response_model=StepResponse, summary="Submit an action")
 
 
 
 
51
  async def step(payload: dict[str, Any]) -> StepResponse:
 
 
 
 
 
 
 
 
 
52
  try:
53
  action = Action(**payload)
54
  except (ValidationError, TypeError) as exc:
 
55
  _log("STEP", endpoint="/step", action_type="invalid", error=str(exc))
56
  error_msg = f"Invalid action payload: {exc}"
57
  _env._errors_so_far.append(error_msg)
58
  _env._step_count += 1
 
 
 
 
 
 
 
 
59
  return StepResponse(
60
+ observation=_env._obs(),
61
+ reward=Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": error_msg}),
62
+ done=False, info={"error": error_msg},
 
63
  )
64
 
65
  _log("STEP", endpoint="/step", action_type=action.action_type)
66
  try:
67
  obs, reward, done, info = _env.step(action)
68
  except RuntimeError as exc:
 
69
  raise HTTPException(status_code=409, detail=str(exc))
70
 
71
  if done:
72
  _log("END", endpoint="/step", final_score=reward.value)
73
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
 
 
 
 
 
 
74
 
75
 
76
+ @router.get("/state", response_model=EnvironmentState, summary="Inspect environment state")
 
 
 
 
77
  async def state() -> EnvironmentState:
 
78
  return _env.state()
79
 
80
 
81
+ @router.get("/health", response_model=HealthResponse, summary="Liveness probe")
 
 
 
 
82
  async def health() -> HealthResponse:
 
83
  return HealthResponse()