Spaces:
Sleeping
Sleeping
File size: 14,165 Bytes
5afe05e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 | """
Genuine ReAct Agent β LLM is the brain, Python is just the hands.
No fake [THOUGHT] logs. No Python detectors solving the problem.
The LLM reads raw data, decides actions, gets feedback, and iterates.
"""
from __future__ import annotations
import json
import re
import time
from typing import Optional
from openai import OpenAI
from models import AuditAction
REACT_SYSTEM_PROMPT = """You are a Senior Clinical Data Manager auditing a clinical trial dataset.
You interact with an audit environment through structured JSON actions.
AVAILABLE ACTIONS:
1. investigate_pattern β examine a variable's distribution
{"action_type": "investigate_pattern", "variable": "<field_name>"}
2. compute_distribution β analyze control-arm demographics
{"action_type": "compute_distribution", "variable": "<field_name>"}
3. flag_error β flag a specific patient with a protocol violation
{"action_type": "flag_error", "patient_id": "<id>", "error_type": "<type>", "reason": "<why>", "confidence": <0.0-1.0>}
Error types: invalid_age, temporal_inconsistency, protocol_window_violation, selection_bias
For selection_bias: omit patient_id.
4. submit_report β submit final audit report (once, at the end)
{"action_type": "submit_report", "report": "<text>"}
CRITICAL RULES:
- Read the PROTOCOL EXCERPT carefully. Age ranges and treatment windows CHANGE every episode.
- You MUST investigate all required variables BEFORE you can flag any errors.
- You MUST compute required distributions BEFORE flagging selection_bias.
- You MUST flag at least one error BEFORE submitting a report.
- False positive flags are penalized HEAVILY (3x the reward of a correct flag).
- Wrong flags with confidence > 0.8 get a 1.8x EXTRA penalty.
- Be precise with patient_id values β do not guess or hallucinate IDs.
- Do NOT flag valid patients. Read dates and numbers carefully.
- Stage IV patients may have an extended treatment window β check the protocol.
OUTPUT FORMAT β respond with ONLY valid JSON:
{"reasoning": "your analysis", "actions": [{"action_type": "...", ...}, ...]}
"""
def build_task_instructions(task_id: str) -> str:
"""Return task-specific constraints the LLM must follow."""
if task_id == "task_easy":
return (
"TASK: Dynamic Eligibility Screening (Easy)\n"
"REQUIRED INVESTIGATIONS: age\n"
"ALLOWED ERROR TYPES: invalid_age ONLY\n"
"DO NOT flag temporal_inconsistency, protocol_window_violation, or selection_bias.\n"
"Focus: Compare each patient's age against the protocol's specific age range."
)
elif task_id == "task_medium":
return (
"TASK: Protocol Timeline Audit (Medium)\n"
"REQUIRED INVESTIGATIONS: age, death_date, enrollment_date, stage\n"
"ALLOWED ERROR TYPES: invalid_age, temporal_inconsistency, protocol_window_violation\n"
"DO NOT flag selection_bias.\n"
"Focus:\n"
"- Check ages against protocol range\n"
"- Check death_date must NEVER precede treatment_start\n"
"- Check enrollment-to-treatment delay against the protocol window\n"
"- Stage IV patients may have an extended window β read the protocol"
)
else:
return (
"TASK: Equity + Protocol Audit (Hard)\n"
"REQUIRED INVESTIGATIONS: age, death_date, enrollment_date, stage, comorbidity_index\n"
"REQUIRED DISTRIBUTIONS: ethnicity, gender, outcome (compute these before bias analysis)\n"
"ALLOWED ERROR TYPES: invalid_age, temporal_inconsistency, protocol_window_violation, selection_bias\n"
"Focus: All checks from Medium, PLUS:\n"
"- CRITICAL: Read the protocol's comorbidity override β it may revoke the Stage IV exception\n"
"- For selection_bias: compute distributions first, then use stage-adjusted mortality, not raw numbers\n"
"- High-risk outreach sites may have elevated mortality that is NOT bias"
)
def filter_dataset_for_llm(dataset: list[dict]) -> list[dict]:
"""Remove None/empty values from records to save LLM tokens."""
cleaned = []
for record in dataset:
clean = {}
for k, v in record.items():
if v is None or v == []:
continue
clean[k] = v
cleaned.append(clean)
return cleaned
def parse_llm_actions(response_text: str) -> list[AuditAction]:
"""Parse LLM JSON response into AuditAction objects with fallback strategies."""
text = response_text.strip()
raw_actions = None
# Strategy 1: direct JSON parse
try:
data = json.loads(text)
if isinstance(data, dict) and "actions" in data:
raw_actions = data["actions"]
elif isinstance(data, list):
raw_actions = data
elif isinstance(data, dict) and "action_type" in data:
raw_actions = [data]
except json.JSONDecodeError:
pass
# Strategy 2: JSON inside markdown code fence
if raw_actions is None:
fence = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
if fence:
try:
data = json.loads(fence.group(1).strip())
if isinstance(data, dict) and "actions" in data:
raw_actions = data["actions"]
elif isinstance(data, list):
raw_actions = data
except json.JSONDecodeError:
pass
# Strategy 3: find largest JSON block
if raw_actions is None:
for match in re.finditer(r"\{[\s\S]*?\}", text):
try:
data = json.loads(match.group())
if isinstance(data, dict) and "actions" in data:
raw_actions = data["actions"]
break
except json.JSONDecodeError:
continue
if not raw_actions:
return []
# Convert to AuditAction objects
result = []
for raw in raw_actions:
if not isinstance(raw, dict) or "action_type" not in raw:
continue
try:
conf = raw.get("confidence")
if conf is not None:
conf = float(conf)
result.append(AuditAction(
action_type=str(raw["action_type"]),
patient_id=raw.get("patient_id"),
error_type=raw.get("error_type"),
reason=raw.get("reason"),
variable=raw.get("variable"),
report=raw.get("report"),
confidence=conf,
proposed_value=raw.get("proposed_value"),
))
except Exception:
continue
return result
def run_react_task(
client: OpenAI,
model_name: str,
env_session,
task_id: str,
task_name: str,
seed: int,
log_step_fn=None,
log_start_fn=None,
log_end_fn=None,
):
"""Genuine ReAct agent: LLM reads raw data and drives all decisions.
Python only: (1) builds prompts, (2) parses JSON, (3) calls env.step().
"""
from collections import Counter
print(f"\n Task: {task_name}")
print(" " + "-" * 54)
llm_calls = 0
step = 0
final_score = 0.0
rewards: list[float] = []
true_pos = 0
false_pos = 0
total_flagged = 0
task_start = time.time()
with env_session() as env:
result = env.reset(task_id=task_id, seed=seed)
obs = result.observation.model_dump()
dataset = obs["dataset"]
protocol_excerpt = obs["trial_protocol_excerpt"]
max_steps = obs["attempts_remaining"]
print(f" Model: {model_name}")
print(f" Protocol: {obs.get('protocol_title', 'N/A')} | "
f"Patients: {len(dataset)} | Max steps: {max_steps}")
if log_start_fn:
log_start_fn(task_id)
# Prepare dataset (strip nulls to save tokens)
filtered = filter_dataset_for_llm(dataset)
dataset_json = json.dumps(filtered, separators=(",", ":"), default=str)
token_est = len(dataset_json) // 4
print(f" Dataset context: ~{token_est:,} tokens")
task_instructions = build_task_instructions(task_id)
initial_prompt = (
f"TRIAL PROTOCOL (read carefully β rules change every episode):\n"
f"{protocol_excerpt}\n\n"
f"{task_instructions}\n\n"
f"PATIENT DATASET ({len(filtered)} records, {max_steps} steps available):\n"
f"{dataset_json}\n\n"
f"Begin your audit. Investigate the required variables first, "
f"then flag errors, then submit your report. Respond with JSON only."
)
messages = [
{"role": "system", "content": REACT_SYSTEM_PROMPT},
{"role": "user", "content": initial_prompt},
]
done = False
max_turns = 6
for turn in range(1, max_turns + 1):
if done or step >= max_steps:
break
print(f"\n ββ LLM Turn {turn} ββ")
# Call LLM
try:
t0 = time.time()
completion = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.1,
max_tokens=4096,
)
elapsed = time.time() - t0
llm_calls += 1
llm_response = completion.choices[0].message.content or ""
print(f" LLM responded in {elapsed:.1f}s ({len(llm_response)} chars)")
except Exception as exc:
print(f" [LLM ERROR] {exc}")
break
# Parse actions
actions = parse_llm_actions(llm_response)
if not actions:
print(f" [PARSE ERROR] No valid actions parsed")
preview = llm_response[:300].replace("\n", " ")
print(f" LLM said: {preview}...")
if turn < max_turns:
messages.append({"role": "assistant", "content": llm_response})
messages.append({"role": "user", "content":
"Your response was not valid JSON. Respond with ONLY: "
'{"reasoning": "...", "actions": [...]}'
})
continue
action_counts = Counter(a.action_type for a in actions)
print(f" Parsed {len(actions)} actions: "
+ ", ".join(f"{k}={v}" for k, v in action_counts.items()))
# Execute each action
turn_feedback = []
for action in actions:
if done or step >= max_steps:
break
result = env.step(action)
obs = result.observation.model_dump()
step_reward = float(result.reward or 0.0)
done = result.done
step += 1
fb = obs.get("feedback", "")
turn_feedback.append(f"Step {step}/{max_steps} [{action.action_type}]: {fb}")
rewards.append(step_reward)
if action.action_type == "flag_error":
total_flagged += 1
if "β" in fb or "Correct" in fb:
true_pos += 1
elif "β" in fb or "REJECTED" in fb:
false_pos += 1
if log_step_fn:
log_step_fn(step, action.action_type, step_reward, done)
score = obs.get("score_so_far", 0.0)
tag = "β" if "β" in fb else "β" if "β" in fb else "β"
print(f" Step {step}: score={score:.2f} [{tag}] {fb[:100]}")
final_score = score
if done:
break
# Build feedback for next turn
feedback_text = "\n".join(turn_feedback)
remaining = max_steps - step
phase = obs.get("phase", "unknown")
errs_found = len(obs.get("errors_found", []))
messages.append({"role": "assistant", "content": llm_response})
messages.append({"role": "user", "content": (
f"RESULTS:\n{feedback_text}\n\n"
f"STATUS: {remaining} steps left | Score: {final_score:.2f} | "
f"Phase: {phase} | Errors found: {errs_found}\n\n"
f"Continue. Respond with JSON actions."
)})
# Context management: after turn 2, condense the dataset message
if turn >= 2 and len(messages[1]["content"]) > 10000:
messages[1] = {"role": "user", "content": (
f"PROTOCOL:\n{protocol_excerpt}\n\n"
f"{task_instructions}\n\n"
f"[Dataset of {len(filtered)} patients was provided in turn 1]\n"
f"Continue your audit. {remaining} steps left."
)}
# If LLM never submitted report, auto-submit a minimal one
if not done and step < max_steps:
result = env.step(AuditAction(
action_type="submit_report",
report=(
"Audit report: protocol compliance review completed. "
"Recommend corrective action for flagged violations."
),
))
obs = result.observation.model_dump()
final_score = obs.get("score_so_far", 0.0)
step += 1
rewards.append(float(result.reward or 0.0))
if log_step_fn:
log_step_fn(step, "submit_report", rewards[-1], True)
print(f" Step {step}: score={final_score:.2f} [β] auto-submitted report")
task_elapsed = time.time() - task_start
precision = true_pos / total_flagged if total_flagged else 0.0
if log_end_fn:
log_end_fn(final_score > 0.0, step, final_score, rewards)
print(f" Metrics: {true_pos}/{total_flagged} correct "
f"(precision={precision:.0%}) | {step} steps | {llm_calls} LLM calls | {task_elapsed:.1f}s")
return final_score, step, llm_calls, precision, task_elapsed
|