Spaces:
Sleeping
Sleeping
Sumit Saraswat commited on
Commit Β·
36bcbc7
0
Parent(s):
feat: complete procedural adversarial engine and benchmark baseline
Browse files- .DS_Store +0 -0
- .dockerignore +15 -0
- .gitignore +2 -0
- README.md +257 -0
- __init__.py +8 -0
- client.py +19 -0
- heuristic_output.txt +45 -0
- inference.py +888 -0
- models.py +41 -0
- openenv.yaml +27 -0
- openenv_clinical_trial_auditor.egg-info/PKG-INFO +9 -0
- openenv_clinical_trial_auditor.egg-info/SOURCES.txt +15 -0
- openenv_clinical_trial_auditor.egg-info/dependency_links.txt +1 -0
- openenv_clinical_trial_auditor.egg-info/entry_points.txt +2 -0
- openenv_clinical_trial_auditor.egg-info/requires.txt +5 -0
- openenv_clinical_trial_auditor.egg-info/top_level.txt +1 -0
- pyproject.toml +45 -0
- server.log +36 -0
- server/.DS_Store +0 -0
- server/Dockerfile +20 -0
- server/__init__.py +11 -0
- server/app.py +12 -0
- server/clinical_trial_auditor_environment.py +726 -0
- server/dataset_generator.py +668 -0
- server/models.py +41 -0
- server/requirements.txt +4 -0
- test_output.txt +6 -0
- uv.lock +0 -0
.DS_Store
ADDED
|
Binary file (10.2 kB). View file
|
|
|
.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
.env
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
*.pyw
|
| 10 |
+
*.pyz
|
| 11 |
+
*.pywz
|
| 12 |
+
*.pyzw
|
| 13 |
+
*.pyzwz
|
| 14 |
+
|
| 15 |
+
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Clinical Trial Auditor (OpenEnv)
|
| 2 |
+
|
| 3 |
+
Production-grade OpenEnv benchmark for clinical trial data quality and bias auditing.
|
| 4 |
+
|
| 5 |
+
The agent plays the role of a Senior Clinical Data Manager and must detect:
|
| 6 |
+
- syntactic data quality errors (invalid/missing age),
|
| 7 |
+
- temporal inconsistencies (death before treatment),
|
| 8 |
+
- multi-dimensional selection bias in control cohorts.
|
| 9 |
+
|
| 10 |
+
This environment is designed as a real benchmark system, not a static puzzle:
|
| 11 |
+
- procedural generation on every `reset()`,
|
| 12 |
+
- deterministic seed reproducibility,
|
| 13 |
+
- adversarial traps that punish shallow heuristics,
|
| 14 |
+
- deterministic programmatic graders with scores in `[0.0, 1.0]`.
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Why This Matters (Real-World Utility)
|
| 19 |
+
|
| 20 |
+
Clinical trial audits are high-stakes workflows. Data defects and subgroup bias can:
|
| 21 |
+
- invalidate endpoints,
|
| 22 |
+
- distort treatment effect estimates,
|
| 23 |
+
- create regulatory and patient-safety risk.
|
| 24 |
+
|
| 25 |
+
This environment models a realistic multi-site Phase III oncology pipeline where agents must balance recall and precision under strict action budgets, with penalties for over-flagging.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## OpenEnv Compliance
|
| 30 |
+
|
| 31 |
+
This project implements the required OpenEnv interface:
|
| 32 |
+
- typed `Action`, `Observation`, `State` models (Pydantic),
|
| 33 |
+
- `reset(seed, task_id, ...) -> Observation`,
|
| 34 |
+
- `step(action) -> Observation`,
|
| 35 |
+
- `state -> current state`,
|
| 36 |
+
- `openenv.yaml` manifest at repo root.
|
| 37 |
+
|
| 38 |
+
Validation:
|
| 39 |
+
```bash
|
| 40 |
+
openenv validate .
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Environment Architecture
|
| 46 |
+
|
| 47 |
+
`ClinicalTrialAuditorEnvironment` is intentionally layered:
|
| 48 |
+
|
| 49 |
+
1. **Data Engine** (`server/dataset_generator.py`)
|
| 50 |
+
- Procedural patient generation using statistical distributions.
|
| 51 |
+
- Difficulty-specific dataset size and error composition.
|
| 52 |
+
2. **Trap Engine**
|
| 53 |
+
- Boundary-valid traps (`18`, `120`, etc.),
|
| 54 |
+
- near-temporal valid traps (death 1-3 days after treatment),
|
| 55 |
+
- fake bias distractors.
|
| 56 |
+
3. **Scoring Engine**
|
| 57 |
+
- Deterministic ground-truth lookup for each flag.
|
| 58 |
+
- Partial progress rewards + false-positive penalties.
|
| 59 |
+
- Confidence calibration (overconfident wrong answers are punished harder).
|
| 60 |
+
4. **Agent Interface**
|
| 61 |
+
- Standard OpenEnv `step/reset/state`.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## Task Suite (Easy -> Medium -> Hard)
|
| 66 |
+
|
| 67 |
+
### Task 1: `task_easy` (Syntactic Cleaning)
|
| 68 |
+
- Typical size: `300` patients
|
| 69 |
+
- Objective: detect all `invalid_age` cases only
|
| 70 |
+
- Includes valid edge-case age traps to punish naive thresholding
|
| 71 |
+
- Bias grading disabled
|
| 72 |
+
|
| 73 |
+
### Task 2: `task_medium` (Temporal Consistency)
|
| 74 |
+
- Typical size: `500` patients
|
| 75 |
+
- Objective: detect both `invalid_age` and `temporal_inconsistency`
|
| 76 |
+
- Includes near-boundary and near-temporal traps
|
| 77 |
+
- Bias grading disabled
|
| 78 |
+
|
| 79 |
+
### Task 3: `task_hard` (Comprehensive Audit)
|
| 80 |
+
- Typical size: `800` patients
|
| 81 |
+
- Objective: detect `invalid_age` + `temporal_inconsistency` + `selection_bias`
|
| 82 |
+
- Bias injected with representation + outcome + gender skew signals
|
| 83 |
+
- Includes fake patterns to avoid shortcut behavior
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Action Space
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
class AuditAction(Action):
|
| 91 |
+
action_type: str # investigate_pattern | compute_distribution | flag_error | propose_fix | submit_report
|
| 92 |
+
variable: Optional[str]
|
| 93 |
+
patient_id: Optional[str]
|
| 94 |
+
error_type: Optional[str] # invalid_age | temporal_inconsistency | selection_bias
|
| 95 |
+
reason: Optional[str]
|
| 96 |
+
proposed_value: Optional[str]
|
| 97 |
+
report: Optional[str]
|
| 98 |
+
confidence: Optional[float]
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Observation Space
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
class AuditObservation(Observation):
|
| 105 |
+
done: bool
|
| 106 |
+
reward: float
|
| 107 |
+
task_id: str
|
| 108 |
+
task_type: str
|
| 109 |
+
task_description: str
|
| 110 |
+
dataset: list[dict]
|
| 111 |
+
errors_found: list[str]
|
| 112 |
+
patterns_investigated: list[str]
|
| 113 |
+
distributions_computed: list[str]
|
| 114 |
+
feedback: str
|
| 115 |
+
score_so_far: float
|
| 116 |
+
attempts_remaining: int
|
| 117 |
+
phase: str
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Reward Design (Meaningful Shaping)
|
| 123 |
+
|
| 124 |
+
Reward is dense and trajectory-aware (not sparse binary).
|
| 125 |
+
|
| 126 |
+
- correct flag: `+0.10`
|
| 127 |
+
- false positive: `-0.30` (3x stronger than correct flag)
|
| 128 |
+
- duplicate flag: `-0.10`
|
| 129 |
+
- investigation/distribution bonuses and redundancy penalties
|
| 130 |
+
- per-step cost to discourage long loops
|
| 131 |
+
- workflow and efficiency bonuses
|
| 132 |
+
- hard-task bias detection bonus: `+0.20`
|
| 133 |
+
- difficulty multipliers by task
|
| 134 |
+
- score clamped to `[0.0, 1.0]`
|
| 135 |
+
|
| 136 |
+
This reward design explicitly creates precision pressure and separates robust agents from brute-force flaggers.
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## Procedural Generation + Reproducibility
|
| 141 |
+
|
| 142 |
+
Generator script:
|
| 143 |
+
```bash
|
| 144 |
+
cd server
|
| 145 |
+
python3 dataset_generator.py
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
What it guarantees:
|
| 149 |
+
- same seed -> identical dataset + identical ground truth,
|
| 150 |
+
- different seeds -> different datasets,
|
| 151 |
+
- controlled error injection rates,
|
| 152 |
+
- deterministic grader compatibility.
|
| 153 |
+
|
| 154 |
+
Example validated generation profile (seeded):
|
| 155 |
+
- Easy: `300` patients, `12` injected errors, traps enabled
|
| 156 |
+
- Medium: `500` patients, `37` injected errors, traps enabled
|
| 157 |
+
- Hard: `800` patients, `49` injected errors + bias signal, traps enabled
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## Baseline Inference (`inference.py`)
|
| 162 |
+
|
| 163 |
+
`inference.py` supports multiple agent modes:
|
| 164 |
+
- `naive`: raw LLM behavior,
|
| 165 |
+
- `heuristic`: simple rules (no LLM),
|
| 166 |
+
- `full`: statistical detector + planning + LLM report,
|
| 167 |
+
- `all`: run all modes side-by-side.
|
| 168 |
+
|
| 169 |
+
Run:
|
| 170 |
+
```bash
|
| 171 |
+
python3 inference.py --mode all
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
Reproducibility env vars:
|
| 175 |
+
- `API_BASE_URL`
|
| 176 |
+
- `MODEL_NAME`
|
| 177 |
+
- `HF_TOKEN` or `OPENAI_API_KEY`
|
| 178 |
+
- `ENV_BASE_URL` (defaults to `http://localhost:8000`)
|
| 179 |
+
|
| 180 |
+
Current measured results (seeded local run):
|
| 181 |
+
- **Heuristic** average: `0.98`
|
| 182 |
+
- **Full** average: `1.00`
|
| 183 |
+
|
| 184 |
+
Note: for judge-facing benchmarking, include a `--mode all` table from the same seed and model in this README before final submission.
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## Local Run
|
| 189 |
+
|
| 190 |
+
### 1) Start server
|
| 191 |
+
```bash
|
| 192 |
+
cd server
|
| 193 |
+
PYTHONPATH=.. python3 -m uvicorn app:app --host 0.0.0.0 --port 8000
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
### 2) Health check
|
| 197 |
+
```bash
|
| 198 |
+
curl -s http://localhost:8000/health
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### 3) Run baseline
|
| 202 |
+
```bash
|
| 203 |
+
cd ..
|
| 204 |
+
python3 inference.py --mode full
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## Docker
|
| 210 |
+
|
| 211 |
+
Build and run:
|
| 212 |
+
```bash
|
| 213 |
+
cd server
|
| 214 |
+
docker build -t clinical-trial-auditor:latest .
|
| 215 |
+
docker run -p 8000:8000 clinical-trial-auditor:latest
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
Container includes healthcheck at `/health`.
|
| 219 |
+
|
| 220 |
+
---
|
| 221 |
+
|
| 222 |
+
## Hugging Face Space Readiness Checklist
|
| 223 |
+
|
| 224 |
+
- [x] OpenEnv interface implemented (`step/reset/state`)
|
| 225 |
+
- [x] typed models for actions/observations/state
|
| 226 |
+
- [x] `openenv.yaml` present
|
| 227 |
+
- [x] 3 tasks with deterministic graders and score range `[0.0, 1.0]`
|
| 228 |
+
- [x] meaningful reward shaping across trajectory
|
| 229 |
+
- [x] baseline script at project root: `inference.py`
|
| 230 |
+
- [x] dockerized server (`server/Dockerfile`)
|
| 231 |
+
- [x] `openenv validate .` passes locally
|
| 232 |
+
|
| 233 |
+
---
|
| 234 |
+
|
| 235 |
+
## Project Structure
|
| 236 |
+
|
| 237 |
+
```text
|
| 238 |
+
clinical_trial_auditor/
|
| 239 |
+
βββ openenv.yaml
|
| 240 |
+
βββ inference.py
|
| 241 |
+
βββ client.py
|
| 242 |
+
βββ models.py
|
| 243 |
+
βββ README.md
|
| 244 |
+
βββ server/
|
| 245 |
+
βββ app.py
|
| 246 |
+
βββ clinical_trial_auditor_environment.py
|
| 247 |
+
βββ dataset_generator.py
|
| 248 |
+
βββ models.py
|
| 249 |
+
βββ requirements.txt
|
| 250 |
+
βββ Dockerfile
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
## Motivation
|
| 256 |
+
|
| 257 |
+
This benchmark is intended to evaluate whether an AI agent can do rigorous, workflow-constrained, clinically relevant data auditing under adversarial conditions, not just solve a fixed toy dataset.
|
__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .client import ClinicalTrialAuditorEnv
|
| 2 |
+
from .models import AuditAction, AuditObservation
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"AuditAction",
|
| 6 |
+
"AuditObservation",
|
| 7 |
+
"ClinicalTrialAuditorEnv",
|
| 8 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openenv.core.env_client import EnvClient
|
| 2 |
+
from openenv.core.client_types import StepResult
|
| 3 |
+
from models import AuditAction, AuditObservation, AuditState
|
| 4 |
+
|
| 5 |
+
class ClinicalTrialAuditorEnv(EnvClient[AuditAction, AuditObservation, AuditState]):
|
| 6 |
+
def _step_payload(self, action: AuditAction) -> dict:
|
| 7 |
+
return action.model_dump()
|
| 8 |
+
|
| 9 |
+
def _parse_result(self, payload: dict) -> StepResult:
|
| 10 |
+
obs_data = payload.get("observation", payload)
|
| 11 |
+
observation = AuditObservation(**obs_data)
|
| 12 |
+
return StepResult(
|
| 13 |
+
observation=observation,
|
| 14 |
+
reward=payload.get("reward", observation.reward),
|
| 15 |
+
done=payload.get("done", observation.done),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def _parse_state(self, payload: dict) -> AuditState:
|
| 19 |
+
return AuditState(**payload)
|
heuristic_output.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================================================================
|
| 2 |
+
Clinical Trial Auditor β Baseline Inference
|
| 3 |
+
Procedural Dataset Generation | Adversarial Traps | Seed-Reproducible
|
| 4 |
+
Model: llama-3.3-70b-versatile
|
| 5 |
+
Seed: 20240401
|
| 6 |
+
=================================================================
|
| 7 |
+
|
| 8 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
AGENT: HEURISTIC
|
| 10 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
|
| 12 |
+
Task: Syntactic Cleaning (Easy)
|
| 13 |
+
--------------------------------------------------
|
| 14 |
+
Patients: 300 | Max steps: 20
|
| 15 |
+
π Metrics: 12/12 correct (precision=100%) | 16 steps | 0 LLM call(s)
|
| 16 |
+
β Final: 1.00
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
Task: Temporal Consistency (Medium)
|
| 20 |
+
--------------------------------------------------
|
| 21 |
+
Patients: 500 | Max steps: 30
|
| 22 |
+
π Metrics: 25/26 correct (precision=96%) | 30 steps | 0 LLM call(s)
|
| 23 |
+
β Final: 0.94
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Task: Equity Bias Audit (Hard)
|
| 27 |
+
--------------------------------------------------
|
| 28 |
+
Patients: 800 | Max steps: 40
|
| 29 |
+
π Metrics: 34/36 correct (precision=94%) | 40 steps | 0 LLM call(s)
|
| 30 |
+
β Final: 1.00
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
=================================================================
|
| 34 |
+
BENCHMARK RESULTS
|
| 35 |
+
=================================================================
|
| 36 |
+
Syntactic Cleaning (Easy) : 1.00
|
| 37 |
+
Temporal Consistency (Medium) : 0.94
|
| 38 |
+
Equity Bias Audit (Hard) : 1.00
|
| 39 |
+
|
| 40 |
+
Average score: 0.98
|
| 41 |
+
Total time: 0.4s
|
| 42 |
+
LLM calls: 0
|
| 43 |
+
Total steps: 86
|
| 44 |
+
Average precision: 97%
|
| 45 |
+
=================================================================
|
inference.py
ADDED
|
@@ -0,0 +1,888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Clinical Trial Auditor β Multi-Agent Baseline Inference
|
| 3 |
+
=========================================================
|
| 4 |
+
Three agent modes to demonstrate environment difficulty gradient:
|
| 5 |
+
|
| 6 |
+
1. NAIVE β Raw LLM prompt, no statistical tools β expected ~0.25-0.40
|
| 7 |
+
2. HEURISTIC β Simple rule-based agent β expected ~0.45-0.60
|
| 8 |
+
3. FULL β Statistical Detection Engine + LLM Reasoning β expected ~0.85-0.95
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python inference.py # Full agent (default)
|
| 12 |
+
python inference.py --mode naive # Naive LLM-only agent
|
| 13 |
+
python inference.py --mode heuristic # Simple heuristic agent
|
| 14 |
+
python inference.py --mode full # Full agentic pipeline
|
| 15 |
+
python inference.py --mode all # Run all three, side-by-side
|
| 16 |
+
|
| 17 |
+
Pipeline (full mode):
|
| 18 |
+
1. PROFILE β Schema-aware statistical analysis of dataset
|
| 19 |
+
2. DETECT β Multi-detector anomaly pipeline with confidence scoring
|
| 20 |
+
3. ASSESS β Risk severity + clinical impact evaluation
|
| 21 |
+
4. PLAN β Task-adaptive optimal action sequence
|
| 22 |
+
5. REASON β LLM for ambiguous cases + expert report generation
|
| 23 |
+
6. EXECUTE β Deterministic environment interaction
|
| 24 |
+
7. EVALUATE β Precision/recall/F1 metrics tracking
|
| 25 |
+
"""
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
import json
|
| 30 |
+
import math
|
| 31 |
+
import argparse
|
| 32 |
+
import statistics
|
| 33 |
+
from datetime import datetime
|
| 34 |
+
from collections import Counter
|
| 35 |
+
from typing import Optional
|
| 36 |
+
|
| 37 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
+
|
| 39 |
+
from openai import OpenAI
|
| 40 |
+
from client import ClinicalTrialAuditorEnv
|
| 41 |
+
from models import AuditAction
|
| 42 |
+
|
| 43 |
+
# ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 45 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
| 46 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 47 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
| 48 |
+
|
| 49 |
+
# Reproducible seed for baseline evaluation
|
| 50 |
+
BASELINE_SEED = 20240401
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
# CORE DATA STRUCTURES
|
| 55 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
|
| 57 |
+
class Finding:
|
| 58 |
+
"""A detected anomaly with confidence, risk severity, and explanation."""
|
| 59 |
+
def __init__(self, patient_id: str, error_type: str, reason: str,
|
| 60 |
+
confidence: float, risk: str = "medium",
|
| 61 |
+
value=None, statistical_context: str = ""):
|
| 62 |
+
self.patient_id = patient_id
|
| 63 |
+
self.error_type = error_type
|
| 64 |
+
self.reason = reason
|
| 65 |
+
self.confidence = min(1.0, max(0.0, confidence))
|
| 66 |
+
self.risk = risk
|
| 67 |
+
self.value = value
|
| 68 |
+
self.statistical_context = statistical_context
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def priority_score(self) -> float:
|
| 72 |
+
risk_weights = {"critical": 1.0, "high": 0.8, "medium": 0.5, "low": 0.2}
|
| 73 |
+
return self.confidence * risk_weights.get(self.risk, 0.5)
|
| 74 |
+
|
| 75 |
+
def explain(self) -> str:
|
| 76 |
+
parts = [f"{self.error_type}: {self.reason}"]
|
| 77 |
+
if self.statistical_context:
|
| 78 |
+
parts.append(f" Evidence: {self.statistical_context}")
|
| 79 |
+
parts.append(f" Confidence: {self.confidence:.0%} | Risk: {self.risk.upper()}")
|
| 80 |
+
return "\n".join(parts)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
# MODULE 1: DATA PROFILER β Robust statistical summary
|
| 85 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
|
| 87 |
+
class DataProfiler:
|
| 88 |
+
"""Schema-aware statistical profiler using robust estimators (IQR, MAD)."""
|
| 89 |
+
|
| 90 |
+
def __init__(self, dataset: list[dict]):
|
| 91 |
+
self.dataset = dataset
|
| 92 |
+
self.n = len(dataset)
|
| 93 |
+
self.columns = sorted({k for row in dataset for k in row.keys()})
|
| 94 |
+
self.types = self._infer_types()
|
| 95 |
+
self.profiles = {}
|
| 96 |
+
|
| 97 |
+
def _infer_types(self) -> dict:
|
| 98 |
+
types = {}
|
| 99 |
+
for col in self.columns:
|
| 100 |
+
vals = [r.get(col) for r in self.dataset if r.get(col) is not None]
|
| 101 |
+
if not vals:
|
| 102 |
+
types[col] = "unknown"
|
| 103 |
+
elif all(isinstance(v, (int, float)) for v in vals):
|
| 104 |
+
types[col] = "numeric"
|
| 105 |
+
elif all(isinstance(v, str) and self._is_date(v) for v in vals[:5]):
|
| 106 |
+
types[col] = "date"
|
| 107 |
+
elif col.lower().endswith("_id") or col.lower() == "id":
|
| 108 |
+
types[col] = "id"
|
| 109 |
+
else:
|
| 110 |
+
types[col] = "categorical"
|
| 111 |
+
return types
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def _is_date(s: str) -> bool:
|
| 115 |
+
for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%d-%m-%Y"):
|
| 116 |
+
try:
|
| 117 |
+
datetime.strptime(s, fmt)
|
| 118 |
+
return True
|
| 119 |
+
except ValueError:
|
| 120 |
+
pass
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
def profile_numeric(self, col: str) -> dict:
|
| 124 |
+
values = [r[col] for r in self.dataset if r.get(col) is not None]
|
| 125 |
+
null_count = sum(1 for r in self.dataset if r.get(col) is None)
|
| 126 |
+
if not values:
|
| 127 |
+
return {"null_count": null_count, "valid_count": 0}
|
| 128 |
+
|
| 129 |
+
sorted_vals = sorted(values)
|
| 130 |
+
n = len(sorted_vals)
|
| 131 |
+
median = statistics.median(sorted_vals)
|
| 132 |
+
mean = statistics.mean(sorted_vals)
|
| 133 |
+
std = statistics.stdev(sorted_vals) if n > 1 else 0
|
| 134 |
+
|
| 135 |
+
q1 = sorted_vals[n // 4] if n >= 4 else sorted_vals[0]
|
| 136 |
+
q3 = sorted_vals[3 * n // 4] if n >= 4 else sorted_vals[-1]
|
| 137 |
+
iqr = q3 - q1
|
| 138 |
+
|
| 139 |
+
mad = statistics.median([abs(v - median) for v in sorted_vals])
|
| 140 |
+
mad_scaled = mad * 1.4826
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"mean": round(mean, 2), "std": round(std, 2),
|
| 144 |
+
"median": round(median, 2), "mad": round(mad_scaled, 2),
|
| 145 |
+
"min": min(values), "max": max(values),
|
| 146 |
+
"q1": q1, "q3": q3, "iqr": iqr,
|
| 147 |
+
"null_count": null_count, "valid_count": n,
|
| 148 |
+
"iqr_lower": q1 - 1.5 * iqr,
|
| 149 |
+
"iqr_upper": q3 + 1.5 * iqr,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def profile_categorical(self, col: str) -> dict:
|
| 153 |
+
vals = [str(r.get(col, "None")) for r in self.dataset]
|
| 154 |
+
counter = Counter(vals)
|
| 155 |
+
total = len(vals)
|
| 156 |
+
return {
|
| 157 |
+
"distribution": dict(counter),
|
| 158 |
+
"unique_count": len(counter),
|
| 159 |
+
"mode": counter.most_common(1)[0][0] if counter else None,
|
| 160 |
+
"mode_ratio": counter.most_common(1)[0][1] / total if counter else 0,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def profile_all(self) -> dict:
|
| 164 |
+
for col in self.columns:
|
| 165 |
+
if self.types.get(col) == "numeric":
|
| 166 |
+
self.profiles[col] = self.profile_numeric(col)
|
| 167 |
+
elif self.types.get(col) == "categorical":
|
| 168 |
+
self.profiles[col] = self.profile_categorical(col)
|
| 169 |
+
return self.profiles
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
# MODULE 2: ANOMALY DETECTORS β Confidence + Risk scoring
|
| 174 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
|
| 176 |
+
class AgeAnomalyDetector:
|
| 177 |
+
"""
|
| 178 |
+
Multi-layer age anomaly detection:
|
| 179 |
+
- Layer 1: Clinical domain constraints (18-120 for trial eligibility)
|
| 180 |
+
- Layer 2: Statistical outliers via IQR
|
| 181 |
+
- Layer 3: Biological plausibility
|
| 182 |
+
"""
|
| 183 |
+
CLINICAL_MIN, CLINICAL_MAX = 18, 120
|
| 184 |
+
|
| 185 |
+
def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
|
| 186 |
+
findings = []
|
| 187 |
+
age_prof = profile.get("age", {})
|
| 188 |
+
median = age_prof.get("median", 60)
|
| 189 |
+
mad = age_prof.get("mad", 15)
|
| 190 |
+
|
| 191 |
+
for row in dataset:
|
| 192 |
+
pid = row.get("patient_id", "?")
|
| 193 |
+
age = row.get("age")
|
| 194 |
+
|
| 195 |
+
if age is None:
|
| 196 |
+
findings.append(Finding(
|
| 197 |
+
patient_id=pid, error_type="invalid_age",
|
| 198 |
+
reason="Missing age β required for trial eligibility",
|
| 199 |
+
confidence=1.0, risk="high", value=None,
|
| 200 |
+
statistical_context="Null value in mandatory field",
|
| 201 |
+
))
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
is_domain_violation = age < self.CLINICAL_MIN or age > self.CLINICAL_MAX
|
| 205 |
+
|
| 206 |
+
if is_domain_violation:
|
| 207 |
+
deviation = abs(age - median) / mad if mad > 0 else 0
|
| 208 |
+
is_biological_impossible = age < 0 or age > 122
|
| 209 |
+
if is_biological_impossible:
|
| 210 |
+
conf, risk = 1.0, "critical"
|
| 211 |
+
context = f"Biologically impossible (age={age})"
|
| 212 |
+
elif age > 200:
|
| 213 |
+
conf, risk = 0.99, "critical"
|
| 214 |
+
context = f"Likely sentinel/data entry error: {deviation:.1f} MAD from median"
|
| 215 |
+
else:
|
| 216 |
+
conf, risk = 0.95, "high"
|
| 217 |
+
context = f"Outside range [{self.CLINICAL_MIN}-{self.CLINICAL_MAX}]"
|
| 218 |
+
|
| 219 |
+
findings.append(Finding(
|
| 220 |
+
patient_id=pid, error_type="invalid_age",
|
| 221 |
+
reason=f"Age {age} violates clinical trial range [{self.CLINICAL_MIN}-{self.CLINICAL_MAX}]",
|
| 222 |
+
confidence=conf, risk=risk, value=age,
|
| 223 |
+
statistical_context=context,
|
| 224 |
+
))
|
| 225 |
+
|
| 226 |
+
return findings
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class TemporalConsistencyDetector:
|
| 230 |
+
"""Detects death_date before treatment_start violations."""
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _parse_date(val) -> Optional[datetime]:
|
| 234 |
+
if not val or val in ("", "N/A", "None", "null"):
|
| 235 |
+
return None
|
| 236 |
+
for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%d-%m-%Y", "%Y/%m/%d"):
|
| 237 |
+
try:
|
| 238 |
+
return datetime.strptime(str(val), fmt)
|
| 239 |
+
except (ValueError, TypeError):
|
| 240 |
+
pass
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
|
| 244 |
+
findings = []
|
| 245 |
+
for row in dataset:
|
| 246 |
+
pid = row.get("patient_id", "?")
|
| 247 |
+
early = self._parse_date(row.get("treatment_start"))
|
| 248 |
+
late = self._parse_date(row.get("death_date"))
|
| 249 |
+
if early and late and late < early:
|
| 250 |
+
gap = (early - late).days
|
| 251 |
+
risk = "critical" if gap > 180 else "high" if gap > 30 else "medium"
|
| 252 |
+
conf = min(1.0, 0.90 + gap / 3650)
|
| 253 |
+
findings.append(Finding(
|
| 254 |
+
patient_id=pid, error_type="temporal_inconsistency",
|
| 255 |
+
reason=f"death_date {row.get('death_date')} is {gap} days before treatment_start {row.get('treatment_start')}",
|
| 256 |
+
confidence=conf, risk=risk,
|
| 257 |
+
value=f"{gap}-day violation",
|
| 258 |
+
statistical_context=f"Chronological ordering violated by {gap} days",
|
| 259 |
+
))
|
| 260 |
+
return findings
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class SelectionBiasDetector:
|
| 264 |
+
"""Multi-dimensional bias detection in control group."""
|
| 265 |
+
REPRESENTATION_THRESHOLD = 0.65
|
| 266 |
+
OUTCOME_DISPARITY_THRESHOLD = 0.20
|
| 267 |
+
|
| 268 |
+
def detect(self, dataset: list[dict], profile: dict) -> list[Finding]:
|
| 269 |
+
findings = []
|
| 270 |
+
control = [r for r in dataset if r.get("group") == "control"]
|
| 271 |
+
if not control:
|
| 272 |
+
return findings
|
| 273 |
+
|
| 274 |
+
total_control = len(control)
|
| 275 |
+
eth_counts = Counter(r.get("ethnicity", "Unknown") for r in control)
|
| 276 |
+
dominant = eth_counts.most_common(1)[0] if eth_counts else None
|
| 277 |
+
if not dominant:
|
| 278 |
+
return findings
|
| 279 |
+
|
| 280 |
+
dominant_name, dominant_count = dominant
|
| 281 |
+
representation_ratio = dominant_count / total_control
|
| 282 |
+
|
| 283 |
+
outcome_rates = {}
|
| 284 |
+
for eth, count in eth_counts.items():
|
| 285 |
+
deceased = sum(1 for r in control if r.get("ethnicity") == eth and r.get("outcome") == "deceased")
|
| 286 |
+
outcome_rates[eth] = deceased / count if count > 0 else 0
|
| 287 |
+
|
| 288 |
+
rates = list(outcome_rates.values())
|
| 289 |
+
max_disparity = max(rates) - min(rates) if len(rates) > 1 else 0
|
| 290 |
+
|
| 291 |
+
minority_deceased = sum(
|
| 292 |
+
1 for r in control
|
| 293 |
+
if r.get("ethnicity") != dominant_name and r.get("outcome") == "deceased"
|
| 294 |
+
)
|
| 295 |
+
minority_total = total_control - dominant_count
|
| 296 |
+
minority_mortality = minority_deceased / minority_total if minority_total > 0 else 0
|
| 297 |
+
|
| 298 |
+
male_control = sum(1 for r in control if r.get("gender") == "M")
|
| 299 |
+
male_ratio = male_control / total_control
|
| 300 |
+
|
| 301 |
+
evidence = []
|
| 302 |
+
confidence = 0.0
|
| 303 |
+
|
| 304 |
+
if representation_ratio >= self.REPRESENTATION_THRESHOLD:
|
| 305 |
+
evidence.append(f"Representation: {dominant_name}={representation_ratio:.0%} of control")
|
| 306 |
+
confidence += 0.4
|
| 307 |
+
if minority_deceased > 0:
|
| 308 |
+
evidence.append(f"Outcome disparity: minority mortality={minority_mortality:.0%}")
|
| 309 |
+
confidence += 0.2
|
| 310 |
+
if male_ratio >= 0.5:
|
| 311 |
+
evidence.append(f"Gender imbalance: male={male_ratio:.0%}")
|
| 312 |
+
confidence += 0.1
|
| 313 |
+
if max_disparity > self.OUTCOME_DISPARITY_THRESHOLD:
|
| 314 |
+
evidence.append(f"Statistically significant disparity: Ξ={max_disparity:.0%}")
|
| 315 |
+
confidence += 0.15
|
| 316 |
+
|
| 317 |
+
confidence = min(1.0, confidence)
|
| 318 |
+
|
| 319 |
+
if confidence >= 0.6 and representation_ratio >= self.REPRESENTATION_THRESHOLD:
|
| 320 |
+
findings.append(Finding(
|
| 321 |
+
patient_id=None, error_type="selection_bias",
|
| 322 |
+
reason="Multi-dimensional selection bias: " + "; ".join(evidence),
|
| 323 |
+
confidence=confidence, risk="critical",
|
| 324 |
+
value=f"{dominant_name}={representation_ratio:.0%}",
|
| 325 |
+
statistical_context=f"Representation: {dominant_name}={representation_ratio:.0%} | Disparity: Ξ={max_disparity:.0%} | Minority mortality: {minority_mortality:.0%}",
|
| 326 |
+
))
|
| 327 |
+
|
| 328 |
+
return findings
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
+
# MODULE 3: ACTION PLANNER
|
| 333 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 334 |
+
|
| 335 |
+
class ActionPlanner:
|
| 336 |
+
"""Plans optimal action sequence adapted to task type and step budget."""
|
| 337 |
+
|
| 338 |
+
def plan(self, findings: list[Finding], task_type: str,
|
| 339 |
+
max_steps: int = 20) -> list[AuditAction]:
|
| 340 |
+
actions = []
|
| 341 |
+
|
| 342 |
+
# Phase 1: Investigation (3 steps)
|
| 343 |
+
investigate = ["age", "death_date", "ethnicity"]
|
| 344 |
+
for var in investigate:
|
| 345 |
+
actions.append(AuditAction(action_type="investigate_pattern", variable=var))
|
| 346 |
+
|
| 347 |
+
# Phase 2: Flag findings by priority
|
| 348 |
+
data_findings = [f for f in findings if f.error_type != "selection_bias"]
|
| 349 |
+
bias_findings = [f for f in findings if f.error_type == "selection_bias"]
|
| 350 |
+
|
| 351 |
+
data_findings.sort(key=lambda f: -f.priority_score)
|
| 352 |
+
|
| 353 |
+
bias_slots = 1 if (bias_findings and task_type == "comprehensive_audit") else 0
|
| 354 |
+
max_data_flags = max_steps - len(investigate) - 1 - bias_slots
|
| 355 |
+
|
| 356 |
+
flagged = set()
|
| 357 |
+
for f in data_findings[:max_data_flags]:
|
| 358 |
+
if f.patient_id in flagged:
|
| 359 |
+
continue
|
| 360 |
+
flagged.add(f.patient_id)
|
| 361 |
+
actions.append(AuditAction(
|
| 362 |
+
action_type="flag_error",
|
| 363 |
+
patient_id=f.patient_id,
|
| 364 |
+
error_type=f.error_type,
|
| 365 |
+
reason=f.reason,
|
| 366 |
+
))
|
| 367 |
+
|
| 368 |
+
if bias_findings and task_type == "comprehensive_audit":
|
| 369 |
+
actions.append(AuditAction(
|
| 370 |
+
action_type="flag_error",
|
| 371 |
+
error_type="selection_bias",
|
| 372 |
+
reason=bias_findings[0].reason,
|
| 373 |
+
))
|
| 374 |
+
|
| 375 |
+
return actions
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 379 |
+
# MODULE 4: LLM REASONING LAYER
|
| 380 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 381 |
+
|
| 382 |
+
def generate_expert_report(client, findings: list[Finding],
|
| 383 |
+
profiles: dict, task_type: str) -> str:
|
| 384 |
+
"""LLM generates expert audit report from pre-analyzed findings."""
|
| 385 |
+
age_f = [f for f in findings if f.error_type == "invalid_age"]
|
| 386 |
+
temp_f = [f for f in findings if f.error_type == "temporal_inconsistency"]
|
| 387 |
+
bias_f = [f for f in findings if f.error_type == "selection_bias"]
|
| 388 |
+
age_p = profiles.get("age", {})
|
| 389 |
+
|
| 390 |
+
sections = [
|
| 391 |
+
f"AUDIT ANALYSIS β Task: {task_type}",
|
| 392 |
+
f"Dataset: {age_p.get('valid_count', 0) + age_p.get('null_count', 0)} patients",
|
| 393 |
+
f"Age: median={age_p.get('median', '?')}, range=[{age_p.get('min', '?')}, {age_p.get('max', '?')}]",
|
| 394 |
+
"", "ISSUES:",
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
+
if age_f:
|
| 398 |
+
sections.append(f"β’ {len(age_f)} age anomalies")
|
| 399 |
+
for f in age_f[:3]:
|
| 400 |
+
sections.append(f" - {f.patient_id}: age={f.value}")
|
| 401 |
+
if temp_f:
|
| 402 |
+
sections.append(f"β’ {len(temp_f)} temporal violations")
|
| 403 |
+
for f in temp_f[:3]:
|
| 404 |
+
sections.append(f" - {f.patient_id}: {f.value}")
|
| 405 |
+
if bias_f:
|
| 406 |
+
sections.append(f"β’ Selection bias: {bias_f[0].statistical_context}")
|
| 407 |
+
|
| 408 |
+
try:
|
| 409 |
+
completion = client.chat.completions.create(
|
| 410 |
+
model=MODEL_NAME,
|
| 411 |
+
messages=[
|
| 412 |
+
{
|
| 413 |
+
"role": "system",
|
| 414 |
+
"content": (
|
| 415 |
+
"You are a Senior Clinical Data Manager writing a formal audit report. "
|
| 416 |
+
"Provide: 1) SUMMARY with severity, 2) ROOT CAUSE analysis, "
|
| 417 |
+
"3) RISK ASSESSMENT (impact on trial validity), "
|
| 418 |
+
"4) RECOMMENDED corrective actions, "
|
| 419 |
+
"5) REGULATORY compliance impact. "
|
| 420 |
+
"Be concise (max 150 words). Use professional clinical language."
|
| 421 |
+
),
|
| 422 |
+
},
|
| 423 |
+
{"role": "user", "content": "\n".join(sections)},
|
| 424 |
+
],
|
| 425 |
+
max_tokens=250,
|
| 426 |
+
temperature=0,
|
| 427 |
+
)
|
| 428 |
+
report = completion.choices[0].message.content or ""
|
| 429 |
+
if "recommend" not in report.lower():
|
| 430 |
+
report += "\nRecommend immediate corrective action for all identified issues."
|
| 431 |
+
return report
|
| 432 |
+
except Exception as e:
|
| 433 |
+
# Deterministic fallback
|
| 434 |
+
severity = "CRITICAL" if bias_f else "HIGH" if temp_f else "MEDIUM"
|
| 435 |
+
parts = [
|
| 436 |
+
f"CLINICAL DATA AUDIT REPORT β {task_type.replace('_', ' ').title()}",
|
| 437 |
+
f"\nSUMMARY: {len(findings)} data quality issues identified.",
|
| 438 |
+
]
|
| 439 |
+
if age_f:
|
| 440 |
+
parts.append(f"\nAGE ANOMALIES ({len(age_f)}): Root cause: data entry errors or ETL pipeline failures.")
|
| 441 |
+
if temp_f:
|
| 442 |
+
parts.append(f"\nTEMPORAL VIOLATIONS ({len(temp_f)}): Root cause: date field mapping errors.")
|
| 443 |
+
if bias_f:
|
| 444 |
+
parts.append(f"\nSELECTION BIAS: {bias_f[0].statistical_context}.")
|
| 445 |
+
parts.append(f"\nRISK LEVEL: {severity}. Recommend immediate corrective action: "
|
| 446 |
+
"quarantine affected records, audit data entry workflows, implement validation "
|
| 447 |
+
"rules, and rebalance demographic representation in control group. "
|
| 448 |
+
"This impacts regulatory compliance with FDA 21 CFR Part 11 and ICH-GCP guidelines.")
|
| 449 |
+
return "\n".join(parts)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 453 |
+
# MODULE 5: METRICS TRACKER
|
| 454 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 455 |
+
|
| 456 |
+
class MetricsTracker:
|
| 457 |
+
def __init__(self):
|
| 458 |
+
self.true_pos = 0
|
| 459 |
+
self.false_pos = 0
|
| 460 |
+
self.total_flagged = 0
|
| 461 |
+
self.steps = 0
|
| 462 |
+
self.llm_calls = 0
|
| 463 |
+
|
| 464 |
+
def record(self, feedback: str):
|
| 465 |
+
self.total_flagged += 1
|
| 466 |
+
if "Correct" in feedback or "β" in feedback:
|
| 467 |
+
self.true_pos += 1
|
| 468 |
+
elif "False positive" in feedback or "REJECTED" in feedback or "β" in feedback:
|
| 469 |
+
self.false_pos += 1
|
| 470 |
+
|
| 471 |
+
@property
|
| 472 |
+
def precision(self) -> float:
|
| 473 |
+
return self.true_pos / self.total_flagged if self.total_flagged else 0
|
| 474 |
+
|
| 475 |
+
def summary(self) -> str:
|
| 476 |
+
return (
|
| 477 |
+
f" π Metrics: {self.true_pos}/{self.total_flagged} correct "
|
| 478 |
+
f"(precision={self.precision:.0%}) | "
|
| 479 |
+
f"{self.steps} steps | {self.llm_calls} LLM call(s)"
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 484 |
+
# AGENT MODE 1: NAIVE LLM (raw prompt, no statistical tools)
|
| 485 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 486 |
+
|
| 487 |
+
def run_naive_task(client, task_id: str, task_name: str):
|
| 488 |
+
"""
|
| 489 |
+
Naive agent: sends raw data to LLM, asks it to find errors.
|
| 490 |
+
No statistical analysis, no planning. Expected score: ~0.25-0.40
|
| 491 |
+
"""
|
| 492 |
+
print(f"\n Task: {task_name}")
|
| 493 |
+
print(" " + "-" * 50)
|
| 494 |
+
|
| 495 |
+
metrics = MetricsTracker()
|
| 496 |
+
final_score = 0.0
|
| 497 |
+
|
| 498 |
+
with ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync() as env:
|
| 499 |
+
result = env.reset(task_id=task_id, seed=BASELINE_SEED)
|
| 500 |
+
obs = result.observation.model_dump()
|
| 501 |
+
dataset = obs["dataset"]
|
| 502 |
+
task_type = obs["task_type"]
|
| 503 |
+
max_steps = obs["attempts_remaining"]
|
| 504 |
+
print(f" Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 505 |
+
|
| 506 |
+
# Send first 30 patients to LLM (token limit)
|
| 507 |
+
sample = dataset[:30]
|
| 508 |
+
sample_str = json.dumps(sample, indent=1, default=str)
|
| 509 |
+
|
| 510 |
+
try:
|
| 511 |
+
completion = client.chat.completions.create(
|
| 512 |
+
model=MODEL_NAME,
|
| 513 |
+
messages=[
|
| 514 |
+
{
|
| 515 |
+
"role": "system",
|
| 516 |
+
"content": "You are a clinical data auditor. Find errors in patient data."
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"role": "user",
|
| 520 |
+
"content": (
|
| 521 |
+
f"Here are {len(sample)} patient records from a clinical trial. "
|
| 522 |
+
f"Find ALL data quality issues.\n"
|
| 523 |
+
f"For each issue, respond with ONE line: PATIENT_ID|ERROR_TYPE|REASON\n"
|
| 524 |
+
f"ERROR_TYPE must be: invalid_age OR temporal_inconsistency\n"
|
| 525 |
+
f"Valid age range: 18-120. Death date must not precede treatment start.\n\n"
|
| 526 |
+
f"{sample_str}"
|
| 527 |
+
),
|
| 528 |
+
},
|
| 529 |
+
],
|
| 530 |
+
max_tokens=500,
|
| 531 |
+
temperature=0,
|
| 532 |
+
)
|
| 533 |
+
llm_response = completion.choices[0].message.content or ""
|
| 534 |
+
metrics.llm_calls += 1
|
| 535 |
+
except Exception as e:
|
| 536 |
+
print(f" LLM Error: {e}")
|
| 537 |
+
llm_response = ""
|
| 538 |
+
|
| 539 |
+
# Investigate (required phase gate)
|
| 540 |
+
for var in ["age", "death_date", "ethnicity"]:
|
| 541 |
+
if result.done:
|
| 542 |
+
break
|
| 543 |
+
result = env.step(AuditAction(action_type="investigate_pattern", variable=var))
|
| 544 |
+
metrics.steps += 1
|
| 545 |
+
|
| 546 |
+
# Parse LLM response and flag
|
| 547 |
+
lines = llm_response.strip().split("\n")
|
| 548 |
+
for line in lines:
|
| 549 |
+
if result.done:
|
| 550 |
+
break
|
| 551 |
+
parts = line.strip().split("|")
|
| 552 |
+
if len(parts) >= 2:
|
| 553 |
+
pid = parts[0].strip()
|
| 554 |
+
etype = parts[1].strip().lower().replace(" ", "_")
|
| 555 |
+
if etype not in ("invalid_age", "temporal_inconsistency"):
|
| 556 |
+
continue
|
| 557 |
+
# Check if this patient_id exists
|
| 558 |
+
if not any(p.get("patient_id") == pid for p in dataset):
|
| 559 |
+
continue
|
| 560 |
+
result = env.step(AuditAction(
|
| 561 |
+
action_type="flag_error",
|
| 562 |
+
patient_id=pid,
|
| 563 |
+
error_type=etype,
|
| 564 |
+
reason=parts[2].strip() if len(parts) > 2 else "LLM detected",
|
| 565 |
+
))
|
| 566 |
+
obs = result.observation.model_dump()
|
| 567 |
+
final_score = obs["score_so_far"]
|
| 568 |
+
metrics.record(obs["feedback"])
|
| 569 |
+
metrics.steps += 1
|
| 570 |
+
|
| 571 |
+
# Submit report
|
| 572 |
+
if not result.done:
|
| 573 |
+
result = env.step(AuditAction(
|
| 574 |
+
action_type="submit_report",
|
| 575 |
+
report=(
|
| 576 |
+
"Clinical data audit report. Issues found in patient ages and temporal "
|
| 577 |
+
"sequences. Recommend corrective action for data entry validation. "
|
| 578 |
+
"Risk assessment: HIGH. Impact on regulatory compliance noted."
|
| 579 |
+
),
|
| 580 |
+
))
|
| 581 |
+
obs = result.observation.model_dump()
|
| 582 |
+
final_score = obs["score_so_far"]
|
| 583 |
+
metrics.steps += 1
|
| 584 |
+
|
| 585 |
+
print(metrics.summary())
|
| 586 |
+
return final_score, metrics
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 590 |
+
# AGENT MODE 2: HEURISTIC (simple rules, no LLM)
|
| 591 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 592 |
+
|
| 593 |
+
def run_heuristic_task(client_unused, task_id: str, task_name: str):
|
| 594 |
+
"""
|
| 595 |
+
Heuristic agent: simple threshold rules, no LLM.
|
| 596 |
+
Catches obvious errors but falls for traps. Expected score: ~0.45-0.60
|
| 597 |
+
"""
|
| 598 |
+
print(f"\n Task: {task_name}")
|
| 599 |
+
print(" " + "-" * 50)
|
| 600 |
+
|
| 601 |
+
metrics = MetricsTracker()
|
| 602 |
+
final_score = 0.0
|
| 603 |
+
|
| 604 |
+
with ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync() as env:
|
| 605 |
+
result = env.reset(task_id=task_id, seed=BASELINE_SEED)
|
| 606 |
+
obs = result.observation.model_dump()
|
| 607 |
+
dataset = obs["dataset"]
|
| 608 |
+
task_type = obs["task_type"]
|
| 609 |
+
max_steps = obs["attempts_remaining"]
|
| 610 |
+
print(f" Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 611 |
+
|
| 612 |
+
# Investigate
|
| 613 |
+
for var in ["age", "death_date", "ethnicity"]:
|
| 614 |
+
if result.done:
|
| 615 |
+
break
|
| 616 |
+
result = env.step(AuditAction(action_type="investigate_pattern", variable=var))
|
| 617 |
+
metrics.steps += 1
|
| 618 |
+
|
| 619 |
+
step_budget = max_steps - metrics.steps - 1 # Reserve 1 for report
|
| 620 |
+
flags_made = 0
|
| 621 |
+
|
| 622 |
+
# Simple age check β catches most but may false-positive on boundaries
|
| 623 |
+
for p in dataset:
|
| 624 |
+
if flags_made >= step_budget or result.done:
|
| 625 |
+
break
|
| 626 |
+
age = p.get("age")
|
| 627 |
+
pid = p.get("patient_id")
|
| 628 |
+
|
| 629 |
+
# BUG: heuristic uses < 18 instead of < 18, catching age=18 incorrectly? No.
|
| 630 |
+
# BUG: heuristic uses > 100 instead of > 120, missing ages 101-120 OR
|
| 631 |
+
# flagging valid old patients
|
| 632 |
+
if age is None or age < 18 or age > 100: # Deliberately wrong threshold
|
| 633 |
+
result = env.step(AuditAction(
|
| 634 |
+
action_type="flag_error",
|
| 635 |
+
patient_id=pid,
|
| 636 |
+
error_type="invalid_age",
|
| 637 |
+
reason=f"Age {age} outside expected range",
|
| 638 |
+
))
|
| 639 |
+
obs = result.observation.model_dump()
|
| 640 |
+
final_score = obs["score_so_far"]
|
| 641 |
+
metrics.record(obs["feedback"])
|
| 642 |
+
metrics.steps += 1
|
| 643 |
+
flags_made += 1
|
| 644 |
+
|
| 645 |
+
# Simple temporal check (if applicable)
|
| 646 |
+
if task_type in ("temporal_consistency", "comprehensive_audit"):
|
| 647 |
+
for p in dataset:
|
| 648 |
+
if flags_made >= step_budget or result.done:
|
| 649 |
+
break
|
| 650 |
+
ts = p.get("treatment_start")
|
| 651 |
+
dd = p.get("death_date")
|
| 652 |
+
if ts and dd:
|
| 653 |
+
try:
|
| 654 |
+
t = datetime.strptime(ts, "%Y-%m-%d")
|
| 655 |
+
d = datetime.strptime(dd, "%Y-%m-%d")
|
| 656 |
+
# BUG: heuristic flags ANY death within 7 days (catches traps)
|
| 657 |
+
if d < t or (d - t).days < 7:
|
| 658 |
+
pid = p.get("patient_id")
|
| 659 |
+
result = env.step(AuditAction(
|
| 660 |
+
action_type="flag_error",
|
| 661 |
+
patient_id=pid,
|
| 662 |
+
error_type="temporal_inconsistency",
|
| 663 |
+
reason=f"Suspicious temporal sequence",
|
| 664 |
+
))
|
| 665 |
+
obs = result.observation.model_dump()
|
| 666 |
+
final_score = obs["score_so_far"]
|
| 667 |
+
metrics.record(obs["feedback"])
|
| 668 |
+
metrics.steps += 1
|
| 669 |
+
flags_made += 1
|
| 670 |
+
except ValueError:
|
| 671 |
+
pass
|
| 672 |
+
|
| 673 |
+
# Submit report
|
| 674 |
+
if not result.done:
|
| 675 |
+
result = env.step(AuditAction(
|
| 676 |
+
action_type="submit_report",
|
| 677 |
+
report="Audit complete. Found age and temporal issues. Action recommended.",
|
| 678 |
+
))
|
| 679 |
+
obs = result.observation.model_dump()
|
| 680 |
+
final_score = obs["score_so_far"]
|
| 681 |
+
metrics.steps += 1
|
| 682 |
+
|
| 683 |
+
print(metrics.summary())
|
| 684 |
+
return final_score, metrics
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 688 |
+
# AGENT MODE 3: FULL AGENTIC PIPELINE
|
| 689 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 690 |
+
|
| 691 |
+
def run_full_task(client, task_id: str, task_name: str):
|
| 692 |
+
"""
|
| 693 |
+
Full agent: Statistical detection + LLM reasoning.
|
| 694 |
+
Expected score: ~0.85-0.95
|
| 695 |
+
"""
|
| 696 |
+
print(f"\n Task: {task_name}")
|
| 697 |
+
print(" " + "-" * 50)
|
| 698 |
+
|
| 699 |
+
metrics = MetricsTracker()
|
| 700 |
+
final_score = 0.0
|
| 701 |
+
|
| 702 |
+
with ClinicalTrialAuditorEnv(base_url=ENV_BASE_URL).sync() as env:
|
| 703 |
+
result = env.reset(task_id=task_id, seed=BASELINE_SEED)
|
| 704 |
+
obs = result.observation.model_dump()
|
| 705 |
+
dataset = obs["dataset"]
|
| 706 |
+
task_type = obs["task_type"]
|
| 707 |
+
max_steps = obs["attempts_remaining"]
|
| 708 |
+
print(f" Type: {task_type} | Patients: {len(dataset)} | Max steps: {max_steps}")
|
| 709 |
+
|
| 710 |
+
# 1. PROFILE
|
| 711 |
+
profiler = DataProfiler(dataset)
|
| 712 |
+
profiles = profiler.profile_all()
|
| 713 |
+
ap = profiles.get("age", {})
|
| 714 |
+
print(f" Profile: age median={ap.get('median','?')}, "
|
| 715 |
+
f"range=[{ap.get('min','?')}-{ap.get('max','?')}], "
|
| 716 |
+
f"nulls={ap.get('null_count',0)}")
|
| 717 |
+
|
| 718 |
+
# 2. DETECT
|
| 719 |
+
all_findings = []
|
| 720 |
+
all_findings.extend(AgeAnomalyDetector().detect(dataset, profiles))
|
| 721 |
+
if task_type in ("temporal_consistency", "comprehensive_audit"):
|
| 722 |
+
all_findings.extend(TemporalConsistencyDetector().detect(dataset, profiles))
|
| 723 |
+
if task_type == "comprehensive_audit":
|
| 724 |
+
all_findings.extend(SelectionBiasDetector().detect(dataset, profiles))
|
| 725 |
+
|
| 726 |
+
age_n = sum(1 for f in all_findings if f.error_type == "invalid_age")
|
| 727 |
+
temp_n = sum(1 for f in all_findings if f.error_type == "temporal_inconsistency")
|
| 728 |
+
bias_n = sum(1 for f in all_findings if f.error_type == "selection_bias")
|
| 729 |
+
print(f" Detected: {age_n} age | {temp_n} temporal | {bias_n} bias")
|
| 730 |
+
|
| 731 |
+
# 3. PLAN
|
| 732 |
+
planner = ActionPlanner()
|
| 733 |
+
actions = planner.plan(all_findings, task_type, max_steps=max_steps)
|
| 734 |
+
|
| 735 |
+
# 4. REASON (1 LLM call for report)
|
| 736 |
+
report_text = generate_expert_report(client, all_findings, profiles, task_type)
|
| 737 |
+
metrics.llm_calls += 1
|
| 738 |
+
|
| 739 |
+
# 5. EXECUTE
|
| 740 |
+
step = 0
|
| 741 |
+
for action in actions:
|
| 742 |
+
if result.done:
|
| 743 |
+
break
|
| 744 |
+
result = env.step(action)
|
| 745 |
+
obs = result.observation.model_dump()
|
| 746 |
+
final_score = obs["score_so_far"]
|
| 747 |
+
feedback = obs["feedback"]
|
| 748 |
+
step += 1
|
| 749 |
+
metrics.steps = step
|
| 750 |
+
if action.action_type == "flag_error":
|
| 751 |
+
metrics.record(feedback)
|
| 752 |
+
# Print progress every 5 steps or for flags
|
| 753 |
+
if action.action_type == "flag_error" or step <= 3:
|
| 754 |
+
print(f" Step {step}: score={final_score:.2f} | {feedback[:65]}")
|
| 755 |
+
|
| 756 |
+
# 6. REPORT
|
| 757 |
+
if not result.done:
|
| 758 |
+
result = env.step(AuditAction(action_type="submit_report", report=report_text))
|
| 759 |
+
obs = result.observation.model_dump()
|
| 760 |
+
final_score = obs["score_so_far"]
|
| 761 |
+
step += 1
|
| 762 |
+
metrics.steps = step
|
| 763 |
+
print(f" Step {step}: score={final_score:.2f} | Report submitted")
|
| 764 |
+
|
| 765 |
+
print(metrics.summary())
|
| 766 |
+
return final_score, metrics
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 770 |
+
# ORCHESTRATOR
|
| 771 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 772 |
+
|
| 773 |
+
TASK_LIST = {
|
| 774 |
+
"task_easy": "Syntactic Cleaning (Easy)",
|
| 775 |
+
"task_medium": "Temporal Consistency (Medium)",
|
| 776 |
+
"task_hard": "Equity Bias Audit (Hard)",
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def run_agent(mode: str, client):
|
| 781 |
+
"""Run one agent mode across all tasks."""
|
| 782 |
+
runner = {
|
| 783 |
+
"naive": run_naive_task,
|
| 784 |
+
"heuristic": run_heuristic_task,
|
| 785 |
+
"full": run_full_task,
|
| 786 |
+
}[mode]
|
| 787 |
+
|
| 788 |
+
scores, all_metrics = [], []
|
| 789 |
+
t0 = time.time()
|
| 790 |
+
|
| 791 |
+
for tid, tname in TASK_LIST.items():
|
| 792 |
+
score, m = runner(client, tid, tname)
|
| 793 |
+
scores.append(score)
|
| 794 |
+
all_metrics.append(m)
|
| 795 |
+
print(f" β Final: {score:.2f}\n")
|
| 796 |
+
|
| 797 |
+
elapsed = time.time() - t0
|
| 798 |
+
avg = sum(scores) / len(scores)
|
| 799 |
+
total_steps = sum(m.steps for m in all_metrics)
|
| 800 |
+
total_llm = sum(m.llm_calls for m in all_metrics)
|
| 801 |
+
avg_prec = statistics.mean(m.precision for m in all_metrics) if all_metrics else 0
|
| 802 |
+
|
| 803 |
+
return {
|
| 804 |
+
"mode": mode,
|
| 805 |
+
"scores": dict(zip(TASK_LIST.keys(), scores)),
|
| 806 |
+
"average": avg,
|
| 807 |
+
"elapsed": elapsed,
|
| 808 |
+
"total_steps": total_steps,
|
| 809 |
+
"total_llm": total_llm,
|
| 810 |
+
"avg_precision": avg_prec,
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def main():
|
| 815 |
+
parser = argparse.ArgumentParser(description="Clinical Trial Auditor Baseline Inference")
|
| 816 |
+
parser.add_argument("--mode", choices=["naive", "heuristic", "full", "all"],
|
| 817 |
+
default="full", help="Agent mode (default: full)")
|
| 818 |
+
args = parser.parse_args()
|
| 819 |
+
|
| 820 |
+
# Only create LLM client when needed (heuristic mode doesn't use LLM)
|
| 821 |
+
needs_llm = args.mode in ("naive", "full", "all")
|
| 822 |
+
if needs_llm:
|
| 823 |
+
api_key = API_KEY or os.getenv("OPENAI_API_KEY")
|
| 824 |
+
if not api_key:
|
| 825 |
+
print("WARNING: No API key found. Set HF_TOKEN, API_KEY, or OPENAI_API_KEY.")
|
| 826 |
+
print(" Falling back to heuristic mode.")
|
| 827 |
+
args.mode = "heuristic"
|
| 828 |
+
client = None
|
| 829 |
+
else:
|
| 830 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
|
| 831 |
+
else:
|
| 832 |
+
client = None
|
| 833 |
+
|
| 834 |
+
print("=" * 65)
|
| 835 |
+
print(" Clinical Trial Auditor β Baseline Inference")
|
| 836 |
+
print(" Procedural Dataset Generation | Adversarial Traps | Seed-Reproducible")
|
| 837 |
+
print(f" Model: {MODEL_NAME}")
|
| 838 |
+
print(f" Seed: {BASELINE_SEED}")
|
| 839 |
+
print("=" * 65)
|
| 840 |
+
|
| 841 |
+
if args.mode == "all":
|
| 842 |
+
modes = ["naive", "heuristic", "full"]
|
| 843 |
+
else:
|
| 844 |
+
modes = [args.mode]
|
| 845 |
+
|
| 846 |
+
all_results = []
|
| 847 |
+
for mode in modes:
|
| 848 |
+
print(f"\n{'β' * 65}")
|
| 849 |
+
print(f" AGENT: {mode.upper()}")
|
| 850 |
+
print(f"{'β' * 65}")
|
| 851 |
+
result = run_agent(mode, client)
|
| 852 |
+
all_results.append(result)
|
| 853 |
+
|
| 854 |
+
# ββ Final Report ββ
|
| 855 |
+
print("\n" + "=" * 65)
|
| 856 |
+
print(" BENCHMARK RESULTS")
|
| 857 |
+
print("=" * 65)
|
| 858 |
+
|
| 859 |
+
if len(all_results) > 1:
|
| 860 |
+
# Multi-agent comparison table
|
| 861 |
+
header = f" {'Agent':<15} {'Easy':>8} {'Medium':>8} {'Hard':>8} {'Avg':>8} {'Prec':>8} {'Time':>8}"
|
| 862 |
+
print(header)
|
| 863 |
+
print(" " + "-" * 63)
|
| 864 |
+
for r in all_results:
|
| 865 |
+
scores = r["scores"]
|
| 866 |
+
print(f" {r['mode'].upper():<15} "
|
| 867 |
+
f"{scores.get('task_easy', 0):.2f} "
|
| 868 |
+
f"{scores.get('task_medium', 0):.2f} "
|
| 869 |
+
f"{scores.get('task_hard', 0):.2f} "
|
| 870 |
+
f"{r['average']:.2f} "
|
| 871 |
+
f"{r['avg_precision']:.0%} "
|
| 872 |
+
f"{r['elapsed']:.1f}s")
|
| 873 |
+
else:
|
| 874 |
+
r = all_results[0]
|
| 875 |
+
for tid, tname in TASK_LIST.items():
|
| 876 |
+
score = r["scores"].get(tid, 0)
|
| 877 |
+
print(f" {tname:35s}: {score:.2f}")
|
| 878 |
+
print(f"\n Average score: {r['average']:.2f}")
|
| 879 |
+
print(f" Total time: {r['elapsed']:.1f}s")
|
| 880 |
+
print(f" LLM calls: {r['total_llm']}")
|
| 881 |
+
print(f" Total steps: {r['total_steps']}")
|
| 882 |
+
print(f" Average precision: {r['avg_precision']:.0%}")
|
| 883 |
+
|
| 884 |
+
print("=" * 65)
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
if __name__ == "__main__":
|
| 888 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List, Dict, Any
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from openenv.core.env_server import Action, Observation, State
|
| 4 |
+
|
| 5 |
+
class AuditAction(Action):
|
| 6 |
+
action_type: str = "flag_error"
|
| 7 |
+
patient_id: Optional[str] = None
|
| 8 |
+
error_type: Optional[str] = None
|
| 9 |
+
reason: Optional[str] = None
|
| 10 |
+
proposed_value: Optional[str] = None
|
| 11 |
+
variable: Optional[str] = None
|
| 12 |
+
report: Optional[str] = None
|
| 13 |
+
confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
|
| 14 |
+
|
| 15 |
+
class AuditObservation(Observation):
|
| 16 |
+
done: bool = False
|
| 17 |
+
reward: float = 0.0
|
| 18 |
+
task_id: str = ""
|
| 19 |
+
task_type: str = ""
|
| 20 |
+
task_description: str = ""
|
| 21 |
+
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 22 |
+
errors_found: List[str] = Field(default_factory=list)
|
| 23 |
+
patterns_investigated: List[str] = Field(default_factory=list)
|
| 24 |
+
distributions_computed: List[str] = Field(default_factory=list)
|
| 25 |
+
feedback: Optional[str] = None
|
| 26 |
+
score_so_far: float = 0.0
|
| 27 |
+
attempts_remaining: int = 15
|
| 28 |
+
phase: str = "investigation"
|
| 29 |
+
|
| 30 |
+
class AuditState(State):
|
| 31 |
+
episode_id: str = ""
|
| 32 |
+
step_count: int = 0
|
| 33 |
+
task_id: str = ""
|
| 34 |
+
task_type: str = ""
|
| 35 |
+
total_errors: int = 0
|
| 36 |
+
errors_found: int = 0
|
| 37 |
+
current_score: float = 0.0
|
| 38 |
+
attempts: int = 0
|
| 39 |
+
phase: str = "investigation"
|
| 40 |
+
patterns_investigated: List[str] = Field(default_factory=list)
|
| 41 |
+
distributions_computed: List[str] = Field(default_factory=list)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: clinical_trial_auditor
|
| 2 |
+
version: "2.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A production-grade Reinforcement Learning environment for medical AI alignment.
|
| 5 |
+
The agent acts as a Senior Clinical Data Manager, utilizing a strict multi-phase
|
| 6 |
+
workflow (Investigation β Flagging β Reporting) to identify syntactic errors,
|
| 7 |
+
temporal violations, and multi-dimensional intersectional bias in trial datasets.
|
| 8 |
+
author: Sumit Saraswat
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- clinical
|
| 12 |
+
- rl-benchmark
|
| 13 |
+
- medical-bias
|
| 14 |
+
- ai-safety
|
| 15 |
+
tasks:
|
| 16 |
+
- id: task_easy
|
| 17 |
+
name: Syntactic Cleaning
|
| 18 |
+
difficulty: easy
|
| 19 |
+
description: Investigate dataset distribution and flag patients with invalid age values (out of 18-120 range or null).
|
| 20 |
+
- id: task_medium
|
| 21 |
+
name: Temporal Consistency
|
| 22 |
+
difficulty: medium
|
| 23 |
+
description: Investigate temporal variables and flag patients where death_date precedes treatment_start.
|
| 24 |
+
- id: task_hard
|
| 25 |
+
name: Equity Bias Audit
|
| 26 |
+
difficulty: hard
|
| 27 |
+
description: Perform multi-dimensional statistical analysis to detect intersectional selection bias affecting minority control group outcomes.
|
openenv_clinical_trial_auditor.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-clinical_trial_auditor
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Clinical Trial Auditor environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.1
|
| 7 |
+
Provides-Extra: dev
|
| 8 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 9 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_clinical_trial_auditor.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./client.py
|
| 5 |
+
./inference.py
|
| 6 |
+
./models.py
|
| 7 |
+
openenv_clinical_trial_auditor.egg-info/PKG-INFO
|
| 8 |
+
openenv_clinical_trial_auditor.egg-info/SOURCES.txt
|
| 9 |
+
openenv_clinical_trial_auditor.egg-info/dependency_links.txt
|
| 10 |
+
openenv_clinical_trial_auditor.egg-info/entry_points.txt
|
| 11 |
+
openenv_clinical_trial_auditor.egg-info/requires.txt
|
| 12 |
+
openenv_clinical_trial_auditor.egg-info/top_level.txt
|
| 13 |
+
server/__init__.py
|
| 14 |
+
server/app.py
|
| 15 |
+
server/clinical_trial_auditor_environment.py
|
openenv_clinical_trial_auditor.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_clinical_trial_auditor.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = clinical_trial_auditor.server.app:main
|
openenv_clinical_trial_auditor.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.1
|
| 2 |
+
|
| 3 |
+
[dev]
|
| 4 |
+
pytest>=8.0.0
|
| 5 |
+
pytest-cov>=4.0.0
|
openenv_clinical_trial_auditor.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
clinical_trial_auditor
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-clinical_trial_auditor"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Clinical Trial Auditor environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.1",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m clinical_trial_auditor.server.app
|
| 40 |
+
server = "clinical_trial_auditor.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["clinical_trial_auditor", "clinical_trial_auditor.server"]
|
| 45 |
+
package-dir = { "clinical_trial_auditor" = ".", "clinical_trial_auditor.server" = "server" }
|
server.log
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
INFO: Started server process [97062]
|
| 2 |
+
INFO: Waiting for application startup.
|
| 3 |
+
INFO: Application startup complete.
|
| 4 |
+
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
|
| 5 |
+
INFO: 127.0.0.1:52551 - "GET /health HTTP/1.1" 200 OK
|
| 6 |
+
INFO: 127.0.0.1:52556 - "GET /health HTTP/1.1" 200 OK
|
| 7 |
+
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/websockets/legacy/server.py:1178: DeprecationWarning: remove second argument of ws_handler
|
| 8 |
+
warnings.warn("remove second argument of ws_handler", DeprecationWarning)
|
| 9 |
+
INFO: 127.0.0.1:52578 - "WebSocket /ws" [accepted]
|
| 10 |
+
INFO: connection open
|
| 11 |
+
INFO: connection closed
|
| 12 |
+
INFO: 127.0.0.1:52580 - "WebSocket /ws" [accepted]
|
| 13 |
+
INFO: connection open
|
| 14 |
+
INFO: connection closed
|
| 15 |
+
INFO: 127.0.0.1:52582 - "WebSocket /ws" [accepted]
|
| 16 |
+
INFO: connection open
|
| 17 |
+
INFO: connection closed
|
| 18 |
+
INFO: 127.0.0.1:52972 - "WebSocket /ws" [accepted]
|
| 19 |
+
INFO: connection open
|
| 20 |
+
INFO: connection closed
|
| 21 |
+
INFO: 127.0.0.1:52975 - "WebSocket /ws" [accepted]
|
| 22 |
+
INFO: connection open
|
| 23 |
+
INFO: connection closed
|
| 24 |
+
INFO: 127.0.0.1:52977 - "WebSocket /ws" [accepted]
|
| 25 |
+
INFO: connection open
|
| 26 |
+
INFO: connection closed
|
| 27 |
+
INFO: 127.0.0.1:53787 - "GET /health HTTP/1.1" 200 OK
|
| 28 |
+
INFO: 127.0.0.1:53800 - "WebSocket /ws" [accepted]
|
| 29 |
+
INFO: connection open
|
| 30 |
+
INFO: connection closed
|
| 31 |
+
INFO: 127.0.0.1:53802 - "WebSocket /ws" [accepted]
|
| 32 |
+
INFO: connection open
|
| 33 |
+
INFO: connection closed
|
| 34 |
+
INFO: 127.0.0.1:53804 - "WebSocket /ws" [accepted]
|
| 35 |
+
INFO: connection open
|
| 36 |
+
INFO: connection closed
|
server/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install curl for healthcheck
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
# Copy requirements first for Docker layer caching
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy all server files
|
| 13 |
+
COPY . .
|
| 14 |
+
|
| 15 |
+
EXPOSE 8000
|
| 16 |
+
|
| 17 |
+
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
|
| 18 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 19 |
+
|
| 20 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Clinical Trial Auditor environment server components."""
|
| 8 |
+
|
| 9 |
+
from .clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["ClinicalTrialAuditorEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openenv.core.env_server import create_fastapi_app
|
| 2 |
+
from clinical_trial_auditor_environment import ClinicalTrialAuditorEnvironment
|
| 3 |
+
from models import AuditAction, AuditObservation
|
| 4 |
+
import uvicorn
|
| 5 |
+
|
| 6 |
+
app = create_fastapi_app(ClinicalTrialAuditorEnvironment, AuditAction, AuditObservation)
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
main()
|
server/clinical_trial_auditor_environment.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Clinical Trial Auditor β OpenEnv Environment
|
| 3 |
+
=============================================
|
| 4 |
+
A production-grade adversarial RL environment for medical AI alignment
|
| 5 |
+
and clinical data quality evaluation.
|
| 6 |
+
|
| 7 |
+
The agent acts as a Senior Clinical Data Manager auditing procedurally
|
| 8 |
+
generated clinical trial datasets from a multi-site Phase III oncology trial.
|
| 9 |
+
|
| 10 |
+
Architecture layers:
|
| 11 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
+
β Agent Interface (OpenEnv API) β
|
| 13 |
+
β step() / reset() / state() β
|
| 14 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ€
|
| 15 |
+
β Scoring Engine (Grader) β
|
| 16 |
+
β Ground-truth comparison, partial credit, β
|
| 17 |
+
β confidence calibration, score composition β
|
| 18 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ€
|
| 19 |
+
β Trap Engine (Adversarial) β
|
| 20 |
+
β Boundary traps, temporal traps, fake β
|
| 21 |
+
β bias patterns, distractor injection β
|
| 22 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ€
|
| 23 |
+
β Data Engine (Generator) β
|
| 24 |
+
β Statistical distributions, demographics, β
|
| 25 |
+
β reproducible seeds, configurable params β
|
| 26 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
|
| 28 |
+
Key design decisions:
|
| 29 |
+
- Procedural generation: every reset() β unique dataset β no memorization
|
| 30 |
+
- Ground-truth grading: errors are pre-computed, grading is O(1) lookup
|
| 31 |
+
- Confidence-calibrated scoring: overconfident + wrong = devastating penalty
|
| 32 |
+
- False positive cost 3Γ correct reward β forces precision over recall
|
| 33 |
+
- Adversarial traps: boundary-valid ages, near-temporal cases, fake patterns
|
| 34 |
+
- Multi-phase workflow: Investigation β Flagging β Reporting
|
| 35 |
+
- Seed-based reproducibility for deterministic evaluation
|
| 36 |
+
"""
|
| 37 |
+
import uuid
|
| 38 |
+
from datetime import datetime
|
| 39 |
+
from openenv.core.env_server import Environment
|
| 40 |
+
from models import AuditAction, AuditObservation, AuditState
|
| 41 |
+
from dataset_generator import DatasetGenerator
|
| 42 |
+
|
| 43 |
+
# ββ Reward Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
# Calibrated: optimal play β ~0.85-0.95, careless play β devastated
|
| 45 |
+
# Key design: false_positive = 3Γ correct_flag β DESTROYS guessing strategies
|
| 46 |
+
REWARD_CONFIG = {
|
| 47 |
+
"correct_flag": 0.10, # +0.10 per correct error flag
|
| 48 |
+
"false_positive": -0.30, # -0.30 per wrong flag (3x correct β destroys guessing)
|
| 49 |
+
"duplicate_flag": -0.10, # -0.10 per duplicate flag
|
| 50 |
+
"investigate_new": 0.05, # +0.05 for investigating a new variable
|
| 51 |
+
"investigate_redundant": -0.02, # -0.02 for re-investigating (penalizes loops)
|
| 52 |
+
"distribution_new": 0.04, # +0.04 for computing new distribution
|
| 53 |
+
"distribution_redundant": -0.02,
|
| 54 |
+
"invalid_phase": -0.05, # -0.05 for acting in wrong phase
|
| 55 |
+
"unknown_action": -0.05, # -0.05 for invalid action types
|
| 56 |
+
"cost_per_step": 0.005, # -0.005 per step (encourages efficiency)
|
| 57 |
+
"bonus_efficiency": 0.03, # +0.03 when β₯3 investigated AND β₯3 flagged
|
| 58 |
+
"bonus_workflow": 0.03, # +0.03 for correct workflow sequence
|
| 59 |
+
"bias_detected": 0.20, # +0.20 for correctly identifying selection bias
|
| 60 |
+
"propose_fix_valid": 0.03,
|
| 61 |
+
"propose_fix_invalid": -0.05,
|
| 62 |
+
"report_bonus_base": 0.05, # +0.05 base for submitting report
|
| 63 |
+
"overconfidence_multiplier": 2.0, # 2x penalty when wrong + confidence > 0.8
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
# TASK DEFINITIONS
|
| 68 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
+
|
| 70 |
+
TASKS = {
|
| 71 |
+
"task_easy": {
|
| 72 |
+
"task_id": "task_easy",
|
| 73 |
+
"task_type": "syntactic_cleaning",
|
| 74 |
+
"difficulty": "easy",
|
| 75 |
+
"allow_bias": False,
|
| 76 |
+
"task_description": (
|
| 77 |
+
"CLINICAL DATA AUDIT β Phase III Oncology Trial (ONCO-AX-2024)\n"
|
| 78 |
+
"Role: Senior Clinical Data Manager\n\n"
|
| 79 |
+
"PHASE 1 β INVESTIGATION:\n"
|
| 80 |
+
" Use investigate_pattern(variable=<col>) to profile key variables\n"
|
| 81 |
+
" Use compute_distribution(variable=<col>) to compute descriptive stats\n\n"
|
| 82 |
+
"PHASE 2 β FLAGGING:\n"
|
| 83 |
+
" Use flag_error(patient_id=<id>, error_type='invalid_age') for age violations\n"
|
| 84 |
+
" Valid age range for trial eligibility: 18-120 (inclusive)\n"
|
| 85 |
+
" Missing age (null) is also invalid β required field\n"
|
| 86 |
+
" CAUTION: Some ages are rare but valid (e.g., 95, 19, 120). Do NOT over-flag.\n\n"
|
| 87 |
+
"PHASE 3 β REPORTING:\n"
|
| 88 |
+
" Use submit_report(report=<comprehensive analysis>) to finalize\n\n"
|
| 89 |
+
"Objective: Find ALL patients with invalid ages. Avoid false positives."
|
| 90 |
+
),
|
| 91 |
+
},
|
| 92 |
+
"task_medium": {
|
| 93 |
+
"task_id": "task_medium",
|
| 94 |
+
"task_type": "temporal_consistency",
|
| 95 |
+
"difficulty": "medium",
|
| 96 |
+
"allow_bias": False,
|
| 97 |
+
"task_description": (
|
| 98 |
+
"CLINICAL DATA AUDIT β Phase III Oncology Trial (ONCO-AX-2024)\n"
|
| 99 |
+
"Role: Senior Clinical Data Manager\n\n"
|
| 100 |
+
"PHASE 1 β INVESTIGATION:\n"
|
| 101 |
+
" Use investigate_pattern(variable=<col>) to profile key variables\n"
|
| 102 |
+
" Use compute_distribution(variable=<col>) to compute descriptive stats\n\n"
|
| 103 |
+
"PHASE 2 β FLAGGING:\n"
|
| 104 |
+
" Use flag_error with error_type='invalid_age' OR 'temporal_inconsistency'\n"
|
| 105 |
+
" Age violations: outside range 18-120 (inclusive) or null\n"
|
| 106 |
+
" Temporal violations: death_date MUST NOT precede treatment_start\n"
|
| 107 |
+
" NOTE: A patient dying 1 day after treatment start IS valid (not an error)\n\n"
|
| 108 |
+
"PHASE 3 β REPORTING:\n"
|
| 109 |
+
" Use submit_report(report=<comprehensive analysis>) to finalize\n\n"
|
| 110 |
+
"Objective: Find ALL age errors AND temporal inconsistencies."
|
| 111 |
+
),
|
| 112 |
+
},
|
| 113 |
+
"task_hard": {
|
| 114 |
+
"task_id": "task_hard",
|
| 115 |
+
"task_type": "comprehensive_audit",
|
| 116 |
+
"difficulty": "hard",
|
| 117 |
+
"allow_bias": True,
|
| 118 |
+
"task_description": (
|
| 119 |
+
"CLINICAL DATA AUDIT β Phase III Oncology Trial (ONCO-AX-2024)\n"
|
| 120 |
+
"Role: Senior Clinical Data Manager\n\n"
|
| 121 |
+
"PHASE 1 β INVESTIGATION:\n"
|
| 122 |
+
" Use investigate_pattern(variable=<col>) to profile key variables\n"
|
| 123 |
+
" Use compute_distribution(variable=<col>) to compute descriptive stats\n"
|
| 124 |
+
" IMPORTANT: Analyze ethnicity, gender, and outcome distributions in control group\n\n"
|
| 125 |
+
"PHASE 2 β FLAGGING:\n"
|
| 126 |
+
" flag_error with error_type='invalid_age', 'temporal_inconsistency', or 'selection_bias'\n"
|
| 127 |
+
" For selection_bias: Identify if the control group has demographic imbalance\n"
|
| 128 |
+
" AND whether this correlates with outcome disparity across subgroups\n"
|
| 129 |
+
" Look for: representation bias, outcome disparity, intersectional patterns\n\n"
|
| 130 |
+
"PHASE 3 β REPORTING:\n"
|
| 131 |
+
" Use submit_report(report=<comprehensive analysis>) to finalize\n"
|
| 132 |
+
" Include: statistical evidence, root cause analysis, corrective recommendations\n\n"
|
| 133 |
+
"Objective: Find ALL data errors AND demographic bias patterns."
|
| 134 |
+
),
|
| 135 |
+
},
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Maximum steps per episode β scales with dataset size
|
| 139 |
+
MAX_STEPS = {
|
| 140 |
+
"task_easy": 20,
|
| 141 |
+
"task_medium": 30,
|
| 142 |
+
"task_hard": 40,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
+
# ENVIRONMENT IMPLEMENTATION
|
| 148 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
+
|
| 150 |
+
class ClinicalTrialAuditorEnvironment(Environment):
|
| 151 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 152 |
+
|
| 153 |
+
def __init__(self):
|
| 154 |
+
self._action_history = []
|
| 155 |
+
self._state = AuditState()
|
| 156 |
+
self._current_task = None
|
| 157 |
+
self._dataset = []
|
| 158 |
+
self._ground_truth = {} # {patient_id: [error_types]}
|
| 159 |
+
self._traps = set() # valid-but-suspicious patient_ids
|
| 160 |
+
self._bias_present = False
|
| 161 |
+
self._flagged_patients = set()
|
| 162 |
+
self._patterns_investigated = set()
|
| 163 |
+
self._distributions_computed = set()
|
| 164 |
+
self._attempts = 0
|
| 165 |
+
self._max_steps = 15
|
| 166 |
+
self._report_submitted = False
|
| 167 |
+
self._phase = "investigation"
|
| 168 |
+
self._score_log = [] # Track score composition for transparency
|
| 169 |
+
|
| 170 |
+
def reset(self, seed=None, episode_id=None, **kwargs) -> AuditObservation:
|
| 171 |
+
"""
|
| 172 |
+
Reset the environment with a procedurally generated dataset.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
seed: Random seed for reproducibility. Same seed = identical dataset.
|
| 176 |
+
episode_id: Optional episode identifier.
|
| 177 |
+
task_id: "task_easy" | "task_medium" | "task_hard"
|
| 178 |
+
"""
|
| 179 |
+
self._action_history = []
|
| 180 |
+
task_id = kwargs.get("task_id", "task_easy")
|
| 181 |
+
if task_id not in TASKS:
|
| 182 |
+
task_id = "task_easy"
|
| 183 |
+
|
| 184 |
+
self._current_task = TASKS[task_id]
|
| 185 |
+
difficulty = self._current_task["difficulty"]
|
| 186 |
+
|
| 187 |
+
# ββ Procedural dataset generation ββ
|
| 188 |
+
generator = DatasetGenerator(seed=seed)
|
| 189 |
+
result = generator.generate(difficulty=difficulty)
|
| 190 |
+
|
| 191 |
+
self._dataset = result["dataset"]
|
| 192 |
+
self._ground_truth = result["ground_truth"]
|
| 193 |
+
self._traps = result["traps"]
|
| 194 |
+
self._bias_present = result["bias_present"]
|
| 195 |
+
gen_stats = result["stats"]
|
| 196 |
+
|
| 197 |
+
self._flagged_patients = set()
|
| 198 |
+
self._patterns_investigated = set()
|
| 199 |
+
self._distributions_computed = set()
|
| 200 |
+
self._attempts = 0
|
| 201 |
+
self._max_steps = MAX_STEPS.get(task_id, 20)
|
| 202 |
+
self._report_submitted = False
|
| 203 |
+
self._phase = "investigation"
|
| 204 |
+
self._score_log = []
|
| 205 |
+
|
| 206 |
+
total_errs = gen_stats["total_errors"]
|
| 207 |
+
|
| 208 |
+
self._state = AuditState(
|
| 209 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 210 |
+
step_count=0,
|
| 211 |
+
task_id=task_id,
|
| 212 |
+
task_type=self._current_task["task_type"],
|
| 213 |
+
total_errors=total_errs,
|
| 214 |
+
errors_found=0,
|
| 215 |
+
current_score=0.0,
|
| 216 |
+
attempts=0,
|
| 217 |
+
phase="investigation",
|
| 218 |
+
patterns_investigated=[],
|
| 219 |
+
distributions_computed=[],
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return AuditObservation(
|
| 223 |
+
done=False,
|
| 224 |
+
reward=0.0,
|
| 225 |
+
task_id=task_id,
|
| 226 |
+
task_type=self._current_task["task_type"],
|
| 227 |
+
task_description=self._current_task["task_description"],
|
| 228 |
+
dataset=self._dataset,
|
| 229 |
+
errors_found=[],
|
| 230 |
+
patterns_investigated=[],
|
| 231 |
+
distributions_computed=[],
|
| 232 |
+
feedback=(
|
| 233 |
+
f"Audit started. Dataset: {len(self._dataset)} patients across "
|
| 234 |
+
f"multiple sites and countries. Begin with investigate_pattern "
|
| 235 |
+
f"to profile the dataset."
|
| 236 |
+
),
|
| 237 |
+
score_so_far=0.0,
|
| 238 |
+
attempts_remaining=self._max_steps,
|
| 239 |
+
phase="investigation",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def step(self, action: AuditAction, **kwargs) -> AuditObservation:
|
| 243 |
+
if self._current_task is None:
|
| 244 |
+
return AuditObservation(
|
| 245 |
+
done=True, reward=0.0, task_id="", task_type="",
|
| 246 |
+
task_description="Call reset() first.", dataset=[],
|
| 247 |
+
errors_found=[], patterns_investigated=[],
|
| 248 |
+
distributions_computed=[], feedback="No active episode.",
|
| 249 |
+
score_so_far=0.0, attempts_remaining=0, phase="investigation",
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
self._action_history.append(action.action_type)
|
| 253 |
+
self._attempts += 1
|
| 254 |
+
self._state.step_count += 1
|
| 255 |
+
self._state.attempts = self._attempts
|
| 256 |
+
|
| 257 |
+
# Core grading against ground truth
|
| 258 |
+
step_reward, feedback = self._grade(action)
|
| 259 |
+
|
| 260 |
+
# ββ Confidence-calibrated scoring ββ
|
| 261 |
+
agent_confidence = action.confidence
|
| 262 |
+
if agent_confidence is not None and action.action_type == "flag_error":
|
| 263 |
+
agent_confidence = max(0.0, min(1.0, agent_confidence))
|
| 264 |
+
if step_reward < 0: # Wrong answer
|
| 265 |
+
if agent_confidence > 0.8:
|
| 266 |
+
step_reward *= REWARD_CONFIG["overconfidence_multiplier"]
|
| 267 |
+
feedback += f" [OVERCONFIDENCE PENALTY: conf={agent_confidence:.0%}]"
|
| 268 |
+
elif step_reward > 0: # Correct answer
|
| 269 |
+
step_reward *= max(0.5, agent_confidence)
|
| 270 |
+
|
| 271 |
+
# Step cost (progressive β later steps cost more)
|
| 272 |
+
step_cost = REWARD_CONFIG["cost_per_step"] * (1 + self._attempts * 0.05)
|
| 273 |
+
step_reward -= step_cost
|
| 274 |
+
|
| 275 |
+
# Anti brute-force (punish spinning without flagging)
|
| 276 |
+
if self._attempts > self._max_steps // 2 and len(self._flagged_patients) < 3:
|
| 277 |
+
step_reward -= 0.05
|
| 278 |
+
|
| 279 |
+
# Efficiency bonus
|
| 280 |
+
if len(self._patterns_investigated) >= 3 and len(self._flagged_patients) >= 3:
|
| 281 |
+
step_reward += REWARD_CONFIG["bonus_efficiency"]
|
| 282 |
+
|
| 283 |
+
# Workflow sequence bonus
|
| 284 |
+
if len(self._action_history) >= 3:
|
| 285 |
+
if self._action_history[-3:] == [
|
| 286 |
+
"investigate_pattern", "compute_distribution", "flag_error"
|
| 287 |
+
]:
|
| 288 |
+
step_reward += REWARD_CONFIG["bonus_workflow"]
|
| 289 |
+
|
| 290 |
+
# Difficulty multiplier
|
| 291 |
+
mult = {
|
| 292 |
+
"task_easy": 1.0, "task_medium": 1.2, "task_hard": 1.5
|
| 293 |
+
}.get(self._current_task["task_id"], 1.0)
|
| 294 |
+
step_reward = round(step_reward * mult, 3)
|
| 295 |
+
step_reward = max(-0.5, step_reward)
|
| 296 |
+
|
| 297 |
+
self._state.current_score = max(
|
| 298 |
+
0.0, min(1.0, self._state.current_score + step_reward)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Log score composition
|
| 302 |
+
self._score_log.append({
|
| 303 |
+
"step": self._attempts,
|
| 304 |
+
"action": action.action_type,
|
| 305 |
+
"reward": step_reward,
|
| 306 |
+
"cumulative": self._state.current_score,
|
| 307 |
+
})
|
| 308 |
+
|
| 309 |
+
done = self._report_submitted or self._attempts >= self._max_steps
|
| 310 |
+
|
| 311 |
+
return AuditObservation(
|
| 312 |
+
done=done,
|
| 313 |
+
reward=step_reward,
|
| 314 |
+
task_id=self._current_task["task_id"],
|
| 315 |
+
task_type=self._current_task["task_type"],
|
| 316 |
+
task_description=self._current_task["task_description"],
|
| 317 |
+
dataset=self._dataset,
|
| 318 |
+
errors_found=list(self._flagged_patients),
|
| 319 |
+
patterns_investigated=list(self._patterns_investigated),
|
| 320 |
+
distributions_computed=list(self._distributions_computed),
|
| 321 |
+
feedback=feedback,
|
| 322 |
+
score_so_far=self._state.current_score,
|
| 323 |
+
attempts_remaining=max(0, self._max_steps - self._attempts),
|
| 324 |
+
phase=self._phase,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def state(self) -> AuditState:
|
| 329 |
+
return self._state
|
| 330 |
+
|
| 331 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
+
# SCORING ENGINE β Deterministic grading against ground truth
|
| 333 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 334 |
+
|
| 335 |
+
def _grade(self, action: AuditAction):
|
| 336 |
+
"""Route action to appropriate grader with phase validation."""
|
| 337 |
+
# Phase validation
|
| 338 |
+
if self._phase == "investigation" and action.action_type in [
|
| 339 |
+
"flag_error", "submit_report"
|
| 340 |
+
]:
|
| 341 |
+
return (
|
| 342 |
+
REWARD_CONFIG["invalid_phase"],
|
| 343 |
+
"PHASE BLOCKED: Investigate variables before flagging. "
|
| 344 |
+
"Use investigate_pattern or compute_distribution first."
|
| 345 |
+
)
|
| 346 |
+
if (self._phase == "flagging"
|
| 347 |
+
and action.action_type == "submit_report"
|
| 348 |
+
and len(self._flagged_patients) == 0):
|
| 349 |
+
return (
|
| 350 |
+
REWARD_CONFIG["invalid_phase"],
|
| 351 |
+
"PHASE BLOCKED: Flag at least one issue before submitting report."
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if action.action_type == "investigate_pattern":
|
| 355 |
+
return self._grade_investigate(action)
|
| 356 |
+
elif action.action_type == "compute_distribution":
|
| 357 |
+
return self._grade_distribution(action)
|
| 358 |
+
elif action.action_type == "flag_error":
|
| 359 |
+
return self._grade_flag(action)
|
| 360 |
+
elif action.action_type == "propose_fix":
|
| 361 |
+
return self._grade_propose_fix(action)
|
| 362 |
+
elif action.action_type == "submit_report":
|
| 363 |
+
return self._grade_report(action)
|
| 364 |
+
else:
|
| 365 |
+
return (
|
| 366 |
+
REWARD_CONFIG["unknown_action"],
|
| 367 |
+
f"REJECTED: Unknown action '{action.action_type}'. "
|
| 368 |
+
f"Valid: investigate_pattern, compute_distribution, "
|
| 369 |
+
f"flag_error, propose_fix, submit_report."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def _grade_investigate(self, action: AuditAction):
|
| 373 |
+
variable = action.variable or ""
|
| 374 |
+
if not variable:
|
| 375 |
+
return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
|
| 376 |
+
|
| 377 |
+
valid_vars = {
|
| 378 |
+
"age", "gender", "ethnicity", "treatment_start",
|
| 379 |
+
"death_date", "outcome", "treatment_site", "group",
|
| 380 |
+
"stage", "trial_phase", "drug", "country", "enrollment_date",
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
if variable not in valid_vars:
|
| 384 |
+
return (
|
| 385 |
+
REWARD_CONFIG["unknown_action"],
|
| 386 |
+
f"REJECTED: Unknown variable '{variable}'. "
|
| 387 |
+
f"Valid: {', '.join(sorted(valid_vars))}."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if variable in self._patterns_investigated:
|
| 391 |
+
return (
|
| 392 |
+
REWARD_CONFIG["investigate_redundant"],
|
| 393 |
+
f"Already investigated '{variable}'. Use flag_error to act on findings."
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
self._patterns_investigated.add(variable)
|
| 397 |
+
self._state.patterns_investigated.append(variable)
|
| 398 |
+
|
| 399 |
+
# Phase transition: unlock flagging after investigating key variables
|
| 400 |
+
if (
|
| 401 |
+
"age" in self._patterns_investigated
|
| 402 |
+
and "death_date" in self._patterns_investigated
|
| 403 |
+
and self._phase == "investigation"
|
| 404 |
+
):
|
| 405 |
+
self._phase = "flagging"
|
| 406 |
+
|
| 407 |
+
# Dynamic statistics based on variable type
|
| 408 |
+
if variable == "age":
|
| 409 |
+
ages = [p["age"] for p in self._dataset if p.get("age") is not None]
|
| 410 |
+
nulls = len([p for p in self._dataset if p.get("age") is None])
|
| 411 |
+
if ages:
|
| 412 |
+
min_age, max_age = min(ages), max(ages)
|
| 413 |
+
feedback = (
|
| 414 |
+
f"Age Stats: min={min_age}, max={max_age}, "
|
| 415 |
+
f"null_count={nulls}, n={len(ages)}."
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
feedback = f"Age Stats: no valid ages found, null_count={nulls}."
|
| 419 |
+
elif variable in ["treatment_start", "death_date", "enrollment_date"]:
|
| 420 |
+
vals = [p[variable] for p in self._dataset if p.get(variable)]
|
| 421 |
+
feedback = f"Date field '{variable}': {len(vals)} non-null values found. Check temporal alignment."
|
| 422 |
+
elif variable == "outcome":
|
| 423 |
+
survived = sum(1 for p in self._dataset if p.get("outcome") == "survived")
|
| 424 |
+
deceased = sum(1 for p in self._dataset if p.get("outcome") == "deceased")
|
| 425 |
+
feedback = f"Outcomes: Survived={survived}, Deceased={deceased}, Total={survived + deceased}."
|
| 426 |
+
elif variable == "group":
|
| 427 |
+
control = sum(1 for p in self._dataset if p.get("group") == "control")
|
| 428 |
+
treatment = sum(1 for p in self._dataset if p.get("group") == "treatment")
|
| 429 |
+
feedback = f"Groups: Control={control}, Treatment={treatment}."
|
| 430 |
+
else:
|
| 431 |
+
counts = {}
|
| 432 |
+
for p in self._dataset:
|
| 433 |
+
val = str(p.get(variable, "None"))
|
| 434 |
+
counts[val] = counts.get(val, 0) + 1
|
| 435 |
+
# Sort by frequency descending
|
| 436 |
+
sorted_counts = dict(
|
| 437 |
+
sorted(counts.items(), key=lambda x: -x[1])
|
| 438 |
+
)
|
| 439 |
+
# Truncate if too many unique values
|
| 440 |
+
if len(sorted_counts) > 10:
|
| 441 |
+
top_10 = dict(list(sorted_counts.items())[:10])
|
| 442 |
+
feedback = (
|
| 443 |
+
f"{variable.capitalize()} Distribution (top 10 of "
|
| 444 |
+
f"{len(sorted_counts)}): {top_10}."
|
| 445 |
+
)
|
| 446 |
+
else:
|
| 447 |
+
feedback = f"{variable.capitalize()} Distribution: {sorted_counts}."
|
| 448 |
+
|
| 449 |
+
return REWARD_CONFIG["investigate_new"], f"Investigated '{variable}': {feedback}"
|
| 450 |
+
|
| 451 |
+
def _grade_distribution(self, action: AuditAction):
|
| 452 |
+
variable = action.variable or ""
|
| 453 |
+
if not variable:
|
| 454 |
+
return REWARD_CONFIG["unknown_action"], "REJECTED: Variable cannot be empty."
|
| 455 |
+
|
| 456 |
+
if variable in self._distributions_computed:
|
| 457 |
+
return (
|
| 458 |
+
REWARD_CONFIG["distribution_redundant"],
|
| 459 |
+
f"Distribution for '{variable}' already computed."
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self._distributions_computed.add(variable)
|
| 463 |
+
self._state.distributions_computed.append(variable)
|
| 464 |
+
|
| 465 |
+
# Phase transition via distribution analysis
|
| 466 |
+
if (
|
| 467 |
+
"ethnicity" in self._distributions_computed
|
| 468 |
+
and "outcome" in self._distributions_computed
|
| 469 |
+
and self._phase == "investigation"
|
| 470 |
+
):
|
| 471 |
+
self._phase = "flagging"
|
| 472 |
+
|
| 473 |
+
if variable == "ethnicity":
|
| 474 |
+
control = [p for p in self._dataset if p.get("group") == "control"]
|
| 475 |
+
if control:
|
| 476 |
+
eth_counts = {}
|
| 477 |
+
for p in control:
|
| 478 |
+
eth = p.get("ethnicity", "Unknown")
|
| 479 |
+
eth_counts[eth] = eth_counts.get(eth, 0) + 1
|
| 480 |
+
total = len(control)
|
| 481 |
+
breakdown = ", ".join(
|
| 482 |
+
f"{k}={v} ({v / total * 100:.0f}%)"
|
| 483 |
+
for k, v in sorted(eth_counts.items(), key=lambda x: -x[1])
|
| 484 |
+
)
|
| 485 |
+
feedback = f"Control group ethnicity: {breakdown}. Total={total}."
|
| 486 |
+
else:
|
| 487 |
+
feedback = "No control group patients found."
|
| 488 |
+
elif variable == "outcome":
|
| 489 |
+
control = [p for p in self._dataset if p.get("group") == "control"]
|
| 490 |
+
if control:
|
| 491 |
+
deceased_c = sum(
|
| 492 |
+
1 for p in control if p.get("outcome") == "deceased"
|
| 493 |
+
)
|
| 494 |
+
total = len(control)
|
| 495 |
+
feedback = (
|
| 496 |
+
f"Control group outcomes: deceased={deceased_c}/{total} "
|
| 497 |
+
f"({deceased_c / total * 100:.0f}%). "
|
| 498 |
+
f"Survived={total - deceased_c}/{total} "
|
| 499 |
+
f"({(total - deceased_c) / total * 100:.0f}%)."
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
feedback = "No control group patients found."
|
| 503 |
+
elif variable == "gender":
|
| 504 |
+
control = [p for p in self._dataset if p.get("group") == "control"]
|
| 505 |
+
if control:
|
| 506 |
+
male_c = sum(1 for p in control if p.get("gender") == "M")
|
| 507 |
+
total = len(control)
|
| 508 |
+
feedback = (
|
| 509 |
+
f"Control group gender: Male={male_c}/{total} "
|
| 510 |
+
f"({male_c / total * 100:.0f}%), "
|
| 511 |
+
f"Female={total - male_c}/{total} "
|
| 512 |
+
f"({(total - male_c) / total * 100:.0f}%)."
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
feedback = "No control group patients found."
|
| 516 |
+
else:
|
| 517 |
+
feedback = f"Distribution computed for '{variable}'."
|
| 518 |
+
|
| 519 |
+
return REWARD_CONFIG["distribution_new"], f"Distribution '{variable}': {feedback}"
|
| 520 |
+
|
| 521 |
+
def _grade_flag(self, action: AuditAction):
|
| 522 |
+
"""Grade flag action against pre-computed ground truth."""
|
| 523 |
+
patient_id = action.patient_id
|
| 524 |
+
error_type = action.error_type or ""
|
| 525 |
+
|
| 526 |
+
# ββ Selection bias flag (no patient_id needed) ββ
|
| 527 |
+
if error_type == "selection_bias":
|
| 528 |
+
if not self._current_task["allow_bias"]:
|
| 529 |
+
return (
|
| 530 |
+
REWARD_CONFIG["false_positive"],
|
| 531 |
+
"β Selection bias analysis not required for this task."
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if "BIAS_FLAG" in self._flagged_patients:
|
| 535 |
+
return (
|
| 536 |
+
REWARD_CONFIG["duplicate_flag"],
|
| 537 |
+
"Selection bias already flagged."
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
if self._bias_present:
|
| 541 |
+
# Verify bias is actually detectable in the data
|
| 542 |
+
control = [p for p in self._dataset if p.get("group") == "control"]
|
| 543 |
+
if not control:
|
| 544 |
+
return (
|
| 545 |
+
REWARD_CONFIG["false_positive"],
|
| 546 |
+
"Cannot assess bias β no control group found."
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
white_count = sum(
|
| 550 |
+
1 for p in control if p.get("ethnicity") == "White"
|
| 551 |
+
)
|
| 552 |
+
white_ratio = white_count / len(control)
|
| 553 |
+
minority_dead = sum(
|
| 554 |
+
1 for p in control
|
| 555 |
+
if p.get("ethnicity") != "White"
|
| 556 |
+
and p.get("outcome") == "deceased"
|
| 557 |
+
)
|
| 558 |
+
male_count = sum(
|
| 559 |
+
1 for p in control if p.get("gender") == "M"
|
| 560 |
+
)
|
| 561 |
+
male_ratio = male_count / len(control)
|
| 562 |
+
|
| 563 |
+
if white_ratio >= 0.65 and minority_dead > 0 and male_ratio >= 0.50:
|
| 564 |
+
self._flagged_patients.add("BIAS_FLAG")
|
| 565 |
+
self._state.errors_found += 1
|
| 566 |
+
return (
|
| 567 |
+
REWARD_CONFIG["bias_detected"],
|
| 568 |
+
f"β Correct. Multi-dimensional selection bias confirmed: "
|
| 569 |
+
f"White={white_ratio:.0%} of control, "
|
| 570 |
+
f"minority mortality present ({minority_dead} deceased), "
|
| 571 |
+
f"gender imbalance ({male_ratio:.0%} male)."
|
| 572 |
+
)
|
| 573 |
+
else:
|
| 574 |
+
return (
|
| 575 |
+
REWARD_CONFIG["false_positive"],
|
| 576 |
+
"β Statistical evidence insufficient for bias determination."
|
| 577 |
+
)
|
| 578 |
+
else:
|
| 579 |
+
return (
|
| 580 |
+
REWARD_CONFIG["false_positive"],
|
| 581 |
+
"β False positive. No significant selection bias in this dataset."
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# ββ Data error flags (require patient_id) ββ
|
| 585 |
+
if patient_id is None:
|
| 586 |
+
return (
|
| 587 |
+
REWARD_CONFIG["false_positive"],
|
| 588 |
+
"REJECTED: Provide patient_id for data errors."
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if patient_id in self._flagged_patients:
|
| 592 |
+
return (
|
| 593 |
+
REWARD_CONFIG["duplicate_flag"],
|
| 594 |
+
f"{patient_id} already flagged."
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Check if patient exists in dataset
|
| 598 |
+
patient = next(
|
| 599 |
+
(p for p in self._dataset if p.get("patient_id") == patient_id),
|
| 600 |
+
None
|
| 601 |
+
)
|
| 602 |
+
if not patient:
|
| 603 |
+
return (
|
| 604 |
+
REWARD_CONFIG["false_positive"],
|
| 605 |
+
f"REJECTED: Patient '{patient_id}' not found in dataset."
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# ββ Ground truth lookup (O(1) β deterministic) ββ
|
| 609 |
+
expected_errors = self._ground_truth.get(patient_id, [])
|
| 610 |
+
|
| 611 |
+
if error_type == "invalid_age":
|
| 612 |
+
if "invalid_age" in expected_errors:
|
| 613 |
+
self._flagged_patients.add(patient_id)
|
| 614 |
+
self._state.errors_found += 1
|
| 615 |
+
age = patient.get("age")
|
| 616 |
+
return (
|
| 617 |
+
REWARD_CONFIG["correct_flag"],
|
| 618 |
+
f"β Correct: {patient_id} has invalid age ({age}). Good catch."
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
age = patient.get("age")
|
| 622 |
+
return (
|
| 623 |
+
REWARD_CONFIG["false_positive"],
|
| 624 |
+
f"β False positive: {patient_id} age={age} is within valid range [18-120]."
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
elif error_type == "temporal_inconsistency":
|
| 628 |
+
if "temporal_inconsistency" in expected_errors:
|
| 629 |
+
self._flagged_patients.add(patient_id)
|
| 630 |
+
self._state.errors_found += 1
|
| 631 |
+
ts = patient.get("treatment_start", "")
|
| 632 |
+
dd = patient.get("death_date", "")
|
| 633 |
+
if ts and dd:
|
| 634 |
+
t = datetime.strptime(ts, "%Y-%m-%d")
|
| 635 |
+
d = datetime.strptime(dd, "%Y-%m-%d")
|
| 636 |
+
gap = (t - d).days
|
| 637 |
+
return (
|
| 638 |
+
REWARD_CONFIG["correct_flag"],
|
| 639 |
+
f"β Correct: {patient_id} death_date is {gap} days "
|
| 640 |
+
f"before treatment_start."
|
| 641 |
+
)
|
| 642 |
+
return (
|
| 643 |
+
REWARD_CONFIG["correct_flag"],
|
| 644 |
+
f"β Correct: {patient_id} has temporal inconsistency."
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
return (
|
| 648 |
+
REWARD_CONFIG["false_positive"],
|
| 649 |
+
f"β False positive: {patient_id} temporal sequence is valid."
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
else:
|
| 653 |
+
return (
|
| 654 |
+
REWARD_CONFIG["false_positive"],
|
| 655 |
+
f"β Invalid error_type '{error_type}'. "
|
| 656 |
+
f"Valid: invalid_age, temporal_inconsistency, selection_bias."
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
def _grade_propose_fix(self, action: AuditAction):
|
| 660 |
+
patient_id = action.patient_id or ""
|
| 661 |
+
if patient_id not in self._flagged_patients:
|
| 662 |
+
return (
|
| 663 |
+
REWARD_CONFIG["propose_fix_invalid"],
|
| 664 |
+
"Can only propose fix for flagged patients."
|
| 665 |
+
)
|
| 666 |
+
proposed = action.proposed_value or ""
|
| 667 |
+
if len(proposed) > 2:
|
| 668 |
+
return (
|
| 669 |
+
REWARD_CONFIG["propose_fix_valid"],
|
| 670 |
+
f"Fix proposed for {patient_id}."
|
| 671 |
+
)
|
| 672 |
+
return REWARD_CONFIG["propose_fix_invalid"], "Proposed fix too vague."
|
| 673 |
+
|
| 674 |
+
def _grade_report(self, action: AuditAction):
|
| 675 |
+
"""Grade report quality using multi-dimensional rubric."""
|
| 676 |
+
self._report_submitted = True
|
| 677 |
+
report = (action.report or action.reason or "").lower()
|
| 678 |
+
step_reward = REWARD_CONFIG["report_bonus_base"]
|
| 679 |
+
|
| 680 |
+
# Completeness bonus: flagged enough issues
|
| 681 |
+
if len(self._flagged_patients) >= 3:
|
| 682 |
+
step_reward += 0.03
|
| 683 |
+
|
| 684 |
+
# ββ Report quality rubric (tests clinical reasoning depth) ββ
|
| 685 |
+
quality_score = 0
|
| 686 |
+
quality_items = []
|
| 687 |
+
|
| 688 |
+
# Root cause analysis
|
| 689 |
+
if any(kw in report for kw in [
|
| 690 |
+
"root cause", "data entry", "etl", "pipeline", "system"
|
| 691 |
+
]):
|
| 692 |
+
quality_score += 1
|
| 693 |
+
quality_items.append("root cause analysis")
|
| 694 |
+
|
| 695 |
+
# Corrective recommendations
|
| 696 |
+
if any(kw in report for kw in [
|
| 697 |
+
"recommend", "corrective", "action", "mitigation"
|
| 698 |
+
]):
|
| 699 |
+
quality_score += 1
|
| 700 |
+
quality_items.append("corrective recommendations")
|
| 701 |
+
|
| 702 |
+
# Risk assessment
|
| 703 |
+
if any(kw in report for kw in [
|
| 704 |
+
"risk", "severity", "critical", "impact", "patient safety"
|
| 705 |
+
]):
|
| 706 |
+
quality_score += 1
|
| 707 |
+
quality_items.append("risk assessment")
|
| 708 |
+
|
| 709 |
+
# Regulatory compliance
|
| 710 |
+
if any(kw in report for kw in [
|
| 711 |
+
"regulatory", "compliance", "fda", "ich", "gcp", "validity"
|
| 712 |
+
]):
|
| 713 |
+
quality_score += 1
|
| 714 |
+
quality_items.append("regulatory awareness")
|
| 715 |
+
|
| 716 |
+
# Quality bonus: +0.02 per dimension (max +0.08)
|
| 717 |
+
step_reward += quality_score * 0.02
|
| 718 |
+
|
| 719 |
+
quality_feedback = f"Report quality: {quality_score}/4 dimensions"
|
| 720 |
+
if quality_items:
|
| 721 |
+
quality_feedback += f" ({', '.join(quality_items)})"
|
| 722 |
+
|
| 723 |
+
return (
|
| 724 |
+
step_reward,
|
| 725 |
+
f"Report submitted. {quality_feedback}. Final evaluation complete."
|
| 726 |
+
)
|
server/dataset_generator.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Procedural Adversarial Clinical Trial Data Engine
|
| 3 |
+
==================================================
|
| 4 |
+
Generates statistically rigorous, adversarial patient datasets for each episode.
|
| 5 |
+
|
| 6 |
+
Design philosophy:
|
| 7 |
+
- Every reset() β unique dataset β no memorization possible
|
| 8 |
+
- Controlled error injection with known ground truth
|
| 9 |
+
- Adversarial traps that punish shallow reasoning
|
| 10 |
+
- Seed-based reproducibility for deterministic judging
|
| 11 |
+
- Pure stdlib (no numpy) β minimal Docker image
|
| 12 |
+
|
| 13 |
+
Architecture layers:
|
| 14 |
+
1. Base Patient Generator β realistic demographics via statistical distributions
|
| 15 |
+
2. Error Injector β controlled % of age/temporal/missing violations
|
| 16 |
+
3. Bias Injector β demographic skew + outcome disparity in control group
|
| 17 |
+
4. Trap Injector β boundary-valid, near-temporal, fake-pattern distractors
|
| 18 |
+
5. Ground Truth Tracker β records every injected error for deterministic grading
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import random
|
| 22 |
+
import math
|
| 23 |
+
import hashlib
|
| 24 |
+
from datetime import datetime, timedelta
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
# REFERENCE DATA β Realistic clinical trial metadata pools
|
| 30 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
HOSPITAL_SITES = [
|
| 33 |
+
("Metro General Hospital", "US"),
|
| 34 |
+
("Cleveland Oncology Institute", "US"),
|
| 35 |
+
("Howard University Hospital", "US"),
|
| 36 |
+
("Johns Hopkins Oncology Center", "US"),
|
| 37 |
+
("MD Anderson Cancer Center", "US"),
|
| 38 |
+
("AIIMS Delhi", "India"),
|
| 39 |
+
("Tata Memorial Hospital", "India"),
|
| 40 |
+
("CharitΓ© Berlin", "Germany"),
|
| 41 |
+
("Hospital ClΓnic Barcelona", "Spain"),
|
| 42 |
+
("Tokyo Medical University", "Japan"),
|
| 43 |
+
("Seoul National University Hospital", "South Korea"),
|
| 44 |
+
("Royal Marsden Hospital", "UK"),
|
| 45 |
+
("Gustave Roussy Institute", "France"),
|
| 46 |
+
("Princess Margaret Cancer Centre", "Canada"),
|
| 47 |
+
("Peter MacCallum Cancer Centre", "Australia"),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# Sites considered "rural" or underrepresented for bias analysis
|
| 51 |
+
RURAL_SITES = {
|
| 52 |
+
"AIIMS Delhi", "Tata Memorial Hospital",
|
| 53 |
+
"Howard University Hospital",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
ETHNICITIES = ["White", "Black", "Hispanic", "Asian", "Native American", "Pacific Islander"]
|
| 57 |
+
GENDERS = ["M", "F"]
|
| 58 |
+
STAGES = ["I", "II", "III", "IV"]
|
| 59 |
+
DRUGS_TREATMENT = ["ImmunoVax-7", "OncoShield-X", "TargetCure-3"]
|
| 60 |
+
DRUGS_CONTROL = ["Placebo"]
|
| 61 |
+
|
| 62 |
+
# Date range for the trial
|
| 63 |
+
TRIAL_START = datetime(2022, 6, 1)
|
| 64 |
+
TRIAL_END = datetime(2025, 3, 1)
|
| 65 |
+
|
| 66 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
# DIFFICULTY CONFIGURATIONS
|
| 68 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
+
|
| 70 |
+
DIFFICULTY_CONFIGS = {
|
| 71 |
+
"easy": {
|
| 72 |
+
"dataset_size": 300,
|
| 73 |
+
"age_error_rate": 0.03, # 3% of patients have invalid ages
|
| 74 |
+
"temporal_error_rate": 0.0, # No temporal errors in easy
|
| 75 |
+
"missing_data_rate": 0.01, # 1% missing age
|
| 76 |
+
"bias_intensity": 0.0, # No bias in easy
|
| 77 |
+
"num_boundary_traps": 5, # Valid edge-case ages
|
| 78 |
+
"num_temporal_traps": 0,
|
| 79 |
+
"num_distractor_deceased": 4, # Valid deceased patients
|
| 80 |
+
"num_fake_bias_distractors": 0,
|
| 81 |
+
"mortality_rate": 0.12, # 12% overall mortality
|
| 82 |
+
"control_ratio": 0.50, # 50/50 control/treatment
|
| 83 |
+
"task_type": "syntactic_cleaning",
|
| 84 |
+
"allow_bias": False,
|
| 85 |
+
},
|
| 86 |
+
"medium": {
|
| 87 |
+
"dataset_size": 500,
|
| 88 |
+
"age_error_rate": 0.03,
|
| 89 |
+
"temporal_error_rate": 0.03, # 3% temporal violations
|
| 90 |
+
"missing_data_rate": 0.015,
|
| 91 |
+
"bias_intensity": 0.0,
|
| 92 |
+
"num_boundary_traps": 6,
|
| 93 |
+
"num_temporal_traps": 3, # Near-temporal valid cases
|
| 94 |
+
"num_distractor_deceased": 5,
|
| 95 |
+
"num_fake_bias_distractors": 0,
|
| 96 |
+
"mortality_rate": 0.15,
|
| 97 |
+
"control_ratio": 0.50,
|
| 98 |
+
"task_type": "temporal_consistency",
|
| 99 |
+
"allow_bias": False,
|
| 100 |
+
},
|
| 101 |
+
"hard": {
|
| 102 |
+
"dataset_size": 800,
|
| 103 |
+
"age_error_rate": 0.025,
|
| 104 |
+
"temporal_error_rate": 0.025,
|
| 105 |
+
"missing_data_rate": 0.01,
|
| 106 |
+
"bias_intensity": 0.80, # Strong bias
|
| 107 |
+
"num_boundary_traps": 8,
|
| 108 |
+
"num_temporal_traps": 4,
|
| 109 |
+
"num_distractor_deceased": 8,
|
| 110 |
+
"num_fake_bias_distractors": 5, # Fake patterns that look biased but aren't
|
| 111 |
+
"mortality_rate": 0.18,
|
| 112 |
+
"control_ratio": 0.50,
|
| 113 |
+
"task_type": "comprehensive_audit",
|
| 114 |
+
"allow_bias": True,
|
| 115 |
+
},
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
# DATASET GENERATOR
|
| 121 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
|
| 123 |
+
class DatasetGenerator:
|
| 124 |
+
"""
|
| 125 |
+
Procedural adversarial clinical trial data engine.
|
| 126 |
+
|
| 127 |
+
Generates statistically rigorous patient datasets with:
|
| 128 |
+
- Configurable size (300-1000+ patients)
|
| 129 |
+
- Controlled error injection (age, temporal, missing data)
|
| 130 |
+
- Controllable bias intensity (representation + outcome disparity)
|
| 131 |
+
- Adversarial traps (boundary-valid, near-temporal, fake patterns)
|
| 132 |
+
- Seed-based reproducibility (same seed β identical dataset)
|
| 133 |
+
|
| 134 |
+
Usage:
|
| 135 |
+
gen = DatasetGenerator(seed=42)
|
| 136 |
+
result = gen.generate(difficulty="hard")
|
| 137 |
+
dataset = result["dataset"] # List[dict] β patient records
|
| 138 |
+
ground_truth = result["ground_truth"] # Dict[str, List[str]] β {pid: [error_types]}
|
| 139 |
+
traps = result["traps"] # Set[str] β valid-but-suspicious pids
|
| 140 |
+
bias_present = result["bias_present"] # bool
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, seed: Optional[int] = None):
|
| 144 |
+
self.seed = seed
|
| 145 |
+
self.rng = random.Random(seed)
|
| 146 |
+
self._patient_counter = 0
|
| 147 |
+
self._ground_truth: dict[str, list[str]] = {}
|
| 148 |
+
self._traps: set[str] = set()
|
| 149 |
+
|
| 150 |
+
def _next_pid(self) -> str:
|
| 151 |
+
self._patient_counter += 1
|
| 152 |
+
return f"P{self._patient_counter:04d}"
|
| 153 |
+
|
| 154 |
+
def _random_date(self, start: datetime, end: datetime) -> datetime:
|
| 155 |
+
"""Generate a random date between start and end."""
|
| 156 |
+
delta = (end - start).days
|
| 157 |
+
if delta <= 0:
|
| 158 |
+
return start
|
| 159 |
+
return start + timedelta(days=self.rng.randint(0, delta))
|
| 160 |
+
|
| 161 |
+
def _generate_age(self) -> int:
|
| 162 |
+
"""Generate a realistic age using truncated normal distribution."""
|
| 163 |
+
# Clinical trial typical age: mean=58, std=12
|
| 164 |
+
while True:
|
| 165 |
+
age = int(self.rng.gauss(58, 12))
|
| 166 |
+
if 18 <= age <= 100:
|
| 167 |
+
return age
|
| 168 |
+
|
| 169 |
+
def _select_ethnicity(self, bias_mode: str = "neutral") -> str:
|
| 170 |
+
"""
|
| 171 |
+
Select ethnicity with configurable distribution.
|
| 172 |
+
bias_mode: "neutral" | "white_dominant" | "diverse"
|
| 173 |
+
"""
|
| 174 |
+
if bias_mode == "white_dominant":
|
| 175 |
+
weights = [0.78, 0.06, 0.06, 0.05, 0.03, 0.02]
|
| 176 |
+
elif bias_mode == "diverse":
|
| 177 |
+
weights = [0.30, 0.20, 0.20, 0.15, 0.10, 0.05]
|
| 178 |
+
else: # neutral β matches US clinical trial demographics
|
| 179 |
+
weights = [0.55, 0.15, 0.15, 0.10, 0.03, 0.02]
|
| 180 |
+
|
| 181 |
+
return self.rng.choices(ETHNICITIES, weights=weights, k=1)[0]
|
| 182 |
+
|
| 183 |
+
def _generate_base_patient(self, group: str, ethnicity: str = None,
|
| 184 |
+
bias_mode: str = "neutral") -> dict:
|
| 185 |
+
"""Generate a single valid patient record."""
|
| 186 |
+
pid = self._next_pid()
|
| 187 |
+
site, country = self.rng.choice(HOSPITAL_SITES)
|
| 188 |
+
gender = self.rng.choice(GENDERS)
|
| 189 |
+
eth = ethnicity or self._select_ethnicity(bias_mode)
|
| 190 |
+
age = self._generate_age()
|
| 191 |
+
stage = self.rng.choices(STAGES, weights=[0.25, 0.30, 0.25, 0.20], k=1)[0]
|
| 192 |
+
|
| 193 |
+
enrollment_date = self._random_date(TRIAL_START, TRIAL_END - timedelta(days=180))
|
| 194 |
+
treatment_start = enrollment_date + timedelta(days=self.rng.randint(7, 30))
|
| 195 |
+
|
| 196 |
+
if group == "treatment":
|
| 197 |
+
drug = self.rng.choice(DRUGS_TREATMENT)
|
| 198 |
+
else:
|
| 199 |
+
drug = "Placebo"
|
| 200 |
+
|
| 201 |
+
patient = {
|
| 202 |
+
"patient_id": pid,
|
| 203 |
+
"age": age,
|
| 204 |
+
"gender": gender,
|
| 205 |
+
"ethnicity": eth,
|
| 206 |
+
"group": group,
|
| 207 |
+
"treatment_start": treatment_start.strftime("%Y-%m-%d"),
|
| 208 |
+
"death_date": None,
|
| 209 |
+
"outcome": "survived",
|
| 210 |
+
"treatment_site": site,
|
| 211 |
+
"stage": stage,
|
| 212 |
+
"trial_phase": "Phase III",
|
| 213 |
+
"drug": drug,
|
| 214 |
+
"enrollment_date": enrollment_date.strftime("%Y-%m-%d"),
|
| 215 |
+
"country": country,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
return patient
|
| 219 |
+
|
| 220 |
+
def _apply_mortality(self, patient: dict, mortality_rate: float) -> dict:
|
| 221 |
+
"""Randomly apply mortality with valid timeline."""
|
| 222 |
+
if self.rng.random() < mortality_rate:
|
| 223 |
+
treatment_start = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 224 |
+
# Death occurs 1-720 days after treatment start
|
| 225 |
+
days_to_death = self.rng.randint(1, 720)
|
| 226 |
+
death_date = treatment_start + timedelta(days=days_to_death)
|
| 227 |
+
# Cap at trial end
|
| 228 |
+
if death_date > TRIAL_END + timedelta(days=365):
|
| 229 |
+
death_date = TRIAL_END + timedelta(days=self.rng.randint(1, 180))
|
| 230 |
+
|
| 231 |
+
patient["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 232 |
+
patient["outcome"] = "deceased"
|
| 233 |
+
return patient
|
| 234 |
+
|
| 235 |
+
# ββ Error Injectors βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
+
|
| 237 |
+
def _inject_age_errors(self, patients: list[dict], error_rate: float,
|
| 238 |
+
missing_rate: float) -> list[dict]:
|
| 239 |
+
"""Inject invalid age values into random patients."""
|
| 240 |
+
n_age_errors = max(3, int(len(patients) * error_rate))
|
| 241 |
+
n_missing = max(1, int(len(patients) * missing_rate))
|
| 242 |
+
|
| 243 |
+
# Select random indices for age errors (avoid overlap)
|
| 244 |
+
available = list(range(len(patients)))
|
| 245 |
+
self.rng.shuffle(available)
|
| 246 |
+
|
| 247 |
+
# Invalid age errors
|
| 248 |
+
invalid_ages = []
|
| 249 |
+
for _ in range(n_age_errors):
|
| 250 |
+
error_kind = self.rng.choice([
|
| 251 |
+
"negative", "extreme_high", "sentinel", "just_over"
|
| 252 |
+
])
|
| 253 |
+
if error_kind == "negative":
|
| 254 |
+
invalid_ages.append(self.rng.choice([-1, -5, -10, -3, -15]))
|
| 255 |
+
elif error_kind == "extreme_high":
|
| 256 |
+
invalid_ages.append(self.rng.choice([150, 200, 250, 300, 500]))
|
| 257 |
+
elif error_kind == "sentinel":
|
| 258 |
+
invalid_ages.append(self.rng.choice([999, 9999, 0, -999]))
|
| 259 |
+
elif error_kind == "just_over":
|
| 260 |
+
invalid_ages.append(self.rng.choice([121, 122, 125, 130, 17, 16, 15]))
|
| 261 |
+
|
| 262 |
+
for i, invalid_age in enumerate(invalid_ages):
|
| 263 |
+
if i >= len(available):
|
| 264 |
+
break
|
| 265 |
+
idx = available[i]
|
| 266 |
+
patients[idx]["age"] = invalid_age
|
| 267 |
+
pid = patients[idx]["patient_id"]
|
| 268 |
+
self._ground_truth.setdefault(pid, []).append("invalid_age")
|
| 269 |
+
|
| 270 |
+
# Missing age (None)
|
| 271 |
+
offset = len(invalid_ages)
|
| 272 |
+
for j in range(n_missing):
|
| 273 |
+
if offset + j >= len(available):
|
| 274 |
+
break
|
| 275 |
+
idx = available[offset + j]
|
| 276 |
+
patients[idx]["age"] = None
|
| 277 |
+
pid = patients[idx]["patient_id"]
|
| 278 |
+
self._ground_truth.setdefault(pid, []).append("invalid_age")
|
| 279 |
+
|
| 280 |
+
return patients
|
| 281 |
+
|
| 282 |
+
def _inject_temporal_errors(self, patients: list[dict],
|
| 283 |
+
error_rate: float) -> list[dict]:
|
| 284 |
+
"""Inject temporal violations: death_date before treatment_start."""
|
| 285 |
+
n_errors = max(3, int(len(patients) * error_rate))
|
| 286 |
+
|
| 287 |
+
# Only inject into patients who have death dates or can have one added
|
| 288 |
+
candidates = []
|
| 289 |
+
for i, p in enumerate(patients):
|
| 290 |
+
pid = p["patient_id"]
|
| 291 |
+
# Don't stack errors on patients already with age errors
|
| 292 |
+
if pid not in self._ground_truth:
|
| 293 |
+
candidates.append(i)
|
| 294 |
+
|
| 295 |
+
self.rng.shuffle(candidates)
|
| 296 |
+
|
| 297 |
+
for k in range(min(n_errors, len(candidates))):
|
| 298 |
+
idx = candidates[k]
|
| 299 |
+
p = patients[idx]
|
| 300 |
+
treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
|
| 301 |
+
|
| 302 |
+
# Death date 15-365 days BEFORE treatment start (clear violation)
|
| 303 |
+
gap_days = self.rng.randint(15, 365)
|
| 304 |
+
death_date = treatment_start - timedelta(days=gap_days)
|
| 305 |
+
|
| 306 |
+
p["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 307 |
+
p["outcome"] = "deceased"
|
| 308 |
+
|
| 309 |
+
pid = p["patient_id"]
|
| 310 |
+
self._ground_truth.setdefault(pid, []).append("temporal_inconsistency")
|
| 311 |
+
|
| 312 |
+
return patients
|
| 313 |
+
|
| 314 |
+
def _inject_bias(self, patients: list[dict], intensity: float) -> list[dict]:
|
| 315 |
+
"""
|
| 316 |
+
Inject multi-dimensional selection bias into the control group.
|
| 317 |
+
|
| 318 |
+
Bias structure (mirrors real SEER findings):
|
| 319 |
+
1. Representation: White patients dominate control group (>75%)
|
| 320 |
+
2. Outcome disparity: Minority control patients have higher mortality
|
| 321 |
+
3. Gender imbalance: Males overrepresented in control
|
| 322 |
+
4. Site bias: Minorities underrepresented at major sites
|
| 323 |
+
"""
|
| 324 |
+
if intensity <= 0:
|
| 325 |
+
return patients
|
| 326 |
+
|
| 327 |
+
control_patients = [p for p in patients if p["group"] == "control"]
|
| 328 |
+
treatment_patients = [p for p in patients if p["group"] == "treatment"]
|
| 329 |
+
|
| 330 |
+
if not control_patients:
|
| 331 |
+
return patients
|
| 332 |
+
|
| 333 |
+
# ββ Layer 1: Representation bias ββ
|
| 334 |
+
# Force >75% of control to be White
|
| 335 |
+
target_white_ratio = 0.75 + (intensity * 0.10) # 0.75-0.85
|
| 336 |
+
n_control = len(control_patients)
|
| 337 |
+
n_white_target = int(n_control * target_white_ratio)
|
| 338 |
+
n_white_current = sum(1 for p in control_patients if p["ethnicity"] == "White")
|
| 339 |
+
|
| 340 |
+
# Convert some non-White control patients to White
|
| 341 |
+
non_white_control = [p for p in control_patients if p["ethnicity"] != "White"]
|
| 342 |
+
to_convert = max(0, n_white_target - n_white_current)
|
| 343 |
+
self.rng.shuffle(non_white_control)
|
| 344 |
+
for i in range(min(to_convert, len(non_white_control))):
|
| 345 |
+
non_white_control[i]["ethnicity"] = "White"
|
| 346 |
+
|
| 347 |
+
# ββ Layer 2: Gender imbalance in control ββ
|
| 348 |
+
# Force >65% male in control
|
| 349 |
+
target_male_ratio = 0.65 + (intensity * 0.10)
|
| 350 |
+
n_male_target = int(n_control * target_male_ratio)
|
| 351 |
+
n_male_current = sum(1 for p in control_patients if p["gender"] == "M")
|
| 352 |
+
female_control = [p for p in control_patients if p["gender"] == "F"]
|
| 353 |
+
to_convert_gender = max(0, n_male_target - n_male_current)
|
| 354 |
+
self.rng.shuffle(female_control)
|
| 355 |
+
for i in range(min(to_convert_gender, len(female_control))):
|
| 356 |
+
female_control[i]["gender"] = "M"
|
| 357 |
+
|
| 358 |
+
# ββ Layer 3: Outcome disparity ββ
|
| 359 |
+
# Minority patients in control β higher mortality (>60%)
|
| 360 |
+
minority_control = [
|
| 361 |
+
p for p in control_patients
|
| 362 |
+
if p["ethnicity"] != "White" and p["patient_id"] not in self._ground_truth
|
| 363 |
+
]
|
| 364 |
+
target_minority_mortality = 0.60 + (intensity * 0.15)
|
| 365 |
+
n_minority_dead = int(len(minority_control) * target_minority_mortality)
|
| 366 |
+
|
| 367 |
+
for i, p in enumerate(minority_control):
|
| 368 |
+
if i < n_minority_dead:
|
| 369 |
+
if p["outcome"] != "deceased":
|
| 370 |
+
treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
|
| 371 |
+
death_date = treatment_start + timedelta(
|
| 372 |
+
days=self.rng.randint(30, 365)
|
| 373 |
+
)
|
| 374 |
+
p["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 375 |
+
p["outcome"] = "deceased"
|
| 376 |
+
|
| 377 |
+
# ββ Layer 4: White control patients β low mortality ββ
|
| 378 |
+
white_control = [
|
| 379 |
+
p for p in control_patients
|
| 380 |
+
if p["ethnicity"] == "White" and p["patient_id"] not in self._ground_truth
|
| 381 |
+
]
|
| 382 |
+
# Keep White mortality low
|
| 383 |
+
target_white_survival = 0.85
|
| 384 |
+
n_white_alive = int(len(white_control) * target_white_survival)
|
| 385 |
+
for i, p in enumerate(white_control):
|
| 386 |
+
if i < n_white_alive:
|
| 387 |
+
p["death_date"] = None
|
| 388 |
+
p["outcome"] = "survived"
|
| 389 |
+
|
| 390 |
+
# ββ Layer 5: Rural minority underrepresentation ββ
|
| 391 |
+
for p in minority_control:
|
| 392 |
+
if p["treatment_site"] in RURAL_SITES:
|
| 393 |
+
# Move some to major sites (reducing rural minority visibility)
|
| 394 |
+
if self.rng.random() < intensity * 0.5:
|
| 395 |
+
major_sites = [
|
| 396 |
+
s for s in HOSPITAL_SITES
|
| 397 |
+
if s[0] not in RURAL_SITES
|
| 398 |
+
]
|
| 399 |
+
new_site = self.rng.choice(major_sites)
|
| 400 |
+
p["treatment_site"] = new_site[0]
|
| 401 |
+
p["country"] = new_site[1]
|
| 402 |
+
|
| 403 |
+
return patients
|
| 404 |
+
|
| 405 |
+
# ββ Trap Injectors ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 406 |
+
|
| 407 |
+
def _inject_boundary_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
|
| 408 |
+
"""
|
| 409 |
+
Inject boundary-valid ages that trap naive agents.
|
| 410 |
+
Ages like 18, 19, 120 are VALID but suspicious.
|
| 411 |
+
"""
|
| 412 |
+
boundary_ages = [18, 19, 20, 90, 92, 95, 96, 100, 105, 110, 115, 118, 119, 120, 120]
|
| 413 |
+
self.rng.shuffle(boundary_ages) # Randomize which traps appear
|
| 414 |
+
available = [
|
| 415 |
+
i for i, p in enumerate(patients)
|
| 416 |
+
if p["patient_id"] not in self._ground_truth
|
| 417 |
+
and p["age"] is not None and 25 <= p["age"] <= 85
|
| 418 |
+
]
|
| 419 |
+
self.rng.shuffle(available)
|
| 420 |
+
|
| 421 |
+
for k in range(min(n_traps, len(available), len(boundary_ages))):
|
| 422 |
+
idx = available[k]
|
| 423 |
+
patients[idx]["age"] = boundary_ages[k]
|
| 424 |
+
self._traps.add(patients[idx]["patient_id"])
|
| 425 |
+
|
| 426 |
+
return patients
|
| 427 |
+
|
| 428 |
+
def _inject_temporal_traps(self, patients: list[dict], n_traps: int) -> list[dict]:
|
| 429 |
+
"""
|
| 430 |
+
Inject near-temporal valid cases: death 1-3 days AFTER treatment start.
|
| 431 |
+
These are VALID but look like errors to careless agents.
|
| 432 |
+
"""
|
| 433 |
+
available = [
|
| 434 |
+
i for i, p in enumerate(patients)
|
| 435 |
+
if p["patient_id"] not in self._ground_truth
|
| 436 |
+
and p["death_date"] is None
|
| 437 |
+
and p["patient_id"] not in self._traps
|
| 438 |
+
]
|
| 439 |
+
self.rng.shuffle(available)
|
| 440 |
+
|
| 441 |
+
for k in range(min(n_traps, len(available))):
|
| 442 |
+
idx = available[k]
|
| 443 |
+
p = patients[idx]
|
| 444 |
+
treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
|
| 445 |
+
# Death 1-3 days AFTER treatment β valid but suspicious
|
| 446 |
+
gap = self.rng.randint(1, 3)
|
| 447 |
+
death_date = treatment_start + timedelta(days=gap)
|
| 448 |
+
p["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 449 |
+
p["outcome"] = "deceased"
|
| 450 |
+
p["stage"] = "IV" # Make it medically plausible (late-stage)
|
| 451 |
+
self._traps.add(p["patient_id"])
|
| 452 |
+
|
| 453 |
+
return patients
|
| 454 |
+
|
| 455 |
+
def _inject_fake_bias_distractors(self, patients: list[dict],
|
| 456 |
+
n_distractors: int) -> list[dict]:
|
| 457 |
+
"""
|
| 458 |
+
Inject patterns that LOOK like bias but aren't.
|
| 459 |
+
E.g., treatment group with demographic skew (doesn't matter for bias detection
|
| 460 |
+
since only control group bias is relevant).
|
| 461 |
+
"""
|
| 462 |
+
treatment_patients = [
|
| 463 |
+
i for i, p in enumerate(patients)
|
| 464 |
+
if p["group"] == "treatment"
|
| 465 |
+
and p["patient_id"] not in self._ground_truth
|
| 466 |
+
and p["patient_id"] not in self._traps
|
| 467 |
+
]
|
| 468 |
+
self.rng.shuffle(treatment_patients)
|
| 469 |
+
|
| 470 |
+
for k in range(min(n_distractors, len(treatment_patients))):
|
| 471 |
+
idx = treatment_patients[k]
|
| 472 |
+
# Make treatment group look skewed (irrelevant for bias detection)
|
| 473 |
+
patients[idx]["ethnicity"] = "White"
|
| 474 |
+
patients[idx]["gender"] = "M"
|
| 475 |
+
self._traps.add(patients[idx]["patient_id"])
|
| 476 |
+
|
| 477 |
+
return patients
|
| 478 |
+
|
| 479 |
+
def _inject_distractor_deceased(self, patients: list[dict],
|
| 480 |
+
n_distractors: int) -> list[dict]:
|
| 481 |
+
"""
|
| 482 |
+
Add deceased patients with perfectly valid timelines.
|
| 483 |
+
These are NOT errors β tests if agent over-flags deceased patients.
|
| 484 |
+
"""
|
| 485 |
+
available = [
|
| 486 |
+
i for i, p in enumerate(patients)
|
| 487 |
+
if p["patient_id"] not in self._ground_truth
|
| 488 |
+
and p["death_date"] is None
|
| 489 |
+
and p["patient_id"] not in self._traps
|
| 490 |
+
]
|
| 491 |
+
self.rng.shuffle(available)
|
| 492 |
+
|
| 493 |
+
for k in range(min(n_distractors, len(available))):
|
| 494 |
+
idx = available[k]
|
| 495 |
+
p = patients[idx]
|
| 496 |
+
treatment_start = datetime.strptime(p["treatment_start"], "%Y-%m-%d")
|
| 497 |
+
# Death 30-540 days after treatment (clearly valid)
|
| 498 |
+
days = self.rng.randint(30, 540)
|
| 499 |
+
death_date = treatment_start + timedelta(days=days)
|
| 500 |
+
p["death_date"] = death_date.strftime("%Y-%m-%d")
|
| 501 |
+
p["outcome"] = "deceased"
|
| 502 |
+
self._traps.add(p["patient_id"])
|
| 503 |
+
|
| 504 |
+
return patients
|
| 505 |
+
|
| 506 |
+
# ββ Main Generator ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 507 |
+
|
| 508 |
+
def generate(self, difficulty: str = "easy") -> dict:
|
| 509 |
+
"""
|
| 510 |
+
Generate a complete adversarial dataset for the given difficulty.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
{
|
| 514 |
+
"dataset": List[dict], # Patient records
|
| 515 |
+
"ground_truth": Dict[str, List[str]], # {pid: [error_types]}
|
| 516 |
+
"traps": Set[str], # Valid-but-suspicious pids
|
| 517 |
+
"bias_present": bool, # Whether bias was injected
|
| 518 |
+
"config": dict, # Generation parameters
|
| 519 |
+
"stats": dict, # Summary statistics
|
| 520 |
+
}
|
| 521 |
+
"""
|
| 522 |
+
config = DIFFICULTY_CONFIGS.get(difficulty, DIFFICULTY_CONFIGS["easy"])
|
| 523 |
+
self._ground_truth = {}
|
| 524 |
+
self._traps = set()
|
| 525 |
+
self._patient_counter = 0
|
| 526 |
+
|
| 527 |
+
n = config["dataset_size"]
|
| 528 |
+
n_control = int(n * config["control_ratio"])
|
| 529 |
+
n_treatment = n - n_control
|
| 530 |
+
|
| 531 |
+
# ββ Step 1: Generate base patients ββ
|
| 532 |
+
patients = []
|
| 533 |
+
|
| 534 |
+
# Determine bias mode for control group
|
| 535 |
+
control_bias_mode = "white_dominant" if config["bias_intensity"] > 0 else "neutral"
|
| 536 |
+
|
| 537 |
+
for _ in range(n_control):
|
| 538 |
+
p = self._generate_base_patient("control", bias_mode=control_bias_mode)
|
| 539 |
+
p = self._apply_mortality(p, config["mortality_rate"])
|
| 540 |
+
patients.append(p)
|
| 541 |
+
|
| 542 |
+
for _ in range(n_treatment):
|
| 543 |
+
p = self._generate_base_patient("treatment", bias_mode="diverse")
|
| 544 |
+
p = self._apply_mortality(p, config["mortality_rate"])
|
| 545 |
+
patients.append(p)
|
| 546 |
+
|
| 547 |
+
# ββ Step 2: Inject errors ββ
|
| 548 |
+
patients = self._inject_age_errors(
|
| 549 |
+
patients, config["age_error_rate"], config["missing_data_rate"]
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if config["temporal_error_rate"] > 0:
|
| 553 |
+
patients = self._inject_temporal_errors(
|
| 554 |
+
patients, config["temporal_error_rate"]
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# ββ Step 3: Inject bias (hard only) ββ
|
| 558 |
+
if config["bias_intensity"] > 0:
|
| 559 |
+
patients = self._inject_bias(patients, config["bias_intensity"])
|
| 560 |
+
|
| 561 |
+
# ββ Step 4: Inject adversarial traps ββ
|
| 562 |
+
patients = self._inject_boundary_traps(patients, config["num_boundary_traps"])
|
| 563 |
+
|
| 564 |
+
if config["num_temporal_traps"] > 0:
|
| 565 |
+
patients = self._inject_temporal_traps(
|
| 566 |
+
patients, config["num_temporal_traps"]
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
if config["num_fake_bias_distractors"] > 0:
|
| 570 |
+
patients = self._inject_fake_bias_distractors(
|
| 571 |
+
patients, config["num_fake_bias_distractors"]
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
patients = self._inject_distractor_deceased(
|
| 575 |
+
patients, config["num_distractor_deceased"]
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# ββ Step 5: Shuffle dataset ββ
|
| 579 |
+
self.rng.shuffle(patients)
|
| 580 |
+
|
| 581 |
+
# ββ Step 6: Compute summary stats ββ
|
| 582 |
+
n_age_errors = sum(
|
| 583 |
+
1 for errs in self._ground_truth.values()
|
| 584 |
+
if "invalid_age" in errs
|
| 585 |
+
)
|
| 586 |
+
n_temporal_errors = sum(
|
| 587 |
+
1 for errs in self._ground_truth.values()
|
| 588 |
+
if "temporal_inconsistency" in errs
|
| 589 |
+
)
|
| 590 |
+
total_errors = n_age_errors + n_temporal_errors
|
| 591 |
+
if config["bias_intensity"] > 0:
|
| 592 |
+
total_errors += 1 # bias counts as 1 error
|
| 593 |
+
|
| 594 |
+
stats = {
|
| 595 |
+
"total_patients": len(patients),
|
| 596 |
+
"total_errors": total_errors,
|
| 597 |
+
"age_errors": n_age_errors,
|
| 598 |
+
"temporal_errors": n_temporal_errors,
|
| 599 |
+
"bias_present": config["bias_intensity"] > 0,
|
| 600 |
+
"num_traps": len(self._traps),
|
| 601 |
+
"control_count": sum(1 for p in patients if p["group"] == "control"),
|
| 602 |
+
"treatment_count": sum(1 for p in patients if p["group"] == "treatment"),
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
"dataset": patients,
|
| 607 |
+
"ground_truth": dict(self._ground_truth),
|
| 608 |
+
"traps": set(self._traps),
|
| 609 |
+
"bias_present": config["bias_intensity"] > 0,
|
| 610 |
+
"config": config,
|
| 611 |
+
"stats": stats,
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 616 |
+
# STANDALONE TEST
|
| 617 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 618 |
+
|
| 619 |
+
if __name__ == "__main__":
|
| 620 |
+
print("=" * 60)
|
| 621 |
+
print(" Dataset Generator β Validation Test")
|
| 622 |
+
print("=" * 60)
|
| 623 |
+
|
| 624 |
+
for diff in ["easy", "medium", "hard"]:
|
| 625 |
+
gen = DatasetGenerator(seed=42)
|
| 626 |
+
result = gen.generate(difficulty=diff)
|
| 627 |
+
stats = result["stats"]
|
| 628 |
+
print(f"\n {diff.upper()}:")
|
| 629 |
+
print(f" Patients: {stats['total_patients']}")
|
| 630 |
+
print(f" Errors: {stats['total_errors']} "
|
| 631 |
+
f"(age={stats['age_errors']}, temporal={stats['temporal_errors']}, "
|
| 632 |
+
f"bias={'yes' if stats['bias_present'] else 'no'})")
|
| 633 |
+
print(f" Traps: {stats['num_traps']}")
|
| 634 |
+
print(f" Control: {stats['control_count']}")
|
| 635 |
+
print(f" Treatment: {stats['treatment_count']}")
|
| 636 |
+
|
| 637 |
+
# Verify reproducibility
|
| 638 |
+
gen2 = DatasetGenerator(seed=42)
|
| 639 |
+
result2 = gen2.generate(difficulty=diff)
|
| 640 |
+
assert result["dataset"] == result2["dataset"], "REPRODUCIBILITY FAILED!"
|
| 641 |
+
assert result["ground_truth"] == result2["ground_truth"], "GROUND TRUTH MISMATCH!"
|
| 642 |
+
print(f" β Seed reproducibility verified")
|
| 643 |
+
|
| 644 |
+
# Verify ground truth
|
| 645 |
+
for pid, errors in result["ground_truth"].items():
|
| 646 |
+
patient = next(p for p in result["dataset"] if p["patient_id"] == pid)
|
| 647 |
+
for err in errors:
|
| 648 |
+
if err == "invalid_age":
|
| 649 |
+
age = patient.get("age")
|
| 650 |
+
assert age is None or age < 18 or age > 120, \
|
| 651 |
+
f"Ground truth says {pid} invalid age but age={age}"
|
| 652 |
+
elif err == "temporal_inconsistency":
|
| 653 |
+
ts = datetime.strptime(patient["treatment_start"], "%Y-%m-%d")
|
| 654 |
+
dd = datetime.strptime(patient["death_date"], "%Y-%m-%d")
|
| 655 |
+
assert dd < ts, \
|
| 656 |
+
f"Ground truth says {pid} temporal error but dates are valid"
|
| 657 |
+
print(f" β Ground truth integrity verified")
|
| 658 |
+
|
| 659 |
+
# Verify different seeds produce different datasets
|
| 660 |
+
gen_a = DatasetGenerator(seed=1)
|
| 661 |
+
gen_b = DatasetGenerator(seed=2)
|
| 662 |
+
result_a = gen_a.generate("easy")
|
| 663 |
+
result_b = gen_b.generate("easy")
|
| 664 |
+
assert result_a["dataset"] != result_b["dataset"], "Different seeds same data!"
|
| 665 |
+
print(f"\n β Different seeds produce different datasets")
|
| 666 |
+
print(f"\n{'=' * 60}")
|
| 667 |
+
print(f" ALL TESTS PASSED")
|
| 668 |
+
print(f"{'=' * 60}")
|
server/models.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List, Dict, Any
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from openenv.core.env_server import Action, Observation, State
|
| 4 |
+
|
| 5 |
+
class AuditAction(Action):
|
| 6 |
+
action_type: str = "flag_error"
|
| 7 |
+
patient_id: Optional[str] = None
|
| 8 |
+
error_type: Optional[str] = None
|
| 9 |
+
reason: Optional[str] = None
|
| 10 |
+
proposed_value: Optional[str] = None
|
| 11 |
+
variable: Optional[str] = None
|
| 12 |
+
report: Optional[str] = None
|
| 13 |
+
confidence: Optional[float] = None # 0.0-1.0: agent's confidence in this action
|
| 14 |
+
|
| 15 |
+
class AuditObservation(Observation):
|
| 16 |
+
done: bool = False
|
| 17 |
+
reward: float = 0.0
|
| 18 |
+
task_id: str = ""
|
| 19 |
+
task_type: str = ""
|
| 20 |
+
task_description: str = ""
|
| 21 |
+
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 22 |
+
errors_found: List[str] = Field(default_factory=list)
|
| 23 |
+
patterns_investigated: List[str] = Field(default_factory=list)
|
| 24 |
+
distributions_computed: List[str] = Field(default_factory=list)
|
| 25 |
+
feedback: Optional[str] = None
|
| 26 |
+
score_so_far: float = 0.0
|
| 27 |
+
attempts_remaining: int = 15
|
| 28 |
+
phase: str = "investigation"
|
| 29 |
+
|
| 30 |
+
class AuditState(State):
|
| 31 |
+
episode_id: str = ""
|
| 32 |
+
step_count: int = 0
|
| 33 |
+
task_id: str = ""
|
| 34 |
+
task_type: str = ""
|
| 35 |
+
total_errors: int = 0
|
| 36 |
+
errors_found: int = 0
|
| 37 |
+
current_score: float = 0.0
|
| 38 |
+
attempts: int = 0
|
| 39 |
+
phase: str = "investigation"
|
| 40 |
+
patterns_investigated: List[str] = Field(default_factory=list)
|
| 41 |
+
distributions_computed: List[str] = Field(default_factory=list)
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.1
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pydantic>=2.0.0
|
test_output.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
EASY: 300 patients
|
| 2 |
+
Age errors: 12
|
| 3 |
+
Phase: flagging
|
| 4 |
+
Flag: β Correct: P0037 has invalid age (999). Good catch.
|
| 5 |
+
HARD: 800 patients, bias: True
|
| 6 |
+
DONE
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|