Aman Khare commited on
Commit ·
8b7bdb7
1
Parent(s): 3c850f8
Optimize codebase + add minimalist frontend
Browse files- environment/env.py +95 -305
- environment/models.py +28 -156
- environment/reward.py +70 -166
- environment/tasks/task_easy.py +19 -67
- environment/tasks/task_hard.py +24 -100
- environment/tasks/task_medium.py +21 -78
- frontend/app.js +262 -0
- frontend/index.html +392 -0
- inference.py +7 -38
- server/app.py +13 -33
- server/routes.py +19 -118
environment/env.py
CHANGED
|
@@ -1,18 +1,4 @@
|
|
| 1 |
-
"""ClinicalNoteScribeEnv — core environment
|
| 2 |
-
|
| 3 |
-
Implements the ``reset() → Observation``, ``step(Action) → (Observation, Reward, bool, dict)``,
|
| 4 |
-
and ``state() → EnvironmentState`` interface required by the OpenEnv spec.
|
| 5 |
-
|
| 6 |
-
Structured logging
|
| 7 |
-
------------------
|
| 8 |
-
Every episode emits exactly three kinds of JSON log lines to **stdout**:
|
| 9 |
-
|
| 10 |
-
- ``{"event": "START", "task_id": "...", "timestamp": ...}``
|
| 11 |
-
- ``{"event": "STEP", "step": N, "action_type": "...", "reward": R}``
|
| 12 |
-
- ``{"event": "END", "task_id": "...", "final_score": S}``
|
| 13 |
-
|
| 14 |
-
The OpenEnv validator scrapes ``[START]``, ``[STEP]``, ``[END]`` keywords.
|
| 15 |
-
"""
|
| 16 |
|
| 17 |
from __future__ import annotations
|
| 18 |
|
|
@@ -27,115 +13,55 @@
|
|
| 27 |
from environment.tasks import GRADER_REGISTRY, TASK_REGISTRY
|
| 28 |
|
| 29 |
logger = logging.getLogger("clinical_note_scribe")
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
# ---------------------------------------------------------------------------
|
| 35 |
-
|
| 36 |
-
def _load_transcript(transcript_path: str) -> str:
|
| 37 |
-
"""Load a transcript text file from *project-root-relative* path."""
|
| 38 |
-
base = Path(__file__).resolve().parent.parent # clinical-note-scribe/
|
| 39 |
-
full_path = base / transcript_path
|
| 40 |
-
if full_path.exists():
|
| 41 |
-
return full_path.read_text(encoding="utf-8")
|
| 42 |
-
return f"[Transcript file not found: {transcript_path}]"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _log_event(event: str, **kwargs: Any) -> None:
|
| 46 |
-
"""Emit a structured JSON log line to stdout via the logger."""
|
| 47 |
-
payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
|
| 48 |
-
payload.update(kwargs)
|
| 49 |
-
logger.info(json.dumps(payload))
|
| 50 |
|
| 51 |
|
| 52 |
-
|
| 53 |
-
"""
|
| 54 |
-
return (
|
| 55 |
-
f"S: {soap.subjective}\n"
|
| 56 |
-
f"O: {soap.objective}\n"
|
| 57 |
-
f"A: {soap.assessment}\n"
|
| 58 |
-
f"P: {soap.plan}"
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
"""Open-environment wrapper for the clinical note-scribe tasks.
|
| 68 |
-
|
| 69 |
-
Lifecycle
|
| 70 |
-
---------
|
| 71 |
-
1. ``env.reset(task_id)`` → returns initial ``Observation``
|
| 72 |
-
2. ``env.step(action)`` → returns ``(Observation, Reward, done, info)``
|
| 73 |
-
3. ``env.state()`` → returns full ``EnvironmentState`` snapshot
|
| 74 |
-
|
| 75 |
-
Parameters
|
| 76 |
-
----------
|
| 77 |
-
clarify_answers_path:
|
| 78 |
-
Project-root-relative path to the clarification lookup JSON.
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
def __init__(
|
| 82 |
-
self,
|
| 83 |
-
clarify_answers_path: str = "data/clarify_answers.json",
|
| 84 |
-
) -> None:
|
| 85 |
-
self._clarify_answers: dict[str, str] = {}
|
| 86 |
-
base = Path(__file__).resolve().parent.parent
|
| 87 |
-
ca_path = base / clarify_answers_path
|
| 88 |
-
if ca_path.exists():
|
| 89 |
-
self._clarify_answers = json.loads(ca_path.read_text(encoding="utf-8"))
|
| 90 |
-
|
| 91 |
-
# Episode state (initialised properly in reset())
|
| 92 |
self._task: dict[str, Any] = {}
|
| 93 |
-
self._task_id
|
| 94 |
-
self._transcript
|
| 95 |
self._patient_context: dict[str, Any] = {}
|
| 96 |
-
self._max_steps
|
| 97 |
-
self._step_count
|
| 98 |
-
self._done
|
| 99 |
self._current_draft: str | None = None
|
| 100 |
self._errors_so_far: list[str] = []
|
| 101 |
self._last_reward: Reward | None = None
|
| 102 |
self._last_observation: Observation | None = None
|
| 103 |
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
# Public API
|
| 106 |
-
# --------------------------------------------------------------------- #
|
| 107 |
|
| 108 |
def reset(self, task_id: str | None = None) -> Observation:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
Parameters
|
| 112 |
-
----------
|
| 113 |
-
task_id:
|
| 114 |
-
One of the keys in ``TASK_REGISTRY``. When ``None`` the first
|
| 115 |
-
registered task is used.
|
| 116 |
-
|
| 117 |
-
Returns
|
| 118 |
-
-------
|
| 119 |
-
Observation
|
| 120 |
-
The initial observation for the episode.
|
| 121 |
-
|
| 122 |
-
Raises
|
| 123 |
-
------
|
| 124 |
-
ValueError
|
| 125 |
-
If *task_id* is not found in the registry.
|
| 126 |
-
"""
|
| 127 |
-
if task_id is None:
|
| 128 |
-
task_id = next(iter(TASK_REGISTRY))
|
| 129 |
-
|
| 130 |
if task_id not in TASK_REGISTRY:
|
| 131 |
-
|
| 132 |
-
raise ValueError(
|
| 133 |
-
f"Unknown task_id '{task_id}'. Available: {available}"
|
| 134 |
-
)
|
| 135 |
|
| 136 |
self._task = TASK_REGISTRY[task_id]
|
| 137 |
self._task_id = task_id
|
| 138 |
-
|
|
|
|
| 139 |
self._patient_context = self._task.get("patient_context", {})
|
| 140 |
self._max_steps = self._task.get("max_steps", 10)
|
| 141 |
self._step_count = 0
|
|
@@ -144,254 +70,118 @@ def reset(self, task_id: str | None = None) -> Observation:
|
|
| 144 |
self._errors_so_far = []
|
| 145 |
self._last_reward = None
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
self._last_observation = obs
|
| 151 |
-
return obs
|
| 152 |
|
| 153 |
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
|
| 154 |
-
"""Execute one agent action and return the resulting observation, reward,
|
| 155 |
-
done flag, and info dict.
|
| 156 |
-
|
| 157 |
-
Parameters
|
| 158 |
-
----------
|
| 159 |
-
action:
|
| 160 |
-
The agent's chosen action.
|
| 161 |
-
|
| 162 |
-
Returns
|
| 163 |
-
-------
|
| 164 |
-
tuple[Observation, Reward, bool, dict]
|
| 165 |
-
"""
|
| 166 |
if self._done:
|
| 167 |
-
raise RuntimeError(
|
| 168 |
-
"Episode is done. Call reset() before stepping again."
|
| 169 |
-
)
|
| 170 |
|
| 171 |
self._step_count += 1
|
| 172 |
info: dict[str, Any] = {}
|
| 173 |
|
| 174 |
-
#
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
else:
|
| 182 |
-
|
| 183 |
-
self.
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
reward = compute_reward(
|
| 187 |
-
action,
|
| 188 |
-
grader_score=0.0,
|
| 189 |
-
step_count=self._step_count,
|
| 190 |
-
errors_so_far=self._errors_so_far,
|
| 191 |
-
done=False,
|
| 192 |
-
info={"error": "bad_action"},
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
# ---- enforce max-step termination ----
|
| 196 |
if self._step_count >= self._max_steps and not self._done:
|
| 197 |
self._done = True
|
| 198 |
-
reward = Reward(
|
| 199 |
-
|
| 200 |
-
signals=reward.signals,
|
| 201 |
-
done=True,
|
| 202 |
-
info={**reward.info, "termination_reason": "max_steps_reached"},
|
| 203 |
-
)
|
| 204 |
|
| 205 |
self._last_reward = reward
|
| 206 |
-
|
| 207 |
-
_log_event(
|
| 208 |
-
"STEP",
|
| 209 |
-
step=self._step_count,
|
| 210 |
-
action_type=action.action_type,
|
| 211 |
-
reward=reward.value,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
if self._done:
|
| 215 |
-
|
| 216 |
-
"END",
|
| 217 |
-
task_id=self._task_id,
|
| 218 |
-
final_score=reward.value,
|
| 219 |
-
)
|
| 220 |
|
| 221 |
-
|
| 222 |
-
self._last_observation
|
| 223 |
-
return obs, reward, self._done, info
|
| 224 |
|
| 225 |
def state(self) -> EnvironmentState:
|
| 226 |
-
"""Return the full internal state snapshot."""
|
| 227 |
return EnvironmentState(
|
| 228 |
-
task_id=self._task_id,
|
| 229 |
-
|
| 230 |
-
max_steps=self._max_steps,
|
| 231 |
-
done=self._done,
|
| 232 |
-
current_draft=self._current_draft,
|
| 233 |
errors_so_far=list(self._errors_so_far),
|
| 234 |
-
last_reward=self._last_reward,
|
| 235 |
-
observation=self._last_observation,
|
| 236 |
)
|
| 237 |
|
| 238 |
-
# --------------------------------------------------------------------- #
|
| 239 |
# Action handlers
|
| 240 |
-
# --------------------------------------------------------------------- #
|
| 241 |
-
|
| 242 |
-
def _handle_submit(self, action: Action, info: dict) -> Reward:
|
| 243 |
-
"""Process a ``submit_note`` action.
|
| 244 |
|
| 245 |
-
|
| 246 |
-
Otherwise, if the agent has built up a draft via ``revise_section``,
|
| 247 |
-
the draft is parsed into a SOAPNote automatically.
|
| 248 |
-
"""
|
| 249 |
soap = action.soap_note
|
| 250 |
|
| 251 |
-
# Fall back to
|
| 252 |
if soap is None and self._current_draft:
|
| 253 |
-
|
| 254 |
for line in self._current_draft.split("\n"):
|
| 255 |
-
for
|
| 256 |
-
if line.startswith(
|
| 257 |
-
|
| 258 |
-
if all(k in
|
| 259 |
-
soap = SOAPNote(
|
| 260 |
-
subjective=sections["S"],
|
| 261 |
-
objective=sections["O"],
|
| 262 |
-
assessment=sections["A"],
|
| 263 |
-
plan=sections["P"],
|
| 264 |
-
)
|
| 265 |
|
| 266 |
if soap is None:
|
| 267 |
-
|
| 268 |
-
self._errors_so_far.append(
|
| 269 |
-
return compute_reward(
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
step_count=self._step_count,
|
| 273 |
-
errors_so_far=self._errors_so_far,
|
| 274 |
-
done=False,
|
| 275 |
-
info={"error": error},
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
self._current_draft = _soap_to_text(soap)
|
| 279 |
self._done = True
|
| 280 |
|
| 281 |
-
# Attempt to grade via the task-specific grader
|
| 282 |
grader = GRADER_REGISTRY.get(self._task_id)
|
| 283 |
-
if
|
| 284 |
info["warning"] = "No grader registered; returning default reward."
|
| 285 |
-
return compute_reward(
|
| 286 |
-
action,
|
| 287 |
-
grader_score=0.5,
|
| 288 |
-
step_count=self._step_count,
|
| 289 |
-
errors_so_far=self._errors_so_far,
|
| 290 |
-
done=True,
|
| 291 |
-
info=info,
|
| 292 |
-
)
|
| 293 |
|
| 294 |
try:
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
grader_score = (
|
| 299 |
-
sum(raw_signals.values()) / len(raw_signals)
|
| 300 |
-
if raw_signals else 0.0
|
| 301 |
-
)
|
| 302 |
-
info["grader_signals"] = raw_signals
|
| 303 |
except Exception as exc:
|
| 304 |
info["warning"] = f"Grader error: {exc}"
|
| 305 |
-
|
| 306 |
|
| 307 |
-
return compute_reward(
|
| 308 |
-
action,
|
| 309 |
-
grader_score=grader_score,
|
| 310 |
-
step_count=self._step_count,
|
| 311 |
-
errors_so_far=self._errors_so_far,
|
| 312 |
-
done=True,
|
| 313 |
-
info=info,
|
| 314 |
-
)
|
| 315 |
|
| 316 |
-
def
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
done=False,
|
| 326 |
-
info={"error": error},
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
# Lookup a canned answer (case-insensitive key match)
|
| 330 |
-
answer = self._clarify_answers.get(question.lower())
|
| 331 |
-
if answer:
|
| 332 |
-
info["clarify_answer"] = answer
|
| 333 |
-
else:
|
| 334 |
-
info["clarify_answer"] = (
|
| 335 |
-
"No additional information available for that question."
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
# Intermediate actions get zero reward — only submit_note earns score
|
| 339 |
-
return Reward(
|
| 340 |
-
value=0.0,
|
| 341 |
-
signals={"intermediate_step": 1.0},
|
| 342 |
-
done=False,
|
| 343 |
-
info=info,
|
| 344 |
-
)
|
| 345 |
|
| 346 |
-
def
|
| 347 |
-
"""Process a ``revise_section`` action."""
|
| 348 |
if action.section is None or action.revision_text is None:
|
| 349 |
-
|
| 350 |
-
self._errors_so_far.append(
|
| 351 |
-
return Reward(
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
done=False,
|
| 355 |
-
info={"error": error},
|
| 356 |
-
)
|
| 357 |
-
|
| 358 |
-
# If there is an existing draft, patch the requested section
|
| 359 |
if self._current_draft:
|
| 360 |
lines = self._current_draft.split("\n")
|
| 361 |
-
prefix = f"{action.section}: "
|
| 362 |
-
patched = False
|
| 363 |
for i, line in enumerate(lines):
|
| 364 |
if line.startswith(prefix):
|
| 365 |
-
lines[i] =
|
| 366 |
-
|
| 367 |
break
|
| 368 |
-
if patched:
|
| 369 |
-
self._current_draft = "\n".join(lines)
|
| 370 |
else:
|
| 371 |
self._current_draft += f"\n{prefix}{action.revision_text}"
|
| 372 |
else:
|
| 373 |
-
self._current_draft =
|
| 374 |
|
| 375 |
info["revised_section"] = action.section
|
| 376 |
-
|
| 377 |
-
# Intermediate actions get zero reward — only submit_note earns score
|
| 378 |
-
return Reward(
|
| 379 |
-
value=0.0,
|
| 380 |
-
signals={"intermediate_step": 1.0},
|
| 381 |
-
done=False,
|
| 382 |
-
info=info,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
# --------------------------------------------------------------------- #
|
| 386 |
-
# Internal helpers
|
| 387 |
-
# --------------------------------------------------------------------- #
|
| 388 |
-
|
| 389 |
-
def _build_observation(self) -> Observation:
|
| 390 |
-
return Observation(
|
| 391 |
-
transcript=self._transcript,
|
| 392 |
-
task_id=self._task_id,
|
| 393 |
-
patient_context=self._patient_context,
|
| 394 |
-
current_draft=self._current_draft,
|
| 395 |
-
errors_so_far=list(self._errors_so_far),
|
| 396 |
-
step_count=self._step_count,
|
| 397 |
-
)
|
|
|
|
| 1 |
+
"""ClinicalNoteScribeEnv — core environment implementing reset/step/state."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 13 |
from environment.tasks import GRADER_REGISTRY, TASK_REGISTRY
|
| 14 |
|
| 15 |
logger = logging.getLogger("clinical_note_scribe")
|
| 16 |
+
_ROOT = Path(__file__).resolve().parent.parent
|
| 17 |
|
| 18 |
|
| 19 |
+
def _log(event: str, **kw: Any) -> None:
|
| 20 |
+
logger.info(json.dumps({"event": event, "timestamp": time.time(), **kw}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
+
class ClinicalNoteScribeEnv:
|
| 24 |
+
"""OpenEnv-compliant environment for clinical SOAP-note generation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def __init__(self, clarify_answers_path: str = "data/clarify_answers.json") -> None:
|
| 27 |
+
ca = _ROOT / clarify_answers_path
|
| 28 |
+
self._clarify_answers: dict[str, str] = json.loads(ca.read_text()) if ca.exists() else {}
|
| 29 |
+
self._reset_state()
|
| 30 |
|
| 31 |
+
def _reset_state(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
self._task: dict[str, Any] = {}
|
| 33 |
+
self._task_id = ""
|
| 34 |
+
self._transcript = ""
|
| 35 |
self._patient_context: dict[str, Any] = {}
|
| 36 |
+
self._max_steps = 10
|
| 37 |
+
self._step_count = 0
|
| 38 |
+
self._done = True
|
| 39 |
self._current_draft: str | None = None
|
| 40 |
self._errors_so_far: list[str] = []
|
| 41 |
self._last_reward: Reward | None = None
|
| 42 |
self._last_observation: Observation | None = None
|
| 43 |
|
| 44 |
+
def _obs(self) -> Observation:
|
| 45 |
+
return Observation(
|
| 46 |
+
transcript=self._transcript,
|
| 47 |
+
task_id=self._task_id,
|
| 48 |
+
patient_context=self._patient_context,
|
| 49 |
+
current_draft=self._current_draft,
|
| 50 |
+
errors_so_far=list(self._errors_so_far),
|
| 51 |
+
step_count=self._step_count,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
# Public API
|
|
|
|
| 55 |
|
| 56 |
def reset(self, task_id: str | None = None) -> Observation:
|
| 57 |
+
task_id = task_id or next(iter(TASK_REGISTRY))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
if task_id not in TASK_REGISTRY:
|
| 59 |
+
raise ValueError(f"Unknown task_id '{task_id}'. Available: {', '.join(TASK_REGISTRY)}")
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
self._task = TASK_REGISTRY[task_id]
|
| 62 |
self._task_id = task_id
|
| 63 |
+
path = _ROOT / self._task["transcript_file"]
|
| 64 |
+
self._transcript = path.read_text(encoding="utf-8") if path.exists() else f"[Not found: {path}]"
|
| 65 |
self._patient_context = self._task.get("patient_context", {})
|
| 66 |
self._max_steps = self._task.get("max_steps", 10)
|
| 67 |
self._step_count = 0
|
|
|
|
| 70 |
self._errors_so_far = []
|
| 71 |
self._last_reward = None
|
| 72 |
|
| 73 |
+
_log("START", task_id=task_id)
|
| 74 |
+
self._last_observation = self._obs()
|
| 75 |
+
return self._last_observation
|
|
|
|
|
|
|
| 76 |
|
| 77 |
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if self._done:
|
| 79 |
+
raise RuntimeError("Episode is done. Call reset() before stepping again.")
|
|
|
|
|
|
|
| 80 |
|
| 81 |
self._step_count += 1
|
| 82 |
info: dict[str, Any] = {}
|
| 83 |
|
| 84 |
+
# Dispatch
|
| 85 |
+
handler = {
|
| 86 |
+
"submit_note": self._submit,
|
| 87 |
+
"request_clarify": self._clarify,
|
| 88 |
+
"revise_section": self._revise,
|
| 89 |
+
}.get(action.action_type)
|
| 90 |
+
|
| 91 |
+
if handler:
|
| 92 |
+
reward = handler(action, info)
|
| 93 |
else:
|
| 94 |
+
self._errors_so_far.append(f"Unknown action_type: {action.action_type}")
|
| 95 |
+
reward = compute_reward(action, 0.0, self._step_count, self._errors_so_far, done=False, info={"error": "bad_action"})
|
| 96 |
+
|
| 97 |
+
# Max-step termination
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if self._step_count >= self._max_steps and not self._done:
|
| 99 |
self._done = True
|
| 100 |
+
reward = Reward(value=reward.value, signals=reward.signals, done=True,
|
| 101 |
+
info={**reward.info, "termination_reason": "max_steps_reached"})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
self._last_reward = reward
|
| 104 |
+
_log("STEP", step=self._step_count, action_type=action.action_type, reward=reward.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
if self._done:
|
| 106 |
+
_log("END", task_id=self._task_id, final_score=reward.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
self._last_observation = self._obs()
|
| 109 |
+
return self._last_observation, reward, self._done, info
|
|
|
|
| 110 |
|
| 111 |
def state(self) -> EnvironmentState:
|
|
|
|
| 112 |
return EnvironmentState(
|
| 113 |
+
task_id=self._task_id, step_count=self._step_count, max_steps=self._max_steps,
|
| 114 |
+
done=self._done, current_draft=self._current_draft,
|
|
|
|
|
|
|
|
|
|
| 115 |
errors_so_far=list(self._errors_so_far),
|
| 116 |
+
last_reward=self._last_reward, observation=self._last_observation,
|
|
|
|
| 117 |
)
|
| 118 |
|
|
|
|
| 119 |
# Action handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
def _submit(self, action: Action, info: dict) -> Reward:
|
|
|
|
|
|
|
|
|
|
| 122 |
soap = action.soap_note
|
| 123 |
|
| 124 |
+
# Fall back to draft
|
| 125 |
if soap is None and self._current_draft:
|
| 126 |
+
secs = {}
|
| 127 |
for line in self._current_draft.split("\n"):
|
| 128 |
+
for p in ("S: ", "O: ", "A: ", "P: "):
|
| 129 |
+
if line.startswith(p):
|
| 130 |
+
secs[p[0]] = line[len(p):]
|
| 131 |
+
if all(k in secs for k in "SOAP"):
|
| 132 |
+
soap = SOAPNote(subjective=secs["S"], objective=secs["O"], assessment=secs["A"], plan=secs["P"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if soap is None:
|
| 135 |
+
err = "submit_note requires a non-null soap_note (or a complete draft from revise_section)."
|
| 136 |
+
self._errors_so_far.append(err)
|
| 137 |
+
return compute_reward(action, 0.0, self._step_count, self._errors_so_far, done=False, info={"error": err})
|
| 138 |
+
|
| 139 |
+
self._current_draft = f"S: {soap.subjective}\nO: {soap.objective}\nA: {soap.assessment}\nP: {soap.plan}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
self._done = True
|
| 141 |
|
|
|
|
| 142 |
grader = GRADER_REGISTRY.get(self._task_id)
|
| 143 |
+
if not grader:
|
| 144 |
info["warning"] = "No grader registered; returning default reward."
|
| 145 |
+
return compute_reward(action, 0.5, self._step_count, self._errors_so_far, done=True, info=info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
try:
|
| 148 |
+
signals = grader(soap, self._task)
|
| 149 |
+
score = sum(signals.values()) / len(signals) if signals else 0.0
|
| 150 |
+
info["grader_signals"] = signals
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
except Exception as exc:
|
| 152 |
info["warning"] = f"Grader error: {exc}"
|
| 153 |
+
score = 0.0
|
| 154 |
|
| 155 |
+
return compute_reward(action, score, self._step_count, self._errors_so_far, done=True, info=info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
def _clarify(self, action: Action, info: dict) -> Reward:
|
| 158 |
+
q = (action.clarify_question or "").strip()
|
| 159 |
+
if not q:
|
| 160 |
+
err = "request_clarify requires a non-empty clarify_question."
|
| 161 |
+
self._errors_so_far.append(err)
|
| 162 |
+
return Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": err})
|
| 163 |
+
|
| 164 |
+
info["clarify_answer"] = self._clarify_answers.get(q.lower(), "No additional information available for that question.")
|
| 165 |
+
return Reward(value=0.0, signals={"intermediate_step": 1.0}, done=False, info=info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
def _revise(self, action: Action, info: dict) -> Reward:
|
|
|
|
| 168 |
if action.section is None or action.revision_text is None:
|
| 169 |
+
err = "revise_section requires both 'section' and 'revision_text'."
|
| 170 |
+
self._errors_so_far.append(err)
|
| 171 |
+
return Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": err})
|
| 172 |
+
|
| 173 |
+
prefix = f"{action.section}: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if self._current_draft:
|
| 175 |
lines = self._current_draft.split("\n")
|
|
|
|
|
|
|
| 176 |
for i, line in enumerate(lines):
|
| 177 |
if line.startswith(prefix):
|
| 178 |
+
lines[i] = prefix + action.revision_text
|
| 179 |
+
self._current_draft = "\n".join(lines)
|
| 180 |
break
|
|
|
|
|
|
|
| 181 |
else:
|
| 182 |
self._current_draft += f"\n{prefix}{action.revision_text}"
|
| 183 |
else:
|
| 184 |
+
self._current_draft = prefix + action.revision_text
|
| 185 |
|
| 186 |
info["revised_section"] = action.section
|
| 187 |
+
return Reward(value=0.0, signals={"intermediate_step": 1.0}, done=False, info=info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
environment/models.py
CHANGED
|
@@ -1,175 +1,47 @@
|
|
| 1 |
-
"""Pydantic v2 models for the Clinical Note Scribe environment.
|
| 2 |
-
|
| 3 |
-
Defines the typed contracts for observations, actions, rewards,
|
| 4 |
-
and overall environment state used by the OpenEnv spec.
|
| 5 |
-
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
-
|
| 9 |
from typing import Any, Literal, Optional
|
| 10 |
-
|
| 11 |
from pydantic import BaseModel, Field
|
| 12 |
|
| 13 |
|
| 14 |
-
# ---------------------------------------------------------------------------
|
| 15 |
-
# Observation — what the agent sees after each step
|
| 16 |
-
# ---------------------------------------------------------------------------
|
| 17 |
-
|
| 18 |
class Observation(BaseModel):
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
transcript: str = Field(
|
| 22 |
-
...,
|
| 23 |
-
description="Full doctor–patient transcript for the current task.",
|
| 24 |
-
)
|
| 25 |
-
task_id: str = Field(
|
| 26 |
-
...,
|
| 27 |
-
description="Unique identifier for the task (e.g. 'easy_routine_checkup').",
|
| 28 |
-
)
|
| 29 |
-
patient_context: dict[str, Any] = Field(
|
| 30 |
-
default_factory=dict,
|
| 31 |
-
description="Structured patient demographics and history.",
|
| 32 |
-
)
|
| 33 |
-
current_draft: Optional[str] = Field(
|
| 34 |
-
default=None,
|
| 35 |
-
description="The agent's most recent SOAP-note draft, if any.",
|
| 36 |
-
)
|
| 37 |
-
errors_so_far: list[str] = Field(
|
| 38 |
-
default_factory=list,
|
| 39 |
-
description="Accumulated error/feedback messages from prior steps.",
|
| 40 |
-
)
|
| 41 |
-
step_count: int = Field(
|
| 42 |
-
default=0,
|
| 43 |
-
ge=0,
|
| 44 |
-
description="Number of steps taken in the current episode.",
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
-
# SOAPNote — structured clinical note
|
| 50 |
-
# ---------------------------------------------------------------------------
|
| 51 |
|
| 52 |
class SOAPNote(BaseModel):
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
description="Patient's self-reported symptoms and history.",
|
| 58 |
-
)
|
| 59 |
-
objective: str = Field(
|
| 60 |
-
...,
|
| 61 |
-
description="Clinician's measurable findings (vitals, exam, labs).",
|
| 62 |
-
)
|
| 63 |
-
assessment: str = Field(
|
| 64 |
-
...,
|
| 65 |
-
description="Clinician's diagnosis or differential.",
|
| 66 |
-
)
|
| 67 |
-
plan: str = Field(
|
| 68 |
-
...,
|
| 69 |
-
description="Treatment plan, follow-ups, and prescriptions.",
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
|
| 73 |
-
# ---------------------------------------------------------------------------
|
| 74 |
-
# Action — what the agent can do
|
| 75 |
-
# ---------------------------------------------------------------------------
|
| 76 |
|
| 77 |
class Action(BaseModel):
|
| 78 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
action_type: Literal["submit_note", "request_clarify", "revise_section"] = Field(
|
| 81 |
-
...,
|
| 82 |
-
description="The kind of action the agent is taking.",
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
# --- submit_note fields ---
|
| 86 |
-
soap_note: Optional[SOAPNote] = Field(
|
| 87 |
-
default=None,
|
| 88 |
-
description="Complete SOAP note (required when action_type == 'submit_note').",
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
# --- revise_section fields ---
|
| 92 |
-
section: Optional[Literal["S", "O", "A", "P"]] = Field(
|
| 93 |
-
default=None,
|
| 94 |
-
description="Which SOAP section to revise (required when action_type == 'revise_section').",
|
| 95 |
-
)
|
| 96 |
-
revision_text: Optional[str] = Field(
|
| 97 |
-
default=None,
|
| 98 |
-
description="Replacement text for the specified section.",
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
# --- request_clarify fields ---
|
| 102 |
-
clarify_question: Optional[str] = Field(
|
| 103 |
-
default=None,
|
| 104 |
-
description="Free-text question the agent asks for clarification.",
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ---------------------------------------------------------------------------
|
| 109 |
-
# Reward — multi-signal feedback
|
| 110 |
-
# ---------------------------------------------------------------------------
|
| 111 |
|
| 112 |
class Reward(BaseModel):
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
ge=0.0,
|
| 118 |
-
le=1.0,
|
| 119 |
-
description="Aggregate reward in the range [0.0, 1.0].",
|
| 120 |
-
)
|
| 121 |
-
signals: dict[str, float] = Field(
|
| 122 |
-
default_factory=dict,
|
| 123 |
-
description="Breakdown of individual reward sub-signals.",
|
| 124 |
-
)
|
| 125 |
-
done: bool = Field(
|
| 126 |
-
...,
|
| 127 |
-
description="Whether the episode has ended.",
|
| 128 |
-
)
|
| 129 |
-
info: dict[str, Any] = Field(
|
| 130 |
-
default_factory=dict,
|
| 131 |
-
description="Auxiliary metadata (e.g. grader diagnostics).",
|
| 132 |
-
)
|
| 133 |
|
| 134 |
|
| 135 |
-
# ---------------------------------------------------------------------------
|
| 136 |
-
# EnvironmentState — full internal state exposed by state()
|
| 137 |
-
# ---------------------------------------------------------------------------
|
| 138 |
-
|
| 139 |
class EnvironmentState(BaseModel):
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
ge=0,
|
| 149 |
-
description="Steps taken so far in this episode.",
|
| 150 |
-
)
|
| 151 |
-
max_steps: int = Field(
|
| 152 |
-
default=10,
|
| 153 |
-
ge=1,
|
| 154 |
-
description="Maximum steps allowed per episode.",
|
| 155 |
-
)
|
| 156 |
-
done: bool = Field(
|
| 157 |
-
default=False,
|
| 158 |
-
description="Whether the current episode has terminated.",
|
| 159 |
-
)
|
| 160 |
-
current_draft: Optional[str] = Field(
|
| 161 |
-
default=None,
|
| 162 |
-
description="Latest SOAP-note draft text, if any.",
|
| 163 |
-
)
|
| 164 |
-
errors_so_far: list[str] = Field(
|
| 165 |
-
default_factory=list,
|
| 166 |
-
description="Accumulated feedback/error messages.",
|
| 167 |
-
)
|
| 168 |
-
last_reward: Optional[Reward] = Field(
|
| 169 |
-
default=None,
|
| 170 |
-
description="Most recent reward object, if a step has been taken.",
|
| 171 |
-
)
|
| 172 |
-
observation: Optional[Observation] = Field(
|
| 173 |
-
default=None,
|
| 174 |
-
description="Most recent observation returned to the agent.",
|
| 175 |
-
)
|
|
|
|
| 1 |
+
"""Pydantic v2 models for the Clinical Note Scribe environment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
|
|
|
| 4 |
from typing import Any, Literal, Optional
|
|
|
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
class Observation(BaseModel):
|
| 9 |
+
transcript: str = Field(..., description="Full doctor–patient transcript.")
|
| 10 |
+
task_id: str = Field(..., description="Unique task identifier.")
|
| 11 |
+
patient_context: dict[str, Any] = Field(default_factory=dict, description="Patient demographics and history.")
|
| 12 |
+
current_draft: Optional[str] = Field(None, description="Most recent SOAP-note draft.")
|
| 13 |
+
errors_so_far: list[str] = Field(default_factory=list, description="Accumulated error messages.")
|
| 14 |
+
step_count: int = Field(0, ge=0, description="Steps taken in the current episode.")
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class SOAPNote(BaseModel):
|
| 18 |
+
subjective: str = Field(..., description="Patient's self-reported symptoms and history.")
|
| 19 |
+
objective: str = Field(..., description="Clinician's measurable findings.")
|
| 20 |
+
assessment: str = Field(..., description="Clinician's diagnosis or differential.")
|
| 21 |
+
plan: str = Field(..., description="Treatment plan, follow-ups, and prescriptions.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class Action(BaseModel):
|
| 25 |
+
action_type: Literal["submit_note", "request_clarify", "revise_section"] = Field(..., description="Action kind.")
|
| 26 |
+
soap_note: Optional[SOAPNote] = Field(None, description="SOAP note (required for submit_note).")
|
| 27 |
+
section: Optional[Literal["S", "O", "A", "P"]] = Field(None, description="Section to revise.")
|
| 28 |
+
revision_text: Optional[str] = Field(None, description="Replacement text for the section.")
|
| 29 |
+
clarify_question: Optional[str] = Field(None, description="Clarification question.")
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
class Reward(BaseModel):
|
| 33 |
+
value: float = Field(..., ge=0.0, le=1.0, description="Aggregate reward [0, 1].")
|
| 34 |
+
signals: dict[str, float] = Field(default_factory=dict, description="Reward sub-signals.")
|
| 35 |
+
done: bool = Field(..., description="Whether the episode ended.")
|
| 36 |
+
info: dict[str, Any] = Field(default_factory=dict, description="Auxiliary metadata.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
class EnvironmentState(BaseModel):
|
| 40 |
+
task_id: str = Field(..., description="Active task identifier.")
|
| 41 |
+
step_count: int = Field(0, ge=0, description="Steps taken so far.")
|
| 42 |
+
max_steps: int = Field(10, ge=1, description="Max steps per episode.")
|
| 43 |
+
done: bool = Field(False, description="Whether the episode terminated.")
|
| 44 |
+
current_draft: Optional[str] = Field(None, description="Latest SOAP draft text.")
|
| 45 |
+
errors_so_far: list[str] = Field(default_factory=list, description="Error messages.")
|
| 46 |
+
last_reward: Optional[Reward] = Field(None, description="Most recent reward.")
|
| 47 |
+
observation: Optional[Observation] = Field(None, description="Most recent observation.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
environment/reward.py
CHANGED
|
@@ -1,18 +1,4 @@
|
|
| 1 |
-
"""Multi-signal reward computation for the Clinical Note Scribe environment.
|
| 2 |
-
|
| 3 |
-
Reward formula (all weights sum to 1.0 before penalties):
|
| 4 |
-
|
| 5 |
-
weighted_sum = grader_score × 0.60 (clinical accuracy from task grader)
|
| 6 |
-
+ conciseness_bonus × 0.10 (1.0 if note ≤ 400 words, else 0.0)
|
| 7 |
-
+ safe_language_score× 0.15 (1.0 if no unsafe-certainty phrases)
|
| 8 |
-
+ format_valid × 0.15 (1.0 if SOAP JSON is well-formed)
|
| 9 |
-
|
| 10 |
-
Deductions (applied after weighted sum):
|
| 11 |
-
- 0.05 × max(0, step_count - 3) (penalty for excessive clarification steps)
|
| 12 |
-
- 0.10 × len(errors_so_far) (penalty for each invalid action)
|
| 13 |
-
|
| 14 |
-
Final value is clamped to [0.0, 1.0].
|
| 15 |
-
"""
|
| 16 |
|
| 17 |
from __future__ import annotations
|
| 18 |
|
|
@@ -21,99 +7,38 @@
|
|
| 21 |
|
| 22 |
from environment.models import Action, Reward, SOAPNote
|
| 23 |
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
W_GRADER = 0.60
|
| 30 |
-
W_CONCISE = 0.10
|
| 31 |
-
W_SAFE_LANG = 0.15
|
| 32 |
-
W_FORMAT = 0.15
|
| 33 |
-
|
| 34 |
-
# Deduction constants
|
| 35 |
-
STEP_PENALTY_RATE = 0.05 # per step beyond FREE_STEPS
|
| 36 |
-
FREE_STEPS = 3
|
| 37 |
-
ERROR_PENALTY_RATE = 0.10 # per item in errors_so_far
|
| 38 |
-
|
| 39 |
-
# Conciseness threshold
|
| 40 |
WORD_LIMIT = 400
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
]
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# ---------------------------------------------------------------------------
|
| 64 |
-
|
| 65 |
-
def _conciseness_bonus(soap_note: Optional[SOAPNote]) -> float:
|
| 66 |
-
"""Return 1.0 if the total SOAP note word count is at or below WORD_LIMIT."""
|
| 67 |
-
if soap_note is None:
|
| 68 |
-
return 0.0
|
| 69 |
-
text = " ".join([
|
| 70 |
-
soap_note.subjective,
|
| 71 |
-
soap_note.objective,
|
| 72 |
-
soap_note.assessment,
|
| 73 |
-
soap_note.plan,
|
| 74 |
-
])
|
| 75 |
-
word_count = len(text.split())
|
| 76 |
-
return 1.0 if word_count <= WORD_LIMIT else 0.0
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _safe_language_score(soap_note: Optional[SOAPNote]) -> float:
|
| 80 |
-
"""Return 1.0 if no unsafe-certainty phrases are found in the SOAP note."""
|
| 81 |
-
if soap_note is None:
|
| 82 |
-
return 1.0 # no note submitted → no unsafe language
|
| 83 |
-
text = " ".join([
|
| 84 |
-
soap_note.subjective,
|
| 85 |
-
soap_note.objective,
|
| 86 |
-
soap_note.assessment,
|
| 87 |
-
soap_note.plan,
|
| 88 |
-
])
|
| 89 |
-
for pattern in _UNSAFE_PATTERNS:
|
| 90 |
-
if pattern.search(text):
|
| 91 |
-
return 0.0
|
| 92 |
-
return 1.0
|
| 93 |
|
| 94 |
|
| 95 |
-
def _format_valid(action: Action) -> float:
|
| 96 |
-
"""Return 1.0 if the submitted note has all required non-empty SOAP fields.
|
| 97 |
-
|
| 98 |
-
This acts as a lightweight structural / «JSON well-formed» check:
|
| 99 |
-
each of S, O, A, P must be a non-empty string, and the action_type
|
| 100 |
-
must be ``submit_note``.
|
| 101 |
-
"""
|
| 102 |
-
if action.action_type != "submit_note":
|
| 103 |
-
return 1.0 # non-submission actions are not graded on format
|
| 104 |
-
if action.soap_note is None:
|
| 105 |
-
return 0.0
|
| 106 |
-
soap = action.soap_note
|
| 107 |
-
fields = [soap.subjective, soap.objective, soap.assessment, soap.plan]
|
| 108 |
-
if all(isinstance(f, str) and f.strip() for f in fields):
|
| 109 |
-
return 1.0
|
| 110 |
-
return 0.0
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
# ---------------------------------------------------------------------------
|
| 114 |
-
# Public API
|
| 115 |
-
# ---------------------------------------------------------------------------
|
| 116 |
-
|
| 117 |
def compute_reward(
|
| 118 |
action: Action,
|
| 119 |
grader_score: float,
|
|
@@ -123,75 +48,54 @@ def compute_reward(
|
|
| 123 |
done: bool = False,
|
| 124 |
info: Optional[dict[str, Any]] = None,
|
| 125 |
) -> Reward:
|
| 126 |
-
"""Compute the multi-signal reward for a completed step.
|
| 127 |
-
|
| 128 |
-
Parameters
|
| 129 |
-
----------
|
| 130 |
-
action:
|
| 131 |
-
The action that was just executed.
|
| 132 |
-
grader_score:
|
| 133 |
-
Clinical-accuracy score returned by the task-specific grader (0.0–1.0).
|
| 134 |
-
Use 0.0 for non-submission actions.
|
| 135 |
-
step_count:
|
| 136 |
-
Total number of steps taken so far in the episode (including this one).
|
| 137 |
-
errors_so_far:
|
| 138 |
-
List of error messages accumulated during the episode.
|
| 139 |
-
done:
|
| 140 |
-
Whether the episode ended with this step.
|
| 141 |
-
info:
|
| 142 |
-
Optional auxiliary metadata dict to include in the Reward.
|
| 143 |
-
|
| 144 |
-
Returns
|
| 145 |
-
-------
|
| 146 |
-
Reward
|
| 147 |
-
Fully populated Reward with ``value`` and ``signals`` breakdown.
|
| 148 |
-
"""
|
| 149 |
grader_score = max(0.0, min(1.0, grader_score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
fmt = _format_valid(action)
|
| 155 |
-
|
| 156 |
-
# ---- weighted sum ----
|
| 157 |
-
weighted = (
|
| 158 |
grader_score * W_GRADER
|
| 159 |
-
+ conciseness
|
| 160 |
-
+ safe_lang
|
| 161 |
-
+ fmt
|
|
|
|
|
|
|
| 162 |
)
|
| 163 |
-
|
| 164 |
-
# ---- deductions ----
|
| 165 |
-
extra_steps = max(0, step_count - FREE_STEPS)
|
| 166 |
-
step_penalty = extra_steps * STEP_PENALTY_RATE
|
| 167 |
-
error_penalty = len(errors_so_far) * ERROR_PENALTY_RATE
|
| 168 |
-
|
| 169 |
-
raw = weighted - step_penalty - error_penalty
|
| 170 |
-
|
| 171 |
-
# ---- clamp ----
|
| 172 |
-
value = max(0.01, min(0.99, raw))
|
| 173 |
-
|
| 174 |
-
signals: dict[str, float] = {
|
| 175 |
-
# positive contributions
|
| 176 |
-
"grader_score": round(grader_score * W_GRADER, 4),
|
| 177 |
-
"conciseness_bonus": round(conciseness * W_CONCISE, 4),
|
| 178 |
-
"safe_language_score": round(safe_lang * W_SAFE_LANG, 4),
|
| 179 |
-
"format_valid": round(fmt * W_FORMAT, 4),
|
| 180 |
-
# deductions (stored as negative numbers for clarity)
|
| 181 |
-
"step_penalty": round(-step_penalty, 4),
|
| 182 |
-
"error_penalty": round(-error_penalty, 4),
|
| 183 |
-
# raw sub-signal values (unweighted, for introspection)
|
| 184 |
-
"_grader_score_raw": round(grader_score, 4),
|
| 185 |
-
"_conciseness_raw": round(conciseness, 4),
|
| 186 |
-
"_safe_language_raw": round(safe_lang, 4),
|
| 187 |
-
"_format_valid_raw": round(fmt, 4),
|
| 188 |
-
"_extra_steps": float(extra_steps),
|
| 189 |
-
"_error_count": float(len(errors_so_far)),
|
| 190 |
-
}
|
| 191 |
|
| 192 |
return Reward(
|
| 193 |
-
value=
|
| 194 |
-
signals=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
done=done,
|
| 196 |
info=info or {},
|
| 197 |
)
|
|
|
|
| 1 |
+
"""Multi-signal reward computation for the Clinical Note Scribe environment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 7 |
|
| 8 |
from environment.models import Action, Reward, SOAPNote
|
| 9 |
|
| 10 |
+
# Reward weights (sum to 1.0)
|
| 11 |
+
W_GRADER, W_CONCISE, W_SAFE_LANG, W_FORMAT = 0.60, 0.10, 0.15, 0.15
|
| 12 |
|
| 13 |
+
# Deductions
|
| 14 |
+
STEP_PENALTY_RATE = 0.05
|
| 15 |
+
FREE_STEPS = 3
|
| 16 |
+
ERROR_PENALTY_RATE = 0.10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
WORD_LIMIT = 400
|
| 18 |
|
| 19 |
+
# Pre-compiled unsafe clinical certainty patterns
|
| 20 |
+
_UNSAFE_RE = re.compile(
|
| 21 |
+
r"\bpatient definitely has\b"
|
| 22 |
+
r"|\bdiagnosis is certain\b"
|
| 23 |
+
r"|\bno doubt\b"
|
| 24 |
+
r"|\babsolutely confirmed\b"
|
| 25 |
+
r"|\b100%\s+certain\b"
|
| 26 |
+
r"|\bwill definitely\b"
|
| 27 |
+
r"|\bguaranteed to\b"
|
| 28 |
+
r"|\bcannot be\s+\w+\s+else\b"
|
| 29 |
+
r"|\bwithout question\b"
|
| 30 |
+
r"|\bthis is clearly\b",
|
| 31 |
+
re.IGNORECASE,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _soap_text(soap: Optional[SOAPNote]) -> Optional[str]:
|
| 36 |
+
"""Join all SOAP fields into one string. Returns None if no note."""
|
| 37 |
+
if soap is None:
|
| 38 |
+
return None
|
| 39 |
+
return f"{soap.subjective} {soap.objective} {soap.assessment} {soap.plan}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def compute_reward(
|
| 43 |
action: Action,
|
| 44 |
grader_score: float,
|
|
|
|
| 48 |
done: bool = False,
|
| 49 |
info: Optional[dict[str, Any]] = None,
|
| 50 |
) -> Reward:
|
| 51 |
+
"""Compute the multi-signal reward for a completed step."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
grader_score = max(0.0, min(1.0, grader_score))
|
| 53 |
+
text = _soap_text(action.soap_note)
|
| 54 |
+
|
| 55 |
+
# Sub-signals
|
| 56 |
+
conciseness = 1.0 if text and len(text.split()) <= WORD_LIMIT else (0.0 if text else 0.0)
|
| 57 |
+
safe_lang = 0.0 if (text and _UNSAFE_RE.search(text)) else 1.0
|
| 58 |
+
fmt = (
|
| 59 |
+
1.0
|
| 60 |
+
if action.action_type != "submit_note"
|
| 61 |
+
else (
|
| 62 |
+
1.0
|
| 63 |
+
if action.soap_note and all(
|
| 64 |
+
getattr(action.soap_note, f).strip()
|
| 65 |
+
for f in ("subjective", "objective", "assessment", "plan")
|
| 66 |
+
)
|
| 67 |
+
else 0.0
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
|
| 71 |
+
# Weighted sum minus deductions, clamped to (0, 1)
|
| 72 |
+
extra_steps = max(0, step_count - FREE_STEPS)
|
| 73 |
+
raw = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
grader_score * W_GRADER
|
| 75 |
+
+ conciseness * W_CONCISE
|
| 76 |
+
+ safe_lang * W_SAFE_LANG
|
| 77 |
+
+ fmt * W_FORMAT
|
| 78 |
+
- extra_steps * STEP_PENALTY_RATE
|
| 79 |
+
- len(errors_so_far) * ERROR_PENALTY_RATE
|
| 80 |
)
|
| 81 |
+
value = round(max(0.01, min(0.99, raw)), 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
return Reward(
|
| 84 |
+
value=value,
|
| 85 |
+
signals={
|
| 86 |
+
"grader_score": round(grader_score * W_GRADER, 4),
|
| 87 |
+
"conciseness_bonus": round(conciseness * W_CONCISE, 4),
|
| 88 |
+
"safe_language_score": round(safe_lang * W_SAFE_LANG, 4),
|
| 89 |
+
"format_valid": round(fmt * W_FORMAT, 4),
|
| 90 |
+
"step_penalty": round(-extra_steps * STEP_PENALTY_RATE, 4),
|
| 91 |
+
"error_penalty": round(-len(errors_so_far) * ERROR_PENALTY_RATE, 4),
|
| 92 |
+
"_grader_score_raw": round(grader_score, 4),
|
| 93 |
+
"_conciseness_raw": round(conciseness, 4),
|
| 94 |
+
"_safe_language_raw": round(safe_lang, 4),
|
| 95 |
+
"_format_valid_raw": round(fmt, 4),
|
| 96 |
+
"_extra_steps": float(extra_steps),
|
| 97 |
+
"_error_count": float(len(errors_so_far)),
|
| 98 |
+
},
|
| 99 |
done=done,
|
| 100 |
info=info or {},
|
| 101 |
)
|
environment/tasks/task_easy.py
CHANGED
|
@@ -1,85 +1,37 @@
|
|
| 1 |
-
"""Easy task — routine check-up.
|
| 2 |
-
|
| 3 |
-
Grader uses keyword-based clinical rubric scoring to evaluate the SOAP note
|
| 4 |
-
against expected findings from a simple cold / blood pressure check visit.
|
| 5 |
-
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
-
|
| 9 |
from typing import Any
|
| 10 |
-
|
| 11 |
from environment.models import SOAPNote
|
| 12 |
|
| 13 |
-
|
| 14 |
-
# ---------------------------------------------------------------------------
|
| 15 |
-
# Task definition
|
| 16 |
-
# ---------------------------------------------------------------------------
|
| 17 |
-
|
| 18 |
EASY_TASK: dict[str, Any] = {
|
| 19 |
"task_id": "easy_routine_checkup",
|
| 20 |
"description": "Generate a SOAP note for a routine annual check-up visit.",
|
| 21 |
"transcript_file": "data/transcripts/easy.txt",
|
| 22 |
"patient_context": {
|
| 23 |
-
"patient_id": "P-1001",
|
| 24 |
-
"
|
| 25 |
-
"age": 34,
|
| 26 |
-
"sex": "F",
|
| 27 |
-
"known_conditions": [],
|
| 28 |
-
"current_medications": [],
|
| 29 |
-
"allergies": ["Penicillin"],
|
| 30 |
},
|
| 31 |
"max_steps": 5,
|
| 32 |
}
|
| 33 |
|
| 34 |
|
| 35 |
-
# ---------------------------------------------------------------------------
|
| 36 |
-
# Grader
|
| 37 |
-
# ---------------------------------------------------------------------------
|
| 38 |
-
|
| 39 |
def grade_easy(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
text_p = soap_note.plan.lower()
|
| 53 |
-
|
| 54 |
-
# 1. Subjective — chief complaints
|
| 55 |
-
s_score = 0.0
|
| 56 |
-
if "sore throat" in text_s or "runny nose" in text_s or "congestion" in text_s:
|
| 57 |
-
s_score += 0.5
|
| 58 |
-
if "5 days" in text_s or "five days" in text_s or "headache" in text_s:
|
| 59 |
-
s_score += 0.5
|
| 60 |
-
|
| 61 |
-
# 2. Objective — vitals
|
| 62 |
-
o_score = 0.0
|
| 63 |
-
if "118/76" in text_o or "118 over 76" in text_o or "blood pressure" in text_o:
|
| 64 |
-
o_score += 0.5
|
| 65 |
-
if "72" in text_o or "heart rate" in text_o or "lungs clear" in text_o:
|
| 66 |
-
o_score += 0.5
|
| 67 |
-
|
| 68 |
-
# 3. Assessment — viral URI
|
| 69 |
-
a_score = 0.0
|
| 70 |
-
if "viral" in text_a or "uri" in text_a or "upper respiratory" in text_a:
|
| 71 |
-
a_score += 1.0
|
| 72 |
-
|
| 73 |
-
# 4. Plan — supportive care
|
| 74 |
-
p_score = 0.0
|
| 75 |
-
if "fluids" in text_p or "rest" in text_p or "hydrat" in text_p:
|
| 76 |
-
p_score += 0.5
|
| 77 |
-
if "dayquil" in text_p or "follow" in text_p or "return" in text_p:
|
| 78 |
-
p_score += 0.5
|
| 79 |
-
|
| 80 |
return {
|
| 81 |
-
"subjective_accuracy":
|
| 82 |
-
"objective_accuracy":
|
| 83 |
-
"assessment_accuracy":
|
| 84 |
-
"plan_accuracy":
|
| 85 |
}
|
|
|
|
| 1 |
+
"""Easy task — routine check-up."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
|
|
|
| 4 |
from typing import Any
|
|
|
|
| 5 |
from environment.models import SOAPNote
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
EASY_TASK: dict[str, Any] = {
|
| 8 |
"task_id": "easy_routine_checkup",
|
| 9 |
"description": "Generate a SOAP note for a routine annual check-up visit.",
|
| 10 |
"transcript_file": "data/transcripts/easy.txt",
|
| 11 |
"patient_context": {
|
| 12 |
+
"patient_id": "P-1001", "name": "Jane Doe", "age": 34, "sex": "F",
|
| 13 |
+
"known_conditions": [], "current_medications": [], "allergies": ["Penicillin"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
},
|
| 15 |
"max_steps": 5,
|
| 16 |
}
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def grade_easy(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 20 |
+
s, o, a, p = (soap_note.subjective.lower(), soap_note.objective.lower(),
|
| 21 |
+
soap_note.assessment.lower(), soap_note.plan.lower())
|
| 22 |
+
|
| 23 |
+
s_score = 0.5 * any(k in s for k in ("sore throat", "runny nose", "congestion")) \
|
| 24 |
+
+ 0.5 * any(k in s for k in ("5 days", "five days", "headache"))
|
| 25 |
+
o_score = 0.5 * any(k in o for k in ("118/76", "118 over 76", "blood pressure")) \
|
| 26 |
+
+ 0.5 * any(k in o for k in ("72", "heart rate", "lungs clear"))
|
| 27 |
+
a_score = 1.0 * any(k in a for k in ("viral", "uri", "upper respiratory"))
|
| 28 |
+
p_score = 0.5 * any(k in p for k in ("fluids", "rest", "hydrat")) \
|
| 29 |
+
+ 0.5 * any(k in p for k in ("dayquil", "follow", "return"))
|
| 30 |
+
|
| 31 |
+
clamp = lambda v: max(0.01, min(v, 0.99))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
return {
|
| 33 |
+
"subjective_accuracy": clamp(s_score),
|
| 34 |
+
"objective_accuracy": clamp(o_score),
|
| 35 |
+
"assessment_accuracy": clamp(a_score),
|
| 36 |
+
"plan_accuracy": clamp(p_score),
|
| 37 |
}
|
environment/tasks/task_hard.py
CHANGED
|
@@ -1,118 +1,42 @@
|
|
| 1 |
-
"""Hard task — complex ER visit.
|
| 2 |
-
|
| 3 |
-
Grader uses keyword-based clinical rubric scoring to evaluate the SOAP note
|
| 4 |
-
against expected findings from a complex ER visit with overlapping chest pain,
|
| 5 |
-
SOB, and a possible PE complicated by a contrast dye allergy.
|
| 6 |
-
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
-
|
| 10 |
from typing import Any
|
| 11 |
-
|
| 12 |
from environment.models import SOAPNote
|
| 13 |
|
| 14 |
-
|
| 15 |
-
# ---------------------------------------------------------------------------
|
| 16 |
-
# Task definition
|
| 17 |
-
# ---------------------------------------------------------------------------
|
| 18 |
-
|
| 19 |
HARD_TASK: dict[str, Any] = {
|
| 20 |
"task_id": "hard_complex_er_visit",
|
| 21 |
-
"description":
|
| 22 |
-
"Generate a SOAP note for a complex emergency-room visit involving "
|
| 23 |
-
"chest pain, polytrauma assessment, and multiple co-morbidities."
|
| 24 |
-
),
|
| 25 |
"transcript_file": "data/transcripts/hard.txt",
|
| 26 |
"patient_context": {
|
| 27 |
-
"patient_id": "P-3782",
|
| 28 |
-
"
|
| 29 |
-
"
|
| 30 |
-
"sex": "F",
|
| 31 |
-
"known_conditions": [
|
| 32 |
-
"Coronary Artery Disease",
|
| 33 |
-
"Atrial Fibrillation",
|
| 34 |
-
"Chronic Kidney Disease Stage 3",
|
| 35 |
-
"Osteoarthritis",
|
| 36 |
-
],
|
| 37 |
-
"current_medications": [
|
| 38 |
-
"Aspirin 81 mg daily",
|
| 39 |
-
"Warfarin 5 mg daily",
|
| 40 |
-
"Metoprolol 50 mg BID",
|
| 41 |
-
"Furosemide 40 mg daily",
|
| 42 |
-
"Amlodipine 5 mg daily",
|
| 43 |
-
],
|
| 44 |
"allergies": ["Sulfa drugs", "Contrast dye"],
|
| 45 |
-
"recent_labs": {
|
| 46 |
-
|
| 47 |
-
"BNP": "450 pg/mL",
|
| 48 |
-
"creatinine": "1.9 mg/dL",
|
| 49 |
-
"eGFR": "34 mL/min",
|
| 50 |
-
"INR": "2.6",
|
| 51 |
-
"hemoglobin": "10.2 g/dL",
|
| 52 |
-
},
|
| 53 |
-
"vitals_on_arrival": {
|
| 54 |
-
"BP": "168/94 mmHg",
|
| 55 |
-
"HR": "112 bpm (irregular)",
|
| 56 |
-
"RR": "22 breaths/min",
|
| 57 |
-
"SpO2": "91% on room air",
|
| 58 |
-
"Temp": "37.2°C",
|
| 59 |
-
},
|
| 60 |
},
|
| 61 |
"max_steps": 10,
|
| 62 |
}
|
| 63 |
|
| 64 |
|
| 65 |
-
# ---------------------------------------------------------------------------
|
| 66 |
-
# Grader
|
| 67 |
-
# ---------------------------------------------------------------------------
|
| 68 |
-
|
| 69 |
def grade_hard(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
text_p = soap_note.plan.lower()
|
| 84 |
-
|
| 85 |
-
# 1. Subjective — catching the contradiction and presenting complaints
|
| 86 |
-
s_score = 0.0
|
| 87 |
-
if "chest pain" in text_s or "shortness of breath" in text_s or "sob" in text_s:
|
| 88 |
-
s_score += 0.5
|
| 89 |
-
if "nitroglycerin" in text_s or "contradict" in text_s or "denied" in text_s:
|
| 90 |
-
s_score += 0.5
|
| 91 |
-
|
| 92 |
-
# 2. Objective — elevated D-dimer and allergy awareness
|
| 93 |
-
o_score = 0.0
|
| 94 |
-
if "d-dimer" in text_o or "1840" in text_o or "d dimer" in text_o:
|
| 95 |
-
o_score += 0.5
|
| 96 |
-
if "allergy" in text_o or "contrast" in text_o or "troponin" in text_o:
|
| 97 |
-
o_score += 0.5
|
| 98 |
-
|
| 99 |
-
# 3. Assessment — the dual differential (ACS vs PE)
|
| 100 |
-
a_score = 0.0
|
| 101 |
-
if "acs" in text_a or "acute coronary" in text_a or "coronary" in text_a or "ischemia" in text_a:
|
| 102 |
-
a_score += 0.5
|
| 103 |
-
if "pe" in text_a or "pulmonary embolism" in text_a or "embolism" in text_a:
|
| 104 |
-
a_score += 0.5
|
| 105 |
-
|
| 106 |
-
# 4. Plan — adapting to the allergy (V/Q scan) and admission
|
| 107 |
-
p_score = 0.0
|
| 108 |
-
if "v/q" in text_p or "ventilation" in text_p or "perfusion" in text_p:
|
| 109 |
-
p_score += 0.5
|
| 110 |
-
if "icu" in text_p or "admit" in text_p or "cardiac" in text_p:
|
| 111 |
-
p_score += 0.5
|
| 112 |
-
|
| 113 |
return {
|
| 114 |
-
"subjective_accuracy":
|
| 115 |
-
"objective_accuracy":
|
| 116 |
-
"assessment_accuracy":
|
| 117 |
-
"plan_accuracy":
|
| 118 |
}
|
|
|
|
| 1 |
+
"""Hard task — complex ER visit."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
|
|
|
| 4 |
from typing import Any
|
|
|
|
| 5 |
from environment.models import SOAPNote
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
HARD_TASK: dict[str, Any] = {
|
| 8 |
"task_id": "hard_complex_er_visit",
|
| 9 |
+
"description": "Generate a SOAP note for a complex ER visit with chest pain, PE differential, and contrast allergy.",
|
|
|
|
|
|
|
|
|
|
| 10 |
"transcript_file": "data/transcripts/hard.txt",
|
| 11 |
"patient_context": {
|
| 12 |
+
"patient_id": "P-3782", "name": "Maria Garcia", "age": 72, "sex": "F",
|
| 13 |
+
"known_conditions": ["Coronary Artery Disease", "Atrial Fibrillation", "Chronic Kidney Disease Stage 3", "Osteoarthritis"],
|
| 14 |
+
"current_medications": ["Aspirin 81 mg daily", "Warfarin 5 mg daily", "Metoprolol 50 mg BID", "Furosemide 40 mg daily", "Amlodipine 5 mg daily"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"allergies": ["Sulfa drugs", "Contrast dye"],
|
| 16 |
+
"recent_labs": {"troponin_I": "0.08 ng/mL", "BNP": "450 pg/mL", "creatinine": "1.9 mg/dL", "eGFR": "34 mL/min", "INR": "2.6", "hemoglobin": "10.2 g/dL"},
|
| 17 |
+
"vitals_on_arrival": {"BP": "168/94 mmHg", "HR": "112 bpm (irregular)", "RR": "22 breaths/min", "SpO2": "91% on room air", "Temp": "37.2°C"},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
},
|
| 19 |
"max_steps": 10,
|
| 20 |
}
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def grade_hard(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 24 |
+
s, o, a, p = (soap_note.subjective.lower(), soap_note.objective.lower(),
|
| 25 |
+
soap_note.assessment.lower(), soap_note.plan.lower())
|
| 26 |
+
|
| 27 |
+
s_score = 0.5 * any(k in s for k in ("chest pain", "shortness of breath", "sob")) \
|
| 28 |
+
+ 0.5 * any(k in s for k in ("nitroglycerin", "contradict", "denied"))
|
| 29 |
+
o_score = 0.5 * any(k in o for k in ("d-dimer", "1840", "d dimer")) \
|
| 30 |
+
+ 0.5 * any(k in o for k in ("allergy", "contrast", "troponin"))
|
| 31 |
+
a_score = 0.5 * any(k in a for k in ("acs", "acute coronary", "coronary", "ischemia")) \
|
| 32 |
+
+ 0.5 * any(k in a for k in ("pe", "pulmonary embolism", "embolism"))
|
| 33 |
+
p_score = 0.5 * any(k in p for k in ("v/q", "ventilation", "perfusion")) \
|
| 34 |
+
+ 0.5 * any(k in p for k in ("icu", "admit", "cardiac"))
|
| 35 |
+
|
| 36 |
+
clamp = lambda v: max(0.01, min(v, 0.99))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
return {
|
| 38 |
+
"subjective_accuracy": clamp(s_score),
|
| 39 |
+
"objective_accuracy": clamp(o_score),
|
| 40 |
+
"assessment_accuracy": clamp(a_score),
|
| 41 |
+
"plan_accuracy": clamp(p_score),
|
| 42 |
}
|
environment/tasks/task_medium.py
CHANGED
|
@@ -1,98 +1,41 @@
|
|
| 1 |
-
"""Medium task — chronic disease follow-up.
|
| 2 |
-
|
| 3 |
-
Grader uses keyword-based clinical rubric scoring to evaluate the SOAP note
|
| 4 |
-
against expected findings from a Type 2 Diabetes / Hypertension follow-up.
|
| 5 |
-
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
-
|
| 9 |
from typing import Any
|
| 10 |
-
|
| 11 |
from environment.models import SOAPNote
|
| 12 |
|
| 13 |
-
|
| 14 |
-
# ---------------------------------------------------------------------------
|
| 15 |
-
# Task definition
|
| 16 |
-
# ---------------------------------------------------------------------------
|
| 17 |
-
|
| 18 |
MEDIUM_TASK: dict[str, Any] = {
|
| 19 |
"task_id": "medium_chronic_disease_followup",
|
| 20 |
"description": "Generate a SOAP note for a Type 2 Diabetes follow-up visit.",
|
| 21 |
"transcript_file": "data/transcripts/medium.txt",
|
| 22 |
"patient_context": {
|
| 23 |
-
"patient_id": "P-2045",
|
| 24 |
-
"name": "Robert Smith",
|
| 25 |
-
"age": 58,
|
| 26 |
-
"sex": "M",
|
| 27 |
"known_conditions": ["Type 2 Diabetes Mellitus", "Hypertension"],
|
| 28 |
-
"current_medications": [
|
| 29 |
-
"Metformin 1000 mg BID",
|
| 30 |
-
"Lisinopril 20 mg daily",
|
| 31 |
-
"Atorvastatin 40 mg daily",
|
| 32 |
-
],
|
| 33 |
"allergies": [],
|
| 34 |
-
"recent_labs": {
|
| 35 |
-
"HbA1c": "7.8%",
|
| 36 |
-
"fasting_glucose": "156 mg/dL",
|
| 37 |
-
"creatinine": "1.1 mg/dL",
|
| 38 |
-
"eGFR": "78 mL/min",
|
| 39 |
-
"LDL": "102 mg/dL",
|
| 40 |
-
},
|
| 41 |
},
|
| 42 |
"max_steps": 8,
|
| 43 |
}
|
| 44 |
|
| 45 |
|
| 46 |
-
# ---------------------------------------------------------------------------
|
| 47 |
-
# Grader
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
-
|
| 50 |
def grade_medium(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# 1. Subjective — dietary habits / statin gap
|
| 66 |
-
s_score = 0.0
|
| 67 |
-
if "restaurant" in text_s or "diet" in text_s or "eating" in text_s:
|
| 68 |
-
s_score += 0.5
|
| 69 |
-
if "statin" in text_s or "gap" in text_s or "missed" in text_s:
|
| 70 |
-
s_score += 0.5
|
| 71 |
-
|
| 72 |
-
# 2. Objective — HbA1c values
|
| 73 |
-
o_score = 0.0
|
| 74 |
-
if "7.8" in text_o or "7.2" in text_o or "a1c" in text_o or "hba1c" in text_o:
|
| 75 |
-
o_score += 0.5
|
| 76 |
-
if "156" in text_o or "fasting glucose" in text_o or "glucose" in text_o:
|
| 77 |
-
o_score += 0.5
|
| 78 |
-
|
| 79 |
-
# 3. Assessment — core diagnoses
|
| 80 |
-
a_score = 0.0
|
| 81 |
-
if "diabetes" in text_a or "t2dm" in text_a or "dm" in text_a:
|
| 82 |
-
a_score += 0.5
|
| 83 |
-
if "hypertension" in text_a or "htn" in text_a or "blood pressure" in text_a:
|
| 84 |
-
a_score += 0.5
|
| 85 |
-
|
| 86 |
-
# 4. Plan — medication changes
|
| 87 |
-
p_score = 0.0
|
| 88 |
-
if "glipizide" in text_p and ("5" in text_p or "add" in text_p):
|
| 89 |
-
p_score += 0.5
|
| 90 |
-
if "lisinopril" in text_p and ("40" in text_p or "increase" in text_p or "uptitrat" in text_p):
|
| 91 |
-
p_score += 0.5
|
| 92 |
-
|
| 93 |
return {
|
| 94 |
-
"subjective_accuracy":
|
| 95 |
-
"objective_accuracy":
|
| 96 |
-
"assessment_accuracy":
|
| 97 |
-
"plan_accuracy":
|
| 98 |
}
|
|
|
|
| 1 |
+
"""Medium task — chronic disease follow-up."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
|
|
|
| 4 |
from typing import Any
|
|
|
|
| 5 |
from environment.models import SOAPNote
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
MEDIUM_TASK: dict[str, Any] = {
|
| 8 |
"task_id": "medium_chronic_disease_followup",
|
| 9 |
"description": "Generate a SOAP note for a Type 2 Diabetes follow-up visit.",
|
| 10 |
"transcript_file": "data/transcripts/medium.txt",
|
| 11 |
"patient_context": {
|
| 12 |
+
"patient_id": "P-2045", "name": "Robert Smith", "age": 58, "sex": "M",
|
|
|
|
|
|
|
|
|
|
| 13 |
"known_conditions": ["Type 2 Diabetes Mellitus", "Hypertension"],
|
| 14 |
+
"current_medications": ["Metformin 1000 mg BID", "Lisinopril 20 mg daily", "Atorvastatin 40 mg daily"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"allergies": [],
|
| 16 |
+
"recent_labs": {"HbA1c": "7.8%", "fasting_glucose": "156 mg/dL", "creatinine": "1.1 mg/dL", "eGFR": "78 mL/min", "LDL": "102 mg/dL"},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
},
|
| 18 |
"max_steps": 8,
|
| 19 |
}
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def grade_medium(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 23 |
+
s, o, a, p = (soap_note.subjective.lower(), soap_note.objective.lower(),
|
| 24 |
+
soap_note.assessment.lower(), soap_note.plan.lower())
|
| 25 |
+
|
| 26 |
+
s_score = 0.5 * any(k in s for k in ("restaurant", "diet", "eating")) \
|
| 27 |
+
+ 0.5 * any(k in s for k in ("statin", "gap", "missed"))
|
| 28 |
+
o_score = 0.5 * any(k in o for k in ("7.8", "7.2", "a1c", "hba1c")) \
|
| 29 |
+
+ 0.5 * any(k in o for k in ("156", "fasting glucose", "glucose"))
|
| 30 |
+
a_score = 0.5 * any(k in a for k in ("diabetes", "t2dm", "dm")) \
|
| 31 |
+
+ 0.5 * any(k in a for k in ("hypertension", "htn", "blood pressure"))
|
| 32 |
+
p_score = 0.5 * ("glipizide" in p and any(k in p for k in ("5", "add"))) \
|
| 33 |
+
+ 0.5 * ("lisinopril" in p and any(k in p for k in ("40", "increase", "uptitrat")))
|
| 34 |
+
|
| 35 |
+
clamp = lambda v: max(0.01, min(v, 0.99))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
return {
|
| 37 |
+
"subjective_accuracy": clamp(s_score),
|
| 38 |
+
"objective_accuracy": clamp(o_score),
|
| 39 |
+
"assessment_accuracy": clamp(a_score),
|
| 40 |
+
"plan_accuracy": clamp(p_score),
|
| 41 |
}
|
frontend/app.js
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Clinical Note Scribe — Frontend Logic
|
| 2 |
+
|
| 3 |
+
const API = "";
|
| 4 |
+
|
| 5 |
+
// DOM refs
|
| 6 |
+
const taskSelect = document.getElementById("taskSelect");
|
| 7 |
+
const resetBtn = document.getElementById("resetBtn");
|
| 8 |
+
const stepBtn = document.getElementById("stepBtn");
|
| 9 |
+
const statusBadge = document.getElementById("statusBadge");
|
| 10 |
+
const contextSection = document.getElementById("contextSection");
|
| 11 |
+
const contextGrid = document.getElementById("contextGrid");
|
| 12 |
+
const transcriptArea = document.getElementById("transcriptArea");
|
| 13 |
+
const actionSection = document.getElementById("actionSection");
|
| 14 |
+
const actionType = document.getElementById("actionType");
|
| 15 |
+
const sectionSelect = document.getElementById("sectionSelect");
|
| 16 |
+
const soapInputs = document.getElementById("soapInputs");
|
| 17 |
+
const reviseInput = document.getElementById("reviseInput");
|
| 18 |
+
const clarifyInput = document.getElementById("clarifyInput");
|
| 19 |
+
const rewardSection = document.getElementById("rewardSection");
|
| 20 |
+
const scoreValue = document.getElementById("scoreValue");
|
| 21 |
+
const rewardFill = document.getElementById("rewardFill");
|
| 22 |
+
const draftArea = document.getElementById("draftArea");
|
| 23 |
+
const draftEmpty = document.getElementById("draftEmpty");
|
| 24 |
+
const soapDraft = document.getElementById("soapDraft");
|
| 25 |
+
const soapGrid = document.getElementById("soapGrid");
|
| 26 |
+
const logContainer = document.getElementById("logContainer");
|
| 27 |
+
|
| 28 |
+
let currentObs = null;
|
| 29 |
+
let isDone = false;
|
| 30 |
+
|
| 31 |
+
// Logging
|
| 32 |
+
function addLog(msg, type = "") {
|
| 33 |
+
const time = new Date().toLocaleTimeString("en-US", { hour12: false });
|
| 34 |
+
const entry = document.createElement("div");
|
| 35 |
+
entry.className = "log-entry " + type;
|
| 36 |
+
entry.innerHTML = `<span class="log-time">${time}</span>${msg}`;
|
| 37 |
+
logContainer.prepend(entry);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Status badge
|
| 41 |
+
function setStatus(state) {
|
| 42 |
+
statusBadge.className = "status-badge " + state;
|
| 43 |
+
statusBadge.textContent = state === "idle" ? "Idle" : state === "active" ? "Active" : "Done";
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// Toggle action inputs based on action type
|
| 47 |
+
actionType.addEventListener("change", () => {
|
| 48 |
+
const val = actionType.value;
|
| 49 |
+
soapInputs.style.display = val === "submit_note" ? "block" : "none";
|
| 50 |
+
reviseInput.style.display = val === "revise_section" ? "block" : "none";
|
| 51 |
+
clarifyInput.style.display = val === "request_clarify" ? "block" : "none";
|
| 52 |
+
sectionSelect.style.display = val === "revise_section" ? "inline-block" : "none";
|
| 53 |
+
});
|
| 54 |
+
|
| 55 |
+
// Format transcript with speaker highlighting
|
| 56 |
+
function renderTranscript(text) {
|
| 57 |
+
if (!text) return "";
|
| 58 |
+
const lines = text.split("\n");
|
| 59 |
+
return lines.map(line => {
|
| 60 |
+
if (/^(Dr\.|Doctor)/i.test(line.trim())) {
|
| 61 |
+
return `<div><span class="speaker-doctor">${escapeHtml(line)}</span></div>`;
|
| 62 |
+
} else if (/^(Patient|Pt)/i.test(line.trim())) {
|
| 63 |
+
return `<div><span class="speaker-patient">${escapeHtml(line)}</span></div>`;
|
| 64 |
+
}
|
| 65 |
+
return `<div>${escapeHtml(line)}</div>`;
|
| 66 |
+
}).join("");
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
function escapeHtml(str) {
|
| 70 |
+
const div = document.createElement("div");
|
| 71 |
+
div.textContent = str;
|
| 72 |
+
return div.innerHTML;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// Render patient context as cards
|
| 76 |
+
function renderContext(ctx) {
|
| 77 |
+
if (!ctx || Object.keys(ctx).length === 0) {
|
| 78 |
+
contextSection.style.display = "none";
|
| 79 |
+
return;
|
| 80 |
+
}
|
| 81 |
+
contextSection.style.display = "block";
|
| 82 |
+
contextGrid.innerHTML = "";
|
| 83 |
+
|
| 84 |
+
const flat = flattenContext(ctx);
|
| 85 |
+
for (const [key, val] of Object.entries(flat)) {
|
| 86 |
+
const card = document.createElement("div");
|
| 87 |
+
card.className = "context-card";
|
| 88 |
+
card.innerHTML = `<div class="label">${escapeHtml(key)}</div><div class="value">${escapeHtml(String(val))}</div>`;
|
| 89 |
+
contextGrid.appendChild(card);
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
function flattenContext(obj, prefix = "") {
|
| 94 |
+
const result = {};
|
| 95 |
+
for (const [k, v] of Object.entries(obj)) {
|
| 96 |
+
const key = prefix ? `${prefix} › ${k}` : k;
|
| 97 |
+
if (v && typeof v === "object" && !Array.isArray(v)) {
|
| 98 |
+
Object.assign(result, flattenContext(v, key));
|
| 99 |
+
} else if (Array.isArray(v)) {
|
| 100 |
+
result[key] = v.length > 0 ? v.join(", ") : "—";
|
| 101 |
+
} else {
|
| 102 |
+
result[key] = v ?? "—";
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
return result;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Render SOAP draft
|
| 109 |
+
function renderDraft(draftText) {
|
| 110 |
+
if (!draftText) {
|
| 111 |
+
draftEmpty.style.display = "flex";
|
| 112 |
+
soapDraft.style.display = "none";
|
| 113 |
+
return;
|
| 114 |
+
}
|
| 115 |
+
draftEmpty.style.display = "none";
|
| 116 |
+
soapDraft.style.display = "block";
|
| 117 |
+
|
| 118 |
+
const sections = { S: "", O: "", A: "", P: "" };
|
| 119 |
+
const lines = draftText.split("\n");
|
| 120 |
+
for (const line of lines) {
|
| 121 |
+
for (const prefix of ["S: ", "O: ", "A: ", "P: "]) {
|
| 122 |
+
if (line.startsWith(prefix)) {
|
| 123 |
+
sections[prefix[0]] = line.slice(prefix.length);
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
const labels = { S: "Subjective", O: "Objective", A: "Assessment", P: "Plan" };
|
| 129 |
+
soapGrid.innerHTML = "";
|
| 130 |
+
for (const [key, label] of Object.entries(labels)) {
|
| 131 |
+
const card = document.createElement("div");
|
| 132 |
+
card.className = `soap-card ${key.toLowerCase()}`;
|
| 133 |
+
card.innerHTML = `
|
| 134 |
+
<div class="soap-label">${label}</div>
|
| 135 |
+
<div class="soap-text">${escapeHtml(sections[key]) || '<em style="opacity:0.4">Empty</em>'}</div>
|
| 136 |
+
`;
|
| 137 |
+
soapGrid.appendChild(card);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// Render reward
|
| 142 |
+
function renderReward(rewardObj) {
|
| 143 |
+
if (!rewardObj) {
|
| 144 |
+
rewardSection.style.display = "none";
|
| 145 |
+
return;
|
| 146 |
+
}
|
| 147 |
+
rewardSection.style.display = "block";
|
| 148 |
+
const val = rewardObj.value;
|
| 149 |
+
scoreValue.textContent = val.toFixed(4);
|
| 150 |
+
rewardFill.style.width = (val * 100) + "%";
|
| 151 |
+
|
| 152 |
+
if (val >= 0.7) {
|
| 153 |
+
scoreValue.style.color = "var(--green)";
|
| 154 |
+
} else if (val >= 0.4) {
|
| 155 |
+
scoreValue.style.color = "var(--yellow)";
|
| 156 |
+
} else {
|
| 157 |
+
scoreValue.style.color = "var(--red)";
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Update UI from observation
|
| 162 |
+
function updateUI(obs, reward = null, done = false) {
|
| 163 |
+
currentObs = obs;
|
| 164 |
+
isDone = done;
|
| 165 |
+
|
| 166 |
+
transcriptArea.innerHTML = `<div class="transcript-box">${renderTranscript(obs.transcript)}</div>`;
|
| 167 |
+
renderContext(obs.patient_context);
|
| 168 |
+
renderDraft(obs.current_draft);
|
| 169 |
+
|
| 170 |
+
if (reward) renderReward(reward);
|
| 171 |
+
|
| 172 |
+
actionSection.style.display = done ? "none" : "block";
|
| 173 |
+
setStatus(done ? "done" : "active");
|
| 174 |
+
|
| 175 |
+
if (done) {
|
| 176 |
+
addLog(`Episode complete — score: ${reward ? reward.value.toFixed(4) : "N/A"}`, "success");
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Reset
|
| 181 |
+
resetBtn.addEventListener("click", async () => {
|
| 182 |
+
const taskId = taskSelect.value;
|
| 183 |
+
resetBtn.disabled = true;
|
| 184 |
+
addLog(`Resetting with task: ${taskId}`);
|
| 185 |
+
|
| 186 |
+
try {
|
| 187 |
+
const res = await fetch(`${API}/reset`, {
|
| 188 |
+
method: "POST",
|
| 189 |
+
headers: { "Content-Type": "application/json" },
|
| 190 |
+
body: JSON.stringify({ task_id: taskId }),
|
| 191 |
+
});
|
| 192 |
+
if (!res.ok) throw new Error(await res.text());
|
| 193 |
+
const obs = await res.json();
|
| 194 |
+
rewardSection.style.display = "none";
|
| 195 |
+
updateUI(obs);
|
| 196 |
+
addLog("Environment reset successfully", "success");
|
| 197 |
+
} catch (err) {
|
| 198 |
+
addLog(`Reset failed: ${err.message}`, "error");
|
| 199 |
+
} finally {
|
| 200 |
+
resetBtn.disabled = false;
|
| 201 |
+
}
|
| 202 |
+
});
|
| 203 |
+
|
| 204 |
+
// Step
|
| 205 |
+
stepBtn.addEventListener("click", async () => {
|
| 206 |
+
if (isDone) {
|
| 207 |
+
addLog("Episode is done. Reset first.", "error");
|
| 208 |
+
return;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
const action = actionType.value;
|
| 212 |
+
let payload = {};
|
| 213 |
+
|
| 214 |
+
if (action === "submit_note") {
|
| 215 |
+
payload = {
|
| 216 |
+
action_type: "submit_note",
|
| 217 |
+
soap_note: {
|
| 218 |
+
subjective: document.getElementById("inputS").value,
|
| 219 |
+
objective: document.getElementById("inputO").value,
|
| 220 |
+
assessment: document.getElementById("inputA").value,
|
| 221 |
+
plan: document.getElementById("inputP").value,
|
| 222 |
+
},
|
| 223 |
+
};
|
| 224 |
+
} else if (action === "revise_section") {
|
| 225 |
+
payload = {
|
| 226 |
+
action_type: "revise_section",
|
| 227 |
+
section: sectionSelect.value,
|
| 228 |
+
revision_text: document.getElementById("inputRevision").value,
|
| 229 |
+
};
|
| 230 |
+
} else if (action === "request_clarify") {
|
| 231 |
+
payload = {
|
| 232 |
+
action_type: "request_clarify",
|
| 233 |
+
clarify_question: document.getElementById("inputClarify").value,
|
| 234 |
+
};
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
stepBtn.disabled = true;
|
| 238 |
+
addLog(`Sending action: ${action}`);
|
| 239 |
+
|
| 240 |
+
try {
|
| 241 |
+
const res = await fetch(`${API}/step`, {
|
| 242 |
+
method: "POST",
|
| 243 |
+
headers: { "Content-Type": "application/json" },
|
| 244 |
+
body: JSON.stringify(payload),
|
| 245 |
+
});
|
| 246 |
+
if (!res.ok) throw new Error(await res.text());
|
| 247 |
+
const data = await res.json();
|
| 248 |
+
updateUI(data.observation, data.reward, data.done);
|
| 249 |
+
|
| 250 |
+
if (data.info && data.info.clarify_answer) {
|
| 251 |
+
addLog(`Clarify answer: ${data.info.clarify_answer}`);
|
| 252 |
+
}
|
| 253 |
+
addLog(`Step done — reward: ${data.reward.value.toFixed(4)}, done: ${data.done}`);
|
| 254 |
+
} catch (err) {
|
| 255 |
+
addLog(`Step failed: ${err.message}`, "error");
|
| 256 |
+
} finally {
|
| 257 |
+
stepBtn.disabled = false;
|
| 258 |
+
}
|
| 259 |
+
});
|
| 260 |
+
|
| 261 |
+
// Init
|
| 262 |
+
addLog("Frontend loaded. Select a task and click Reset.");
|
frontend/index.html
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Clinical Note Scribe</title>
|
| 7 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
| 8 |
+
<style>
|
| 9 |
+
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 10 |
+
|
| 11 |
+
:root {
|
| 12 |
+
--bg: #0f1117;
|
| 13 |
+
--surface: #1a1d27;
|
| 14 |
+
--surface-2: #242836;
|
| 15 |
+
--border: #2e3345;
|
| 16 |
+
--text: #e4e6ef;
|
| 17 |
+
--text-muted: #8b8fa3;
|
| 18 |
+
--accent: #4f8cff;
|
| 19 |
+
--accent-glow: rgba(79, 140, 255, 0.15);
|
| 20 |
+
--green: #34d399;
|
| 21 |
+
--red: #f87171;
|
| 22 |
+
--yellow: #fbbf24;
|
| 23 |
+
--radius: 10px;
|
| 24 |
+
--font: 'Inter', -apple-system, sans-serif;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
body {
|
| 28 |
+
font-family: var(--font);
|
| 29 |
+
background: var(--bg);
|
| 30 |
+
color: var(--text);
|
| 31 |
+
min-height: 100vh;
|
| 32 |
+
line-height: 1.5;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
/* Header */
|
| 36 |
+
.header {
|
| 37 |
+
display: flex;
|
| 38 |
+
align-items: center;
|
| 39 |
+
justify-content: space-between;
|
| 40 |
+
padding: 16px 28px;
|
| 41 |
+
border-bottom: 1px solid var(--border);
|
| 42 |
+
background: var(--surface);
|
| 43 |
+
}
|
| 44 |
+
.header h1 {
|
| 45 |
+
font-size: 18px;
|
| 46 |
+
font-weight: 600;
|
| 47 |
+
display: flex;
|
| 48 |
+
align-items: center;
|
| 49 |
+
gap: 8px;
|
| 50 |
+
}
|
| 51 |
+
.header h1 span { font-size: 22px; }
|
| 52 |
+
.header-right {
|
| 53 |
+
display: flex;
|
| 54 |
+
align-items: center;
|
| 55 |
+
gap: 12px;
|
| 56 |
+
}
|
| 57 |
+
.status-badge {
|
| 58 |
+
font-size: 12px;
|
| 59 |
+
padding: 4px 10px;
|
| 60 |
+
border-radius: 20px;
|
| 61 |
+
font-weight: 500;
|
| 62 |
+
}
|
| 63 |
+
.status-badge.idle { background: var(--surface-2); color: var(--text-muted); }
|
| 64 |
+
.status-badge.active { background: rgba(52,211,153,0.15); color: var(--green); }
|
| 65 |
+
.status-badge.done { background: rgba(79,140,255,0.15); color: var(--accent); }
|
| 66 |
+
|
| 67 |
+
/* Layout */
|
| 68 |
+
.layout {
|
| 69 |
+
display: grid;
|
| 70 |
+
grid-template-columns: 1fr 1fr;
|
| 71 |
+
gap: 0;
|
| 72 |
+
height: calc(100vh - 57px);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.panel {
|
| 76 |
+
display: flex;
|
| 77 |
+
flex-direction: column;
|
| 78 |
+
overflow: hidden;
|
| 79 |
+
}
|
| 80 |
+
.panel-left { border-right: 1px solid var(--border); }
|
| 81 |
+
|
| 82 |
+
.panel-section {
|
| 83 |
+
padding: 16px 20px;
|
| 84 |
+
border-bottom: 1px solid var(--border);
|
| 85 |
+
}
|
| 86 |
+
.panel-section-title {
|
| 87 |
+
font-size: 11px;
|
| 88 |
+
font-weight: 600;
|
| 89 |
+
text-transform: uppercase;
|
| 90 |
+
letter-spacing: 0.8px;
|
| 91 |
+
color: var(--text-muted);
|
| 92 |
+
margin-bottom: 10px;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.scrollable {
|
| 96 |
+
flex: 1;
|
| 97 |
+
overflow-y: auto;
|
| 98 |
+
padding: 16px 20px;
|
| 99 |
+
}
|
| 100 |
+
.scrollable::-webkit-scrollbar { width: 6px; }
|
| 101 |
+
.scrollable::-webkit-scrollbar-track { background: transparent; }
|
| 102 |
+
.scrollable::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
|
| 103 |
+
|
| 104 |
+
/* Controls */
|
| 105 |
+
.controls {
|
| 106 |
+
display: flex;
|
| 107 |
+
gap: 8px;
|
| 108 |
+
flex-wrap: wrap;
|
| 109 |
+
align-items: center;
|
| 110 |
+
}
|
| 111 |
+
select, input, textarea {
|
| 112 |
+
font-family: var(--font);
|
| 113 |
+
font-size: 13px;
|
| 114 |
+
background: var(--surface-2);
|
| 115 |
+
border: 1px solid var(--border);
|
| 116 |
+
color: var(--text);
|
| 117 |
+
border-radius: 6px;
|
| 118 |
+
padding: 7px 10px;
|
| 119 |
+
outline: none;
|
| 120 |
+
transition: border-color 0.15s;
|
| 121 |
+
}
|
| 122 |
+
select:focus, input:focus, textarea:focus {
|
| 123 |
+
border-color: var(--accent);
|
| 124 |
+
}
|
| 125 |
+
textarea {
|
| 126 |
+
width: 100%;
|
| 127 |
+
resize: vertical;
|
| 128 |
+
min-height: 80px;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
.btn {
|
| 132 |
+
font-family: var(--font);
|
| 133 |
+
font-size: 13px;
|
| 134 |
+
font-weight: 500;
|
| 135 |
+
padding: 7px 16px;
|
| 136 |
+
border: none;
|
| 137 |
+
border-radius: 6px;
|
| 138 |
+
cursor: pointer;
|
| 139 |
+
transition: all 0.15s;
|
| 140 |
+
}
|
| 141 |
+
.btn:disabled { opacity: 0.4; cursor: not-allowed; }
|
| 142 |
+
.btn-primary { background: var(--accent); color: #fff; }
|
| 143 |
+
.btn-primary:hover:not(:disabled) { background: #3d7ae8; }
|
| 144 |
+
.btn-secondary { background: var(--surface-2); color: var(--text); border: 1px solid var(--border); }
|
| 145 |
+
.btn-secondary:hover:not(:disabled) { background: var(--border); }
|
| 146 |
+
.btn-green { background: var(--green); color: #0f1117; }
|
| 147 |
+
.btn-green:hover:not(:disabled) { background: #2bc48d; }
|
| 148 |
+
|
| 149 |
+
/* Transcript */
|
| 150 |
+
.transcript-box {
|
| 151 |
+
font-size: 13px;
|
| 152 |
+
white-space: pre-wrap;
|
| 153 |
+
color: var(--text);
|
| 154 |
+
line-height: 1.7;
|
| 155 |
+
}
|
| 156 |
+
.transcript-box .speaker-doctor { color: var(--accent); font-weight: 500; }
|
| 157 |
+
.transcript-box .speaker-patient { color: var(--green); font-weight: 500; }
|
| 158 |
+
|
| 159 |
+
/* Context cards */
|
| 160 |
+
.context-grid {
|
| 161 |
+
display: grid;
|
| 162 |
+
grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
|
| 163 |
+
gap: 8px;
|
| 164 |
+
}
|
| 165 |
+
.context-card {
|
| 166 |
+
background: var(--surface-2);
|
| 167 |
+
border-radius: 8px;
|
| 168 |
+
padding: 10px 12px;
|
| 169 |
+
}
|
| 170 |
+
.context-card .label {
|
| 171 |
+
font-size: 10px;
|
| 172 |
+
font-weight: 600;
|
| 173 |
+
text-transform: uppercase;
|
| 174 |
+
letter-spacing: 0.5px;
|
| 175 |
+
color: var(--text-muted);
|
| 176 |
+
margin-bottom: 3px;
|
| 177 |
+
}
|
| 178 |
+
.context-card .value {
|
| 179 |
+
font-size: 13px;
|
| 180 |
+
color: var(--text);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
/* SOAP Draft */
|
| 184 |
+
.soap-grid {
|
| 185 |
+
display: grid;
|
| 186 |
+
grid-template-columns: 1fr 1fr;
|
| 187 |
+
gap: 10px;
|
| 188 |
+
}
|
| 189 |
+
.soap-card {
|
| 190 |
+
background: var(--surface-2);
|
| 191 |
+
border-radius: 8px;
|
| 192 |
+
padding: 12px 14px;
|
| 193 |
+
border-left: 3px solid var(--border);
|
| 194 |
+
}
|
| 195 |
+
.soap-card.s { border-left-color: #60a5fa; }
|
| 196 |
+
.soap-card.o { border-left-color: #34d399; }
|
| 197 |
+
.soap-card.a { border-left-color: #fbbf24; }
|
| 198 |
+
.soap-card.p { border-left-color: #c084fc; }
|
| 199 |
+
.soap-card .soap-label {
|
| 200 |
+
font-size: 11px;
|
| 201 |
+
font-weight: 600;
|
| 202 |
+
text-transform: uppercase;
|
| 203 |
+
letter-spacing: 0.5px;
|
| 204 |
+
margin-bottom: 6px;
|
| 205 |
+
}
|
| 206 |
+
.soap-card.s .soap-label { color: #60a5fa; }
|
| 207 |
+
.soap-card.o .soap-label { color: #34d399; }
|
| 208 |
+
.soap-card.a .soap-label { color: #fbbf24; }
|
| 209 |
+
.soap-card.p .soap-label { color: #c084fc; }
|
| 210 |
+
.soap-card .soap-text {
|
| 211 |
+
font-size: 13px;
|
| 212 |
+
color: var(--text-muted);
|
| 213 |
+
line-height: 1.6;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
/* Reward bar */
|
| 217 |
+
.reward-bar {
|
| 218 |
+
display: flex;
|
| 219 |
+
align-items: center;
|
| 220 |
+
gap: 12px;
|
| 221 |
+
padding: 10px 0;
|
| 222 |
+
}
|
| 223 |
+
.reward-bar .score-value {
|
| 224 |
+
font-size: 28px;
|
| 225 |
+
font-weight: 700;
|
| 226 |
+
font-variant-numeric: tabular-nums;
|
| 227 |
+
}
|
| 228 |
+
.reward-meter {
|
| 229 |
+
flex: 1;
|
| 230 |
+
height: 8px;
|
| 231 |
+
background: var(--surface-2);
|
| 232 |
+
border-radius: 4px;
|
| 233 |
+
overflow: hidden;
|
| 234 |
+
}
|
| 235 |
+
.reward-meter-fill {
|
| 236 |
+
height: 100%;
|
| 237 |
+
border-radius: 4px;
|
| 238 |
+
background: linear-gradient(90deg, var(--accent), var(--green));
|
| 239 |
+
transition: width 0.4s ease;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
/* Action form */
|
| 243 |
+
.action-form {
|
| 244 |
+
display: flex;
|
| 245 |
+
flex-direction: column;
|
| 246 |
+
gap: 10px;
|
| 247 |
+
}
|
| 248 |
+
.action-row {
|
| 249 |
+
display: flex;
|
| 250 |
+
gap: 8px;
|
| 251 |
+
align-items: flex-start;
|
| 252 |
+
}
|
| 253 |
+
.action-row select { min-width: 160px; }
|
| 254 |
+
|
| 255 |
+
/* Log */
|
| 256 |
+
.log-container {
|
| 257 |
+
font-size: 12px;
|
| 258 |
+
font-family: 'SF Mono', 'Fira Code', monospace;
|
| 259 |
+
color: var(--text-muted);
|
| 260 |
+
max-height: 120px;
|
| 261 |
+
overflow-y: auto;
|
| 262 |
+
padding: 8px 0;
|
| 263 |
+
}
|
| 264 |
+
.log-entry {
|
| 265 |
+
padding: 2px 0;
|
| 266 |
+
border-bottom: 1px solid rgba(46, 51, 69, 0.4);
|
| 267 |
+
}
|
| 268 |
+
.log-entry .log-time { color: var(--border); margin-right: 8px; }
|
| 269 |
+
.log-entry.error { color: var(--red); }
|
| 270 |
+
.log-entry.success { color: var(--green); }
|
| 271 |
+
|
| 272 |
+
/* Empty state */
|
| 273 |
+
.empty-state {
|
| 274 |
+
display: flex;
|
| 275 |
+
flex-direction: column;
|
| 276 |
+
align-items: center;
|
| 277 |
+
justify-content: center;
|
| 278 |
+
height: 100%;
|
| 279 |
+
color: var(--text-muted);
|
| 280 |
+
text-align: center;
|
| 281 |
+
gap: 12px;
|
| 282 |
+
}
|
| 283 |
+
.empty-state .icon { font-size: 36px; opacity: 0.5; }
|
| 284 |
+
.empty-state p { font-size: 14px; max-width: 280px; }
|
| 285 |
+
</style>
|
| 286 |
+
</head>
|
| 287 |
+
<body>
|
| 288 |
+
|
| 289 |
+
<header class="header">
|
| 290 |
+
<h1><span>🏥</span> Clinical Note Scribe</h1>
|
| 291 |
+
<div class="header-right">
|
| 292 |
+
<span class="status-badge idle" id="statusBadge">Idle</span>
|
| 293 |
+
</div>
|
| 294 |
+
</header>
|
| 295 |
+
|
| 296 |
+
<div class="layout">
|
| 297 |
+
|
| 298 |
+
<!-- Left Panel: Transcript + Context -->
|
| 299 |
+
<div class="panel panel-left">
|
| 300 |
+
<div class="panel-section">
|
| 301 |
+
<div class="controls">
|
| 302 |
+
<select id="taskSelect">
|
| 303 |
+
<option value="easy_routine_checkup">🟢 Easy — Routine Check-Up</option>
|
| 304 |
+
<option value="medium_chronic_disease_followup">🟡 Medium — Chronic Follow-Up</option>
|
| 305 |
+
<option value="hard_complex_er_visit">🔴 Hard — Complex ER Visit</option>
|
| 306 |
+
</select>
|
| 307 |
+
<button class="btn btn-primary" id="resetBtn">Reset</button>
|
| 308 |
+
</div>
|
| 309 |
+
</div>
|
| 310 |
+
|
| 311 |
+
<div class="panel-section" id="contextSection" style="display:none;">
|
| 312 |
+
<div class="panel-section-title">Patient Context</div>
|
| 313 |
+
<div class="context-grid" id="contextGrid"></div>
|
| 314 |
+
</div>
|
| 315 |
+
|
| 316 |
+
<div class="scrollable" id="transcriptArea">
|
| 317 |
+
<div class="empty-state">
|
| 318 |
+
<div class="icon">📋</div>
|
| 319 |
+
<p>Select a task and click <strong>Reset</strong> to load the transcript.</p>
|
| 320 |
+
</div>
|
| 321 |
+
</div>
|
| 322 |
+
</div>
|
| 323 |
+
|
| 324 |
+
<!-- Right Panel: Actions + Draft + Reward -->
|
| 325 |
+
<div class="panel">
|
| 326 |
+
|
| 327 |
+
<div class="panel-section" id="actionSection" style="display:none;">
|
| 328 |
+
<div class="panel-section-title">Action</div>
|
| 329 |
+
<div class="action-form">
|
| 330 |
+
<div class="action-row">
|
| 331 |
+
<select id="actionType">
|
| 332 |
+
<option value="submit_note">Submit Note</option>
|
| 333 |
+
<option value="revise_section">Revise Section</option>
|
| 334 |
+
<option value="request_clarify">Request Clarify</option>
|
| 335 |
+
</select>
|
| 336 |
+
<select id="sectionSelect" style="display:none;">
|
| 337 |
+
<option value="S">Subjective</option>
|
| 338 |
+
<option value="O">Objective</option>
|
| 339 |
+
<option value="A">Assessment</option>
|
| 340 |
+
<option value="P">Plan</option>
|
| 341 |
+
</select>
|
| 342 |
+
<button class="btn btn-green" id="stepBtn">Send</button>
|
| 343 |
+
</div>
|
| 344 |
+
<div id="soapInputs">
|
| 345 |
+
<div style="display:grid; grid-template-columns:1fr 1fr; gap:8px;">
|
| 346 |
+
<textarea id="inputS" placeholder="Subjective..." rows="3"></textarea>
|
| 347 |
+
<textarea id="inputO" placeholder="Objective..." rows="3"></textarea>
|
| 348 |
+
<textarea id="inputA" placeholder="Assessment..." rows="3"></textarea>
|
| 349 |
+
<textarea id="inputP" placeholder="Plan..." rows="3"></textarea>
|
| 350 |
+
</div>
|
| 351 |
+
</div>
|
| 352 |
+
<div id="reviseInput" style="display:none;">
|
| 353 |
+
<textarea id="inputRevision" placeholder="Revision text..." rows="3"></textarea>
|
| 354 |
+
</div>
|
| 355 |
+
<div id="clarifyInput" style="display:none;">
|
| 356 |
+
<input id="inputClarify" type="text" placeholder="Your question..." style="width:100%;">
|
| 357 |
+
</div>
|
| 358 |
+
</div>
|
| 359 |
+
</div>
|
| 360 |
+
|
| 361 |
+
<div class="panel-section" id="rewardSection" style="display:none;">
|
| 362 |
+
<div class="panel-section-title">Reward</div>
|
| 363 |
+
<div class="reward-bar">
|
| 364 |
+
<div class="score-value" id="scoreValue">0.00</div>
|
| 365 |
+
<div class="reward-meter">
|
| 366 |
+
<div class="reward-meter-fill" id="rewardFill" style="width:0%"></div>
|
| 367 |
+
</div>
|
| 368 |
+
</div>
|
| 369 |
+
</div>
|
| 370 |
+
|
| 371 |
+
<div class="scrollable" id="draftArea">
|
| 372 |
+
<div class="empty-state" id="draftEmpty">
|
| 373 |
+
<div class="icon">📝</div>
|
| 374 |
+
<p>Your SOAP note draft will appear here after you submit or revise.</p>
|
| 375 |
+
</div>
|
| 376 |
+
<div id="soapDraft" style="display:none;">
|
| 377 |
+
<div class="panel-section-title" style="margin-bottom:12px;">Current Draft</div>
|
| 378 |
+
<div class="soap-grid" id="soapGrid"></div>
|
| 379 |
+
</div>
|
| 380 |
+
</div>
|
| 381 |
+
|
| 382 |
+
<div class="panel-section" style="border-top:1px solid var(--border); border-bottom:none;">
|
| 383 |
+
<div class="panel-section-title">Log</div>
|
| 384 |
+
<div class="log-container" id="logContainer"></div>
|
| 385 |
+
</div>
|
| 386 |
+
|
| 387 |
+
</div>
|
| 388 |
+
</div>
|
| 389 |
+
|
| 390 |
+
<script src="/static/app.js"></script>
|
| 391 |
+
</body>
|
| 392 |
+
</html>
|
inference.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
Inference Script — Clinical Note Scribe
|
| 3 |
-
===================================
|
| 4 |
MANDATORY
|
| 5 |
-
- Before submitting, ensure the following variables are defined
|
| 6 |
API_BASE_URL The API endpoint for the LLM.
|
| 7 |
MODEL_NAME The model identifier to use for inference.
|
| 8 |
HF_TOKEN Your Hugging Face / API key.
|
| 9 |
-
LOCAL_IMAGE_NAME The name of the local image
|
| 10 |
-
if you are using from_docker_image() method.
|
| 11 |
|
| 12 |
- Defaults are set only for API_BASE_URL and MODEL_NAME:
|
| 13 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 14 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 15 |
|
| 16 |
-
- The inference script must be named
|
| 17 |
- Participants must use OpenAI Client for all LLM calls using above variables.
|
| 18 |
|
| 19 |
STDOUT FORMAT
|
|
@@ -47,29 +46,20 @@
|
|
| 47 |
|
| 48 |
from openai import OpenAI
|
| 49 |
|
| 50 |
-
# ---------------------------------------------------------------------------
|
| 51 |
# Silence the underlying env's stdout JSON logs (redirect them to stderr)
|
| 52 |
-
# ---------------------------------------------------------------------------
|
| 53 |
env_logger = logging.getLogger("clinical_note_scribe")
|
| 54 |
env_logger.setLevel(logging.INFO)
|
| 55 |
env_logger.handlers.clear()
|
| 56 |
env_logger.addHandler(logging.StreamHandler(sys.stderr))
|
| 57 |
env_logger.propagate = False
|
| 58 |
|
| 59 |
-
|
| 60 |
-
# ---------------------------------------------------------------------------
|
| 61 |
# Environment imports
|
| 62 |
-
# ---------------------------------------------------------------------------
|
| 63 |
-
|
| 64 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 65 |
|
| 66 |
from environment import ClinicalNoteScribeEnv, Action, SOAPNote # noqa: E402
|
| 67 |
from environment.tasks import TASK_REGISTRY # noqa: E402
|
| 68 |
|
| 69 |
-
# ---------------------------------------------------------------------------
|
| 70 |
# Config
|
| 71 |
-
# ---------------------------------------------------------------------------
|
| 72 |
-
|
| 73 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 74 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 75 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
|
@@ -77,14 +67,11 @@
|
|
| 77 |
|
| 78 |
BENCHMARK = "clinical-note-scribe"
|
| 79 |
TASK_IDS = list(TASK_REGISTRY.keys())
|
| 80 |
-
MAX_STEPS = 5
|
| 81 |
MAX_TOKENS = 1024
|
| 82 |
TEMPERATURE = 0.2
|
| 83 |
|
| 84 |
-
# ---------------------------------------------------------------------------
|
| 85 |
# System prompt
|
| 86 |
-
# ---------------------------------------------------------------------------
|
| 87 |
-
|
| 88 |
SYSTEM_PROMPT = textwrap.dedent("""\
|
| 89 |
You are a clinical documentation assistant. Given a doctor-patient transcript
|
| 90 |
and patient context, generate a concise, clinically accurate SOAP note.
|
|
@@ -109,10 +96,7 @@
|
|
| 109 |
""").strip()
|
| 110 |
|
| 111 |
|
| 112 |
-
# ---------------------------------------------------------------------------
|
| 113 |
# Stdout logging — mandatory hackathon format
|
| 114 |
-
# ---------------------------------------------------------------------------
|
| 115 |
-
|
| 116 |
def log_start(task: str, env: str, model: str) -> None:
|
| 117 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 118 |
|
|
@@ -134,10 +118,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 134 |
)
|
| 135 |
|
| 136 |
|
| 137 |
-
# ---------------------------------------------------------------------------
|
| 138 |
# Helpers
|
| 139 |
-
# ---------------------------------------------------------------------------
|
| 140 |
-
|
| 141 |
def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
|
| 142 |
"""Build the user message containing the transcript and context."""
|
| 143 |
ctx_str = json.dumps(patient_context, indent=2, default=str)
|
|
@@ -187,10 +168,7 @@ def get_soap_note(client: OpenAI, transcript: str, patient_context: dict[str, An
|
|
| 187 |
raise
|
| 188 |
|
| 189 |
|
| 190 |
-
# ---------------------------------------------------------------------------
|
| 191 |
# Per-task runner
|
| 192 |
-
# ---------------------------------------------------------------------------
|
| 193 |
-
|
| 194 |
def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[str, Any]:
|
| 195 |
"""Run a single task episode and return the result dict."""
|
| 196 |
rewards: List[float] = []
|
|
@@ -202,18 +180,15 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
|
|
| 202 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 203 |
|
| 204 |
try:
|
| 205 |
-
# ---- reset ----
|
| 206 |
obs = env.reset(task_id)
|
| 207 |
|
| 208 |
for step in range(1, MAX_STEPS + 1):
|
| 209 |
-
# ---- generate SOAP note via LLM ----
|
| 210 |
try:
|
| 211 |
action_dict = get_soap_note(client, obs.transcript, obs.patient_context)
|
| 212 |
action = Action(**action_dict)
|
| 213 |
action_str = f"submit_note(sections=S,O,A,P)"
|
| 214 |
except Exception as exc:
|
| 215 |
-
# On model / parse failure, submit an empty note
|
| 216 |
-
# grade to 0.0 (format_valid=0 because fields are empty, grader=0).
|
| 217 |
action = Action(
|
| 218 |
action_type="submit_note",
|
| 219 |
soap_note=SOAPNote(
|
|
@@ -226,14 +201,12 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
|
|
| 226 |
action_str = "submit_note(fallback)"
|
| 227 |
last_error = str(exc)
|
| 228 |
|
| 229 |
-
# ---- step ----
|
| 230 |
obs, reward_obj, done, info = env.step(action)
|
| 231 |
|
| 232 |
reward_val = reward_obj.value
|
| 233 |
rewards.append(reward_val)
|
| 234 |
steps_taken = step
|
| 235 |
|
| 236 |
-
# Check for env-level errors
|
| 237 |
error_msg = None
|
| 238 |
if obs.errors_so_far:
|
| 239 |
error_msg = obs.errors_so_far[-1]
|
|
@@ -252,7 +225,6 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
|
|
| 252 |
if done:
|
| 253 |
break
|
| 254 |
|
| 255 |
-
# Final score = last reward value (already in [0, 1])
|
| 256 |
score = rewards[-1] if rewards else 0.0
|
| 257 |
score = min(max(score, 0.0), 1.0)
|
| 258 |
success = score > 0.0
|
|
@@ -274,10 +246,7 @@ def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[s
|
|
| 274 |
}
|
| 275 |
|
| 276 |
|
| 277 |
-
# ---------------------------------------------------------------------------
|
| 278 |
# Main
|
| 279 |
-
# ---------------------------------------------------------------------------
|
| 280 |
-
|
| 281 |
def main() -> None:
|
| 282 |
if not HF_TOKEN:
|
| 283 |
print(
|
|
@@ -295,7 +264,7 @@ def main() -> None:
|
|
| 295 |
result = run_task(client, env, task_id)
|
| 296 |
results.append(result)
|
| 297 |
|
| 298 |
-
#
|
| 299 |
print("", file=sys.stderr, flush=True)
|
| 300 |
print("=" * 60, file=sys.stderr, flush=True)
|
| 301 |
print(" SUMMARY", file=sys.stderr, flush=True)
|
|
|
|
| 1 |
"""
|
| 2 |
Inference Script — Clinical Note Scribe
|
| 3 |
+
========================================
|
| 4 |
MANDATORY
|
| 5 |
+
- Before submitting, ensure the following variables are defined:
|
| 6 |
API_BASE_URL The API endpoint for the LLM.
|
| 7 |
MODEL_NAME The model identifier to use for inference.
|
| 8 |
HF_TOKEN Your Hugging Face / API key.
|
| 9 |
+
LOCAL_IMAGE_NAME The name of the local image for the environment.
|
|
|
|
| 10 |
|
| 11 |
- Defaults are set only for API_BASE_URL and MODEL_NAME:
|
| 12 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 13 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 14 |
|
| 15 |
+
- The inference script must be named inference.py and placed in the root directory.
|
| 16 |
- Participants must use OpenAI Client for all LLM calls using above variables.
|
| 17 |
|
| 18 |
STDOUT FORMAT
|
|
|
|
| 46 |
|
| 47 |
from openai import OpenAI
|
| 48 |
|
|
|
|
| 49 |
# Silence the underlying env's stdout JSON logs (redirect them to stderr)
|
|
|
|
| 50 |
env_logger = logging.getLogger("clinical_note_scribe")
|
| 51 |
env_logger.setLevel(logging.INFO)
|
| 52 |
env_logger.handlers.clear()
|
| 53 |
env_logger.addHandler(logging.StreamHandler(sys.stderr))
|
| 54 |
env_logger.propagate = False
|
| 55 |
|
|
|
|
|
|
|
| 56 |
# Environment imports
|
|
|
|
|
|
|
| 57 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 58 |
|
| 59 |
from environment import ClinicalNoteScribeEnv, Action, SOAPNote # noqa: E402
|
| 60 |
from environment.tasks import TASK_REGISTRY # noqa: E402
|
| 61 |
|
|
|
|
| 62 |
# Config
|
|
|
|
|
|
|
| 63 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 64 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 65 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
|
|
|
| 67 |
|
| 68 |
BENCHMARK = "clinical-note-scribe"
|
| 69 |
TASK_IDS = list(TASK_REGISTRY.keys())
|
| 70 |
+
MAX_STEPS = 5
|
| 71 |
MAX_TOKENS = 1024
|
| 72 |
TEMPERATURE = 0.2
|
| 73 |
|
|
|
|
| 74 |
# System prompt
|
|
|
|
|
|
|
| 75 |
SYSTEM_PROMPT = textwrap.dedent("""\
|
| 76 |
You are a clinical documentation assistant. Given a doctor-patient transcript
|
| 77 |
and patient context, generate a concise, clinically accurate SOAP note.
|
|
|
|
| 96 |
""").strip()
|
| 97 |
|
| 98 |
|
|
|
|
| 99 |
# Stdout logging — mandatory hackathon format
|
|
|
|
|
|
|
| 100 |
def log_start(task: str, env: str, model: str) -> None:
|
| 101 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 102 |
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
|
|
|
|
| 121 |
# Helpers
|
|
|
|
|
|
|
| 122 |
def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
|
| 123 |
"""Build the user message containing the transcript and context."""
|
| 124 |
ctx_str = json.dumps(patient_context, indent=2, default=str)
|
|
|
|
| 168 |
raise
|
| 169 |
|
| 170 |
|
|
|
|
| 171 |
# Per-task runner
|
|
|
|
|
|
|
| 172 |
def run_task(client: OpenAI, env: ClinicalNoteScribeEnv, task_id: str) -> dict[str, Any]:
|
| 173 |
"""Run a single task episode and return the result dict."""
|
| 174 |
rewards: List[float] = []
|
|
|
|
| 180 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 181 |
|
| 182 |
try:
|
|
|
|
| 183 |
obs = env.reset(task_id)
|
| 184 |
|
| 185 |
for step in range(1, MAX_STEPS + 1):
|
|
|
|
| 186 |
try:
|
| 187 |
action_dict = get_soap_note(client, obs.transcript, obs.patient_context)
|
| 188 |
action = Action(**action_dict)
|
| 189 |
action_str = f"submit_note(sections=S,O,A,P)"
|
| 190 |
except Exception as exc:
|
| 191 |
+
# On model / parse failure, submit an empty note
|
|
|
|
| 192 |
action = Action(
|
| 193 |
action_type="submit_note",
|
| 194 |
soap_note=SOAPNote(
|
|
|
|
| 201 |
action_str = "submit_note(fallback)"
|
| 202 |
last_error = str(exc)
|
| 203 |
|
|
|
|
| 204 |
obs, reward_obj, done, info = env.step(action)
|
| 205 |
|
| 206 |
reward_val = reward_obj.value
|
| 207 |
rewards.append(reward_val)
|
| 208 |
steps_taken = step
|
| 209 |
|
|
|
|
| 210 |
error_msg = None
|
| 211 |
if obs.errors_so_far:
|
| 212 |
error_msg = obs.errors_so_far[-1]
|
|
|
|
| 225 |
if done:
|
| 226 |
break
|
| 227 |
|
|
|
|
| 228 |
score = rewards[-1] if rewards else 0.0
|
| 229 |
score = min(max(score, 0.0), 1.0)
|
| 230 |
success = score > 0.0
|
|
|
|
| 246 |
}
|
| 247 |
|
| 248 |
|
|
|
|
| 249 |
# Main
|
|
|
|
|
|
|
| 250 |
def main() -> None:
|
| 251 |
if not HF_TOKEN:
|
| 252 |
print(
|
|
|
|
| 264 |
result = run_task(client, env, task_id)
|
| 265 |
results.append(result)
|
| 266 |
|
| 267 |
+
# Summary table
|
| 268 |
print("", file=sys.stderr, flush=True)
|
| 269 |
print("=" * 60, file=sys.stderr, flush=True)
|
| 270 |
print(" SUMMARY", file=sys.stderr, flush=True)
|
server/app.py
CHANGED
|
@@ -1,60 +1,40 @@
|
|
| 1 |
-
"""FastAPI application for the Clinical Note Scribe environment.
|
| 2 |
-
|
| 3 |
-
Run locally::
|
| 4 |
-
|
| 5 |
-
uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 6 |
-
|
| 7 |
-
Or via Docker (see ``Dockerfile`` in project root).
|
| 8 |
-
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
import logging
|
| 13 |
import sys
|
|
|
|
| 14 |
|
| 15 |
from fastapi import FastAPI
|
|
|
|
|
|
|
| 16 |
|
| 17 |
from server.routes import router
|
| 18 |
|
| 19 |
-
|
| 20 |
-
# Configure root logging → structured JSON to stdout
|
| 21 |
-
# ---------------------------------------------------------------------------
|
| 22 |
-
|
| 23 |
-
logging.basicConfig(
|
| 24 |
-
level=logging.INFO,
|
| 25 |
-
format="%(message)s",
|
| 26 |
-
handlers=[logging.StreamHandler(sys.stdout)],
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Silence noisy uvicorn access logs so our structured events stay clean
|
| 30 |
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 31 |
|
| 32 |
-
# ---------------------------------------------------------------------------
|
| 33 |
-
# Application factory
|
| 34 |
-
# ---------------------------------------------------------------------------
|
| 35 |
-
|
| 36 |
app = FastAPI(
|
| 37 |
title="Clinical Note Scribe – OpenEnv",
|
| 38 |
-
description=
|
| 39 |
-
|
| 40 |
-
"clinical SOAP-note generation from doctor–patient transcripts."
|
| 41 |
-
),
|
| 42 |
-
version="0.1.0",
|
| 43 |
)
|
|
|
|
| 44 |
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
-
# Mount all routes at root (/)
|
| 48 |
-
app.include_router(router)
|
| 49 |
|
| 50 |
@app.get("/", include_in_schema=False)
|
| 51 |
async def root():
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
def main():
|
| 56 |
import uvicorn
|
| 57 |
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
|
| 58 |
|
|
|
|
| 59 |
if __name__ == "__main__":
|
| 60 |
main()
|
|
|
|
| 1 |
+
"""FastAPI application for the Clinical Note Scribe environment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
|
| 9 |
from fastapi import FastAPI
|
| 10 |
+
from fastapi.staticfiles import StaticFiles
|
| 11 |
+
from fastapi.responses import FileResponse
|
| 12 |
|
| 13 |
from server.routes import router
|
| 14 |
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[logging.StreamHandler(sys.stdout)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
app = FastAPI(
|
| 19 |
title="Clinical Note Scribe – OpenEnv",
|
| 20 |
+
description="OpenEnv-compliant environment for evaluating AI agents on clinical SOAP-note generation.",
|
| 21 |
+
version="1.0.0",
|
|
|
|
|
|
|
|
|
|
| 22 |
)
|
| 23 |
+
app.include_router(router)
|
| 24 |
|
| 25 |
+
FRONTEND_DIR = Path(__file__).resolve().parent.parent / "frontend"
|
| 26 |
+
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
|
| 27 |
|
|
|
|
|
|
|
| 28 |
|
| 29 |
@app.get("/", include_in_schema=False)
|
| 30 |
async def root():
|
| 31 |
+
return FileResponse(str(FRONTEND_DIR / "index.html"))
|
| 32 |
+
|
| 33 |
|
| 34 |
def main():
|
| 35 |
import uvicorn
|
| 36 |
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
|
| 37 |
|
| 38 |
+
|
| 39 |
if __name__ == "__main__":
|
| 40 |
main()
|
server/routes.py
CHANGED
|
@@ -1,61 +1,29 @@
|
|
| 1 |
-
"""FastAPI
|
| 2 |
-
|
| 3 |
-
Endpoints
|
| 4 |
-
---------
|
| 5 |
-
POST /reset – start a new episode (takes ``task_id``, returns ``Observation``)
|
| 6 |
-
POST /step – execute an action (takes ``Action``, returns step result)
|
| 7 |
-
GET /state – inspect env state (returns ``EnvironmentState``)
|
| 8 |
-
GET /health – liveness probe (returns ``{"status": "ok"}``)
|
| 9 |
-
|
| 10 |
-
Structured logging
|
| 11 |
-
------------------
|
| 12 |
-
The underlying ``ClinicalNoteScribeEnv`` already emits ``[START]``, ``[STEP]``,
|
| 13 |
-
and ``[END]`` JSON lines to stdout via Python's ``logging`` module. This router
|
| 14 |
-
adds a thin request-level log wrapper so every inbound HTTP call is also
|
| 15 |
-
traceable.
|
| 16 |
-
"""
|
| 17 |
|
| 18 |
from __future__ import annotations
|
| 19 |
|
| 20 |
-
import logging
|
| 21 |
import json
|
|
|
|
| 22 |
import time
|
| 23 |
from typing import Any, Optional
|
| 24 |
|
| 25 |
from fastapi import APIRouter, HTTPException
|
| 26 |
-
from pydantic import BaseModel, Field
|
| 27 |
-
|
| 28 |
-
from environment.models import
|
| 29 |
-
Action,
|
| 30 |
-
EnvironmentState,
|
| 31 |
-
Observation,
|
| 32 |
-
Reward,
|
| 33 |
-
)
|
| 34 |
from environment.env import ClinicalNoteScribeEnv
|
| 35 |
|
| 36 |
logger = logging.getLogger("clinical_note_scribe.server")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# ---------------------------------------------------------------------------
|
| 40 |
-
# Singleton environment instance
|
| 41 |
-
# ---------------------------------------------------------------------------
|
| 42 |
-
|
| 43 |
_env = ClinicalNoteScribeEnv()
|
|
|
|
|
|
|
| 44 |
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
# ---------------------------------------------------------------------------
|
| 47 |
-
# Request / response schemas
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
|
| 50 |
class ResetRequest(BaseModel):
|
| 51 |
-
task_id: Optional[str] = Field(
|
| 52 |
-
default=None,
|
| 53 |
-
description=(
|
| 54 |
-
"Task to load. One of: easy_routine_checkup, "
|
| 55 |
-
"medium_chronic_disease_followup, hard_complex_er_visit. "
|
| 56 |
-
"Defaults to the first registered task."
|
| 57 |
-
),
|
| 58 |
-
)
|
| 59 |
|
| 60 |
|
| 61 |
class StepResponse(BaseModel):
|
|
@@ -69,114 +37,47 @@ class HealthResponse(BaseModel):
|
|
| 69 |
status: str = "ok"
|
| 70 |
|
| 71 |
|
| 72 |
-
|
| 73 |
-
# Helpers
|
| 74 |
-
# ---------------------------------------------------------------------------
|
| 75 |
-
|
| 76 |
-
def _log(event: str, **kwargs: Any) -> None:
|
| 77 |
-
"""Emit a structured JSON log line to stdout."""
|
| 78 |
-
payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
|
| 79 |
-
payload.update(kwargs)
|
| 80 |
-
logger.info(json.dumps(payload, default=str))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# ---------------------------------------------------------------------------
|
| 84 |
-
# Router
|
| 85 |
-
# ---------------------------------------------------------------------------
|
| 86 |
-
|
| 87 |
-
router = APIRouter()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
@router.post(
|
| 91 |
-
"/reset",
|
| 92 |
-
response_model=Observation,
|
| 93 |
-
summary="Reset the environment and start a new episode",
|
| 94 |
-
)
|
| 95 |
async def reset(body: Optional[ResetRequest] = None) -> Observation:
|
| 96 |
-
"""Load a task and return the initial ``Observation``.
|
| 97 |
-
|
| 98 |
-
The underlying environment emits a ``[START]`` log event.
|
| 99 |
-
"""
|
| 100 |
task_id = body.task_id if body else None
|
| 101 |
_log("START", endpoint="/reset", task_id=task_id)
|
| 102 |
try:
|
| 103 |
-
|
| 104 |
except ValueError as exc:
|
| 105 |
raise HTTPException(status_code=400, detail=str(exc))
|
| 106 |
-
return obs
|
| 107 |
|
| 108 |
|
| 109 |
-
@router.post(
|
| 110 |
-
"/step",
|
| 111 |
-
response_model=StepResponse,
|
| 112 |
-
summary="Submit an action and advance the environment by one step",
|
| 113 |
-
)
|
| 114 |
async def step(payload: dict[str, Any]) -> StepResponse:
|
| 115 |
-
"""Execute an action in the current episode.
|
| 116 |
-
|
| 117 |
-
Accepts a raw JSON body and validates it into an ``Action``.
|
| 118 |
-
If validation fails, the error is recorded in the environment
|
| 119 |
-
instead of returning an HTTP 422.
|
| 120 |
-
"""
|
| 121 |
-
from pydantic import ValidationError
|
| 122 |
-
from environment.models import Reward
|
| 123 |
-
|
| 124 |
try:
|
| 125 |
action = Action(**payload)
|
| 126 |
except (ValidationError, TypeError) as exc:
|
| 127 |
-
# Gracefully absorb bad payloads instead of crashing with HTTP 422
|
| 128 |
_log("STEP", endpoint="/step", action_type="invalid", error=str(exc))
|
| 129 |
error_msg = f"Invalid action payload: {exc}"
|
| 130 |
_env._errors_so_far.append(error_msg)
|
| 131 |
_env._step_count += 1
|
| 132 |
-
|
| 133 |
-
obs = _env._build_observation()
|
| 134 |
-
reward = Reward(
|
| 135 |
-
value=0.0,
|
| 136 |
-
signals={"error": 1.0},
|
| 137 |
-
done=False,
|
| 138 |
-
info={"error": error_msg},
|
| 139 |
-
)
|
| 140 |
return StepResponse(
|
| 141 |
-
observation=
|
| 142 |
-
reward=
|
| 143 |
-
done=False,
|
| 144 |
-
info={"error": error_msg},
|
| 145 |
)
|
| 146 |
|
| 147 |
_log("STEP", endpoint="/step", action_type=action.action_type)
|
| 148 |
try:
|
| 149 |
obs, reward, done, info = _env.step(action)
|
| 150 |
except RuntimeError as exc:
|
| 151 |
-
# e.g. stepping after episode is done without reset
|
| 152 |
raise HTTPException(status_code=409, detail=str(exc))
|
| 153 |
|
| 154 |
if done:
|
| 155 |
_log("END", endpoint="/step", final_score=reward.value)
|
| 156 |
-
|
| 157 |
-
return StepResponse(
|
| 158 |
-
observation=obs,
|
| 159 |
-
reward=reward,
|
| 160 |
-
done=done,
|
| 161 |
-
info=info,
|
| 162 |
-
)
|
| 163 |
|
| 164 |
|
| 165 |
-
@router.get(
|
| 166 |
-
"/state",
|
| 167 |
-
response_model=EnvironmentState,
|
| 168 |
-
summary="Return the full internal environment state",
|
| 169 |
-
)
|
| 170 |
async def state() -> EnvironmentState:
|
| 171 |
-
"""Inspect the environment without mutating it."""
|
| 172 |
return _env.state()
|
| 173 |
|
| 174 |
|
| 175 |
-
@router.get(
|
| 176 |
-
"/health",
|
| 177 |
-
response_model=HealthResponse,
|
| 178 |
-
summary="Liveness probe",
|
| 179 |
-
)
|
| 180 |
async def health() -> HealthResponse:
|
| 181 |
-
"""Returns HTTP 200 with ``{"status": "ok"}``."""
|
| 182 |
return HealthResponse()
|
|
|
|
| 1 |
+
"""FastAPI routes for the Clinical Note Scribe environment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 5 |
import json
|
| 6 |
+
import logging
|
| 7 |
import time
|
| 8 |
from typing import Any, Optional
|
| 9 |
|
| 10 |
from fastapi import APIRouter, HTTPException
|
| 11 |
+
from pydantic import BaseModel, Field, ValidationError
|
| 12 |
+
|
| 13 |
+
from environment.models import Action, EnvironmentState, Observation, Reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from environment.env import ClinicalNoteScribeEnv
|
| 15 |
|
| 16 |
logger = logging.getLogger("clinical_note_scribe.server")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
_env = ClinicalNoteScribeEnv()
|
| 18 |
+
router = APIRouter()
|
| 19 |
+
|
| 20 |
|
| 21 |
+
def _log(event: str, **kw: Any) -> None:
|
| 22 |
+
logger.info(json.dumps({"event": event, "timestamp": time.time(), **kw}, default=str))
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
class ResetRequest(BaseModel):
|
| 26 |
+
task_id: Optional[str] = Field(None, description="Task to load. Defaults to first registered task.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class StepResponse(BaseModel):
|
|
|
|
| 37 |
status: str = "ok"
|
| 38 |
|
| 39 |
|
| 40 |
+
@router.post("/reset", response_model=Observation, summary="Reset and start a new episode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
async def reset(body: Optional[ResetRequest] = None) -> Observation:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
task_id = body.task_id if body else None
|
| 43 |
_log("START", endpoint="/reset", task_id=task_id)
|
| 44 |
try:
|
| 45 |
+
return _env.reset(task_id=task_id)
|
| 46 |
except ValueError as exc:
|
| 47 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
+
@router.post("/step", response_model=StepResponse, summary="Submit an action")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
async def step(payload: dict[str, Any]) -> StepResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
try:
|
| 53 |
action = Action(**payload)
|
| 54 |
except (ValidationError, TypeError) as exc:
|
|
|
|
| 55 |
_log("STEP", endpoint="/step", action_type="invalid", error=str(exc))
|
| 56 |
error_msg = f"Invalid action payload: {exc}"
|
| 57 |
_env._errors_so_far.append(error_msg)
|
| 58 |
_env._step_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
return StepResponse(
|
| 60 |
+
observation=_env._obs(),
|
| 61 |
+
reward=Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": error_msg}),
|
| 62 |
+
done=False, info={"error": error_msg},
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
_log("STEP", endpoint="/step", action_type=action.action_type)
|
| 66 |
try:
|
| 67 |
obs, reward, done, info = _env.step(action)
|
| 68 |
except RuntimeError as exc:
|
|
|
|
| 69 |
raise HTTPException(status_code=409, detail=str(exc))
|
| 70 |
|
| 71 |
if done:
|
| 72 |
_log("END", endpoint="/step", final_score=reward.value)
|
| 73 |
+
return StepResponse(observation=obs, reward=reward, done=done, info=info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
+
@router.get("/state", response_model=EnvironmentState, summary="Inspect environment state")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
async def state() -> EnvironmentState:
|
|
|
|
| 78 |
return _env.state()
|
| 79 |
|
| 80 |
|
| 81 |
+
@router.get("/health", response_model=HealthResponse, summary="Liveness probe")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
async def health() -> HealthResponse:
|
|
|
|
| 83 |
return HealthResponse()
|