Spaces:
Sleeping
Sleeping
Sumit Saraswat commited on
Commit ·
a7bca03
1
Parent(s): 0dca8bf
Restructured Dockerfile and requirements to root for Hugging Face deployment
Browse files- .DS_Store +0 -0
- server/Dockerfile → Dockerfile +1 -1
- README.md +122 -124
- inference.py +655 -711
- models.py +15 -1
- openenv.yaml +15 -13
- server/requirements.txt → requirements.txt +0 -0
- server.log +4 -0
- server/app.py +9 -4
- server/clinical_trial_auditor_environment.py +530 -506
- server/dataset_generator.py +492 -476
- server/models.py +15 -1
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
server/Dockerfile → Dockerfile
RENAMED
|
@@ -17,4 +17,4 @@ EXPOSE 8000
|
|
| 17 |
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
|
| 18 |
CMD curl -f http://localhost:8000/health || exit 1
|
| 19 |
|
| 20 |
-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
| 17 |
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
|
| 18 |
CMD curl -f http://localhost:8000/health || exit 1
|
| 19 |
|
| 20 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,88 +1,67 @@
|
|
| 1 |
# Clinical Trial Auditor (OpenEnv)
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
-
- syntactic data quality errors (invalid/missing age),
|
| 7 |
-
- temporal inconsistencies (death before treatment),
|
| 8 |
-
- multi-dimensional selection bias in control cohorts.
|
| 9 |
|
| 10 |
-
|
| 11 |
-
- procedural generation on every `reset()`,
|
| 12 |
-
- deterministic seed reproducibility,
|
| 13 |
-
- adversarial traps that punish shallow heuristics,
|
| 14 |
-
- deterministic programmatic graders with scores in `[0.0, 1.0]`.
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
-
- invalidate endpoints,
|
| 22 |
-
- distort treatment effect estimates,
|
| 23 |
-
- create regulatory and patient-safety risk.
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
---
|
|
|
|
|
|
|
| 28 |
|
| 29 |
## OpenEnv Compliance
|
| 30 |
|
| 31 |
This project implements the required OpenEnv interface:
|
| 32 |
-
- typed `Action`, `Observation`, `State` models
|
| 33 |
- `reset(seed, task_id, ...) -> Observation`,
|
| 34 |
- `step(action) -> Observation`,
|
| 35 |
- `state -> current state`,
|
| 36 |
-
- `openenv.yaml`
|
| 37 |
|
| 38 |
Validation:
|
|
|
|
| 39 |
```bash
|
| 40 |
openenv validate .
|
| 41 |
```
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
## Environment Architecture
|
| 46 |
-
|
| 47 |
-
`ClinicalTrialAuditorEnvironment` is intentionally layered:
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
2. **Trap Engine**
|
| 53 |
-
- Boundary-valid traps (`18`, `120`, etc.),
|
| 54 |
-
- near-temporal valid traps (death 1-3 days after treatment),
|
| 55 |
-
- fake bias distractors.
|
| 56 |
-
3. **Scoring Engine**
|
| 57 |
-
- Deterministic ground-truth lookup for each flag.
|
| 58 |
-
- Partial progress rewards + false-positive penalties.
|
| 59 |
-
- Confidence calibration (overconfident wrong answers are punished harder).
|
| 60 |
-
4. **Agent Interface**
|
| 61 |
-
- Standard OpenEnv `step/reset/state`.
|
| 62 |
-
|
| 63 |
-
---
|
| 64 |
-
|
| 65 |
-
## Task Suite (Easy -> Medium -> Hard)
|
| 66 |
|
| 67 |
-
##
|
| 68 |
-
- Typical size: `300` patients
|
| 69 |
-
- Objective: detect all `invalid_age` cases only
|
| 70 |
-
- Includes valid edge-case age traps to punish naive thresholding
|
| 71 |
-
- Bias grading disabled
|
| 72 |
|
| 73 |
-
### Task
|
| 74 |
-
-
|
| 75 |
-
-
|
| 76 |
-
-
|
| 77 |
-
-
|
| 78 |
|
| 79 |
-
### Task
|
| 80 |
-
-
|
| 81 |
-
-
|
| 82 |
-
-
|
| 83 |
-
-
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
## Action Space
|
| 88 |
|
|
@@ -91,7 +70,7 @@ class AuditAction(Action):
|
|
| 91 |
action_type: str # investigate_pattern | compute_distribution | flag_error | propose_fix | submit_report
|
| 92 |
variable: Optional[str]
|
| 93 |
patient_id: Optional[str]
|
| 94 |
-
error_type: Optional[str] # invalid_age | temporal_inconsistency | selection_bias
|
| 95 |
reason: Optional[str]
|
| 96 |
proposed_value: Optional[str]
|
| 97 |
report: Optional[str]
|
|
@@ -107,130 +86,151 @@ class AuditObservation(Observation):
|
|
| 107 |
task_id: str
|
| 108 |
task_type: str
|
| 109 |
task_description: str
|
|
|
|
|
|
|
| 110 |
dataset: list[dict]
|
| 111 |
errors_found: list[str]
|
| 112 |
patterns_investigated: list[str]
|
| 113 |
distributions_computed: list[str]
|
| 114 |
feedback: str
|
| 115 |
score_so_far: float
|
|
|
|
|
|
|
| 116 |
attempts_remaining: int
|
| 117 |
phase: str
|
| 118 |
```
|
| 119 |
|
| 120 |
-
|
| 121 |
|
| 122 |
-
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
-
|
| 127 |
-
-
|
| 128 |
-
-
|
| 129 |
-
-
|
| 130 |
-
-
|
| 131 |
-
-
|
| 132 |
-
- hard-task bias detection bonus: `+0.20`
|
| 133 |
-
- difficulty multipliers by task
|
| 134 |
-
- score clamped to `[0.0, 1.0]`
|
| 135 |
|
| 136 |
-
This
|
| 137 |
|
| 138 |
-
|
| 139 |
|
| 140 |
-
|
| 141 |
|
| 142 |
-
Generator script:
|
| 143 |
```bash
|
| 144 |
-
|
| 145 |
-
python3 dataset_generator.py
|
| 146 |
```
|
| 147 |
|
| 148 |
What it guarantees:
|
| 149 |
-
- same seed ->
|
| 150 |
-
- different seeds -> different datasets,
|
| 151 |
-
-
|
| 152 |
-
-
|
| 153 |
|
| 154 |
-
Example validated
|
| 155 |
-
- Easy: `300` patients, `12` injected errors, traps enabled
|
| 156 |
-
- Medium: `500` patients, `37` injected errors, traps enabled
|
| 157 |
-
- Hard: `800` patients, `49` injected errors + bias signal, traps enabled
|
| 158 |
|
| 159 |
-
--
|
|
|
|
|
|
|
| 160 |
|
| 161 |
## Baseline Inference (`inference.py`)
|
| 162 |
|
| 163 |
-
`inference.py`
|
| 164 |
-
|
| 165 |
-
- `
|
| 166 |
-
- `
|
| 167 |
-
- `
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
Run:
|
| 170 |
```bash
|
| 171 |
python3 inference.py --mode all
|
| 172 |
```
|
| 173 |
|
| 174 |
-
|
| 175 |
-
- `API_BASE_URL`
|
| 176 |
-
- `MODEL_NAME`
|
| 177 |
-
- `HF_TOKEN` or `OPENAI_API_KEY`
|
| 178 |
-
- `ENV_BASE_URL` (defaults to `http://localhost:8000`)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
### 1) Start server
|
| 191 |
```bash
|
| 192 |
cd server
|
| 193 |
PYTHONPATH=.. python3 -m uvicorn app:app --host 0.0.0.0 --port 8000
|
| 194 |
```
|
| 195 |
|
| 196 |
### 2) Health check
|
|
|
|
| 197 |
```bash
|
| 198 |
curl -s http://localhost:8000/health
|
| 199 |
```
|
| 200 |
|
| 201 |
-
### 3) Run baseline
|
|
|
|
| 202 |
```bash
|
| 203 |
cd ..
|
| 204 |
-
python3 inference.py --mode
|
| 205 |
```
|
| 206 |
|
| 207 |
-
---
|
| 208 |
-
|
| 209 |
## Docker
|
| 210 |
|
| 211 |
Build and run:
|
|
|
|
| 212 |
```bash
|
| 213 |
cd server
|
| 214 |
docker build -t clinical-trial-auditor:latest .
|
| 215 |
docker run -p 8000:8000 clinical-trial-auditor:latest
|
| 216 |
```
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
---
|
| 221 |
|
| 222 |
## Hugging Face Space Readiness Checklist
|
| 223 |
|
| 224 |
-
- [x] OpenEnv interface implemented
|
| 225 |
-
- [x] typed models for
|
| 226 |
- [x] `openenv.yaml` present
|
| 227 |
-
- [x] 3 tasks with deterministic graders and
|
| 228 |
-
- [x]
|
| 229 |
-
- [x]
|
| 230 |
-
- [x] dockerized server
|
| 231 |
-
- [x] `openenv validate .` passes
|
| 232 |
-
|
| 233 |
-
---
|
| 234 |
|
| 235 |
## Project Structure
|
| 236 |
|
|
@@ -250,8 +250,6 @@ clinical_trial_auditor/
|
|
| 250 |
└── Dockerfile
|
| 251 |
```
|
| 252 |
|
| 253 |
-
---
|
| 254 |
-
|
| 255 |
## Motivation
|
| 256 |
|
| 257 |
-
This benchmark is
|
|
|
|
| 1 |
# Clinical Trial Auditor (OpenEnv)
|
| 2 |
|
| 3 |
+
Clinical Trial Auditor is a protocol-aware OpenEnv benchmark for clinical data auditing. The agent acts as a Senior Clinical Data Manager reviewing procedurally generated Phase III oncology trial data under dynamic per-episode rules.
|
| 4 |
|
| 5 |
+
This is not a static spreadsheet puzzle. Every `reset()` samples a new protocol excerpt and a new dataset, so the agent must read the rules for that episode and then audit the records accordingly.
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
## Why This Matters
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
Real clinical audits are messy:
|
| 10 |
+
- eligibility criteria vary by protocol,
|
| 11 |
+
- timeline rules include exceptions,
|
| 12 |
+
- suspicious subgroup outcomes are not always evidence of bias,
|
| 13 |
+
- false positives waste reviewer time and can trigger unnecessary escalations.
|
| 14 |
|
| 15 |
+
This environment is built to evaluate exactly those failure modes. It targets the gap between "can parse a table" and "can follow a high-stakes auditing workflow with protocol friction and adversarial traps."
|
| 16 |
|
| 17 |
+
## What Makes This Benchmark Different
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
- Dynamic protocol reasoning: each episode exposes a new `trial_protocol_excerpt` with episode-specific age ranges and treatment-start windows.
|
| 20 |
+
- Cross-modal audit logic: the agent must apply text rules from the protocol to tabular patient data.
|
| 21 |
+
- Stage-aware timing exceptions: Stage IV patients can have a longer enrollment-to-treatment window, which creates valid edge cases that trap shortcut heuristics.
|
| 22 |
+
- Hallucination traps: hard episodes can contain a confounded high-risk cohort that looks biased overall but is not actionable after stage-adjusted review.
|
| 23 |
+
- Dense reward plus benchmark rubric: step rewards encourage learning, while `score_so_far` tracks a judge-facing episode rubric emphasizing recall, precision, workflow discipline, efficiency, and report quality.
|
| 24 |
|
| 25 |
## OpenEnv Compliance
|
| 26 |
|
| 27 |
This project implements the required OpenEnv interface:
|
| 28 |
+
- typed `Action`, `Observation`, and `State` models with Pydantic,
|
| 29 |
- `reset(seed, task_id, ...) -> Observation`,
|
| 30 |
- `step(action) -> Observation`,
|
| 31 |
- `state -> current state`,
|
| 32 |
+
- `openenv.yaml` at the repo root.
|
| 33 |
|
| 34 |
Validation:
|
| 35 |
+
|
| 36 |
```bash
|
| 37 |
openenv validate .
|
| 38 |
```
|
| 39 |
|
| 40 |
+
Local validation result:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
```text
|
| 43 |
+
[OK] : Ready for multi-mode deployment
|
| 44 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
## Task Suite
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
### Task 1: `task_easy` — Dynamic Eligibility Screening
|
| 49 |
+
- Dataset size: about `300` patients
|
| 50 |
+
- Goal: flag `invalid_age`
|
| 51 |
+
- Difficulty source: the age bounds are episode-specific, not fixed at 18-120
|
| 52 |
+
- Traps: valid edge ages at the protocol boundary
|
| 53 |
|
| 54 |
+
### Task 2: `task_medium` — Protocol Timeline Audit
|
| 55 |
+
- Dataset size: about `480` patients
|
| 56 |
+
- Goal: flag `invalid_age`, `temporal_inconsistency`, and `protocol_window_violation`
|
| 57 |
+
- Difficulty source: the treatment-start window is protocol-specific and Stage IV has a longer valid window
|
| 58 |
+
- Traps: valid near-boundary start delays and near-immediate but valid deaths
|
| 59 |
|
| 60 |
+
### Task 3: `task_hard` — Equity + Protocol Audit
|
| 61 |
+
- Dataset size: about `720` patients
|
| 62 |
+
- Goal: flag record-level issues and determine whether actionable `selection_bias` exists
|
| 63 |
+
- Difficulty source: some hard episodes contain real control-arm bias, while others contain a confounded high-risk cohort that only looks biased before stage adjustment
|
| 64 |
+
- Traps: treatment-arm skew, high-risk outreach sites, and false-positive bias patterns
|
| 65 |
|
| 66 |
## Action Space
|
| 67 |
|
|
|
|
| 70 |
action_type: str # investigate_pattern | compute_distribution | flag_error | propose_fix | submit_report
|
| 71 |
variable: Optional[str]
|
| 72 |
patient_id: Optional[str]
|
| 73 |
+
error_type: Optional[str] # invalid_age | temporal_inconsistency | protocol_window_violation | selection_bias
|
| 74 |
reason: Optional[str]
|
| 75 |
proposed_value: Optional[str]
|
| 76 |
report: Optional[str]
|
|
|
|
| 86 |
task_id: str
|
| 87 |
task_type: str
|
| 88 |
task_description: str
|
| 89 |
+
protocol_title: str
|
| 90 |
+
trial_protocol_excerpt: str
|
| 91 |
dataset: list[dict]
|
| 92 |
errors_found: list[str]
|
| 93 |
patterns_investigated: list[str]
|
| 94 |
distributions_computed: list[str]
|
| 95 |
feedback: str
|
| 96 |
score_so_far: float
|
| 97 |
+
dense_reward_total: float
|
| 98 |
+
score_breakdown: dict[str, float]
|
| 99 |
attempts_remaining: int
|
| 100 |
phase: str
|
| 101 |
```
|
| 102 |
|
| 103 |
+
## Reward Design and Benchmark Score
|
| 104 |
|
| 105 |
+
The environment uses two scoring layers:
|
| 106 |
|
| 107 |
+
- Dense step reward:
|
| 108 |
+
- correct flags,
|
| 109 |
+
- false-positive penalties,
|
| 110 |
+
- duplicate penalties,
|
| 111 |
+
- investigation/distribution bonuses,
|
| 112 |
+
- confidence penalties for overconfident wrong flags,
|
| 113 |
+
- per-step costs.
|
| 114 |
|
| 115 |
+
- Episode benchmark score (`score_so_far`):
|
| 116 |
+
- recall: `70%`
|
| 117 |
+
- precision: `15%`
|
| 118 |
+
- workflow discipline: `5%`
|
| 119 |
+
- efficiency: `5%`
|
| 120 |
+
- report quality: `5%`
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
This separation keeps the RL signal dense while preventing early score saturation from hiding later mistakes.
|
| 123 |
|
| 124 |
+
## Procedural Generation and Reproducibility
|
| 125 |
|
| 126 |
+
Run the generator self-test:
|
| 127 |
|
|
|
|
| 128 |
```bash
|
| 129 |
+
python3 server/dataset_generator.py
|
|
|
|
| 130 |
```
|
| 131 |
|
| 132 |
What it guarantees:
|
| 133 |
+
- same seed -> same dataset, same protocol excerpt, same ground truth,
|
| 134 |
+
- different seeds -> different protocols and different datasets,
|
| 135 |
+
- deterministic grading compatibility,
|
| 136 |
+
- hard mode can alternate between `true_bias` and `confounded_no_bias`.
|
| 137 |
|
| 138 |
+
Example validated seeded profile:
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
- Easy: `300` patients, `8` record-level errors, `13` traps
|
| 141 |
+
- Medium: `480` patients, `23` record-level errors, `25` traps
|
| 142 |
+
- Hard: `720` patients, `34` total issues including protocol/timing/bias logic, `40` traps
|
| 143 |
|
| 144 |
## Baseline Inference (`inference.py`)
|
| 145 |
|
| 146 |
+
`inference.py` now demonstrates a clean difficulty gradient:
|
| 147 |
+
|
| 148 |
+
- `naive`: raw sample-level behavior
|
| 149 |
+
- `heuristic`: rule-based but trap-prone
|
| 150 |
+
- `full`: protocol parser + stage-aware detectors + structured reporting
|
| 151 |
+
- `all`: side-by-side comparison
|
| 152 |
+
|
| 153 |
+
HTTP mode:
|
| 154 |
|
|
|
|
| 155 |
```bash
|
| 156 |
python3 inference.py --mode all
|
| 157 |
```
|
| 158 |
|
| 159 |
+
Isolated local validation mode with no socket bind:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
```bash
|
| 162 |
+
ENV_BASE_URL=inprocess python3 inference.py --mode all
|
| 163 |
+
```
|
| 164 |
|
| 165 |
+
LLM integration:
|
| 166 |
+
- When `OPENAI_API_KEY` or `HF_TOKEN` is present, naive mode and report generation use the OpenAI-compatible client pointed at `API_BASE_URL`.
|
| 167 |
+
- Without a key, the script falls back to deterministic local behavior so validation still runs end-to-end.
|
| 168 |
|
| 169 |
+
Current reproducible local benchmark result:
|
| 170 |
|
| 171 |
+
Command:
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
ENV_BASE_URL=inprocess python3 inference.py --mode all --seed 20260402
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Scores:
|
| 178 |
+
|
| 179 |
+
| Agent | Easy | Medium | Hard | Average |
|
| 180 |
+
|---|---:|---:|---:|---:|
|
| 181 |
+
| Naive | 0.36 | 0.08 | 0.09 | 0.18 |
|
| 182 |
+
| Heuristic | 0.81 | 0.56 | 0.45 | 0.60 |
|
| 183 |
+
| Full | 0.98 | 0.99 | 0.99 | 0.99 |
|
| 184 |
+
|
| 185 |
+
This is the intended story:
|
| 186 |
+
- naive agents underperform badly,
|
| 187 |
+
- shallow heuristics get trapped by dynamic protocol edges and confounded bias signals,
|
| 188 |
+
- protocol-aware agents perform strongly.
|
| 189 |
+
|
| 190 |
+
## Local Usage
|
| 191 |
+
|
| 192 |
+
### 1) Start the server
|
| 193 |
|
|
|
|
| 194 |
```bash
|
| 195 |
cd server
|
| 196 |
PYTHONPATH=.. python3 -m uvicorn app:app --host 0.0.0.0 --port 8000
|
| 197 |
```
|
| 198 |
|
| 199 |
### 2) Health check
|
| 200 |
+
|
| 201 |
```bash
|
| 202 |
curl -s http://localhost:8000/health
|
| 203 |
```
|
| 204 |
|
| 205 |
+
### 3) Run the baseline
|
| 206 |
+
|
| 207 |
```bash
|
| 208 |
cd ..
|
| 209 |
+
python3 inference.py --mode all
|
| 210 |
```
|
| 211 |
|
|
|
|
|
|
|
| 212 |
## Docker
|
| 213 |
|
| 214 |
Build and run:
|
| 215 |
+
|
| 216 |
```bash
|
| 217 |
cd server
|
| 218 |
docker build -t clinical-trial-auditor:latest .
|
| 219 |
docker run -p 8000:8000 clinical-trial-auditor:latest
|
| 220 |
```
|
| 221 |
|
| 222 |
+
The container exposes `/health` for health checks and is ready for Hugging Face Spaces container deployment.
|
|
|
|
|
|
|
| 223 |
|
| 224 |
## Hugging Face Space Readiness Checklist
|
| 225 |
|
| 226 |
+
- [x] OpenEnv interface implemented
|
| 227 |
+
- [x] typed models for action/observation/state
|
| 228 |
- [x] `openenv.yaml` present
|
| 229 |
+
- [x] 3 tasks with deterministic graders and scores in `[0.0, 1.0]`
|
| 230 |
+
- [x] dense reward shaping and benchmark rubric
|
| 231 |
+
- [x] reproducible `inference.py` at repo root
|
| 232 |
+
- [x] dockerized server
|
| 233 |
+
- [x] `openenv validate .` passes
|
|
|
|
|
|
|
| 234 |
|
| 235 |
## Project Structure
|
| 236 |
|
|
|
|
| 250 |
└── Dockerfile
|
| 251 |
```
|
| 252 |
|
|
|
|
|
|
|
| 253 |
## Motivation
|
| 254 |
|
| 255 |
+
This benchmark is built to test whether an agent can read a changing clinical protocol, audit patient records against that protocol, avoid hallucinated escalations, and write a grounded operational report under a limited action budget.
|
inference.py
CHANGED
|
@@ -1,458 +1,429 @@
|
|
| 1 |
"""
|
| 2 |
-
Clinical Trial Auditor —
|
| 3 |
-
===========================================
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
1. NAIVE
|
| 7 |
-
2. HEURISTIC —
|
| 8 |
-
3. FULL
|
| 9 |
-
|
| 10 |
-
Usage:
|
| 11 |
-
python inference.py # Full agent (default)
|
| 12 |
-
python inference.py --mode naive # Naive LLM-only agent
|
| 13 |
-
python inference.py --mode heuristic # Simple heuristic agent
|
| 14 |
-
python inference.py --mode full # Full agentic pipeline
|
| 15 |
-
python inference.py --mode all # Run all three, side-by-side
|
| 16 |
-
|
| 17 |
-
Pipeline (full mode):
|
| 18 |
-
1. PROFILE → Schema-aware statistical analysis of dataset
|
| 19 |
-
2. DETECT → Multi-detector anomaly pipeline with confidence scoring
|
| 20 |
-
3. ASSESS → Risk severity + clinical impact evaluation
|
| 21 |
-
4. PLAN → Task-adaptive optimal action sequence
|
| 22 |
-
5. REASON → LLM for ambiguous cases + expert report generation
|
| 23 |
-
6. EXECUTE → Deterministic environment interaction
|
| 24 |
-
7. EVALUATE → Precision/recall/F1 metrics tracking
|
| 25 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
import os
|
|
|
|
|
|
|
| 27 |
import sys
|
| 28 |
import time
|
| 29 |
-
import json
|
| 30 |
-
import math
|
| 31 |
-
import argparse
|
| 32 |
-
import statistics
|
| 33 |
-
from datetime import datetime
|
| 34 |
from collections import Counter
|
|
|
|
|
|
|
|
|
|
| 35 |
from typing import Optional
|
| 36 |
|
|
|
|
|
|
|
| 37 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
|
| 39 |
-
from openai import OpenAI
|
| 40 |
from client import ClinicalTrialAuditorEnv
|
| 41 |
from models import AuditAction
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 45 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
| 46 |
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 47 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 54 |
-
# CORE DATA STRUCTURES
|
| 55 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 56 |
|
|
|
|
| 57 |
class Finding:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
self.reason = reason
|
| 65 |
-
self.confidence = min(1.0, max(0.0, confidence))
|
| 66 |
-
self.risk = risk
|
| 67 |
-
self.value = value
|
| 68 |
-
self.statistical_context = statistical_context
|
| 69 |
|
| 70 |
@property
|
| 71 |
def priority_score(self) -> float:
|
| 72 |
-
|
| 73 |
-
return self.confidence *
|
| 74 |
-
|
| 75 |
-
def explain(self) -> str:
|
| 76 |
-
parts = [f"{self.error_type}: {self.reason}"]
|
| 77 |
-
if self.statistical_context:
|
| 78 |
-
parts.append(f" Evidence: {self.statistical_context}")
|
| 79 |
-
parts.append(f" Confidence: {self.confidence:.0%} | Risk: {self.risk.upper()}")
|
| 80 |
-
return "\n".join(parts)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 84 |
-
# MODULE 1: DATA PROFILER — Robust statistical summary
|
| 85 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 86 |
-
|
| 87 |
-
class DataProfiler:
|
| 88 |
-
"""Schema-aware statistical profiler using robust estimators (IQR, MAD)."""
|
| 89 |
-
|
| 90 |
-
def __init__(self, dataset: list[dict]):
|
| 91 |
-
self.dataset = dataset
|
| 92 |
-
self.n = len(dataset)
|
| 93 |
-
self.columns = sorted({k for row in dataset for k in row.keys()})
|
| 94 |
-
self.types = self._infer_types()
|
| 95 |
-
self.profiles = {}
|
| 96 |
-
|
| 97 |
-
def _infer_types(self) -> dict:
|
| 98 |
-
types = {}
|
| 99 |
-
for col in self.columns:
|
| 100 |
-
vals = [r.get(col) for r in self.dataset if r.get(col) is not None]
|
| 101 |
-
if not vals:
|
| 102 |
-
types[col] = "unknown"
|
| 103 |
-
elif all(isinstance(v, (int, float)) for v in vals):
|
| 104 |
-
types[col] = "numeric"
|
| 105 |
-
elif all(isinstance(v, str) and self._is_date(v) for v in vals[:5]):
|
| 106 |
-
types[col] = "date"
|
| 107 |
-
elif col.lower().endswith("_id") or col.lower() == "id":
|
| 108 |
-
types[col] = "id"
|
| 109 |
-
else:
|
| 110 |
-
types[col] = "categorical"
|
| 111 |
-
return types
|
| 112 |
|
|
|
|
|
|
|
| 113 |
@staticmethod
|
| 114 |
-
def
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
"min": min(values), "max": max(values),
|
| 146 |
-
"q1": q1, "q3": q3, "iqr": iqr,
|
| 147 |
-
"null_count": null_count, "valid_count": n,
|
| 148 |
-
"iqr_lower": q1 - 1.5 * iqr,
|
| 149 |
-
"iqr_upper": q3 + 1.5 * iqr,
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
def profile_categorical(self, col: str) -> dict:
|
| 153 |
-
vals = [str(r.get(col, "None")) for r in self.dataset]
|
| 154 |
-
counter = Counter(vals)
|
| 155 |
-
total = len(vals)
|
| 156 |
-
return {
|
| 157 |
-
"distribution": dict(counter),
|
| 158 |
-
"unique_count": len(counter),
|
| 159 |
-
"mode": counter.most_common(1)[0][0] if counter else None,
|
| 160 |
-
"mode_ratio": counter.most_common(1)[0][1] / total if counter else 0,
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
def profile_all(self) -> dict:
|
| 164 |
-
for col in self.columns:
|
| 165 |
-
if self.types.get(col) == "numeric":
|
| 166 |
-
self.profiles[col] = self.profile_numeric(col)
|
| 167 |
-
elif self.types.get(col) == "categorical":
|
| 168 |
-
self.profiles[col] = self.profile_categorical(col)
|
| 169 |
-
return self.profiles
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 173 |
-
# MODULE 2: ANOMALY DETECTORS — Confidence + Risk scoring
|
| 174 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 175 |
-
|
| 176 |
-
class AgeAnomalyDetector:
|
| 177 |
-
"""
|
| 178 |
-
Multi-layer age anomaly detection:
|
| 179 |
-
- Layer 1: Clinical domain constraints (18-120 for trial eligibility)
|
| 180 |
-
- Layer 2: Statistical outliers via IQR
|
| 181 |
-
- Layer 3: Biological plausibility
|
| 182 |
-
"""
|
| 183 |
-
CLINICAL_MIN, CLINICAL_MAX = 18, 120
|
| 184 |
-
|
| 185 |
-
def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
|
| 186 |
-
findings = []
|
| 187 |
-
age_prof = profile.get("age", {})
|
| 188 |
-
median = age_prof.get("median", 60)
|
| 189 |
-
mad = age_prof.get("mad", 15)
|
| 190 |
|
| 191 |
-
for row in dataset:
|
| 192 |
-
pid = row.get("patient_id", "?")
|
| 193 |
-
age = row.get("age")
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
| 203 |
|
| 204 |
-
is_domain_violation = age < self.CLINICAL_MIN or age > self.CLINICAL_MAX
|
| 205 |
-
|
| 206 |
-
if is_domain_violation:
|
| 207 |
-
deviation = abs(age - median) / mad if mad > 0 else 0
|
| 208 |
-
is_biological_impossible = age < 0 or age > 122
|
| 209 |
-
if is_biological_impossible:
|
| 210 |
-
conf, risk = 1.0, "critical"
|
| 211 |
-
context = f"Biologically impossible (age={age})"
|
| 212 |
-
elif age > 200:
|
| 213 |
-
conf, risk = 0.99, "critical"
|
| 214 |
-
context = f"Likely sentinel/data entry error: {deviation:.1f} MAD from median"
|
| 215 |
-
else:
|
| 216 |
-
conf, risk = 0.95, "high"
|
| 217 |
-
context = f"Outside range [{self.CLINICAL_MIN}-{self.CLINICAL_MAX}]"
|
| 218 |
-
|
| 219 |
-
findings.append(Finding(
|
| 220 |
-
patient_id=pid, error_type="invalid_age",
|
| 221 |
-
reason=f"Age {age} violates clinical trial range [{self.CLINICAL_MIN}-{self.CLINICAL_MAX}]",
|
| 222 |
-
confidence=conf, risk=risk, value=age,
|
| 223 |
-
statistical_context=context,
|
| 224 |
-
))
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
return findings
|
| 227 |
|
| 228 |
|
| 229 |
-
class
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
@staticmethod
|
| 233 |
-
def _parse_date(val) -> Optional[datetime]:
|
| 234 |
-
if not val or val in ("", "N/A", "None", "null"):
|
| 235 |
-
return None
|
| 236 |
-
for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%d-%m-%Y", "%Y/%m/%d"):
|
| 237 |
-
try:
|
| 238 |
-
return datetime.strptime(str(val), fmt)
|
| 239 |
-
except (ValueError, TypeError):
|
| 240 |
-
pass
|
| 241 |
-
return None
|
| 242 |
-
|
| 243 |
-
def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
|
| 244 |
findings = []
|
| 245 |
for row in dataset:
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
))
|
| 260 |
return findings
|
| 261 |
|
| 262 |
|
| 263 |
-
class
|
| 264 |
-
|
| 265 |
-
REPRESENTATION_THRESHOLD = 0.65
|
| 266 |
-
OUTCOME_DISPARITY_THRESHOLD = 0.20
|
| 267 |
-
|
| 268 |
-
def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
|
| 269 |
findings = []
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
| 287 |
|
| 288 |
-
rates = list(outcome_rates.values())
|
| 289 |
-
max_disparity = max(rates) - min(rates) if len(rates) > 1 else 0
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
confidence = 0.0
|
| 303 |
-
|
| 304 |
-
if representation_ratio >= self.REPRESENTATION_THRESHOLD:
|
| 305 |
-
evidence.append(f"Representation: {dominant_name}={representation_ratio:.0%} of control")
|
| 306 |
-
confidence += 0.4
|
| 307 |
-
if minority_deceased > 0:
|
| 308 |
-
evidence.append(f"Outcome disparity: minority mortality={minority_mortality:.0%}")
|
| 309 |
-
confidence += 0.2
|
| 310 |
-
if male_ratio >= 0.5:
|
| 311 |
-
evidence.append(f"Gender imbalance: male={male_ratio:.0%}")
|
| 312 |
-
confidence += 0.1
|
| 313 |
-
if max_disparity > self.OUTCOME_DISPARITY_THRESHOLD:
|
| 314 |
-
evidence.append(f"Statistically significant disparity: Δ={max_disparity:.0%}")
|
| 315 |
-
confidence += 0.15
|
| 316 |
-
|
| 317 |
-
confidence = min(1.0, confidence)
|
| 318 |
-
|
| 319 |
-
if confidence >= 0.6 and representation_ratio >= self.REPRESENTATION_THRESHOLD:
|
| 320 |
-
findings.append(Finding(
|
| 321 |
-
patient_id=None, error_type="selection_bias",
|
| 322 |
-
reason="Multi-dimensional selection bias: " + "; ".join(evidence),
|
| 323 |
-
confidence=confidence, risk="critical",
|
| 324 |
-
value=f"{dominant_name}={representation_ratio:.0%}",
|
| 325 |
-
statistical_context=f"Representation: {dominant_name}={representation_ratio:.0%} | Disparity: Δ={max_disparity:.0%} | Minority mortality: {minority_mortality:.0%}",
|
| 326 |
-
))
|
| 327 |
-
|
| 328 |
-
return findings
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 332 |
-
# MODULE 3: ACTION PLANNER
|
| 333 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 334 |
|
| 335 |
class ActionPlanner:
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
|
| 348 |
-
data_findings = [f for f in findings if f.error_type != "selection_bias"]
|
| 349 |
-
bias_findings = [f for f in findings if f.error_type == "selection_bias"]
|
| 350 |
|
| 351 |
-
data_findings.sort(key=lambda f: -f.priority_score)
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 379 |
-
# MODULE 4: LLM REASONING LAYER
|
| 380 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 381 |
-
|
| 382 |
-
def generate_expert_report(client, findings: list[Finding],
|
| 383 |
-
profiles: dict, task_type: str) -> str:
|
| 384 |
-
"""LLM generates expert audit report from pre-analyzed findings."""
|
| 385 |
-
age_f = [f for f in findings if f.error_type == "invalid_age"]
|
| 386 |
-
temp_f = [f for f in findings if f.error_type == "temporal_inconsistency"]
|
| 387 |
-
bias_f = [f for f in findings if f.error_type == "selection_bias"]
|
| 388 |
-
age_p = profiles.get("age", {})
|
| 389 |
-
|
| 390 |
-
sections = [
|
| 391 |
-
f"AUDIT ANALYSIS — Task: {task_type}",
|
| 392 |
-
f"Dataset: {age_p.get('valid_count', 0) + age_p.get('null_count', 0)} patients",
|
| 393 |
-
f"Age: median={age_p.get('median', '?')}, range=[{age_p.get('min', '?')}, {age_p.get('max', '?')}]",
|
| 394 |
-
"", "ISSUES:",
|
| 395 |
-
]
|
| 396 |
-
|
| 397 |
-
if age_f:
|
| 398 |
-
sections.append(f"• {len(age_f)} age anomalies")
|
| 399 |
-
for f in age_f[:3]:
|
| 400 |
-
sections.append(f" - {f.patient_id}: age={f.value}")
|
| 401 |
-
if temp_f:
|
| 402 |
-
sections.append(f"• {len(temp_f)} temporal violations")
|
| 403 |
-
for f in temp_f[:3]:
|
| 404 |
-
sections.append(f" - {f.patient_id}: {f.value}")
|
| 405 |
-
if bias_f:
|
| 406 |
-
sections.append(f"• Selection bias: {bias_f[0].statistical_context}")
|
| 407 |
-
|
| 408 |
-
try:
|
| 409 |
-
completion = client.chat.completions.create(
|
| 410 |
-
model=MODEL_NAME,
|
| 411 |
-
messages=[
|
| 412 |
-
{
|
| 413 |
-
"role": "system",
|
| 414 |
-
"content": (
|
| 415 |
-
"You are a Senior Clinical Data Manager writing a formal audit report. "
|
| 416 |
-
"Provide: 1) SUMMARY with severity, 2) ROOT CAUSE analysis, "
|
| 417 |
-
"3) RISK ASSESSMENT (impact on trial validity), "
|
| 418 |
-
"4) RECOMMENDED corrective actions, "
|
| 419 |
-
"5) REGULATORY compliance impact. "
|
| 420 |
-
"Be concise (max 150 words). Use professional clinical language."
|
| 421 |
-
),
|
| 422 |
-
},
|
| 423 |
-
{"role": "user", "content": "\n".join(sections)},
|
| 424 |
-
],
|
| 425 |
-
max_tokens=250,
|
| 426 |
-
temperature=0,
|
| 427 |
-
)
|
| 428 |
-
report = completion.choices[0].message.content or ""
|
| 429 |
-
if "recommend" not in report.lower():
|
| 430 |
-
report += "\nRecommend immediate corrective action for all identified issues."
|
| 431 |
-
return report
|
| 432 |
-
except Exception as e:
|
| 433 |
-
# Deterministic fallback
|
| 434 |
-
severity = "CRITICAL" if bias_f else "HIGH" if temp_f else "MEDIUM"
|
| 435 |
-
parts = [
|
| 436 |
-
f"CLINICAL DATA AUDIT REPORT — {task_type.replace('_', ' ').title()}",
|
| 437 |
-
f"\nSUMMARY: {len(findings)} data quality issues identified.",
|
| 438 |
-
]
|
| 439 |
-
if age_f:
|
| 440 |
-
parts.append(f"\nAGE ANOMALIES ({len(age_f)}): Root cause: data entry errors or ETL pipeline failures.")
|
| 441 |
-
if temp_f:
|
| 442 |
-
parts.append(f"\nTEMPORAL VIOLATIONS ({len(temp_f)}): Root cause: date field mapping errors.")
|
| 443 |
-
if bias_f:
|
| 444 |
-
parts.append(f"\nSELECTION BIAS: {bias_f[0].statistical_context}.")
|
| 445 |
-
parts.append(f"\nRISK LEVEL: {severity}. Recommend immediate corrective action: "
|
| 446 |
-
"quarantine affected records, audit data entry workflows, implement validation "
|
| 447 |
-
"rules, and rebalance demographic representation in control group. "
|
| 448 |
-
"This impacts regulatory compliance with FDA 21 CFR Part 11 and ICH-GCP guidelines.")
|
| 449 |
-
return "\n".join(parts)
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 453 |
-
# MODULE 5: METRICS TRACKER
|
| 454 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 455 |
-
|
| 456 |
class MetricsTracker:
|
| 457 |
def __init__(self):
|
| 458 |
self.true_pos = 0
|
|
@@ -461,123 +432,176 @@ class MetricsTracker:
|
|
| 461 |
self.steps = 0
|
| 462 |
self.llm_calls = 0
|
| 463 |
|
| 464 |
-
def record(self, feedback: str):
|
| 465 |
self.total_flagged += 1
|
| 466 |
-
if "
|
| 467 |
self.true_pos += 1
|
| 468 |
-
elif "
|
| 469 |
self.false_pos += 1
|
| 470 |
|
| 471 |
@property
|
| 472 |
def precision(self) -> float:
|
| 473 |
-
return self.true_pos / self.total_flagged if self.total_flagged else 0
|
| 474 |
|
| 475 |
def summary(self) -> str:
|
| 476 |
return (
|
| 477 |
-
f"
|
| 478 |
-
f"(precision={self.precision:.0%}) | "
|
| 479 |
-
f"{self.steps} steps | {self.llm_calls} LLM call(s)"
|
| 480 |
)
|
| 481 |
|
| 482 |
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
|
|
|
|
|
|
| 486 |
|
| 487 |
-
def
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
metrics = MetricsTracker()
|
| 496 |
final_score = 0.0
|
| 497 |
|
| 498 |
-
with
|
| 499 |
-
result = env.reset(task_id=task_id, seed=
|
| 500 |
obs = result.observation.model_dump()
|
| 501 |
dataset = obs["dataset"]
|
| 502 |
-
|
| 503 |
max_steps = obs["attempts_remaining"]
|
| 504 |
-
|
|
|
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
sample_str = json.dumps(sample, indent=1, default=str)
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
)
|
| 533 |
-
llm_response = completion.choices[0].message.content or ""
|
| 534 |
-
metrics.llm_calls += 1
|
| 535 |
-
except Exception as e:
|
| 536 |
-
print(f" LLM Error: {e}")
|
| 537 |
-
llm_response = ""
|
| 538 |
|
| 539 |
-
|
| 540 |
-
for
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
-
|
| 547 |
-
lines = llm_response.strip().split("\n")
|
| 548 |
-
for line in lines:
|
| 549 |
if result.done:
|
| 550 |
break
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
continue
|
| 557 |
-
# Check if this patient_id exists
|
| 558 |
-
if not any(p.get("patient_id") == pid for p in dataset):
|
| 559 |
-
continue
|
| 560 |
-
result = env.step(AuditAction(
|
| 561 |
-
action_type="flag_error",
|
| 562 |
-
patient_id=pid,
|
| 563 |
-
error_type=etype,
|
| 564 |
-
reason=parts[2].strip() if len(parts) > 2 else "LLM detected",
|
| 565 |
-
))
|
| 566 |
-
obs = result.observation.model_dump()
|
| 567 |
-
final_score = obs["score_so_far"]
|
| 568 |
metrics.record(obs["feedback"])
|
| 569 |
-
metrics.steps += 1
|
| 570 |
|
| 571 |
-
# Submit report
|
| 572 |
if not result.done:
|
| 573 |
-
result = env.step(
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
| 581 |
obs = result.observation.model_dump()
|
| 582 |
final_score = obs["score_so_far"]
|
| 583 |
metrics.steps += 1
|
|
@@ -586,96 +610,69 @@ def run_naive_task(client, task_id: str, task_name: str):
|
|
| 586 |
return final_score, metrics
|
| 587 |
|
| 588 |
|
| 589 |
-
|
| 590 |
-
# AGENT MODE 2: HEURISTIC (simple rules, no LLM)
|
| 591 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 592 |
-
|
| 593 |
-
def run_heuristic_task(client_unused, task_id: str, task_name: str):
|
| 594 |
-
"""
|
| 595 |
-
Heuristic agent: simple threshold rules, no LLM.
|
| 596 |
-
Catches obvious errors but falls for traps. Expected score: ~0.45-0.60
|
| 597 |
-
"""
|
| 598 |
print(f"\n Task: {task_name}")
|
| 599 |
-
print(" " + "-" *
|
| 600 |
-
|
| 601 |
metrics = MetricsTracker()
|
| 602 |
final_score = 0.0
|
| 603 |
|
| 604 |
-
with
|
| 605 |
-
result = env.reset(task_id=task_id, seed=
|
| 606 |
obs = result.observation.model_dump()
|
| 607 |
dataset = obs["dataset"]
|
| 608 |
-
|
| 609 |
max_steps = obs["attempts_remaining"]
|
| 610 |
-
print(f" Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 611 |
|
| 612 |
-
|
| 613 |
-
for
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
metrics.steps += 1
|
| 618 |
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
| 625 |
break
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
# flagging valid old patients
|
| 632 |
-
if age is None or age < 18 or age > 100: # Deliberately wrong threshold
|
| 633 |
-
result = env.step(AuditAction(
|
| 634 |
-
action_type="flag_error",
|
| 635 |
-
patient_id=pid,
|
| 636 |
-
error_type="invalid_age",
|
| 637 |
-
reason=f"Age {age} outside expected range",
|
| 638 |
-
))
|
| 639 |
-
obs = result.observation.model_dump()
|
| 640 |
-
final_score = obs["score_so_far"]
|
| 641 |
metrics.record(obs["feedback"])
|
| 642 |
-
|
| 643 |
-
flags_made += 1
|
| 644 |
-
|
| 645 |
-
# Simple temporal check (if applicable)
|
| 646 |
-
if task_type in ("temporal_consistency", "comprehensive_audit"):
|
| 647 |
-
for p in dataset:
|
| 648 |
-
if flags_made >= step_budget or result.done:
|
| 649 |
-
break
|
| 650 |
-
ts = p.get("treatment_start")
|
| 651 |
-
dd = p.get("death_date")
|
| 652 |
-
if ts and dd:
|
| 653 |
-
try:
|
| 654 |
-
t = datetime.strptime(ts, "%Y-%m-%d")
|
| 655 |
-
d = datetime.strptime(dd, "%Y-%m-%d")
|
| 656 |
-
# BUG: heuristic flags ANY death within 7 days (catches traps)
|
| 657 |
-
if d < t or (d - t).days < 7:
|
| 658 |
-
pid = p.get("patient_id")
|
| 659 |
-
result = env.step(AuditAction(
|
| 660 |
-
action_type="flag_error",
|
| 661 |
-
patient_id=pid,
|
| 662 |
-
error_type="temporal_inconsistency",
|
| 663 |
-
reason=f"Suspicious temporal sequence",
|
| 664 |
-
))
|
| 665 |
-
obs = result.observation.model_dump()
|
| 666 |
-
final_score = obs["score_so_far"]
|
| 667 |
-
metrics.record(obs["feedback"])
|
| 668 |
-
metrics.steps += 1
|
| 669 |
-
flags_made += 1
|
| 670 |
-
except ValueError:
|
| 671 |
-
pass
|
| 672 |
-
|
| 673 |
-
# Submit report
|
| 674 |
if not result.done:
|
| 675 |
-
result = env.step(
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
obs = result.observation.model_dump()
|
| 680 |
final_score = obs["score_so_far"]
|
| 681 |
metrics.steps += 1
|
|
@@ -684,205 +681,152 @@ def run_heuristic_task(client_unused, task_id: str, task_name: str):
|
|
| 684 |
return final_score, metrics
|
| 685 |
|
| 686 |
|
| 687 |
-
|
| 688 |
-
# AGENT MODE 3: FULL AGENTIC PIPELINE
|
| 689 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 690 |
-
|
| 691 |
-
def run_full_task(client, task_id: str, task_name: str):
|
| 692 |
-
"""
|
| 693 |
-
Full agent: Statistical detection + LLM reasoning.
|
| 694 |
-
Expected score: ~0.85-0.95
|
| 695 |
-
"""
|
| 696 |
print(f"\n Task: {task_name}")
|
| 697 |
-
print(" " + "-" *
|
| 698 |
-
|
| 699 |
metrics = MetricsTracker()
|
| 700 |
final_score = 0.0
|
| 701 |
|
| 702 |
-
with
|
| 703 |
-
result = env.reset(task_id=task_id, seed=
|
| 704 |
obs = result.observation.model_dump()
|
| 705 |
dataset = obs["dataset"]
|
| 706 |
-
|
| 707 |
max_steps = obs["attempts_remaining"]
|
| 708 |
-
print(f"
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
ap = profiles.get("age", {})
|
| 714 |
-
print(f" Profile: age median={ap.get('median','?')}, "
|
| 715 |
-
f"range=[{ap.get('min','?')}-{ap.get('max','?')}], "
|
| 716 |
-
f"nulls={ap.get('null_count',0)}")
|
| 717 |
-
|
| 718 |
-
# 2. DETECT
|
| 719 |
-
all_findings = []
|
| 720 |
-
all_findings.extend(AgeAnomalyDetector().detect(dataset, profiles))
|
| 721 |
-
if task_type in ("temporal_consistency", "comprehensive_audit"):
|
| 722 |
-
all_findings.extend(TemporalConsistencyDetector().detect(dataset, profiles))
|
| 723 |
-
if task_type == "comprehensive_audit":
|
| 724 |
-
all_findings.extend(SelectionBiasDetector().detect(dataset, profiles))
|
| 725 |
-
|
| 726 |
-
age_n = sum(1 for f in all_findings if f.error_type == "invalid_age")
|
| 727 |
-
temp_n = sum(1 for f in all_findings if f.error_type == "temporal_inconsistency")
|
| 728 |
-
bias_n = sum(1 for f in all_findings if f.error_type == "selection_bias")
|
| 729 |
-
print(f" Detected: {age_n} age | {temp_n} temporal | {bias_n} bias")
|
| 730 |
-
|
| 731 |
-
# 3. PLAN
|
| 732 |
-
planner = ActionPlanner()
|
| 733 |
-
actions = planner.plan(all_findings, task_type, max_steps=max_steps)
|
| 734 |
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
|
| 739 |
-
# 5. EXECUTE
|
| 740 |
-
step = 0
|
| 741 |
for action in actions:
|
| 742 |
if result.done:
|
| 743 |
break
|
| 744 |
result = env.step(action)
|
| 745 |
obs = result.observation.model_dump()
|
| 746 |
final_score = obs["score_so_far"]
|
| 747 |
-
|
| 748 |
-
step += 1
|
| 749 |
-
metrics.steps = step
|
| 750 |
if action.action_type == "flag_error":
|
| 751 |
-
metrics.record(feedback)
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
print(f" Step {step}: score={final_score:.2f} | {feedback[:65]}")
|
| 755 |
|
| 756 |
-
# 6. REPORT
|
| 757 |
if not result.done:
|
| 758 |
-
result = env.step(AuditAction(action_type="submit_report", report=
|
| 759 |
obs = result.observation.model_dump()
|
| 760 |
final_score = obs["score_so_far"]
|
| 761 |
-
|
| 762 |
-
metrics.steps =
|
| 763 |
-
print(f" Step {step}: score={final_score:.2f} | Report submitted")
|
| 764 |
|
| 765 |
print(metrics.summary())
|
| 766 |
return final_score, metrics
|
| 767 |
|
| 768 |
|
| 769 |
-
|
| 770 |
-
# ORCHESTRATOR
|
| 771 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 772 |
-
|
| 773 |
-
TASK_LIST = {
|
| 774 |
-
"task_easy": "Syntactic Cleaning (Easy)",
|
| 775 |
-
"task_medium": "Temporal Consistency (Medium)",
|
| 776 |
-
"task_hard": "Equity Bias Audit (Hard)",
|
| 777 |
-
}
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
def run_agent(mode: str, client):
|
| 781 |
-
"""Run one agent mode across all tasks."""
|
| 782 |
runner = {
|
| 783 |
"naive": run_naive_task,
|
| 784 |
"heuristic": run_heuristic_task,
|
| 785 |
"full": run_full_task,
|
| 786 |
}[mode]
|
| 787 |
|
| 788 |
-
scores
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
for
|
| 792 |
-
score,
|
| 793 |
scores.append(score)
|
| 794 |
-
|
| 795 |
-
print(f"
|
| 796 |
-
|
| 797 |
-
elapsed = time.time() - t0
|
| 798 |
-
avg = sum(scores) / len(scores)
|
| 799 |
-
total_steps = sum(m.steps for m in all_metrics)
|
| 800 |
-
total_llm = sum(m.llm_calls for m in all_metrics)
|
| 801 |
-
avg_prec = statistics.mean(m.precision for m in all_metrics) if all_metrics else 0
|
| 802 |
|
| 803 |
return {
|
| 804 |
"mode": mode,
|
| 805 |
"scores": dict(zip(TASK_LIST.keys(), scores)),
|
| 806 |
-
"average":
|
| 807 |
-
"elapsed":
|
| 808 |
-
"total_steps":
|
| 809 |
-
"total_llm":
|
| 810 |
-
"avg_precision":
|
| 811 |
}
|
| 812 |
|
| 813 |
|
| 814 |
def main():
|
| 815 |
-
parser = argparse.ArgumentParser(description="Clinical Trial Auditor
|
| 816 |
-
parser.add_argument("--mode", choices=["naive", "heuristic", "full", "all"],
|
| 817 |
-
|
| 818 |
args = parser.parse_args()
|
| 819 |
|
| 820 |
-
|
| 821 |
-
needs_llm = args.mode in ("naive", "full", "all")
|
| 822 |
-
if needs_llm:
|
| 823 |
-
api_key = API_KEY or os.getenv("OPENAI_API_KEY")
|
| 824 |
-
if not api_key:
|
| 825 |
-
print("WARNING: No API key found. Set HF_TOKEN, API_KEY, or OPENAI_API_KEY.")
|
| 826 |
-
print(" Falling back to heuristic mode.")
|
| 827 |
-
args.mode = "heuristic"
|
| 828 |
-
client = None
|
| 829 |
-
else:
|
| 830 |
-
client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
|
| 831 |
-
else:
|
| 832 |
-
client = None
|
| 833 |
|
| 834 |
-
print("=" *
|
| 835 |
-
print(" Clinical Trial Auditor — Baseline Inference")
|
| 836 |
-
print("
|
| 837 |
print(f" Model: {MODEL_NAME}")
|
| 838 |
-
print(f" Seed:
|
| 839 |
-
print("=" *
|
|
|
|
|
|
|
| 840 |
|
| 841 |
-
if args.mode == "all"
|
| 842 |
-
|
| 843 |
-
else:
|
| 844 |
-
modes = [args.mode]
|
| 845 |
-
|
| 846 |
-
all_results = []
|
| 847 |
for mode in modes:
|
| 848 |
-
print(f"\n{'═' *
|
| 849 |
print(f" AGENT: {mode.upper()}")
|
| 850 |
-
print(f"{'═' *
|
| 851 |
-
|
| 852 |
-
all_results.append(result)
|
| 853 |
|
| 854 |
-
|
| 855 |
-
print("\n" + "=" * 65)
|
| 856 |
print(" BENCHMARK RESULTS")
|
| 857 |
-
print("=" *
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
# Multi-agent comparison table
|
| 861 |
-
header = f" {'Agent':<15} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'Avg':>8} {'Prec':>8} {'Time':>8}"
|
| 862 |
print(header)
|
| 863 |
-
print(" " + "-" *
|
| 864 |
-
for
|
| 865 |
-
scores =
|
| 866 |
-
print(
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
f"{r['elapsed']:.1f}s")
|
| 873 |
else:
|
| 874 |
-
|
| 875 |
-
for
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
print(f"
|
| 879 |
-
print(f"
|
| 880 |
-
print(f"
|
| 881 |
-
print(f"
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
print("=" * 65)
|
| 885 |
|
| 886 |
|
| 887 |
if __name__ == "__main__":
|
| 888 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
+
Clinical Trial Auditor — Baseline Inference
|
| 3 |
+
===========================================
|
| 4 |
+
Demonstrates a deliberate difficulty gradient on the protocol-aware benchmark:
|
| 5 |
+
|
| 6 |
+
1. NAIVE — raw prompt + small sample, weak structure
|
| 7 |
+
2. HEURISTIC — parses obvious rules but ignores key exceptions
|
| 8 |
+
3. FULL — parses protocol, honors stage exceptions, stage-adjusts bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
import os
|
| 16 |
+
import re
|
| 17 |
+
import statistics
|
| 18 |
import sys
|
| 19 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from collections import Counter
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from types import SimpleNamespace
|
| 24 |
from typing import Optional
|
| 25 |
|
| 26 |
+
from openai import OpenAI
|
| 27 |
+
|
| 28 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 29 |
|
|
|
|
| 30 |
from client import ClinicalTrialAuditorEnv
|
| 31 |
from models import AuditAction
|
| 32 |
|
| 33 |
+
try:
|
| 34 |
+
from server.clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
|
| 35 |
+
except ImportError:
|
| 36 |
+
ClinicalTrialAuditorEnvironment = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 40 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
| 41 |
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 42 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
| 43 |
+
BASELINE_SEED = int(os.getenv("BASELINE_SEED", "20260402"))
|
| 44 |
+
|
| 45 |
+
TASK_LIST = {
|
| 46 |
+
"task_easy": "Dynamic Eligibility Screening (Easy)",
|
| 47 |
+
"task_medium": "Protocol Timeline Audit (Medium)",
|
| 48 |
+
"task_hard": "Equity + Protocol Audit (Hard)",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
TASK_SPECS = {
|
| 52 |
+
"task_easy": {
|
| 53 |
+
"investigations": ["age"],
|
| 54 |
+
"distributions": [],
|
| 55 |
+
},
|
| 56 |
+
"task_medium": {
|
| 57 |
+
"investigations": ["age", "death_date", "enrollment_date", "stage"],
|
| 58 |
+
"distributions": [],
|
| 59 |
+
},
|
| 60 |
+
"task_hard": {
|
| 61 |
+
"investigations": ["age", "death_date", "enrollment_date", "stage"],
|
| 62 |
+
"distributions": ["ethnicity", "gender", "outcome"],
|
| 63 |
+
},
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
|
| 67 |
+
@dataclass
|
| 68 |
+
class ProtocolRules:
|
| 69 |
+
protocol_title: str
|
| 70 |
+
age_min: int
|
| 71 |
+
age_max: int
|
| 72 |
+
treatment_window_days: int
|
| 73 |
+
stage_iv_window_days: int
|
| 74 |
+
high_risk_sites: list[str] = field(default_factory=list)
|
| 75 |
+
bias_control_dominance_threshold: float = 1.0
|
| 76 |
+
bias_male_threshold: float = 1.0
|
| 77 |
+
bias_stage_adjusted_gap: float = 1.0
|
| 78 |
|
| 79 |
+
def allowed_window(self, stage: str) -> int:
|
| 80 |
+
return self.stage_iv_window_days if stage == "IV" else self.treatment_window_days
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
@dataclass
|
| 84 |
class Finding:
|
| 85 |
+
error_type: str
|
| 86 |
+
reason: str
|
| 87 |
+
patient_id: Optional[str] = None
|
| 88 |
+
confidence: float = 1.0
|
| 89 |
+
risk: str = "medium"
|
| 90 |
+
evidence: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
@property
|
| 93 |
def priority_score(self) -> float:
|
| 94 |
+
risk_weight = {"critical": 1.0, "high": 0.8, "medium": 0.5, "low": 0.2}
|
| 95 |
+
return self.confidence * risk_weight.get(self.risk, 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
|
| 98 |
+
class ProtocolParser:
|
| 99 |
@staticmethod
|
| 100 |
+
def parse(excerpt: str) -> ProtocolRules:
|
| 101 |
+
title_match = re.search(r"TRIAL PROTOCOL EXCERPT\s+[—-]\s+([A-Z0-9-]+)", excerpt)
|
| 102 |
+
age_match = re.search(r"age (\d+)-(\d+) inclusive", excerpt)
|
| 103 |
+
window_match = re.search(r"Treatment must begin within (\d+) days", excerpt)
|
| 104 |
+
stage_match = re.search(r"Stage IV exception: treatment may begin within (\d+) days", excerpt)
|
| 105 |
+
sites_match = re.search(
|
| 106 |
+
r"Stage IV patients at (.+?) are a known high-risk outreach cohort",
|
| 107 |
+
excerpt,
|
| 108 |
+
)
|
| 109 |
+
bias_match = re.search(
|
| 110 |
+
r"dominance exceeds (\d+)%, male share exceeds (\d+)%, "
|
| 111 |
+
r"and stage-adjusted mortality gap exceeds (\d+) percentage points",
|
| 112 |
+
excerpt,
|
| 113 |
+
)
|
| 114 |
|
| 115 |
+
high_risk_sites = []
|
| 116 |
+
if sites_match:
|
| 117 |
+
high_risk_sites = [site.strip() for site in sites_match.group(1).split(",")]
|
| 118 |
+
|
| 119 |
+
bias_values = (100, 100, 100)
|
| 120 |
+
if bias_match:
|
| 121 |
+
bias_values = tuple(int(value) for value in bias_match.groups())
|
| 122 |
+
|
| 123 |
+
if not age_match or not window_match or not stage_match:
|
| 124 |
+
raise ValueError("Unable to parse protocol excerpt.")
|
| 125 |
+
|
| 126 |
+
return ProtocolRules(
|
| 127 |
+
protocol_title=(title_match.group(1) if title_match else "UNKNOWN"),
|
| 128 |
+
age_min=int(age_match.group(1)),
|
| 129 |
+
age_max=int(age_match.group(2)),
|
| 130 |
+
treatment_window_days=int(window_match.group(1)),
|
| 131 |
+
stage_iv_window_days=int(stage_match.group(1)),
|
| 132 |
+
high_risk_sites=high_risk_sites,
|
| 133 |
+
bias_control_dominance_threshold=bias_values[0] / 100.0,
|
| 134 |
+
bias_male_threshold=bias_values[1] / 100.0,
|
| 135 |
+
bias_stage_adjusted_gap=bias_values[2] / 100.0,
|
| 136 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
def parse_date(value: Optional[str]) -> Optional[datetime]:
|
| 140 |
+
if not value:
|
| 141 |
+
return None
|
| 142 |
+
for fmt in ("%Y-%m-%d", "%Y/%m/%d", "%m/%d/%Y", "%d-%m-%Y"):
|
| 143 |
+
try:
|
| 144 |
+
return datetime.strptime(str(value), fmt)
|
| 145 |
+
except ValueError:
|
| 146 |
+
continue
|
| 147 |
+
return None
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
class AgeDetector:
|
| 151 |
+
def detect(self, dataset: list[dict], rules: ProtocolRules) -> list[Finding]:
|
| 152 |
+
findings = []
|
| 153 |
+
for row in dataset:
|
| 154 |
+
age = row.get("age")
|
| 155 |
+
if age is None or age < rules.age_min or age > rules.age_max:
|
| 156 |
+
findings.append(
|
| 157 |
+
Finding(
|
| 158 |
+
patient_id=row.get("patient_id"),
|
| 159 |
+
error_type="invalid_age",
|
| 160 |
+
reason=f"Age {age} violates protocol range {rules.age_min}-{rules.age_max}",
|
| 161 |
+
confidence=0.98 if age is None or age < 0 or age > (rules.age_max + 10) else 0.94,
|
| 162 |
+
risk="high",
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
return findings
|
| 166 |
|
| 167 |
|
| 168 |
+
class TemporalDetector:
|
| 169 |
+
def detect(self, dataset: list[dict]) -> list[Finding]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
findings = []
|
| 171 |
for row in dataset:
|
| 172 |
+
treatment = parse_date(row.get("treatment_start"))
|
| 173 |
+
death = parse_date(row.get("death_date"))
|
| 174 |
+
if treatment and death and death < treatment:
|
| 175 |
+
gap = (treatment - death).days
|
| 176 |
+
findings.append(
|
| 177 |
+
Finding(
|
| 178 |
+
patient_id=row.get("patient_id"),
|
| 179 |
+
error_type="temporal_inconsistency",
|
| 180 |
+
reason=f"death_date precedes treatment_start by {gap} days",
|
| 181 |
+
confidence=min(1.0, 0.92 + gap / 500.0),
|
| 182 |
+
risk="critical" if gap > 120 else "high",
|
| 183 |
+
)
|
| 184 |
+
)
|
|
|
|
| 185 |
return findings
|
| 186 |
|
| 187 |
|
| 188 |
+
class ProtocolWindowDetector:
|
| 189 |
+
def detect(self, dataset: list[dict], rules: ProtocolRules, ignore_stage_exception: bool = False) -> list[Finding]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
findings = []
|
| 191 |
+
for row in dataset:
|
| 192 |
+
enrollment = parse_date(row.get("enrollment_date"))
|
| 193 |
+
treatment = parse_date(row.get("treatment_start"))
|
| 194 |
+
if not enrollment or not treatment:
|
| 195 |
+
continue
|
| 196 |
+
allowed_days = rules.treatment_window_days if ignore_stage_exception else rules.allowed_window(row.get("stage", ""))
|
| 197 |
+
delay = (treatment - enrollment).days
|
| 198 |
+
if delay > allowed_days:
|
| 199 |
+
findings.append(
|
| 200 |
+
Finding(
|
| 201 |
+
patient_id=row.get("patient_id"),
|
| 202 |
+
error_type="protocol_window_violation",
|
| 203 |
+
reason=f"treatment started after {delay} days (allowed {allowed_days})",
|
| 204 |
+
confidence=0.93 if delay > allowed_days + 3 else 0.82,
|
| 205 |
+
risk="high",
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
return findings
|
| 209 |
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
class BiasAnalyzer:
|
| 212 |
+
@staticmethod
|
| 213 |
+
def summarize_control(dataset: list[dict]) -> tuple[list[dict], str, float, float, float]:
|
| 214 |
+
control = [row for row in dataset if row.get("group") == "control"]
|
| 215 |
+
if not control:
|
| 216 |
+
return [], "Unknown", 0.0, 0.0, 0.0
|
| 217 |
+
|
| 218 |
+
counts = Counter(row.get("ethnicity", "Unknown") for row in control)
|
| 219 |
+
dominant_ethnicity, dominant_count = counts.most_common(1)[0]
|
| 220 |
+
dominant_ratio = dominant_count / len(control)
|
| 221 |
+
male_ratio = sum(row.get("gender") == "M" for row in control) / len(control)
|
| 222 |
+
|
| 223 |
+
dominant_group = [row for row in control if row.get("ethnicity") == dominant_ethnicity]
|
| 224 |
+
minority_group = [row for row in control if row.get("ethnicity") != dominant_ethnicity]
|
| 225 |
+
dom_mortality = (
|
| 226 |
+
sum(row.get("outcome") == "deceased" for row in dominant_group) / len(dominant_group)
|
| 227 |
+
if dominant_group
|
| 228 |
+
else 0.0
|
| 229 |
)
|
| 230 |
+
min_mortality = (
|
| 231 |
+
sum(row.get("outcome") == "deceased" for row in minority_group) / len(minority_group)
|
| 232 |
+
if minority_group
|
| 233 |
+
else 0.0
|
| 234 |
+
)
|
| 235 |
+
overall_gap = min_mortality - dom_mortality
|
| 236 |
+
return control, dominant_ethnicity, dominant_ratio, male_ratio, overall_gap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
+
@staticmethod
|
| 239 |
+
def stage_adjusted_gap(control: list[dict], dominant_ethnicity: str) -> float:
|
| 240 |
+
weighted_gap = 0.0
|
| 241 |
+
total_weight = 0
|
| 242 |
+
for stage in ("I", "II", "III", "IV"):
|
| 243 |
+
stage_rows = [row for row in control if row.get("stage") == stage]
|
| 244 |
+
dominant_rows = [row for row in stage_rows if row.get("ethnicity") == dominant_ethnicity]
|
| 245 |
+
minority_rows = [row for row in stage_rows if row.get("ethnicity") != dominant_ethnicity]
|
| 246 |
+
if len(dominant_rows) < 5 or len(minority_rows) < 5:
|
| 247 |
+
continue
|
| 248 |
+
dominant_mortality = sum(row.get("outcome") == "deceased" for row in dominant_rows) / len(dominant_rows)
|
| 249 |
+
minority_mortality = sum(row.get("outcome") == "deceased" for row in minority_rows) / len(minority_rows)
|
| 250 |
+
weight = len(stage_rows)
|
| 251 |
+
weighted_gap += (minority_mortality - dominant_mortality) * weight
|
| 252 |
+
total_weight += weight
|
| 253 |
+
return weighted_gap / total_weight if total_weight else 0.0
|
| 254 |
+
|
| 255 |
+
def detect_full(self, dataset: list[dict], rules: ProtocolRules) -> list[Finding]:
|
| 256 |
+
control, dominant_ethnicity, dominant_ratio, male_ratio, overall_gap = self.summarize_control(dataset)
|
| 257 |
+
if not control:
|
| 258 |
+
return []
|
| 259 |
+
adjusted_gap = self.stage_adjusted_gap(control, dominant_ethnicity)
|
| 260 |
+
if (
|
| 261 |
+
dominant_ratio >= rules.bias_control_dominance_threshold
|
| 262 |
+
and male_ratio >= rules.bias_male_threshold
|
| 263 |
+
and adjusted_gap >= rules.bias_stage_adjusted_gap
|
| 264 |
+
):
|
| 265 |
+
return [
|
| 266 |
+
Finding(
|
| 267 |
+
patient_id=None,
|
| 268 |
+
error_type="selection_bias",
|
| 269 |
+
reason=(
|
| 270 |
+
f"Control-arm skew detected: {dominant_ethnicity}={dominant_ratio:.0%}, "
|
| 271 |
+
f"male={male_ratio:.0%}, stage-adjusted mortality gap={adjusted_gap:.0%}"
|
| 272 |
+
),
|
| 273 |
+
confidence=0.92,
|
| 274 |
+
risk="critical",
|
| 275 |
+
evidence=f"overall gap={overall_gap:.0%}",
|
| 276 |
+
)
|
| 277 |
+
]
|
| 278 |
+
return []
|
| 279 |
+
|
| 280 |
+
def detect_heuristic(self, dataset: list[dict], rules: ProtocolRules) -> list[Finding]:
|
| 281 |
+
control, dominant_ethnicity, dominant_ratio, male_ratio, overall_gap = self.summarize_control(dataset)
|
| 282 |
+
if not control:
|
| 283 |
+
return []
|
| 284 |
+
loose_threshold = max(0.10, rules.bias_stage_adjusted_gap - 0.04)
|
| 285 |
+
if dominant_ratio >= max(0.55, rules.bias_control_dominance_threshold - 0.07) and overall_gap >= loose_threshold:
|
| 286 |
+
return [
|
| 287 |
+
Finding(
|
| 288 |
+
patient_id=None,
|
| 289 |
+
error_type="selection_bias",
|
| 290 |
+
reason=(
|
| 291 |
+
f"Heuristic bias concern: {dominant_ethnicity}={dominant_ratio:.0%}, "
|
| 292 |
+
f"male={male_ratio:.0%}, overall mortality gap={overall_gap:.0%}"
|
| 293 |
+
),
|
| 294 |
+
confidence=0.74,
|
| 295 |
+
risk="high",
|
| 296 |
+
)
|
| 297 |
+
]
|
| 298 |
+
return []
|
| 299 |
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
class ActionPlanner:
|
| 302 |
+
def plan(
|
| 303 |
+
self,
|
| 304 |
+
task_id: str,
|
| 305 |
+
findings: list[Finding],
|
| 306 |
+
max_steps: int,
|
| 307 |
+
extra_investigations: Optional[list[str]] = None,
|
| 308 |
+
) -> list[AuditAction]:
|
| 309 |
+
spec = TASK_SPECS[task_id]
|
| 310 |
+
actions: list[AuditAction] = []
|
| 311 |
+
|
| 312 |
+
investigations = list(spec["investigations"])
|
| 313 |
+
distributions = list(spec["distributions"])
|
| 314 |
+
if extra_investigations:
|
| 315 |
+
investigations.extend(extra_investigations)
|
| 316 |
+
|
| 317 |
+
for variable in investigations:
|
| 318 |
+
actions.append(AuditAction(action_type="investigate_pattern", variable=variable))
|
| 319 |
+
for variable in distributions:
|
| 320 |
+
actions.append(AuditAction(action_type="compute_distribution", variable=variable))
|
| 321 |
+
|
| 322 |
+
record_findings = [finding for finding in findings if finding.error_type != "selection_bias"]
|
| 323 |
+
bias_findings = [finding for finding in findings if finding.error_type == "selection_bias"]
|
| 324 |
+
record_findings.sort(key=lambda finding: -finding.priority_score)
|
| 325 |
+
|
| 326 |
+
max_flag_slots = max_steps - len(actions) - 1 - (1 if bias_findings else 0)
|
| 327 |
+
flagged_ids = set()
|
| 328 |
+
for finding in record_findings[:max_flag_slots]:
|
| 329 |
+
if not finding.patient_id or finding.patient_id in flagged_ids:
|
| 330 |
+
continue
|
| 331 |
+
flagged_ids.add(finding.patient_id)
|
| 332 |
+
actions.append(
|
| 333 |
+
AuditAction(
|
| 334 |
+
action_type="flag_error",
|
| 335 |
+
patient_id=finding.patient_id,
|
| 336 |
+
error_type=finding.error_type,
|
| 337 |
+
reason=finding.reason,
|
| 338 |
+
confidence=finding.confidence,
|
| 339 |
+
)
|
| 340 |
+
)
|
| 341 |
|
| 342 |
+
if bias_findings:
|
| 343 |
+
bias = bias_findings[0]
|
| 344 |
+
actions.append(
|
| 345 |
+
AuditAction(
|
| 346 |
+
action_type="flag_error",
|
| 347 |
+
error_type="selection_bias",
|
| 348 |
+
reason=bias.reason,
|
| 349 |
+
confidence=bias.confidence,
|
| 350 |
+
)
|
| 351 |
+
)
|
| 352 |
|
| 353 |
+
return actions
|
|
|
|
|
|
|
| 354 |
|
|
|
|
| 355 |
|
| 356 |
+
def generate_expert_report(
|
| 357 |
+
client: Optional[OpenAI],
|
| 358 |
+
rules: ProtocolRules,
|
| 359 |
+
findings: list[Finding],
|
| 360 |
+
task_name: str,
|
| 361 |
+
) -> str:
|
| 362 |
+
finding_lines = []
|
| 363 |
+
for finding in findings[:8]:
|
| 364 |
+
if finding.patient_id:
|
| 365 |
+
finding_lines.append(f"- {finding.patient_id}: {finding.error_type} | {finding.reason}")
|
| 366 |
+
else:
|
| 367 |
+
finding_lines.append(f"- {finding.error_type}: {finding.reason}")
|
| 368 |
+
|
| 369 |
+
prompt = "\n".join(
|
| 370 |
+
[
|
| 371 |
+
f"Protocol: {rules.protocol_title}",
|
| 372 |
+
f"Task: {task_name}",
|
| 373 |
+
f"Key rules: age {rules.age_min}-{rules.age_max}, "
|
| 374 |
+
f"standard start <= {rules.treatment_window_days} days, "
|
| 375 |
+
f"stage IV <= {rules.stage_iv_window_days} days.",
|
| 376 |
+
"",
|
| 377 |
+
"Findings:",
|
| 378 |
+
*finding_lines,
|
| 379 |
+
"",
|
| 380 |
+
"Write a concise audit report with protocol grounding, root cause, risk, corrective actions, "
|
| 381 |
+
"and fairness reasoning when relevant.",
|
| 382 |
+
]
|
| 383 |
+
)
|
| 384 |
|
| 385 |
+
if client is not None:
|
| 386 |
+
try:
|
| 387 |
+
completion = client.chat.completions.create(
|
| 388 |
+
model=MODEL_NAME,
|
| 389 |
+
messages=[
|
| 390 |
+
{
|
| 391 |
+
"role": "system",
|
| 392 |
+
"content": (
|
| 393 |
+
"You are a senior clinical data manager. Produce a concise report "
|
| 394 |
+
"with protocol grounding, root cause, risk, corrective action, and "
|
| 395 |
+
"fairness reasoning when applicable."
|
| 396 |
+
),
|
| 397 |
+
},
|
| 398 |
+
{"role": "user", "content": prompt},
|
| 399 |
+
],
|
| 400 |
+
temperature=0,
|
| 401 |
+
max_tokens=240,
|
| 402 |
+
)
|
| 403 |
+
content = completion.choices[0].message.content or ""
|
| 404 |
+
if content:
|
| 405 |
+
return content
|
| 406 |
+
except Exception:
|
| 407 |
+
pass
|
| 408 |
+
|
| 409 |
+
if any(finding.error_type == "selection_bias" for finding in findings):
|
| 410 |
+
fairness_line = (
|
| 411 |
+
"Fairness review: control-arm patterns were reviewed with stage-adjusted comparisons "
|
| 412 |
+
"before escalating the bias conclusion."
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
fairness_line = (
|
| 416 |
+
"Fairness review: no actionable control-arm bias was confirmed after stage-adjusted review."
|
| 417 |
+
)
|
| 418 |
|
| 419 |
+
return (
|
| 420 |
+
f"Protocol-grounded audit for {rules.protocol_title}. Root cause analysis indicates site-level "
|
| 421 |
+
f"data capture and scheduling control weaknesses. Risk assessment: protocol compliance and endpoint "
|
| 422 |
+
f"validity are affected. Recommended corrective actions include quarantining impacted records, "
|
| 423 |
+
f"tightening enrollment-to-treatment validations, and retraining site coordinators. {fairness_line}"
|
| 424 |
+
)
|
| 425 |
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
class MetricsTracker:
|
| 428 |
def __init__(self):
|
| 429 |
self.true_pos = 0
|
|
|
|
| 432 |
self.steps = 0
|
| 433 |
self.llm_calls = 0
|
| 434 |
|
| 435 |
+
def record(self, feedback: str) -> None:
|
| 436 |
self.total_flagged += 1
|
| 437 |
+
if "✓" in feedback or "Correct" in feedback:
|
| 438 |
self.true_pos += 1
|
| 439 |
+
elif "✗" in feedback or "REJECTED" in feedback:
|
| 440 |
self.false_pos += 1
|
| 441 |
|
| 442 |
@property
|
| 443 |
def precision(self) -> float:
|
| 444 |
+
return self.true_pos / self.total_flagged if self.total_flagged else 0.0
|
| 445 |
|
| 446 |
def summary(self) -> str:
|
| 447 |
return (
|
| 448 |
+
f" Metrics: {self.true_pos}/{self.total_flagged} correct "
|
| 449 |
+
f"(precision={self.precision:.0%}) | {self.steps} steps | {self.llm_calls} LLM call(s)"
|
|
|
|
| 450 |
)
|
| 451 |
|
| 452 |
|
| 453 |
+
class InProcessEnvSession:
|
| 454 |
+
def __init__(self):
|
| 455 |
+
if ClinicalTrialAuditorEnvironment is None:
|
| 456 |
+
raise RuntimeError("In-process environment is unavailable.")
|
| 457 |
+
self._env = ClinicalTrialAuditorEnvironment()
|
| 458 |
|
| 459 |
+
def __enter__(self):
|
| 460 |
+
return self
|
| 461 |
+
|
| 462 |
+
def __exit__(self, exc_type, exc, tb):
|
| 463 |
+
return False
|
| 464 |
+
|
| 465 |
+
def reset(self, **kwargs):
|
| 466 |
+
observation = self._env.reset(**kwargs)
|
| 467 |
+
return SimpleNamespace(observation=observation, reward=observation.reward, done=observation.done)
|
| 468 |
+
|
| 469 |
+
def step(self, action: AuditAction):
|
| 470 |
+
observation = self._env.step(action)
|
| 471 |
+
return SimpleNamespace(observation=observation, reward=observation.reward, done=observation.done)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def open_env_session():
|
| 475 |
+
if ENV_BASE_URL.lower() == "inprocess":
|
| 476 |
+
return InProcessEnvSession()
|
| 477 |
+
return ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync()
|
| 478 |
|
| 479 |
+
|
| 480 |
+
def run_naive_task(client: Optional[OpenAI], task_id: str, task_name: str, seed: int):
|
| 481 |
+
print(f"\n Task: {task_name}")
|
| 482 |
+
print(" " + "-" * 54)
|
| 483 |
metrics = MetricsTracker()
|
| 484 |
final_score = 0.0
|
| 485 |
|
| 486 |
+
with open_env_session() as env:
|
| 487 |
+
result = env.reset(task_id=task_id, seed=seed)
|
| 488 |
obs = result.observation.model_dump()
|
| 489 |
dataset = obs["dataset"]
|
| 490 |
+
protocol_excerpt = obs["trial_protocol_excerpt"]
|
| 491 |
max_steps = obs["attempts_remaining"]
|
| 492 |
+
rules = ProtocolParser.parse(protocol_excerpt)
|
| 493 |
+
print(f" Protocol: {rules.protocol_title} | Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 494 |
|
| 495 |
+
sample = dataset[:24]
|
| 496 |
+
guessed_findings: list[Finding] = []
|
|
|
|
| 497 |
|
| 498 |
+
if client is not None:
|
| 499 |
+
try:
|
| 500 |
+
completion = client.chat.completions.create(
|
| 501 |
+
model=MODEL_NAME,
|
| 502 |
+
messages=[
|
| 503 |
+
{
|
| 504 |
+
"role": "system",
|
| 505 |
+
"content": (
|
| 506 |
+
"You are auditing patient records from a clinical trial. "
|
| 507 |
+
"Return one issue per line as PATIENT_ID|ERROR_TYPE|REASON."
|
| 508 |
+
),
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"role": "user",
|
| 512 |
+
"content": (
|
| 513 |
+
f"Protocol excerpt:\n{protocol_excerpt}\n\n"
|
| 514 |
+
f"Review only these {len(sample)} records:\n{json.dumps(sample, default=str)}\n\n"
|
| 515 |
+
"Allowed ERROR_TYPE values: invalid_age, temporal_inconsistency, "
|
| 516 |
+
"protocol_window_violation, selection_bias."
|
| 517 |
+
),
|
| 518 |
+
},
|
| 519 |
+
],
|
| 520 |
+
temperature=0,
|
| 521 |
+
max_tokens=450,
|
| 522 |
+
)
|
| 523 |
+
metrics.llm_calls += 1
|
| 524 |
+
lines = (completion.choices[0].message.content or "").splitlines()
|
| 525 |
+
for line in lines:
|
| 526 |
+
parts = [part.strip() for part in line.split("|")]
|
| 527 |
+
if len(parts) >= 2:
|
| 528 |
+
guessed_findings.append(
|
| 529 |
+
Finding(
|
| 530 |
+
patient_id=parts[0] if parts[0] and parts[0] != "None" else None,
|
| 531 |
+
error_type=parts[1],
|
| 532 |
+
reason=parts[2] if len(parts) > 2 else "LLM guess",
|
| 533 |
+
confidence=0.65,
|
| 534 |
+
)
|
| 535 |
+
)
|
| 536 |
+
except Exception as exc:
|
| 537 |
+
print(f" LLM error: {exc}")
|
| 538 |
+
|
| 539 |
+
if not guessed_findings:
|
| 540 |
+
for row in sample:
|
| 541 |
+
age = row.get("age")
|
| 542 |
+
treatment = parse_date(row.get("treatment_start"))
|
| 543 |
+
death = parse_date(row.get("death_date"))
|
| 544 |
+
enrollment = parse_date(row.get("enrollment_date"))
|
| 545 |
+
if age is None or age < 0 or age > 120:
|
| 546 |
+
guessed_findings.append(
|
| 547 |
+
Finding(
|
| 548 |
+
patient_id=row.get("patient_id"),
|
| 549 |
+
error_type="invalid_age",
|
| 550 |
+
reason="Sample-level obvious age anomaly",
|
| 551 |
+
confidence=0.55,
|
| 552 |
+
)
|
| 553 |
+
)
|
| 554 |
+
if treatment and death and death < treatment:
|
| 555 |
+
guessed_findings.append(
|
| 556 |
+
Finding(
|
| 557 |
+
patient_id=row.get("patient_id"),
|
| 558 |
+
error_type="temporal_inconsistency",
|
| 559 |
+
reason="Sample-level temporal anomaly",
|
| 560 |
+
confidence=0.60,
|
| 561 |
+
)
|
| 562 |
+
)
|
| 563 |
+
plan_actions = []
|
| 564 |
+
for variable in TASK_SPECS[task_id]["investigations"]:
|
| 565 |
+
plan_actions.append(AuditAction(action_type="investigate_pattern", variable=variable))
|
| 566 |
+
if task_id == "task_hard":
|
| 567 |
+
plan_actions.extend(
|
| 568 |
+
AuditAction(action_type="compute_distribution", variable=variable)
|
| 569 |
+
for variable in TASK_SPECS[task_id]["distributions"]
|
| 570 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
+
max_flag_slots = max_steps - len(plan_actions) - 1
|
| 573 |
+
for finding in guessed_findings[:max_flag_slots]:
|
| 574 |
+
plan_actions.append(
|
| 575 |
+
AuditAction(
|
| 576 |
+
action_type="flag_error",
|
| 577 |
+
patient_id=finding.patient_id,
|
| 578 |
+
error_type=finding.error_type,
|
| 579 |
+
reason=finding.reason,
|
| 580 |
+
confidence=finding.confidence,
|
| 581 |
+
)
|
| 582 |
+
)
|
| 583 |
|
| 584 |
+
for action in plan_actions:
|
|
|
|
|
|
|
| 585 |
if result.done:
|
| 586 |
break
|
| 587 |
+
result = env.step(action)
|
| 588 |
+
obs = result.observation.model_dump()
|
| 589 |
+
final_score = obs["score_so_far"]
|
| 590 |
+
metrics.steps += 1
|
| 591 |
+
if action.action_type == "flag_error":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
metrics.record(obs["feedback"])
|
|
|
|
| 593 |
|
|
|
|
| 594 |
if not result.done:
|
| 595 |
+
result = env.step(
|
| 596 |
+
AuditAction(
|
| 597 |
+
action_type="submit_report",
|
| 598 |
+
report=(
|
| 599 |
+
f"Protocol grounding for {rules.protocol_title}. "
|
| 600 |
+
"Sample review found possible age and timing issues. "
|
| 601 |
+
"Recommend manual review and corrective action."
|
| 602 |
+
),
|
| 603 |
+
)
|
| 604 |
+
)
|
| 605 |
obs = result.observation.model_dump()
|
| 606 |
final_score = obs["score_so_far"]
|
| 607 |
metrics.steps += 1
|
|
|
|
| 610 |
return final_score, metrics
|
| 611 |
|
| 612 |
|
| 613 |
+
def run_heuristic_task(client_unused: Optional[OpenAI], task_id: str, task_name: str, seed: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
print(f"\n Task: {task_name}")
|
| 615 |
+
print(" " + "-" * 54)
|
|
|
|
| 616 |
metrics = MetricsTracker()
|
| 617 |
final_score = 0.0
|
| 618 |
|
| 619 |
+
with open_env_session() as env:
|
| 620 |
+
result = env.reset(task_id=task_id, seed=seed)
|
| 621 |
obs = result.observation.model_dump()
|
| 622 |
dataset = obs["dataset"]
|
| 623 |
+
rules = ProtocolParser.parse(obs["trial_protocol_excerpt"])
|
| 624 |
max_steps = obs["attempts_remaining"]
|
| 625 |
+
print(f" Protocol: {rules.protocol_title} | Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 626 |
|
| 627 |
+
actions: list[AuditAction] = []
|
| 628 |
+
for variable in TASK_SPECS[task_id]["investigations"]:
|
| 629 |
+
actions.append(AuditAction(action_type="investigate_pattern", variable=variable))
|
| 630 |
+
for variable in TASK_SPECS[task_id]["distributions"]:
|
| 631 |
+
actions.append(AuditAction(action_type="compute_distribution", variable=variable))
|
|
|
|
| 632 |
|
| 633 |
+
findings = []
|
| 634 |
+
for row in dataset:
|
| 635 |
+
age = row.get("age")
|
| 636 |
+
if age is None or age < (rules.age_min - 3) or age > (rules.age_max + 3):
|
| 637 |
+
findings.append(
|
| 638 |
+
Finding(
|
| 639 |
+
patient_id=row.get("patient_id"),
|
| 640 |
+
error_type="invalid_age",
|
| 641 |
+
reason=f"Heuristic age screen triggered on {age}",
|
| 642 |
+
confidence=0.82,
|
| 643 |
+
risk="high",
|
| 644 |
+
)
|
| 645 |
+
)
|
| 646 |
+
findings.extend(TemporalDetector().detect(dataset))
|
| 647 |
+
if task_id in {"task_medium", "task_hard"}:
|
| 648 |
+
findings.extend(ProtocolWindowDetector().detect(dataset, rules, ignore_stage_exception=True))
|
| 649 |
+
if task_id == "task_hard":
|
| 650 |
+
findings.extend(BiasAnalyzer().detect_heuristic(dataset, rules))
|
| 651 |
|
| 652 |
+
planner = ActionPlanner()
|
| 653 |
+
planned_flags = planner.plan(task_id, findings, max_steps=max_steps)
|
| 654 |
+
actions = planned_flags
|
| 655 |
+
|
| 656 |
+
for action in actions:
|
| 657 |
+
if result.done:
|
| 658 |
break
|
| 659 |
+
result = env.step(action)
|
| 660 |
+
obs = result.observation.model_dump()
|
| 661 |
+
final_score = obs["score_so_far"]
|
| 662 |
+
metrics.steps += 1
|
| 663 |
+
if action.action_type == "flag_error":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
metrics.record(obs["feedback"])
|
| 665 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
if not result.done:
|
| 667 |
+
result = env.step(
|
| 668 |
+
AuditAction(
|
| 669 |
+
action_type="submit_report",
|
| 670 |
+
report=(
|
| 671 |
+
f"Protocol review for {rules.protocol_title}. Root cause is likely data-entry drift. "
|
| 672 |
+
"Recommend validation checks and operational follow-up. Risk is moderate to high."
|
| 673 |
+
),
|
| 674 |
+
)
|
| 675 |
+
)
|
| 676 |
obs = result.observation.model_dump()
|
| 677 |
final_score = obs["score_so_far"]
|
| 678 |
metrics.steps += 1
|
|
|
|
| 681 |
return final_score, metrics
|
| 682 |
|
| 683 |
|
| 684 |
+
def run_full_task(client: Optional[OpenAI], task_id: str, task_name: str, seed: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
print(f"\n Task: {task_name}")
|
| 686 |
+
print(" " + "-" * 54)
|
|
|
|
| 687 |
metrics = MetricsTracker()
|
| 688 |
final_score = 0.0
|
| 689 |
|
| 690 |
+
with open_env_session() as env:
|
| 691 |
+
result = env.reset(task_id=task_id, seed=seed)
|
| 692 |
obs = result.observation.model_dump()
|
| 693 |
dataset = obs["dataset"]
|
| 694 |
+
rules = ProtocolParser.parse(obs["trial_protocol_excerpt"])
|
| 695 |
max_steps = obs["attempts_remaining"]
|
| 696 |
+
print(f" Protocol: {rules.protocol_title} | Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 697 |
+
print(
|
| 698 |
+
f" Rules: age {rules.age_min}-{rules.age_max} | standard <= {rules.treatment_window_days}d | "
|
| 699 |
+
f"stage IV <= {rules.stage_iv_window_days}d"
|
| 700 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
|
| 702 |
+
findings = []
|
| 703 |
+
findings.extend(AgeDetector().detect(dataset, rules))
|
| 704 |
+
findings.extend(TemporalDetector().detect(dataset))
|
| 705 |
+
if task_id in {"task_medium", "task_hard"}:
|
| 706 |
+
findings.extend(ProtocolWindowDetector().detect(dataset, rules, ignore_stage_exception=False))
|
| 707 |
+
if task_id == "task_hard":
|
| 708 |
+
findings.extend(BiasAnalyzer().detect_full(dataset, rules))
|
| 709 |
+
|
| 710 |
+
age_count = sum(f.error_type == "invalid_age" for f in findings)
|
| 711 |
+
temporal_count = sum(f.error_type == "temporal_inconsistency" for f in findings)
|
| 712 |
+
window_count = sum(f.error_type == "protocol_window_violation" for f in findings)
|
| 713 |
+
bias_count = sum(f.error_type == "selection_bias" for f in findings)
|
| 714 |
+
print(
|
| 715 |
+
f" Detected: age={age_count} | temporal={temporal_count} | "
|
| 716 |
+
f"window={window_count} | bias={bias_count}"
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
extra_checks = {
|
| 720 |
+
"task_easy": ["enrollment_date", "stage", "group", "treatment_site", "country"],
|
| 721 |
+
"task_medium": ["group", "treatment_site", "outcome", "country", "drug"],
|
| 722 |
+
"task_hard": ["treatment_site", "group", "country", "drug", "trial_phase"],
|
| 723 |
+
}.get(task_id, [])
|
| 724 |
+
actions = ActionPlanner().plan(task_id, findings, max_steps=max_steps, extra_investigations=extra_checks)
|
| 725 |
+
report = generate_expert_report(client, rules, findings, task_name)
|
| 726 |
+
if client is not None:
|
| 727 |
+
metrics.llm_calls += 1
|
| 728 |
|
|
|
|
|
|
|
| 729 |
for action in actions:
|
| 730 |
if result.done:
|
| 731 |
break
|
| 732 |
result = env.step(action)
|
| 733 |
obs = result.observation.model_dump()
|
| 734 |
final_score = obs["score_so_far"]
|
| 735 |
+
metrics.steps += 1
|
|
|
|
|
|
|
| 736 |
if action.action_type == "flag_error":
|
| 737 |
+
metrics.record(obs["feedback"])
|
| 738 |
+
if action.action_type == "flag_error" or metrics.steps <= 5:
|
| 739 |
+
print(f" Step {metrics.steps}: score={final_score:.2f} | {obs['feedback'][:80]}")
|
|
|
|
| 740 |
|
|
|
|
| 741 |
if not result.done:
|
| 742 |
+
result = env.step(AuditAction(action_type="submit_report", report=report))
|
| 743 |
obs = result.observation.model_dump()
|
| 744 |
final_score = obs["score_so_far"]
|
| 745 |
+
metrics.steps += 1
|
| 746 |
+
print(f" Step {metrics.steps}: score={final_score:.2f} | report submitted")
|
|
|
|
| 747 |
|
| 748 |
print(metrics.summary())
|
| 749 |
return final_score, metrics
|
| 750 |
|
| 751 |
|
| 752 |
+
def run_agent(mode: str, client: Optional[OpenAI], seed: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
runner = {
|
| 754 |
"naive": run_naive_task,
|
| 755 |
"heuristic": run_heuristic_task,
|
| 756 |
"full": run_full_task,
|
| 757 |
}[mode]
|
| 758 |
|
| 759 |
+
scores = []
|
| 760 |
+
metrics_list = []
|
| 761 |
+
start = time.time()
|
| 762 |
+
for task_id, task_name in TASK_LIST.items():
|
| 763 |
+
score, metrics = runner(client, task_id, task_name, seed)
|
| 764 |
scores.append(score)
|
| 765 |
+
metrics_list.append(metrics)
|
| 766 |
+
print(f" Final score: {score:.2f}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
return {
|
| 769 |
"mode": mode,
|
| 770 |
"scores": dict(zip(TASK_LIST.keys(), scores)),
|
| 771 |
+
"average": sum(scores) / len(scores),
|
| 772 |
+
"elapsed": time.time() - start,
|
| 773 |
+
"total_steps": sum(metric.steps for metric in metrics_list),
|
| 774 |
+
"total_llm": sum(metric.llm_calls for metric in metrics_list),
|
| 775 |
+
"avg_precision": statistics.mean(metric.precision for metric in metrics_list),
|
| 776 |
}
|
| 777 |
|
| 778 |
|
| 779 |
def main():
|
| 780 |
+
parser = argparse.ArgumentParser(description="Clinical Trial Auditor baseline inference")
|
| 781 |
+
parser.add_argument("--mode", choices=["naive", "heuristic", "full", "all"], default="full")
|
| 782 |
+
parser.add_argument("--seed", type=int, default=BASELINE_SEED)
|
| 783 |
args = parser.parse_args()
|
| 784 |
|
| 785 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
|
| 787 |
+
print("=" * 70)
|
| 788 |
+
print(" Clinical Trial Auditor — Protocol-Aware Baseline Inference")
|
| 789 |
+
print(" Dynamic Rules | Adversarial Traps | Stage-Adjusted Fairness Review")
|
| 790 |
print(f" Model: {MODEL_NAME}")
|
| 791 |
+
print(f" Seed: {args.seed}")
|
| 792 |
+
print("=" * 70)
|
| 793 |
+
if client is None:
|
| 794 |
+
print(" Note: no API key detected. Naive/full report generation will use deterministic fallbacks.")
|
| 795 |
|
| 796 |
+
modes = ["naive", "heuristic", "full"] if args.mode == "all" else [args.mode]
|
| 797 |
+
results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
for mode in modes:
|
| 799 |
+
print(f"\n{'═' * 70}")
|
| 800 |
print(f" AGENT: {mode.upper()}")
|
| 801 |
+
print(f"{'═' * 70}")
|
| 802 |
+
results.append(run_agent(mode, client, args.seed))
|
|
|
|
| 803 |
|
| 804 |
+
print("\n" + "=" * 70)
|
|
|
|
| 805 |
print(" BENCHMARK RESULTS")
|
| 806 |
+
print("=" * 70)
|
| 807 |
+
if len(results) > 1:
|
| 808 |
+
header = f" {'Agent':<12} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'Avg':>8} {'Prec':>8} {'Time':>8}"
|
|
|
|
|
|
|
| 809 |
print(header)
|
| 810 |
+
print(" " + "-" * 66)
|
| 811 |
+
for result in results:
|
| 812 |
+
scores = result["scores"]
|
| 813 |
+
print(
|
| 814 |
+
f" {result['mode'].upper():<12} "
|
| 815 |
+
f"{scores['task_easy']:.2f} {scores['task_medium']:.2f} "
|
| 816 |
+
f"{scores['task_hard']:.2f} {result['average']:.2f} "
|
| 817 |
+
f"{result['avg_precision']:.0%} {result['elapsed']:.1f}s"
|
| 818 |
+
)
|
|
|
|
| 819 |
else:
|
| 820 |
+
result = results[0]
|
| 821 |
+
for task_id, task_name in TASK_LIST.items():
|
| 822 |
+
print(f" {task_name:38s}: {result['scores'][task_id]:.2f}")
|
| 823 |
+
print(f"\n Average score: {result['average']:.2f}")
|
| 824 |
+
print(f" Total time: {result['elapsed']:.1f}s")
|
| 825 |
+
print(f" LLM calls: {result['total_llm']}")
|
| 826 |
+
print(f" Total steps: {result['total_steps']}")
|
| 827 |
+
print(f" Average precision: {result['avg_precision']:.0%}")
|
| 828 |
+
print("=" * 70)
|
|
|
|
|
|
|
| 829 |
|
| 830 |
|
| 831 |
if __name__ == "__main__":
|
| 832 |
+
main()
|
models.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
|
|
| 2 |
from pydantic import Field
|
| 3 |
from openenv.core.env_server import Action, Observation, State
|
| 4 |
|
|
|
|
| 5 |
class AuditAction(Action):
|
| 6 |
action_type: str = "flag_error"
|
| 7 |
patient_id: Optional[str] = None
|
|
@@ -12,30 +13,43 @@ class AuditAction(Action):
|
|
| 12 |
report: Optional[str] = None
|
| 13 |
confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
|
| 14 |
|
|
|
|
| 15 |
class AuditObservation(Observation):
|
| 16 |
done: bool = False
|
| 17 |
reward: float = 0.0
|
| 18 |
task_id: str = ""
|
| 19 |
task_type: str = ""
|
| 20 |
task_description: str = ""
|
|
|
|
|
|
|
| 21 |
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 22 |
errors_found: List[str] = Field(default_factory=list)
|
| 23 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 24 |
distributions_computed: List[str] = Field(default_factory=list)
|
| 25 |
feedback: Optional[str] = None
|
| 26 |
score_so_far: float = 0.0
|
|
|
|
|
|
|
| 27 |
attempts_remaining: int = 15
|
| 28 |
phase: str = "investigation"
|
| 29 |
|
|
|
|
| 30 |
class AuditState(State):
|
| 31 |
episode_id: str = ""
|
| 32 |
step_count: int = 0
|
| 33 |
task_id: str = ""
|
| 34 |
task_type: str = ""
|
|
|
|
|
|
|
| 35 |
total_errors: int = 0
|
| 36 |
errors_found: int = 0
|
| 37 |
current_score: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
attempts: int = 0
|
| 39 |
phase: str = "investigation"
|
|
|
|
| 40 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 41 |
-
distributions_computed: List[str] = Field(default_factory=list)
|
|
|
|
| 2 |
from pydantic import Field
|
| 3 |
from openenv.core.env_server import Action, Observation, State
|
| 4 |
|
| 5 |
+
|
| 6 |
class AuditAction(Action):
|
| 7 |
action_type: str = "flag_error"
|
| 8 |
patient_id: Optional[str] = None
|
|
|
|
| 13 |
report: Optional[str] = None
|
| 14 |
confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
|
| 15 |
|
| 16 |
+
|
| 17 |
class AuditObservation(Observation):
|
| 18 |
done: bool = False
|
| 19 |
reward: float = 0.0
|
| 20 |
task_id: str = ""
|
| 21 |
task_type: str = ""
|
| 22 |
task_description: str = ""
|
| 23 |
+
protocol_title: str = ""
|
| 24 |
+
trial_protocol_excerpt: str = ""
|
| 25 |
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 26 |
errors_found: List[str] = Field(default_factory=list)
|
| 27 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 28 |
distributions_computed: List[str] = Field(default_factory=list)
|
| 29 |
feedback: Optional[str] = None
|
| 30 |
score_so_far: float = 0.0
|
| 31 |
+
dense_reward_total: float = 0.0
|
| 32 |
+
score_breakdown: Dict[str, float] = Field(default_factory=dict)
|
| 33 |
attempts_remaining: int = 15
|
| 34 |
phase: str = "investigation"
|
| 35 |
|
| 36 |
+
|
| 37 |
class AuditState(State):
|
| 38 |
episode_id: str = ""
|
| 39 |
step_count: int = 0
|
| 40 |
task_id: str = ""
|
| 41 |
task_type: str = ""
|
| 42 |
+
protocol_title: str = ""
|
| 43 |
+
trial_protocol_excerpt: str = ""
|
| 44 |
total_errors: int = 0
|
| 45 |
errors_found: int = 0
|
| 46 |
current_score: float = 0.0
|
| 47 |
+
dense_reward_total: float = 0.0
|
| 48 |
+
correct_flags: int = 0
|
| 49 |
+
false_positives: int = 0
|
| 50 |
+
duplicate_flags: int = 0
|
| 51 |
attempts: int = 0
|
| 52 |
phase: str = "investigation"
|
| 53 |
+
score_breakdown: Dict[str, float] = Field(default_factory=dict)
|
| 54 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 55 |
+
distributions_computed: List[str] = Field(default_factory=list)
|
openenv.yaml
CHANGED
|
@@ -1,27 +1,29 @@
|
|
| 1 |
name: clinical_trial_auditor
|
| 2 |
-
version: "
|
| 3 |
description: >
|
| 4 |
-
A
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 8 |
author: Sumit Saraswat
|
| 9 |
tags:
|
| 10 |
- openenv
|
| 11 |
- clinical
|
| 12 |
-
-
|
| 13 |
-
-
|
|
|
|
| 14 |
- ai-safety
|
| 15 |
tasks:
|
| 16 |
- id: task_easy
|
| 17 |
-
name:
|
| 18 |
difficulty: easy
|
| 19 |
-
description:
|
| 20 |
- id: task_medium
|
| 21 |
-
name:
|
| 22 |
difficulty: medium
|
| 23 |
-
description:
|
| 24 |
- id: task_hard
|
| 25 |
-
name: Equity
|
| 26 |
difficulty: hard
|
| 27 |
-
description:
|
|
|
|
| 1 |
name: clinical_trial_auditor
|
| 2 |
+
version: "3.0.0"
|
| 3 |
description: >
|
| 4 |
+
A protocol-aware clinical audit benchmark for OpenEnv. The agent acts as a Senior
|
| 5 |
+
Clinical Data Manager and must read an episode-specific protocol excerpt, audit
|
| 6 |
+
tabular patient records against dynamic eligibility and timing rules, and decide
|
| 7 |
+
whether suspicious subgroup outcomes represent actionable control-arm bias or a
|
| 8 |
+
confounded high-risk cohort.
|
| 9 |
author: Sumit Saraswat
|
| 10 |
tags:
|
| 11 |
- openenv
|
| 12 |
- clinical
|
| 13 |
+
- benchmark
|
| 14 |
+
- protocol-reasoning
|
| 15 |
+
- bias-audit
|
| 16 |
- ai-safety
|
| 17 |
tasks:
|
| 18 |
- id: task_easy
|
| 19 |
+
name: Dynamic Eligibility Screening
|
| 20 |
difficulty: easy
|
| 21 |
+
description: Read the protocol excerpt for the episode and flag patients whose ages violate the protocol-specific eligibility range.
|
| 22 |
- id: task_medium
|
| 23 |
+
name: Protocol Timeline Audit
|
| 24 |
difficulty: medium
|
| 25 |
+
description: Audit dynamic age eligibility, death-before-treatment errors, and treatment-start window violations with a Stage IV timing exception.
|
| 26 |
- id: task_hard
|
| 27 |
+
name: Equity + Protocol Audit
|
| 28 |
difficulty: hard
|
| 29 |
+
description: Audit record-level protocol issues and determine whether control-arm bias is genuinely present or only confounded by a high-risk outreach cohort.
|
server/requirements.txt → requirements.txt
RENAMED
|
File without changes
|
server.log
CHANGED
|
@@ -34,3 +34,7 @@ INFO: connection closed
|
|
| 34 |
INFO: 127.0.0.1:53804 - "WebSocket /ws" [accepted]
|
| 35 |
INFO: connection open
|
| 36 |
INFO: connection closed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
INFO: 127.0.0.1:53804 - "WebSocket /ws" [accepted]
|
| 35 |
INFO: connection open
|
| 36 |
INFO: connection closed
|
| 37 |
+
INFO: 127.0.0.1:56934 - "GET /health HTTP/1.1" 200 OK
|
| 38 |
+
INFO: 127.0.0.1:56965 - "WebSocket /ws" [accepted]
|
| 39 |
+
INFO: connection open
|
| 40 |
+
INFO: connection closed
|
server/app.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
-
from openenv.core.env_server import create_fastapi_app
|
| 2 |
-
from clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
|
| 3 |
-
from models import AuditAction, AuditObservation
|
| 4 |
import uvicorn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app = create_fastapi_app(ClinicalTrialAuditorEnvironment, AuditAction, AuditObservation)
|
| 7 |
|
|
@@ -9,4 +14,4 @@ def main():
|
|
| 9 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 10 |
|
| 11 |
if __name__ == "__main__":
|
| 12 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import uvicorn
|
| 2 |
+
from openenv.core.env_server import create_fastapi_app
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from .clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
|
| 6 |
+
from .models import AuditAction, AuditObservation
|
| 7 |
+
except ImportError:
|
| 8 |
+
from clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
|
| 9 |
+
from models import AuditAction, AuditObservation
|
| 10 |
|
| 11 |
app = create_fastapi_app(ClinicalTrialAuditorEnvironment, AuditAction, AuditObservation)
|
| 12 |
|
|
|
|
| 14 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 15 |
|
| 16 |
if __name__ == "__main__":
|
| 17 |
+
main()
|
server/clinical_trial_auditor_environment.py
CHANGED
|
@@ -1,152 +1,102 @@
|
|
| 1 |
"""
|
| 2 |
Clinical Trial Auditor — OpenEnv Environment
|
| 3 |
-
============================================
|
| 4 |
-
|
| 5 |
-
and
|
| 6 |
-
|
| 7 |
-
The agent acts as a Senior Clinical Data Manager auditing procedurally
|
| 8 |
-
generated clinical trial datasets from a multi-site Phase III oncology trial.
|
| 9 |
-
|
| 10 |
-
Architecture layers:
|
| 11 |
-
┌─────────────────────────────────────────────┐
|
| 12 |
-
│ Agent Interface (OpenEnv API) │
|
| 13 |
-
│ step() / reset() / state() │
|
| 14 |
-
├─────────────────────────────────────────────┤
|
| 15 |
-
│ Scoring Engine (Grader) │
|
| 16 |
-
│ Ground-truth comparison, partial credit, │
|
| 17 |
-
│ confidence calibration, score composition │
|
| 18 |
-
├─────────────────────────────────────────────┤
|
| 19 |
-
│ Trap Engine (Adversarial) │
|
| 20 |
-
│ Boundary traps, temporal traps, fake │
|
| 21 |
-
│ bias patterns, distractor injection │
|
| 22 |
-
├─────────────────────────────────────────────┤
|
| 23 |
-
│ Data Engine (Generator) │
|
| 24 |
-
│ Statistical distributions, demographics, │
|
| 25 |
-
│ reproducible seeds, configurable params │
|
| 26 |
-
└─────────────────────────────────────────────┘
|
| 27 |
-
|
| 28 |
-
Key design decisions:
|
| 29 |
-
- Procedural generation: every reset() → unique dataset → no memorization
|
| 30 |
-
- Ground-truth grading: errors are pre-computed, grading is O(1) lookup
|
| 31 |
-
- Confidence-calibrated scoring: overconfident + wrong = devastating penalty
|
| 32 |
-
- False positive cost 3× correct reward → forces precision over recall
|
| 33 |
-
- Adversarial traps: boundary-valid ages, near-temporal cases, fake patterns
|
| 34 |
-
- Multi-phase workflow: Investigation → Flagging → Reporting
|
| 35 |
-
- Seed-based reproducibility for deterministic evaluation
|
| 36 |
"""
|
|
|
|
|
|
|
|
|
|
| 37 |
import uuid
|
| 38 |
from datetime import datetime
|
|
|
|
| 39 |
from openenv.core.env_server import Environment
|
| 40 |
-
from models import AuditAction, AuditObservation, AuditState
|
| 41 |
-
from dataset_generator import DatasetGenerator
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
REWARD_CONFIG = {
|
| 47 |
-
"correct_flag": 0.
|
| 48 |
-
"false_positive": -0.
|
| 49 |
-
"duplicate_flag": -0.
|
| 50 |
-
"investigate_new": 0.
|
| 51 |
-
"investigate_redundant": -0.02,
|
| 52 |
-
"distribution_new": 0.04,
|
| 53 |
"distribution_redundant": -0.02,
|
| 54 |
-
"invalid_phase": -0.
|
| 55 |
-
"unknown_action": -0.05,
|
| 56 |
-
"cost_per_step": 0.
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"bias_detected": 0.20,
|
| 60 |
-
"propose_fix_valid": 0.
|
| 61 |
-
"propose_fix_invalid": -0.
|
| 62 |
-
"report_bonus_base": 0.
|
| 63 |
-
"overconfidence_multiplier":
|
| 64 |
}
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
TASKS = {
|
| 71 |
"task_easy": {
|
| 72 |
"task_id": "task_easy",
|
| 73 |
-
"task_type": "syntactic_cleaning",
|
| 74 |
"difficulty": "easy",
|
|
|
|
|
|
|
| 75 |
"allow_bias": False,
|
| 76 |
-
"
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
"PHASE 1 — INVESTIGATION:\n"
|
| 80 |
-
" Use investigate_pattern(variable=<col>) to profile key variables\n"
|
| 81 |
-
" Use compute_distribution(variable=<col>) to compute descriptive stats\n\n"
|
| 82 |
-
"PHASE 2 — FLAGGING:\n"
|
| 83 |
-
" Use flag_error(patient_id=<id>, error_type='invalid_age') for age violations\n"
|
| 84 |
-
" Valid age range for trial eligibility: 18-120 (inclusive)\n"
|
| 85 |
-
" Missing age (null) is also invalid — required field\n"
|
| 86 |
-
" CAUTION: Some ages are rare but valid (e.g., 95, 19, 120). Do NOT over-flag.\n\n"
|
| 87 |
-
"PHASE 3 — REPORTING:\n"
|
| 88 |
-
" Use submit_report(report=<comprehensive analysis>) to finalize\n\n"
|
| 89 |
-
"Objective: Find ALL patients with invalid ages. Avoid false positives."
|
| 90 |
-
),
|
| 91 |
},
|
| 92 |
"task_medium": {
|
| 93 |
"task_id": "task_medium",
|
| 94 |
-
"task_type": "temporal_consistency",
|
| 95 |
"difficulty": "medium",
|
|
|
|
|
|
|
| 96 |
"allow_bias": False,
|
| 97 |
-
"
|
| 98 |
-
"
|
| 99 |
-
"
|
| 100 |
-
"
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
" Use flag_error with error_type='invalid_age' OR 'temporal_inconsistency'\n"
|
| 105 |
-
" Age violations: outside range 18-120 (inclusive) or null\n"
|
| 106 |
-
" Temporal violations: death_date MUST NOT precede treatment_start\n"
|
| 107 |
-
" NOTE: A patient dying 1 day after treatment start IS valid (not an error)\n\n"
|
| 108 |
-
"PHASE 3 — REPORTING:\n"
|
| 109 |
-
" Use submit_report(report=<comprehensive analysis>) to finalize\n\n"
|
| 110 |
-
"Objective: Find ALL age errors AND temporal inconsistencies."
|
| 111 |
-
),
|
| 112 |
},
|
| 113 |
"task_hard": {
|
| 114 |
"task_id": "task_hard",
|
| 115 |
-
"task_type": "comprehensive_audit",
|
| 116 |
"difficulty": "hard",
|
|
|
|
|
|
|
| 117 |
"allow_bias": True,
|
| 118 |
-
"
|
| 119 |
-
"
|
| 120 |
-
"
|
| 121 |
-
"
|
| 122 |
-
"
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
" flag_error with error_type='invalid_age', 'temporal_inconsistency', or 'selection_bias'\n"
|
| 127 |
-
" For selection_bias: Identify if the control group has demographic imbalance\n"
|
| 128 |
-
" AND whether this correlates with outcome disparity across subgroups\n"
|
| 129 |
-
" Look for: representation bias, outcome disparity, intersectional patterns\n\n"
|
| 130 |
-
"PHASE 3 — REPORTING:\n"
|
| 131 |
-
" Use submit_report(report=<comprehensive analysis>) to finalize\n"
|
| 132 |
-
" Include: statistical evidence, root cause analysis, corrective recommendations\n\n"
|
| 133 |
-
"Objective: Find ALL data errors AND demographic bias patterns."
|
| 134 |
-
),
|
| 135 |
},
|
| 136 |
}
|
| 137 |
|
| 138 |
-
# Maximum steps per episode — scales with dataset size
|
| 139 |
MAX_STEPS = {
|
| 140 |
-
"task_easy":
|
| 141 |
-
"task_medium":
|
| 142 |
-
"task_hard":
|
| 143 |
}
|
| 144 |
|
| 145 |
|
| 146 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 147 |
-
# ENVIRONMENT IMPLEMENTATION
|
| 148 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 149 |
-
|
| 150 |
class ClinicalTrialAuditorEnvironment(Environment):
|
| 151 |
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 152 |
|
|
@@ -155,9 +105,12 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 155 |
self._state = AuditState()
|
| 156 |
self._current_task = None
|
| 157 |
self._dataset = []
|
| 158 |
-
self._ground_truth = {}
|
| 159 |
-
self._traps = set()
|
| 160 |
self._bias_present = False
|
|
|
|
|
|
|
|
|
|
| 161 |
self._flagged_patients = set()
|
| 162 |
self._patterns_investigated = set()
|
| 163 |
self._distributions_computed = set()
|
|
@@ -165,17 +118,185 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 165 |
self._max_steps = 15
|
| 166 |
self._report_submitted = False
|
| 167 |
self._phase = "investigation"
|
| 168 |
-
self._score_log = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
def reset(self, seed=None, episode_id=None, **kwargs) -> AuditObservation:
|
| 171 |
-
"""
|
| 172 |
-
Reset the environment with a procedurally generated dataset.
|
| 173 |
-
|
| 174 |
-
Args:
|
| 175 |
-
seed: Random seed for reproducibility. Same seed = identical dataset.
|
| 176 |
-
episode_id: Optional episode identifier.
|
| 177 |
-
task_id: "task_easy" | "task_medium" | "task_hard"
|
| 178 |
-
"""
|
| 179 |
self._action_history = []
|
| 180 |
task_id = kwargs.get("task_id", "task_easy")
|
| 181 |
if task_id not in TASKS:
|
|
@@ -184,7 +305,6 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 184 |
self._current_task = TASKS[task_id]
|
| 185 |
difficulty = self._current_task["difficulty"]
|
| 186 |
|
| 187 |
-
# ── Procedural dataset generation ──
|
| 188 |
generator = DatasetGenerator(seed=seed)
|
| 189 |
result = generator.generate(difficulty=difficulty)
|
| 190 |
|
|
@@ -192,7 +312,9 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 192 |
self._ground_truth = result["ground_truth"]
|
| 193 |
self._traps = result["traps"]
|
| 194 |
self._bias_present = result["bias_present"]
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
|
| 197 |
self._flagged_patients = set()
|
| 198 |
self._patterns_investigated = set()
|
|
@@ -202,21 +324,32 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 202 |
self._report_submitted = False
|
| 203 |
self._phase = "investigation"
|
| 204 |
self._score_log = []
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
self._state = AuditState(
|
| 209 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 210 |
step_count=0,
|
| 211 |
task_id=task_id,
|
| 212 |
task_type=self._current_task["task_type"],
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
errors_found=0,
|
| 215 |
current_score=0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
attempts=0,
|
| 217 |
phase="investigation",
|
| 218 |
patterns_investigated=[],
|
| 219 |
distributions_computed=[],
|
|
|
|
| 220 |
)
|
| 221 |
|
| 222 |
return AuditObservation(
|
|
@@ -224,17 +357,20 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 224 |
reward=0.0,
|
| 225 |
task_id=task_id,
|
| 226 |
task_type=self._current_task["task_type"],
|
| 227 |
-
task_description=self.
|
|
|
|
|
|
|
| 228 |
dataset=self._dataset,
|
| 229 |
errors_found=[],
|
| 230 |
patterns_investigated=[],
|
| 231 |
distributions_computed=[],
|
| 232 |
feedback=(
|
| 233 |
-
f"Audit started
|
| 234 |
-
|
| 235 |
-
f"to profile the dataset."
|
| 236 |
),
|
| 237 |
score_so_far=0.0,
|
|
|
|
|
|
|
| 238 |
attempts_remaining=self._max_steps,
|
| 239 |
phase="investigation",
|
| 240 |
)
|
|
@@ -242,11 +378,23 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 242 |
def step(self, action: AuditAction, **kwargs) -> AuditObservation:
|
| 243 |
if self._current_task is None:
|
| 244 |
return AuditObservation(
|
| 245 |
-
done=True,
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
)
|
| 251 |
|
| 252 |
self._action_history.append(action.action_type)
|
|
@@ -254,72 +402,50 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 254 |
self._state.step_count += 1
|
| 255 |
self._state.attempts = self._attempts
|
| 256 |
|
| 257 |
-
# Core grading against ground truth
|
| 258 |
step_reward, feedback = self._grade(action)
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
feedback += f" [OVERCONFIDENCE PENALTY: conf={agent_confidence:.0%}]"
|
| 268 |
-
elif step_reward > 0: # Correct answer
|
| 269 |
-
step_reward *= max(0.5, agent_confidence)
|
| 270 |
-
|
| 271 |
-
# Step cost (progressive — later steps cost more)
|
| 272 |
-
step_cost = REWARD_CONFIG["cost_per_step"] * (1 + self._attempts * 0.05)
|
| 273 |
-
step_reward -= step_cost
|
| 274 |
-
|
| 275 |
-
# Anti brute-force (punish spinning without flagging)
|
| 276 |
-
if self._attempts > self._max_steps // 2 and len(self._flagged_patients) < 3:
|
| 277 |
-
step_reward -= 0.05
|
| 278 |
-
|
| 279 |
-
# Efficiency bonus
|
| 280 |
-
if len(self._patterns_investigated) >= 3 and len(self._flagged_patients) >= 3:
|
| 281 |
-
step_reward += REWARD_CONFIG["bonus_efficiency"]
|
| 282 |
-
|
| 283 |
-
# Workflow sequence bonus
|
| 284 |
-
if len(self._action_history) >= 3:
|
| 285 |
-
if self._action_history[-3:] == [
|
| 286 |
-
"investigate_pattern", "compute_distribution", "flag_error"
|
| 287 |
-
]:
|
| 288 |
-
step_reward += REWARD_CONFIG["bonus_workflow"]
|
| 289 |
-
|
| 290 |
-
# Difficulty multiplier
|
| 291 |
-
mult = {
|
| 292 |
-
"task_easy": 1.0, "task_medium": 1.2, "task_hard": 1.5
|
| 293 |
-
}.get(self._current_task["task_id"], 1.0)
|
| 294 |
-
step_reward = round(step_reward * mult, 3)
|
| 295 |
-
step_reward = max(-0.5, step_reward)
|
| 296 |
-
|
| 297 |
-
self._state.current_score = max(
|
| 298 |
-
0.0, min(1.0, self._state.current_score + step_reward)
|
| 299 |
-
)
|
| 300 |
|
| 301 |
-
|
| 302 |
-
self.
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
"
|
| 306 |
-
"cumulative": self._state.current_score,
|
| 307 |
-
})
|
| 308 |
|
| 309 |
done = self._report_submitted or self._attempts >= self._max_steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
return AuditObservation(
|
| 312 |
done=done,
|
| 313 |
-
reward=step_reward,
|
| 314 |
task_id=self._current_task["task_id"],
|
| 315 |
task_type=self._current_task["task_type"],
|
| 316 |
-
task_description=self.
|
|
|
|
|
|
|
| 317 |
dataset=self._dataset,
|
| 318 |
-
errors_found=
|
| 319 |
-
patterns_investigated=
|
| 320 |
-
distributions_computed=
|
| 321 |
feedback=feedback,
|
| 322 |
score_so_far=self._state.current_score,
|
|
|
|
|
|
|
| 323 |
attempts_remaining=max(0, self._max_steps - self._attempts),
|
| 324 |
phase=self._phase,
|
| 325 |
)
|
|
@@ -328,399 +454,297 @@ class ClinicalTrialAuditorEnvironment(Environment):
|
|
| 328 |
def state(self) -> AuditState:
|
| 329 |
return self._state
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
# Phase validation
|
| 338 |
-
if self._phase == "investigation" and action.action_type in [
|
| 339 |
-
"flag_error", "submit_report"
|
| 340 |
-
]:
|
| 341 |
return (
|
| 342 |
REWARD_CONFIG["invalid_phase"],
|
| 343 |
-
"PHASE BLOCKED:
|
| 344 |
-
"Use investigate_pattern or compute_distribution first."
|
| 345 |
-
)
|
| 346 |
-
if (self._phase == "flagging"
|
| 347 |
-
and action.action_type == "submit_report"
|
| 348 |
-
and len(self._flagged_patients) == 0):
|
| 349 |
-
return (
|
| 350 |
-
REWARD_CONFIG["invalid_phase"],
|
| 351 |
-
"PHASE BLOCKED: Flag at least one issue before submitting report."
|
| 352 |
)
|
| 353 |
|
| 354 |
if action.action_type == "investigate_pattern":
|
| 355 |
return self._grade_investigate(action)
|
| 356 |
-
|
| 357 |
return self._grade_distribution(action)
|
| 358 |
-
|
| 359 |
return self._grade_flag(action)
|
| 360 |
-
|
| 361 |
return self._grade_propose_fix(action)
|
| 362 |
-
|
| 363 |
return self._grade_report(action)
|
| 364 |
-
else:
|
| 365 |
-
return (
|
| 366 |
-
REWARD_CONFIG["unknown_action"],
|
| 367 |
-
f"REJECTED: Unknown action '{action.action_type}'. "
|
| 368 |
-
f"Valid: investigate_pattern, compute_distribution, "
|
| 369 |
-
f"flag_error, propose_fix, submit_report."
|
| 370 |
-
)
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
| 376 |
|
|
|
|
|
|
|
| 377 |
valid_vars = {
|
| 378 |
-
"age",
|
| 379 |
-
"
|
| 380 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
}
|
| 382 |
-
|
| 383 |
if variable not in valid_vars:
|
| 384 |
-
return
|
| 385 |
-
REWARD_CONFIG["unknown_action"],
|
| 386 |
-
f"REJECTED: Unknown variable '{variable}'. "
|
| 387 |
-
f"Valid: {', '.join(sorted(valid_vars))}."
|
| 388 |
-
)
|
| 389 |
-
|
| 390 |
if variable in self._patterns_investigated:
|
| 391 |
return (
|
| 392 |
REWARD_CONFIG["investigate_redundant"],
|
| 393 |
-
f"Already investigated '{variable}'.
|
| 394 |
)
|
| 395 |
|
| 396 |
self._patterns_investigated.add(variable)
|
| 397 |
-
self._state.patterns_investigated.append(variable)
|
| 398 |
-
|
| 399 |
-
# Phase transition: unlock flagging after investigating key variables
|
| 400 |
-
if (
|
| 401 |
-
"age" in self._patterns_investigated
|
| 402 |
-
and "death_date" in self._patterns_investigated
|
| 403 |
-
and self._phase == "investigation"
|
| 404 |
-
):
|
| 405 |
-
self._phase = "flagging"
|
| 406 |
|
| 407 |
-
# Dynamic statistics based on variable type
|
| 408 |
if variable == "age":
|
| 409 |
-
ages = [
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
| 430 |
else:
|
| 431 |
counts = {}
|
| 432 |
-
for
|
| 433 |
-
|
| 434 |
-
counts[
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
f"{len(sorted_counts)}): {top_10}."
|
| 445 |
-
)
|
| 446 |
-
else:
|
| 447 |
-
feedback = f"{variable.capitalize()} Distribution: {sorted_counts}."
|
| 448 |
-
|
| 449 |
-
return REWARD_CONFIG["investigate_new"], f"Investigated '{variable}': {feedback}"
|
| 450 |
-
|
| 451 |
-
def _grade_distribution(self, action: AuditAction):
|
| 452 |
variable = action.variable or ""
|
| 453 |
if not variable:
|
| 454 |
return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
|
| 455 |
-
|
| 456 |
if variable in self._distributions_computed:
|
| 457 |
return (
|
| 458 |
REWARD_CONFIG["distribution_redundant"],
|
| 459 |
-
f"Distribution for '{variable}' already computed."
|
| 460 |
)
|
| 461 |
|
| 462 |
self._distributions_computed.add(variable)
|
| 463 |
-
self._state.distributions_computed.append(variable)
|
| 464 |
-
|
| 465 |
-
# Phase transition via distribution analysis
|
| 466 |
-
if (
|
| 467 |
-
"ethnicity" in self._distributions_computed
|
| 468 |
-
and "outcome" in self._distributions_computed
|
| 469 |
-
and self._phase == "investigation"
|
| 470 |
-
):
|
| 471 |
-
self._phase = "flagging"
|
| 472 |
|
|
|
|
| 473 |
if variable == "ethnicity":
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
for k, v in sorted(eth_counts.items(), key=lambda x: -x[1])
|
| 484 |
-
)
|
| 485 |
-
feedback = f"Control group ethnicity: {breakdown}. Total={total}."
|
| 486 |
-
else:
|
| 487 |
-
feedback = "No control group patients found."
|
| 488 |
-
elif variable == "outcome":
|
| 489 |
-
control = [p for p in self._dataset if p.get("group") == "control"]
|
| 490 |
-
if control:
|
| 491 |
-
deceased_c = sum(
|
| 492 |
-
1 for p in control if p.get("outcome") == "deceased"
|
| 493 |
-
)
|
| 494 |
-
total = len(control)
|
| 495 |
-
feedback = (
|
| 496 |
-
f"Control group outcomes: deceased={deceased_c}/{total} "
|
| 497 |
-
f"({deceased_c / total * 100:.0f}%). "
|
| 498 |
-
f"Survived={total - deceased_c}/{total} "
|
| 499 |
-
f"({(total - deceased_c) / total * 100:.0f}%)."
|
| 500 |
-
)
|
| 501 |
-
else:
|
| 502 |
-
feedback = "No control group patients found."
|
| 503 |
elif variable == "gender":
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
| 516 |
else:
|
| 517 |
-
|
| 518 |
|
| 519 |
-
return REWARD_CONFIG["distribution_new"],
|
| 520 |
|
| 521 |
-
def _grade_flag(self, action: AuditAction):
|
| 522 |
-
"""Grade flag action against pre-computed ground truth."""
|
| 523 |
-
patient_id = action.patient_id
|
| 524 |
error_type = action.error_type or ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
-
# ── Selection bias flag (no patient_id needed) ──
|
| 527 |
if error_type == "selection_bias":
|
| 528 |
if not self._current_task["allow_bias"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
return (
|
| 530 |
-
REWARD_CONFIG["
|
| 531 |
-
"
|
| 532 |
)
|
| 533 |
-
|
| 534 |
if "BIAS_FLAG" in self._flagged_patients:
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
"Selection bias already flagged."
|
| 538 |
-
)
|
| 539 |
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
return (
|
| 545 |
-
REWARD_CONFIG["false_positive"],
|
| 546 |
-
"Cannot assess bias — no control group found."
|
| 547 |
-
)
|
| 548 |
-
|
| 549 |
-
white_count = sum(
|
| 550 |
-
1 for p in control if p.get("ethnicity") == "White"
|
| 551 |
-
)
|
| 552 |
-
white_ratio = white_count / len(control)
|
| 553 |
-
minority_dead = sum(
|
| 554 |
-
1 for p in control
|
| 555 |
-
if p.get("ethnicity") != "White"
|
| 556 |
-
and p.get("outcome") == "deceased"
|
| 557 |
-
)
|
| 558 |
-
male_count = sum(
|
| 559 |
-
1 for p in control if p.get("gender") == "M"
|
| 560 |
-
)
|
| 561 |
-
male_ratio = male_count / len(control)
|
| 562 |
-
|
| 563 |
-
if white_ratio >= 0.65 and minority_dead > 0 and male_ratio >= 0.50:
|
| 564 |
-
self._flagged_patients.add("BIAS_FLAG")
|
| 565 |
-
self._state.errors_found += 1
|
| 566 |
-
return (
|
| 567 |
-
REWARD_CONFIG["bias_detected"],
|
| 568 |
-
f"✓ Correct. Multi-dimensional selection bias confirmed: "
|
| 569 |
-
f"White={white_ratio:.0%} of control, "
|
| 570 |
-
f"minority mortality present ({minority_dead} deceased), "
|
| 571 |
-
f"gender imbalance ({male_ratio:.0%} male)."
|
| 572 |
-
)
|
| 573 |
-
else:
|
| 574 |
-
return (
|
| 575 |
-
REWARD_CONFIG["false_positive"],
|
| 576 |
-
"✗ Statistical evidence insufficient for bias determination."
|
| 577 |
-
)
|
| 578 |
-
else:
|
| 579 |
return (
|
| 580 |
-
REWARD_CONFIG["
|
| 581 |
-
"
|
|
|
|
|
|
|
|
|
|
| 582 |
)
|
| 583 |
|
| 584 |
-
|
| 585 |
-
if patient_id is None:
|
| 586 |
return (
|
| 587 |
REWARD_CONFIG["false_positive"],
|
| 588 |
-
"
|
|
|
|
|
|
|
| 589 |
)
|
| 590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
if patient_id in self._flagged_patients:
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
f"{patient_id} already flagged."
|
| 595 |
-
)
|
| 596 |
|
| 597 |
-
|
| 598 |
-
patient
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
)
|
| 602 |
-
if not patient:
|
| 603 |
-
return (
|
| 604 |
-
REWARD_CONFIG["false_positive"],
|
| 605 |
-
f"REJECTED: Patient '{patient_id}' not found in dataset."
|
| 606 |
-
)
|
| 607 |
|
| 608 |
-
# ── Ground truth lookup (O(1) — deterministic) ──
|
| 609 |
expected_errors = self._ground_truth.get(patient_id, [])
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
self._state.errors_found += 1
|
| 615 |
-
age = patient.get("age")
|
| 616 |
return (
|
| 617 |
REWARD_CONFIG["correct_flag"],
|
| 618 |
-
f"✓ Correct: {patient_id}
|
| 619 |
-
|
| 620 |
-
else:
|
| 621 |
-
age = patient.get("age")
|
| 622 |
-
return (
|
| 623 |
-
REWARD_CONFIG["false_positive"],
|
| 624 |
-
f"✗ False positive: {patient_id} age={age} is within valid range [18-120]."
|
| 625 |
)
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
self._state.errors_found += 1
|
| 631 |
-
ts = patient.get("treatment_start", "")
|
| 632 |
-
dd = patient.get("death_date", "")
|
| 633 |
-
if ts and dd:
|
| 634 |
-
t = datetime.strptime(ts, "%Y-%m-%d")
|
| 635 |
-
d = datetime.strptime(dd, "%Y-%m-%d")
|
| 636 |
-
gap = (t - d).days
|
| 637 |
-
return (
|
| 638 |
-
REWARD_CONFIG["correct_flag"],
|
| 639 |
-
f"✓ Correct: {patient_id} death_date is {gap} days "
|
| 640 |
-
f"before treatment_start."
|
| 641 |
-
)
|
| 642 |
return (
|
| 643 |
REWARD_CONFIG["correct_flag"],
|
| 644 |
-
f"✓ Correct: {patient_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
)
|
| 646 |
-
else:
|
| 647 |
return (
|
| 648 |
-
REWARD_CONFIG["
|
| 649 |
-
f"
|
| 650 |
)
|
| 651 |
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
return (
|
| 654 |
REWARD_CONFIG["false_positive"],
|
| 655 |
-
f"✗
|
| 656 |
-
f"
|
| 657 |
)
|
| 658 |
|
| 659 |
-
|
|
|
|
|
|
|
| 660 |
patient_id = action.patient_id or ""
|
| 661 |
if patient_id not in self._flagged_patients:
|
| 662 |
-
return
|
| 663 |
-
REWARD_CONFIG["propose_fix_invalid"],
|
| 664 |
-
"Can only propose fix for flagged patients."
|
| 665 |
-
)
|
| 666 |
proposed = action.proposed_value or ""
|
| 667 |
if len(proposed) > 2:
|
| 668 |
-
return
|
| 669 |
-
|
| 670 |
-
f"Fix proposed for {patient_id}."
|
| 671 |
-
)
|
| 672 |
-
return REWARD_CONFIG["propose_fix_invalid"], "Proposed fix too vague."
|
| 673 |
|
| 674 |
-
def _grade_report(self, action: AuditAction):
|
| 675 |
-
"""Grade report quality using multi-dimensional rubric."""
|
| 676 |
self._report_submitted = True
|
| 677 |
report = (action.report or action.reason or "").lower()
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
# Completeness bonus: flagged enough issues
|
| 681 |
-
if len(self._flagged_patients) >= 3:
|
| 682 |
-
step_reward += 0.03
|
| 683 |
-
|
| 684 |
-
# ── Report quality rubric (tests clinical reasoning depth) ──
|
| 685 |
-
quality_score = 0
|
| 686 |
quality_items = []
|
| 687 |
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
"
|
| 691 |
-
]):
|
| 692 |
-
|
| 693 |
-
quality_items.append("root cause
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
quality_score += 1
|
| 700 |
-
quality_items.append("corrective recommendations")
|
| 701 |
-
|
| 702 |
-
# Risk assessment
|
| 703 |
-
if any(kw in report for kw in [
|
| 704 |
-
"risk", "severity", "critical", "impact", "patient safety"
|
| 705 |
-
]):
|
| 706 |
-
quality_score += 1
|
| 707 |
quality_items.append("risk assessment")
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
quality_feedback = f"Report quality: {quality_score}/4 dimensions"
|
| 720 |
-
if quality_items:
|
| 721 |
-
quality_feedback += f" ({', '.join(quality_items)})"
|
| 722 |
-
|
| 723 |
-
return (
|
| 724 |
-
step_reward,
|
| 725 |
-
f"Report submitted. {quality_feedback}. Final evaluation complete."
|
| 726 |
-
)
|
|
|
|
| 1 |
"""
|
| 2 |
Clinical Trial Auditor — OpenEnv Environment
|
| 3 |
+
============================================
|
| 4 |
+
Protocol-aware clinical audit benchmark with dynamic rules, adversarial traps,
|
| 5 |
+
and stage-aware fairness evaluation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
import uuid
|
| 11 |
from datetime import datetime
|
| 12 |
+
|
| 13 |
from openenv.core.env_server import Environment
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
try:
|
| 16 |
+
from .dataset_generator import DatasetGenerator
|
| 17 |
+
from .models import AuditAction, AuditObservation, AuditState
|
| 18 |
+
except ImportError:
|
| 19 |
+
from dataset_generator import DatasetGenerator
|
| 20 |
+
from models import AuditAction, AuditObservation, AuditState
|
| 21 |
+
|
| 22 |
+
|
| 23 |
REWARD_CONFIG = {
|
| 24 |
+
"correct_flag": 0.16,
|
| 25 |
+
"false_positive": -0.26,
|
| 26 |
+
"duplicate_flag": -0.08,
|
| 27 |
+
"investigate_new": 0.04,
|
| 28 |
+
"investigate_redundant": -0.02,
|
| 29 |
+
"distribution_new": 0.04,
|
| 30 |
"distribution_redundant": -0.02,
|
| 31 |
+
"invalid_phase": -0.06,
|
| 32 |
+
"unknown_action": -0.05,
|
| 33 |
+
"cost_per_step": 0.004,
|
| 34 |
+
"bonus_workflow": 0.03,
|
| 35 |
+
"bonus_protocol_window": 0.04,
|
| 36 |
+
"bias_detected": 0.20,
|
| 37 |
+
"propose_fix_valid": 0.02,
|
| 38 |
+
"propose_fix_invalid": -0.04,
|
| 39 |
+
"report_bonus_base": 0.03,
|
| 40 |
+
"overconfidence_multiplier": 1.8,
|
| 41 |
}
|
| 42 |
|
| 43 |
+
SCORE_WEIGHTS = {
|
| 44 |
+
"recall": 0.70,
|
| 45 |
+
"precision": 0.15,
|
| 46 |
+
"workflow": 0.05,
|
| 47 |
+
"efficiency": 0.05,
|
| 48 |
+
"report": 0.05,
|
| 49 |
+
}
|
| 50 |
|
| 51 |
TASKS = {
|
| 52 |
"task_easy": {
|
| 53 |
"task_id": "task_easy",
|
|
|
|
| 54 |
"difficulty": "easy",
|
| 55 |
+
"task_type": "eligibility_screening",
|
| 56 |
+
"title": "Dynamic Eligibility Screening",
|
| 57 |
"allow_bias": False,
|
| 58 |
+
"allowed_error_types": ["invalid_age"],
|
| 59 |
+
"required_investigations": ["age"],
|
| 60 |
+
"required_distributions": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
},
|
| 62 |
"task_medium": {
|
| 63 |
"task_id": "task_medium",
|
|
|
|
| 64 |
"difficulty": "medium",
|
| 65 |
+
"task_type": "protocol_timeline_audit",
|
| 66 |
+
"title": "Protocol Timeline Audit",
|
| 67 |
"allow_bias": False,
|
| 68 |
+
"allowed_error_types": [
|
| 69 |
+
"invalid_age",
|
| 70 |
+
"temporal_inconsistency",
|
| 71 |
+
"protocol_window_violation",
|
| 72 |
+
],
|
| 73 |
+
"required_investigations": ["age", "death_date", "enrollment_date", "stage"],
|
| 74 |
+
"required_distributions": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
},
|
| 76 |
"task_hard": {
|
| 77 |
"task_id": "task_hard",
|
|
|
|
| 78 |
"difficulty": "hard",
|
| 79 |
+
"task_type": "equity_and_protocol_audit",
|
| 80 |
+
"title": "Equity + Protocol Audit",
|
| 81 |
"allow_bias": True,
|
| 82 |
+
"allowed_error_types": [
|
| 83 |
+
"invalid_age",
|
| 84 |
+
"temporal_inconsistency",
|
| 85 |
+
"protocol_window_violation",
|
| 86 |
+
"selection_bias",
|
| 87 |
+
],
|
| 88 |
+
"required_investigations": ["age", "death_date", "enrollment_date", "stage"],
|
| 89 |
+
"required_distributions": ["ethnicity", "gender", "outcome"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
},
|
| 91 |
}
|
| 92 |
|
|
|
|
| 93 |
MAX_STEPS = {
|
| 94 |
+
"task_easy": 18,
|
| 95 |
+
"task_medium": 34,
|
| 96 |
+
"task_hard": 46,
|
| 97 |
}
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
class ClinicalTrialAuditorEnvironment(Environment):
|
| 101 |
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 102 |
|
|
|
|
| 105 |
self._state = AuditState()
|
| 106 |
self._current_task = None
|
| 107 |
self._dataset = []
|
| 108 |
+
self._ground_truth = {}
|
| 109 |
+
self._traps = set()
|
| 110 |
self._bias_present = False
|
| 111 |
+
self._protocol = {}
|
| 112 |
+
self._protocol_title = ""
|
| 113 |
+
self._protocol_excerpt = ""
|
| 114 |
self._flagged_patients = set()
|
| 115 |
self._patterns_investigated = set()
|
| 116 |
self._distributions_computed = set()
|
|
|
|
| 118 |
self._max_steps = 15
|
| 119 |
self._report_submitted = False
|
| 120 |
self._phase = "investigation"
|
| 121 |
+
self._score_log = []
|
| 122 |
+
self._dense_reward_total = 0.0
|
| 123 |
+
self._correct_flags = 0
|
| 124 |
+
self._false_positive_flags = 0
|
| 125 |
+
self._duplicate_flags = 0
|
| 126 |
+
self._invalid_phase_actions = 0
|
| 127 |
+
self._report_quality = 0.0
|
| 128 |
+
|
| 129 |
+
def _task_description(self) -> str:
|
| 130 |
+
allowed = ", ".join(self._current_task["allowed_error_types"])
|
| 131 |
+
lines = [
|
| 132 |
+
f"CLINICAL TRIAL AUDIT — {self._current_task['title']}",
|
| 133 |
+
"Role: Senior Clinical Data Manager",
|
| 134 |
+
"",
|
| 135 |
+
"Use the protocol excerpt from the observation. Do not assume default clinical rules.",
|
| 136 |
+
f"Allowed error types for this task: {allowed}.",
|
| 137 |
+
"",
|
| 138 |
+
"Workflow",
|
| 139 |
+
"- Investigate the required variables before flagging records.",
|
| 140 |
+
"- Use compute_distribution for cohort-level review when the task asks for bias analysis.",
|
| 141 |
+
"- submit_report should summarize evidence, impact, and corrective action.",
|
| 142 |
+
]
|
| 143 |
+
if self._current_task["allow_bias"]:
|
| 144 |
+
lines.append("- For selection_bias, determine whether actionable control-arm bias exists at all.")
|
| 145 |
+
return "\n".join(lines)
|
| 146 |
+
|
| 147 |
+
def _required_investigations(self) -> set[str]:
|
| 148 |
+
return set(self._current_task["required_investigations"])
|
| 149 |
+
|
| 150 |
+
def _required_distributions(self) -> set[str]:
|
| 151 |
+
return set(self._current_task["required_distributions"])
|
| 152 |
+
|
| 153 |
+
def _workflow_ready_for_flagging(self) -> bool:
|
| 154 |
+
return self._required_investigations().issubset(self._patterns_investigated)
|
| 155 |
+
|
| 156 |
+
def _bias_review_ready(self) -> bool:
|
| 157 |
+
return self._required_distributions().issubset(self._distributions_computed)
|
| 158 |
+
|
| 159 |
+
def _stage_adjusted_gap(self) -> tuple[float, str, float, float]:
|
| 160 |
+
control = [patient for patient in self._dataset if patient.get("group") == "control"]
|
| 161 |
+
if not control:
|
| 162 |
+
return 0.0, "Unknown", 0.0, 0.0
|
| 163 |
+
|
| 164 |
+
ethnicity_counts = {}
|
| 165 |
+
for patient in control:
|
| 166 |
+
ethnicity = patient.get("ethnicity", "Unknown")
|
| 167 |
+
ethnicity_counts[ethnicity] = ethnicity_counts.get(ethnicity, 0) + 1
|
| 168 |
+
dominant_ethnicity = max(ethnicity_counts.items(), key=lambda item: item[1])[0]
|
| 169 |
+
dominant_ratio = ethnicity_counts[dominant_ethnicity] / len(control)
|
| 170 |
+
male_ratio = sum(patient.get("gender") == "M" for patient in control) / len(control)
|
| 171 |
+
|
| 172 |
+
weighted_gap = 0.0
|
| 173 |
+
total_weight = 0
|
| 174 |
+
for stage in ("I", "II", "III", "IV"):
|
| 175 |
+
stage_patients = [patient for patient in control if patient.get("stage") == stage]
|
| 176 |
+
dominant_stage = [patient for patient in stage_patients if patient.get("ethnicity") == dominant_ethnicity]
|
| 177 |
+
minority_stage = [patient for patient in stage_patients if patient.get("ethnicity") != dominant_ethnicity]
|
| 178 |
+
if len(dominant_stage) < 5 or len(minority_stage) < 5:
|
| 179 |
+
continue
|
| 180 |
+
dom_mortality = sum(patient.get("outcome") == "deceased" for patient in dominant_stage) / len(dominant_stage)
|
| 181 |
+
min_mortality = sum(patient.get("outcome") == "deceased" for patient in minority_stage) / len(minority_stage)
|
| 182 |
+
weight = len(stage_patients)
|
| 183 |
+
weighted_gap += (min_mortality - dom_mortality) * weight
|
| 184 |
+
total_weight += weight
|
| 185 |
+
|
| 186 |
+
stage_adjusted_gap = weighted_gap / total_weight if total_weight else 0.0
|
| 187 |
+
return stage_adjusted_gap, dominant_ethnicity, dominant_ratio, male_ratio
|
| 188 |
+
|
| 189 |
+
def _bias_signal(self) -> dict:
|
| 190 |
+
control = [patient for patient in self._dataset if patient.get("group") == "control"]
|
| 191 |
+
if not control:
|
| 192 |
+
return {
|
| 193 |
+
"signal_present": False,
|
| 194 |
+
"stage_adjusted_gap": 0.0,
|
| 195 |
+
"dominant_ethnicity": "Unknown",
|
| 196 |
+
"dominant_ratio": 0.0,
|
| 197 |
+
"male_ratio": 0.0,
|
| 198 |
+
"overall_gap": 0.0,
|
| 199 |
+
"high_risk_note": "",
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
stage_adjusted_gap, dominant_ethnicity, dominant_ratio, male_ratio = self._stage_adjusted_gap()
|
| 203 |
+
dominant_group = [patient for patient in control if patient.get("ethnicity") == dominant_ethnicity]
|
| 204 |
+
minority_group = [patient for patient in control if patient.get("ethnicity") != dominant_ethnicity]
|
| 205 |
+
dom_mortality = (
|
| 206 |
+
sum(patient.get("outcome") == "deceased" for patient in dominant_group) / len(dominant_group)
|
| 207 |
+
if dominant_group
|
| 208 |
+
else 0.0
|
| 209 |
+
)
|
| 210 |
+
min_mortality = (
|
| 211 |
+
sum(patient.get("outcome") == "deceased" for patient in minority_group) / len(minority_group)
|
| 212 |
+
if minority_group
|
| 213 |
+
else 0.0
|
| 214 |
+
)
|
| 215 |
+
overall_gap = min_mortality - dom_mortality
|
| 216 |
+
signal_present = (
|
| 217 |
+
dominant_ratio >= self._protocol.get("bias_control_dominance_threshold", 1.0)
|
| 218 |
+
and male_ratio >= self._protocol.get("bias_male_threshold", 1.0)
|
| 219 |
+
and stage_adjusted_gap >= self._protocol.get("bias_stage_adjusted_gap", 1.0)
|
| 220 |
+
)
|
| 221 |
+
return {
|
| 222 |
+
"signal_present": signal_present,
|
| 223 |
+
"stage_adjusted_gap": stage_adjusted_gap,
|
| 224 |
+
"dominant_ethnicity": dominant_ethnicity,
|
| 225 |
+
"dominant_ratio": dominant_ratio,
|
| 226 |
+
"male_ratio": male_ratio,
|
| 227 |
+
"overall_gap": overall_gap,
|
| 228 |
+
"high_risk_note": ", ".join(self._protocol.get("high_risk_sites", [])),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
def _build_breakdown(self) -> dict[str, float]:
|
| 232 |
+
total_targets = max(1, self._state.total_errors)
|
| 233 |
+
recall = min(1.0, self._correct_flags / total_targets)
|
| 234 |
+
precision = self._correct_flags / max(
|
| 235 |
+
1,
|
| 236 |
+
self._correct_flags + (2 * self._false_positive_flags) + self._duplicate_flags,
|
| 237 |
+
)
|
| 238 |
+
required_investigations = len(self._required_investigations())
|
| 239 |
+
required_distributions = len(self._required_distributions())
|
| 240 |
+
investigation_coverage = (
|
| 241 |
+
min(len(self._patterns_investigated & self._required_investigations()), required_investigations)
|
| 242 |
+
/ required_investigations
|
| 243 |
+
if required_investigations
|
| 244 |
+
else 1.0
|
| 245 |
+
)
|
| 246 |
+
distribution_coverage = (
|
| 247 |
+
min(len(self._distributions_computed & self._required_distributions()), required_distributions)
|
| 248 |
+
/ required_distributions
|
| 249 |
+
if required_distributions
|
| 250 |
+
else 1.0
|
| 251 |
+
)
|
| 252 |
+
if required_investigations and required_distributions:
|
| 253 |
+
workflow = (0.7 * investigation_coverage) + (0.3 * distribution_coverage)
|
| 254 |
+
elif required_investigations:
|
| 255 |
+
workflow = investigation_coverage
|
| 256 |
+
elif required_distributions:
|
| 257 |
+
workflow = distribution_coverage
|
| 258 |
+
else:
|
| 259 |
+
workflow = 0.0
|
| 260 |
+
workflow *= max(0.0, 1.0 - (0.12 * self._invalid_phase_actions))
|
| 261 |
+
|
| 262 |
+
useful_steps = (
|
| 263 |
+
min(len(self._patterns_investigated), required_investigations)
|
| 264 |
+
+ min(len(self._distributions_computed), required_distributions)
|
| 265 |
+
+ self._correct_flags
|
| 266 |
+
+ (1 if self._report_submitted else 0)
|
| 267 |
+
)
|
| 268 |
+
efficiency = min(1.0, useful_steps / max(1, self._attempts))
|
| 269 |
+
report = self._report_quality / 5.0
|
| 270 |
+
score = (
|
| 271 |
+
SCORE_WEIGHTS["recall"] * recall
|
| 272 |
+
+ SCORE_WEIGHTS["precision"] * precision
|
| 273 |
+
+ SCORE_WEIGHTS["workflow"] * workflow
|
| 274 |
+
+ SCORE_WEIGHTS["efficiency"] * efficiency
|
| 275 |
+
+ SCORE_WEIGHTS["report"] * report
|
| 276 |
+
)
|
| 277 |
+
return {
|
| 278 |
+
"recall": round(recall, 3),
|
| 279 |
+
"precision": round(precision, 3),
|
| 280 |
+
"workflow": round(workflow, 3),
|
| 281 |
+
"efficiency": round(efficiency, 3),
|
| 282 |
+
"report": round(report, 3),
|
| 283 |
+
"benchmark_score": round(min(1.0, max(0.0, score)), 3),
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
def _sync_state(self) -> None:
|
| 287 |
+
breakdown = self._build_breakdown()
|
| 288 |
+
self._state.current_score = breakdown["benchmark_score"]
|
| 289 |
+
self._state.dense_reward_total = round(self._dense_reward_total, 3)
|
| 290 |
+
self._state.correct_flags = self._correct_flags
|
| 291 |
+
self._state.false_positives = self._false_positive_flags
|
| 292 |
+
self._state.duplicate_flags = self._duplicate_flags
|
| 293 |
+
self._state.patterns_investigated = sorted(self._patterns_investigated)
|
| 294 |
+
self._state.distributions_computed = sorted(self._distributions_computed)
|
| 295 |
+
self._state.phase = self._phase
|
| 296 |
+
self._state.errors_found = self._correct_flags
|
| 297 |
+
self._state.score_breakdown = breakdown
|
| 298 |
|
| 299 |
def reset(self, seed=None, episode_id=None, **kwargs) -> AuditObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
self._action_history = []
|
| 301 |
task_id = kwargs.get("task_id", "task_easy")
|
| 302 |
if task_id not in TASKS:
|
|
|
|
| 305 |
self._current_task = TASKS[task_id]
|
| 306 |
difficulty = self._current_task["difficulty"]
|
| 307 |
|
|
|
|
| 308 |
generator = DatasetGenerator(seed=seed)
|
| 309 |
result = generator.generate(difficulty=difficulty)
|
| 310 |
|
|
|
|
| 312 |
self._ground_truth = result["ground_truth"]
|
| 313 |
self._traps = result["traps"]
|
| 314 |
self._bias_present = result["bias_present"]
|
| 315 |
+
self._protocol = result["protocol"]
|
| 316 |
+
self._protocol_title = result["protocol_title"]
|
| 317 |
+
self._protocol_excerpt = result["protocol_excerpt"]
|
| 318 |
|
| 319 |
self._flagged_patients = set()
|
| 320 |
self._patterns_investigated = set()
|
|
|
|
| 324 |
self._report_submitted = False
|
| 325 |
self._phase = "investigation"
|
| 326 |
self._score_log = []
|
| 327 |
+
self._dense_reward_total = 0.0
|
| 328 |
+
self._correct_flags = 0
|
| 329 |
+
self._false_positive_flags = 0
|
| 330 |
+
self._duplicate_flags = 0
|
| 331 |
+
self._invalid_phase_actions = 0
|
| 332 |
+
self._report_quality = 0.0
|
| 333 |
|
| 334 |
self._state = AuditState(
|
| 335 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 336 |
step_count=0,
|
| 337 |
task_id=task_id,
|
| 338 |
task_type=self._current_task["task_type"],
|
| 339 |
+
protocol_title=self._protocol_title,
|
| 340 |
+
trial_protocol_excerpt=self._protocol_excerpt,
|
| 341 |
+
total_errors=result["stats"]["total_errors"],
|
| 342 |
errors_found=0,
|
| 343 |
current_score=0.0,
|
| 344 |
+
dense_reward_total=0.0,
|
| 345 |
+
correct_flags=0,
|
| 346 |
+
false_positives=0,
|
| 347 |
+
duplicate_flags=0,
|
| 348 |
attempts=0,
|
| 349 |
phase="investigation",
|
| 350 |
patterns_investigated=[],
|
| 351 |
distributions_computed=[],
|
| 352 |
+
score_breakdown=self._build_breakdown(),
|
| 353 |
)
|
| 354 |
|
| 355 |
return AuditObservation(
|
|
|
|
| 357 |
reward=0.0,
|
| 358 |
task_id=task_id,
|
| 359 |
task_type=self._current_task["task_type"],
|
| 360 |
+
task_description=self._task_description(),
|
| 361 |
+
protocol_title=self._protocol_title,
|
| 362 |
+
trial_protocol_excerpt=self._protocol_excerpt,
|
| 363 |
dataset=self._dataset,
|
| 364 |
errors_found=[],
|
| 365 |
patterns_investigated=[],
|
| 366 |
distributions_computed=[],
|
| 367 |
feedback=(
|
| 368 |
+
f"Audit started for {self._protocol_title}. Read the protocol excerpt first, "
|
| 369 |
+
"then investigate the required variables before flagging issues."
|
|
|
|
| 370 |
),
|
| 371 |
score_so_far=0.0,
|
| 372 |
+
dense_reward_total=0.0,
|
| 373 |
+
score_breakdown=self._build_breakdown(),
|
| 374 |
attempts_remaining=self._max_steps,
|
| 375 |
phase="investigation",
|
| 376 |
)
|
|
|
|
| 378 |
def step(self, action: AuditAction, **kwargs) -> AuditObservation:
|
| 379 |
if self._current_task is None:
|
| 380 |
return AuditObservation(
|
| 381 |
+
done=True,
|
| 382 |
+
reward=0.0,
|
| 383 |
+
task_id="",
|
| 384 |
+
task_type="",
|
| 385 |
+
task_description="Call reset() first.",
|
| 386 |
+
protocol_title="",
|
| 387 |
+
trial_protocol_excerpt="",
|
| 388 |
+
dataset=[],
|
| 389 |
+
errors_found=[],
|
| 390 |
+
patterns_investigated=[],
|
| 391 |
+
distributions_computed=[],
|
| 392 |
+
feedback="No active episode.",
|
| 393 |
+
score_so_far=0.0,
|
| 394 |
+
dense_reward_total=0.0,
|
| 395 |
+
score_breakdown={},
|
| 396 |
+
attempts_remaining=0,
|
| 397 |
+
phase="investigation",
|
| 398 |
)
|
| 399 |
|
| 400 |
self._action_history.append(action.action_type)
|
|
|
|
| 402 |
self._state.step_count += 1
|
| 403 |
self._state.attempts = self._attempts
|
| 404 |
|
|
|
|
| 405 |
step_reward, feedback = self._grade(action)
|
| 406 |
|
| 407 |
+
if action.action_type == "flag_error" and action.confidence is not None:
|
| 408 |
+
confidence = max(0.0, min(1.0, action.confidence))
|
| 409 |
+
if step_reward < 0 and confidence > 0.8:
|
| 410 |
+
step_reward *= REWARD_CONFIG["overconfidence_multiplier"]
|
| 411 |
+
feedback += f" [OVERCONFIDENCE PENALTY: conf={confidence:.0%}]"
|
| 412 |
+
elif step_reward > 0:
|
| 413 |
+
step_reward *= max(0.65, confidence)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
+
step_reward -= REWARD_CONFIG["cost_per_step"] * self._attempts
|
| 416 |
+
self._dense_reward_total += step_reward
|
| 417 |
+
|
| 418 |
+
if self._workflow_ready_for_flagging():
|
| 419 |
+
self._phase = "flagging"
|
|
|
|
|
|
|
| 420 |
|
| 421 |
done = self._report_submitted or self._attempts >= self._max_steps
|
| 422 |
+
self._sync_state()
|
| 423 |
+
self._score_log.append(
|
| 424 |
+
{
|
| 425 |
+
"step": self._attempts,
|
| 426 |
+
"action": action.action_type,
|
| 427 |
+
"reward": round(step_reward, 3),
|
| 428 |
+
"dense_reward_total": round(self._dense_reward_total, 3),
|
| 429 |
+
"benchmark_score": self._state.current_score,
|
| 430 |
+
}
|
| 431 |
+
)
|
| 432 |
|
| 433 |
return AuditObservation(
|
| 434 |
done=done,
|
| 435 |
+
reward=round(step_reward, 3),
|
| 436 |
task_id=self._current_task["task_id"],
|
| 437 |
task_type=self._current_task["task_type"],
|
| 438 |
+
task_description=self._task_description(),
|
| 439 |
+
protocol_title=self._protocol_title,
|
| 440 |
+
trial_protocol_excerpt=self._protocol_excerpt,
|
| 441 |
dataset=self._dataset,
|
| 442 |
+
errors_found=sorted(self._flagged_patients),
|
| 443 |
+
patterns_investigated=sorted(self._patterns_investigated),
|
| 444 |
+
distributions_computed=sorted(self._distributions_computed),
|
| 445 |
feedback=feedback,
|
| 446 |
score_so_far=self._state.current_score,
|
| 447 |
+
dense_reward_total=self._state.dense_reward_total,
|
| 448 |
+
score_breakdown=self._state.score_breakdown,
|
| 449 |
attempts_remaining=max(0, self._max_steps - self._attempts),
|
| 450 |
phase=self._phase,
|
| 451 |
)
|
|
|
|
| 454 |
def state(self) -> AuditState:
|
| 455 |
return self._state
|
| 456 |
|
| 457 |
+
def _grade(self, action: AuditAction) -> tuple[float, str]:
|
| 458 |
+
if self._phase == "investigation" and action.action_type in {"flag_error", "submit_report"}:
|
| 459 |
+
if not self._workflow_ready_for_flagging():
|
| 460 |
+
self._invalid_phase_actions += 1
|
| 461 |
+
return (
|
| 462 |
+
REWARD_CONFIG["invalid_phase"],
|
| 463 |
+
"PHASE BLOCKED: Investigate the required variables before flagging or reporting.",
|
| 464 |
+
)
|
| 465 |
|
| 466 |
+
if action.action_type == "submit_report" and not self._flagged_patients:
|
| 467 |
+
self._invalid_phase_actions += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
return (
|
| 469 |
REWARD_CONFIG["invalid_phase"],
|
| 470 |
+
"PHASE BLOCKED: Flag at least one issue before submitting the report.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
)
|
| 472 |
|
| 473 |
if action.action_type == "investigate_pattern":
|
| 474 |
return self._grade_investigate(action)
|
| 475 |
+
if action.action_type == "compute_distribution":
|
| 476 |
return self._grade_distribution(action)
|
| 477 |
+
if action.action_type == "flag_error":
|
| 478 |
return self._grade_flag(action)
|
| 479 |
+
if action.action_type == "propose_fix":
|
| 480 |
return self._grade_propose_fix(action)
|
| 481 |
+
if action.action_type == "submit_report":
|
| 482 |
return self._grade_report(action)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
+
return (
|
| 485 |
+
REWARD_CONFIG["unknown_action"],
|
| 486 |
+
"REJECTED: Unknown action. Valid actions are investigate_pattern, compute_distribution, "
|
| 487 |
+
"flag_error, propose_fix, submit_report.",
|
| 488 |
+
)
|
| 489 |
|
| 490 |
+
def _grade_investigate(self, action: AuditAction) -> tuple[float, str]:
|
| 491 |
+
variable = action.variable or ""
|
| 492 |
valid_vars = {
|
| 493 |
+
"age",
|
| 494 |
+
"gender",
|
| 495 |
+
"ethnicity",
|
| 496 |
+
"treatment_start",
|
| 497 |
+
"death_date",
|
| 498 |
+
"outcome",
|
| 499 |
+
"treatment_site",
|
| 500 |
+
"group",
|
| 501 |
+
"stage",
|
| 502 |
+
"trial_phase",
|
| 503 |
+
"drug",
|
| 504 |
+
"country",
|
| 505 |
+
"enrollment_date",
|
| 506 |
}
|
|
|
|
| 507 |
if variable not in valid_vars:
|
| 508 |
+
return REWARD_CONFIG["unknown_action"], f"REJECTED: Unknown variable '{variable}'."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
if variable in self._patterns_investigated:
|
| 510 |
return (
|
| 511 |
REWARD_CONFIG["investigate_redundant"],
|
| 512 |
+
f"Already investigated '{variable}'. Move to another variable or flag a finding.",
|
| 513 |
)
|
| 514 |
|
| 515 |
self._patterns_investigated.add(variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
|
|
|
| 517 |
if variable == "age":
|
| 518 |
+
ages = [patient["age"] for patient in self._dataset if patient.get("age") is not None]
|
| 519 |
+
null_count = sum(patient.get("age") is None for patient in self._dataset)
|
| 520 |
+
feedback = (
|
| 521 |
+
f"Age profile: min={min(ages) if ages else 'NA'}, max={max(ages) if ages else 'NA'}, "
|
| 522 |
+
f"null_count={null_count}, protocol_range={self._protocol['age_min']}-{self._protocol['age_max']}."
|
| 523 |
+
)
|
| 524 |
+
elif variable == "death_date":
|
| 525 |
+
non_null = [patient for patient in self._dataset if patient.get("death_date")]
|
| 526 |
+
feedback = f"death_date present for {len(non_null)} patients. Compare against treatment_start."
|
| 527 |
+
elif variable == "enrollment_date":
|
| 528 |
+
delays = [
|
| 529 |
+
(datetime.strptime(patient["treatment_start"], "%Y-%m-%d") - datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")).days
|
| 530 |
+
for patient in self._dataset
|
| 531 |
+
]
|
| 532 |
+
feedback = (
|
| 533 |
+
f"Enrollment-to-treatment delays: min={min(delays)}, max={max(delays)}, "
|
| 534 |
+
f"standard_window={self._protocol['treatment_window_days']} days."
|
| 535 |
+
)
|
| 536 |
+
elif variable == "stage":
|
| 537 |
+
counts = {stage: 0 for stage in ("I", "II", "III", "IV")}
|
| 538 |
+
for patient in self._dataset:
|
| 539 |
+
counts[patient["stage"]] = counts.get(patient["stage"], 0) + 1
|
| 540 |
+
feedback = f"Stage distribution: {counts}. Stage IV has a longer treatment-start window."
|
| 541 |
else:
|
| 542 |
counts = {}
|
| 543 |
+
for patient in self._dataset:
|
| 544 |
+
value = str(patient.get(variable, "None"))
|
| 545 |
+
counts[value] = counts.get(value, 0) + 1
|
| 546 |
+
top_counts = dict(sorted(counts.items(), key=lambda item: -item[1])[:8])
|
| 547 |
+
feedback = f"{variable} distribution snapshot: {top_counts}."
|
| 548 |
+
|
| 549 |
+
reward = REWARD_CONFIG["investigate_new"]
|
| 550 |
+
if variable in {"enrollment_date", "stage"}:
|
| 551 |
+
reward += REWARD_CONFIG["bonus_protocol_window"] / 2
|
| 552 |
+
return reward, f"Investigated '{variable}': {feedback}"
|
| 553 |
+
|
| 554 |
+
def _grade_distribution(self, action: AuditAction) -> tuple[float, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
variable = action.variable or ""
|
| 556 |
if not variable:
|
| 557 |
return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
|
|
|
|
| 558 |
if variable in self._distributions_computed:
|
| 559 |
return (
|
| 560 |
REWARD_CONFIG["distribution_redundant"],
|
| 561 |
+
f"Distribution for '{variable}' already computed.",
|
| 562 |
)
|
| 563 |
|
| 564 |
self._distributions_computed.add(variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
+
control = [patient for patient in self._dataset if patient.get("group") == "control"]
|
| 567 |
if variable == "ethnicity":
|
| 568 |
+
counts = {}
|
| 569 |
+
for patient in control:
|
| 570 |
+
counts[patient["ethnicity"]] = counts.get(patient["ethnicity"], 0) + 1
|
| 571 |
+
total = len(control) or 1
|
| 572 |
+
feedback = ", ".join(
|
| 573 |
+
f"{key}={value} ({(value / total) * 100:.0f}%)"
|
| 574 |
+
for key, value in sorted(counts.items(), key=lambda item: -item[1])
|
| 575 |
+
)
|
| 576 |
+
message = f"Control-arm ethnicity distribution: {feedback}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
elif variable == "gender":
|
| 578 |
+
male = sum(patient.get("gender") == "M" for patient in control)
|
| 579 |
+
total = len(control) or 1
|
| 580 |
+
message = (
|
| 581 |
+
f"Control-arm gender mix: male={male}/{total} ({(male / total) * 100:.0f}%), "
|
| 582 |
+
f"female={total - male}/{total} ({((total - male) / total) * 100:.0f}%)."
|
| 583 |
+
)
|
| 584 |
+
elif variable == "outcome":
|
| 585 |
+
deceased = sum(patient.get("outcome") == "deceased" for patient in control)
|
| 586 |
+
total = len(control) or 1
|
| 587 |
+
message = (
|
| 588 |
+
f"Control-arm outcomes: deceased={deceased}/{total} ({(deceased / total) * 100:.0f}%), "
|
| 589 |
+
f"survived={total - deceased}/{total} ({((total - deceased) / total) * 100:.0f}%)."
|
| 590 |
+
)
|
| 591 |
else:
|
| 592 |
+
message = f"Distribution computed for '{variable}'."
|
| 593 |
|
| 594 |
+
return REWARD_CONFIG["distribution_new"], message
|
| 595 |
|
| 596 |
+
def _grade_flag(self, action: AuditAction) -> tuple[float, str]:
|
|
|
|
|
|
|
| 597 |
error_type = action.error_type or ""
|
| 598 |
+
if error_type not in self._current_task["allowed_error_types"]:
|
| 599 |
+
self._false_positive_flags += 1
|
| 600 |
+
return (
|
| 601 |
+
REWARD_CONFIG["false_positive"],
|
| 602 |
+
f"✗ Invalid error_type '{error_type}' for this task.",
|
| 603 |
+
)
|
| 604 |
|
|
|
|
| 605 |
if error_type == "selection_bias":
|
| 606 |
if not self._current_task["allow_bias"]:
|
| 607 |
+
self._false_positive_flags += 1
|
| 608 |
+
return REWARD_CONFIG["false_positive"], "✗ Bias review is not part of this task."
|
| 609 |
+
if not self._bias_review_ready():
|
| 610 |
+
self._invalid_phase_actions += 1
|
| 611 |
return (
|
| 612 |
+
REWARD_CONFIG["invalid_phase"],
|
| 613 |
+
"PHASE BLOCKED: Compute ethnicity, gender, and outcome distributions before flagging bias.",
|
| 614 |
)
|
|
|
|
| 615 |
if "BIAS_FLAG" in self._flagged_patients:
|
| 616 |
+
self._duplicate_flags += 1
|
| 617 |
+
return REWARD_CONFIG["duplicate_flag"], "Bias already flagged."
|
|
|
|
|
|
|
| 618 |
|
| 619 |
+
signal = self._bias_signal()
|
| 620 |
+
if self._bias_present and signal["signal_present"]:
|
| 621 |
+
self._flagged_patients.add("BIAS_FLAG")
|
| 622 |
+
self._correct_flags += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
return (
|
| 624 |
+
REWARD_CONFIG["bias_detected"],
|
| 625 |
+
"✓ Correct. Control-arm bias confirmed: "
|
| 626 |
+
f"{signal['dominant_ethnicity']}={signal['dominant_ratio']:.0%}, "
|
| 627 |
+
f"male={signal['male_ratio']:.0%}, "
|
| 628 |
+
f"stage-adjusted mortality gap={signal['stage_adjusted_gap']:.0%}.",
|
| 629 |
)
|
| 630 |
|
| 631 |
+
self._false_positive_flags += 1
|
|
|
|
| 632 |
return (
|
| 633 |
REWARD_CONFIG["false_positive"],
|
| 634 |
+
"✗ False positive. Current data show either no actionable bias or only a confounded "
|
| 635 |
+
f"high-risk cohort at {signal['high_risk_note']}. "
|
| 636 |
+
f"Overall gap={signal['overall_gap']:.0%}, stage-adjusted gap={signal['stage_adjusted_gap']:.0%}.",
|
| 637 |
)
|
| 638 |
|
| 639 |
+
patient_id = action.patient_id
|
| 640 |
+
if not patient_id:
|
| 641 |
+
self._false_positive_flags += 1
|
| 642 |
+
return REWARD_CONFIG["false_positive"], "REJECTED: patient_id is required for record-level errors."
|
| 643 |
if patient_id in self._flagged_patients:
|
| 644 |
+
self._duplicate_flags += 1
|
| 645 |
+
return REWARD_CONFIG["duplicate_flag"], f"{patient_id} already flagged."
|
|
|
|
|
|
|
| 646 |
|
| 647 |
+
patient = next((row for row in self._dataset if row.get("patient_id") == patient_id), None)
|
| 648 |
+
if patient is None:
|
| 649 |
+
self._false_positive_flags += 1
|
| 650 |
+
return REWARD_CONFIG["false_positive"], f"REJECTED: Patient '{patient_id}' not found."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
|
|
|
|
| 652 |
expected_errors = self._ground_truth.get(patient_id, [])
|
| 653 |
+
if error_type in expected_errors:
|
| 654 |
+
self._flagged_patients.add(patient_id)
|
| 655 |
+
self._correct_flags += 1
|
| 656 |
+
if error_type == "invalid_age":
|
|
|
|
|
|
|
| 657 |
return (
|
| 658 |
REWARD_CONFIG["correct_flag"],
|
| 659 |
+
f"✓ Correct: {patient_id} age={patient.get('age')} violates protocol range "
|
| 660 |
+
f"{self._protocol['age_min']}-{self._protocol['age_max']}.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
)
|
| 662 |
+
if error_type == "temporal_inconsistency":
|
| 663 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 664 |
+
death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
|
| 665 |
+
gap = (treatment_start - death_date).days
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
return (
|
| 667 |
REWARD_CONFIG["correct_flag"],
|
| 668 |
+
f"✓ Correct: {patient_id} death_date is {gap} days before treatment_start.",
|
| 669 |
+
)
|
| 670 |
+
if error_type == "protocol_window_violation":
|
| 671 |
+
enrollment = datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
|
| 672 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 673 |
+
delay = (treatment_start - enrollment).days
|
| 674 |
+
allowed = (
|
| 675 |
+
self._protocol["stage_iv_treatment_window_days"]
|
| 676 |
+
if patient["stage"] == "IV"
|
| 677 |
+
else self._protocol["treatment_window_days"]
|
| 678 |
)
|
|
|
|
| 679 |
return (
|
| 680 |
+
REWARD_CONFIG["correct_flag"] + REWARD_CONFIG["bonus_protocol_window"] / 2,
|
| 681 |
+
f"✓ Correct: {patient_id} started treatment after {delay} days; protocol allows only {allowed}.",
|
| 682 |
)
|
| 683 |
|
| 684 |
+
self._false_positive_flags += 1
|
| 685 |
+
if error_type == "invalid_age":
|
| 686 |
+
return (
|
| 687 |
+
REWARD_CONFIG["false_positive"],
|
| 688 |
+
f"✗ False positive: {patient_id} age={patient.get('age')} is valid for protocol range "
|
| 689 |
+
f"{self._protocol['age_min']}-{self._protocol['age_max']}.",
|
| 690 |
+
)
|
| 691 |
+
if error_type == "temporal_inconsistency":
|
| 692 |
+
return (
|
| 693 |
+
REWARD_CONFIG["false_positive"],
|
| 694 |
+
f"✗ False positive: {patient_id} has a valid death/treatment ordering.",
|
| 695 |
+
)
|
| 696 |
+
if error_type == "protocol_window_violation":
|
| 697 |
+
enrollment = datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
|
| 698 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 699 |
+
delay = (treatment_start - enrollment).days
|
| 700 |
+
allowed = (
|
| 701 |
+
self._protocol["stage_iv_treatment_window_days"]
|
| 702 |
+
if patient["stage"] == "IV"
|
| 703 |
+
else self._protocol["treatment_window_days"]
|
| 704 |
+
)
|
| 705 |
return (
|
| 706 |
REWARD_CONFIG["false_positive"],
|
| 707 |
+
f"✗ False positive: {patient_id} started treatment after {delay} days, which is valid for stage "
|
| 708 |
+
f"{patient['stage']} (allowed {allowed}).",
|
| 709 |
)
|
| 710 |
|
| 711 |
+
return REWARD_CONFIG["false_positive"], f"✗ Invalid error_type '{error_type}'."
|
| 712 |
+
|
| 713 |
+
def _grade_propose_fix(self, action: AuditAction) -> tuple[float, str]:
|
| 714 |
patient_id = action.patient_id or ""
|
| 715 |
if patient_id not in self._flagged_patients:
|
| 716 |
+
return REWARD_CONFIG["propose_fix_invalid"], "Can only propose a fix for a flagged patient."
|
|
|
|
|
|
|
|
|
|
| 717 |
proposed = action.proposed_value or ""
|
| 718 |
if len(proposed) > 2:
|
| 719 |
+
return REWARD_CONFIG["propose_fix_valid"], f"Fix proposed for {patient_id}."
|
| 720 |
+
return REWARD_CONFIG["propose_fix_invalid"], "Proposed fix is too vague."
|
|
|
|
|
|
|
|
|
|
| 721 |
|
| 722 |
+
def _grade_report(self, action: AuditAction) -> tuple[float, str]:
|
|
|
|
| 723 |
self._report_submitted = True
|
| 724 |
report = (action.report or action.reason or "").lower()
|
| 725 |
+
quality = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
quality_items = []
|
| 727 |
|
| 728 |
+
if any(keyword in report for keyword in ["protocol", "eligibility", "inclusion", "excerpt"]):
|
| 729 |
+
quality += 1
|
| 730 |
+
quality_items.append("protocol grounding")
|
| 731 |
+
if any(keyword in report for keyword in ["root cause", "data entry", "pipeline", "system", "site process"]):
|
| 732 |
+
quality += 1
|
| 733 |
+
quality_items.append("root cause")
|
| 734 |
+
if any(keyword in report for keyword in ["recommend", "corrective", "action", "mitigation"]):
|
| 735 |
+
quality += 1
|
| 736 |
+
quality_items.append("corrective action")
|
| 737 |
+
if any(keyword in report for keyword in ["risk", "severity", "impact", "patient safety"]):
|
| 738 |
+
quality += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
quality_items.append("risk assessment")
|
| 740 |
+
if any(keyword in report for keyword in ["bias", "stage-adjusted", "fairness", "control arm", "equity"]):
|
| 741 |
+
quality += 1
|
| 742 |
+
quality_items.append("fairness reasoning")
|
| 743 |
+
|
| 744 |
+
self._report_quality = float(quality)
|
| 745 |
+
reward = REWARD_CONFIG["report_bonus_base"] + (0.015 * quality)
|
| 746 |
+
return reward, (
|
| 747 |
+
f"Report submitted. Quality {quality}/5"
|
| 748 |
+
+ (f" ({', '.join(quality_items)})" if quality_items else "")
|
| 749 |
+
+ "."
|
| 750 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/dataset_generator.py
CHANGED
|
@@ -1,34 +1,24 @@
|
|
| 1 |
"""
|
| 2 |
Procedural Adversarial Clinical Trial Data Engine
|
| 3 |
-
=================================================
|
| 4 |
-
Generates
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
|
| 13 |
-
Architecture layers:
|
| 14 |
-
1. Base Patient Generator — realistic demographics via statistical distributions
|
| 15 |
-
2. Error Injector — controlled % of age/temporal/missing violations
|
| 16 |
-
3. Bias Injector — demographic skew + outcome disparity in control group
|
| 17 |
-
4. Trap Injector — boundary-valid, near-temporal, fake-pattern distractors
|
| 18 |
-
5. Ground Truth Tracker — records every injected error for deterministic grading
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
import
|
| 22 |
-
|
| 23 |
import hashlib
|
|
|
|
| 24 |
from datetime import datetime, timedelta
|
| 25 |
from typing import Optional
|
| 26 |
|
| 27 |
|
| 28 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 29 |
-
# REFERENCE DATA — Realistic clinical trial metadata pools
|
| 30 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 31 |
-
|
| 32 |
HOSPITAL_SITES = [
|
| 33 |
("Metro General Hospital", "US"),
|
| 34 |
("Cleveland Oncology Institute", "US"),
|
|
@@ -37,8 +27,8 @@ HOSPITAL_SITES = [
|
|
| 37 |
("MD Anderson Cancer Center", "US"),
|
| 38 |
("AIIMS Delhi", "India"),
|
| 39 |
("Tata Memorial Hospital", "India"),
|
| 40 |
-
("
|
| 41 |
-
("Hospital
|
| 42 |
("Tokyo Medical University", "Japan"),
|
| 43 |
("Seoul National University Hospital", "South Korea"),
|
| 44 |
("Royal Marsden Hospital", "UK"),
|
|
@@ -47,98 +37,97 @@ HOSPITAL_SITES = [
|
|
| 47 |
("Peter MacCallum Cancer Centre", "Australia"),
|
| 48 |
]
|
| 49 |
|
| 50 |
-
# Sites considered "rural" or underrepresented for bias analysis
|
| 51 |
RURAL_SITES = {
|
| 52 |
-
"AIIMS Delhi",
|
| 53 |
"Howard University Hospital",
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
-
ETHNICITIES = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
GENDERS = ["M", "F"]
|
| 58 |
STAGES = ["I", "II", "III", "IV"]
|
| 59 |
DRUGS_TREATMENT = ["ImmunoVax-7", "OncoShield-X", "TargetCure-3"]
|
| 60 |
-
DRUGS_CONTROL = ["Placebo"]
|
| 61 |
|
| 62 |
-
# Date range for the trial
|
| 63 |
TRIAL_START = datetime(2022, 6, 1)
|
| 64 |
TRIAL_END = datetime(2025, 3, 1)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
DIFFICULTY_CONFIGS = {
|
| 71 |
"easy": {
|
| 72 |
"dataset_size": 300,
|
| 73 |
-
"age_error_rate": 0.
|
| 74 |
-
"
|
| 75 |
-
"
|
| 76 |
-
"
|
| 77 |
-
"num_boundary_traps":
|
| 78 |
"num_temporal_traps": 0,
|
| 79 |
-
"
|
|
|
|
| 80 |
"num_fake_bias_distractors": 0,
|
| 81 |
-
"
|
| 82 |
-
"control_ratio": 0.50,
|
| 83 |
-
"task_type": "
|
| 84 |
-
"allow_bias": False,
|
| 85 |
},
|
| 86 |
"medium": {
|
| 87 |
-
"dataset_size":
|
| 88 |
-
"age_error_rate": 0.
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
"num_boundary_traps":
|
| 93 |
-
"num_temporal_traps":
|
| 94 |
-
"
|
|
|
|
| 95 |
"num_fake_bias_distractors": 0,
|
| 96 |
-
"
|
| 97 |
"control_ratio": 0.50,
|
| 98 |
-
"task_type": "
|
| 99 |
-
"allow_bias": False,
|
| 100 |
},
|
| 101 |
"hard": {
|
| 102 |
-
"dataset_size":
|
| 103 |
-
"age_error_rate": 0.
|
| 104 |
-
"
|
| 105 |
-
"
|
| 106 |
-
"
|
| 107 |
-
"num_boundary_traps":
|
| 108 |
-
"num_temporal_traps":
|
|
|
|
| 109 |
"num_distractor_deceased": 8,
|
| 110 |
-
"num_fake_bias_distractors":
|
| 111 |
-
"
|
| 112 |
"control_ratio": 0.50,
|
| 113 |
-
"task_type": "
|
| 114 |
-
"allow_bias": True,
|
| 115 |
},
|
| 116 |
}
|
| 117 |
|
| 118 |
|
| 119 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 120 |
-
# DATASET GENERATOR
|
| 121 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 122 |
-
|
| 123 |
class DatasetGenerator:
|
| 124 |
-
"""
|
| 125 |
-
Procedural adversarial clinical trial data engine.
|
| 126 |
-
|
| 127 |
-
Generates statistically rigorous patient datasets with:
|
| 128 |
-
- Configurable size (300-1000+ patients)
|
| 129 |
-
- Controlled error injection (age, temporal, missing data)
|
| 130 |
-
- Controllable bias intensity (representation + outcome disparity)
|
| 131 |
-
- Adversarial traps (boundary-valid, near-temporal, fake patterns)
|
| 132 |
-
- Seed-based reproducibility (same seed → identical dataset)
|
| 133 |
-
|
| 134 |
-
Usage:
|
| 135 |
-
gen = DatasetGenerator(seed=42)
|
| 136 |
-
result = gen.generate(difficulty="hard")
|
| 137 |
-
dataset = result["dataset"] # List[dict] — patient records
|
| 138 |
-
ground_truth = result["ground_truth"] # Dict[str, List[str]] — {pid: [error_types]}
|
| 139 |
-
traps = result["traps"] # Set[str] — valid-but-suspicious pids
|
| 140 |
-
bias_present = result["bias_present"] # bool
|
| 141 |
-
"""
|
| 142 |
|
| 143 |
def __init__(self, seed: Optional[int] = None):
|
| 144 |
self.seed = seed
|
|
@@ -151,58 +140,122 @@ class DatasetGenerator:
|
|
| 151 |
self._patient_counter += 1
|
| 152 |
return f"P{self._patient_counter:04d}"
|
| 153 |
|
|
|
|
|
|
|
|
|
|
| 154 |
def _random_date(self, start: datetime, end: datetime) -> datetime:
|
| 155 |
-
"""Generate a random date between start and end."""
|
| 156 |
delta = (end - start).days
|
| 157 |
if delta <= 0:
|
| 158 |
return start
|
| 159 |
return start + timedelta(days=self.rng.randint(0, delta))
|
| 160 |
|
| 161 |
-
def
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
while True:
|
| 165 |
-
age = int(self.rng.gauss(58,
|
| 166 |
-
if
|
| 167 |
return age
|
| 168 |
|
| 169 |
def _select_ethnicity(self, bias_mode: str = "neutral") -> str:
|
| 170 |
-
""
|
| 171 |
-
|
| 172 |
-
bias_mode
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
weights = [0.
|
| 176 |
-
elif bias_mode == "diverse":
|
| 177 |
-
weights = [0.30, 0.20, 0.20, 0.15, 0.10, 0.05]
|
| 178 |
-
else: # neutral — matches US clinical trial demographics
|
| 179 |
-
weights = [0.55, 0.15, 0.15, 0.10, 0.03, 0.02]
|
| 180 |
-
|
| 181 |
return self.rng.choices(ETHNICITIES, weights=weights, k=1)[0]
|
| 182 |
|
| 183 |
-
def
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
pid = self._next_pid()
|
| 187 |
site, country = self.rng.choice(HOSPITAL_SITES)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
treatment_start = enrollment_date + timedelta(days=self.rng.randint(7, 30))
|
| 195 |
-
|
| 196 |
-
if group == "treatment":
|
| 197 |
-
drug = self.rng.choice(DRUGS_TREATMENT)
|
| 198 |
-
else:
|
| 199 |
-
drug = "Placebo"
|
| 200 |
-
|
| 201 |
-
patient = {
|
| 202 |
"patient_id": pid,
|
| 203 |
"age": age,
|
| 204 |
-
"gender":
|
| 205 |
-
"ethnicity":
|
| 206 |
"group": group,
|
| 207 |
"treatment_start": treatment_start.strftime("%Y-%m-%d"),
|
| 208 |
"death_date": None,
|
|
@@ -210,459 +263,422 @@ class DatasetGenerator:
|
|
| 210 |
"treatment_site": site,
|
| 211 |
"stage": stage,
|
| 212 |
"trial_phase": "Phase III",
|
| 213 |
-
"drug":
|
| 214 |
"enrollment_date": enrollment_date.strftime("%Y-%m-%d"),
|
| 215 |
"country": country,
|
| 216 |
}
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
return patient
|
| 219 |
|
| 220 |
-
def
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
death_date =
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
|
|
|
|
| 236 |
|
| 237 |
-
def
|
| 238 |
-
|
| 239 |
-
"""Inject invalid age values into random patients."""
|
| 240 |
-
n_age_errors = max(3, int(len(patients) * error_rate))
|
| 241 |
-
n_missing = max(1, int(len(patients) * missing_rate))
|
| 242 |
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
available = list(range(len(patients)))
|
| 245 |
self.rng.shuffle(available)
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
for _ in range(n_age_errors):
|
| 250 |
-
error_kind = self.rng.choice([
|
| 251 |
-
"negative", "extreme_high", "sentinel", "just_over"
|
| 252 |
-
])
|
| 253 |
-
if error_kind == "negative":
|
| 254 |
-
invalid_ages.append(self.rng.choice([-1, -5, -10, -3, -15]))
|
| 255 |
-
elif error_kind == "extreme_high":
|
| 256 |
-
invalid_ages.append(self.rng.choice([150, 200, 250, 300, 500]))
|
| 257 |
-
elif error_kind == "sentinel":
|
| 258 |
-
invalid_ages.append(self.rng.choice([999, 9999, 0, -999]))
|
| 259 |
-
elif error_kind == "just_over":
|
| 260 |
-
invalid_ages.append(self.rng.choice([121, 122, 125, 130, 17, 16, 15]))
|
| 261 |
-
|
| 262 |
-
for i, invalid_age in enumerate(invalid_ages):
|
| 263 |
-
if i >= len(available):
|
| 264 |
-
break
|
| 265 |
-
idx = available[i]
|
| 266 |
-
patients[idx]["age"] = invalid_age
|
| 267 |
-
pid = patients[idx]["patient_id"]
|
| 268 |
-
self._ground_truth.setdefault(pid, []).append("invalid_age")
|
| 269 |
-
|
| 270 |
-
# Missing age (None)
|
| 271 |
-
offset = len(invalid_ages)
|
| 272 |
-
for j in range(n_missing):
|
| 273 |
-
if offset + j >= len(available):
|
| 274 |
-
break
|
| 275 |
-
idx = available[offset + j]
|
| 276 |
-
patients[idx]["age"] = None
|
| 277 |
-
pid = patients[idx]["patient_id"]
|
| 278 |
-
self._ground_truth.setdefault(pid, []).append("invalid_age")
|
| 279 |
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
| 286 |
|
| 287 |
-
|
| 288 |
-
candidates = []
|
| 289 |
-
for i, p in enumerate(patients):
|
| 290 |
-
pid = p["patient_id"]
|
| 291 |
-
# Don't stack errors on patients already with age errors
|
| 292 |
-
if pid not in self._ground_truth:
|
| 293 |
-
candidates.append(i)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
| 295 |
self.rng.shuffle(candidates)
|
| 296 |
|
| 297 |
-
for
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
gap_days = self.rng.randint(15, 365)
|
| 304 |
-
death_date = treatment_start - timedelta(days=gap_days)
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
-
|
| 310 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
return patients
|
| 313 |
|
| 314 |
-
def
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
4. Site bias: Minorities underrepresented at major sites
|
| 323 |
-
"""
|
| 324 |
-
if intensity <= 0:
|
| 325 |
-
return patients
|
| 326 |
-
|
| 327 |
-
control_patients = [p for p in patients if p["group"] == "control"]
|
| 328 |
-
treatment_patients = [p for p in patients if p["group"] == "treatment"]
|
| 329 |
-
|
| 330 |
-
if not control_patients:
|
| 331 |
-
return patients
|
| 332 |
-
|
| 333 |
-
# ── Layer 1: Representation bias ──
|
| 334 |
-
# Force >75% of control to be White
|
| 335 |
-
target_white_ratio = 0.75 + (intensity * 0.10) # 0.75-0.85
|
| 336 |
-
n_control = len(control_patients)
|
| 337 |
-
n_white_target = int(n_control * target_white_ratio)
|
| 338 |
-
n_white_current = sum(1 for p in control_patients if p["ethnicity"] == "White")
|
| 339 |
-
|
| 340 |
-
# Convert some non-White control patients to White
|
| 341 |
-
non_white_control = [p for p in control_patients if p["ethnicity"] != "White"]
|
| 342 |
-
to_convert = max(0, n_white_target - n_white_current)
|
| 343 |
-
self.rng.shuffle(non_white_control)
|
| 344 |
-
for i in range(min(to_convert, len(non_white_control))):
|
| 345 |
-
non_white_control[i]["ethnicity"] = "White"
|
| 346 |
-
|
| 347 |
-
# ── Layer 2: Gender imbalance in control ──
|
| 348 |
-
# Force >65% male in control
|
| 349 |
-
target_male_ratio = 0.65 + (intensity * 0.10)
|
| 350 |
-
n_male_target = int(n_control * target_male_ratio)
|
| 351 |
-
n_male_current = sum(1 for p in control_patients if p["gender"] == "M")
|
| 352 |
-
female_control = [p for p in control_patients if p["gender"] == "F"]
|
| 353 |
-
to_convert_gender = max(0, n_male_target - n_male_current)
|
| 354 |
-
self.rng.shuffle(female_control)
|
| 355 |
-
for i in range(min(to_convert_gender, len(female_control))):
|
| 356 |
-
female_control[i]["gender"] = "M"
|
| 357 |
-
|
| 358 |
-
# ── Layer 3: Outcome disparity ──
|
| 359 |
-
# Minority patients in control → higher mortality (>60%)
|
| 360 |
-
minority_control = [
|
| 361 |
-
p for p in control_patients
|
| 362 |
-
if p["ethnicity"] != "White" and p["patient_id"] not in self._ground_truth
|
| 363 |
]
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
if i < n_minority_dead:
|
| 369 |
-
if p["outcome"] != "deceased":
|
| 370 |
-
treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
|
| 371 |
-
death_date = treatment_start + timedelta(
|
| 372 |
-
days=self.rng.randint(30, 365)
|
| 373 |
-
)
|
| 374 |
-
p["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 375 |
-
p["outcome"] = "deceased"
|
| 376 |
-
|
| 377 |
-
# ── Layer 4: White control patients → low mortality ──
|
| 378 |
-
white_control = [
|
| 379 |
-
p for p in control_patients
|
| 380 |
-
if p["ethnicity"] == "White" and p["patient_id"] not in self._ground_truth
|
| 381 |
]
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
if i < n_white_alive:
|
| 387 |
-
p["death_date"] = None
|
| 388 |
-
p["outcome"] = "survived"
|
| 389 |
-
|
| 390 |
-
# ── Layer 5: Rural minority underrepresentation ──
|
| 391 |
-
for p in minority_control:
|
| 392 |
-
if p["treatment_site"] in RURAL_SITES:
|
| 393 |
-
# Move some to major sites (reducing rural minority visibility)
|
| 394 |
-
if self.rng.random() < intensity * 0.5:
|
| 395 |
-
major_sites = [
|
| 396 |
-
s for s in HOSPITAL_SITES
|
| 397 |
-
if s[0] not in RURAL_SITES
|
| 398 |
-
]
|
| 399 |
-
new_site = self.rng.choice(major_sites)
|
| 400 |
-
p["treatment_site"] = new_site[0]
|
| 401 |
-
p["country"] = new_site[1]
|
| 402 |
-
|
| 403 |
return patients
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
def _inject_boundary_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
|
| 408 |
-
"""
|
| 409 |
-
Inject boundary-valid ages that trap naive agents.
|
| 410 |
-
Ages like 18, 19, 120 are VALID but suspicious.
|
| 411 |
-
"""
|
| 412 |
-
boundary_ages = [18, 19, 20, 90, 92, 95, 96, 100, 105, 110, 115, 118, 119, 120, 120]
|
| 413 |
-
self.rng.shuffle(boundary_ages) # Randomize which traps appear
|
| 414 |
available = [
|
| 415 |
-
|
|
|
|
| 416 |
if p["patient_id"] not in self._ground_truth
|
| 417 |
-
and p["
|
|
|
|
| 418 |
]
|
| 419 |
self.rng.shuffle(available)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
return patients
|
| 427 |
|
| 428 |
-
def
|
| 429 |
-
"""
|
| 430 |
-
Inject near-temporal valid cases: death 1-3 days AFTER treatment start.
|
| 431 |
-
These are VALID but look like errors to careless agents.
|
| 432 |
-
"""
|
| 433 |
available = [
|
| 434 |
-
|
|
|
|
| 435 |
if p["patient_id"] not in self._ground_truth
|
| 436 |
-
and p["death_date"] is None
|
| 437 |
and p["patient_id"] not in self._traps
|
|
|
|
| 438 |
]
|
| 439 |
self.rng.shuffle(available)
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
p = patients[idx]
|
| 444 |
-
treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
|
| 445 |
-
# Death 1-3 days AFTER treatment — valid but suspicious
|
| 446 |
-
gap = self.rng.randint(1, 3)
|
| 447 |
-
death_date = treatment_start + timedelta(days=gap)
|
| 448 |
-
p["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 449 |
-
p["outcome"] = "deceased"
|
| 450 |
-
p["stage"] = "IV" # Make it medically plausible (late-stage)
|
| 451 |
-
self._traps.add(p["patient_id"])
|
| 452 |
-
|
| 453 |
return patients
|
| 454 |
|
| 455 |
-
def _inject_fake_bias_distractors(self, patients: list[dict],
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
E.g., treatment group with demographic skew (doesn't matter for bias detection
|
| 460 |
-
since only control group bias is relevant).
|
| 461 |
-
"""
|
| 462 |
-
treatment_patients = [
|
| 463 |
-
i for i, p in enumerate(patients)
|
| 464 |
if p["group"] == "treatment"
|
| 465 |
and p["patient_id"] not in self._ground_truth
|
| 466 |
and p["patient_id"] not in self._traps
|
| 467 |
]
|
| 468 |
-
self.rng.shuffle(
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
self._traps.add(
|
| 476 |
-
|
| 477 |
return patients
|
| 478 |
|
| 479 |
-
def
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
"""
|
| 485 |
-
available = [
|
| 486 |
-
i for i, p in enumerate(patients)
|
| 487 |
-
if p["patient_id"] not in self._ground_truth
|
| 488 |
-
and p["death_date"] is None
|
| 489 |
-
and p["patient_id"] not in self._traps
|
| 490 |
]
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
def generate(self, difficulty: str = "easy") -> dict:
|
| 509 |
-
"""
|
| 510 |
-
Generate a complete adversarial dataset for the given difficulty.
|
| 511 |
-
|
| 512 |
-
Returns:
|
| 513 |
-
{
|
| 514 |
-
"dataset": List[dict], # Patient records
|
| 515 |
-
"ground_truth": Dict[str, List[str]], # {pid: [error_types]}
|
| 516 |
-
"traps": Set[str], # Valid-but-suspicious pids
|
| 517 |
-
"bias_present": bool, # Whether bias was injected
|
| 518 |
-
"config": dict, # Generation parameters
|
| 519 |
-
"stats": dict, # Summary statistics
|
| 520 |
-
}
|
| 521 |
-
"""
|
| 522 |
config = DIFFICULTY_CONFIGS.get(difficulty, DIFFICULTY_CONFIGS["easy"])
|
| 523 |
self._ground_truth = {}
|
| 524 |
self._traps = set()
|
| 525 |
self._patient_counter = 0
|
| 526 |
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
|
|
|
| 530 |
|
| 531 |
-
# ── Step 1: Generate base patients ──
|
| 532 |
patients = []
|
| 533 |
-
|
| 534 |
-
# Determine bias mode for control group
|
| 535 |
-
control_bias_mode = "white_dominant" if config["bias_intensity"] > 0 else "neutral"
|
| 536 |
-
|
| 537 |
for _ in range(n_control):
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
patients.append(p)
|
| 541 |
|
| 542 |
for _ in range(n_treatment):
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
patients.append(p)
|
| 546 |
-
|
| 547 |
-
# ── Step 2: Inject errors ──
|
| 548 |
-
patients = self._inject_age_errors(
|
| 549 |
-
patients, config["age_error_rate"], config["missing_data_rate"]
|
| 550 |
-
)
|
| 551 |
|
|
|
|
| 552 |
if config["temporal_error_rate"] > 0:
|
| 553 |
-
patients = self._inject_temporal_errors(
|
| 554 |
-
|
| 555 |
-
)
|
| 556 |
-
|
| 557 |
-
# ── Step 3: Inject bias (hard only) ──
|
| 558 |
-
if config["bias_intensity"] > 0:
|
| 559 |
-
patients = self._inject_bias(patients, config["bias_intensity"])
|
| 560 |
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
| 563 |
|
|
|
|
| 564 |
if config["num_temporal_traps"] > 0:
|
| 565 |
-
patients = self._inject_temporal_traps(
|
| 566 |
-
|
| 567 |
-
)
|
| 568 |
-
|
| 569 |
if config["num_fake_bias_distractors"] > 0:
|
| 570 |
-
patients = self._inject_fake_bias_distractors(
|
| 571 |
-
patients, config["num_fake_bias_distractors"]
|
| 572 |
-
)
|
| 573 |
|
| 574 |
-
patients = self._inject_distractor_deceased(
|
| 575 |
-
patients, config["num_distractor_deceased"]
|
| 576 |
-
)
|
| 577 |
-
|
| 578 |
-
# ── Step 5: Shuffle dataset ──
|
| 579 |
self.rng.shuffle(patients)
|
| 580 |
|
| 581 |
-
# ── Step 6: Compute summary stats ──
|
| 582 |
-
n_age_errors = sum(
|
| 583 |
-
1 for errs in self._ground_truth.values()
|
| 584 |
-
if "invalid_age" in errs
|
| 585 |
-
)
|
| 586 |
-
n_temporal_errors = sum(
|
| 587 |
-
1 for errs in self._ground_truth.values()
|
| 588 |
-
if "temporal_inconsistency" in errs
|
| 589 |
-
)
|
| 590 |
-
total_errors = n_age_errors + n_temporal_errors
|
| 591 |
-
if config["bias_intensity"] > 0:
|
| 592 |
-
total_errors += 1 # bias counts as 1 error
|
| 593 |
-
|
| 594 |
stats = {
|
| 595 |
"total_patients": len(patients),
|
| 596 |
-
"
|
| 597 |
-
"
|
| 598 |
-
"
|
| 599 |
-
"bias_present":
|
|
|
|
| 600 |
"num_traps": len(self._traps),
|
| 601 |
"control_count": sum(1 for p in patients if p["group"] == "control"),
|
| 602 |
"treatment_count": sum(1 for p in patients if p["group"] == "treatment"),
|
|
|
|
| 603 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
return {
|
| 606 |
"dataset": patients,
|
| 607 |
"ground_truth": dict(self._ground_truth),
|
| 608 |
"traps": set(self._traps),
|
| 609 |
-
"bias_present":
|
|
|
|
|
|
|
|
|
|
| 610 |
"config": config,
|
| 611 |
"stats": stats,
|
| 612 |
}
|
| 613 |
|
| 614 |
|
| 615 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 616 |
-
# STANDALONE TEST
|
| 617 |
-
# ═══════════════════════════════════════════════════════════════════════════
|
| 618 |
-
|
| 619 |
if __name__ == "__main__":
|
| 620 |
print("=" * 60)
|
| 621 |
print(" Dataset Generator — Validation Test")
|
| 622 |
print("=" * 60)
|
| 623 |
|
| 624 |
-
for
|
| 625 |
-
|
| 626 |
-
result =
|
| 627 |
stats = result["stats"]
|
| 628 |
-
|
|
|
|
|
|
|
| 629 |
print(f" Patients: {stats['total_patients']}")
|
| 630 |
-
print(
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
| 633 |
print(f" Traps: {stats['num_traps']}")
|
| 634 |
print(f" Control: {stats['control_count']}")
|
| 635 |
print(f" Treatment: {stats['treatment_count']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
assert result["
|
| 641 |
-
assert result["
|
| 642 |
-
print(
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
if err == "invalid_age":
|
| 649 |
age = patient.get("age")
|
| 650 |
-
assert age is None or age <
|
| 651 |
-
f"Ground truth says {
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
print(f"\n{'=' * 60}")
|
| 667 |
-
print(
|
| 668 |
print(f"{'=' * 60}")
|
|
|
|
| 1 |
"""
|
| 2 |
Procedural Adversarial Clinical Trial Data Engine
|
| 3 |
+
=================================================
|
| 4 |
+
Generates seeded, protocol-driven clinical trial datasets for OpenEnv episodes.
|
| 5 |
+
|
| 6 |
+
This generator is intentionally benchmark-oriented:
|
| 7 |
+
- each episode samples a different protocol excerpt and hidden rule set
|
| 8 |
+
- age eligibility is protocol-specific, not a fixed 18-120 shortcut
|
| 9 |
+
- treatment scheduling uses stage-aware exceptions to create valid edge cases
|
| 10 |
+
- hard episodes alternate between true bias and confounded "looks bad" cohorts
|
| 11 |
+
- all labels remain deterministic and reproducible from the seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
import hashlib
|
| 17 |
+
import random
|
| 18 |
from datetime import datetime, timedelta
|
| 19 |
from typing import Optional
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
HOSPITAL_SITES = [
|
| 23 |
("Metro General Hospital", "US"),
|
| 24 |
("Cleveland Oncology Institute", "US"),
|
|
|
|
| 27 |
("MD Anderson Cancer Center", "US"),
|
| 28 |
("AIIMS Delhi", "India"),
|
| 29 |
("Tata Memorial Hospital", "India"),
|
| 30 |
+
("Charite Berlin", "Germany"),
|
| 31 |
+
("Hospital Clinic Barcelona", "Spain"),
|
| 32 |
("Tokyo Medical University", "Japan"),
|
| 33 |
("Seoul National University Hospital", "South Korea"),
|
| 34 |
("Royal Marsden Hospital", "UK"),
|
|
|
|
| 37 |
("Peter MacCallum Cancer Centre", "Australia"),
|
| 38 |
]
|
| 39 |
|
|
|
|
| 40 |
RURAL_SITES = {
|
| 41 |
+
"AIIMS Delhi",
|
| 42 |
"Howard University Hospital",
|
| 43 |
+
"Tata Memorial Hospital",
|
| 44 |
}
|
| 45 |
|
| 46 |
+
ETHNICITIES = [
|
| 47 |
+
"White",
|
| 48 |
+
"Black",
|
| 49 |
+
"Hispanic",
|
| 50 |
+
"Asian",
|
| 51 |
+
"Native American",
|
| 52 |
+
"Pacific Islander",
|
| 53 |
+
]
|
| 54 |
GENDERS = ["M", "F"]
|
| 55 |
STAGES = ["I", "II", "III", "IV"]
|
| 56 |
DRUGS_TREATMENT = ["ImmunoVax-7", "OncoShield-X", "TargetCure-3"]
|
|
|
|
| 57 |
|
|
|
|
| 58 |
TRIAL_START = datetime(2022, 6, 1)
|
| 59 |
TRIAL_END = datetime(2025, 3, 1)
|
| 60 |
|
| 61 |
+
BASE_STAGE_MORTALITY = {
|
| 62 |
+
"I": 0.04,
|
| 63 |
+
"II": 0.08,
|
| 64 |
+
"III": 0.16,
|
| 65 |
+
"IV": 0.32,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
AGE_RULESETS = {
|
| 69 |
+
"easy": [(35, 75), (40, 80), (45, 85)],
|
| 70 |
+
"medium": [(18, 75), (21, 80), (30, 85), (40, 90)],
|
| 71 |
+
"hard": [(18, 75), (21, 80), (30, 85), (35, 85), (40, 90)],
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
WINDOW_RULESETS = {
|
| 75 |
+
"easy": [21, 24, 28],
|
| 76 |
+
"medium": [18, 21, 24, 28],
|
| 77 |
+
"hard": [14, 18, 21, 24],
|
| 78 |
+
}
|
| 79 |
|
| 80 |
DIFFICULTY_CONFIGS = {
|
| 81 |
"easy": {
|
| 82 |
"dataset_size": 300,
|
| 83 |
+
"age_error_rate": 0.020,
|
| 84 |
+
"missing_age_rate": 0.007,
|
| 85 |
+
"temporal_error_rate": 0.0,
|
| 86 |
+
"protocol_window_error_rate": 0.0,
|
| 87 |
+
"num_boundary_traps": 8,
|
| 88 |
"num_temporal_traps": 0,
|
| 89 |
+
"num_window_traps": 0,
|
| 90 |
+
"num_distractor_deceased": 5,
|
| 91 |
"num_fake_bias_distractors": 0,
|
| 92 |
+
"bias_probability": 0.0,
|
| 93 |
+
"control_ratio": 0.50,
|
| 94 |
+
"task_type": "eligibility_screening",
|
|
|
|
| 95 |
},
|
| 96 |
"medium": {
|
| 97 |
+
"dataset_size": 480,
|
| 98 |
+
"age_error_rate": 0.018,
|
| 99 |
+
"missing_age_rate": 0.007,
|
| 100 |
+
"temporal_error_rate": 0.012,
|
| 101 |
+
"protocol_window_error_rate": 0.015,
|
| 102 |
+
"num_boundary_traps": 10,
|
| 103 |
+
"num_temporal_traps": 4,
|
| 104 |
+
"num_window_traps": 5,
|
| 105 |
+
"num_distractor_deceased": 6,
|
| 106 |
"num_fake_bias_distractors": 0,
|
| 107 |
+
"bias_probability": 0.0,
|
| 108 |
"control_ratio": 0.50,
|
| 109 |
+
"task_type": "protocol_timeline_audit",
|
|
|
|
| 110 |
},
|
| 111 |
"hard": {
|
| 112 |
+
"dataset_size": 720,
|
| 113 |
+
"age_error_rate": 0.017,
|
| 114 |
+
"missing_age_rate": 0.006,
|
| 115 |
+
"temporal_error_rate": 0.010,
|
| 116 |
+
"protocol_window_error_rate": 0.014,
|
| 117 |
+
"num_boundary_traps": 12,
|
| 118 |
+
"num_temporal_traps": 5,
|
| 119 |
+
"num_window_traps": 7,
|
| 120 |
"num_distractor_deceased": 8,
|
| 121 |
+
"num_fake_bias_distractors": 8,
|
| 122 |
+
"bias_probability": 0.58,
|
| 123 |
"control_ratio": 0.50,
|
| 124 |
+
"task_type": "equity_and_protocol_audit",
|
|
|
|
| 125 |
},
|
| 126 |
}
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
class DatasetGenerator:
|
| 130 |
+
"""Seeded benchmark dataset generator."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
def __init__(self, seed: Optional[int] = None):
|
| 133 |
self.seed = seed
|
|
|
|
| 140 |
self._patient_counter += 1
|
| 141 |
return f"P{self._patient_counter:04d}"
|
| 142 |
|
| 143 |
+
def _mark_error(self, patient_id: str, error_type: str) -> None:
|
| 144 |
+
self._ground_truth.setdefault(patient_id, []).append(error_type)
|
| 145 |
+
|
| 146 |
def _random_date(self, start: datetime, end: datetime) -> datetime:
|
|
|
|
| 147 |
delta = (end - start).days
|
| 148 |
if delta <= 0:
|
| 149 |
return start
|
| 150 |
return start + timedelta(days=self.rng.randint(0, delta))
|
| 151 |
|
| 152 |
+
def _build_protocol(self, difficulty: str, config: dict) -> dict:
|
| 153 |
+
age_min, age_max = self.rng.choice(AGE_RULESETS[difficulty])
|
| 154 |
+
treatment_window = self.rng.choice(WINDOW_RULESETS[difficulty])
|
| 155 |
+
stage_iv_window = treatment_window + self.rng.choice([7, 10, 14])
|
| 156 |
+
high_risk_sites = self.rng.sample(sorted(RURAL_SITES), k=2 if difficulty == "hard" else 1)
|
| 157 |
+
dominant_threshold = self.rng.choice([0.68, 0.70, 0.72]) if difficulty == "hard" else 0.0
|
| 158 |
+
male_threshold = self.rng.choice([0.56, 0.60, 0.63]) if difficulty == "hard" else 0.0
|
| 159 |
+
adjusted_gap = self.rng.choice([0.12, 0.15, 0.18]) if difficulty == "hard" else 0.0
|
| 160 |
+
bias_present = difficulty == "hard" and self.rng.random() < config["bias_probability"]
|
| 161 |
+
protocol_key = (
|
| 162 |
+
f"{difficulty}|{age_min}|{age_max}|{treatment_window}|"
|
| 163 |
+
f"{stage_iv_window}|{'/'.join(high_risk_sites)}|{bias_present}"
|
| 164 |
+
)
|
| 165 |
+
protocol_id = hashlib.sha1(protocol_key.encode("utf-8")).hexdigest()[:8].upper()
|
| 166 |
+
protocol_title = f"ONCO-AX-{difficulty[:1].upper()}{protocol_id}"
|
| 167 |
+
|
| 168 |
+
lines = [
|
| 169 |
+
f"TRIAL PROTOCOL EXCERPT — {protocol_title}",
|
| 170 |
+
"",
|
| 171 |
+
"Eligibility",
|
| 172 |
+
f"- Participants must be age {age_min}-{age_max} inclusive on enrollment.",
|
| 173 |
+
"- Missing age is a protocol violation.",
|
| 174 |
+
"",
|
| 175 |
+
"Treatment Scheduling",
|
| 176 |
+
f"- Treatment must begin within {treatment_window} days of enrollment.",
|
| 177 |
+
(
|
| 178 |
+
f"- Stage IV exception: treatment may begin within "
|
| 179 |
+
f"{stage_iv_window} days of enrollment."
|
| 180 |
+
),
|
| 181 |
+
"- death_date must never precede treatment_start.",
|
| 182 |
+
"- Do not assume a generic 18-120 range; this excerpt overrides defaults.",
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
if difficulty == "hard":
|
| 186 |
+
lines.extend(
|
| 187 |
+
[
|
| 188 |
+
"",
|
| 189 |
+
"Equity Review",
|
| 190 |
+
"- Selection bias concerns control-arm composition, not treatment-arm skew.",
|
| 191 |
+
"- Compare mortality within stage strata before escalating a bias concern.",
|
| 192 |
+
(
|
| 193 |
+
f"- Stage IV patients at {', '.join(high_risk_sites)} are a known "
|
| 194 |
+
"high-risk outreach cohort; elevated mortality there is not standalone "
|
| 195 |
+
"evidence of unfair sampling."
|
| 196 |
+
),
|
| 197 |
+
(
|
| 198 |
+
f"- Escalate bias only when control-arm dominance exceeds "
|
| 199 |
+
f"{int(dominant_threshold * 100)}%, male share exceeds "
|
| 200 |
+
f"{int(male_threshold * 100)}%, and stage-adjusted mortality gap "
|
| 201 |
+
f"exceeds {int(adjusted_gap * 100)} percentage points."
|
| 202 |
+
),
|
| 203 |
+
]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
return {
|
| 207 |
+
"protocol_id": protocol_id,
|
| 208 |
+
"protocol_title": protocol_title,
|
| 209 |
+
"excerpt": "\n".join(lines),
|
| 210 |
+
"age_min": age_min,
|
| 211 |
+
"age_max": age_max,
|
| 212 |
+
"treatment_window_days": treatment_window,
|
| 213 |
+
"stage_iv_treatment_window_days": stage_iv_window,
|
| 214 |
+
"high_risk_sites": high_risk_sites,
|
| 215 |
+
"bias_control_dominance_threshold": dominant_threshold,
|
| 216 |
+
"bias_male_threshold": male_threshold,
|
| 217 |
+
"bias_stage_adjusted_gap": adjusted_gap,
|
| 218 |
+
"bias_present": bias_present,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
def _generate_age(self, protocol: dict) -> int:
|
| 222 |
while True:
|
| 223 |
+
age = int(self.rng.gauss(58, 11))
|
| 224 |
+
if protocol["age_min"] <= age <= protocol["age_max"]:
|
| 225 |
return age
|
| 226 |
|
| 227 |
def _select_ethnicity(self, bias_mode: str = "neutral") -> str:
|
| 228 |
+
if bias_mode == "diverse":
|
| 229 |
+
weights = [0.28, 0.19, 0.20, 0.18, 0.10, 0.05]
|
| 230 |
+
elif bias_mode == "white_dominant":
|
| 231 |
+
weights = [0.68, 0.08, 0.08, 0.08, 0.05, 0.03]
|
| 232 |
+
else:
|
| 233 |
+
weights = [0.50, 0.16, 0.15, 0.12, 0.04, 0.03]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
return self.rng.choices(ETHNICITIES, weights=weights, k=1)[0]
|
| 235 |
|
| 236 |
+
def _base_delay(self, stage: str, protocol: dict) -> int:
|
| 237 |
+
max_window = (
|
| 238 |
+
protocol["stage_iv_treatment_window_days"]
|
| 239 |
+
if stage == "IV"
|
| 240 |
+
else protocol["treatment_window_days"]
|
| 241 |
+
)
|
| 242 |
+
lower = 5 if max_window >= 10 else 1
|
| 243 |
+
upper = max(lower, max_window - 2)
|
| 244 |
+
return self.rng.randint(lower, upper)
|
| 245 |
+
|
| 246 |
+
def _generate_base_patient(self, group: str, protocol: dict, bias_mode: str = "neutral") -> dict:
|
| 247 |
pid = self._next_pid()
|
| 248 |
site, country = self.rng.choice(HOSPITAL_SITES)
|
| 249 |
+
stage = self.rng.choices(STAGES, weights=[0.24, 0.28, 0.28, 0.20], k=1)[0]
|
| 250 |
+
age = self._generate_age(protocol)
|
| 251 |
+
enrollment_end = TRIAL_END - timedelta(days=150)
|
| 252 |
+
enrollment_date = self._random_date(TRIAL_START, enrollment_end)
|
| 253 |
+
treatment_start = enrollment_date + timedelta(days=self._base_delay(stage, protocol))
|
| 254 |
+
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
"patient_id": pid,
|
| 256 |
"age": age,
|
| 257 |
+
"gender": self.rng.choice(GENDERS),
|
| 258 |
+
"ethnicity": self._select_ethnicity(bias_mode),
|
| 259 |
"group": group,
|
| 260 |
"treatment_start": treatment_start.strftime("%Y-%m-%d"),
|
| 261 |
"death_date": None,
|
|
|
|
| 263 |
"treatment_site": site,
|
| 264 |
"stage": stage,
|
| 265 |
"trial_phase": "Phase III",
|
| 266 |
+
"drug": self.rng.choice(DRUGS_TREATMENT) if group == "treatment" else "Placebo",
|
| 267 |
"enrollment_date": enrollment_date.strftime("%Y-%m-%d"),
|
| 268 |
"country": country,
|
| 269 |
}
|
| 270 |
|
| 271 |
+
def _mortality_rate(self, patient: dict, protocol: dict) -> float:
|
| 272 |
+
rate = BASE_STAGE_MORTALITY.get(patient["stage"], 0.10)
|
| 273 |
+
if patient["treatment_site"] in protocol["high_risk_sites"] and patient["stage"] == "IV":
|
| 274 |
+
rate += 0.16
|
| 275 |
+
if patient["group"] == "treatment":
|
| 276 |
+
rate *= 0.92
|
| 277 |
+
return max(0.02, min(0.82, rate))
|
| 278 |
+
|
| 279 |
+
def _set_deceased(self, patient: dict, min_days: int, max_days: int) -> None:
|
| 280 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 281 |
+
days_to_death = self.rng.randint(min_days, max_days)
|
| 282 |
+
death_date = treatment_start + timedelta(days=days_to_death)
|
| 283 |
+
patient["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 284 |
+
patient["outcome"] = "deceased"
|
| 285 |
+
|
| 286 |
+
def _apply_mortality(self, patient: dict, protocol: dict) -> dict:
|
| 287 |
+
if self.rng.random() < self._mortality_rate(patient, protocol):
|
| 288 |
+
self._set_deceased(patient, min_days=3, max_days=540)
|
| 289 |
return patient
|
| 290 |
|
| 291 |
+
def _apply_target_mortality(self, cohort: list[dict], target_rate: float) -> None:
|
| 292 |
+
if not cohort:
|
| 293 |
+
return
|
| 294 |
+
self.rng.shuffle(cohort)
|
| 295 |
+
target_count = int(round(len(cohort) * max(0.0, min(1.0, target_rate))))
|
| 296 |
+
for index, patient in enumerate(cohort):
|
| 297 |
+
if index < target_count:
|
| 298 |
+
self._set_deceased(patient, min_days=10, max_days=420)
|
| 299 |
+
else:
|
| 300 |
+
patient["death_date"] = None
|
| 301 |
+
patient["outcome"] = "survived"
|
| 302 |
+
|
| 303 |
+
def _allowed_treatment_window(self, patient: dict, protocol: dict) -> int:
|
| 304 |
+
return (
|
| 305 |
+
protocol["stage_iv_treatment_window_days"]
|
| 306 |
+
if patient.get("stage") == "IV"
|
| 307 |
+
else protocol["treatment_window_days"]
|
| 308 |
+
)
|
| 309 |
|
| 310 |
+
def _enrollment_date(self, patient: dict) -> datetime:
|
| 311 |
+
return datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
|
| 312 |
|
| 313 |
+
def _treatment_date(self, patient: dict) -> datetime:
|
| 314 |
+
return datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
+
def _inject_age_errors(self, patients: list[dict], protocol: dict, config: dict) -> list[dict]:
|
| 317 |
+
n_invalid = max(3, int(len(patients) * config["age_error_rate"]))
|
| 318 |
+
n_missing = max(1, int(len(patients) * config["missing_age_rate"]))
|
| 319 |
available = list(range(len(patients)))
|
| 320 |
self.rng.shuffle(available)
|
| 321 |
|
| 322 |
+
low_values = [protocol["age_min"] - 1, protocol["age_min"] - 2, max(0, protocol["age_min"] - 5), -1]
|
| 323 |
+
high_values = [protocol["age_max"] + 1, protocol["age_max"] + 2, protocol["age_max"] + 5, 999]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
+
for offset in range(min(n_invalid, len(available))):
|
| 326 |
+
patient = patients[available[offset]]
|
| 327 |
+
patient["age"] = self.rng.choice(low_values + high_values)
|
| 328 |
+
self._mark_error(patient["patient_id"], "invalid_age")
|
| 329 |
|
| 330 |
+
start = min(n_invalid, len(available))
|
| 331 |
+
for offset in range(start, min(start + n_missing, len(available))):
|
| 332 |
+
patient = patients[available[offset]]
|
| 333 |
+
patient["age"] = None
|
| 334 |
+
self._mark_error(patient["patient_id"], "invalid_age")
|
| 335 |
|
| 336 |
+
return patients
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
+
def _inject_temporal_errors(self, patients: list[dict], config: dict) -> list[dict]:
|
| 339 |
+
n_errors = max(3, int(len(patients) * config["temporal_error_rate"]))
|
| 340 |
+
candidates = [p for p in patients if p["patient_id"] not in self._ground_truth]
|
| 341 |
self.rng.shuffle(candidates)
|
| 342 |
|
| 343 |
+
for patient in candidates[:n_errors]:
|
| 344 |
+
treatment_start = self._treatment_date(patient)
|
| 345 |
+
death_date = treatment_start - timedelta(days=self.rng.randint(10, 240))
|
| 346 |
+
patient["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 347 |
+
patient["outcome"] = "deceased"
|
| 348 |
+
self._mark_error(patient["patient_id"], "temporal_inconsistency")
|
| 349 |
|
| 350 |
+
return patients
|
|
|
|
|
|
|
| 351 |
|
| 352 |
+
def _inject_protocol_window_errors(
|
| 353 |
+
self,
|
| 354 |
+
patients: list[dict],
|
| 355 |
+
protocol: dict,
|
| 356 |
+
config: dict,
|
| 357 |
+
) -> list[dict]:
|
| 358 |
+
n_errors = max(3, int(len(patients) * config["protocol_window_error_rate"]))
|
| 359 |
+
candidates = [p for p in patients if p["patient_id"] not in self._ground_truth]
|
| 360 |
+
self.rng.shuffle(candidates)
|
| 361 |
|
| 362 |
+
for patient in candidates[:n_errors]:
|
| 363 |
+
allowed_days = self._allowed_treatment_window(patient, protocol)
|
| 364 |
+
enrollment = self._enrollment_date(patient)
|
| 365 |
+
violation_days = allowed_days + self.rng.randint(2, 18)
|
| 366 |
+
patient["treatment_start"] = (enrollment + timedelta(days=violation_days)).strftime("%Y-%m-%d")
|
| 367 |
+
if patient["death_date"]:
|
| 368 |
+
death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
|
| 369 |
+
treatment_start = self._treatment_date(patient)
|
| 370 |
+
if death_date <= treatment_start:
|
| 371 |
+
self._set_deceased(patient, min_days=20, max_days=320)
|
| 372 |
+
self._mark_error(patient["patient_id"], "protocol_window_violation")
|
| 373 |
|
| 374 |
return patients
|
| 375 |
|
| 376 |
+
def _inject_boundary_traps(self, patients: list[dict], protocol: dict, n_traps: int) -> list[dict]:
|
| 377 |
+
valid_ages = [
|
| 378 |
+
protocol["age_min"],
|
| 379 |
+
protocol["age_min"] + 1,
|
| 380 |
+
protocol["age_min"] + 2,
|
| 381 |
+
protocol["age_max"] - 2,
|
| 382 |
+
protocol["age_max"] - 1,
|
| 383 |
+
protocol["age_max"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
]
|
| 385 |
+
available = [
|
| 386 |
+
p
|
| 387 |
+
for p in patients
|
| 388 |
+
if p["patient_id"] not in self._ground_truth and p["age"] is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
]
|
| 390 |
+
self.rng.shuffle(available)
|
| 391 |
+
for patient, age in zip(available[:n_traps], valid_ages * max(1, n_traps)):
|
| 392 |
+
patient["age"] = age
|
| 393 |
+
self._traps.add(patient["patient_id"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
return patients
|
| 395 |
|
| 396 |
+
def _inject_temporal_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
available = [
|
| 398 |
+
p
|
| 399 |
+
for p in patients
|
| 400 |
if p["patient_id"] not in self._ground_truth
|
| 401 |
+
and p["patient_id"] not in self._traps
|
| 402 |
+
and p["death_date"] is None
|
| 403 |
]
|
| 404 |
self.rng.shuffle(available)
|
| 405 |
+
for patient in available[:n_traps]:
|
| 406 |
+
patient["stage"] = "IV"
|
| 407 |
+
self._set_deceased(patient, min_days=1, max_days=3)
|
| 408 |
+
self._traps.add(patient["patient_id"])
|
| 409 |
+
return patients
|
| 410 |
|
| 411 |
+
def _inject_window_traps(self, patients: list[dict], protocol: dict, n_traps: int) -> list[dict]:
|
| 412 |
+
available = [
|
| 413 |
+
p
|
| 414 |
+
for p in patients
|
| 415 |
+
if p["patient_id"] not in self._ground_truth and p["patient_id"] not in self._traps
|
| 416 |
+
]
|
| 417 |
+
self.rng.shuffle(available)
|
| 418 |
+
for patient in available[:n_traps]:
|
| 419 |
+
enrollment = self._enrollment_date(patient)
|
| 420 |
+
if self.rng.random() < 0.55:
|
| 421 |
+
patient["stage"] = "IV"
|
| 422 |
+
allowed_days = self._allowed_treatment_window(patient, protocol)
|
| 423 |
+
trap_delay = max(1, allowed_days - self.rng.choice([0, 1]))
|
| 424 |
+
patient["treatment_start"] = (enrollment + timedelta(days=trap_delay)).strftime("%Y-%m-%d")
|
| 425 |
+
if patient["death_date"]:
|
| 426 |
+
death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
|
| 427 |
+
if death_date <= self._treatment_date(patient):
|
| 428 |
+
self._set_deceased(patient, min_days=12, max_days=240)
|
| 429 |
+
self._traps.add(patient["patient_id"])
|
| 430 |
return patients
|
| 431 |
|
| 432 |
+
def _inject_distractor_deceased(self, patients: list[dict], n_distractors: int) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
available = [
|
| 434 |
+
p
|
| 435 |
+
for p in patients
|
| 436 |
if p["patient_id"] not in self._ground_truth
|
|
|
|
| 437 |
and p["patient_id"] not in self._traps
|
| 438 |
+
and p["death_date"] is None
|
| 439 |
]
|
| 440 |
self.rng.shuffle(available)
|
| 441 |
+
for patient in available[:n_distractors]:
|
| 442 |
+
self._set_deceased(patient, min_days=30, max_days=520)
|
| 443 |
+
self._traps.add(patient["patient_id"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
return patients
|
| 445 |
|
| 446 |
+
def _inject_fake_bias_distractors(self, patients: list[dict], n_distractors: int) -> list[dict]:
|
| 447 |
+
treatment_group = [
|
| 448 |
+
p
|
| 449 |
+
for p in patients
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
if p["group"] == "treatment"
|
| 451 |
and p["patient_id"] not in self._ground_truth
|
| 452 |
and p["patient_id"] not in self._traps
|
| 453 |
]
|
| 454 |
+
self.rng.shuffle(treatment_group)
|
| 455 |
+
for patient in treatment_group[:n_distractors]:
|
| 456 |
+
patient["ethnicity"] = "White"
|
| 457 |
+
patient["gender"] = "M"
|
| 458 |
+
if self.rng.random() < 0.5:
|
| 459 |
+
patient["stage"] = "IV"
|
| 460 |
+
self._set_deceased(patient, min_days=15, max_days=180)
|
| 461 |
+
self._traps.add(patient["patient_id"])
|
|
|
|
| 462 |
return patients
|
| 463 |
|
| 464 |
+
def _inject_selection_bias(self, patients: list[dict], protocol: dict) -> None:
|
| 465 |
+
control = [
|
| 466 |
+
p
|
| 467 |
+
for p in patients
|
| 468 |
+
if p["group"] == "control" and p["patient_id"] not in self._ground_truth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
]
|
| 470 |
+
if not control:
|
| 471 |
+
return
|
| 472 |
+
|
| 473 |
+
target_dom_ratio = protocol["bias_control_dominance_threshold"] + self.rng.uniform(0.06, 0.12)
|
| 474 |
+
dominant_target = int(len(control) * min(0.86, target_dom_ratio))
|
| 475 |
+
white_control = [p for p in control if p["ethnicity"] == "White"]
|
| 476 |
+
non_white_control = [p for p in control if p["ethnicity"] != "White"]
|
| 477 |
+
needed = max(0, dominant_target - len(white_control))
|
| 478 |
+
self.rng.shuffle(non_white_control)
|
| 479 |
+
for patient in non_white_control[:needed]:
|
| 480 |
+
patient["ethnicity"] = "White"
|
| 481 |
+
|
| 482 |
+
target_male_ratio = protocol["bias_male_threshold"] + self.rng.uniform(0.05, 0.10)
|
| 483 |
+
male_target = int(len(control) * min(0.82, target_male_ratio))
|
| 484 |
+
male_control = [p for p in control if p["gender"] == "M"]
|
| 485 |
+
female_control = [p for p in control if p["gender"] == "F"]
|
| 486 |
+
needed_male = max(0, male_target - len(male_control))
|
| 487 |
+
self.rng.shuffle(female_control)
|
| 488 |
+
for patient in female_control[:needed_male]:
|
| 489 |
+
patient["gender"] = "M"
|
| 490 |
+
|
| 491 |
+
dominant = [p for p in control if p["ethnicity"] == "White"]
|
| 492 |
+
minority = [p for p in control if p["ethnicity"] != "White"]
|
| 493 |
+
for stage in STAGES:
|
| 494 |
+
stage_majority = [p for p in dominant if p["stage"] == stage]
|
| 495 |
+
stage_minority = [p for p in minority if p["stage"] == stage]
|
| 496 |
+
if not stage_majority or not stage_minority:
|
| 497 |
+
continue
|
| 498 |
+
base = BASE_STAGE_MORTALITY[stage]
|
| 499 |
+
self._apply_target_mortality(stage_majority, max(0.02, base - 0.03))
|
| 500 |
+
self._apply_target_mortality(stage_minority, min(0.82, base + 0.18))
|
| 501 |
+
|
| 502 |
+
def _inject_confounder_cohort(self, patients: list[dict], protocol: dict) -> None:
|
| 503 |
+
control = [
|
| 504 |
+
p
|
| 505 |
+
for p in patients
|
| 506 |
+
if p["group"] == "control" and p["patient_id"] not in self._ground_truth
|
| 507 |
+
]
|
| 508 |
+
if not control:
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
minority = [p for p in control if p["ethnicity"] != "White"]
|
| 512 |
+
white = [p for p in control if p["ethnicity"] == "White"]
|
| 513 |
+
self.rng.shuffle(minority)
|
| 514 |
+
self.rng.shuffle(white)
|
| 515 |
+
|
| 516 |
+
minority_shift = max(8, len(control) // 18)
|
| 517 |
+
white_shift = max(4, len(control) // 30)
|
| 518 |
+
|
| 519 |
+
for patient in minority[:minority_shift]:
|
| 520 |
+
patient["stage"] = "IV"
|
| 521 |
+
patient["treatment_site"] = self.rng.choice(protocol["high_risk_sites"])
|
| 522 |
+
patient["country"] = next(
|
| 523 |
+
country for site, country in HOSPITAL_SITES if site == patient["treatment_site"]
|
| 524 |
+
)
|
| 525 |
|
| 526 |
+
for patient in white[:white_shift]:
|
| 527 |
+
patient["stage"] = "IV"
|
| 528 |
+
patient["treatment_site"] = self.rng.choice(protocol["high_risk_sites"])
|
| 529 |
+
patient["country"] = next(
|
| 530 |
+
country for site, country in HOSPITAL_SITES if site == patient["treatment_site"]
|
| 531 |
+
)
|
| 532 |
|
| 533 |
+
stage_iv_control = [p for p in control if p["stage"] == "IV"]
|
| 534 |
+
stage_iv_minority = [p for p in stage_iv_control if p["ethnicity"] != "White"]
|
| 535 |
+
stage_iv_white = [p for p in stage_iv_control if p["ethnicity"] == "White"]
|
| 536 |
+
self._apply_target_mortality(stage_iv_minority, 0.66)
|
| 537 |
+
self._apply_target_mortality(stage_iv_white, 0.63)
|
| 538 |
|
| 539 |
def generate(self, difficulty: str = "easy") -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
config = DIFFICULTY_CONFIGS.get(difficulty, DIFFICULTY_CONFIGS["easy"])
|
| 541 |
self._ground_truth = {}
|
| 542 |
self._traps = set()
|
| 543 |
self._patient_counter = 0
|
| 544 |
|
| 545 |
+
protocol = self._build_protocol(difficulty, config)
|
| 546 |
+
n_patients = config["dataset_size"]
|
| 547 |
+
n_control = int(n_patients * config["control_ratio"])
|
| 548 |
+
n_treatment = n_patients - n_control
|
| 549 |
|
|
|
|
| 550 |
patients = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
for _ in range(n_control):
|
| 552 |
+
patient = self._generate_base_patient("control", protocol, bias_mode="neutral")
|
| 553 |
+
patients.append(self._apply_mortality(patient, protocol))
|
|
|
|
| 554 |
|
| 555 |
for _ in range(n_treatment):
|
| 556 |
+
patient = self._generate_base_patient("treatment", protocol, bias_mode="diverse")
|
| 557 |
+
patients.append(self._apply_mortality(patient, protocol))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
+
patients = self._inject_age_errors(patients, protocol, config)
|
| 560 |
if config["temporal_error_rate"] > 0:
|
| 561 |
+
patients = self._inject_temporal_errors(patients, config)
|
| 562 |
+
if config["protocol_window_error_rate"] > 0:
|
| 563 |
+
patients = self._inject_protocol_window_errors(patients, protocol, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
+
if difficulty == "hard":
|
| 566 |
+
if protocol["bias_present"]:
|
| 567 |
+
self._inject_selection_bias(patients, protocol)
|
| 568 |
+
else:
|
| 569 |
+
self._inject_confounder_cohort(patients, protocol)
|
| 570 |
|
| 571 |
+
patients = self._inject_boundary_traps(patients, protocol, config["num_boundary_traps"])
|
| 572 |
if config["num_temporal_traps"] > 0:
|
| 573 |
+
patients = self._inject_temporal_traps(patients, config["num_temporal_traps"])
|
| 574 |
+
if config["num_window_traps"] > 0:
|
| 575 |
+
patients = self._inject_window_traps(patients, protocol, config["num_window_traps"])
|
| 576 |
+
patients = self._inject_distractor_deceased(patients, config["num_distractor_deceased"])
|
| 577 |
if config["num_fake_bias_distractors"] > 0:
|
| 578 |
+
patients = self._inject_fake_bias_distractors(patients, config["num_fake_bias_distractors"])
|
|
|
|
|
|
|
| 579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
self.rng.shuffle(patients)
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
stats = {
|
| 583 |
"total_patients": len(patients),
|
| 584 |
+
"age_errors": sum("invalid_age" in errs for errs in self._ground_truth.values()),
|
| 585 |
+
"temporal_errors": sum("temporal_inconsistency" in errs for errs in self._ground_truth.values()),
|
| 586 |
+
"protocol_window_errors": sum("protocol_window_violation" in errs for errs in self._ground_truth.values()),
|
| 587 |
+
"bias_present": protocol["bias_present"],
|
| 588 |
+
"bias_mode": "true_bias" if protocol["bias_present"] else ("confounded_no_bias" if difficulty == "hard" else "none"),
|
| 589 |
"num_traps": len(self._traps),
|
| 590 |
"control_count": sum(1 for p in patients if p["group"] == "control"),
|
| 591 |
"treatment_count": sum(1 for p in patients if p["group"] == "treatment"),
|
| 592 |
+
"protocol_title": protocol["protocol_title"],
|
| 593 |
}
|
| 594 |
+
stats["total_errors"] = (
|
| 595 |
+
stats["age_errors"]
|
| 596 |
+
+ stats["temporal_errors"]
|
| 597 |
+
+ stats["protocol_window_errors"]
|
| 598 |
+
+ (1 if protocol["bias_present"] else 0)
|
| 599 |
+
)
|
| 600 |
|
| 601 |
return {
|
| 602 |
"dataset": patients,
|
| 603 |
"ground_truth": dict(self._ground_truth),
|
| 604 |
"traps": set(self._traps),
|
| 605 |
+
"bias_present": protocol["bias_present"],
|
| 606 |
+
"protocol": protocol,
|
| 607 |
+
"protocol_excerpt": protocol["excerpt"],
|
| 608 |
+
"protocol_title": protocol["protocol_title"],
|
| 609 |
"config": config,
|
| 610 |
"stats": stats,
|
| 611 |
}
|
| 612 |
|
| 613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
if __name__ == "__main__":
|
| 615 |
print("=" * 60)
|
| 616 |
print(" Dataset Generator — Validation Test")
|
| 617 |
print("=" * 60)
|
| 618 |
|
| 619 |
+
for difficulty in ["easy", "medium", "hard"]:
|
| 620 |
+
generator = DatasetGenerator(seed=42)
|
| 621 |
+
result = generator.generate(difficulty=difficulty)
|
| 622 |
stats = result["stats"]
|
| 623 |
+
protocol = result["protocol"]
|
| 624 |
+
print(f"\n {difficulty.upper()}:")
|
| 625 |
+
print(f" Protocol: {stats['protocol_title']}")
|
| 626 |
print(f" Patients: {stats['total_patients']}")
|
| 627 |
+
print(
|
| 628 |
+
f" Errors: {stats['total_errors']} "
|
| 629 |
+
f"(age={stats['age_errors']}, temporal={stats['temporal_errors']}, "
|
| 630 |
+
f"window={stats['protocol_window_errors']}, bias={stats['bias_mode']})"
|
| 631 |
+
)
|
| 632 |
print(f" Traps: {stats['num_traps']}")
|
| 633 |
print(f" Control: {stats['control_count']}")
|
| 634 |
print(f" Treatment: {stats['treatment_count']}")
|
| 635 |
+
print(
|
| 636 |
+
f" Rules: age={protocol['age_min']}-{protocol['age_max']} | "
|
| 637 |
+
f"start<={protocol['treatment_window_days']}d | "
|
| 638 |
+
f"stage IV<={protocol['stage_iv_treatment_window_days']}d"
|
| 639 |
+
)
|
| 640 |
|
| 641 |
+
generator_2 = DatasetGenerator(seed=42)
|
| 642 |
+
result_2 = generator_2.generate(difficulty=difficulty)
|
| 643 |
+
assert result["dataset"] == result_2["dataset"], "REPRODUCIBILITY FAILED!"
|
| 644 |
+
assert result["ground_truth"] == result_2["ground_truth"], "GROUND TRUTH MISMATCH!"
|
| 645 |
+
assert result["protocol_excerpt"] == result_2["protocol_excerpt"], "PROTOCOL MISMATCH!"
|
| 646 |
+
print(" ✓ Seed reproducibility verified")
|
| 647 |
+
|
| 648 |
+
for patient_id, errors in result["ground_truth"].items():
|
| 649 |
+
patient = next(p for p in result["dataset"] if p["patient_id"] == patient_id)
|
| 650 |
+
for error in errors:
|
| 651 |
+
if error == "invalid_age":
|
|
|
|
| 652 |
age = patient.get("age")
|
| 653 |
+
assert age is None or age < protocol["age_min"] or age > protocol["age_max"], (
|
| 654 |
+
f"Ground truth says {patient_id} invalid age but age={age}"
|
| 655 |
+
)
|
| 656 |
+
elif error == "temporal_inconsistency":
|
| 657 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 658 |
+
death_date = datetime.strptime(patient["death_date"], "%Y-%m-%d")
|
| 659 |
+
assert death_date < treatment_start, (
|
| 660 |
+
f"Ground truth says {patient_id} temporal error but dates are valid"
|
| 661 |
+
)
|
| 662 |
+
elif error == "protocol_window_violation":
|
| 663 |
+
enrollment = datetime.strptime(patient["enrollment_date"], "%Y-%m-%d")
|
| 664 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 665 |
+
allowed = (
|
| 666 |
+
protocol["stage_iv_treatment_window_days"]
|
| 667 |
+
if patient["stage"] == "IV"
|
| 668 |
+
else protocol["treatment_window_days"]
|
| 669 |
+
)
|
| 670 |
+
assert (treatment_start - enrollment).days > allowed, (
|
| 671 |
+
f"Ground truth says {patient_id} window error but delay is valid"
|
| 672 |
+
)
|
| 673 |
+
print(" ✓ Ground truth integrity verified")
|
| 674 |
+
|
| 675 |
+
generator_a = DatasetGenerator(seed=1)
|
| 676 |
+
generator_b = DatasetGenerator(seed=2)
|
| 677 |
+
result_a = generator_a.generate("easy")
|
| 678 |
+
result_b = generator_b.generate("easy")
|
| 679 |
+
assert result_a["dataset"] != result_b["dataset"], "Different seeds generated identical datasets!"
|
| 680 |
+
assert result_a["protocol_excerpt"] != result_b["protocol_excerpt"], "Different seeds generated identical protocols!"
|
| 681 |
+
print("\n ✓ Different seeds produce different datasets")
|
| 682 |
print(f"\n{'=' * 60}")
|
| 683 |
+
print(" ALL TESTS PASSED")
|
| 684 |
print(f"{'=' * 60}")
|
server/models.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
|
|
| 2 |
from pydantic import Field
|
| 3 |
from openenv.core.env_server import Action, Observation, State
|
| 4 |
|
|
|
|
| 5 |
class AuditAction(Action):
|
| 6 |
action_type: str = "flag_error"
|
| 7 |
patient_id: Optional[str] = None
|
|
@@ -12,30 +13,43 @@ class AuditAction(Action):
|
|
| 12 |
report: Optional[str] = None
|
| 13 |
confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
|
| 14 |
|
|
|
|
| 15 |
class AuditObservation(Observation):
|
| 16 |
done: bool = False
|
| 17 |
reward: float = 0.0
|
| 18 |
task_id: str = ""
|
| 19 |
task_type: str = ""
|
| 20 |
task_description: str = ""
|
|
|
|
|
|
|
| 21 |
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 22 |
errors_found: List[str] = Field(default_factory=list)
|
| 23 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 24 |
distributions_computed: List[str] = Field(default_factory=list)
|
| 25 |
feedback: Optional[str] = None
|
| 26 |
score_so_far: float = 0.0
|
|
|
|
|
|
|
| 27 |
attempts_remaining: int = 15
|
| 28 |
phase: str = "investigation"
|
| 29 |
|
|
|
|
| 30 |
class AuditState(State):
|
| 31 |
episode_id: str = ""
|
| 32 |
step_count: int = 0
|
| 33 |
task_id: str = ""
|
| 34 |
task_type: str = ""
|
|
|
|
|
|
|
| 35 |
total_errors: int = 0
|
| 36 |
errors_found: int = 0
|
| 37 |
current_score: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
attempts: int = 0
|
| 39 |
phase: str = "investigation"
|
|
|
|
| 40 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 41 |
-
distributions_computed: List[str] = Field(default_factory=list)
|
|
|
|
| 2 |
from pydantic import Field
|
| 3 |
from openenv.core.env_server import Action, Observation, State
|
| 4 |
|
| 5 |
+
|
| 6 |
class AuditAction(Action):
|
| 7 |
action_type: str = "flag_error"
|
| 8 |
patient_id: Optional[str] = None
|
|
|
|
| 13 |
report: Optional[str] = None
|
| 14 |
confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
|
| 15 |
|
| 16 |
+
|
| 17 |
class AuditObservation(Observation):
|
| 18 |
done: bool = False
|
| 19 |
reward: float = 0.0
|
| 20 |
task_id: str = ""
|
| 21 |
task_type: str = ""
|
| 22 |
task_description: str = ""
|
| 23 |
+
protocol_title: str = ""
|
| 24 |
+
trial_protocol_excerpt: str = ""
|
| 25 |
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 26 |
errors_found: List[str] = Field(default_factory=list)
|
| 27 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 28 |
distributions_computed: List[str] = Field(default_factory=list)
|
| 29 |
feedback: Optional[str] = None
|
| 30 |
score_so_far: float = 0.0
|
| 31 |
+
dense_reward_total: float = 0.0
|
| 32 |
+
score_breakdown: Dict[str, float] = Field(default_factory=dict)
|
| 33 |
attempts_remaining: int = 15
|
| 34 |
phase: str = "investigation"
|
| 35 |
|
| 36 |
+
|
| 37 |
class AuditState(State):
|
| 38 |
episode_id: str = ""
|
| 39 |
step_count: int = 0
|
| 40 |
task_id: str = ""
|
| 41 |
task_type: str = ""
|
| 42 |
+
protocol_title: str = ""
|
| 43 |
+
trial_protocol_excerpt: str = ""
|
| 44 |
total_errors: int = 0
|
| 45 |
errors_found: int = 0
|
| 46 |
current_score: float = 0.0
|
| 47 |
+
dense_reward_total: float = 0.0
|
| 48 |
+
correct_flags: int = 0
|
| 49 |
+
false_positives: int = 0
|
| 50 |
+
duplicate_flags: int = 0
|
| 51 |
attempts: int = 0
|
| 52 |
phase: str = "investigation"
|
| 53 |
+
score_breakdown: Dict[str, float] = Field(default_factory=dict)
|
| 54 |
patterns_investigated: List[str] = Field(default_factory=list)
|
| 55 |
+
distributions_computed: List[str] = Field(default_factory=list)
|