Sumit Saraswat commited on
Commit
a7bca03
·
1 Parent(s): 0dca8bf

Restructured Dockerfile and requirements to root for Hugging Face deployment

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
server/Dockerfile → Dockerfile RENAMED
@@ -17,4 +17,4 @@ EXPOSE 8000
17
  HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
18
  CMD curl -f http://localhost:8000/health || exit 1
19
 
20
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
 
17
  HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
18
  CMD curl -f http://localhost:8000/health || exit 1
19
 
20
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,88 +1,67 @@
1
  # Clinical Trial Auditor (OpenEnv)
2
 
3
- Production-grade OpenEnv benchmark for clinical trial data quality and bias auditing.
4
 
5
- The agent plays the role of a Senior Clinical Data Manager and must detect:
6
- - syntactic data quality errors (invalid/missing age),
7
- - temporal inconsistencies (death before treatment),
8
- - multi-dimensional selection bias in control cohorts.
9
 
10
- This environment is designed as a real benchmark system, not a static puzzle:
11
- - procedural generation on every `reset()`,
12
- - deterministic seed reproducibility,
13
- - adversarial traps that punish shallow heuristics,
14
- - deterministic programmatic graders with scores in `[0.0, 1.0]`.
15
 
16
- ---
 
 
 
 
17
 
18
- ## Why This Matters (Real-World Utility)
19
 
20
- Clinical trial audits are high-stakes workflows. Data defects and subgroup bias can:
21
- - invalidate endpoints,
22
- - distort treatment effect estimates,
23
- - create regulatory and patient-safety risk.
24
 
25
- This environment models a realistic multi-site Phase III oncology pipeline where agents must balance recall and precision under strict action budgets, with penalties for over-flagging.
26
-
27
- ---
 
 
28
 
29
  ## OpenEnv Compliance
30
 
31
  This project implements the required OpenEnv interface:
32
- - typed `Action`, `Observation`, `State` models (Pydantic),
33
  - `reset(seed, task_id, ...) -> Observation`,
34
  - `step(action) -> Observation`,
35
  - `state -> current state`,
36
- - `openenv.yaml` manifest at repo root.
37
 
38
  Validation:
 
39
  ```bash
40
  openenv validate .
41
  ```
42
 
43
- ---
44
-
45
- ## Environment Architecture
46
-
47
- `ClinicalTrialAuditorEnvironment` is intentionally layered:
48
 
49
- 1. **Data Engine** (`server/dataset_generator.py`)
50
- - Procedural patient generation using statistical distributions.
51
- - Difficulty-specific dataset size and error composition.
52
- 2. **Trap Engine**
53
- - Boundary-valid traps (`18`, `120`, etc.),
54
- - near-temporal valid traps (death 1-3 days after treatment),
55
- - fake bias distractors.
56
- 3. **Scoring Engine**
57
- - Deterministic ground-truth lookup for each flag.
58
- - Partial progress rewards + false-positive penalties.
59
- - Confidence calibration (overconfident wrong answers are punished harder).
60
- 4. **Agent Interface**
61
- - Standard OpenEnv `step/reset/state`.
62
-
63
- ---
64
-
65
- ## Task Suite (Easy -> Medium -> Hard)
66
 
67
- ### Task 1: `task_easy` (Syntactic Cleaning)
68
- - Typical size: `300` patients
69
- - Objective: detect all `invalid_age` cases only
70
- - Includes valid edge-case age traps to punish naive thresholding
71
- - Bias grading disabled
72
 
73
- ### Task 2: `task_medium` (Temporal Consistency)
74
- - Typical size: `500` patients
75
- - Objective: detect both `invalid_age` and `temporal_inconsistency`
76
- - Includes near-boundary and near-temporal traps
77
- - Bias grading disabled
78
 
79
- ### Task 3: `task_hard` (Comprehensive Audit)
80
- - Typical size: `800` patients
81
- - Objective: detect `invalid_age` + `temporal_inconsistency` + `selection_bias`
82
- - Bias injected with representation + outcome + gender skew signals
83
- - Includes fake patterns to avoid shortcut behavior
84
 
85
- ---
 
 
 
 
86
 
87
  ## Action Space
88
 
@@ -91,7 +70,7 @@ class AuditAction(Action):
91
  action_type: str # investigate_pattern | compute_distribution | flag_error | propose_fix | submit_report
92
  variable: Optional[str]
93
  patient_id: Optional[str]
94
- error_type: Optional[str] # invalid_age | temporal_inconsistency | selection_bias
95
  reason: Optional[str]
96
  proposed_value: Optional[str]
97
  report: Optional[str]
@@ -107,130 +86,151 @@ class AuditObservation(Observation):
107
  task_id: str
108
  task_type: str
109
  task_description: str
 
 
110
  dataset: list[dict]
111
  errors_found: list[str]
112
  patterns_investigated: list[str]
113
  distributions_computed: list[str]
114
  feedback: str
115
  score_so_far: float
 
 
116
  attempts_remaining: int
117
  phase: str
118
  ```
119
 
120
- ---
121
 
122
- ## Reward Design (Meaningful Shaping)
123
 
124
- Reward is dense and trajectory-aware (not sparse binary).
 
 
 
 
 
 
125
 
126
- - correct flag: `+0.10`
127
- - false positive: `-0.30` (3x stronger than correct flag)
128
- - duplicate flag: `-0.10`
129
- - investigation/distribution bonuses and redundancy penalties
130
- - per-step cost to discourage long loops
131
- - workflow and efficiency bonuses
132
- - hard-task bias detection bonus: `+0.20`
133
- - difficulty multipliers by task
134
- - score clamped to `[0.0, 1.0]`
135
 
136
- This reward design explicitly creates precision pressure and separates robust agents from brute-force flaggers.
137
 
138
- ---
139
 
140
- ## Procedural Generation + Reproducibility
141
 
142
- Generator script:
143
  ```bash
144
- cd server
145
- python3 dataset_generator.py
146
  ```
147
 
148
  What it guarantees:
149
- - same seed -> identical dataset + identical ground truth,
150
- - different seeds -> different datasets,
151
- - controlled error injection rates,
152
- - deterministic grader compatibility.
153
 
154
- Example validated generation profile (seeded):
155
- - Easy: `300` patients, `12` injected errors, traps enabled
156
- - Medium: `500` patients, `37` injected errors, traps enabled
157
- - Hard: `800` patients, `49` injected errors + bias signal, traps enabled
158
 
159
- ---
 
 
160
 
161
  ## Baseline Inference (`inference.py`)
162
 
163
- `inference.py` supports multiple agent modes:
164
- - `naive`: raw LLM behavior,
165
- - `heuristic`: simple rules (no LLM),
166
- - `full`: statistical detector + planning + LLM report,
167
- - `all`: run all modes side-by-side.
 
 
 
168
 
169
- Run:
170
  ```bash
171
  python3 inference.py --mode all
172
  ```
173
 
174
- Reproducibility env vars:
175
- - `API_BASE_URL`
176
- - `MODEL_NAME`
177
- - `HF_TOKEN` or `OPENAI_API_KEY`
178
- - `ENV_BASE_URL` (defaults to `http://localhost:8000`)
179
 
180
- Current measured results (seeded local run):
181
- - **Heuristic** average: `0.98`
182
- - **Full** average: `1.00`
183
 
184
- Note: for judge-facing benchmarking, include a `--mode all` table from the same seed and model in this README before final submission.
 
 
185
 
186
- ---
187
 
188
- ## Local Run
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- ### 1) Start server
191
  ```bash
192
  cd server
193
  PYTHONPATH=.. python3 -m uvicorn app:app --host 0.0.0.0 --port 8000
194
  ```
195
 
196
  ### 2) Health check
 
197
  ```bash
198
  curl -s http://localhost:8000/health
199
  ```
200
 
201
- ### 3) Run baseline
 
202
  ```bash
203
  cd ..
204
- python3 inference.py --mode full
205
  ```
206
 
207
- ---
208
-
209
  ## Docker
210
 
211
  Build and run:
 
212
  ```bash
213
  cd server
214
  docker build -t clinical-trial-auditor:latest .
215
  docker run -p 8000:8000 clinical-trial-auditor:latest
216
  ```
217
 
218
- Container includes healthcheck at `/health`.
219
-
220
- ---
221
 
222
  ## Hugging Face Space Readiness Checklist
223
 
224
- - [x] OpenEnv interface implemented (`step/reset/state`)
225
- - [x] typed models for actions/observations/state
226
  - [x] `openenv.yaml` present
227
- - [x] 3 tasks with deterministic graders and score range `[0.0, 1.0]`
228
- - [x] meaningful reward shaping across trajectory
229
- - [x] baseline script at project root: `inference.py`
230
- - [x] dockerized server (`server/Dockerfile`)
231
- - [x] `openenv validate .` passes locally
232
-
233
- ---
234
 
235
  ## Project Structure
236
 
@@ -250,8 +250,6 @@ clinical_trial_auditor/
250
  └── Dockerfile
251
  ```
252
 
253
- ---
254
-
255
  ## Motivation
256
 
257
- This benchmark is intended to evaluate whether an AI agent can do rigorous, workflow-constrained, clinically relevant data auditing under adversarial conditions, not just solve a fixed toy dataset.
 
1
  # Clinical Trial Auditor (OpenEnv)
2
 
3
+ Clinical Trial Auditor is a protocol-aware OpenEnv benchmark for clinical data auditing. The agent acts as a Senior Clinical Data Manager reviewing procedurally generated Phase III oncology trial data under dynamic per-episode rules.
4
 
5
+ This is not a static spreadsheet puzzle. Every `reset()` samples a new protocol excerpt and a new dataset, so the agent must read the rules for that episode and then audit the records accordingly.
 
 
 
6
 
7
+ ## Why This Matters
 
 
 
 
8
 
9
+ Real clinical audits are messy:
10
+ - eligibility criteria vary by protocol,
11
+ - timeline rules include exceptions,
12
+ - suspicious subgroup outcomes are not always evidence of bias,
13
+ - false positives waste reviewer time and can trigger unnecessary escalations.
14
 
15
+ This environment is built to evaluate exactly those failure modes. It targets the gap between "can parse a table" and "can follow a high-stakes auditing workflow with protocol friction and adversarial traps."
16
 
17
+ ## What Makes This Benchmark Different
 
 
 
18
 
19
+ - Dynamic protocol reasoning: each episode exposes a new `trial_protocol_excerpt` with episode-specific age ranges and treatment-start windows.
20
+ - Cross-modal audit logic: the agent must apply text rules from the protocol to tabular patient data.
21
+ - Stage-aware timing exceptions: Stage IV patients can have a longer enrollment-to-treatment window, which creates valid edge cases that trap shortcut heuristics.
22
+ - Hallucination traps: hard episodes can contain a confounded high-risk cohort that looks biased overall but is not actionable after stage-adjusted review.
23
+ - Dense reward plus benchmark rubric: step rewards encourage learning, while `score_so_far` tracks a judge-facing episode rubric emphasizing recall, precision, workflow discipline, efficiency, and report quality.
24
 
25
  ## OpenEnv Compliance
26
 
27
  This project implements the required OpenEnv interface:
28
+ - typed `Action`, `Observation`, and `State` models with Pydantic,
29
  - `reset(seed, task_id, ...) -> Observation`,
30
  - `step(action) -> Observation`,
31
  - `state -> current state`,
32
+ - `openenv.yaml` at the repo root.
33
 
34
  Validation:
35
+
36
  ```bash
37
  openenv validate .
38
  ```
39
 
40
+ Local validation result:
 
 
 
 
41
 
42
+ ```text
43
+ [OK] : Ready for multi-mode deployment
44
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ ## Task Suite
 
 
 
 
47
 
48
+ ### Task 1: `task_easy` Dynamic Eligibility Screening
49
+ - Dataset size: about `300` patients
50
+ - Goal: flag `invalid_age`
51
+ - Difficulty source: the age bounds are episode-specific, not fixed at 18-120
52
+ - Traps: valid edge ages at the protocol boundary
53
 
54
+ ### Task 2: `task_medium` Protocol Timeline Audit
55
+ - Dataset size: about `480` patients
56
+ - Goal: flag `invalid_age`, `temporal_inconsistency`, and `protocol_window_violation`
57
+ - Difficulty source: the treatment-start window is protocol-specific and Stage IV has a longer valid window
58
+ - Traps: valid near-boundary start delays and near-immediate but valid deaths
59
 
60
+ ### Task 3: `task_hard` — Equity + Protocol Audit
61
+ - Dataset size: about `720` patients
62
+ - Goal: flag record-level issues and determine whether actionable `selection_bias` exists
63
+ - Difficulty source: some hard episodes contain real control-arm bias, while others contain a confounded high-risk cohort that only looks biased before stage adjustment
64
+ - Traps: treatment-arm skew, high-risk outreach sites, and false-positive bias patterns
65
 
66
  ## Action Space
67
 
 
70
  action_type: str # investigate_pattern | compute_distribution | flag_error | propose_fix | submit_report
71
  variable: Optional[str]
72
  patient_id: Optional[str]
73
+ error_type: Optional[str] # invalid_age | temporal_inconsistency | protocol_window_violation | selection_bias
74
  reason: Optional[str]
75
  proposed_value: Optional[str]
76
  report: Optional[str]
 
86
  task_id: str
87
  task_type: str
88
  task_description: str
89
+ protocol_title: str
90
+ trial_protocol_excerpt: str
91
  dataset: list[dict]
92
  errors_found: list[str]
93
  patterns_investigated: list[str]
94
  distributions_computed: list[str]
95
  feedback: str
96
  score_so_far: float
97
+ dense_reward_total: float
98
+ score_breakdown: dict[str, float]
99
  attempts_remaining: int
100
  phase: str
101
  ```
102
 
103
+ ## Reward Design and Benchmark Score
104
 
105
+ The environment uses two scoring layers:
106
 
107
+ - Dense step reward:
108
+ - correct flags,
109
+ - false-positive penalties,
110
+ - duplicate penalties,
111
+ - investigation/distribution bonuses,
112
+ - confidence penalties for overconfident wrong flags,
113
+ - per-step costs.
114
 
115
+ - Episode benchmark score (`score_so_far`):
116
+ - recall: `70%`
117
+ - precision: `15%`
118
+ - workflow discipline: `5%`
119
+ - efficiency: `5%`
120
+ - report quality: `5%`
 
 
 
121
 
122
+ This separation keeps the RL signal dense while preventing early score saturation from hiding later mistakes.
123
 
124
+ ## Procedural Generation and Reproducibility
125
 
126
+ Run the generator self-test:
127
 
 
128
  ```bash
129
+ python3 server/dataset_generator.py
 
130
  ```
131
 
132
  What it guarantees:
133
+ - same seed -> same dataset, same protocol excerpt, same ground truth,
134
+ - different seeds -> different protocols and different datasets,
135
+ - deterministic grading compatibility,
136
+ - hard mode can alternate between `true_bias` and `confounded_no_bias`.
137
 
138
+ Example validated seeded profile:
 
 
 
139
 
140
+ - Easy: `300` patients, `8` record-level errors, `13` traps
141
+ - Medium: `480` patients, `23` record-level errors, `25` traps
142
+ - Hard: `720` patients, `34` total issues including protocol/timing/bias logic, `40` traps
143
 
144
  ## Baseline Inference (`inference.py`)
145
 
146
+ `inference.py` now demonstrates a clean difficulty gradient:
147
+
148
+ - `naive`: raw sample-level behavior
149
+ - `heuristic`: rule-based but trap-prone
150
+ - `full`: protocol parser + stage-aware detectors + structured reporting
151
+ - `all`: side-by-side comparison
152
+
153
+ HTTP mode:
154
 
 
155
  ```bash
156
  python3 inference.py --mode all
157
  ```
158
 
159
+ Isolated local validation mode with no socket bind:
 
 
 
 
160
 
161
+ ```bash
162
+ ENV_BASE_URL=inprocess python3 inference.py --mode all
163
+ ```
164
 
165
+ LLM integration:
166
+ - When `OPENAI_API_KEY` or `HF_TOKEN` is present, naive mode and report generation use the OpenAI-compatible client pointed at `API_BASE_URL`.
167
+ - Without a key, the script falls back to deterministic local behavior so validation still runs end-to-end.
168
 
169
+ Current reproducible local benchmark result:
170
 
171
+ Command:
172
+
173
+ ```bash
174
+ ENV_BASE_URL=inprocess python3 inference.py --mode all --seed 20260402
175
+ ```
176
+
177
+ Scores:
178
+
179
+ | Agent | Easy | Medium | Hard | Average |
180
+ |---|---:|---:|---:|---:|
181
+ | Naive | 0.36 | 0.08 | 0.09 | 0.18 |
182
+ | Heuristic | 0.81 | 0.56 | 0.45 | 0.60 |
183
+ | Full | 0.98 | 0.99 | 0.99 | 0.99 |
184
+
185
+ This is the intended story:
186
+ - naive agents underperform badly,
187
+ - shallow heuristics get trapped by dynamic protocol edges and confounded bias signals,
188
+ - protocol-aware agents perform strongly.
189
+
190
+ ## Local Usage
191
+
192
+ ### 1) Start the server
193
 
 
194
  ```bash
195
  cd server
196
  PYTHONPATH=.. python3 -m uvicorn app:app --host 0.0.0.0 --port 8000
197
  ```
198
 
199
  ### 2) Health check
200
+
201
  ```bash
202
  curl -s http://localhost:8000/health
203
  ```
204
 
205
+ ### 3) Run the baseline
206
+
207
  ```bash
208
  cd ..
209
+ python3 inference.py --mode all
210
  ```
211
 
 
 
212
  ## Docker
213
 
214
  Build and run:
215
+
216
  ```bash
217
  cd server
218
  docker build -t clinical-trial-auditor:latest .
219
  docker run -p 8000:8000 clinical-trial-auditor:latest
220
  ```
221
 
222
+ The container exposes `/health` for health checks and is ready for Hugging Face Spaces container deployment.
 
 
223
 
224
  ## Hugging Face Space Readiness Checklist
225
 
226
+ - [x] OpenEnv interface implemented
227
+ - [x] typed models for action/observation/state
228
  - [x] `openenv.yaml` present
229
+ - [x] 3 tasks with deterministic graders and scores in `[0.0, 1.0]`
230
+ - [x] dense reward shaping and benchmark rubric
231
+ - [x] reproducible `inference.py` at repo root
232
+ - [x] dockerized server
233
+ - [x] `openenv validate .` passes
 
 
234
 
235
  ## Project Structure
236
 
 
250
  └── Dockerfile
251
  ```
252
 
 
 
253
  ## Motivation
254
 
255
+ This benchmark is built to test whether an agent can read a changing clinical protocol, audit patient records against that protocol, avoid hallucinated escalations, and write a grounded operational report under a limited action budget.
inference.py CHANGED
@@ -1,458 +1,429 @@
1
  """
2
- Clinical Trial Auditor — Multi-Agent Baseline Inference
3
- =========================================================
4
- Three agent modes to demonstrate environment difficulty gradient:
5
-
6
- 1. NAIVE Raw LLM prompt, no statistical tools expected ~0.25-0.40
7
- 2. HEURISTIC — Simple rule-based agent expected ~0.45-0.60
8
- 3. FULL Statistical Detection Engine + LLM Reasoning → expected ~0.85-0.95
9
-
10
- Usage:
11
- python inference.py # Full agent (default)
12
- python inference.py --mode naive # Naive LLM-only agent
13
- python inference.py --mode heuristic # Simple heuristic agent
14
- python inference.py --mode full # Full agentic pipeline
15
- python inference.py --mode all # Run all three, side-by-side
16
-
17
- Pipeline (full mode):
18
- 1. PROFILE → Schema-aware statistical analysis of dataset
19
- 2. DETECT → Multi-detector anomaly pipeline with confidence scoring
20
- 3. ASSESS → Risk severity + clinical impact evaluation
21
- 4. PLAN → Task-adaptive optimal action sequence
22
- 5. REASON → LLM for ambiguous cases + expert report generation
23
- 6. EXECUTE → Deterministic environment interaction
24
- 7. EVALUATE → Precision/recall/F1 metrics tracking
25
  """
 
 
 
 
 
26
  import os
 
 
27
  import sys
28
  import time
29
- import json
30
- import math
31
- import argparse
32
- import statistics
33
- from datetime import datetime
34
  from collections import Counter
 
 
 
35
  from typing import Optional
36
 
 
 
37
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
 
39
- from openai import OpenAI
40
  from client import ClinicalTrialAuditorEnv
41
  from models import AuditAction
42
 
43
- # ── Configuration ─────────────────────────────────────────────────────────
 
 
 
 
 
44
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
45
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
46
  MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
47
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Reproducible seed for baseline evaluation
50
- BASELINE_SEED = 20240401
 
 
 
 
 
 
 
 
 
51
 
 
 
52
 
53
- # ═══════════════════════════════════════════════════════════════════════════
54
- # CORE DATA STRUCTURES
55
- # ═══════════════════════════════════════════════════════════════════════════
56
 
 
57
  class Finding:
58
- """A detected anomaly with confidence, risk severity, and explanation."""
59
- def __init__(self, patient_id: str, error_type: str, reason: str,
60
- confidence: float, risk: str = "medium",
61
- value=None, statistical_context: str = ""):
62
- self.patient_id = patient_id
63
- self.error_type = error_type
64
- self.reason = reason
65
- self.confidence = min(1.0, max(0.0, confidence))
66
- self.risk = risk
67
- self.value = value
68
- self.statistical_context = statistical_context
69
 
70
  @property
71
  def priority_score(self) -> float:
72
- risk_weights = {"critical": 1.0, "high": 0.8, "medium": 0.5, "low": 0.2}
73
- return self.confidence * risk_weights.get(self.risk, 0.5)
74
-
75
- def explain(self) -> str:
76
- parts = [f"{self.error_type}: {self.reason}"]
77
- if self.statistical_context:
78
- parts.append(f" Evidence: {self.statistical_context}")
79
- parts.append(f" Confidence: {self.confidence:.0%} | Risk: {self.risk.upper()}")
80
- return "\n".join(parts)
81
-
82
-
83
- # ═══════════════════════════════════════════════════════════════════════════
84
- # MODULE 1: DATA PROFILER — Robust statistical summary
85
- # ═══════════════════════════════════════════════════════════════════════════
86
-
87
- class DataProfiler:
88
- """Schema-aware statistical profiler using robust estimators (IQR, MAD)."""
89
-
90
- def __init__(self, dataset: list[dict]):
91
- self.dataset = dataset
92
- self.n = len(dataset)
93
- self.columns = sorted({k for row in dataset for k in row.keys()})
94
- self.types = self._infer_types()
95
- self.profiles = {}
96
-
97
- def _infer_types(self) -> dict:
98
- types = {}
99
- for col in self.columns:
100
- vals = [r.get(col) for r in self.dataset if r.get(col) is not None]
101
- if not vals:
102
- types[col] = "unknown"
103
- elif all(isinstance(v, (int, float)) for v in vals):
104
- types[col] = "numeric"
105
- elif all(isinstance(v, str) and self._is_date(v) for v in vals[:5]):
106
- types[col] = "date"
107
- elif col.lower().endswith("_id") or col.lower() == "id":
108
- types[col] = "id"
109
- else:
110
- types[col] = "categorical"
111
- return types
112
 
 
 
113
  @staticmethod
114
- def _is_date(s: str) -> bool:
115
- for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%d-%m-%Y"):
116
- try:
117
- datetime.strptime(s, fmt)
118
- return True
119
- except ValueError:
120
- pass
121
- return False
 
 
 
 
 
 
122
 
123
- def profile_numeric(self, col: str) -> dict:
124
- values = [r[col] for r in self.dataset if r.get(col) is not None]
125
- null_count = sum(1 for r in self.dataset if r.get(col) is None)
126
- if not values:
127
- return {"null_count": null_count, "valid_count": 0}
128
-
129
- sorted_vals = sorted(values)
130
- n = len(sorted_vals)
131
- median = statistics.median(sorted_vals)
132
- mean = statistics.mean(sorted_vals)
133
- std = statistics.stdev(sorted_vals) if n > 1 else 0
134
-
135
- q1 = sorted_vals[n // 4] if n >= 4 else sorted_vals[0]
136
- q3 = sorted_vals[3 * n // 4] if n >= 4 else sorted_vals[-1]
137
- iqr = q3 - q1
138
-
139
- mad = statistics.median([abs(v - median) for v in sorted_vals])
140
- mad_scaled = mad * 1.4826
141
-
142
- return {
143
- "mean": round(mean, 2), "std": round(std, 2),
144
- "median": round(median, 2), "mad": round(mad_scaled, 2),
145
- "min": min(values), "max": max(values),
146
- "q1": q1, "q3": q3, "iqr": iqr,
147
- "null_count": null_count, "valid_count": n,
148
- "iqr_lower": q1 - 1.5 * iqr,
149
- "iqr_upper": q3 + 1.5 * iqr,
150
- }
151
-
152
- def profile_categorical(self, col: str) -> dict:
153
- vals = [str(r.get(col, "None")) for r in self.dataset]
154
- counter = Counter(vals)
155
- total = len(vals)
156
- return {
157
- "distribution": dict(counter),
158
- "unique_count": len(counter),
159
- "mode": counter.most_common(1)[0][0] if counter else None,
160
- "mode_ratio": counter.most_common(1)[0][1] / total if counter else 0,
161
- }
162
-
163
- def profile_all(self) -> dict:
164
- for col in self.columns:
165
- if self.types.get(col) == "numeric":
166
- self.profiles[col] = self.profile_numeric(col)
167
- elif self.types.get(col) == "categorical":
168
- self.profiles[col] = self.profile_categorical(col)
169
- return self.profiles
170
-
171
-
172
- # ═══════════════════════════════════════════════════════════════════════════
173
- # MODULE 2: ANOMALY DETECTORS — Confidence + Risk scoring
174
- # ═══════════════════════════════════════════════════════════════════════════
175
-
176
- class AgeAnomalyDetector:
177
- """
178
- Multi-layer age anomaly detection:
179
- - Layer 1: Clinical domain constraints (18-120 for trial eligibility)
180
- - Layer 2: Statistical outliers via IQR
181
- - Layer 3: Biological plausibility
182
- """
183
- CLINICAL_MIN, CLINICAL_MAX = 18, 120
184
-
185
- def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
186
- findings = []
187
- age_prof = profile.get("age", {})
188
- median = age_prof.get("median", 60)
189
- mad = age_prof.get("mad", 15)
190
 
191
- for row in dataset:
192
- pid = row.get("patient_id", "?")
193
- age = row.get("age")
194
 
195
- if age is None:
196
- findings.append(Finding(
197
- patient_id=pid, error_type="invalid_age",
198
- reason="Missing age required for trial eligibility",
199
- confidence=1.0, risk="high", value=None,
200
- statistical_context="Null value in mandatory field",
201
- ))
202
- continue
 
203
 
204
- is_domain_violation = age < self.CLINICAL_MIN or age > self.CLINICAL_MAX
205
-
206
- if is_domain_violation:
207
- deviation = abs(age - median) / mad if mad > 0 else 0
208
- is_biological_impossible = age < 0 or age > 122
209
- if is_biological_impossible:
210
- conf, risk = 1.0, "critical"
211
- context = f"Biologically impossible (age={age})"
212
- elif age > 200:
213
- conf, risk = 0.99, "critical"
214
- context = f"Likely sentinel/data entry error: {deviation:.1f} MAD from median"
215
- else:
216
- conf, risk = 0.95, "high"
217
- context = f"Outside range [{self.CLINICAL_MIN}-{self.CLINICAL_MAX}]"
218
-
219
- findings.append(Finding(
220
- patient_id=pid, error_type="invalid_age",
221
- reason=f"Age {age} violates clinical trial range [{self.CLINICAL_MIN}-{self.CLINICAL_MAX}]",
222
- confidence=conf, risk=risk, value=age,
223
- statistical_context=context,
224
- ))
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  return findings
227
 
228
 
229
- class TemporalConsistencyDetector:
230
- """Detects death_date before treatment_start violations."""
231
-
232
- @staticmethod
233
- def _parse_date(val) -> Optional[datetime]:
234
- if not val or val in ("", "N/A", "None", "null"):
235
- return None
236
- for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%d-%m-%Y", "%Y/%m/%d"):
237
- try:
238
- return datetime.strptime(str(val), fmt)
239
- except (ValueError, TypeError):
240
- pass
241
- return None
242
-
243
- def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
244
  findings = []
245
  for row in dataset:
246
- pid = row.get("patient_id", "?")
247
- early = self._parse_date(row.get("treatment_start"))
248
- late = self._parse_date(row.get("death_date"))
249
- if early and late and late < early:
250
- gap = (early - late).days
251
- risk = "critical" if gap > 180 else "high" if gap > 30 else "medium"
252
- conf = min(1.0, 0.90 + gap / 3650)
253
- findings.append(Finding(
254
- patient_id=pid, error_type="temporal_inconsistency",
255
- reason=f"death_date {row.get('death_date')} is {gap} days before treatment_start {row.get('treatment_start')}",
256
- confidence=conf, risk=risk,
257
- value=f"{gap}-day violation",
258
- statistical_context=f"Chronological ordering violated by {gap} days",
259
- ))
260
  return findings
261
 
262
 
263
- class SelectionBiasDetector:
264
- """Multi-dimensional bias detection in control group."""
265
- REPRESENTATION_THRESHOLD = 0.65
266
- OUTCOME_DISPARITY_THRESHOLD = 0.20
267
-
268
- def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
269
  findings = []
270
- control = [r for r in dataset if r.get("group") == "control"]
271
- if not control:
272
- return findings
273
-
274
- total_control = len(control)
275
- eth_counts = Counter(r.get("ethnicity", "Unknown") for r in control)
276
- dominant = eth_counts.most_common(1)[0] if eth_counts else None
277
- if not dominant:
278
- return findings
279
-
280
- dominant_name, dominant_count = dominant
281
- representation_ratio = dominant_count / total_control
282
-
283
- outcome_rates = {}
284
- for eth, count in eth_counts.items():
285
- deceased = sum(1 for r in control if r.get("ethnicity") == eth and r.get("outcome") == "deceased")
286
- outcome_rates[eth] = deceased / count if count > 0 else 0
 
287
 
288
- rates = list(outcome_rates.values())
289
- max_disparity = max(rates) - min(rates) if len(rates) > 1 else 0
290
 
291
- minority_deceased = sum(
292
- 1 for r in control
293
- if r.get("ethnicity") != dominant_name and r.get("outcome") == "deceased"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
- minority_total = total_control - dominant_count
296
- minority_mortality = minority_deceased / minority_total if minority_total > 0 else 0
297
-
298
- male_control = sum(1 for r in control if r.get("gender") == "M")
299
- male_ratio = male_control / total_control
300
-
301
- evidence = []
302
- confidence = 0.0
303
-
304
- if representation_ratio >= self.REPRESENTATION_THRESHOLD:
305
- evidence.append(f"Representation: {dominant_name}={representation_ratio:.0%} of control")
306
- confidence += 0.4
307
- if minority_deceased > 0:
308
- evidence.append(f"Outcome disparity: minority mortality={minority_mortality:.0%}")
309
- confidence += 0.2
310
- if male_ratio >= 0.5:
311
- evidence.append(f"Gender imbalance: male={male_ratio:.0%}")
312
- confidence += 0.1
313
- if max_disparity > self.OUTCOME_DISPARITY_THRESHOLD:
314
- evidence.append(f"Statistically significant disparity: Δ={max_disparity:.0%}")
315
- confidence += 0.15
316
-
317
- confidence = min(1.0, confidence)
318
-
319
- if confidence >= 0.6 and representation_ratio >= self.REPRESENTATION_THRESHOLD:
320
- findings.append(Finding(
321
- patient_id=None, error_type="selection_bias",
322
- reason="Multi-dimensional selection bias: " + "; ".join(evidence),
323
- confidence=confidence, risk="critical",
324
- value=f"{dominant_name}={representation_ratio:.0%}",
325
- statistical_context=f"Representation: {dominant_name}={representation_ratio:.0%} | Disparity: Δ={max_disparity:.0%} | Minority mortality: {minority_mortality:.0%}",
326
- ))
327
-
328
- return findings
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
- # ═══════════════════════════════════════════════════════════════════════════
332
- # MODULE 3: ACTION PLANNER
333
- # ═══════════════════════════════════════════════════════════════════════════
334
 
335
  class ActionPlanner:
336
- """Plans optimal action sequence adapted to task type and step budget."""
337
-
338
- def plan(self, findings: list[Finding], task_type: str,
339
- max_steps: int = 20) -> list[AuditAction]:
340
- actions = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- # Phase 1: Investigation (3 steps)
343
- investigate = ["age", "death_date", "ethnicity"]
344
- for var in investigate:
345
- actions.append(AuditAction(action_type="investigate_pattern", variable=var))
 
 
 
 
 
 
346
 
347
- # Phase 2: Flag findings by priority
348
- data_findings = [f for f in findings if f.error_type != "selection_bias"]
349
- bias_findings = [f for f in findings if f.error_type == "selection_bias"]
350
 
351
- data_findings.sort(key=lambda f: -f.priority_score)
352
 
353
- bias_slots = 1 if (bias_findings and task_type == "comprehensive_audit") else 0
354
- max_data_flags = max_steps - len(investigate) - 1 - bias_slots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
- flagged = set()
357
- for f in data_findings[:max_data_flags]:
358
- if f.patient_id in flagged:
359
- continue
360
- flagged.add(f.patient_id)
361
- actions.append(AuditAction(
362
- action_type="flag_error",
363
- patient_id=f.patient_id,
364
- error_type=f.error_type,
365
- reason=f.reason,
366
- ))
367
-
368
- if bias_findings and task_type == "comprehensive_audit":
369
- actions.append(AuditAction(
370
- action_type="flag_error",
371
- error_type="selection_bias",
372
- reason=bias_findings[0].reason,
373
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
- return actions
 
 
 
 
 
376
 
377
 
378
- # ═══════════════════════════════════════════════════════════════════════════
379
- # MODULE 4: LLM REASONING LAYER
380
- # ═══════════════════════════════════════════════════════════════════════════
381
-
382
- def generate_expert_report(client, findings: list[Finding],
383
- profiles: dict, task_type: str) -> str:
384
- """LLM generates expert audit report from pre-analyzed findings."""
385
- age_f = [f for f in findings if f.error_type == "invalid_age"]
386
- temp_f = [f for f in findings if f.error_type == "temporal_inconsistency"]
387
- bias_f = [f for f in findings if f.error_type == "selection_bias"]
388
- age_p = profiles.get("age", {})
389
-
390
- sections = [
391
- f"AUDIT ANALYSIS — Task: {task_type}",
392
- f"Dataset: {age_p.get('valid_count', 0) + age_p.get('null_count', 0)} patients",
393
- f"Age: median={age_p.get('median', '?')}, range=[{age_p.get('min', '?')}, {age_p.get('max', '?')}]",
394
- "", "ISSUES:",
395
- ]
396
-
397
- if age_f:
398
- sections.append(f"• {len(age_f)} age anomalies")
399
- for f in age_f[:3]:
400
- sections.append(f" - {f.patient_id}: age={f.value}")
401
- if temp_f:
402
- sections.append(f"• {len(temp_f)} temporal violations")
403
- for f in temp_f[:3]:
404
- sections.append(f" - {f.patient_id}: {f.value}")
405
- if bias_f:
406
- sections.append(f"• Selection bias: {bias_f[0].statistical_context}")
407
-
408
- try:
409
- completion = client.chat.completions.create(
410
- model=MODEL_NAME,
411
- messages=[
412
- {
413
- "role": "system",
414
- "content": (
415
- "You are a Senior Clinical Data Manager writing a formal audit report. "
416
- "Provide: 1) SUMMARY with severity, 2) ROOT CAUSE analysis, "
417
- "3) RISK ASSESSMENT (impact on trial validity), "
418
- "4) RECOMMENDED corrective actions, "
419
- "5) REGULATORY compliance impact. "
420
- "Be concise (max 150 words). Use professional clinical language."
421
- ),
422
- },
423
- {"role": "user", "content": "\n".join(sections)},
424
- ],
425
- max_tokens=250,
426
- temperature=0,
427
- )
428
- report = completion.choices[0].message.content or ""
429
- if "recommend" not in report.lower():
430
- report += "\nRecommend immediate corrective action for all identified issues."
431
- return report
432
- except Exception as e:
433
- # Deterministic fallback
434
- severity = "CRITICAL" if bias_f else "HIGH" if temp_f else "MEDIUM"
435
- parts = [
436
- f"CLINICAL DATA AUDIT REPORT — {task_type.replace('_', ' ').title()}",
437
- f"\nSUMMARY: {len(findings)} data quality issues identified.",
438
- ]
439
- if age_f:
440
- parts.append(f"\nAGE ANOMALIES ({len(age_f)}): Root cause: data entry errors or ETL pipeline failures.")
441
- if temp_f:
442
- parts.append(f"\nTEMPORAL VIOLATIONS ({len(temp_f)}): Root cause: date field mapping errors.")
443
- if bias_f:
444
- parts.append(f"\nSELECTION BIAS: {bias_f[0].statistical_context}.")
445
- parts.append(f"\nRISK LEVEL: {severity}. Recommend immediate corrective action: "
446
- "quarantine affected records, audit data entry workflows, implement validation "
447
- "rules, and rebalance demographic representation in control group. "
448
- "This impacts regulatory compliance with FDA 21 CFR Part 11 and ICH-GCP guidelines.")
449
- return "\n".join(parts)
450
-
451
-
452
- # ═══════════════════════════════════════════════════════════════════════════
453
- # MODULE 5: METRICS TRACKER
454
- # ═══════════════════════════════════════════════════════════════════════════
455
-
456
  class MetricsTracker:
457
  def __init__(self):
458
  self.true_pos = 0
@@ -461,123 +432,176 @@ class MetricsTracker:
461
  self.steps = 0
462
  self.llm_calls = 0
463
 
464
- def record(self, feedback: str):
465
  self.total_flagged += 1
466
- if "Correct" in feedback or "" in feedback:
467
  self.true_pos += 1
468
- elif "False positive" in feedback or "REJECTED" in feedback or "✗" in feedback:
469
  self.false_pos += 1
470
 
471
  @property
472
  def precision(self) -> float:
473
- return self.true_pos / self.total_flagged if self.total_flagged else 0
474
 
475
  def summary(self) -> str:
476
  return (
477
- f" 📊 Metrics: {self.true_pos}/{self.total_flagged} correct "
478
- f"(precision={self.precision:.0%}) | "
479
- f"{self.steps} steps | {self.llm_calls} LLM call(s)"
480
  )
481
 
482
 
483
- # ═══════════════════════════════════════════════════════════════════════════
484
- # AGENT MODE 1: NAIVE LLM (raw prompt, no statistical tools)
485
- # ═══════════════════════════════════════════════════════════════════════════
 
 
486
 
487
- def run_naive_task(client, task_id: str, task_name: str):
488
- """
489
- Naive agent: sends raw data to LLM, asks it to find errors.
490
- No statistical analysis, no planning. Expected score: ~0.25-0.40
491
- """
492
- print(f"\n Task: {task_name}")
493
- print(" " + "-" * 50)
 
 
 
 
 
 
 
 
 
 
 
 
494
 
 
 
 
 
495
  metrics = MetricsTracker()
496
  final_score = 0.0
497
 
498
- with ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync() as env:
499
- result = env.reset(task_id=task_id, seed=BASELINE_SEED)
500
  obs = result.observation.model_dump()
501
  dataset = obs["dataset"]
502
- task_type = obs["task_type"]
503
  max_steps = obs["attempts_remaining"]
504
- print(f" Patients: {len(dataset)} | Max steps: {max_steps}")
 
505
 
506
- # Send first 30 patients to LLM (token limit)
507
- sample = dataset[:30]
508
- sample_str = json.dumps(sample, indent=1, default=str)
509
 
510
- try:
511
- completion = client.chat.completions.create(
512
- model=MODEL_NAME,
513
- messages=[
514
- {
515
- "role": "system",
516
- "content": "You are a clinical data auditor. Find errors in patient data."
517
- },
518
- {
519
- "role": "user",
520
- "content": (
521
- f"Here are {len(sample)} patient records from a clinical trial. "
522
- f"Find ALL data quality issues.\n"
523
- f"For each issue, respond with ONE line: PATIENT_ID|ERROR_TYPE|REASON\n"
524
- f"ERROR_TYPE must be: invalid_age OR temporal_inconsistency\n"
525
- f"Valid age range: 18-120. Death date must not precede treatment start.\n\n"
526
- f"{sample_str}"
527
- ),
528
- },
529
- ],
530
- max_tokens=500,
531
- temperature=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  )
533
- llm_response = completion.choices[0].message.content or ""
534
- metrics.llm_calls += 1
535
- except Exception as e:
536
- print(f" LLM Error: {e}")
537
- llm_response = ""
538
 
539
- # Investigate (required phase gate)
540
- for var in ["age", "death_date", "ethnicity"]:
541
- if result.done:
542
- break
543
- result = env.step(AuditAction(action_type="investigate_pattern", variable=var))
544
- metrics.steps += 1
 
 
 
 
 
545
 
546
- # Parse LLM response and flag
547
- lines = llm_response.strip().split("\n")
548
- for line in lines:
549
  if result.done:
550
  break
551
- parts = line.strip().split("|")
552
- if len(parts) >= 2:
553
- pid = parts[0].strip()
554
- etype = parts[1].strip().lower().replace(" ", "_")
555
- if etype not in ("invalid_age", "temporal_inconsistency"):
556
- continue
557
- # Check if this patient_id exists
558
- if not any(p.get("patient_id") == pid for p in dataset):
559
- continue
560
- result = env.step(AuditAction(
561
- action_type="flag_error",
562
- patient_id=pid,
563
- error_type=etype,
564
- reason=parts[2].strip() if len(parts) > 2 else "LLM detected",
565
- ))
566
- obs = result.observation.model_dump()
567
- final_score = obs["score_so_far"]
568
  metrics.record(obs["feedback"])
569
- metrics.steps += 1
570
 
571
- # Submit report
572
  if not result.done:
573
- result = env.step(AuditAction(
574
- action_type="submit_report",
575
- report=(
576
- "Clinical data audit report. Issues found in patient ages and temporal "
577
- "sequences. Recommend corrective action for data entry validation. "
578
- "Risk assessment: HIGH. Impact on regulatory compliance noted."
579
- ),
580
- ))
 
 
581
  obs = result.observation.model_dump()
582
  final_score = obs["score_so_far"]
583
  metrics.steps += 1
@@ -586,96 +610,69 @@ def run_naive_task(client, task_id: str, task_name: str):
586
  return final_score, metrics
587
 
588
 
589
- # ═══════════════════════════════════════════════════════════════════════════
590
- # AGENT MODE 2: HEURISTIC (simple rules, no LLM)
591
- # ═══════════════════════════════════════════════════════════════════════════
592
-
593
- def run_heuristic_task(client_unused, task_id: str, task_name: str):
594
- """
595
- Heuristic agent: simple threshold rules, no LLM.
596
- Catches obvious errors but falls for traps. Expected score: ~0.45-0.60
597
- """
598
  print(f"\n Task: {task_name}")
599
- print(" " + "-" * 50)
600
-
601
  metrics = MetricsTracker()
602
  final_score = 0.0
603
 
604
- with ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync() as env:
605
- result = env.reset(task_id=task_id, seed=BASELINE_SEED)
606
  obs = result.observation.model_dump()
607
  dataset = obs["dataset"]
608
- task_type = obs["task_type"]
609
  max_steps = obs["attempts_remaining"]
610
- print(f" Patients: {len(dataset)} | Max steps: {max_steps}")
611
 
612
- # Investigate
613
- for var in ["age", "death_date", "ethnicity"]:
614
- if result.done:
615
- break
616
- result = env.step(AuditAction(action_type="investigate_pattern", variable=var))
617
- metrics.steps += 1
618
 
619
- step_budget = max_steps - metrics.steps - 1 # Reserve 1 for report
620
- flags_made = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
- # Simple age check — catches most but may false-positive on boundaries
623
- for p in dataset:
624
- if flags_made >= step_budget or result.done:
 
 
 
625
  break
626
- age = p.get("age")
627
- pid = p.get("patient_id")
628
-
629
- # BUG: heuristic uses < 18 instead of < 18, catching age=18 incorrectly? No.
630
- # BUG: heuristic uses > 100 instead of > 120, missing ages 101-120 OR
631
- # flagging valid old patients
632
- if age is None or age < 18 or age > 100: # Deliberately wrong threshold
633
- result = env.step(AuditAction(
634
- action_type="flag_error",
635
- patient_id=pid,
636
- error_type="invalid_age",
637
- reason=f"Age {age} outside expected range",
638
- ))
639
- obs = result.observation.model_dump()
640
- final_score = obs["score_so_far"]
641
  metrics.record(obs["feedback"])
642
- metrics.steps += 1
643
- flags_made += 1
644
-
645
- # Simple temporal check (if applicable)
646
- if task_type in ("temporal_consistency", "comprehensive_audit"):
647
- for p in dataset:
648
- if flags_made >= step_budget or result.done:
649
- break
650
- ts = p.get("treatment_start")
651
- dd = p.get("death_date")
652
- if ts and dd:
653
- try:
654
- t = datetime.strptime(ts, "%Y-%m-%d")
655
- d = datetime.strptime(dd, "%Y-%m-%d")
656
- # BUG: heuristic flags ANY death within 7 days (catches traps)
657
- if d < t or (d - t).days < 7:
658
- pid = p.get("patient_id")
659
- result = env.step(AuditAction(
660
- action_type="flag_error",
661
- patient_id=pid,
662
- error_type="temporal_inconsistency",
663
- reason=f"Suspicious temporal sequence",
664
- ))
665
- obs = result.observation.model_dump()
666
- final_score = obs["score_so_far"]
667
- metrics.record(obs["feedback"])
668
- metrics.steps += 1
669
- flags_made += 1
670
- except ValueError:
671
- pass
672
-
673
- # Submit report
674
  if not result.done:
675
- result = env.step(AuditAction(
676
- action_type="submit_report",
677
- report="Audit complete. Found age and temporal issues. Action recommended.",
678
- ))
 
 
 
 
 
679
  obs = result.observation.model_dump()
680
  final_score = obs["score_so_far"]
681
  metrics.steps += 1
@@ -684,205 +681,152 @@ def run_heuristic_task(client_unused, task_id: str, task_name: str):
684
  return final_score, metrics
685
 
686
 
687
- # ═══════════════════════════════════════════════════════════════════════════
688
- # AGENT MODE 3: FULL AGENTIC PIPELINE
689
- # ═══════════════════════════════════════════════════════════════════════════
690
-
691
- def run_full_task(client, task_id: str, task_name: str):
692
- """
693
- Full agent: Statistical detection + LLM reasoning.
694
- Expected score: ~0.85-0.95
695
- """
696
  print(f"\n Task: {task_name}")
697
- print(" " + "-" * 50)
698
-
699
  metrics = MetricsTracker()
700
  final_score = 0.0
701
 
702
- with ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync() as env:
703
- result = env.reset(task_id=task_id, seed=BASELINE_SEED)
704
  obs = result.observation.model_dump()
705
  dataset = obs["dataset"]
706
- task_type = obs["task_type"]
707
  max_steps = obs["attempts_remaining"]
708
- print(f" Type: {task_type} | Patients: {len(dataset)} | Max steps: {max_steps}")
709
-
710
- # 1. PROFILE
711
- profiler = DataProfiler(dataset)
712
- profiles = profiler.profile_all()
713
- ap = profiles.get("age", {})
714
- print(f" Profile: age median={ap.get('median','?')}, "
715
- f"range=[{ap.get('min','?')}-{ap.get('max','?')}], "
716
- f"nulls={ap.get('null_count',0)}")
717
-
718
- # 2. DETECT
719
- all_findings = []
720
- all_findings.extend(AgeAnomalyDetector().detect(dataset, profiles))
721
- if task_type in ("temporal_consistency", "comprehensive_audit"):
722
- all_findings.extend(TemporalConsistencyDetector().detect(dataset, profiles))
723
- if task_type == "comprehensive_audit":
724
- all_findings.extend(SelectionBiasDetector().detect(dataset, profiles))
725
-
726
- age_n = sum(1 for f in all_findings if f.error_type == "invalid_age")
727
- temp_n = sum(1 for f in all_findings if f.error_type == "temporal_inconsistency")
728
- bias_n = sum(1 for f in all_findings if f.error_type == "selection_bias")
729
- print(f" Detected: {age_n} age | {temp_n} temporal | {bias_n} bias")
730
-
731
- # 3. PLAN
732
- planner = ActionPlanner()
733
- actions = planner.plan(all_findings, task_type, max_steps=max_steps)
734
 
735
- # 4. REASON (1 LLM call for report)
736
- report_text = generate_expert_report(client, all_findings, profiles, task_type)
737
- metrics.llm_calls += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
 
739
- # 5. EXECUTE
740
- step = 0
741
  for action in actions:
742
  if result.done:
743
  break
744
  result = env.step(action)
745
  obs = result.observation.model_dump()
746
  final_score = obs["score_so_far"]
747
- feedback = obs["feedback"]
748
- step += 1
749
- metrics.steps = step
750
  if action.action_type == "flag_error":
751
- metrics.record(feedback)
752
- # Print progress every 5 steps or for flags
753
- if action.action_type == "flag_error" or step <= 3:
754
- print(f" Step {step}: score={final_score:.2f} | {feedback[:65]}")
755
 
756
- # 6. REPORT
757
  if not result.done:
758
- result = env.step(AuditAction(action_type="submit_report", report=report_text))
759
  obs = result.observation.model_dump()
760
  final_score = obs["score_so_far"]
761
- step += 1
762
- metrics.steps = step
763
- print(f" Step {step}: score={final_score:.2f} | Report submitted")
764
 
765
  print(metrics.summary())
766
  return final_score, metrics
767
 
768
 
769
- # ═══════════════════════════════════════════════════════════════════════════
770
- # ORCHESTRATOR
771
- # ═══════════════════════════════════════════════════════════════════════════
772
-
773
- TASK_LIST = {
774
- "task_easy": "Syntactic Cleaning (Easy)",
775
- "task_medium": "Temporal Consistency (Medium)",
776
- "task_hard": "Equity Bias Audit (Hard)",
777
- }
778
-
779
-
780
- def run_agent(mode: str, client):
781
- """Run one agent mode across all tasks."""
782
  runner = {
783
  "naive": run_naive_task,
784
  "heuristic": run_heuristic_task,
785
  "full": run_full_task,
786
  }[mode]
787
 
788
- scores, all_metrics = [], []
789
- t0 = time.time()
790
-
791
- for tid, tname in TASK_LIST.items():
792
- score, m = runner(client, tid, tname)
793
  scores.append(score)
794
- all_metrics.append(m)
795
- print(f" Final: {score:.2f}\n")
796
-
797
- elapsed = time.time() - t0
798
- avg = sum(scores) / len(scores)
799
- total_steps = sum(m.steps for m in all_metrics)
800
- total_llm = sum(m.llm_calls for m in all_metrics)
801
- avg_prec = statistics.mean(m.precision for m in all_metrics) if all_metrics else 0
802
 
803
  return {
804
  "mode": mode,
805
  "scores": dict(zip(TASK_LIST.keys(), scores)),
806
- "average": avg,
807
- "elapsed": elapsed,
808
- "total_steps": total_steps,
809
- "total_llm": total_llm,
810
- "avg_precision": avg_prec,
811
  }
812
 
813
 
814
  def main():
815
- parser = argparse.ArgumentParser(description="Clinical Trial Auditor Baseline Inference")
816
- parser.add_argument("--mode", choices=["naive", "heuristic", "full", "all"],
817
- default="full", help="Agent mode (default: full)")
818
  args = parser.parse_args()
819
 
820
- # Only create LLM client when needed (heuristic mode doesn't use LLM)
821
- needs_llm = args.mode in ("naive", "full", "all")
822
- if needs_llm:
823
- api_key = API_KEY or os.getenv("OPENAI_API_KEY")
824
- if not api_key:
825
- print("WARNING: No API key found. Set HF_TOKEN, API_KEY, or OPENAI_API_KEY.")
826
- print(" Falling back to heuristic mode.")
827
- args.mode = "heuristic"
828
- client = None
829
- else:
830
- client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
831
- else:
832
- client = None
833
 
834
- print("=" * 65)
835
- print(" Clinical Trial Auditor — Baseline Inference")
836
- print(" Procedural Dataset Generation | Adversarial Traps | Seed-Reproducible")
837
  print(f" Model: {MODEL_NAME}")
838
- print(f" Seed: {BASELINE_SEED}")
839
- print("=" * 65)
 
 
840
 
841
- if args.mode == "all":
842
- modes = ["naive", "heuristic", "full"]
843
- else:
844
- modes = [args.mode]
845
-
846
- all_results = []
847
  for mode in modes:
848
- print(f"\n{'═' * 65}")
849
  print(f" AGENT: {mode.upper()}")
850
- print(f"{'═' * 65}")
851
- result = run_agent(mode, client)
852
- all_results.append(result)
853
 
854
- # ── Final Report ──
855
- print("\n" + "=" * 65)
856
  print(" BENCHMARK RESULTS")
857
- print("=" * 65)
858
-
859
- if len(all_results) > 1:
860
- # Multi-agent comparison table
861
- header = f" {'Agent':<15} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'Avg':>8} {'Prec':>8} {'Time':>8}"
862
  print(header)
863
- print(" " + "-" * 63)
864
- for r in all_results:
865
- scores = r["scores"]
866
- print(f" {r['mode'].upper():<15} "
867
- f"{scores.get('task_easy', 0):.2f} "
868
- f"{scores.get('task_medium', 0):.2f} "
869
- f"{scores.get('task_hard', 0):.2f} "
870
- f"{r['average']:.2f} "
871
- f"{r['avg_precision']:.0%} "
872
- f"{r['elapsed']:.1f}s")
873
  else:
874
- r = all_results[0]
875
- for tid, tname in TASK_LIST.items():
876
- score = r["scores"].get(tid, 0)
877
- print(f" {tname:35s}: {score:.2f}")
878
- print(f"\n Average score: {r['average']:.2f}")
879
- print(f" Total time: {r['elapsed']:.1f}s")
880
- print(f" LLM calls: {r['total_llm']}")
881
- print(f" Total steps: {r['total_steps']}")
882
- print(f" Average precision: {r['avg_precision']:.0%}")
883
-
884
- print("=" * 65)
885
 
886
 
887
  if __name__ == "__main__":
888
- main()
 
1
  """
2
+ Clinical Trial Auditor — Baseline Inference
3
+ ===========================================
4
+ Demonstrates a deliberate difficulty gradient on the protocol-aware benchmark:
5
+
6
+ 1. NAIVE raw prompt + small sample, weak structure
7
+ 2. HEURISTIC — parses obvious rules but ignores key exceptions
8
+ 3. FULL parses protocol, honors stage exceptions, stage-adjusts bias
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import json
15
  import os
16
+ import re
17
+ import statistics
18
  import sys
19
  import time
 
 
 
 
 
20
  from collections import Counter
21
+ from dataclasses import dataclass, field
22
+ from datetime import datetime
23
+ from types import SimpleNamespace
24
  from typing import Optional
25
 
26
+ from openai import OpenAI
27
+
28
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
29
 
 
30
  from client import ClinicalTrialAuditorEnv
31
  from models import AuditAction
32
 
33
+ try:
34
+ from server.clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
35
+ except ImportError:
36
+ ClinicalTrialAuditorEnvironment = None
37
+
38
+
39
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
40
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
41
  MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
42
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
43
+ BASELINE_SEED = int(os.getenv("BASELINE_SEED", "20260402"))
44
+
45
+ TASK_LIST = {
46
+ "task_easy": "Dynamic Eligibility Screening (Easy)",
47
+ "task_medium": "Protocol Timeline Audit (Medium)",
48
+ "task_hard": "Equity + Protocol Audit (Hard)",
49
+ }
50
+
51
+ TASK_SPECS = {
52
+ "task_easy": {
53
+ "investigations": ["age"],
54
+ "distributions": [],
55
+ },
56
+ "task_medium": {
57
+ "investigations": ["age", "death_date", "enrollment_date", "stage"],
58
+ "distributions": [],
59
+ },
60
+ "task_hard": {
61
+ "investigations": ["age", "death_date", "enrollment_date", "stage"],
62
+ "distributions": ["ethnicity", "gender", "outcome"],
63
+ },
64
+ }
65
+
66
 
67
+ @dataclass
68
+ class ProtocolRules:
69
+ protocol_title: str
70
+ age_min: int
71
+ age_max: int
72
+ treatment_window_days: int
73
+ stage_iv_window_days: int
74
+ high_risk_sites: list[str] = field(default_factory=list)
75
+ bias_control_dominance_threshold: float = 1.0
76
+ bias_male_threshold: float = 1.0
77
+ bias_stage_adjusted_gap: float = 1.0
78
 
79
+ def allowed_window(self, stage: str) -> int:
80
+ return self.stage_iv_window_days if stage == "IV" else self.treatment_window_days
81
 
 
 
 
82
 
83
+ @dataclass
84
  class Finding:
85
+ error_type: str
86
+ reason: str
87
+ patient_id: Optional[str] = None
88
+ confidence: float = 1.0
89
+ risk: str = "medium"
90
+ evidence: str = ""
 
 
 
 
 
91
 
92
  @property
93
  def priority_score(self) -> float:
94
+ risk_weight = {"critical": 1.0, "high": 0.8, "medium": 0.5, "low": 0.2}
95
+ return self.confidence * risk_weight.get(self.risk, 0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+
98
+ class ProtocolParser:
99
  @staticmethod
100
+ def parse(excerpt: str) -> ProtocolRules:
101
+ title_match = re.search(r"TRIAL PROTOCOL EXCERPT\s+[—-]\s+([A-Z0-9-]+)", excerpt)
102
+ age_match = re.search(r"age (\d+)-(\d+) inclusive", excerpt)
103
+ window_match = re.search(r"Treatment must begin within (\d+) days", excerpt)
104
+ stage_match = re.search(r"Stage IV exception: treatment may begin within (\d+) days", excerpt)
105
+ sites_match = re.search(
106
+ r"Stage IV patients at (.+?) are a known high-risk outreach cohort",
107
+ excerpt,
108
+ )
109
+ bias_match = re.search(
110
+ r"dominance exceeds (\d+)%, male share exceeds (\d+)%, "
111
+ r"and stage-adjusted mortality gap exceeds (\d+) percentage points",
112
+ excerpt,
113
+ )
114
 
115
+ high_risk_sites = []
116
+ if sites_match:
117
+ high_risk_sites = [site.strip() for site in sites_match.group(1).split(",")]
118
+
119
+ bias_values = (100, 100, 100)
120
+ if bias_match:
121
+ bias_values = tuple(int(value) for value in bias_match.groups())
122
+
123
+ if not age_match or not window_match or not stage_match:
124
+ raise ValueError("Unable to parse protocol excerpt.")
125
+
126
+ return ProtocolRules(
127
+ protocol_title=(title_match.group(1) if title_match else "UNKNOWN"),
128
+ age_min=int(age_match.group(1)),
129
+ age_max=int(age_match.group(2)),
130
+ treatment_window_days=int(window_match.group(1)),
131
+ stage_iv_window_days=int(stage_match.group(1)),
132
+ high_risk_sites=high_risk_sites,
133
+ bias_control_dominance_threshold=bias_values[0] / 100.0,
134
+ bias_male_threshold=bias_values[1] / 100.0,
135
+ bias_stage_adjusted_gap=bias_values[2] / 100.0,
136
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
 
138
 
139
+ def parse_date(value: Optional[str]) -> Optional[datetime]:
140
+ if not value:
141
+ return None
142
+ for fmt in ("%Y-%m-%d", "%Y/%m/%d", "%m/%d/%Y", "%d-%m-%Y"):
143
+ try:
144
+ return datetime.strptime(str(value), fmt)
145
+ except ValueError:
146
+ continue
147
+ return None
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ class AgeDetector:
151
+ def detect(self, dataset: list[dict], rules: ProtocolRules) -> list[Finding]:
152
+ findings = []
153
+ for row in dataset:
154
+ age = row.get("age")
155
+ if age is None or age < rules.age_min or age > rules.age_max:
156
+ findings.append(
157
+ Finding(
158
+ patient_id=row.get("patient_id"),
159
+ error_type="invalid_age",
160
+ reason=f"Age {age} violates protocol range {rules.age_min}-{rules.age_max}",
161
+ confidence=0.98 if age is None or age < 0 or age > (rules.age_max + 10) else 0.94,
162
+ risk="high",
163
+ )
164
+ )
165
  return findings
166
 
167
 
168
+ class TemporalDetector:
169
+ def detect(self, dataset: list[dict]) -> list[Finding]:
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  findings = []
171
  for row in dataset:
172
+ treatment = parse_date(row.get("treatment_start"))
173
+ death = parse_date(row.get("death_date"))
174
+ if treatment and death and death < treatment:
175
+ gap = (treatment - death).days
176
+ findings.append(
177
+ Finding(
178
+ patient_id=row.get("patient_id"),
179
+ error_type="temporal_inconsistency",
180
+ reason=f"death_date precedes treatment_start by {gap} days",
181
+ confidence=min(1.0, 0.92 + gap / 500.0),
182
+ risk="critical" if gap > 120 else "high",
183
+ )
184
+ )
 
185
  return findings
186
 
187
 
188
+ class ProtocolWindowDetector:
189
+ def detect(self, dataset: list[dict], rules: ProtocolRules, ignore_stage_exception: bool = False) -> list[Finding]:
 
 
 
 
190
  findings = []
191
+ for row in dataset:
192
+ enrollment = parse_date(row.get("enrollment_date"))
193
+ treatment = parse_date(row.get("treatment_start"))
194
+ if not enrollment or not treatment:
195
+ continue
196
+ allowed_days = rules.treatment_window_days if ignore_stage_exception else rules.allowed_window(row.get("stage", ""))
197
+ delay = (treatment - enrollment).days
198
+ if delay > allowed_days:
199
+ findings.append(
200
+ Finding(
201
+ patient_id=row.get("patient_id"),
202
+ error_type="protocol_window_violation",
203
+ reason=f"treatment started after {delay} days (allowed {allowed_days})",
204
+ confidence=0.93 if delay > allowed_days + 3 else 0.82,
205
+ risk="high",
206
+ )
207
+ )
208
+ return findings
209
 
 
 
210
 
211
+ class BiasAnalyzer:
212
+ @staticmethod
213
+ def summarize_control(dataset: list[dict]) -> tuple[list[dict], str, float, float, float]:
214
+ control = [row for row in dataset if row.get("group") == "control"]
215
+ if not control:
216
+ return [], "Unknown", 0.0, 0.0, 0.0
217
+
218
+ counts = Counter(row.get("ethnicity", "Unknown") for row in control)
219
+ dominant_ethnicity, dominant_count = counts.most_common(1)[0]
220
+ dominant_ratio = dominant_count / len(control)
221
+ male_ratio = sum(row.get("gender") == "M" for row in control) / len(control)
222
+
223
+ dominant_group = [row for row in control if row.get("ethnicity") == dominant_ethnicity]
224
+ minority_group = [row for row in control if row.get("ethnicity") != dominant_ethnicity]
225
+ dom_mortality = (
226
+ sum(row.get("outcome") == "deceased" for row in dominant_group) / len(dominant_group)
227
+ if dominant_group
228
+ else 0.0
229
  )
230
+ min_mortality = (
231
+ sum(row.get("outcome") == "deceased" for row in minority_group) / len(minority_group)
232
+ if minority_group
233
+ else 0.0
234
+ )
235
+ overall_gap = min_mortality - dom_mortality
236
+ return control, dominant_ethnicity, dominant_ratio, male_ratio, overall_gap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ @staticmethod
239
+ def stage_adjusted_gap(control: list[dict], dominant_ethnicity: str) -> float:
240
+ weighted_gap = 0.0
241
+ total_weight = 0
242
+ for stage in ("I", "II", "III", "IV"):
243
+ stage_rows = [row for row in control if row.get("stage") == stage]
244
+ dominant_rows = [row for row in stage_rows if row.get("ethnicity") == dominant_ethnicity]
245
+ minority_rows = [row for row in stage_rows if row.get("ethnicity") != dominant_ethnicity]
246
+ if len(dominant_rows) < 5 or len(minority_rows) < 5:
247
+ continue
248
+ dominant_mortality = sum(row.get("outcome") == "deceased" for row in dominant_rows) / len(dominant_rows)
249
+ minority_mortality = sum(row.get("outcome") == "deceased" for row in minority_rows) / len(minority_rows)
250
+ weight = len(stage_rows)
251
+ weighted_gap += (minority_mortality - dominant_mortality) * weight
252
+ total_weight += weight
253
+ return weighted_gap / total_weight if total_weight else 0.0
254
+
255
+ def detect_full(self, dataset: list[dict], rules: ProtocolRules) -> list[Finding]:
256
+ control, dominant_ethnicity, dominant_ratio, male_ratio, overall_gap = self.summarize_control(dataset)
257
+ if not control:
258
+ return []
259
+ adjusted_gap = self.stage_adjusted_gap(control, dominant_ethnicity)
260
+ if (
261
+ dominant_ratio >= rules.bias_control_dominance_threshold
262
+ and male_ratio >= rules.bias_male_threshold
263
+ and adjusted_gap >= rules.bias_stage_adjusted_gap
264
+ ):
265
+ return [
266
+ Finding(
267
+ patient_id=None,
268
+ error_type="selection_bias",
269
+ reason=(
270
+ f"Control-arm skew detected: {dominant_ethnicity}={dominant_ratio:.0%}, "
271
+ f"male={male_ratio:.0%}, stage-adjusted mortality gap={adjusted_gap:.0%}"
272
+ ),
273
+ confidence=0.92,
274
+ risk="critical",
275
+ evidence=f"overall gap={overall_gap:.0%}",
276
+ )
277
+ ]
278
+ return []
279
+
280
+ def detect_heuristic(self, dataset: list[dict], rules: ProtocolRules) -> list[Finding]:
281
+ control, dominant_ethnicity, dominant_ratio, male_ratio, overall_gap = self.summarize_control(dataset)
282
+ if not control:
283
+ return []
284
+ loose_threshold = max(0.10, rules.bias_stage_adjusted_gap - 0.04)
285
+ if dominant_ratio >= max(0.55, rules.bias_control_dominance_threshold - 0.07) and overall_gap >= loose_threshold:
286
+ return [
287
+ Finding(
288
+ patient_id=None,
289
+ error_type="selection_bias",
290
+ reason=(
291
+ f"Heuristic bias concern: {dominant_ethnicity}={dominant_ratio:.0%}, "
292
+ f"male={male_ratio:.0%}, overall mortality gap={overall_gap:.0%}"
293
+ ),
294
+ confidence=0.74,
295
+ risk="high",
296
+ )
297
+ ]
298
+ return []
299
 
 
 
 
300
 
301
  class ActionPlanner:
302
+ def plan(
303
+ self,
304
+ task_id: str,
305
+ findings: list[Finding],
306
+ max_steps: int,
307
+ extra_investigations: Optional[list[str]] = None,
308
+ ) -> list[AuditAction]:
309
+ spec = TASK_SPECS[task_id]
310
+ actions: list[AuditAction] = []
311
+
312
+ investigations = list(spec["investigations"])
313
+ distributions = list(spec["distributions"])
314
+ if extra_investigations:
315
+ investigations.extend(extra_investigations)
316
+
317
+ for variable in investigations:
318
+ actions.append(AuditAction(action_type="investigate_pattern", variable=variable))
319
+ for variable in distributions:
320
+ actions.append(AuditAction(action_type="compute_distribution", variable=variable))
321
+
322
+ record_findings = [finding for finding in findings if finding.error_type != "selection_bias"]
323
+ bias_findings = [finding for finding in findings if finding.error_type == "selection_bias"]
324
+ record_findings.sort(key=lambda finding: -finding.priority_score)
325
+
326
+ max_flag_slots = max_steps - len(actions) - 1 - (1 if bias_findings else 0)
327
+ flagged_ids = set()
328
+ for finding in record_findings[:max_flag_slots]:
329
+ if not finding.patient_id or finding.patient_id in flagged_ids:
330
+ continue
331
+ flagged_ids.add(finding.patient_id)
332
+ actions.append(
333
+ AuditAction(
334
+ action_type="flag_error",
335
+ patient_id=finding.patient_id,
336
+ error_type=finding.error_type,
337
+ reason=finding.reason,
338
+ confidence=finding.confidence,
339
+ )
340
+ )
341
 
342
+ if bias_findings:
343
+ bias = bias_findings[0]
344
+ actions.append(
345
+ AuditAction(
346
+ action_type="flag_error",
347
+ error_type="selection_bias",
348
+ reason=bias.reason,
349
+ confidence=bias.confidence,
350
+ )
351
+ )
352
 
353
+ return actions
 
 
354
 
 
355
 
356
+ def generate_expert_report(
357
+ client: Optional[OpenAI],
358
+ rules: ProtocolRules,
359
+ findings: list[Finding],
360
+ task_name: str,
361
+ ) -> str:
362
+ finding_lines = []
363
+ for finding in findings[:8]:
364
+ if finding.patient_id:
365
+ finding_lines.append(f"- {finding.patient_id}: {finding.error_type} | {finding.reason}")
366
+ else:
367
+ finding_lines.append(f"- {finding.error_type}: {finding.reason}")
368
+
369
+ prompt = "\n".join(
370
+ [
371
+ f"Protocol: {rules.protocol_title}",
372
+ f"Task: {task_name}",
373
+ f"Key rules: age {rules.age_min}-{rules.age_max}, "
374
+ f"standard start <= {rules.treatment_window_days} days, "
375
+ f"stage IV <= {rules.stage_iv_window_days} days.",
376
+ "",
377
+ "Findings:",
378
+ *finding_lines,
379
+ "",
380
+ "Write a concise audit report with protocol grounding, root cause, risk, corrective actions, "
381
+ "and fairness reasoning when relevant.",
382
+ ]
383
+ )
384
 
385
+ if client is not None:
386
+ try:
387
+ completion = client.chat.completions.create(
388
+ model=MODEL_NAME,
389
+ messages=[
390
+ {
391
+ "role": "system",
392
+ "content": (
393
+ "You are a senior clinical data manager. Produce a concise report "
394
+ "with protocol grounding, root cause, risk, corrective action, and "
395
+ "fairness reasoning when applicable."
396
+ ),
397
+ },
398
+ {"role": "user", "content": prompt},
399
+ ],
400
+ temperature=0,
401
+ max_tokens=240,
402
+ )
403
+ content = completion.choices[0].message.content or ""
404
+ if content:
405
+ return content
406
+ except Exception:
407
+ pass
408
+
409
+ if any(finding.error_type == "selection_bias" for finding in findings):
410
+ fairness_line = (
411
+ "Fairness review: control-arm patterns were reviewed with stage-adjusted comparisons "
412
+ "before escalating the bias conclusion."
413
+ )
414
+ else:
415
+ fairness_line = (
416
+ "Fairness review: no actionable control-arm bias was confirmed after stage-adjusted review."
417
+ )
418
 
419
+ return (
420
+ f"Protocol-grounded audit for {rules.protocol_title}. Root cause analysis indicates site-level "
421
+ f"data capture and scheduling control weaknesses. Risk assessment: protocol compliance and endpoint "
422
+ f"validity are affected. Recommended corrective actions include quarantining impacted records, "
423
+ f"tightening enrollment-to-treatment validations, and retraining site coordinators. {fairness_line}"
424
+ )
425
 
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  class MetricsTracker:
428
  def __init__(self):
429
  self.true_pos = 0
 
432
  self.steps = 0
433
  self.llm_calls = 0
434
 
435
+ def record(self, feedback: str) -> None:
436
  self.total_flagged += 1
437
+ if "" in feedback or "Correct" in feedback:
438
  self.true_pos += 1
439
+ elif "" in feedback or "REJECTED" in feedback:
440
  self.false_pos += 1
441
 
442
  @property
443
  def precision(self) -> float:
444
+ return self.true_pos / self.total_flagged if self.total_flagged else 0.0
445
 
446
  def summary(self) -> str:
447
  return (
448
+ f" Metrics: {self.true_pos}/{self.total_flagged} correct "
449
+ f"(precision={self.precision:.0%}) | {self.steps} steps | {self.llm_calls} LLM call(s)"
 
450
  )
451
 
452
 
453
+ class InProcessEnvSession:
454
+ def __init__(self):
455
+ if ClinicalTrialAuditorEnvironment is None:
456
+ raise RuntimeError("In-process environment is unavailable.")
457
+ self._env = ClinicalTrialAuditorEnvironment()
458
 
459
+ def __enter__(self):
460
+ return self
461
+
462
+ def __exit__(self, exc_type, exc, tb):
463
+ return False
464
+
465
+ def reset(self, **kwargs):
466
+ observation = self._env.reset(**kwargs)
467
+ return SimpleNamespace(observation=observation, reward=observation.reward, done=observation.done)
468
+
469
+ def step(self, action: AuditAction):
470
+ observation = self._env.step(action)
471
+ return SimpleNamespace(observation=observation, reward=observation.reward, done=observation.done)
472
+
473
+
474
+ def open_env_session():
475
+ if ENV_BASE_URL.lower() == "inprocess":
476
+ return InProcessEnvSession()
477
+ return ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync()
478
 
479
+
480
+ def run_naive_task(client: Optional[OpenAI], task_id: str, task_name: str, seed: int):
481
+ print(f"\n Task: {task_name}")
482
+ print(" " + "-" * 54)
483
  metrics = MetricsTracker()
484
  final_score = 0.0
485
 
486
+ with open_env_session() as env:
487
+ result = env.reset(task_id=task_id, seed=seed)
488
  obs = result.observation.model_dump()
489
  dataset = obs["dataset"]
490
+ protocol_excerpt = obs["trial_protocol_excerpt"]
491
  max_steps = obs["attempts_remaining"]
492
+ rules = ProtocolParser.parse(protocol_excerpt)
493
+ print(f" Protocol: {rules.protocol_title} | Patients: {len(dataset)} | Max steps: {max_steps}")
494
 
495
+ sample = dataset[:24]
496
+ guessed_findings: list[Finding] = []
 
497
 
498
+ if client is not None:
499
+ try:
500
+ completion = client.chat.completions.create(
501
+ model=MODEL_NAME,
502
+ messages=[
503
+ {
504
+ "role": "system",
505
+ "content": (
506
+ "You are auditing patient records from a clinical trial. "
507
+ "Return one issue per line as PATIENT_ID|ERROR_TYPE|REASON."
508
+ ),
509
+ },
510
+ {
511
+ "role": "user",
512
+ "content": (
513
+ f"Protocol excerpt:\n{protocol_excerpt}\n\n"
514
+ f"Review only these {len(sample)} records:\n{json.dumps(sample, default=str)}\n\n"
515
+ "Allowed ERROR_TYPE values: invalid_age, temporal_inconsistency, "
516
+ "protocol_window_violation, selection_bias."
517
+ ),
518
+ },
519
+ ],
520
+ temperature=0,
521
+ max_tokens=450,
522
+ )
523
+ metrics.llm_calls += 1
524
+ lines = (completion.choices[0].message.content or "").splitlines()
525
+ for line in lines:
526
+ parts = [part.strip() for part in line.split("|")]
527
+ if len(parts) >= 2:
528
+ guessed_findings.append(
529
+ Finding(
530
+ patient_id=parts[0] if parts[0] and parts[0] != "None" else None,
531
+ error_type=parts[1],
532
+ reason=parts[2] if len(parts) > 2 else "LLM guess",
533
+ confidence=0.65,
534
+ )
535
+ )
536
+ except Exception as exc:
537
+ print(f" LLM error: {exc}")
538
+
539
+ if not guessed_findings:
540
+ for row in sample:
541
+ age = row.get("age")
542
+ treatment = parse_date(row.get("treatment_start"))
543
+ death = parse_date(row.get("death_date"))
544
+ enrollment = parse_date(row.get("enrollment_date"))
545
+ if age is None or age < 0 or age > 120:
546
+ guessed_findings.append(
547
+ Finding(
548
+ patient_id=row.get("patient_id"),
549
+ error_type="invalid_age",
550
+ reason="Sample-level obvious age anomaly",
551
+ confidence=0.55,
552
+ )
553
+ )
554
+ if treatment and death and death < treatment:
555
+ guessed_findings.append(
556
+ Finding(
557
+ patient_id=row.get("patient_id"),
558
+ error_type="temporal_inconsistency",
559
+ reason="Sample-level temporal anomaly",
560
+ confidence=0.60,
561
+ )
562
+ )
563
+ plan_actions = []
564
+ for variable in TASK_SPECS[task_id]["investigations"]:
565
+ plan_actions.append(AuditAction(action_type="investigate_pattern", variable=variable))
566
+ if task_id == "task_hard":
567
+ plan_actions.extend(
568
+ AuditAction(action_type="compute_distribution", variable=variable)
569
+ for variable in TASK_SPECS[task_id]["distributions"]
570
  )
 
 
 
 
 
571
 
572
+ max_flag_slots = max_steps - len(plan_actions) - 1
573
+ for finding in guessed_findings[:max_flag_slots]:
574
+ plan_actions.append(
575
+ AuditAction(
576
+ action_type="flag_error",
577
+ patient_id=finding.patient_id,
578
+ error_type=finding.error_type,
579
+ reason=finding.reason,
580
+ confidence=finding.confidence,
581
+ )
582
+ )
583
 
584
+ for action in plan_actions:
 
 
585
  if result.done:
586
  break
587
+ result = env.step(action)
588
+ obs = result.observation.model_dump()
589
+ final_score = obs["score_so_far"]
590
+ metrics.steps += 1
591
+ if action.action_type == "flag_error":
 
 
 
 
 
 
 
 
 
 
 
 
592
  metrics.record(obs["feedback"])
 
593
 
 
594
  if not result.done:
595
+ result = env.step(
596
+ AuditAction(
597
+ action_type="submit_report",
598
+ report=(
599
+ f"Protocol grounding for {rules.protocol_title}. "
600
+ "Sample review found possible age and timing issues. "
601
+ "Recommend manual review and corrective action."
602
+ ),
603
+ )
604
+ )
605
  obs = result.observation.model_dump()
606
  final_score = obs["score_so_far"]
607
  metrics.steps += 1
 
610
  return final_score, metrics
611
 
612
 
613
+ def run_heuristic_task(client_unused: Optional[OpenAI], task_id: str, task_name: str, seed: int):
 
 
 
 
 
 
 
 
614
  print(f"\n Task: {task_name}")
615
+ print(" " + "-" * 54)
 
616
  metrics = MetricsTracker()
617
  final_score = 0.0
618
 
619
+ with open_env_session() as env:
620
+ result = env.reset(task_id=task_id, seed=seed)
621
  obs = result.observation.model_dump()
622
  dataset = obs["dataset"]
623
+ rules = ProtocolParser.parse(obs["trial_protocol_excerpt"])
624
  max_steps = obs["attempts_remaining"]
625
+ print(f" Protocol: {rules.protocol_title} | Patients: {len(dataset)} | Max steps: {max_steps}")
626
 
627
+ actions: list[AuditAction] = []
628
+ for variable in TASK_SPECS[task_id]["investigations"]:
629
+ actions.append(AuditAction(action_type="investigate_pattern", variable=variable))
630
+ for variable in TASK_SPECS[task_id]["distributions"]:
631
+ actions.append(AuditAction(action_type="compute_distribution", variable=variable))
 
632
 
633
+ findings = []
634
+ for row in dataset:
635
+ age = row.get("age")
636
+ if age is None or age < (rules.age_min - 3) or age > (rules.age_max + 3):
637
+ findings.append(
638
+ Finding(
639
+ patient_id=row.get("patient_id"),
640
+ error_type="invalid_age",
641
+ reason=f"Heuristic age screen triggered on {age}",
642
+ confidence=0.82,
643
+ risk="high",
644
+ )
645
+ )
646
+ findings.extend(TemporalDetector().detect(dataset))
647
+ if task_id in {"task_medium", "task_hard"}:
648
+ findings.extend(ProtocolWindowDetector().detect(dataset, rules, ignore_stage_exception=True))
649
+ if task_id == "task_hard":
650
+ findings.extend(BiasAnalyzer().detect_heuristic(dataset, rules))
651
 
652
+ planner = ActionPlanner()
653
+ planned_flags = planner.plan(task_id, findings, max_steps=max_steps)
654
+ actions = planned_flags
655
+
656
+ for action in actions:
657
+ if result.done:
658
  break
659
+ result = env.step(action)
660
+ obs = result.observation.model_dump()
661
+ final_score = obs["score_so_far"]
662
+ metrics.steps += 1
663
+ if action.action_type == "flag_error":
 
 
 
 
 
 
 
 
 
 
664
  metrics.record(obs["feedback"])
665
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  if not result.done:
667
+ result = env.step(
668
+ AuditAction(
669
+ action_type="submit_report",
670
+ report=(
671
+ f"Protocol review for {rules.protocol_title}. Root cause is likely data-entry drift. "
672
+ "Recommend validation checks and operational follow-up. Risk is moderate to high."
673
+ ),
674
+ )
675
+ )
676
  obs = result.observation.model_dump()
677
  final_score = obs["score_so_far"]
678
  metrics.steps += 1
 
681
  return final_score, metrics
682
 
683
 
684
+ def run_full_task(client: Optional[OpenAI], task_id: str, task_name: str, seed: int):
 
 
 
 
 
 
 
 
685
  print(f"\n Task: {task_name}")
686
+ print(" " + "-" * 54)
 
687
  metrics = MetricsTracker()
688
  final_score = 0.0
689
 
690
+ with open_env_session() as env:
691
+ result = env.reset(task_id=task_id, seed=seed)
692
  obs = result.observation.model_dump()
693
  dataset = obs["dataset"]
694
+ rules = ProtocolParser.parse(obs["trial_protocol_excerpt"])
695
  max_steps = obs["attempts_remaining"]
696
+ print(f" Protocol: {rules.protocol_title} | Patients: {len(dataset)} | Max steps: {max_steps}")
697
+ print(
698
+ f" Rules: age {rules.age_min}-{rules.age_max} | standard <= {rules.treatment_window_days}d | "
699
+ f"stage IV <= {rules.stage_iv_window_days}d"
700
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
+ findings = []
703
+ findings.extend(AgeDetector().detect(dataset, rules))
704
+ findings.extend(TemporalDetector().detect(dataset))
705
+ if task_id in {"task_medium", "task_hard"}:
706
+ findings.extend(ProtocolWindowDetector().detect(dataset, rules, ignore_stage_exception=False))
707
+ if task_id == "task_hard":
708
+ findings.extend(BiasAnalyzer().detect_full(dataset, rules))
709
+
710
+ age_count = sum(f.error_type == "invalid_age" for f in findings)
711
+ temporal_count = sum(f.error_type == "temporal_inconsistency" for f in findings)
712
+ window_count = sum(f.error_type == "protocol_window_violation" for f in findings)
713
+ bias_count = sum(f.error_type == "selection_bias" for f in findings)
714
+ print(
715
+ f" Detected: age={age_count} | temporal={temporal_count} | "
716
+ f"window={window_count} | bias={bias_count}"
717
+ )
718
+
719
+ extra_checks = {
720
+ "task_easy": ["enrollment_date", "stage", "group", "treatment_site", "country"],
721
+ "task_medium": ["group", "treatment_site", "outcome", "country", "drug"],
722
+ "task_hard": ["treatment_site", "group", "country", "drug", "trial_phase"],
723
+ }.get(task_id, [])
724
+ actions = ActionPlanner().plan(task_id, findings, max_steps=max_steps, extra_investigations=extra_checks)
725
+ report = generate_expert_report(client, rules, findings, task_name)
726
+ if client is not None:
727
+ metrics.llm_calls += 1
728
 
 
 
729
  for action in actions:
730
  if result.done:
731
  break
732
  result = env.step(action)
733
  obs = result.observation.model_dump()
734
  final_score = obs["score_so_far"]
735
+ metrics.steps += 1
 
 
736
  if action.action_type == "flag_error":
737
+ metrics.record(obs["feedback"])
738
+ if action.action_type == "flag_error" or metrics.steps <= 5:
739
+ print(f" Step {metrics.steps}: score={final_score:.2f} | {obs['feedback'][:80]}")
 
740
 
 
741
  if not result.done:
742
+ result = env.step(AuditAction(action_type="submit_report", report=report))
743
  obs = result.observation.model_dump()
744
  final_score = obs["score_so_far"]
745
+ metrics.steps += 1
746
+ print(f" Step {metrics.steps}: score={final_score:.2f} | report submitted")
 
747
 
748
  print(metrics.summary())
749
  return final_score, metrics
750
 
751
 
752
+ def run_agent(mode: str, client: Optional[OpenAI], seed: int):
 
 
 
 
 
 
 
 
 
 
 
 
753
  runner = {
754
  "naive": run_naive_task,
755
  "heuristic": run_heuristic_task,
756
  "full": run_full_task,
757
  }[mode]
758
 
759
+ scores = []
760
+ metrics_list = []
761
+ start = time.time()
762
+ for task_id, task_name in TASK_LIST.items():
763
+ score, metrics = runner(client, task_id, task_name, seed)
764
  scores.append(score)
765
+ metrics_list.append(metrics)
766
+ print(f" Final score: {score:.2f}\n")
 
 
 
 
 
 
767
 
768
  return {
769
  "mode": mode,
770
  "scores": dict(zip(TASK_LIST.keys(), scores)),
771
+ "average": sum(scores) / len(scores),
772
+ "elapsed": time.time() - start,
773
+ "total_steps": sum(metric.steps for metric in metrics_list),
774
+ "total_llm": sum(metric.llm_calls for metric in metrics_list),
775
+ "avg_precision": statistics.mean(metric.precision for metric in metrics_list),
776
  }
777
 
778
 
779
  def main():
780
+ parser = argparse.ArgumentParser(description="Clinical Trial Auditor baseline inference")
781
+ parser.add_argument("--mode", choices=["naive", "heuristic", "full", "all"], default="full")
782
+ parser.add_argument("--seed", type=int, default=BASELINE_SEED)
783
  args = parser.parse_args()
784
 
785
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None
 
 
 
 
 
 
 
 
 
 
 
 
786
 
787
+ print("=" * 70)
788
+ print(" Clinical Trial Auditor — Protocol-Aware Baseline Inference")
789
+ print(" Dynamic Rules | Adversarial Traps | Stage-Adjusted Fairness Review")
790
  print(f" Model: {MODEL_NAME}")
791
+ print(f" Seed: {args.seed}")
792
+ print("=" * 70)
793
+ if client is None:
794
+ print(" Note: no API key detected. Naive/full report generation will use deterministic fallbacks.")
795
 
796
+ modes = ["naive", "heuristic", "full"] if args.mode == "all" else [args.mode]
797
+ results = []
 
 
 
 
798
  for mode in modes:
799
+ print(f"\n{'═' * 70}")
800
  print(f" AGENT: {mode.upper()}")
801
+ print(f"{'═' * 70}")
802
+ results.append(run_agent(mode, client, args.seed))
 
803
 
804
+ print("\n" + "=" * 70)
 
805
  print(" BENCHMARK RESULTS")
806
+ print("=" * 70)
807
+ if len(results) > 1:
808
+ header = f" {'Agent':<12} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'Avg':>8} {'Prec':>8} {'Time':>8}"
 
 
809
  print(header)
810
+ print(" " + "-" * 66)
811
+ for result in results:
812
+ scores = result["scores"]
813
+ print(
814
+ f" {result['mode'].upper():<12} "
815
+ f"{scores['task_easy']:.2f} {scores['task_medium']:.2f} "
816
+ f"{scores['task_hard']:.2f} {result['average']:.2f} "
817
+ f"{result['avg_precision']:.0%} {result['elapsed']:.1f}s"
818
+ )
 
819
  else:
820
+ result = results[0]
821
+ for task_id, task_name in TASK_LIST.items():
822
+ print(f" {task_name:38s}: {result['scores'][task_id]:.2f}")
823
+ print(f"\n Average score: {result['average']:.2f}")
824
+ print(f" Total time: {result['elapsed']:.1f}s")
825
+ print(f" LLM calls: {result['total_llm']}")
826
+ print(f" Total steps: {result['total_steps']}")
827
+ print(f" Average precision: {result['avg_precision']:.0%}")
828
+ print("=" * 70)
 
 
829
 
830
 
831
  if __name__ == "__main__":
832
+ main()
models.py CHANGED
@@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
2
  from pydantic import Field
3
  from openenv.core.env_server import Action, Observation, State
4
 
 
5
  class AuditAction(Action):
6
  action_type: str = "flag_error"
7
  patient_id: Optional[str] = None
@@ -12,30 +13,43 @@ class AuditAction(Action):
12
  report: Optional[str] = None
13
  confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
14
 
 
15
  class AuditObservation(Observation):
16
  done: bool = False
17
  reward: float = 0.0
18
  task_id: str = ""
19
  task_type: str = ""
20
  task_description: str = ""
 
 
21
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
22
  errors_found: List[str] = Field(default_factory=list)
23
  patterns_investigated: List[str] = Field(default_factory=list)
24
  distributions_computed: List[str] = Field(default_factory=list)
25
  feedback: Optional[str] = None
26
  score_so_far: float = 0.0
 
 
27
  attempts_remaining: int = 15
28
  phase: str = "investigation"
29
 
 
30
  class AuditState(State):
31
  episode_id: str = ""
32
  step_count: int = 0
33
  task_id: str = ""
34
  task_type: str = ""
 
 
35
  total_errors: int = 0
36
  errors_found: int = 0
37
  current_score: float = 0.0
 
 
 
 
38
  attempts: int = 0
39
  phase: str = "investigation"
 
40
  patterns_investigated: List[str] = Field(default_factory=list)
41
- distributions_computed: List[str] = Field(default_factory=list)
 
2
  from pydantic import Field
3
  from openenv.core.env_server import Action, Observation, State
4
 
5
+
6
  class AuditAction(Action):
7
  action_type: str = "flag_error"
8
  patient_id: Optional[str] = None
 
13
  report: Optional[str] = None
14
  confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
15
 
16
+
17
  class AuditObservation(Observation):
18
  done: bool = False
19
  reward: float = 0.0
20
  task_id: str = ""
21
  task_type: str = ""
22
  task_description: str = ""
23
+ protocol_title: str = ""
24
+ trial_protocol_excerpt: str = ""
25
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
26
  errors_found: List[str] = Field(default_factory=list)
27
  patterns_investigated: List[str] = Field(default_factory=list)
28
  distributions_computed: List[str] = Field(default_factory=list)
29
  feedback: Optional[str] = None
30
  score_so_far: float = 0.0
31
+ dense_reward_total: float = 0.0
32
+ score_breakdown: Dict[str, float] = Field(default_factory=dict)
33
  attempts_remaining: int = 15
34
  phase: str = "investigation"
35
 
36
+
37
  class AuditState(State):
38
  episode_id: str = ""
39
  step_count: int = 0
40
  task_id: str = ""
41
  task_type: str = ""
42
+ protocol_title: str = ""
43
+ trial_protocol_excerpt: str = ""
44
  total_errors: int = 0
45
  errors_found: int = 0
46
  current_score: float = 0.0
47
+ dense_reward_total: float = 0.0
48
+ correct_flags: int = 0
49
+ false_positives: int = 0
50
+ duplicate_flags: int = 0
51
  attempts: int = 0
52
  phase: str = "investigation"
53
+ score_breakdown: Dict[str, float] = Field(default_factory=dict)
54
  patterns_investigated: List[str] = Field(default_factory=list)
55
+ distributions_computed: List[str] = Field(default_factory=list)
openenv.yaml CHANGED
@@ -1,27 +1,29 @@
1
  name: clinical_trial_auditor
2
- version: "2.0.0"
3
  description: >
4
- A production-grade Reinforcement Learning environment for medical AI alignment.
5
- The agent acts as a Senior Clinical Data Manager, utilizing a strict multi-phase
6
- workflow (Investigation Flagging Reporting) to identify syntactic errors,
7
- temporal violations, and multi-dimensional intersectional bias in trial datasets.
 
8
  author: Sumit Saraswat
9
  tags:
10
  - openenv
11
  - clinical
12
- - rl-benchmark
13
- - medical-bias
 
14
  - ai-safety
15
  tasks:
16
  - id: task_easy
17
- name: Syntactic Cleaning
18
  difficulty: easy
19
- description: Investigate dataset distribution and flag patients with invalid age values (out of 18-120 range or null).
20
  - id: task_medium
21
- name: Temporal Consistency
22
  difficulty: medium
23
- description: Investigate temporal variables and flag patients where death_date precedes treatment_start.
24
  - id: task_hard
25
- name: Equity Bias Audit
26
  difficulty: hard
27
- description: Perform multi-dimensional statistical analysis to detect intersectional selection bias affecting minority control group outcomes.
 
1
  name: clinical_trial_auditor
2
+ version: "3.0.0"
3
  description: >
4
+ A protocol-aware clinical audit benchmark for OpenEnv. The agent acts as a Senior
5
+ Clinical Data Manager and must read an episode-specific protocol excerpt, audit
6
+ tabular patient records against dynamic eligibility and timing rules, and decide
7
+ whether suspicious subgroup outcomes represent actionable control-arm bias or a
8
+ confounded high-risk cohort.
9
  author: Sumit Saraswat
10
  tags:
11
  - openenv
12
  - clinical
13
+ - benchmark
14
+ - protocol-reasoning
15
+ - bias-audit
16
  - ai-safety
17
  tasks:
18
  - id: task_easy
19
+ name: Dynamic Eligibility Screening
20
  difficulty: easy
21
+ description: Read the protocol excerpt for the episode and flag patients whose ages violate the protocol-specific eligibility range.
22
  - id: task_medium
23
+ name: Protocol Timeline Audit
24
  difficulty: medium
25
+ description: Audit dynamic age eligibility, death-before-treatment errors, and treatment-start window violations with a Stage IV timing exception.
26
  - id: task_hard
27
+ name: Equity + Protocol Audit
28
  difficulty: hard
29
+ description: Audit record-level protocol issues and determine whether control-arm bias is genuinely present or only confounded by a high-risk outreach cohort.
server/requirements.txt → requirements.txt RENAMED
File without changes
server.log CHANGED
@@ -34,3 +34,7 @@ INFO: connection closed
34
  INFO: 127.0.0.1:53804 - "WebSocket /ws" [accepted]
35
  INFO: connection open
36
  INFO: connection closed
 
 
 
 
 
34
  INFO: 127.0.0.1:53804 - "WebSocket /ws" [accepted]
35
  INFO: connection open
36
  INFO: connection closed
37
+ INFO: 127.0.0.1:56934 - "GET /health HTTP/1.1" 200 OK
38
+ INFO: 127.0.0.1:56965 - "WebSocket /ws" [accepted]
39
+ INFO: connection open
40
+ INFO: connection closed
server/app.py CHANGED
@@ -1,7 +1,12 @@
1
- from openenv.core.env_server import create_fastapi_app
2
- from clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
3
- from models import AuditAction, AuditObservation
4
  import uvicorn
 
 
 
 
 
 
 
 
5
 
6
  app = create_fastapi_app(ClinicalTrialAuditorEnvironment, AuditAction, AuditObservation)
7
 
@@ -9,4 +14,4 @@ def main():
9
  uvicorn.run(app, host="0.0.0.0", port=8000)
10
 
11
  if __name__ == "__main__":
12
- main()
 
 
 
 
1
  import uvicorn
2
+ from openenv.core.env_server import create_fastapi_app
3
+
4
+ try:
5
+ from .clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
6
+ from .models import AuditAction, AuditObservation
7
+ except ImportError:
8
+ from clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
9
+ from models import AuditAction, AuditObservation
10
 
11
  app = create_fastapi_app(ClinicalTrialAuditorEnvironment, AuditAction, AuditObservation)
12
 
 
14
  uvicorn.run(app, host="0.0.0.0", port=8000)
15
 
16
  if __name__ == "__main__":
17
+ main()
server/clinical_trial_auditor_environment.py CHANGED
@@ -1,152 +1,102 @@
1
  """
2
  Clinical Trial Auditor — OpenEnv Environment
3
- =============================================
4
- A production-grade adversarial RL environment for medical AI alignment
5
- and clinical data quality evaluation.
6
-
7
- The agent acts as a Senior Clinical Data Manager auditing procedurally
8
- generated clinical trial datasets from a multi-site Phase III oncology trial.
9
-
10
- Architecture layers:
11
- ┌─────────────────────────────────────────────┐
12
- │ Agent Interface (OpenEnv API) │
13
- │ step() / reset() / state() │
14
- ├─────────────────────────────────────────────┤
15
- │ Scoring Engine (Grader) │
16
- │ Ground-truth comparison, partial credit, │
17
- │ confidence calibration, score composition │
18
- ├─────────────────────────────────────────────┤
19
- │ Trap Engine (Adversarial) │
20
- │ Boundary traps, temporal traps, fake │
21
- │ bias patterns, distractor injection │
22
- ├─────────────────────────────────────────────┤
23
- │ Data Engine (Generator) │
24
- │ Statistical distributions, demographics, │
25
- │ reproducible seeds, configurable params │
26
- └─────────────────────────────────────────────┘
27
-
28
- Key design decisions:
29
- - Procedural generation: every reset() → unique dataset → no memorization
30
- - Ground-truth grading: errors are pre-computed, grading is O(1) lookup
31
- - Confidence-calibrated scoring: overconfident + wrong = devastating penalty
32
- - False positive cost 3× correct reward → forces precision over recall
33
- - Adversarial traps: boundary-valid ages, near-temporal cases, fake patterns
34
- - Multi-phase workflow: Investigation → Flagging → Reporting
35
- - Seed-based reproducibility for deterministic evaluation
36
  """
 
 
 
37
  import uuid
38
  from datetime import datetime
 
39
  from openenv.core.env_server import Environment
40
- from models import AuditAction, AuditObservation, AuditState
41
- from dataset_generator import DatasetGenerator
42
 
43
- # ── Reward Configuration ──────────────────────────────────────────────────
44
- # Calibrated: optimal play → ~0.85-0.95, careless play → devastated
45
- # Key design: false_positive = 3× correct_flag → DESTROYS guessing strategies
 
 
 
 
 
46
  REWARD_CONFIG = {
47
- "correct_flag": 0.10, # +0.10 per correct error flag
48
- "false_positive": -0.30, # -0.30 per wrong flag (3x correct → destroys guessing)
49
- "duplicate_flag": -0.10, # -0.10 per duplicate flag
50
- "investigate_new": 0.05, # +0.05 for investigating a new variable
51
- "investigate_redundant": -0.02, # -0.02 for re-investigating (penalizes loops)
52
- "distribution_new": 0.04, # +0.04 for computing new distribution
53
  "distribution_redundant": -0.02,
54
- "invalid_phase": -0.05, # -0.05 for acting in wrong phase
55
- "unknown_action": -0.05, # -0.05 for invalid action types
56
- "cost_per_step": 0.005, # -0.005 per step (encourages efficiency)
57
- "bonus_efficiency": 0.03, # +0.03 when ≥3 investigated AND ≥3 flagged
58
- "bonus_workflow": 0.03, # +0.03 for correct workflow sequence
59
- "bias_detected": 0.20, # +0.20 for correctly identifying selection bias
60
- "propose_fix_valid": 0.03,
61
- "propose_fix_invalid": -0.05,
62
- "report_bonus_base": 0.05, # +0.05 base for submitting report
63
- "overconfidence_multiplier": 2.0, # 2x penalty when wrong + confidence > 0.8
64
  }
65
 
66
- # ══════════════════════════���════════════════════════════════════════════════
67
- # TASK DEFINITIONS
68
- # ═══════════════════════════════════════════════════════════════════════════
 
 
 
 
69
 
70
  TASKS = {
71
  "task_easy": {
72
  "task_id": "task_easy",
73
- "task_type": "syntactic_cleaning",
74
  "difficulty": "easy",
 
 
75
  "allow_bias": False,
76
- "task_description": (
77
- "CLINICAL DATA AUDIT — Phase III Oncology Trial (ONCO-AX-2024)\n"
78
- "Role: Senior Clinical Data Manager\n\n"
79
- "PHASE 1 — INVESTIGATION:\n"
80
- " Use investigate_pattern(variable=<col>) to profile key variables\n"
81
- " Use compute_distribution(variable=<col>) to compute descriptive stats\n\n"
82
- "PHASE 2 — FLAGGING:\n"
83
- " Use flag_error(patient_id=<id>, error_type='invalid_age') for age violations\n"
84
- " Valid age range for trial eligibility: 18-120 (inclusive)\n"
85
- " Missing age (null) is also invalid — required field\n"
86
- " CAUTION: Some ages are rare but valid (e.g., 95, 19, 120). Do NOT over-flag.\n\n"
87
- "PHASE 3 — REPORTING:\n"
88
- " Use submit_report(report=<comprehensive analysis>) to finalize\n\n"
89
- "Objective: Find ALL patients with invalid ages. Avoid false positives."
90
- ),
91
  },
92
  "task_medium": {
93
  "task_id": "task_medium",
94
- "task_type": "temporal_consistency",
95
  "difficulty": "medium",
 
 
96
  "allow_bias": False,
97
- "task_description": (
98
- "CLINICAL DATA AUDIT — Phase III Oncology Trial (ONCO-AX-2024)\n"
99
- "Role: Senior Clinical Data Manager\n\n"
100
- "PHASE 1 — INVESTIGATION:\n"
101
- " Use investigate_pattern(variable=<col>) to profile key variables\n"
102
- " Use compute_distribution(variable=<col>) to compute descriptive stats\n\n"
103
- "PHASE 2 — FLAGGING:\n"
104
- " Use flag_error with error_type='invalid_age' OR 'temporal_inconsistency'\n"
105
- " Age violations: outside range 18-120 (inclusive) or null\n"
106
- " Temporal violations: death_date MUST NOT precede treatment_start\n"
107
- " NOTE: A patient dying 1 day after treatment start IS valid (not an error)\n\n"
108
- "PHASE 3 — REPORTING:\n"
109
- " Use submit_report(report=<comprehensive analysis>) to finalize\n\n"
110
- "Objective: Find ALL age errors AND temporal inconsistencies."
111
- ),
112
  },
113
  "task_hard": {
114
  "task_id": "task_hard",
115
- "task_type": "comprehensive_audit",
116
  "difficulty": "hard",
 
 
117
  "allow_bias": True,
118
- "task_description": (
119
- "CLINICAL DATA AUDIT — Phase III Oncology Trial (ONCO-AX-2024)\n"
120
- "Role: Senior Clinical Data Manager\n\n"
121
- "PHASE 1 — INVESTIGATION:\n"
122
- " Use investigate_pattern(variable=<col>) to profile key variables\n"
123
- " Use compute_distribution(variable=<col>) to compute descriptive stats\n"
124
- " IMPORTANT: Analyze ethnicity, gender, and outcome distributions in control group\n\n"
125
- "PHASE 2 FLAGGING:\n"
126
- " flag_error with error_type='invalid_age', 'temporal_inconsistency', or 'selection_bias'\n"
127
- " For selection_bias: Identify if the control group has demographic imbalance\n"
128
- " AND whether this correlates with outcome disparity across subgroups\n"
129
- " Look for: representation bias, outcome disparity, intersectional patterns\n\n"
130
- "PHASE 3 — REPORTING:\n"
131
- " Use submit_report(report=<comprehensive analysis>) to finalize\n"
132
- " Include: statistical evidence, root cause analysis, corrective recommendations\n\n"
133
- "Objective: Find ALL data errors AND demographic bias patterns."
134
- ),
135
  },
136
  }
137
 
138
- # Maximum steps per episode — scales with dataset size
139
  MAX_STEPS = {
140
- "task_easy": 20,
141
- "task_medium": 30,
142
- "task_hard": 40,
143
  }
144
 
145
 
146
- # ═══════════════════════════════════════════════════════════════════════════
147
- # ENVIRONMENT IMPLEMENTATION
148
- # ═══════════════════════════════════════════════════════════════════════════
149
-
150
  class ClinicalTrialAuditorEnvironment(Environment):
151
  SUPPORTS_CONCURRENT_SESSIONS = True
152
 
@@ -155,9 +105,12 @@ class ClinicalTrialAuditorEnvironment(Environment):
155
  self._state = AuditState()
156
  self._current_task = None
157
  self._dataset = []
158
- self._ground_truth = {} # {patient_id: [error_types]}
159
- self._traps = set() # valid-but-suspicious patient_ids
160
  self._bias_present = False
 
 
 
161
  self._flagged_patients = set()
162
  self._patterns_investigated = set()
163
  self._distributions_computed = set()
@@ -165,17 +118,185 @@ class ClinicalTrialAuditorEnvironment(Environment):
165
  self._max_steps = 15
166
  self._report_submitted = False
167
  self._phase = "investigation"
168
- self._score_log = [] # Track score composition for transparency
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def reset(self, seed=None, episode_id=None, **kwargs) -> AuditObservation:
171
- """
172
- Reset the environment with a procedurally generated dataset.
173
-
174
- Args:
175
- seed: Random seed for reproducibility. Same seed = identical dataset.
176
- episode_id: Optional episode identifier.
177
- task_id: "task_easy" | "task_medium" | "task_hard"
178
- """
179
  self._action_history = []
180
  task_id = kwargs.get("task_id", "task_easy")
181
  if task_id not in TASKS:
@@ -184,7 +305,6 @@ class ClinicalTrialAuditorEnvironment(Environment):
184
  self._current_task = TASKS[task_id]
185
  difficulty = self._current_task["difficulty"]
186
 
187
- # ── Procedural dataset generation ──
188
  generator = DatasetGenerator(seed=seed)
189
  result = generator.generate(difficulty=difficulty)
190
 
@@ -192,7 +312,9 @@ class ClinicalTrialAuditorEnvironment(Environment):
192
  self._ground_truth = result["ground_truth"]
193
  self._traps = result["traps"]
194
  self._bias_present = result["bias_present"]
195
- gen_stats = result["stats"]
 
 
196
 
197
  self._flagged_patients = set()
198
  self._patterns_investigated = set()
@@ -202,21 +324,32 @@ class ClinicalTrialAuditorEnvironment(Environment):
202
  self._report_submitted = False
203
  self._phase = "investigation"
204
  self._score_log = []
205
-
206
- total_errs = gen_stats["total_errors"]
 
 
 
 
207
 
208
  self._state = AuditState(
209
  episode_id=episode_id or str(uuid.uuid4()),
210
  step_count=0,
211
  task_id=task_id,
212
  task_type=self._current_task["task_type"],
213
- total_errors=total_errs,
 
 
214
  errors_found=0,
215
  current_score=0.0,
 
 
 
 
216
  attempts=0,
217
  phase="investigation",
218
  patterns_investigated=[],
219
  distributions_computed=[],
 
220
  )
221
 
222
  return AuditObservation(
@@ -224,17 +357,20 @@ class ClinicalTrialAuditorEnvironment(Environment):
224
  reward=0.0,
225
  task_id=task_id,
226
  task_type=self._current_task["task_type"],
227
- task_description=self._current_task["task_description"],
 
 
228
  dataset=self._dataset,
229
  errors_found=[],
230
  patterns_investigated=[],
231
  distributions_computed=[],
232
  feedback=(
233
- f"Audit started. Dataset: {len(self._dataset)} patients across "
234
- f"multiple sites and countries. Begin with investigate_pattern "
235
- f"to profile the dataset."
236
  ),
237
  score_so_far=0.0,
 
 
238
  attempts_remaining=self._max_steps,
239
  phase="investigation",
240
  )
@@ -242,11 +378,23 @@ class ClinicalTrialAuditorEnvironment(Environment):
242
  def step(self, action: AuditAction, **kwargs) -> AuditObservation:
243
  if self._current_task is None:
244
  return AuditObservation(
245
- done=True, reward=0.0, task_id="", task_type="",
246
- task_description="Call reset() first.", dataset=[],
247
- errors_found=[], patterns_investigated=[],
248
- distributions_computed=[], feedback="No active episode.",
249
- score_so_far=0.0, attempts_remaining=0, phase="investigation",
 
 
 
 
 
 
 
 
 
 
 
 
250
  )
251
 
252
  self._action_history.append(action.action_type)
@@ -254,72 +402,50 @@ class ClinicalTrialAuditorEnvironment(Environment):
254
  self._state.step_count += 1
255
  self._state.attempts = self._attempts
256
 
257
- # Core grading against ground truth
258
  step_reward, feedback = self._grade(action)
259
 
260
- # ── Confidence-calibrated scoring ──
261
- agent_confidence = action.confidence
262
- if agent_confidence is not None and action.action_type == "flag_error":
263
- agent_confidence = max(0.0, min(1.0, agent_confidence))
264
- if step_reward < 0: # Wrong answer
265
- if agent_confidence > 0.8:
266
- step_reward *= REWARD_CONFIG["overconfidence_multiplier"]
267
- feedback += f" [OVERCONFIDENCE PENALTY: conf={agent_confidence:.0%}]"
268
- elif step_reward > 0: # Correct answer
269
- step_reward *= max(0.5, agent_confidence)
270
-
271
- # Step cost (progressive — later steps cost more)
272
- step_cost = REWARD_CONFIG["cost_per_step"] * (1 + self._attempts * 0.05)
273
- step_reward -= step_cost
274
-
275
- # Anti brute-force (punish spinning without flagging)
276
- if self._attempts > self._max_steps // 2 and len(self._flagged_patients) < 3:
277
- step_reward -= 0.05
278
-
279
- # Efficiency bonus
280
- if len(self._patterns_investigated) >= 3 and len(self._flagged_patients) >= 3:
281
- step_reward += REWARD_CONFIG["bonus_efficiency"]
282
-
283
- # Workflow sequence bonus
284
- if len(self._action_history) >= 3:
285
- if self._action_history[-3:] == [
286
- "investigate_pattern", "compute_distribution", "flag_error"
287
- ]:
288
- step_reward += REWARD_CONFIG["bonus_workflow"]
289
-
290
- # Difficulty multiplier
291
- mult = {
292
- "task_easy": 1.0, "task_medium": 1.2, "task_hard": 1.5
293
- }.get(self._current_task["task_id"], 1.0)
294
- step_reward = round(step_reward * mult, 3)
295
- step_reward = max(-0.5, step_reward)
296
-
297
- self._state.current_score = max(
298
- 0.0, min(1.0, self._state.current_score + step_reward)
299
- )
300
 
301
- # Log score composition
302
- self._score_log.append({
303
- "step": self._attempts,
304
- "action": action.action_type,
305
- "reward": step_reward,
306
- "cumulative": self._state.current_score,
307
- })
308
 
309
  done = self._report_submitted or self._attempts >= self._max_steps
 
 
 
 
 
 
 
 
 
 
310
 
311
  return AuditObservation(
312
  done=done,
313
- reward=step_reward,
314
  task_id=self._current_task["task_id"],
315
  task_type=self._current_task["task_type"],
316
- task_description=self._current_task["task_description"],
 
 
317
  dataset=self._dataset,
318
- errors_found=list(self._flagged_patients),
319
- patterns_investigated=list(self._patterns_investigated),
320
- distributions_computed=list(self._distributions_computed),
321
  feedback=feedback,
322
  score_so_far=self._state.current_score,
 
 
323
  attempts_remaining=max(0, self._max_steps - self._attempts),
324
  phase=self._phase,
325
  )
@@ -328,399 +454,297 @@ class ClinicalTrialAuditorEnvironment(Environment):
328
  def state(self) -> AuditState:
329
  return self._state
330
 
331
- # ═══════════════════════════════════════════════════════════════════
332
- # SCORING ENGINE Deterministic grading against ground truth
333
- # ═══════════════════════════════════════════════════════════════════
 
 
 
 
 
334
 
335
- def _grade(self, action: AuditAction):
336
- """Route action to appropriate grader with phase validation."""
337
- # Phase validation
338
- if self._phase == "investigation" and action.action_type in [
339
- "flag_error", "submit_report"
340
- ]:
341
  return (
342
  REWARD_CONFIG["invalid_phase"],
343
- "PHASE BLOCKED: Investigate variables before flagging. "
344
- "Use investigate_pattern or compute_distribution first."
345
- )
346
- if (self._phase == "flagging"
347
- and action.action_type == "submit_report"
348
- and len(self._flagged_patients) == 0):
349
- return (
350
- REWARD_CONFIG["invalid_phase"],
351
- "PHASE BLOCKED: Flag at least one issue before submitting report."
352
  )
353
 
354
  if action.action_type == "investigate_pattern":
355
  return self._grade_investigate(action)
356
- elif action.action_type == "compute_distribution":
357
  return self._grade_distribution(action)
358
- elif action.action_type == "flag_error":
359
  return self._grade_flag(action)
360
- elif action.action_type == "propose_fix":
361
  return self._grade_propose_fix(action)
362
- elif action.action_type == "submit_report":
363
  return self._grade_report(action)
364
- else:
365
- return (
366
- REWARD_CONFIG["unknown_action"],
367
- f"REJECTED: Unknown action '{action.action_type}'. "
368
- f"Valid: investigate_pattern, compute_distribution, "
369
- f"flag_error, propose_fix, submit_report."
370
- )
371
 
372
- def _grade_investigate(self, action: AuditAction):
373
- variable = action.variable or ""
374
- if not variable:
375
- return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
 
376
 
 
 
377
  valid_vars = {
378
- "age", "gender", "ethnicity", "treatment_start",
379
- "death_date", "outcome", "treatment_site", "group",
380
- "stage", "trial_phase", "drug", "country", "enrollment_date",
 
 
 
 
 
 
 
 
 
 
381
  }
382
-
383
  if variable not in valid_vars:
384
- return (
385
- REWARD_CONFIG["unknown_action"],
386
- f"REJECTED: Unknown variable '{variable}'. "
387
- f"Valid: {', '.join(sorted(valid_vars))}."
388
- )
389
-
390
  if variable in self._patterns_investigated:
391
  return (
392
  REWARD_CONFIG["investigate_redundant"],
393
- f"Already investigated '{variable}'. Use flag_error to act on findings."
394
  )
395
 
396
  self._patterns_investigated.add(variable)
397
- self._state.patterns_investigated.append(variable)
398
-
399
- # Phase transition: unlock flagging after investigating key variables
400
- if (
401
- "age" in self._patterns_investigated
402
- and "death_date" in self._patterns_investigated
403
- and self._phase == "investigation"
404
- ):
405
- self._phase = "flagging"
406
 
407
- # Dynamic statistics based on variable type
408
  if variable == "age":
409
- ages = [p["age"] for p in self._dataset if p.get("age") is not None]
410
- nulls = len([p for p in self._dataset if p.get("age") is None])
411
- if ages:
412
- min_age, max_age = min(ages), max(ages)
413
- feedback = (
414
- f"Age Stats: min={min_age}, max={max_age}, "
415
- f"null_count={nulls}, n={len(ages)}."
416
- )
417
- else:
418
- feedback = f"Age Stats: no valid ages found, null_count={nulls}."
419
- elif variable in ["treatment_start", "death_date", "enrollment_date"]:
420
- vals = [p[variable] for p in self._dataset if p.get(variable)]
421
- feedback = f"Date field '{variable}': {len(vals)} non-null values found. Check temporal alignment."
422
- elif variable == "outcome":
423
- survived = sum(1 for p in self._dataset if p.get("outcome") == "survived")
424
- deceased = sum(1 for p in self._dataset if p.get("outcome") == "deceased")
425
- feedback = f"Outcomes: Survived={survived}, Deceased={deceased}, Total={survived + deceased}."
426
- elif variable == "group":
427
- control = sum(1 for p in self._dataset if p.get("group") == "control")
428
- treatment = sum(1 for p in self._dataset if p.get("group") == "treatment")
429
- feedback = f"Groups: Control={control}, Treatment={treatment}."
 
 
430
  else:
431
  counts = {}
432
- for p in self._dataset:
433
- val = str(p.get(variable, "None"))
434
- counts[val] = counts.get(val, 0) + 1
435
- # Sort by frequency descending
436
- sorted_counts = dict(
437
- sorted(counts.items(), key=lambda x: -x[1])
438
- )
439
- # Truncate if too many unique values
440
- if len(sorted_counts) > 10:
441
- top_10 = dict(list(sorted_counts.items())[:10])
442
- feedback = (
443
- f"{variable.capitalize()} Distribution (top 10 of "
444
- f"{len(sorted_counts)}): {top_10}."
445
- )
446
- else:
447
- feedback = f"{variable.capitalize()} Distribution: {sorted_counts}."
448
-
449
- return REWARD_CONFIG["investigate_new"], f"Investigated '{variable}': {feedback}"
450
-
451
- def _grade_distribution(self, action: AuditAction):
452
  variable = action.variable or ""
453
  if not variable:
454
  return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
455
-
456
  if variable in self._distributions_computed:
457
  return (
458
  REWARD_CONFIG["distribution_redundant"],
459
- f"Distribution for '{variable}' already computed."
460
  )
461
 
462
  self._distributions_computed.add(variable)
463
- self._state.distributions_computed.append(variable)
464
-
465
- # Phase transition via distribution analysis
466
- if (
467
- "ethnicity" in self._distributions_computed
468
- and "outcome" in self._distributions_computed
469
- and self._phase == "investigation"
470
- ):
471
- self._phase = "flagging"
472
 
 
473
  if variable == "ethnicity":
474
- control = [p for p in self._dataset if p.get("group") == "control"]
475
- if control:
476
- eth_counts = {}
477
- for p in control:
478
- eth = p.get("ethnicity", "Unknown")
479
- eth_counts[eth] = eth_counts.get(eth, 0) + 1
480
- total = len(control)
481
- breakdown = ", ".join(
482
- f"{k}={v} ({v / total * 100:.0f}%)"
483
- for k, v in sorted(eth_counts.items(), key=lambda x: -x[1])
484
- )
485
- feedback = f"Control group ethnicity: {breakdown}. Total={total}."
486
- else:
487
- feedback = "No control group patients found."
488
- elif variable == "outcome":
489
- control = [p for p in self._dataset if p.get("group") == "control"]
490
- if control:
491
- deceased_c = sum(
492
- 1 for p in control if p.get("outcome") == "deceased"
493
- )
494
- total = len(control)
495
- feedback = (
496
- f"Control group outcomes: deceased={deceased_c}/{total} "
497
- f"({deceased_c / total * 100:.0f}%). "
498
- f"Survived={total - deceased_c}/{total} "
499
- f"({(total - deceased_c) / total * 100:.0f}%)."
500
- )
501
- else:
502
- feedback = "No control group patients found."
503
  elif variable == "gender":
504
- control = [p for p in self._dataset if p.get("group") == "control"]
505
- if control:
506
- male_c = sum(1 for p in control if p.get("gender") == "M")
507
- total = len(control)
508
- feedback = (
509
- f"Control group gender: Male={male_c}/{total} "
510
- f"({male_c / total * 100:.0f}%), "
511
- f"Female={total - male_c}/{total} "
512
- f"({(total - male_c) / total * 100:.0f}%)."
513
- )
514
- else:
515
- feedback = "No control group patients found."
 
516
  else:
517
- feedback = f"Distribution computed for '{variable}'."
518
 
519
- return REWARD_CONFIG["distribution_new"], f"Distribution '{variable}': {feedback}"
520
 
521
- def _grade_flag(self, action: AuditAction):
522
- """Grade flag action against pre-computed ground truth."""
523
- patient_id = action.patient_id
524
  error_type = action.error_type or ""
 
 
 
 
 
 
525
 
526
- # ── Selection bias flag (no patient_id needed) ──
527
  if error_type == "selection_bias":
528
  if not self._current_task["allow_bias"]:
 
 
 
 
529
  return (
530
- REWARD_CONFIG["false_positive"],
531
- " Selection bias analysis not required for this task."
532
  )
533
-
534
  if "BIAS_FLAG" in self._flagged_patients:
535
- return (
536
- REWARD_CONFIG["duplicate_flag"],
537
- "Selection bias already flagged."
538
- )
539
 
540
- if self._bias_present:
541
- # Verify bias is actually detectable in the data
542
- control = [p for p in self._dataset if p.get("group") == "control"]
543
- if not control:
544
- return (
545
- REWARD_CONFIG["false_positive"],
546
- "Cannot assess bias — no control group found."
547
- )
548
-
549
- white_count = sum(
550
- 1 for p in control if p.get("ethnicity") == "White"
551
- )
552
- white_ratio = white_count / len(control)
553
- minority_dead = sum(
554
- 1 for p in control
555
- if p.get("ethnicity") != "White"
556
- and p.get("outcome") == "deceased"
557
- )
558
- male_count = sum(
559
- 1 for p in control if p.get("gender") == "M"
560
- )
561
- male_ratio = male_count / len(control)
562
-
563
- if white_ratio >= 0.65 and minority_dead > 0 and male_ratio >= 0.50:
564
- self._flagged_patients.add("BIAS_FLAG")
565
- self._state.errors_found += 1
566
- return (
567
- REWARD_CONFIG["bias_detected"],
568
- f"✓ Correct. Multi-dimensional selection bias confirmed: "
569
- f"White={white_ratio:.0%} of control, "
570
- f"minority mortality present ({minority_dead} deceased), "
571
- f"gender imbalance ({male_ratio:.0%} male)."
572
- )
573
- else:
574
- return (
575
- REWARD_CONFIG["false_positive"],
576
- "✗ Statistical evidence insufficient for bias determination."
577
- )
578
- else:
579
  return (
580
- REWARD_CONFIG["false_positive"],
581
- " False positive. No significant selection bias in this dataset."
 
 
 
582
  )
583
 
584
- # ── Data error flags (require patient_id) ──
585
- if patient_id is None:
586
  return (
587
  REWARD_CONFIG["false_positive"],
588
- "REJECTED: Provide patient_id for data errors."
 
 
589
  )
590
 
 
 
 
 
591
  if patient_id in self._flagged_patients:
592
- return (
593
- REWARD_CONFIG["duplicate_flag"],
594
- f"{patient_id} already flagged."
595
- )
596
 
597
- # Check if patient exists in dataset
598
- patient = next(
599
- (p for p in self._dataset if p.get("patient_id") == patient_id),
600
- None
601
- )
602
- if not patient:
603
- return (
604
- REWARD_CONFIG["false_positive"],
605
- f"REJECTED: Patient '{patient_id}' not found in dataset."
606
- )
607
 
608
- # ── Ground truth lookup (O(1) — deterministic) ──
609
  expected_errors = self._ground_truth.get(patient_id, [])
610
-
611
- if error_type == "invalid_age":
612
- if "invalid_age" in expected_errors:
613
- self._flagged_patients.add(patient_id)
614
- self._state.errors_found += 1
615
- age = patient.get("age")
616
  return (
617
  REWARD_CONFIG["correct_flag"],
618
- f"✓ Correct: {patient_id} has invalid age ({age}). Good catch."
619
- )
620
- else:
621
- age = patient.get("age")
622
- return (
623
- REWARD_CONFIG["false_positive"],
624
- f"✗ False positive: {patient_id} age={age} is within valid range [18-120]."
625
  )
626
-
627
- elif error_type == "temporal_inconsistency":
628
- if "temporal_inconsistency" in expected_errors:
629
- self._flagged_patients.add(patient_id)
630
- self._state.errors_found += 1
631
- ts = patient.get("treatment_start", "")
632
- dd = patient.get("death_date", "")
633
- if ts and dd:
634
- t = datetime.strptime(ts, "%Y-%m-%d")
635
- d = datetime.strptime(dd, "%Y-%m-%d")
636
- gap = (t - d).days
637
- return (
638
- REWARD_CONFIG["correct_flag"],
639
- f"✓ Correct: {patient_id} death_date is {gap} days "
640
- f"before treatment_start."
641
- )
642
  return (
643
  REWARD_CONFIG["correct_flag"],
644
- f"✓ Correct: {patient_id} has temporal inconsistency."
 
 
 
 
 
 
 
 
 
645
  )
646
- else:
647
  return (
648
- REWARD_CONFIG["false_positive"],
649
- f" False positive: {patient_id} temporal sequence is valid."
650
  )
651
 
652
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  return (
654
  REWARD_CONFIG["false_positive"],
655
- f"✗ Invalid error_type '{error_type}'. "
656
- f"Valid: invalid_age, temporal_inconsistency, selection_bias."
657
  )
658
 
659
- def _grade_propose_fix(self, action: AuditAction):
 
 
660
  patient_id = action.patient_id or ""
661
  if patient_id not in self._flagged_patients:
662
- return (
663
- REWARD_CONFIG["propose_fix_invalid"],
664
- "Can only propose fix for flagged patients."
665
- )
666
  proposed = action.proposed_value or ""
667
  if len(proposed) > 2:
668
- return (
669
- REWARD_CONFIG["propose_fix_valid"],
670
- f"Fix proposed for {patient_id}."
671
- )
672
- return REWARD_CONFIG["propose_fix_invalid"], "Proposed fix too vague."
673
 
674
- def _grade_report(self, action: AuditAction):
675
- """Grade report quality using multi-dimensional rubric."""
676
  self._report_submitted = True
677
  report = (action.report or action.reason or "").lower()
678
- step_reward = REWARD_CONFIG["report_bonus_base"]
679
-
680
- # Completeness bonus: flagged enough issues
681
- if len(self._flagged_patients) >= 3:
682
- step_reward += 0.03
683
-
684
- # ── Report quality rubric (tests clinical reasoning depth) ──
685
- quality_score = 0
686
  quality_items = []
687
 
688
- # Root cause analysis
689
- if any(kw in report for kw in [
690
- "root cause", "data entry", "etl", "pipeline", "system"
691
- ]):
692
- quality_score += 1
693
- quality_items.append("root cause analysis")
694
-
695
- # Corrective recommendations
696
- if any(kw in report for kw in [
697
- "recommend", "corrective", "action", "mitigation"
698
- ]):
699
- quality_score += 1
700
- quality_items.append("corrective recommendations")
701
-
702
- # Risk assessment
703
- if any(kw in report for kw in [
704
- "risk", "severity", "critical", "impact", "patient safety"
705
- ]):
706
- quality_score += 1
707
  quality_items.append("risk assessment")
708
-
709
- # Regulatory compliance
710
- if any(kw in report for kw in [
711
- "regulatory", "compliance", "fda", "ich", "gcp", "validity"
712
- ]):
713
- quality_score += 1
714
- quality_items.append("regulatory awareness")
715
-
716
- # Quality bonus: +0.02 per dimension (max +0.08)
717
- step_reward += quality_score * 0.02
718
-
719
- quality_feedback = f"Report quality: {quality_score}/4 dimensions"
720
- if quality_items:
721
- quality_feedback += f" ({', '.join(quality_items)})"
722
-
723
- return (
724
- step_reward,
725
- f"Report submitted. {quality_feedback}. Final evaluation complete."
726
- )
 
1
  """
2
  Clinical Trial Auditor — OpenEnv Environment
3
+ ============================================
4
+ Protocol-aware clinical audit benchmark with dynamic rules, adversarial traps,
5
+ and stage-aware fairness evaluation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
+
8
+ from __future__ import annotations
9
+
10
  import uuid
11
  from datetime import datetime
12
+
13
  from openenv.core.env_server import Environment
 
 
14
 
15
+ try:
16
+ from .dataset_generator import DatasetGenerator
17
+ from .models import AuditAction, AuditObservation, AuditState
18
+ except ImportError:
19
+ from dataset_generator import DatasetGenerator
20
+ from models import AuditAction, AuditObservation, AuditState
21
+
22
+
23
  REWARD_CONFIG = {
24
+ "correct_flag": 0.16,
25
+ "false_positive": -0.26,
26
+ "duplicate_flag": -0.08,
27
+ "investigate_new": 0.04,
28
+ "investigate_redundant": -0.02,
29
+ "distribution_new": 0.04,
30
  "distribution_redundant": -0.02,
31
+ "invalid_phase": -0.06,
32
+ "unknown_action": -0.05,
33
+ "cost_per_step": 0.004,
34
+ "bonus_workflow": 0.03,
35
+ "bonus_protocol_window": 0.04,
36
+ "bias_detected": 0.20,
37
+ "propose_fix_valid": 0.02,
38
+ "propose_fix_invalid": -0.04,
39
+ "report_bonus_base": 0.03,
40
+ "overconfidence_multiplier": 1.8,
41
  }
42
 
43
+ SCORE_WEIGHTS = {
44
+ "recall": 0.70,
45
+ "precision": 0.15,
46
+ "workflow": 0.05,
47
+ "efficiency": 0.05,
48
+ "report": 0.05,
49
+ }
50
 
51
  TASKS = {
52
  "task_easy": {
53
  "task_id": "task_easy",
 
54
  "difficulty": "easy",
55
+ "task_type": "eligibility_screening",
56
+ "title": "Dynamic Eligibility Screening",
57
  "allow_bias": False,
58
+ "allowed_error_types": ["invalid_age"],
59
+ "required_investigations": ["age"],
60
+ "required_distributions": [],
 
 
 
 
 
 
 
 
 
 
 
 
61
  },
62
  "task_medium": {
63
  "task_id": "task_medium",
 
64
  "difficulty": "medium",
65
+ "task_type": "protocol_timeline_audit",
66
+ "title": "Protocol Timeline Audit",
67
  "allow_bias": False,
68
+ "allowed_error_types": [
69
+ "invalid_age",
70
+ "temporal_inconsistency",
71
+ "protocol_window_violation",
72
+ ],
73
+ "required_investigations": ["age", "death_date", "enrollment_date", "stage"],
74
+ "required_distributions": [],
 
 
 
 
 
 
 
 
75
  },
76
  "task_hard": {
77
  "task_id": "task_hard",
 
78
  "difficulty": "hard",
79
+ "task_type": "equity_and_protocol_audit",
80
+ "title": "Equity + Protocol Audit",
81
  "allow_bias": True,
82
+ "allowed_error_types": [
83
+ "invalid_age",
84
+ "temporal_inconsistency",
85
+ "protocol_window_violation",
86
+ "selection_bias",
87
+ ],
88
+ "required_investigations": ["age", "death_date", "enrollment_date", "stage"],
89
+ "required_distributions": ["ethnicity", "gender", "outcome"],
 
 
 
 
 
 
 
 
 
90
  },
91
  }
92
 
 
93
  MAX_STEPS = {
94
+ "task_easy": 18,
95
+ "task_medium": 34,
96
+ "task_hard": 46,
97
  }
98
 
99
 
 
 
 
 
100
  class ClinicalTrialAuditorEnvironment(Environment):
101
  SUPPORTS_CONCURRENT_SESSIONS = True
102
 
 
105
  self._state = AuditState()
106
  self._current_task = None
107
  self._dataset = []
108
+ self._ground_truth = {}
109
+ self._traps = set()
110
  self._bias_present = False
111
+ self._protocol = {}
112
+ self._protocol_title = ""
113
+ self._protocol_excerpt = ""
114
  self._flagged_patients = set()
115
  self._patterns_investigated = set()
116
  self._distributions_computed = set()
 
118
  self._max_steps = 15
119
  self._report_submitted = False
120
  self._phase = "investigation"
121
+ self._score_log = []
122
+ self._dense_reward_total = 0.0
123
+ self._correct_flags = 0
124
+ self._false_positive_flags = 0
125
+ self._duplicate_flags = 0
126
+ self._invalid_phase_actions = 0
127
+ self._report_quality = 0.0
128
+
129
+ def _task_description(self) -> str:
130
+ allowed = ", ".join(self._current_task["allowed_error_types"])
131
+ lines = [
132
+ f"CLINICAL TRIAL AUDIT — {self._current_task['title']}",
133
+ "Role: Senior Clinical Data Manager",
134
+ "",
135
+ "Use the protocol excerpt from the observation. Do not assume default clinical rules.",
136
+ f"Allowed error types for this task: {allowed}.",
137
+ "",
138
+ "Workflow",
139
+ "- Investigate the required variables before flagging records.",
140
+ "- Use compute_distribution for cohort-level review when the task asks for bias analysis.",
141
+ "- submit_report should summarize evidence, impact, and corrective action.",
142
+ ]
143
+ if self._current_task["allow_bias"]:
144
+ lines.append("- For selection_bias, determine whether actionable control-arm bias exists at all.")
145
+ return "\n".join(lines)
146
+
147
+ def _required_investigations(self) -> set[str]:
148
+ return set(self._current_task["required_investigations"])
149
+
150
+ def _required_distributions(self) -> set[str]:
151
+ return set(self._current_task["required_distributions"])
152
+
153
+ def _workflow_ready_for_flagging(self) -> bool:
154
+ return self._required_investigations().issubset(self._patterns_investigated)
155
+
156
+ def _bias_review_ready(self) -> bool:
157
+ return self._required_distributions().issubset(self._distributions_computed)
158
+
159
+ def _stage_adjusted_gap(self) -> tuple[float, str, float, float]:
160
+ control = [patient for patient in self._dataset if patient.get("group") == "control"]
161
+ if not control:
162
+ return 0.0, "Unknown", 0.0, 0.0
163
+
164
+ ethnicity_counts = {}
165
+ for patient in control:
166
+ ethnicity = patient.get("ethnicity", "Unknown")
167
+ ethnicity_counts[ethnicity] = ethnicity_counts.get(ethnicity, 0) + 1
168
+ dominant_ethnicity = max(ethnicity_counts.items(), key=lambda item: item[1])[0]
169
+ dominant_ratio = ethnicity_counts[dominant_ethnicity] / len(control)
170
+ male_ratio = sum(patient.get("gender") == "M" for patient in control) / len(control)
171
+
172
+ weighted_gap = 0.0
173
+ total_weight = 0
174
+ for stage in ("I", "II", "III", "IV"):
175
+ stage_patients = [patient for patient in control if patient.get("stage") == stage]
176
+ dominant_stage = [patient for patient in stage_patients if patient.get("ethnicity") == dominant_ethnicity]
177
+ minority_stage = [patient for patient in stage_patients if patient.get("ethnicity") != dominant_ethnicity]
178
+ if len(dominant_stage) < 5 or len(minority_stage) < 5:
179
+ continue
180
+ dom_mortality = sum(patient.get("outcome") == "deceased" for patient in dominant_stage) / len(dominant_stage)
181
+ min_mortality = sum(patient.get("outcome") == "deceased" for patient in minority_stage) / len(minority_stage)
182
+ weight = len(stage_patients)
183
+ weighted_gap += (min_mortality - dom_mortality) * weight
184
+ total_weight += weight
185
+
186
+ stage_adjusted_gap = weighted_gap / total_weight if total_weight else 0.0
187
+ return stage_adjusted_gap, dominant_ethnicity, dominant_ratio, male_ratio
188
+
189
+ def _bias_signal(self) -> dict:
190
+ control = [patient for patient in self._dataset if patient.get("group") == "control"]
191
+ if not control:
192
+ return {
193
+ "signal_present": False,
194
+ "stage_adjusted_gap": 0.0,
195
+ "dominant_ethnicity": "Unknown",
196
+ "dominant_ratio": 0.0,
197
+ "male_ratio": 0.0,
198
+ "overall_gap": 0.0,
199
+ "high_risk_note": "",
200
+ }
201
+
202
+ stage_adjusted_gap, dominant_ethnicity, dominant_ratio, male_ratio = self._stage_adjusted_gap()
203
+ dominant_group = [patient for patient in control if patient.get("ethnicity") == dominant_ethnicity]
204
+ minority_group = [patient for patient in control if patient.get("ethnicity") != dominant_ethnicity]
205
+ dom_mortality = (
206
+ sum(patient.get("outcome") == "deceased" for patient in dominant_group) / len(dominant_group)
207
+ if dominant_group
208
+ else 0.0
209
+ )
210
+ min_mortality = (
211
+ sum(patient.get("outcome") == "deceased" for patient in minority_group) / len(minority_group)
212
+ if minority_group
213
+ else 0.0
214
+ )
215
+ overall_gap = min_mortality - dom_mortality
216
+ signal_present = (
217
+ dominant_ratio >= self._protocol.get("bias_control_dominance_threshold", 1.0)
218
+ and male_ratio >= self._protocol.get("bias_male_threshold", 1.0)
219
+ and stage_adjusted_gap >= self._protocol.get("bias_stage_adjusted_gap", 1.0)
220
+ )
221
+ return {
222
+ "signal_present": signal_present,
223
+ "stage_adjusted_gap": stage_adjusted_gap,
224
+ "dominant_ethnicity": dominant_ethnicity,
225
+ "dominant_ratio": dominant_ratio,
226
+ "male_ratio": male_ratio,
227
+ "overall_gap": overall_gap,
228
+ "high_risk_note": ", ".join(self._protocol.get("high_risk_sites", [])),
229
+ }
230
+
231
+ def _build_breakdown(self) -> dict[str, float]:
232
+ total_targets = max(1, self._state.total_errors)
233
+ recall = min(1.0, self._correct_flags / total_targets)
234
+ precision = self._correct_flags / max(
235
+ 1,
236
+ self._correct_flags + (2 * self._false_positive_flags) + self._duplicate_flags,
237
+ )
238
+ required_investigations = len(self._required_investigations())
239
+ required_distributions = len(self._required_distributions())
240
+ investigation_coverage = (
241
+ min(len(self._patterns_investigated & self._required_investigations()), required_investigations)
242
+ / required_investigations
243
+ if required_investigations
244
+ else 1.0
245
+ )
246
+ distribution_coverage = (
247
+ min(len(self._distributions_computed & self._required_distributions()), required_distributions)
248
+ / required_distributions
249
+ if required_distributions
250
+ else 1.0
251
+ )
252
+ if required_investigations and required_distributions:
253
+ workflow = (0.7 * investigation_coverage) + (0.3 * distribution_coverage)
254
+ elif required_investigations:
255
+ workflow = investigation_coverage
256
+ elif required_distributions:
257
+ workflow = distribution_coverage
258
+ else:
259
+ workflow = 0.0
260
+ workflow *= max(0.0, 1.0 - (0.12 * self._invalid_phase_actions))
261
+
262
+ useful_steps = (
263
+ min(len(self._patterns_investigated), required_investigations)
264
+ + min(len(self._distributions_computed), required_distributions)
265
+ + self._correct_flags
266
+ + (1 if self._report_submitted else 0)
267
+ )
268
+ efficiency = min(1.0, useful_steps / max(1, self._attempts))
269
+ report = self._report_quality / 5.0
270
+ score = (
271
+ SCORE_WEIGHTS["recall"] * recall
272
+ + SCORE_WEIGHTS["precision"] * precision
273
+ + SCORE_WEIGHTS["workflow"] * workflow
274
+ + SCORE_WEIGHTS["efficiency"] * efficiency
275
+ + SCORE_WEIGHTS["report"] * report
276
+ )
277
+ return {
278
+ "recall": round(recall, 3),
279
+ "precision": round(precision, 3),
280
+ "workflow": round(workflow, 3),
281
+ "efficiency": round(efficiency, 3),
282
+ "report": round(report, 3),
283
+ "benchmark_score": round(min(1.0, max(0.0, score)), 3),
284
+ }
285
+
286
+ def _sync_state(self) -> None:
287
+ breakdown = self._build_breakdown()
288
+ self._state.current_score = breakdown["benchmark_score"]
289
+ self._state.dense_reward_total = round(self._dense_reward_total, 3)
290
+ self._state.correct_flags = self._correct_flags
291
+ self._state.false_positives = self._false_positive_flags
292
+ self._state.duplicate_flags = self._duplicate_flags
293
+ self._state.patterns_investigated = sorted(self._patterns_investigated)
294
+ self._state.distributions_computed = sorted(self._distributions_computed)
295
+ self._state.phase = self._phase
296
+ self._state.errors_found = self._correct_flags
297
+ self._state.score_breakdown = breakdown
298
 
299
  def reset(self, seed=None, episode_id=None, **kwargs) -> AuditObservation:
 
 
 
 
 
 
 
 
300
  self._action_history = []
301
  task_id = kwargs.get("task_id", "task_easy")
302
  if task_id not in TASKS:
 
305
  self._current_task = TASKS[task_id]
306
  difficulty = self._current_task["difficulty"]
307
 
 
308
  generator = DatasetGenerator(seed=seed)
309
  result = generator.generate(difficulty=difficulty)
310
 
 
312
  self._ground_truth = result["ground_truth"]
313
  self._traps = result["traps"]
314
  self._bias_present = result["bias_present"]
315
+ self._protocol = result["protocol"]
316
+ self._protocol_title = result["protocol_title"]
317
+ self._protocol_excerpt = result["protocol_excerpt"]
318
 
319
  self._flagged_patients = set()
320
  self._patterns_investigated = set()
 
324
  self._report_submitted = False
325
  self._phase = "investigation"
326
  self._score_log = []
327
+ self._dense_reward_total = 0.0
328
+ self._correct_flags = 0
329
+ self._false_positive_flags = 0
330
+ self._duplicate_flags = 0
331
+ self._invalid_phase_actions = 0
332
+ self._report_quality = 0.0
333
 
334
  self._state = AuditState(
335
  episode_id=episode_id or str(uuid.uuid4()),
336
  step_count=0,
337
  task_id=task_id,
338
  task_type=self._current_task["task_type"],
339
+ protocol_title=self._protocol_title,
340
+ trial_protocol_excerpt=self._protocol_excerpt,
341
+ total_errors=result["stats"]["total_errors"],
342
  errors_found=0,
343
  current_score=0.0,
344
+ dense_reward_total=0.0,
345
+ correct_flags=0,
346
+ false_positives=0,
347
+ duplicate_flags=0,
348
  attempts=0,
349
  phase="investigation",
350
  patterns_investigated=[],
351
  distributions_computed=[],
352
+ score_breakdown=self._build_breakdown(),
353
  )
354
 
355
  return AuditObservation(
 
357
  reward=0.0,
358
  task_id=task_id,
359
  task_type=self._current_task["task_type"],
360
+ task_description=self._task_description(),
361
+ protocol_title=self._protocol_title,
362
+ trial_protocol_excerpt=self._protocol_excerpt,
363
  dataset=self._dataset,
364
  errors_found=[],
365
  patterns_investigated=[],
366
  distributions_computed=[],
367
  feedback=(
368
+ f"Audit started for {self._protocol_title}. Read the protocol excerpt first, "
369
+ "then investigate the required variables before flagging issues."
 
370
  ),
371
  score_so_far=0.0,
372
+ dense_reward_total=0.0,
373
+ score_breakdown=self._build_breakdown(),
374
  attempts_remaining=self._max_steps,
375
  phase="investigation",
376
  )
 
378
  def step(self, action: AuditAction, **kwargs) -> AuditObservation:
379
  if self._current_task is None:
380
  return AuditObservation(
381
+ done=True,
382
+ reward=0.0,
383
+ task_id="",
384
+ task_type="",
385
+ task_description="Call reset() first.",
386
+ protocol_title="",
387
+ trial_protocol_excerpt="",
388
+ dataset=[],
389
+ errors_found=[],
390
+ patterns_investigated=[],
391
+ distributions_computed=[],
392
+ feedback="No active episode.",
393
+ score_so_far=0.0,
394
+ dense_reward_total=0.0,
395
+ score_breakdown={},
396
+ attempts_remaining=0,
397
+ phase="investigation",
398
  )
399
 
400
  self._action_history.append(action.action_type)
 
402
  self._state.step_count += 1
403
  self._state.attempts = self._attempts
404
 
 
405
  step_reward, feedback = self._grade(action)
406
 
407
+ if action.action_type == "flag_error" and action.confidence is not None:
408
+ confidence = max(0.0, min(1.0, action.confidence))
409
+ if step_reward < 0 and confidence > 0.8:
410
+ step_reward *= REWARD_CONFIG["overconfidence_multiplier"]
411
+ feedback += f" [OVERCONFIDENCE PENALTY: conf={confidence:.0%}]"
412
+ elif step_reward > 0:
413
+ step_reward *= max(0.65, confidence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ step_reward -= REWARD_CONFIG["cost_per_step"] * self._attempts
416
+ self._dense_reward_total += step_reward
417
+
418
+ if self._workflow_ready_for_flagging():
419
+ self._phase = "flagging"
 
 
420
 
421
  done = self._report_submitted or self._attempts >= self._max_steps
422
+ self._sync_state()
423
+ self._score_log.append(
424
+ {
425
+ "step": self._attempts,
426
+ "action": action.action_type,
427
+ "reward": round(step_reward, 3),
428
+ "dense_reward_total": round(self._dense_reward_total, 3),
429
+ "benchmark_score": self._state.current_score,
430
+ }
431
+ )
432
 
433
  return AuditObservation(
434
  done=done,
435
+ reward=round(step_reward, 3),
436
  task_id=self._current_task["task_id"],
437
  task_type=self._current_task["task_type"],
438
+ task_description=self._task_description(),
439
+ protocol_title=self._protocol_title,
440
+ trial_protocol_excerpt=self._protocol_excerpt,
441
  dataset=self._dataset,
442
+ errors_found=sorted(self._flagged_patients),
443
+ patterns_investigated=sorted(self._patterns_investigated),
444
+ distributions_computed=sorted(self._distributions_computed),
445
  feedback=feedback,
446
  score_so_far=self._state.current_score,
447
+ dense_reward_total=self._state.dense_reward_total,
448
+ score_breakdown=self._state.score_breakdown,
449
  attempts_remaining=max(0, self._max_steps - self._attempts),
450
  phase=self._phase,
451
  )
 
454
  def state(self) -> AuditState:
455
  return self._state
456
 
457
+ def _grade(self, action: AuditAction) -> tuple[float, str]:
458
+ if self._phase == "investigation" and action.action_type in {"flag_error", "submit_report"}:
459
+ if not self._workflow_ready_for_flagging():
460
+ self._invalid_phase_actions += 1
461
+ return (
462
+ REWARD_CONFIG["invalid_phase"],
463
+ "PHASE BLOCKED: Investigate the required variables before flagging or reporting.",
464
+ )
465
 
466
+ if action.action_type == "submit_report" and not self._flagged_patients:
467
+ self._invalid_phase_actions += 1
 
 
 
 
468
  return (
469
  REWARD_CONFIG["invalid_phase"],
470
+ "PHASE BLOCKED: Flag at least one issue before submitting the report.",
 
 
 
 
 
 
 
 
471
  )
472
 
473
  if action.action_type == "investigate_pattern":
474
  return self._grade_investigate(action)
475
+ if action.action_type == "compute_distribution":
476
  return self._grade_distribution(action)
477
+ if action.action_type == "flag_error":
478
  return self._grade_flag(action)
479
+ if action.action_type == "propose_fix":
480
  return self._grade_propose_fix(action)
481
+ if action.action_type == "submit_report":
482
  return self._grade_report(action)
 
 
 
 
 
 
 
483
 
484
+ return (
485
+ REWARD_CONFIG["unknown_action"],
486
+ "REJECTED: Unknown action. Valid actions are investigate_pattern, compute_distribution, "
487
+ "flag_error, propose_fix, submit_report.",
488
+ )
489
 
490
+ def _grade_investigate(self, action: AuditAction) -> tuple[float, str]:
491
+ variable = action.variable or ""
492
  valid_vars = {
493
+ "age",
494
+ "gender",
495
+ "ethnicity",
496
+ "treatment_start",
497
+ "death_date",
498
+ "outcome",
499
+ "treatment_site",
500
+ "group",
501
+ "stage",
502
+ "trial_phase",
503
+ "drug",
504
+ "country",
505
+ "enrollment_date",
506
  }
 
507
  if variable not in valid_vars:
508
+ return REWARD_CONFIG["unknown_action"], f"REJECTED: Unknown variable '{variable}'."
 
 
 
 
 
509
  if variable in self._patterns_investigated:
510
  return (
511
  REWARD_CONFIG["investigate_redundant"],
512
+ f"Already investigated '{variable}'. Move to another variable or flag a finding.",
513
  )
514
 
515
  self._patterns_investigated.add(variable)
 
 
 
 
 
 
 
 
 
516
 
 
517
  if variable == "age":
518
+ ages = [patient["age"] for patient in self._dataset if patient.get("age") is not None]
519
+ null_count = sum(patient.get("age") is None for patient in self._dataset)
520
+ feedback = (
521
+ f"Age profile: min={min(ages) if ages else 'NA'}, max={max(ages) if ages else 'NA'}, "
522
+ f"null_count={null_count}, protocol_range={self._protocol['age_min']}-{self._protocol['age_max']}."
523
+ )
524
+ elif variable == "death_date":
525
+ non_null = [patient for patient in self._dataset if patient.get("death_date")]
526
+ feedback = f"death_date present for {len(non_null)} patients. Compare against treatment_start."
527
+ elif variable == "enrollment_date":
528
+ delays = [
529
+ (datetime.strptime(patient["treatment_start"], "%Y-%m-%d") - datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")).days
530
+ for patient in self._dataset
531
+ ]
532
+ feedback = (
533
+ f"Enrollment-to-treatment delays: min={min(delays)}, max={max(delays)}, "
534
+ f"standard_window={self._protocol['treatment_window_days']} days."
535
+ )
536
+ elif variable == "stage":
537
+ counts = {stage: 0 for stage in ("I", "II", "III", "IV")}
538
+ for patient in self._dataset:
539
+ counts[patient["stage"]] = counts.get(patient["stage"], 0) + 1
540
+ feedback = f"Stage distribution: {counts}. Stage IV has a longer treatment-start window."
541
  else:
542
  counts = {}
543
+ for patient in self._dataset:
544
+ value = str(patient.get(variable, "None"))
545
+ counts[value] = counts.get(value, 0) + 1
546
+ top_counts = dict(sorted(counts.items(), key=lambda item: -item[1])[:8])
547
+ feedback = f"{variable} distribution snapshot: {top_counts}."
548
+
549
+ reward = REWARD_CONFIG["investigate_new"]
550
+ if variable in {"enrollment_date", "stage"}:
551
+ reward += REWARD_CONFIG["bonus_protocol_window"] / 2
552
+ return reward, f"Investigated '{variable}': {feedback}"
553
+
554
+ def _grade_distribution(self, action: AuditAction) -> tuple[float, str]:
 
 
 
 
 
 
 
 
555
  variable = action.variable or ""
556
  if not variable:
557
  return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
 
558
  if variable in self._distributions_computed:
559
  return (
560
  REWARD_CONFIG["distribution_redundant"],
561
+ f"Distribution for '{variable}' already computed.",
562
  )
563
 
564
  self._distributions_computed.add(variable)
 
 
 
 
 
 
 
 
 
565
 
566
+ control = [patient for patient in self._dataset if patient.get("group") == "control"]
567
  if variable == "ethnicity":
568
+ counts = {}
569
+ for patient in control:
570
+ counts[patient["ethnicity"]] = counts.get(patient["ethnicity"], 0) + 1
571
+ total = len(control) or 1
572
+ feedback = ", ".join(
573
+ f"{key}={value} ({(value / total) * 100:.0f}%)"
574
+ for key, value in sorted(counts.items(), key=lambda item: -item[1])
575
+ )
576
+ message = f"Control-arm ethnicity distribution: {feedback}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  elif variable == "gender":
578
+ male = sum(patient.get("gender") == "M" for patient in control)
579
+ total = len(control) or 1
580
+ message = (
581
+ f"Control-arm gender mix: male={male}/{total} ({(male / total) * 100:.0f}%), "
582
+ f"female={total - male}/{total} ({((total - male) / total) * 100:.0f}%)."
583
+ )
584
+ elif variable == "outcome":
585
+ deceased = sum(patient.get("outcome") == "deceased" for patient in control)
586
+ total = len(control) or 1
587
+ message = (
588
+ f"Control-arm outcomes: deceased={deceased}/{total} ({(deceased / total) * 100:.0f}%), "
589
+ f"survived={total - deceased}/{total} ({((total - deceased) / total) * 100:.0f}%)."
590
+ )
591
  else:
592
+ message = f"Distribution computed for '{variable}'."
593
 
594
+ return REWARD_CONFIG["distribution_new"], message
595
 
596
+ def _grade_flag(self, action: AuditAction) -> tuple[float, str]:
 
 
597
  error_type = action.error_type or ""
598
+ if error_type not in self._current_task["allowed_error_types"]:
599
+ self._false_positive_flags += 1
600
+ return (
601
+ REWARD_CONFIG["false_positive"],
602
+ f"✗ Invalid error_type '{error_type}' for this task.",
603
+ )
604
 
 
605
  if error_type == "selection_bias":
606
  if not self._current_task["allow_bias"]:
607
+ self._false_positive_flags += 1
608
+ return REWARD_CONFIG["false_positive"], "✗ Bias review is not part of this task."
609
+ if not self._bias_review_ready():
610
+ self._invalid_phase_actions += 1
611
  return (
612
+ REWARD_CONFIG["invalid_phase"],
613
+ "PHASE BLOCKED: Compute ethnicity, gender, and outcome distributions before flagging bias.",
614
  )
 
615
  if "BIAS_FLAG" in self._flagged_patients:
616
+ self._duplicate_flags += 1
617
+ return REWARD_CONFIG["duplicate_flag"], "Bias already flagged."
 
 
618
 
619
+ signal = self._bias_signal()
620
+ if self._bias_present and signal["signal_present"]:
621
+ self._flagged_patients.add("BIAS_FLAG")
622
+ self._correct_flags += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  return (
624
+ REWARD_CONFIG["bias_detected"],
625
+ " Correct. Control-arm bias confirmed: "
626
+ f"{signal['dominant_ethnicity']}={signal['dominant_ratio']:.0%}, "
627
+ f"male={signal['male_ratio']:.0%}, "
628
+ f"stage-adjusted mortality gap={signal['stage_adjusted_gap']:.0%}.",
629
  )
630
 
631
+ self._false_positive_flags += 1
 
632
  return (
633
  REWARD_CONFIG["false_positive"],
634
+ " False positive. Current data show either no actionable bias or only a confounded "
635
+ f"high-risk cohort at {signal['high_risk_note']}. "
636
+ f"Overall gap={signal['overall_gap']:.0%}, stage-adjusted gap={signal['stage_adjusted_gap']:.0%}.",
637
  )
638
 
639
+ patient_id = action.patient_id
640
+ if not patient_id:
641
+ self._false_positive_flags += 1
642
+ return REWARD_CONFIG["false_positive"], "REJECTED: patient_id is required for record-level errors."
643
  if patient_id in self._flagged_patients:
644
+ self._duplicate_flags += 1
645
+ return REWARD_CONFIG["duplicate_flag"], f"{patient_id} already flagged."
 
 
646
 
647
+ patient = next((row for row in self._dataset if row.get("patient_id") == patient_id), None)
648
+ if patient is None:
649
+ self._false_positive_flags += 1
650
+ return REWARD_CONFIG["false_positive"], f"REJECTED: Patient '{patient_id}' not found."
 
 
 
 
 
 
651
 
 
652
  expected_errors = self._ground_truth.get(patient_id, [])
653
+ if error_type in expected_errors:
654
+ self._flagged_patients.add(patient_id)
655
+ self._correct_flags += 1
656
+ if error_type == "invalid_age":
 
 
657
  return (
658
  REWARD_CONFIG["correct_flag"],
659
+ f"✓ Correct: {patient_id} age={patient.get('age')} violates protocol range "
660
+ f"{self._protocol['age_min']}-{self._protocol['age_max']}.",
 
 
 
 
 
661
  )
662
+ if error_type == "temporal_inconsistency":
663
+ treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
664
+ death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
665
+ gap = (treatment_start - death_date).days
 
 
 
 
 
 
 
 
 
 
 
 
666
  return (
667
  REWARD_CONFIG["correct_flag"],
668
+ f"✓ Correct: {patient_id} death_date is {gap} days before treatment_start.",
669
+ )
670
+ if error_type == "protocol_window_violation":
671
+ enrollment = datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
672
+ treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
673
+ delay = (treatment_start - enrollment).days
674
+ allowed = (
675
+ self._protocol["stage_iv_treatment_window_days"]
676
+ if patient["stage"] == "IV"
677
+ else self._protocol["treatment_window_days"]
678
  )
 
679
  return (
680
+ REWARD_CONFIG["correct_flag"] + REWARD_CONFIG["bonus_protocol_window"] / 2,
681
+ f" Correct: {patient_id} started treatment after {delay} days; protocol allows only {allowed}.",
682
  )
683
 
684
+ self._false_positive_flags += 1
685
+ if error_type == "invalid_age":
686
+ return (
687
+ REWARD_CONFIG["false_positive"],
688
+ f"✗ False positive: {patient_id} age={patient.get('age')} is valid for protocol range "
689
+ f"{self._protocol['age_min']}-{self._protocol['age_max']}.",
690
+ )
691
+ if error_type == "temporal_inconsistency":
692
+ return (
693
+ REWARD_CONFIG["false_positive"],
694
+ f"✗ False positive: {patient_id} has a valid death/treatment ordering.",
695
+ )
696
+ if error_type == "protocol_window_violation":
697
+ enrollment = datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
698
+ treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
699
+ delay = (treatment_start - enrollment).days
700
+ allowed = (
701
+ self._protocol["stage_iv_treatment_window_days"]
702
+ if patient["stage"] == "IV"
703
+ else self._protocol["treatment_window_days"]
704
+ )
705
  return (
706
  REWARD_CONFIG["false_positive"],
707
+ f"✗ False positive: {patient_id} started treatment after {delay} days, which is valid for stage "
708
+ f"{patient['stage']} (allowed {allowed}).",
709
  )
710
 
711
+ return REWARD_CONFIG["false_positive"], f"✗ Invalid error_type '{error_type}'."
712
+
713
+ def _grade_propose_fix(self, action: AuditAction) -> tuple[float, str]:
714
  patient_id = action.patient_id or ""
715
  if patient_id not in self._flagged_patients:
716
+ return REWARD_CONFIG["propose_fix_invalid"], "Can only propose a fix for a flagged patient."
 
 
 
717
  proposed = action.proposed_value or ""
718
  if len(proposed) > 2:
719
+ return REWARD_CONFIG["propose_fix_valid"], f"Fix proposed for {patient_id}."
720
+ return REWARD_CONFIG["propose_fix_invalid"], "Proposed fix is too vague."
 
 
 
721
 
722
+ def _grade_report(self, action: AuditAction) -> tuple[float, str]:
 
723
  self._report_submitted = True
724
  report = (action.report or action.reason or "").lower()
725
+ quality = 0
 
 
 
 
 
 
 
726
  quality_items = []
727
 
728
+ if any(keyword in report for keyword in ["protocol", "eligibility", "inclusion", "excerpt"]):
729
+ quality += 1
730
+ quality_items.append("protocol grounding")
731
+ if any(keyword in report for keyword in ["root cause", "data entry", "pipeline", "system", "site process"]):
732
+ quality += 1
733
+ quality_items.append("root cause")
734
+ if any(keyword in report for keyword in ["recommend", "corrective", "action", "mitigation"]):
735
+ quality += 1
736
+ quality_items.append("corrective action")
737
+ if any(keyword in report for keyword in ["risk", "severity", "impact", "patient safety"]):
738
+ quality += 1
 
 
 
 
 
 
 
 
739
  quality_items.append("risk assessment")
740
+ if any(keyword in report for keyword in ["bias", "stage-adjusted", "fairness", "control arm", "equity"]):
741
+ quality += 1
742
+ quality_items.append("fairness reasoning")
743
+
744
+ self._report_quality = float(quality)
745
+ reward = REWARD_CONFIG["report_bonus_base"] + (0.015 * quality)
746
+ return reward, (
747
+ f"Report submitted. Quality {quality}/5"
748
+ + (f" ({', '.join(quality_items)})" if quality_items else "")
749
+ + "."
750
+ )
 
 
 
 
 
 
 
 
server/dataset_generator.py CHANGED
@@ -1,34 +1,24 @@
1
  """
2
  Procedural Adversarial Clinical Trial Data Engine
3
- ==================================================
4
- Generates statistically rigorous, adversarial patient datasets for each episode.
5
-
6
- Design philosophy:
7
- - Every reset() unique dataset no memorization possible
8
- - Controlled error injection with known ground truth
9
- - Adversarial traps that punish shallow reasoning
10
- - Seed-based reproducibility for deterministic judging
11
- - Pure stdlib (no numpy) minimal Docker image
12
-
13
- Architecture layers:
14
- 1. Base Patient Generator — realistic demographics via statistical distributions
15
- 2. Error Injector — controlled % of age/temporal/missing violations
16
- 3. Bias Injector — demographic skew + outcome disparity in control group
17
- 4. Trap Injector — boundary-valid, near-temporal, fake-pattern distractors
18
- 5. Ground Truth Tracker — records every injected error for deterministic grading
19
  """
20
 
21
- import random
22
- import math
23
  import hashlib
 
24
  from datetime import datetime, timedelta
25
  from typing import Optional
26
 
27
 
28
- # ═══════════════════════════════════════════════════════════════════════════
29
- # REFERENCE DATA — Realistic clinical trial metadata pools
30
- # ═══════════════════════════════════════════════════════════════════════════
31
-
32
  HOSPITAL_SITES = [
33
  ("Metro General Hospital", "US"),
34
  ("Cleveland Oncology Institute", "US"),
@@ -37,8 +27,8 @@ HOSPITAL_SITES = [
37
  ("MD Anderson Cancer Center", "US"),
38
  ("AIIMS Delhi", "India"),
39
  ("Tata Memorial Hospital", "India"),
40
- ("Charité Berlin", "Germany"),
41
- ("Hospital Clínic Barcelona", "Spain"),
42
  ("Tokyo Medical University", "Japan"),
43
  ("Seoul National University Hospital", "South Korea"),
44
  ("Royal Marsden Hospital", "UK"),
@@ -47,98 +37,97 @@ HOSPITAL_SITES = [
47
  ("Peter MacCallum Cancer Centre", "Australia"),
48
  ]
49
 
50
- # Sites considered "rural" or underrepresented for bias analysis
51
  RURAL_SITES = {
52
- "AIIMS Delhi", "Tata Memorial Hospital",
53
  "Howard University Hospital",
 
54
  }
55
 
56
- ETHNICITIES = ["White", "Black", "Hispanic", "Asian", "Native American", "Pacific Islander"]
 
 
 
 
 
 
 
57
  GENDERS = ["M", "F"]
58
  STAGES = ["I", "II", "III", "IV"]
59
  DRUGS_TREATMENT = ["ImmunoVax-7", "OncoShield-X", "TargetCure-3"]
60
- DRUGS_CONTROL = ["Placebo"]
61
 
62
- # Date range for the trial
63
  TRIAL_START = datetime(2022, 6, 1)
64
  TRIAL_END = datetime(2025, 3, 1)
65
 
66
- # ═══════════════════════════════════════════════════════════════════════════
67
- # DIFFICULTY CONFIGURATIONS
68
- # ═══════════════════════════════════════════════════════════════════════════
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  DIFFICULTY_CONFIGS = {
71
  "easy": {
72
  "dataset_size": 300,
73
- "age_error_rate": 0.03, # 3% of patients have invalid ages
74
- "temporal_error_rate": 0.0, # No temporal errors in easy
75
- "missing_data_rate": 0.01, # 1% missing age
76
- "bias_intensity": 0.0, # No bias in easy
77
- "num_boundary_traps": 5, # Valid edge-case ages
78
  "num_temporal_traps": 0,
79
- "num_distractor_deceased": 4, # Valid deceased patients
 
80
  "num_fake_bias_distractors": 0,
81
- "mortality_rate": 0.12, # 12% overall mortality
82
- "control_ratio": 0.50, # 50/50 control/treatment
83
- "task_type": "syntactic_cleaning",
84
- "allow_bias": False,
85
  },
86
  "medium": {
87
- "dataset_size": 500,
88
- "age_error_rate": 0.03,
89
- "temporal_error_rate": 0.03, # 3% temporal violations
90
- "missing_data_rate": 0.015,
91
- "bias_intensity": 0.0,
92
- "num_boundary_traps": 6,
93
- "num_temporal_traps": 3, # Near-temporal valid cases
94
- "num_distractor_deceased": 5,
 
95
  "num_fake_bias_distractors": 0,
96
- "mortality_rate": 0.15,
97
  "control_ratio": 0.50,
98
- "task_type": "temporal_consistency",
99
- "allow_bias": False,
100
  },
101
  "hard": {
102
- "dataset_size": 800,
103
- "age_error_rate": 0.025,
104
- "temporal_error_rate": 0.025,
105
- "missing_data_rate": 0.01,
106
- "bias_intensity": 0.80, # Strong bias
107
- "num_boundary_traps": 8,
108
- "num_temporal_traps": 4,
 
109
  "num_distractor_deceased": 8,
110
- "num_fake_bias_distractors": 5, # Fake patterns that look biased but aren't
111
- "mortality_rate": 0.18,
112
  "control_ratio": 0.50,
113
- "task_type": "comprehensive_audit",
114
- "allow_bias": True,
115
  },
116
  }
117
 
118
 
119
- # ═══════════════════════════════════════════════════════════════════════════
120
- # DATASET GENERATOR
121
- # ═══════════════════════════════════════════════════════════════════════════
122
-
123
  class DatasetGenerator:
124
- """
125
- Procedural adversarial clinical trial data engine.
126
-
127
- Generates statistically rigorous patient datasets with:
128
- - Configurable size (300-1000+ patients)
129
- - Controlled error injection (age, temporal, missing data)
130
- - Controllable bias intensity (representation + outcome disparity)
131
- - Adversarial traps (boundary-valid, near-temporal, fake patterns)
132
- - Seed-based reproducibility (same seed → identical dataset)
133
-
134
- Usage:
135
- gen = DatasetGenerator(seed=42)
136
- result = gen.generate(difficulty="hard")
137
- dataset = result["dataset"] # List[dict] — patient records
138
- ground_truth = result["ground_truth"] # Dict[str, List[str]] — {pid: [error_types]}
139
- traps = result["traps"] # Set[str] — valid-but-suspicious pids
140
- bias_present = result["bias_present"] # bool
141
- """
142
 
143
  def __init__(self, seed: Optional[int] = None):
144
  self.seed = seed
@@ -151,58 +140,122 @@ class DatasetGenerator:
151
  self._patient_counter += 1
152
  return f"P{self._patient_counter:04d}"
153
 
 
 
 
154
  def _random_date(self, start: datetime, end: datetime) -> datetime:
155
- """Generate a random date between start and end."""
156
  delta = (end - start).days
157
  if delta <= 0:
158
  return start
159
  return start + timedelta(days=self.rng.randint(0, delta))
160
 
161
- def _generate_age(self) -> int:
162
- """Generate a realistic age using truncated normal distribution."""
163
- # Clinical trial typical age: mean=58, std=12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  while True:
165
- age = int(self.rng.gauss(58, 12))
166
- if 18 <= age <= 100:
167
  return age
168
 
169
  def _select_ethnicity(self, bias_mode: str = "neutral") -> str:
170
- """
171
- Select ethnicity with configurable distribution.
172
- bias_mode: "neutral" | "white_dominant" | "diverse"
173
- """
174
- if bias_mode == "white_dominant":
175
- weights = [0.78, 0.06, 0.06, 0.05, 0.03, 0.02]
176
- elif bias_mode == "diverse":
177
- weights = [0.30, 0.20, 0.20, 0.15, 0.10, 0.05]
178
- else: # neutral — matches US clinical trial demographics
179
- weights = [0.55, 0.15, 0.15, 0.10, 0.03, 0.02]
180
-
181
  return self.rng.choices(ETHNICITIES, weights=weights, k=1)[0]
182
 
183
- def _generate_base_patient(self, group: str, ethnicity: str = None,
184
- bias_mode: str = "neutral") -> dict:
185
- """Generate a single valid patient record."""
 
 
 
 
 
 
 
 
186
  pid = self._next_pid()
187
  site, country = self.rng.choice(HOSPITAL_SITES)
188
- gender = self.rng.choice(GENDERS)
189
- eth = ethnicity or self._select_ethnicity(bias_mode)
190
- age = self._generate_age()
191
- stage = self.rng.choices(STAGES, weights=[0.25, 0.30, 0.25, 0.20], k=1)[0]
192
-
193
- enrollment_date = self._random_date(TRIAL_START, TRIAL_END - timedelta(days=180))
194
- treatment_start = enrollment_date + timedelta(days=self.rng.randint(7, 30))
195
-
196
- if group == "treatment":
197
- drug = self.rng.choice(DRUGS_TREATMENT)
198
- else:
199
- drug = "Placebo"
200
-
201
- patient = {
202
  "patient_id": pid,
203
  "age": age,
204
- "gender": gender,
205
- "ethnicity": eth,
206
  "group": group,
207
  "treatment_start": treatment_start.strftime("%Y-%m-%d"),
208
  "death_date": None,
@@ -210,459 +263,422 @@ class DatasetGenerator:
210
  "treatment_site": site,
211
  "stage": stage,
212
  "trial_phase": "Phase III",
213
- "drug": drug,
214
  "enrollment_date": enrollment_date.strftime("%Y-%m-%d"),
215
  "country": country,
216
  }
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  return patient
219
 
220
- def _apply_mortality(self, patient: dict, mortality_rate: float) -> dict:
221
- """Randomly apply mortality with valid timeline."""
222
- if self.rng.random() < mortality_rate:
223
- treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
224
- # Death occurs 1-720 days after treatment start
225
- days_to_death = self.rng.randint(1, 720)
226
- death_date = treatment_start + timedelta(days=days_to_death)
227
- # Cap at trial end
228
- if death_date > TRIAL_END + timedelta(days=365):
229
- death_date = TRIAL_END + timedelta(days=self.rng.randint(1, 180))
230
-
231
- patient["death_date"] = death_date.strftime("%Y-%m-%d")
232
- patient["outcome"] = "deceased"
233
- return patient
 
 
 
 
234
 
235
- # ── Error Injectors ───────────────────────────────────────────────
 
236
 
237
- def _inject_age_errors(self, patients: list[dict], error_rate: float,
238
- missing_rate: float) -> list[dict]:
239
- """Inject invalid age values into random patients."""
240
- n_age_errors = max(3, int(len(patients) * error_rate))
241
- n_missing = max(1, int(len(patients) * missing_rate))
242
 
243
- # Select random indices for age errors (avoid overlap)
 
 
244
  available = list(range(len(patients)))
245
  self.rng.shuffle(available)
246
 
247
- # Invalid age errors
248
- invalid_ages = []
249
- for _ in range(n_age_errors):
250
- error_kind = self.rng.choice([
251
- "negative", "extreme_high", "sentinel", "just_over"
252
- ])
253
- if error_kind == "negative":
254
- invalid_ages.append(self.rng.choice([-1, -5, -10, -3, -15]))
255
- elif error_kind == "extreme_high":
256
- invalid_ages.append(self.rng.choice([150, 200, 250, 300, 500]))
257
- elif error_kind == "sentinel":
258
- invalid_ages.append(self.rng.choice([999, 9999, 0, -999]))
259
- elif error_kind == "just_over":
260
- invalid_ages.append(self.rng.choice([121, 122, 125, 130, 17, 16, 15]))
261
-
262
- for i, invalid_age in enumerate(invalid_ages):
263
- if i >= len(available):
264
- break
265
- idx = available[i]
266
- patients[idx]["age"] = invalid_age
267
- pid = patients[idx]["patient_id"]
268
- self._ground_truth.setdefault(pid, []).append("invalid_age")
269
-
270
- # Missing age (None)
271
- offset = len(invalid_ages)
272
- for j in range(n_missing):
273
- if offset + j >= len(available):
274
- break
275
- idx = available[offset + j]
276
- patients[idx]["age"] = None
277
- pid = patients[idx]["patient_id"]
278
- self._ground_truth.setdefault(pid, []).append("invalid_age")
279
 
280
- return patients
 
 
 
281
 
282
- def _inject_temporal_errors(self, patients: list[dict],
283
- error_rate: float) -> list[dict]:
284
- """Inject temporal violations: death_date before treatment_start."""
285
- n_errors = max(3, int(len(patients) * error_rate))
 
286
 
287
- # Only inject into patients who have death dates or can have one added
288
- candidates = []
289
- for i, p in enumerate(patients):
290
- pid = p["patient_id"]
291
- # Don't stack errors on patients already with age errors
292
- if pid not in self._ground_truth:
293
- candidates.append(i)
294
 
 
 
 
295
  self.rng.shuffle(candidates)
296
 
297
- for k in range(min(n_errors, len(candidates))):
298
- idx = candidates[k]
299
- p = patients[idx]
300
- treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
 
 
301
 
302
- # Death date 15-365 days BEFORE treatment start (clear violation)
303
- gap_days = self.rng.randint(15, 365)
304
- death_date = treatment_start - timedelta(days=gap_days)
305
 
306
- p["death_date"] = death_date.strftime("%Y-%m-%d")
307
- p["outcome"] = "deceased"
 
 
 
 
 
 
 
308
 
309
- pid = p["patient_id"]
310
- self._ground_truth.setdefault(pid, []).append("temporal_inconsistency")
 
 
 
 
 
 
 
 
 
311
 
312
  return patients
313
 
314
- def _inject_bias(self, patients: list[dict], intensity: float) -> list[dict]:
315
- """
316
- Inject multi-dimensional selection bias into the control group.
317
-
318
- Bias structure (mirrors real SEER findings):
319
- 1. Representation: White patients dominate control group (>75%)
320
- 2. Outcome disparity: Minority control patients have higher mortality
321
- 3. Gender imbalance: Males overrepresented in control
322
- 4. Site bias: Minorities underrepresented at major sites
323
- """
324
- if intensity <= 0:
325
- return patients
326
-
327
- control_patients = [p for p in patients if p["group"] == "control"]
328
- treatment_patients = [p for p in patients if p["group"] == "treatment"]
329
-
330
- if not control_patients:
331
- return patients
332
-
333
- # ── Layer 1: Representation bias ──
334
- # Force >75% of control to be White
335
- target_white_ratio = 0.75 + (intensity * 0.10) # 0.75-0.85
336
- n_control = len(control_patients)
337
- n_white_target = int(n_control * target_white_ratio)
338
- n_white_current = sum(1 for p in control_patients if p["ethnicity"] == "White")
339
-
340
- # Convert some non-White control patients to White
341
- non_white_control = [p for p in control_patients if p["ethnicity"] != "White"]
342
- to_convert = max(0, n_white_target - n_white_current)
343
- self.rng.shuffle(non_white_control)
344
- for i in range(min(to_convert, len(non_white_control))):
345
- non_white_control[i]["ethnicity"] = "White"
346
-
347
- # ── Layer 2: Gender imbalance in control ──
348
- # Force >65% male in control
349
- target_male_ratio = 0.65 + (intensity * 0.10)
350
- n_male_target = int(n_control * target_male_ratio)
351
- n_male_current = sum(1 for p in control_patients if p["gender"] == "M")
352
- female_control = [p for p in control_patients if p["gender"] == "F"]
353
- to_convert_gender = max(0, n_male_target - n_male_current)
354
- self.rng.shuffle(female_control)
355
- for i in range(min(to_convert_gender, len(female_control))):
356
- female_control[i]["gender"] = "M"
357
-
358
- # ── Layer 3: Outcome disparity ──
359
- # Minority patients in control → higher mortality (>60%)
360
- minority_control = [
361
- p for p in control_patients
362
- if p["ethnicity"] != "White" and p["patient_id"] not in self._ground_truth
363
  ]
364
- target_minority_mortality = 0.60 + (intensity * 0.15)
365
- n_minority_dead = int(len(minority_control) * target_minority_mortality)
366
-
367
- for i, p in enumerate(minority_control):
368
- if i < n_minority_dead:
369
- if p["outcome"] != "deceased":
370
- treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
371
- death_date = treatment_start + timedelta(
372
- days=self.rng.randint(30, 365)
373
- )
374
- p["death_date"] = death_date.strftime("%Y-%m-%d")
375
- p["outcome"] = "deceased"
376
-
377
- # ── Layer 4: White control patients → low mortality ──
378
- white_control = [
379
- p for p in control_patients
380
- if p["ethnicity"] == "White" and p["patient_id"] not in self._ground_truth
381
  ]
382
- # Keep White mortality low
383
- target_white_survival = 0.85
384
- n_white_alive = int(len(white_control) * target_white_survival)
385
- for i, p in enumerate(white_control):
386
- if i < n_white_alive:
387
- p["death_date"] = None
388
- p["outcome"] = "survived"
389
-
390
- # ── Layer 5: Rural minority underrepresentation ──
391
- for p in minority_control:
392
- if p["treatment_site"] in RURAL_SITES:
393
- # Move some to major sites (reducing rural minority visibility)
394
- if self.rng.random() < intensity * 0.5:
395
- major_sites = [
396
- s for s in HOSPITAL_SITES
397
- if s[0] not in RURAL_SITES
398
- ]
399
- new_site = self.rng.choice(major_sites)
400
- p["treatment_site"] = new_site[0]
401
- p["country"] = new_site[1]
402
-
403
  return patients
404
 
405
- # ── Trap Injectors ────────────────────────────────────────────────
406
-
407
- def _inject_boundary_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
408
- """
409
- Inject boundary-valid ages that trap naive agents.
410
- Ages like 18, 19, 120 are VALID but suspicious.
411
- """
412
- boundary_ages = [18, 19, 20, 90, 92, 95, 96, 100, 105, 110, 115, 118, 119, 120, 120]
413
- self.rng.shuffle(boundary_ages) # Randomize which traps appear
414
  available = [
415
- i for i, p in enumerate(patients)
 
416
  if p["patient_id"] not in self._ground_truth
417
- and p["age"] is not None and 25 <= p["age"] <= 85
 
418
  ]
419
  self.rng.shuffle(available)
 
 
 
 
 
420
 
421
- for k in range(min(n_traps, len(available), len(boundary_ages))):
422
- idx = available[k]
423
- patients[idx]["age"] = boundary_ages[k]
424
- self._traps.add(patients[idx]["patient_id"])
425
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  return patients
427
 
428
- def _inject_temporal_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
429
- """
430
- Inject near-temporal valid cases: death 1-3 days AFTER treatment start.
431
- These are VALID but look like errors to careless agents.
432
- """
433
  available = [
434
- i for i, p in enumerate(patients)
 
435
  if p["patient_id"] not in self._ground_truth
436
- and p["death_date"] is None
437
  and p["patient_id"] not in self._traps
 
438
  ]
439
  self.rng.shuffle(available)
440
-
441
- for k in range(min(n_traps, len(available))):
442
- idx = available[k]
443
- p = patients[idx]
444
- treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
445
- # Death 1-3 days AFTER treatment — valid but suspicious
446
- gap = self.rng.randint(1, 3)
447
- death_date = treatment_start + timedelta(days=gap)
448
- p["death_date"] = death_date.strftime("%Y-%m-%d")
449
- p["outcome"] = "deceased"
450
- p["stage"] = "IV" # Make it medically plausible (late-stage)
451
- self._traps.add(p["patient_id"])
452
-
453
  return patients
454
 
455
- def _inject_fake_bias_distractors(self, patients: list[dict],
456
- n_distractors: int) -> list[dict]:
457
- """
458
- Inject patterns that LOOK like bias but aren't.
459
- E.g., treatment group with demographic skew (doesn't matter for bias detection
460
- since only control group bias is relevant).
461
- """
462
- treatment_patients = [
463
- i for i, p in enumerate(patients)
464
  if p["group"] == "treatment"
465
  and p["patient_id"] not in self._ground_truth
466
  and p["patient_id"] not in self._traps
467
  ]
468
- self.rng.shuffle(treatment_patients)
469
-
470
- for k in range(min(n_distractors, len(treatment_patients))):
471
- idx = treatment_patients[k]
472
- # Make treatment group look skewed (irrelevant for bias detection)
473
- patients[idx]["ethnicity"] = "White"
474
- patients[idx]["gender"] = "M"
475
- self._traps.add(patients[idx]["patient_id"])
476
-
477
  return patients
478
 
479
- def _inject_distractor_deceased(self, patients: list[dict],
480
- n_distractors: int) -> list[dict]:
481
- """
482
- Add deceased patients with perfectly valid timelines.
483
- These are NOT errors tests if agent over-flags deceased patients.
484
- """
485
- available = [
486
- i for i, p in enumerate(patients)
487
- if p["patient_id"] not in self._ground_truth
488
- and p["death_date"] is None
489
- and p["patient_id"] not in self._traps
490
  ]
491
- self.rng.shuffle(available)
492
-
493
- for k in range(min(n_distractors, len(available))):
494
- idx = available[k]
495
- p = patients[idx]
496
- treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
497
- # Death 30-540 days after treatment (clearly valid)
498
- days = self.rng.randint(30, 540)
499
- death_date = treatment_start + timedelta(days=days)
500
- p["death_date"] = death_date.strftime("%Y-%m-%d")
501
- p["outcome"] = "deceased"
502
- self._traps.add(p["patient_id"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
- return patients
 
 
 
 
 
505
 
506
- # ── Main Generator ────────────────────────────────────────────────
 
 
 
 
507
 
508
  def generate(self, difficulty: str = "easy") -> dict:
509
- """
510
- Generate a complete adversarial dataset for the given difficulty.
511
-
512
- Returns:
513
- {
514
- "dataset": List[dict], # Patient records
515
- "ground_truth": Dict[str, List[str]], # {pid: [error_types]}
516
- "traps": Set[str], # Valid-but-suspicious pids
517
- "bias_present": bool, # Whether bias was injected
518
- "config": dict, # Generation parameters
519
- "stats": dict, # Summary statistics
520
- }
521
- """
522
  config = DIFFICULTY_CONFIGS.get(difficulty, DIFFICULTY_CONFIGS["easy"])
523
  self._ground_truth = {}
524
  self._traps = set()
525
  self._patient_counter = 0
526
 
527
- n = config["dataset_size"]
528
- n_control = int(n * config["control_ratio"])
529
- n_treatment = n - n_control
 
530
 
531
- # ── Step 1: Generate base patients ──
532
  patients = []
533
-
534
- # Determine bias mode for control group
535
- control_bias_mode = "white_dominant" if config["bias_intensity"] > 0 else "neutral"
536
-
537
  for _ in range(n_control):
538
- p = self._generate_base_patient("control", bias_mode=control_bias_mode)
539
- p = self._apply_mortality(p, config["mortality_rate"])
540
- patients.append(p)
541
 
542
  for _ in range(n_treatment):
543
- p = self._generate_base_patient("treatment", bias_mode="diverse")
544
- p = self._apply_mortality(p, config["mortality_rate"])
545
- patients.append(p)
546
-
547
- # ── Step 2: Inject errors ──
548
- patients = self._inject_age_errors(
549
- patients, config["age_error_rate"], config["missing_data_rate"]
550
- )
551
 
 
552
  if config["temporal_error_rate"] > 0:
553
- patients = self._inject_temporal_errors(
554
- patients, config["temporal_error_rate"]
555
- )
556
-
557
- # ── Step 3: Inject bias (hard only) ──
558
- if config["bias_intensity"] > 0:
559
- patients = self._inject_bias(patients, config["bias_intensity"])
560
 
561
- # ── Step 4: Inject adversarial traps ──
562
- patients = self._inject_boundary_traps(patients, config["num_boundary_traps"])
 
 
 
563
 
 
564
  if config["num_temporal_traps"] > 0:
565
- patients = self._inject_temporal_traps(
566
- patients, config["num_temporal_traps"]
567
- )
568
-
569
  if config["num_fake_bias_distractors"] > 0:
570
- patients = self._inject_fake_bias_distractors(
571
- patients, config["num_fake_bias_distractors"]
572
- )
573
 
574
- patients = self._inject_distractor_deceased(
575
- patients, config["num_distractor_deceased"]
576
- )
577
-
578
- # ── Step 5: Shuffle dataset ──
579
  self.rng.shuffle(patients)
580
 
581
- # ── Step 6: Compute summary stats ──
582
- n_age_errors = sum(
583
- 1 for errs in self._ground_truth.values()
584
- if "invalid_age" in errs
585
- )
586
- n_temporal_errors = sum(
587
- 1 for errs in self._ground_truth.values()
588
- if "temporal_inconsistency" in errs
589
- )
590
- total_errors = n_age_errors + n_temporal_errors
591
- if config["bias_intensity"] > 0:
592
- total_errors += 1 # bias counts as 1 error
593
-
594
  stats = {
595
  "total_patients": len(patients),
596
- "total_errors": total_errors,
597
- "age_errors": n_age_errors,
598
- "temporal_errors": n_temporal_errors,
599
- "bias_present": config["bias_intensity"] > 0,
 
600
  "num_traps": len(self._traps),
601
  "control_count": sum(1 for p in patients if p["group"] == "control"),
602
  "treatment_count": sum(1 for p in patients if p["group"] == "treatment"),
 
603
  }
 
 
 
 
 
 
604
 
605
  return {
606
  "dataset": patients,
607
  "ground_truth": dict(self._ground_truth),
608
  "traps": set(self._traps),
609
- "bias_present": config["bias_intensity"] > 0,
 
 
 
610
  "config": config,
611
  "stats": stats,
612
  }
613
 
614
 
615
- # ═══════════════════════════════════════════════════════════════════════════
616
- # STANDALONE TEST
617
- # ═══════════════════════════════════════════════════════════════════════════
618
-
619
  if __name__ == "__main__":
620
  print("=" * 60)
621
  print(" Dataset Generator — Validation Test")
622
  print("=" * 60)
623
 
624
- for diff in ["easy", "medium", "hard"]:
625
- gen = DatasetGenerator(seed=42)
626
- result = gen.generate(difficulty=diff)
627
  stats = result["stats"]
628
- print(f"\n {diff.upper()}:")
 
 
629
  print(f" Patients: {stats['total_patients']}")
630
- print(f" Errors: {stats['total_errors']} "
631
- f"(age={stats['age_errors']}, temporal={stats['temporal_errors']}, "
632
- f"bias={'yes' if stats['bias_present'] else 'no'})")
 
 
633
  print(f" Traps: {stats['num_traps']}")
634
  print(f" Control: {stats['control_count']}")
635
  print(f" Treatment: {stats['treatment_count']}")
 
 
 
 
 
636
 
637
- # Verify reproducibility
638
- gen2 = DatasetGenerator(seed=42)
639
- result2 = gen2.generate(difficulty=diff)
640
- assert result["dataset"] == result2["dataset"], "REPRODUCIBILITY FAILED!"
641
- assert result["ground_truth"] == result2["ground_truth"], "GROUND TRUTH MISMATCH!"
642
- print(f" ✓ Seed reproducibility verified")
643
-
644
- # Verify ground truth
645
- for pid, errors in result["ground_truth"].items():
646
- patient = next(p for p in result["dataset"] if p["patient_id"] == pid)
647
- for err in errors:
648
- if err == "invalid_age":
649
  age = patient.get("age")
650
- assert age is None or age < 18 or age > 120, \
651
- f"Ground truth says {pid} invalid age but age={age}"
652
- elif err == "temporal_inconsistency":
653
- ts = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
654
- dd = datetime.strptime(patient["death_date"], "%Y-%m-%d")
655
- assert dd < ts, \
656
- f"Ground truth says {pid} temporal error but dates are valid"
657
- print(f"Ground truth integrity verified")
658
-
659
- # Verify different seeds produce different datasets
660
- gen_a = DatasetGenerator(seed=1)
661
- gen_b = DatasetGenerator(seed=2)
662
- result_a = gen_a.generate("easy")
663
- result_b = gen_b.generate("easy")
664
- assert result_a["dataset"] != result_b["dataset"], "Different seeds same data!"
665
- print(f"\n ✓ Different seeds produce different datasets")
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  print(f"\n{'=' * 60}")
667
- print(f" ALL TESTS PASSED")
668
  print(f"{'=' * 60}")
 
1
  """
2
  Procedural Adversarial Clinical Trial Data Engine
3
+ =================================================
4
+ Generates seeded, protocol-driven clinical trial datasets for OpenEnv episodes.
5
+
6
+ This generator is intentionally benchmark-oriented:
7
+ - each episode samples a different protocol excerpt and hidden rule set
8
+ - age eligibility is protocol-specific, not a fixed 18-120 shortcut
9
+ - treatment scheduling uses stage-aware exceptions to create valid edge cases
10
+ - hard episodes alternate between true bias and confounded "looks bad" cohorts
11
+ - all labels remain deterministic and reproducible from the seed
 
 
 
 
 
 
 
12
  """
13
 
14
+ from __future__ import annotations
15
+
16
  import hashlib
17
+ import random
18
  from datetime import datetime, timedelta
19
  from typing import Optional
20
 
21
 
 
 
 
 
22
  HOSPITAL_SITES = [
23
  ("Metro General Hospital", "US"),
24
  ("Cleveland Oncology Institute", "US"),
 
27
  ("MD Anderson Cancer Center", "US"),
28
  ("AIIMS Delhi", "India"),
29
  ("Tata Memorial Hospital", "India"),
30
+ ("Charite Berlin", "Germany"),
31
+ ("Hospital Clinic Barcelona", "Spain"),
32
  ("Tokyo Medical University", "Japan"),
33
  ("Seoul National University Hospital", "South Korea"),
34
  ("Royal Marsden Hospital", "UK"),
 
37
  ("Peter MacCallum Cancer Centre", "Australia"),
38
  ]
39
 
 
40
  RURAL_SITES = {
41
+ "AIIMS Delhi",
42
  "Howard University Hospital",
43
+ "Tata Memorial Hospital",
44
  }
45
 
46
+ ETHNICITIES = [
47
+ "White",
48
+ "Black",
49
+ "Hispanic",
50
+ "Asian",
51
+ "Native American",
52
+ "Pacific Islander",
53
+ ]
54
  GENDERS = ["M", "F"]
55
  STAGES = ["I", "II", "III", "IV"]
56
  DRUGS_TREATMENT = ["ImmunoVax-7", "OncoShield-X", "TargetCure-3"]
 
57
 
 
58
  TRIAL_START = datetime(2022, 6, 1)
59
  TRIAL_END = datetime(2025, 3, 1)
60
 
61
+ BASE_STAGE_MORTALITY = {
62
+ "I": 0.04,
63
+ "II": 0.08,
64
+ "III": 0.16,
65
+ "IV": 0.32,
66
+ }
67
+
68
+ AGE_RULESETS = {
69
+ "easy": [(35, 75), (40, 80), (45, 85)],
70
+ "medium": [(18, 75), (21, 80), (30, 85), (40, 90)],
71
+ "hard": [(18, 75), (21, 80), (30, 85), (35, 85), (40, 90)],
72
+ }
73
+
74
+ WINDOW_RULESETS = {
75
+ "easy": [21, 24, 28],
76
+ "medium": [18, 21, 24, 28],
77
+ "hard": [14, 18, 21, 24],
78
+ }
79
 
80
  DIFFICULTY_CONFIGS = {
81
  "easy": {
82
  "dataset_size": 300,
83
+ "age_error_rate": 0.020,
84
+ "missing_age_rate": 0.007,
85
+ "temporal_error_rate": 0.0,
86
+ "protocol_window_error_rate": 0.0,
87
+ "num_boundary_traps": 8,
88
  "num_temporal_traps": 0,
89
+ "num_window_traps": 0,
90
+ "num_distractor_deceased": 5,
91
  "num_fake_bias_distractors": 0,
92
+ "bias_probability": 0.0,
93
+ "control_ratio": 0.50,
94
+ "task_type": "eligibility_screening",
 
95
  },
96
  "medium": {
97
+ "dataset_size": 480,
98
+ "age_error_rate": 0.018,
99
+ "missing_age_rate": 0.007,
100
+ "temporal_error_rate": 0.012,
101
+ "protocol_window_error_rate": 0.015,
102
+ "num_boundary_traps": 10,
103
+ "num_temporal_traps": 4,
104
+ "num_window_traps": 5,
105
+ "num_distractor_deceased": 6,
106
  "num_fake_bias_distractors": 0,
107
+ "bias_probability": 0.0,
108
  "control_ratio": 0.50,
109
+ "task_type": "protocol_timeline_audit",
 
110
  },
111
  "hard": {
112
+ "dataset_size": 720,
113
+ "age_error_rate": 0.017,
114
+ "missing_age_rate": 0.006,
115
+ "temporal_error_rate": 0.010,
116
+ "protocol_window_error_rate": 0.014,
117
+ "num_boundary_traps": 12,
118
+ "num_temporal_traps": 5,
119
+ "num_window_traps": 7,
120
  "num_distractor_deceased": 8,
121
+ "num_fake_bias_distractors": 8,
122
+ "bias_probability": 0.58,
123
  "control_ratio": 0.50,
124
+ "task_type": "equity_and_protocol_audit",
 
125
  },
126
  }
127
 
128
 
 
 
 
 
129
  class DatasetGenerator:
130
+ """Seeded benchmark dataset generator."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def __init__(self, seed: Optional[int] = None):
133
  self.seed = seed
 
140
  self._patient_counter += 1
141
  return f"P{self._patient_counter:04d}"
142
 
143
+ def _mark_error(self, patient_id: str, error_type: str) -> None:
144
+ self._ground_truth.setdefault(patient_id, []).append(error_type)
145
+
146
  def _random_date(self, start: datetime, end: datetime) -> datetime:
 
147
  delta = (end - start).days
148
  if delta <= 0:
149
  return start
150
  return start + timedelta(days=self.rng.randint(0, delta))
151
 
152
+ def _build_protocol(self, difficulty: str, config: dict) -> dict:
153
+ age_min, age_max = self.rng.choice(AGE_RULESETS[difficulty])
154
+ treatment_window = self.rng.choice(WINDOW_RULESETS[difficulty])
155
+ stage_iv_window = treatment_window + self.rng.choice([7, 10, 14])
156
+ high_risk_sites = self.rng.sample(sorted(RURAL_SITES), k=2 if difficulty == "hard" else 1)
157
+ dominant_threshold = self.rng.choice([0.68, 0.70, 0.72]) if difficulty == "hard" else 0.0
158
+ male_threshold = self.rng.choice([0.56, 0.60, 0.63]) if difficulty == "hard" else 0.0
159
+ adjusted_gap = self.rng.choice([0.12, 0.15, 0.18]) if difficulty == "hard" else 0.0
160
+ bias_present = difficulty == "hard" and self.rng.random() < config["bias_probability"]
161
+ protocol_key = (
162
+ f"{difficulty}|{age_min}|{age_max}|{treatment_window}|"
163
+ f"{stage_iv_window}|{'/'.join(high_risk_sites)}|{bias_present}"
164
+ )
165
+ protocol_id = hashlib.sha1(protocol_key.encode("utf-8")).hexdigest()[:8].upper()
166
+ protocol_title = f"ONCO-AX-{difficulty[:1].upper()}{protocol_id}"
167
+
168
+ lines = [
169
+ f"TRIAL PROTOCOL EXCERPT — {protocol_title}",
170
+ "",
171
+ "Eligibility",
172
+ f"- Participants must be age {age_min}-{age_max} inclusive on enrollment.",
173
+ "- Missing age is a protocol violation.",
174
+ "",
175
+ "Treatment Scheduling",
176
+ f"- Treatment must begin within {treatment_window} days of enrollment.",
177
+ (
178
+ f"- Stage IV exception: treatment may begin within "
179
+ f"{stage_iv_window} days of enrollment."
180
+ ),
181
+ "- death_date must never precede treatment_start.",
182
+ "- Do not assume a generic 18-120 range; this excerpt overrides defaults.",
183
+ ]
184
+
185
+ if difficulty == "hard":
186
+ lines.extend(
187
+ [
188
+ "",
189
+ "Equity Review",
190
+ "- Selection bias concerns control-arm composition, not treatment-arm skew.",
191
+ "- Compare mortality within stage strata before escalating a bias concern.",
192
+ (
193
+ f"- Stage IV patients at {', '.join(high_risk_sites)} are a known "
194
+ "high-risk outreach cohort; elevated mortality there is not standalone "
195
+ "evidence of unfair sampling."
196
+ ),
197
+ (
198
+ f"- Escalate bias only when control-arm dominance exceeds "
199
+ f"{int(dominant_threshold * 100)}%, male share exceeds "
200
+ f"{int(male_threshold * 100)}%, and stage-adjusted mortality gap "
201
+ f"exceeds {int(adjusted_gap * 100)} percentage points."
202
+ ),
203
+ ]
204
+ )
205
+
206
+ return {
207
+ "protocol_id": protocol_id,
208
+ "protocol_title": protocol_title,
209
+ "excerpt": "\n".join(lines),
210
+ "age_min": age_min,
211
+ "age_max": age_max,
212
+ "treatment_window_days": treatment_window,
213
+ "stage_iv_treatment_window_days": stage_iv_window,
214
+ "high_risk_sites": high_risk_sites,
215
+ "bias_control_dominance_threshold": dominant_threshold,
216
+ "bias_male_threshold": male_threshold,
217
+ "bias_stage_adjusted_gap": adjusted_gap,
218
+ "bias_present": bias_present,
219
+ }
220
+
221
+ def _generate_age(self, protocol: dict) -> int:
222
  while True:
223
+ age = int(self.rng.gauss(58, 11))
224
+ if protocol["age_min"] <= age <= protocol["age_max"]:
225
  return age
226
 
227
  def _select_ethnicity(self, bias_mode: str = "neutral") -> str:
228
+ if bias_mode == "diverse":
229
+ weights = [0.28, 0.19, 0.20, 0.18, 0.10, 0.05]
230
+ elif bias_mode == "white_dominant":
231
+ weights = [0.68, 0.08, 0.08, 0.08, 0.05, 0.03]
232
+ else:
233
+ weights = [0.50, 0.16, 0.15, 0.12, 0.04, 0.03]
 
 
 
 
 
234
  return self.rng.choices(ETHNICITIES, weights=weights, k=1)[0]
235
 
236
+ def _base_delay(self, stage: str, protocol: dict) -> int:
237
+ max_window = (
238
+ protocol["stage_iv_treatment_window_days"]
239
+ if stage == "IV"
240
+ else protocol["treatment_window_days"]
241
+ )
242
+ lower = 5 if max_window >= 10 else 1
243
+ upper = max(lower, max_window - 2)
244
+ return self.rng.randint(lower, upper)
245
+
246
+ def _generate_base_patient(self, group: str, protocol: dict, bias_mode: str = "neutral") -> dict:
247
  pid = self._next_pid()
248
  site, country = self.rng.choice(HOSPITAL_SITES)
249
+ stage = self.rng.choices(STAGES, weights=[0.24, 0.28, 0.28, 0.20], k=1)[0]
250
+ age = self._generate_age(protocol)
251
+ enrollment_end = TRIAL_END - timedelta(days=150)
252
+ enrollment_date = self._random_date(TRIAL_START, enrollment_end)
253
+ treatment_start = enrollment_date + timedelta(days=self._base_delay(stage, protocol))
254
+ return {
 
 
 
 
 
 
 
 
255
  "patient_id": pid,
256
  "age": age,
257
+ "gender": self.rng.choice(GENDERS),
258
+ "ethnicity": self._select_ethnicity(bias_mode),
259
  "group": group,
260
  "treatment_start": treatment_start.strftime("%Y-%m-%d"),
261
  "death_date": None,
 
263
  "treatment_site": site,
264
  "stage": stage,
265
  "trial_phase": "Phase III",
266
+ "drug": self.rng.choice(DRUGS_TREATMENT) if group == "treatment" else "Placebo",
267
  "enrollment_date": enrollment_date.strftime("%Y-%m-%d"),
268
  "country": country,
269
  }
270
 
271
+ def _mortality_rate(self, patient: dict, protocol: dict) -> float:
272
+ rate = BASE_STAGE_MORTALITY.get(patient["stage"], 0.10)
273
+ if patient["treatment_site"] in protocol["high_risk_sites"] and patient["stage"] == "IV":
274
+ rate += 0.16
275
+ if patient["group"] == "treatment":
276
+ rate *= 0.92
277
+ return max(0.02, min(0.82, rate))
278
+
279
+ def _set_deceased(self, patient: dict, min_days: int, max_days: int) -> None:
280
+ treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
281
+ days_to_death = self.rng.randint(min_days, max_days)
282
+ death_date = treatment_start + timedelta(days=days_to_death)
283
+ patient["death_date"] = death_date.strftime("%Y-%m-%d")
284
+ patient["outcome"] = "deceased"
285
+
286
+ def _apply_mortality(self, patient: dict, protocol: dict) -> dict:
287
+ if self.rng.random() < self._mortality_rate(patient, protocol):
288
+ self._set_deceased(patient, min_days=3, max_days=540)
289
  return patient
290
 
291
+ def _apply_target_mortality(self, cohort: list[dict], target_rate: float) -> None:
292
+ if not cohort:
293
+ return
294
+ self.rng.shuffle(cohort)
295
+ target_count = int(round(len(cohort) * max(0.0, min(1.0, target_rate))))
296
+ for index, patient in enumerate(cohort):
297
+ if index < target_count:
298
+ self._set_deceased(patient, min_days=10, max_days=420)
299
+ else:
300
+ patient["death_date"] = None
301
+ patient["outcome"] = "survived"
302
+
303
+ def _allowed_treatment_window(self, patient: dict, protocol: dict) -> int:
304
+ return (
305
+ protocol["stage_iv_treatment_window_days"]
306
+ if patient.get("stage") == "IV"
307
+ else protocol["treatment_window_days"]
308
+ )
309
 
310
+ def _enrollment_date(self, patient: dict) -> datetime:
311
+ return datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
312
 
313
+ def _treatment_date(self, patient: dict) -> datetime:
314
+ return datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
 
 
 
315
 
316
+ def _inject_age_errors(self, patients: list[dict], protocol: dict, config: dict) -> list[dict]:
317
+ n_invalid = max(3, int(len(patients) * config["age_error_rate"]))
318
+ n_missing = max(1, int(len(patients) * config["missing_age_rate"]))
319
  available = list(range(len(patients)))
320
  self.rng.shuffle(available)
321
 
322
+ low_values = [protocol["age_min"] - 1, protocol["age_min"] - 2, max(0, protocol["age_min"] - 5), -1]
323
+ high_values = [protocol["age_max"] + 1, protocol["age_max"] + 2, protocol["age_max"] + 5, 999]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ for offset in range(min(n_invalid, len(available))):
326
+ patient = patients[available[offset]]
327
+ patient["age"] = self.rng.choice(low_values + high_values)
328
+ self._mark_error(patient["patient_id"], "invalid_age")
329
 
330
+ start = min(n_invalid, len(available))
331
+ for offset in range(start, min(start + n_missing, len(available))):
332
+ patient = patients[available[offset]]
333
+ patient["age"] = None
334
+ self._mark_error(patient["patient_id"], "invalid_age")
335
 
336
+ return patients
 
 
 
 
 
 
337
 
338
+ def _inject_temporal_errors(self, patients: list[dict], config: dict) -> list[dict]:
339
+ n_errors = max(3, int(len(patients) * config["temporal_error_rate"]))
340
+ candidates = [p for p in patients if p["patient_id"] not in self._ground_truth]
341
  self.rng.shuffle(candidates)
342
 
343
+ for patient in candidates[:n_errors]:
344
+ treatment_start = self._treatment_date(patient)
345
+ death_date = treatment_start - timedelta(days=self.rng.randint(10, 240))
346
+ patient["death_date"] = death_date.strftime("%Y-%m-%d")
347
+ patient["outcome"] = "deceased"
348
+ self._mark_error(patient["patient_id"], "temporal_inconsistency")
349
 
350
+ return patients
 
 
351
 
352
+ def _inject_protocol_window_errors(
353
+ self,
354
+ patients: list[dict],
355
+ protocol: dict,
356
+ config: dict,
357
+ ) -> list[dict]:
358
+ n_errors = max(3, int(len(patients) * config["protocol_window_error_rate"]))
359
+ candidates = [p for p in patients if p["patient_id"] not in self._ground_truth]
360
+ self.rng.shuffle(candidates)
361
 
362
+ for patient in candidates[:n_errors]:
363
+ allowed_days = self._allowed_treatment_window(patient, protocol)
364
+ enrollment = self._enrollment_date(patient)
365
+ violation_days = allowed_days + self.rng.randint(2, 18)
366
+ patient["treatment_start"] = (enrollment + timedelta(days=violation_days)).strftime("%Y-%m-%d")
367
+ if patient["death_date"]:
368
+ death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
369
+ treatment_start = self._treatment_date(patient)
370
+ if death_date <= treatment_start:
371
+ self._set_deceased(patient, min_days=20, max_days=320)
372
+ self._mark_error(patient["patient_id"], "protocol_window_violation")
373
 
374
  return patients
375
 
376
+ def _inject_boundary_traps(self, patients: list[dict], protocol: dict, n_traps: int) -> list[dict]:
377
+ valid_ages = [
378
+ protocol["age_min"],
379
+ protocol["age_min"] + 1,
380
+ protocol["age_min"] + 2,
381
+ protocol["age_max"] - 2,
382
+ protocol["age_max"] - 1,
383
+ protocol["age_max"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  ]
385
+ available = [
386
+ p
387
+ for p in patients
388
+ if p["patient_id"] not in self._ground_truth and p["age"] is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  ]
390
+ self.rng.shuffle(available)
391
+ for patient, age in zip(available[:n_traps], valid_ages * max(1, n_traps)):
392
+ patient["age"] = age
393
+ self._traps.add(patient["patient_id"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  return patients
395
 
396
+ def _inject_temporal_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
 
 
 
 
 
 
 
 
397
  available = [
398
+ p
399
+ for p in patients
400
  if p["patient_id"] not in self._ground_truth
401
+ and p["patient_id"] not in self._traps
402
+ and p["death_date"] is None
403
  ]
404
  self.rng.shuffle(available)
405
+ for patient in available[:n_traps]:
406
+ patient["stage"] = "IV"
407
+ self._set_deceased(patient, min_days=1, max_days=3)
408
+ self._traps.add(patient["patient_id"])
409
+ return patients
410
 
411
+ def _inject_window_traps(self, patients: list[dict], protocol: dict, n_traps: int) -> list[dict]:
412
+ available = [
413
+ p
414
+ for p in patients
415
+ if p["patient_id"] not in self._ground_truth and p["patient_id"] not in self._traps
416
+ ]
417
+ self.rng.shuffle(available)
418
+ for patient in available[:n_traps]:
419
+ enrollment = self._enrollment_date(patient)
420
+ if self.rng.random() < 0.55:
421
+ patient["stage"] = "IV"
422
+ allowed_days = self._allowed_treatment_window(patient, protocol)
423
+ trap_delay = max(1, allowed_days - self.rng.choice([0, 1]))
424
+ patient["treatment_start"] = (enrollment + timedelta(days=trap_delay)).strftime("%Y-%m-%d")
425
+ if patient["death_date"]:
426
+ death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
427
+ if death_date <= self._treatment_date(patient):
428
+ self._set_deceased(patient, min_days=12, max_days=240)
429
+ self._traps.add(patient["patient_id"])
430
  return patients
431
 
432
+ def _inject_distractor_deceased(self, patients: list[dict], n_distractors: int) -> list[dict]:
 
 
 
 
433
  available = [
434
+ p
435
+ for p in patients
436
  if p["patient_id"] not in self._ground_truth
 
437
  and p["patient_id"] not in self._traps
438
+ and p["death_date"] is None
439
  ]
440
  self.rng.shuffle(available)
441
+ for patient in available[:n_distractors]:
442
+ self._set_deceased(patient, min_days=30, max_days=520)
443
+ self._traps.add(patient["patient_id"])
 
 
 
 
 
 
 
 
 
 
444
  return patients
445
 
446
+ def _inject_fake_bias_distractors(self, patients: list[dict], n_distractors: int) -> list[dict]:
447
+ treatment_group = [
448
+ p
449
+ for p in patients
 
 
 
 
 
450
  if p["group"] == "treatment"
451
  and p["patient_id"] not in self._ground_truth
452
  and p["patient_id"] not in self._traps
453
  ]
454
+ self.rng.shuffle(treatment_group)
455
+ for patient in treatment_group[:n_distractors]:
456
+ patient["ethnicity"] = "White"
457
+ patient["gender"] = "M"
458
+ if self.rng.random() < 0.5:
459
+ patient["stage"] = "IV"
460
+ self._set_deceased(patient, min_days=15, max_days=180)
461
+ self._traps.add(patient["patient_id"])
 
462
  return patients
463
 
464
+ def _inject_selection_bias(self, patients: list[dict], protocol: dict) -> None:
465
+ control = [
466
+ p
467
+ for p in patients
468
+ if p["group"] == "control" and p["patient_id"] not in self._ground_truth
 
 
 
 
 
 
469
  ]
470
+ if not control:
471
+ return
472
+
473
+ target_dom_ratio = protocol["bias_control_dominance_threshold"] + self.rng.uniform(0.06, 0.12)
474
+ dominant_target = int(len(control) * min(0.86, target_dom_ratio))
475
+ white_control = [p for p in control if p["ethnicity"] == "White"]
476
+ non_white_control = [p for p in control if p["ethnicity"] != "White"]
477
+ needed = max(0, dominant_target - len(white_control))
478
+ self.rng.shuffle(non_white_control)
479
+ for patient in non_white_control[:needed]:
480
+ patient["ethnicity"] = "White"
481
+
482
+ target_male_ratio = protocol["bias_male_threshold"] + self.rng.uniform(0.05, 0.10)
483
+ male_target = int(len(control) * min(0.82, target_male_ratio))
484
+ male_control = [p for p in control if p["gender"] == "M"]
485
+ female_control = [p for p in control if p["gender"] == "F"]
486
+ needed_male = max(0, male_target - len(male_control))
487
+ self.rng.shuffle(female_control)
488
+ for patient in female_control[:needed_male]:
489
+ patient["gender"] = "M"
490
+
491
+ dominant = [p for p in control if p["ethnicity"] == "White"]
492
+ minority = [p for p in control if p["ethnicity"] != "White"]
493
+ for stage in STAGES:
494
+ stage_majority = [p for p in dominant if p["stage"] == stage]
495
+ stage_minority = [p for p in minority if p["stage"] == stage]
496
+ if not stage_majority or not stage_minority:
497
+ continue
498
+ base = BASE_STAGE_MORTALITY[stage]
499
+ self._apply_target_mortality(stage_majority, max(0.02, base - 0.03))
500
+ self._apply_target_mortality(stage_minority, min(0.82, base + 0.18))
501
+
502
+ def _inject_confounder_cohort(self, patients: list[dict], protocol: dict) -> None:
503
+ control = [
504
+ p
505
+ for p in patients
506
+ if p["group"] == "control" and p["patient_id"] not in self._ground_truth
507
+ ]
508
+ if not control:
509
+ return
510
+
511
+ minority = [p for p in control if p["ethnicity"] != "White"]
512
+ white = [p for p in control if p["ethnicity"] == "White"]
513
+ self.rng.shuffle(minority)
514
+ self.rng.shuffle(white)
515
+
516
+ minority_shift = max(8, len(control) // 18)
517
+ white_shift = max(4, len(control) // 30)
518
+
519
+ for patient in minority[:minority_shift]:
520
+ patient["stage"] = "IV"
521
+ patient["treatment_site"] = self.rng.choice(protocol["high_risk_sites"])
522
+ patient["country"] = next(
523
+ country for site, country in HOSPITAL_SITES if site == patient["treatment_site"]
524
+ )
525
 
526
+ for patient in white[:white_shift]:
527
+ patient["stage"] = "IV"
528
+ patient["treatment_site"] = self.rng.choice(protocol["high_risk_sites"])
529
+ patient["country"] = next(
530
+ country for site, country in HOSPITAL_SITES if site == patient["treatment_site"]
531
+ )
532
 
533
+ stage_iv_control = [p for p in control if p["stage"] == "IV"]
534
+ stage_iv_minority = [p for p in stage_iv_control if p["ethnicity"] != "White"]
535
+ stage_iv_white = [p for p in stage_iv_control if p["ethnicity"] == "White"]
536
+ self._apply_target_mortality(stage_iv_minority, 0.66)
537
+ self._apply_target_mortality(stage_iv_white, 0.63)
538
 
539
  def generate(self, difficulty: str = "easy") -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  config = DIFFICULTY_CONFIGS.get(difficulty, DIFFICULTY_CONFIGS["easy"])
541
  self._ground_truth = {}
542
  self._traps = set()
543
  self._patient_counter = 0
544
 
545
+ protocol = self._build_protocol(difficulty, config)
546
+ n_patients = config["dataset_size"]
547
+ n_control = int(n_patients * config["control_ratio"])
548
+ n_treatment = n_patients - n_control
549
 
 
550
  patients = []
 
 
 
 
551
  for _ in range(n_control):
552
+ patient = self._generate_base_patient("control", protocol, bias_mode="neutral")
553
+ patients.append(self._apply_mortality(patient, protocol))
 
554
 
555
  for _ in range(n_treatment):
556
+ patient = self._generate_base_patient("treatment", protocol, bias_mode="diverse")
557
+ patients.append(self._apply_mortality(patient, protocol))
 
 
 
 
 
 
558
 
559
+ patients = self._inject_age_errors(patients, protocol, config)
560
  if config["temporal_error_rate"] > 0:
561
+ patients = self._inject_temporal_errors(patients, config)
562
+ if config["protocol_window_error_rate"] > 0:
563
+ patients = self._inject_protocol_window_errors(patients, protocol, config)
 
 
 
 
564
 
565
+ if difficulty == "hard":
566
+ if protocol["bias_present"]:
567
+ self._inject_selection_bias(patients, protocol)
568
+ else:
569
+ self._inject_confounder_cohort(patients, protocol)
570
 
571
+ patients = self._inject_boundary_traps(patients, protocol, config["num_boundary_traps"])
572
  if config["num_temporal_traps"] > 0:
573
+ patients = self._inject_temporal_traps(patients, config["num_temporal_traps"])
574
+ if config["num_window_traps"] > 0:
575
+ patients = self._inject_window_traps(patients, protocol, config["num_window_traps"])
576
+ patients = self._inject_distractor_deceased(patients, config["num_distractor_deceased"])
577
  if config["num_fake_bias_distractors"] > 0:
578
+ patients = self._inject_fake_bias_distractors(patients, config["num_fake_bias_distractors"])
 
 
579
 
 
 
 
 
 
580
  self.rng.shuffle(patients)
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  stats = {
583
  "total_patients": len(patients),
584
+ "age_errors": sum("invalid_age" in errs for errs in self._ground_truth.values()),
585
+ "temporal_errors": sum("temporal_inconsistency" in errs for errs in self._ground_truth.values()),
586
+ "protocol_window_errors": sum("protocol_window_violation" in errs for errs in self._ground_truth.values()),
587
+ "bias_present": protocol["bias_present"],
588
+ "bias_mode": "true_bias" if protocol["bias_present"] else ("confounded_no_bias" if difficulty == "hard" else "none"),
589
  "num_traps": len(self._traps),
590
  "control_count": sum(1 for p in patients if p["group"] == "control"),
591
  "treatment_count": sum(1 for p in patients if p["group"] == "treatment"),
592
+ "protocol_title": protocol["protocol_title"],
593
  }
594
+ stats["total_errors"] = (
595
+ stats["age_errors"]
596
+ + stats["temporal_errors"]
597
+ + stats["protocol_window_errors"]
598
+ + (1 if protocol["bias_present"] else 0)
599
+ )
600
 
601
  return {
602
  "dataset": patients,
603
  "ground_truth": dict(self._ground_truth),
604
  "traps": set(self._traps),
605
+ "bias_present": protocol["bias_present"],
606
+ "protocol": protocol,
607
+ "protocol_excerpt": protocol["excerpt"],
608
+ "protocol_title": protocol["protocol_title"],
609
  "config": config,
610
  "stats": stats,
611
  }
612
 
613
 
 
 
 
 
614
  if __name__ == "__main__":
615
  print("=" * 60)
616
  print(" Dataset Generator — Validation Test")
617
  print("=" * 60)
618
 
619
+ for difficulty in ["easy", "medium", "hard"]:
620
+ generator = DatasetGenerator(seed=42)
621
+ result = generator.generate(difficulty=difficulty)
622
  stats = result["stats"]
623
+ protocol = result["protocol"]
624
+ print(f"\n {difficulty.upper()}:")
625
+ print(f" Protocol: {stats['protocol_title']}")
626
  print(f" Patients: {stats['total_patients']}")
627
+ print(
628
+ f" Errors: {stats['total_errors']} "
629
+ f"(age={stats['age_errors']}, temporal={stats['temporal_errors']}, "
630
+ f"window={stats['protocol_window_errors']}, bias={stats['bias_mode']})"
631
+ )
632
  print(f" Traps: {stats['num_traps']}")
633
  print(f" Control: {stats['control_count']}")
634
  print(f" Treatment: {stats['treatment_count']}")
635
+ print(
636
+ f" Rules: age={protocol['age_min']}-{protocol['age_max']} | "
637
+ f"start<={protocol['treatment_window_days']}d | "
638
+ f"stage IV<={protocol['stage_iv_treatment_window_days']}d"
639
+ )
640
 
641
+ generator_2 = DatasetGenerator(seed=42)
642
+ result_2 = generator_2.generate(difficulty=difficulty)
643
+ assert result["dataset"] == result_2["dataset"], "REPRODUCIBILITY FAILED!"
644
+ assert result["ground_truth"] == result_2["ground_truth"], "GROUND TRUTH MISMATCH!"
645
+ assert result["protocol_excerpt"] == result_2["protocol_excerpt"], "PROTOCOL MISMATCH!"
646
+ print(" ✓ Seed reproducibility verified")
647
+
648
+ for patient_id, errors in result["ground_truth"].items():
649
+ patient = next(p for p in result["dataset"] if p["patient_id"] == patient_id)
650
+ for error in errors:
651
+ if error == "invalid_age":
 
652
  age = patient.get("age")
653
+ assert age is None or age < protocol["age_min"] or age > protocol["age_max"], (
654
+ f"Ground truth says {patient_id} invalid age but age={age}"
655
+ )
656
+ elif error == "temporal_inconsistency":
657
+ treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
658
+ death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
659
+ assert death_date < treatment_start, (
660
+ f"Ground truth says {patient_id} temporal error but dates are valid"
661
+ )
662
+ elif error == "protocol_window_violation":
663
+ enrollment = datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
664
+ treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
665
+ allowed = (
666
+ protocol["stage_iv_treatment_window_days"]
667
+ if patient["stage"] == "IV"
668
+ else protocol["treatment_window_days"]
669
+ )
670
+ assert (treatment_start - enrollment).days > allowed, (
671
+ f"Ground truth says {patient_id} window error but delay is valid"
672
+ )
673
+ print(" ✓ Ground truth integrity verified")
674
+
675
+ generator_a = DatasetGenerator(seed=1)
676
+ generator_b = DatasetGenerator(seed=2)
677
+ result_a = generator_a.generate("easy")
678
+ result_b = generator_b.generate("easy")
679
+ assert result_a["dataset"] != result_b["dataset"], "Different seeds generated identical datasets!"
680
+ assert result_a["protocol_excerpt"] != result_b["protocol_excerpt"], "Different seeds generated identical protocols!"
681
+ print("\n ✓ Different seeds produce different datasets")
682
  print(f"\n{'=' * 60}")
683
+ print(" ALL TESTS PASSED")
684
  print(f"{'=' * 60}")
server/models.py CHANGED
@@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
2
  from pydantic import Field
3
  from openenv.core.env_server import Action, Observation, State
4
 
 
5
  class AuditAction(Action):
6
  action_type: str = "flag_error"
7
  patient_id: Optional[str] = None
@@ -12,30 +13,43 @@ class AuditAction(Action):
12
  report: Optional[str] = None
13
  confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
14
 
 
15
  class AuditObservation(Observation):
16
  done: bool = False
17
  reward: float = 0.0
18
  task_id: str = ""
19
  task_type: str = ""
20
  task_description: str = ""
 
 
21
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
22
  errors_found: List[str] = Field(default_factory=list)
23
  patterns_investigated: List[str] = Field(default_factory=list)
24
  distributions_computed: List[str] = Field(default_factory=list)
25
  feedback: Optional[str] = None
26
  score_so_far: float = 0.0
 
 
27
  attempts_remaining: int = 15
28
  phase: str = "investigation"
29
 
 
30
  class AuditState(State):
31
  episode_id: str = ""
32
  step_count: int = 0
33
  task_id: str = ""
34
  task_type: str = ""
 
 
35
  total_errors: int = 0
36
  errors_found: int = 0
37
  current_score: float = 0.0
 
 
 
 
38
  attempts: int = 0
39
  phase: str = "investigation"
 
40
  patterns_investigated: List[str] = Field(default_factory=list)
41
- distributions_computed: List[str] = Field(default_factory=list)
 
2
  from pydantic import Field
3
  from openenv.core.env_server import Action, Observation, State
4
 
5
+
6
  class AuditAction(Action):
7
  action_type: str = "flag_error"
8
  patient_id: Optional[str] = None
 
13
  report: Optional[str] = None
14
  confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
15
 
16
+
17
  class AuditObservation(Observation):
18
  done: bool = False
19
  reward: float = 0.0
20
  task_id: str = ""
21
  task_type: str = ""
22
  task_description: str = ""
23
+ protocol_title: str = ""
24
+ trial_protocol_excerpt: str = ""
25
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
26
  errors_found: List[str] = Field(default_factory=list)
27
  patterns_investigated: List[str] = Field(default_factory=list)
28
  distributions_computed: List[str] = Field(default_factory=list)
29
  feedback: Optional[str] = None
30
  score_so_far: float = 0.0
31
+ dense_reward_total: float = 0.0
32
+ score_breakdown: Dict[str, float] = Field(default_factory=dict)
33
  attempts_remaining: int = 15
34
  phase: str = "investigation"
35
 
36
+
37
  class AuditState(State):
38
  episode_id: str = ""
39
  step_count: int = 0
40
  task_id: str = ""
41
  task_type: str = ""
42
+ protocol_title: str = ""
43
+ trial_protocol_excerpt: str = ""
44
  total_errors: int = 0
45
  errors_found: int = 0
46
  current_score: float = 0.0
47
+ dense_reward_total: float = 0.0
48
+ correct_flags: int = 0
49
+ false_positives: int = 0
50
+ duplicate_flags: int = 0
51
  attempts: int = 0
52
  phase: str = "investigation"
53
+ score_breakdown: Dict[str, float] = Field(default_factory=dict)
54
  patterns_investigated: List[str] = Field(default_factory=list)
55
+ distributions_computed: List[str] = Field(default_factory=list)