TheJackBright Claude Opus 4.6 commited on
Commit
f0ef01d
·
1 Parent(s): bbb6de2

Version 3

Browse files

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

.gitignore CHANGED
@@ -23,9 +23,13 @@ yarn-debug.log*
23
  yarn-error.log*
24
  pnpm-debug.log*
25
 
 
 
 
26
  # --- Build / temp ---
27
  *.log
28
  *.tmp
29
  *.swp
30
  .DS_Store
31
 
 
 
23
  yarn-error.log*
24
  pnpm-debug.log*
25
 
26
+ # --- Project-specific ---
27
+ PROMPT.md
28
+
29
  # --- Build / temp ---
30
  *.log
31
  *.tmp
32
  *.swp
33
  .DS_Store
34
 
35
+ arXiv-2212.05190v3/
Dockerfile CHANGED
@@ -1,7 +1,7 @@
1
  FROM node:20-alpine AS frontend-builder
2
  WORKDIR /app/frontend
3
  COPY frontend/package*.json ./
4
- RUN npm ci
5
  COPY frontend/ ./
6
  RUN npm run build
7
 
@@ -11,6 +11,8 @@ RUN apt-get update && \
11
  apt-get install -y --no-install-recommends build-essential curl && \
12
  rm -rf /var/lib/apt/lists/*
13
 
 
 
14
  WORKDIR /app
15
 
16
  COPY backend/requirements.txt /app/backend/requirements.txt
@@ -22,18 +24,26 @@ COPY scripts /app/scripts
22
  COPY openenv.yaml /app/openenv.yaml
23
  COPY .env.example /app/.env.example
24
  COPY inference.py /app/inference.py
 
 
25
 
26
  COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist
27
 
28
  RUN python3 /app/scripts/preprocess_data.py
29
 
 
 
 
 
30
  ENV PORT=7860
31
- ENV PYTHONPATH="/app/backend/src:${PYTHONPATH}"
32
  ENV PYTHONUNBUFFERED=1
33
 
 
 
34
  EXPOSE 7860
35
 
36
- HEALTHCHECK --interval=30s --timeout=3s --start-period=15s --retries=3 \
37
  CMD curl -f http://localhost:7860/health || exit 1
38
 
39
- CMD ["sh", "-c", "uvicorn backend.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
 
1
  FROM node:20-alpine AS frontend-builder
2
  WORKDIR /app/frontend
3
  COPY frontend/package*.json ./
4
+ RUN npm ci --production=false
5
  COPY frontend/ ./
6
  RUN npm run build
7
 
 
11
  apt-get install -y --no-install-recommends build-essential curl && \
12
  rm -rf /var/lib/apt/lists/*
13
 
14
+ # HF Spaces runs as uid 1000
15
+ RUN useradd -m -u 1000 user
16
  WORKDIR /app
17
 
18
  COPY backend/requirements.txt /app/backend/requirements.txt
 
24
  COPY openenv.yaml /app/openenv.yaml
25
  COPY .env.example /app/.env.example
26
  COPY inference.py /app/inference.py
27
+ COPY train_rl.py /app/train_rl.py
28
+ COPY train_bandit.py /app/train_bandit.py
29
 
30
  COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist
31
 
32
  RUN python3 /app/scripts/preprocess_data.py
33
 
34
+ # Ensure the user owns the app directory and has a writable home (HF Spaces)
35
+ RUN chown -R user:user /app && \
36
+ mkdir -p /home/user/.cache && chown -R user:user /home/user
37
+
38
  ENV PORT=7860
39
+ ENV PYTHONPATH="/app/backend/src:/app"
40
  ENV PYTHONUNBUFFERED=1
41
 
42
+ USER user
43
+
44
  EXPOSE 7860
45
 
46
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
47
  CMD curl -f http://localhost:7860/health || exit 1
48
 
49
+ CMD ["sh", "-c", "uvicorn backend.main:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1"]
PROMPT.md DELETED
@@ -1,571 +0,0 @@
1
- You are an expert Python backend, ML, and infrastructure engineer.
2
- Your task is to implement a complete, production-ready OpenEnv environment called **PolypharmacyEnv** for training and evaluating agentic RL policies that act as an "elderly polypharmacy safety agent" (clinical pharmacist assistant).
3
-
4
- The deliverable MUST satisfy all of the following:
5
- - Fully compliant with the OpenEnv spec (typed models, `step()` / `reset()` / `state()`, `openenv.yaml`, HTTP server, Dockerfile).
6
- - Simulates a realistic healthcare workflow around elderly polypharmacy and dangerous drug combinations.
7
- - Defines at least **3 tasks** (easy → medium → hard) with deterministic agent graders producing scores in (0.0, 1.0).
8
- - Provides shaped rewards over the trajectory (not just sparse terminal rewards).
9
- - Includes a baseline LLM-based inference script `inference.py` in the repo root, following the evaluation requirements:
10
- - Uses the OpenAI Python client.
11
- - Reads `OPENAI_API_KEY`, `API_BASE_URL`, `MODEL_NAME`, and `HF_TOKEN` from the environment.
12
- - Emits structured stdout logs in the exact `[START]`, `[STEP]`, `[END]` format from the OpenEnv sample inference script.
13
- - Is containerized and deployable as a **Hugging Face Space** tagged with `openenv` that responds to OpenEnv-style `reset` / `step` / `state` HTTP calls.
14
-
15
- Implement everything described below.
16
-
17
- =================================================
18
- 1. Repository and folder structure
19
- =================================================
20
-
21
- Create a Python package repository with this structure (names are important unless clearly labeled as examples):
22
-
23
- - `openenv-polypharmacy/`
24
- - `openenv.yaml`
25
- - `README.md`
26
- - `requirements.txt`
27
- - `Dockerfile`
28
- - `inference.py` # baseline LLM agent per spec
29
- - `pyproject.toml` or `setup.cfg` (optional but recommended)
30
- - `src/`
31
- - `polypharmacy_env/`
32
- - `__init__.py`
33
- - `config.py`
34
- - `models.py` # Action, Observation, State, helper models
35
- - `env_core.py` # PolypharmacyEnv implementation
36
- - `tasks.py` # task setup utilities
37
- - `graders.py` # deterministic graders for each task
38
- - `rewards.py` # reward shaping logic
39
- - `data_loader.py` # load/preprocess patient and lookup data
40
- - `ddi_simulator.py` # local DDI / guideline simulator
41
- - `api/`
42
- - `__init__.py`
43
- - `schemas.py` # HTTP request/response schemas
44
- - `server.py` # FastAPI app exposing OpenEnv endpoints
45
- - `baselines/`
46
- - `__init__.py`
47
- - `heuristic_agent.py` # simple rule-based baseline agent
48
- - `random_agent.py` # trivial random baseline (optional)
49
- - `tests/`
50
- - `__init__.py`
51
- - `test_env_core.py`
52
- - `test_api.py`
53
- - `data/`
54
- - `raw/` # placeholder for real/synthetic source data
55
- - `processed/`
56
- - `lookups/`
57
- - `ddi_rules.csv`
58
- - `beers_criteria.csv`
59
- - `drug_metadata.csv`
60
- - `scripts/`
61
- - `preprocess_data.py`
62
- - `run_validation.sh` # optional; runs OpenEnv validator, tests, etc.
63
-
64
- Use Python 3.10+ with full type hints, and keep the code black/isort-compatible.
65
-
66
- =================================================
67
- 2. Domain, data, and clinical abstraction
68
- =================================================
69
-
70
- 2.1. Core scenario
71
-
72
- Model an elderly patient (age ≥ 65) with:
73
- - Demographics: age, sex.
74
- - Comorbidities: e.g., hypertension, diabetes, heart failure, CKD, dementia.
75
- - Basic labs: kidney function (eGFR category), liver function category.
76
- - A current medication list (polypharmacy, e.g., 3–15 drugs depending on task).
77
-
78
- Each **episode** is one medication-review session where the agent:
79
- - Observes patient info and current meds.
80
- - Optionally **queries** a DDI/guideline tool for specific drug pairs.
81
- - Proposes **interventions**:
82
- - `stop`: discontinue a drug.
83
- - `dose_reduce`: lower dose of a drug.
84
- - `substitute`: swap to a safer alternative.
85
- - `add_monitoring`: keep the drug but flag extra monitoring.
86
- - Calls `finish_review` when it decides the regimen is acceptable or budgets are exhausted.
87
-
88
- No external PHI, EHRs, or online APIs: all data is **synthetic** or de-identified and local to the container (CSV files).
89
-
90
- 2.2. Data files and CSV schemas
91
-
92
- Implement local CSVs under `data/lookups/`:
93
-
94
- **`drug_metadata.csv`**
95
- - `drug_id` (string; unique key)
96
- - `generic_name` (string)
97
- - `atc_class` (string)
98
- - `is_high_risk_elderly` (0/1)
99
- - `default_dose_mg` (float)
100
- - `min_dose_mg` (float)
101
- - `max_dose_mg` (float)
102
-
103
- **`beers_criteria.csv`**
104
- - `drug_id` (string)
105
- - `criterion_type` (enum string: `avoid`, `caution`, `dose_adjust`, `avoid_in_condition`)
106
- - `condition` (nullable string; e.g., `CKD`, `dementia`)
107
- - `rationale` (brief text)
108
-
109
- **`ddi_rules.csv`**
110
- - `drug_id_1` (string; normalized so `drug_id_1 < drug_id_2` lexicographically)
111
- - `drug_id_2` (string)
112
- - `severity` (enum string: `mild`, `moderate`, `severe`)
113
- - `mechanism` (short text)
114
- - `recommendation` (enum string: `avoid_combination`, `monitor_closely`, `dose_adjust`, `no_action`)
115
- - `base_risk_score` (float in [0.0, 1.0])
116
-
117
- Implement a synthetic patient-episode dataset under `data/processed/`:
118
-
119
- **`patients_polypharmacy.csv`**
120
- - `episode_id` (string)
121
- - `age` (int)
122
- - `sex` (enum: `M`, `F`, `O`)
123
- - `conditions` (semicolon-separated; e.g., `HTN;DM;CKD`)
124
- - `eGFR_category` (enum: `normal`, `mild`, `moderate`, `severe`)
125
- - `liver_function_category` (enum: `normal`, `impaired`)
126
- - `medication_ids` (semicolon-separated list of `drug_id`)
127
- - `baseline_risk_score` (float in [0.0, 1.0])
128
-
129
- 2.3. Preprocessing script
130
-
131
- In `scripts/preprocess_data.py`:
132
- - If real data is not provided, procedurally generate synthetic but plausible data using:
133
- - Random combinations of conditions and drugs constrained by simple rules (e.g., CKD + renally-cleared drugs).
134
- - Controlled distribution of high-risk DDIs and Beers violations.
135
- - Explicitly tag episodes as easy/medium/hard (e.g., via number of drugs, number/severity of DDIs, and number of Beers issues).
136
- - Save `patients_polypharmacy.csv` ready for the environment to consume.
137
-
138
- =================================================
139
- 3. OpenEnv models and environment implementation
140
- =================================================
141
-
142
- 3.1. Models
143
-
144
- In `models.py`, define dataclasses or Pydantic models that extend the appropriate OpenEnv base types (`Action`, `Observation`, `State`) and are JSON-compatible.
145
-
146
- Auxiliary models:
147
-
148
- **`MedicationEntry`**
149
- - `drug_id: str`
150
- - `generic_name: str`
151
- - `atc_class: str`
152
- - `dose_mg: float`
153
- - `frequency: str` # e.g., `qd`, `bid`
154
- - `route: str` # e.g., `po`
155
- - `is_high_risk_elderly: bool`
156
- - `beers_flags: list[str]` # e.g., `["avoid", "dose_adjust_CKD"]`
157
-
158
- **`InteractionQueryRecord`**
159
- - `drug_id_1: str`
160
- - `drug_id_2: str`
161
- - `severity: str | None`
162
- - `recommendation: str | None`
163
- - `risk_score: float | None`
164
- - `step_index: int`
165
-
166
- **`InterventionRecord`**
167
- - `target_drug_id: str`
168
- - `action_type: Literal["stop", "dose_reduce", "substitute", "add_monitoring"]`
169
- - `proposed_new_drug_id: str | None`
170
- - `rationale: str`
171
- - `step_index: int`
172
-
173
- Core wire models:
174
-
175
- **`PolypharmacyObservation`** (extends OpenEnv `Observation`)
176
- - `episode_id: str`
177
- - `task_id: Literal["easy_screening", "budgeted_screening", "complex_tradeoff"]`
178
- - `age: int`
179
- - `sex: str`
180
- - `conditions: list[str]`
181
- - `eGFR_category: str`
182
- - `liver_function_category: str`
183
- - `current_medications: list[MedicationEntry]`
184
- - `interaction_queries: list[InteractionQueryRecord]`
185
- - `interventions: list[InterventionRecord]`
186
- - `step_index: int`
187
- - `remaining_query_budget: int`
188
- - `remaining_intervention_budget: int`
189
- - `shaped_reward: float` # reward from last step
190
- - `done: bool`
191
-
192
- **`PolypharmacyAction`** (extends OpenEnv `Action`)
193
- - `action_type: Literal["query_ddi", "propose_intervention", "finish_review"]`
194
- - `drug_id_1: str | None` # for DDI queries or some interventions
195
- - `drug_id_2: str | None` # for DDI queries
196
- - `target_drug_id: str | None` # for interventions
197
- - `intervention_type: Literal["stop", "dose_reduce", "substitute", "add_monitoring", "none"] | None`
198
- - `proposed_new_drug_id: str | None`
199
- - `rationale: str | None`
200
-
201
- **`PolypharmacyState`** (extends OpenEnv `State`)
202
- - `episode_id: str`
203
- - `task_id: str`
204
- - `step_count: int`
205
- - `max_steps: int`
206
- - `num_query_actions: int`
207
- - `num_interventions: int`
208
-
209
- 3.2. Environment core
210
-
211
- In `env_core.py`, implement `PolypharmacyEnv` extending the appropriate OpenEnv environment base class. It must implement:
212
-
213
- **`reset(task_id: str | None = None) -> PolypharmacyObservation`**
214
- - If `task_id` is `None`, default to medium (`budgeted_screening`).
215
- - Sample an episode from `patients_polypharmacy.csv` filtered by difficulty.
216
- - Initialize:
217
- - `episode_id`
218
- - `step_count = 0`
219
- - task-specific budgets (query, interventions, max_steps)
220
- - baseline regime and risk
221
- - empty `interaction_queries` and `interventions`
222
- - Return the initial `PolypharmacyObservation` with:
223
- - `step_index = 0`
224
- - `shaped_reward = 0.0`
225
- - `done = False`
226
-
227
- **`step(action: PolypharmacyAction) -> dict`**
228
- - Validate the action; if invalid:
229
- - Apply a negative reward.
230
- - Do not modify regimen, but log error in `info`.
231
- - If `action_type == "query_ddi"`:
232
- - If query budget exhausted, apply penalty and do not query.
233
- - Else:
234
- - Use `ddi_simulator.lookup_ddi(drug_id_1, drug_id_2)` to get severity, recommendation, base_risk_score.
235
- - Append an `InteractionQueryRecord`.
236
- - Apply a small negative reward for query cost.
237
- - If `action_type == "propose_intervention"`:
238
- - If intervention budget exhausted, apply penalty and ignore change.
239
- - Else:
240
- - Update `current_medications` according to `intervention_type`:
241
- - `stop`: remove medication.
242
- - `dose_reduce`: adjust dose downward within [min_dose_mg, default_dose_mg].
243
- - `substitute`: replace with a safer alternative from same `atc_class`.
244
- - `add_monitoring`: keep drug but tag in internal state.
245
- - Append an `InterventionRecord`.
246
- - Recompute current regimen risk using the risk model (see 3.3).
247
- - Compute shaped reward = (previous_risk - new_risk) - small intervention cost.
248
- - If `action_type == "finish_review"`:
249
- - Mark `done = True`.
250
- - Call the task’s grader to get episode-level score in [0.0, 1.0].
251
- - Add this as a terminal bonus to the current step reward.
252
-
253
- - In all cases:
254
- - Increment `step_count`.
255
- - Check `max_steps`; if exceeded, auto-terminate:
256
- - `done = True`
257
- - apply time-out penalty
258
- - call grader with current trajectory for a final score if appropriate.
259
- - Construct next `PolypharmacyObservation` with updated fields.
260
- - Return a dict:
261
- - `observation`: `PolypharmacyObservation`
262
- - `reward`: float shaped reward for this step
263
- - `done`: bool
264
- - `info`: dict with fields like `current_risk`, `baseline_risk`, `grader_score_if_terminal`, and debug flags.
265
-
266
- **`state` property**
267
- - Returns `PolypharmacyState` reflecting the current internal state.
268
-
269
- 3.3. DDI simulator and risk model
270
-
271
- In `ddi_simulator.py`:
272
- - Load `ddi_rules.csv` once via `data_loader`.
273
- - Implement `lookup_ddi(drug_id_1, drug_id_2) -> tuple[severity, recommendation, base_risk_score]`:
274
- - Normalize the pair ordering.
275
- - Look up row; if missing, return:
276
- - severity = `"none"`
277
- - recommendation = `"no_action"`
278
- - base_risk_score = 0.0
279
-
280
- In `rewards.py` (or a dedicated module), implement:
281
- - `compute_regimen_risk(current_drug_ids, patient_context, ddi_rules, beers_rules, drug_metadata) -> float`
282
- - Aggregate contributions from:
283
- - Beers violations (weighted by `criterion_type` and relevant conditions).
284
- - DDI base risk scores for all present drug pairs.
285
- - High-risk elderly drugs.
286
- - Normalize and clip to [0.0, 1.0].
287
-
288
- Use this function to compute:
289
- - `baseline_risk` at episode start.
290
- - Risk after each intervention step.
291
-
292
- Also implement:
293
- - `compute_shaped_reward(previous_risk, new_risk, action, context, partial_metrics) -> float`
294
- - Positive component: `previous_risk - new_risk`.
295
- - Negative components: per-query cost, per-intervention cost, invalid-action penalty, time-out penalty.
296
-
297
- =================================================
298
- 4. Tasks and graders (3 difficulty levels)
299
- =================================================
300
-
301
- Define three task IDs and semantics in `tasks.py` and `graders.py`:
302
-
303
- Task IDs:
304
- - `easy_screening`
305
- - `budgeted_screening`
306
- - `complex_tradeoff`
307
-
308
- 4.1. `easy_screening` (easy)
309
-
310
- - Small regimen: 3–5 drugs.
311
- - Exactly one **severe** DDI pair and possibly one simple Beers violation.
312
- - Budgets:
313
- - query_budget ≈ 4
314
- - intervention_budget ≈ 2
315
- - max_steps ≈ 10
316
-
317
- Grader:
318
- - Input: full trajectory, baseline risk, final risk, list of interventions.
319
- - Compute:
320
- - `risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, ε)` (normalized).
321
- - `targeted_intervention_flag = 1.0` if at least one intervention affects one of the drugs in the known severe DDI pair, else 0.0.
322
- - Score:
323
- - `score = 0.5 * risk_reduction + 0.5 * targeted_intervention_flag`
324
- - Clip to [0.0, 1.0].
325
-
326
- 4.2. `budgeted_screening` (medium)
327
-
328
- - Medium regimen: 6–10 drugs.
329
- - Multiple DDIs (mild/moderate/severe) and multiple Beers issues.
330
- - Budgets:
331
- - query_budget ≈ 8
332
- - intervention_budget ≈ 3
333
- - max_steps ≈ 20
334
-
335
- Grader:
336
- - Compute:
337
- - `risk_reduction_score` as normalized risk drop.
338
- - `intervention_precision_score` = fraction of interventions that actually reduce risk or fix guideline violations.
339
- - `query_efficiency_score` = (number of severe/moderate DDIs discovered) / (number of queries used), normalized.
340
- - Weighted score, for example:
341
- - `score = 0.5 * risk_reduction_score + 0.3 * intervention_precision_score + 0.2 * query_efficiency_score`
342
- - Clip to [0.0, 1.0].
343
-
344
- 4.3. `complex_tradeoff` (hard)
345
-
346
- - Larger regimen: 10–15 drugs.
347
- - Some drugs are **clinically critical** (e.g., anticoagulants, insulin analogues) and encoded as such in `drug_metadata` or a small internal map.
348
- - Episodes contain:
349
- - multiple DDIs and Beers issues, including ones involving critical drugs.
350
- - safer substitutes for some risky drugs.
351
-
352
- Budgets:
353
- - query_budget ≈ 12
354
- - intervention_budget ≈ 5
355
- - max_steps ≈ 30
356
-
357
- Grader adds a **regimen disruption penalty** component:
358
- - Metrics:
359
- - `risk_reduction_score` (as above).
360
- - `critical_drug_penalty` = penalty if a critical drug is stopped without substitution to another suitable agent.
361
- - `total_drug_changes` = number of drugs stopped or substituted.
362
- - `regimen_disruption_penalty` derived from `total_drug_changes` and `critical_drug_penalty`.
363
-
364
- Example scoring:
365
- - `base = risk_reduction_score`
366
- - `penalty = α * regimen_disruption_penalty`
367
- - `score = clamp(base - penalty, 0.0, 1.0)`
368
-
369
- 4.4. Reward shaping
370
-
371
- In `rewards.py`, define a consistent shaping scheme:
372
- - On each query:
373
- - Small negative reward (e.g., −0.01) plus any small bonus if it discovers a severe DDI, if desired.
374
- - On each intervention:
375
- - Reward ≈ (previous_risk - new_risk) − small intervention cost.
376
- - On invalid actions:
377
- - Larger negative reward (e.g., −0.1) and no state change.
378
- - On `finish_review`:
379
- - Add the task-level `score` ∈ [0.0, 1.0] from the corresponding grader to that step’s shaped reward.
380
-
381
- Ensure the sum of step rewards per episode remains in a reasonable numeric range (e.g., roughly -5 to +5) while still allowing meaningful differentiation by graders.
382
-
383
- =================================================
384
- 5. HTTP API server and openenv.yaml
385
- =================================================
386
-
387
- 5.1. HTTP server (FastAPI)
388
-
389
- In `api/server.py`:
390
- - Implement a FastAPI app that maintains a `PolypharmacyEnv` instance (or a multiplexing scheme if needed).
391
- - Endpoints:
392
- - `POST /reset`:
393
- - Request body: may include `task_id` (string).
394
- - Response: serialized `PolypharmacyObservation`.
395
- - `POST /step`:
396
- - Request body: serialized `PolypharmacyAction`.
397
- - Response: dict with:
398
- - `observation`: `PolypharmacyObservation`
399
- - `reward`: float
400
- - `done`: bool
401
- - `info`: dict
402
- - `GET /state`:
403
- - Response: `PolypharmacyState`.
404
-
405
- Provide a module-level `app = FastAPI(...)` object for use with uvicorn and Hugging Face Spaces. Ensure the JSON schema is consistent with OpenEnv clients (simple, flat JSON for observation/action/state).
406
-
407
- 5.2. `openenv.yaml`
408
-
409
- At repo root, define `openenv.yaml` consistent with the latest OpenEnv spec. At minimum, include:
410
- - `name`: `polypharmacy_env`
411
- - `version`: e.g., `0.1.0`
412
- - `description`: human-readable description.
413
- - `author`: your details.
414
- - `tags`: e.g., `["healthcare", "polypharmacy", "openenv"]`
415
- - `tasks`:
416
- - One entry per task:
417
- - `id`: `"easy_screening"` / `"budgeted_screening"` / `"complex_tradeoff"`
418
- - `description`: one-line description
419
- - `difficulty`: `"easy"`, `"medium"`, `"hard"`
420
-
421
- Ensure `openenv validate` (or equivalent validator) passes once implemented.
422
-
423
- =================================================
424
- 6. Baseline heuristic (non-LLM) agent
425
- =================================================
426
-
427
- In `baselines/heuristic_agent.py`, implement a simple, deterministic baseline agent that:
428
-
429
- For each episode:
430
- - Iterates through all unordered medication pairs within query budget:
431
- - Calls `query_ddi` via the environment for each pair until the query budget is exhausted or all pairs are examined.
432
- - Records severe and moderate interactions.
433
- - After querying:
434
- - For each severe DDI pair:
435
- - Try `substitute` one of the drugs using `drug_metadata`:
436
- - Prefer substitute within same `atc_class` that:
437
- - is not marked high-risk elderly.
438
- - does not participate in known severe DDIs with the rest of the regimen.
439
- - If no substitute exists, propose `stop` for the higher-risk drug.
440
- - Respect intervention budget limits.
441
- - Finally, call `finish_review`.
442
-
443
- This baseline should be callable as a simple Python function that interacts with `PolypharmacyEnv` directly (without HTTP).
444
-
445
- =================================================
446
- 7. Baseline LLM inference script (inference.py)
447
- =================================================
448
-
449
- At repo root, create `inference.py` that:
450
-
451
- 7.1. Uses the OpenAI Python client
452
-
453
- - Import and configure the official OpenAI Python client.
454
- - Read environment variables:
455
- - `OPENAI_API_KEY` (required).
456
- - `API_BASE_URL` (base URL for LLM; default to OpenAI standard if not set).
457
- - `MODEL_NAME` (e.g., `gpt-4.1` or similar).
458
- - `HF_TOKEN` (if needed for HF auth; do not hardcode).
459
- - Read `POLYPHARMACY_ENV_URL` (or similar) for the environment’s HTTP base URL.
460
-
461
- 7.2. Implements the required logging format
462
-
463
- - For each **run** across all tasks:
464
- - Emit a `[START]` line with a JSON payload exactly matching the evaluation specification:
465
- - Fields such as `run_id`, `task_id`, `model`, etc., in the same order and naming as the sample OpenEnv inference script.
466
- - For each **step** in an episode:
467
- - Emit a `[STEP]` line with JSON fields including:
468
- - `run_id`
469
- - `task_id`
470
- - `episode_id`
471
- - `step_index`
472
- - `observation_summary` (brief, machine-readable summary)
473
- - `action_payload` (the action sent to the env)
474
- - `reward`
475
- - `done`
476
- - After finishing an episode for a task:
477
- - Emit an `[END]` line summarizing:
478
- - `run_id`
479
- - `task_id`
480
- - per-episode statistics (e.g., total reward, grader score from last step’s `info`).
481
- - The stdout format MUST follow the sample exactly:
482
- - Same tags: `[START]`, `[STEP]`, `[END]`.
483
- - Same JSON field names and ordering as the provided reference.
484
- - No extra prints except these structured logs (and necessary error messages to stderr).
485
-
486
- 7.3. LLM agent loop
487
-
488
- - For each task (`easy_screening`, `budgeted_screening`, `complex_tradeoff`):
489
- - Run a fixed small number of episodes (e.g., 5–10 per task) for baseline scoring.
490
- - For each episode:
491
- - Call `/reset` with the task id.
492
- - At each step:
493
- - Summarize the observation into a concise prompt for the LLM:
494
- - Include age, sex, conditions, high-risk flags, budgets, and a compressed view of meds and previous actions.
495
- - Ask the model to output a **strict JSON** representing `PolypharmacyAction` fields.
496
- - Parse and validate the JSON; if invalid, fall back to a safe default (e.g., `finish_review` or a no-op) and penalize in evaluation.
497
- - Send this action to `/step` and log `[STEP]`.
498
- - End when `done=True` or max_steps is reached.
499
- - At the end, print aggregate scores per task and overall.
500
-
501
- Make sure runtime < 20 minutes and that the script can run within 2 vCPUs and 8 GB RAM.
502
-
503
- =================================================
504
- 8. Dockerfile and Hugging Face Space
505
- =================================================
506
-
507
- 8.1. Dockerfile
508
-
509
- Create a `Dockerfile` that:
510
- - Starts from a slim Python image (e.g., `python:3.11-slim`).
511
- - Installs system dependencies as needed (e.g., `build-essential`, `curl`).
512
- - Copies the project into the container.
513
- - Installs Python dependencies from `requirements.txt`.
514
- - Sets appropriate environment variables for the app (e.g., `PORT=7860`).
515
- - Exposes port 7860.
516
- - Uses a `CMD` or `ENTRYPOINT` that runs the FastAPI server, for example:
517
- - `uvicorn polypharmacy_env.api.server:app --host 0.0.0.0 --port 7860`
518
-
519
- 8.2. Hugging Face Space
520
-
521
- Ensure the repository is ready to be used as a Hugging Face Space:
522
- - Space type: `docker`.
523
- - Tag: `openenv`.
524
- - On container start, the server must listen on the correct port and respond to:
525
- - `POST /reset`
526
- - `POST /step`
527
- - `GET /state`
528
- - The environment must start cleanly with `docker build` + `docker run` locally.
529
-
530
- =================================================
531
- 9. README and documentation
532
- =================================================
533
-
534
- In `README.md`, include:
535
-
536
- - **Environment description & motivation**:
537
- - What PolypharmacyEnv simulates.
538
- - Why elderly polypharmacy safety matters.
539
- - **Action and observation spaces**:
540
- - Describe `PolypharmacyAction`, `PolypharmacyObservation`, and `PolypharmacyState` fields and semantics.
541
- - **Task descriptions**:
542
- - `easy_screening`, `budgeted_screening`, `complex_tradeoff`, their difficulty and goals.
543
- - **Reward structure**:
544
- - Summarize shaping and terminal rewards.
545
- - **Setup & usage**:
546
- - How to install dependencies.
547
- - How to run the API server locally (uvicorn command).
548
- - How to run the heuristic baseline.
549
- - How to run `inference.py` with environment variables.
550
- - **Baseline scores**:
551
- - Document reproducible baseline scores for each task (heuristic agent, and LLM baseline if available).
552
-
553
- =================================================
554
- 10. Validation and quality gates
555
- =================================================
556
-
557
- - Ensure:
558
- - `openenv.yaml` and the HTTP server pass the OpenEnv validation script.
559
- - `docker build` and `docker run` work without errors.
560
- - `inference.py` completes under 20 minutes, within 2 vCPUs / 8 GB RAM.
561
- - All graders:
562
- - Are deterministic.
563
- - Return scores strictly in [0.0, 1.0].
564
- - No grader returns a constant score irrespective of behavior.
565
-
566
- Aim for clean, well-structured, well-documented code with clear separation of concerns between:
567
- - Data loading,
568
- - Environment state & dynamics,
569
- - Reward/grade logic,
570
- - HTTP serving,
571
- - Baseline agents and inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.MD CHANGED
@@ -1,256 +1,528 @@
1
  ---
2
- title: Polypharmacy
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- # PolypharmacyEnv
11
 
12
- Monorepo for an OpenEnv-compatible medication safety environment with:
13
 
14
- - a FastAPI backend (`backend/`)
15
- - a React frontend (`frontend/`)
16
- - data assets (`data/`)
17
- - utility scripts (`scripts/`)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ---
20
 
21
  ## Repository Structure
22
 
23
- ```text
24
- backend/
25
- main.py # ASGI entrypoint (uvicorn target)
26
- requirements.txt # Backend dependencies
27
- Dockerfile # Backend container
28
- src/polypharmacy_env/ # Python package source
29
- api/
30
- app.py # FastAPI/OpenEnv app assembly
31
- server.py # Compatibility import wrapper
32
- routes/agent.py # /agent/suggest route
33
- services/
34
- groq_agent.py # Groq-based action suggestion logic
35
- env_core.py # OpenEnv environment core
36
- models.py # Action/observation/state models
37
- data_loader.py # CSV loading
38
- ddi_simulator.py # DDI and Beers lookups
39
- rewards.py # Reward shaping
40
- graders.py # Task graders
41
- tasks.py # Task/episode selection
42
- tests/ # Backend tests
43
- frontend/
44
- src/ # React UI code
45
- package.json
46
- Dockerfile # Frontend container
47
- data/
48
- lookups/ # drug_metadata.csv, ddi_rules.csv, beers_criteria.csv
49
- processed/ # patients_polypharmacy.csv
50
- scripts/
51
- preprocess_data.py # Synthetic data generation
52
- dev_backend.sh # Local backend run helper
53
- dev_frontend.sh # Local frontend run helper
54
- run_validation.sh # Tests + baseline validation
55
- docker-compose.yml # Full stack orchestration
56
- openenv.yaml # OpenEnv manifest
57
- inference.py # Baseline inference script (required at root)
58
- .env.example # Environment template
 
 
 
 
 
 
 
 
 
 
 
59
  ```
60
 
61
  ---
62
 
63
- ## What It Does
 
 
64
 
65
- The environment simulates elderly polypharmacy review. Agent actions:
 
 
 
 
66
 
67
- - `query_ddi`
68
- - `propose_intervention`
69
- - `finish_review`
70
 
71
- Supported tasks:
72
 
73
- - `easy_screening`
74
- - `budgeted_screening`
75
- - `complex_tradeoff`
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  ---
78
 
79
- ## Prerequisites
 
 
 
 
 
 
 
 
80
 
81
- - Python 3.10+
82
- - Node.js 18+ (or 20+ recommended)
83
- - npm
84
- - Docker + Docker Compose (optional, for containerized run)
 
85
 
86
  ---
87
 
88
- ## Environment Setup
89
 
90
- Create `.env`:
91
 
92
- ```bash
93
- cp .env.example .env
94
- ```
 
 
 
 
 
 
95
 
96
- Set values for local backend integrations as needed.
97
 
98
  ---
99
 
100
- ## Local Run (Recommended During Development)
 
 
 
 
 
 
101
 
102
- ### 1) Install dependencies
103
 
104
- Backend:
105
 
106
  ```bash
107
- pip install -r backend/requirements.txt
 
 
 
108
  ```
109
 
110
- Frontend:
111
 
112
  ```bash
113
- cd frontend
114
- npm install
115
- cd ..
 
 
116
  ```
117
 
118
- ### 2) Generate/update synthetic data (if needed)
119
 
120
  ```bash
121
  python scripts/preprocess_data.py
122
  ```
123
 
124
- ### 3) Start services in two terminals
125
-
126
- Terminal A:
127
 
 
128
  ```bash
129
  ./scripts/dev_backend.sh
130
  ```
131
 
132
- Terminal B:
133
-
134
  ```bash
135
  ./scripts/dev_frontend.sh
136
  ```
137
 
138
- ### 4) Open app
139
 
140
- - Frontend: [http://localhost:5173](http://localhost:5173)
141
- - Backend health: [http://localhost:7860/health](http://localhost:7860/health)
142
 
143
  ---
144
 
145
- ## Docker Run
146
 
147
- Run both services:
148
 
149
  ```bash
150
- docker compose up --build
 
151
  ```
152
 
153
- Stop:
 
 
154
 
155
  ```bash
156
- docker compose down
157
  ```
158
 
159
- Ports:
160
-
161
- - backend: `7860`
162
- - frontend: `5173`
163
 
164
  ---
165
 
166
- ## Hugging Face Spaces Deployment (Docker)
167
 
168
- This repo now includes a **root `Dockerfile`** that builds frontend + backend into one container, so Spaces can host both API and UI together.
169
-
170
- ### 1) Create a new Space
171
 
172
  - Go to [Hugging Face Spaces](https://huggingface.co/new-space)
173
  - Choose **Docker** SDK
174
- - Create the Space
175
 
176
- ### 2) Add Space secrets/variables
177
 
178
- In Space Settings -> Variables and Secrets:
179
 
180
- - Secret: `HF_TOKEN`
181
- - Variable: `API_BASE_URL=https://router.huggingface.co/v1`
182
- - Variable: `MODEL_NAME=Qwen/Qwen2.5-72B-Instruct`
 
 
183
 
184
- ### 3) Push this repository to the Space
185
 
186
- Commit and push all files, including root `Dockerfile`.
 
 
 
187
 
188
- ### 4) Verify after build
189
 
190
  - Space root URL loads the React UI
191
  - `/health` returns healthy status
192
- - OpenEnv endpoints are available (`/reset`, `/step`, `/state`, `/schema`)
193
 
194
- Notes:
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- - Container reads `PORT` (defaults to `7860`) which is Space-friendly.
197
- - Frontend static assets are served by FastAPI from `frontend/dist`.
 
 
 
198
 
199
  ---
200
 
201
- ## API Endpoints
202
 
203
- OpenEnv/health:
 
 
 
 
204
 
205
- - `POST /reset`
206
- - `POST /step`
207
- - `GET /state`
208
- - `GET /health`
209
- - `GET /schema`
210
- - `WS /ws` (stateful session)
211
 
212
- AI helper:
 
 
213
 
214
- - `POST /agent/suggest`
215
 
216
  ---
217
 
218
- ## Testing
 
 
 
 
219
 
220
- Run backend tests:
 
 
 
 
 
 
 
 
 
221
 
222
  ```bash
223
- python -m pytest backend/src/polypharmacy_env/tests -v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  ```
225
 
226
- Or run validation script:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  ```bash
229
- ./scripts/run_validation.sh
 
 
 
 
 
 
 
 
 
 
230
  ```
231
 
232
- ### Submission validation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  ```bash
 
 
 
 
 
 
 
235
  openenv validate
236
- python inference.py
237
  ```
238
 
239
  ---
240
 
241
- ## Notes
 
 
 
 
 
242
 
243
- - OpenEnv HTTP reset/step is stateless; multi-step episode continuity should use websocket (`/ws`).
244
- - The frontend uses websocket for episode continuity and HTTP for AI suggestion.
245
- - AI behavior includes rule-based guardrails to avoid repetitive low-value loops.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  ---
248
 
249
  ## Troubleshooting
250
 
251
- - `ModuleNotFoundError: polypharmacy_env`
252
- - Start backend using `./scripts/dev_backend.sh` from repo root.
253
- - `/agent/suggest` fails
254
- - Check `.env` keys and restart backend.
255
- - UI state looks stale
256
- - Hard refresh browser and click `Reset Episode`.
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PolypharmacyEnv
3
+ emoji: 💊
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ tags:
9
+ - openenv
10
+ - healthcare
11
+ - reinforcement-learning
12
  ---
13
 
14
+ # PolypharmacyEnv — Elderly Medication Safety via Reinforcement Learning
15
 
16
+ An [OpenEnv](https://github.com/meta-pytorch/OpenEnv)-compliant environment that simulates **elderly polypharmacy medication review**. An RL agent acts as a clinical pharmacist assistant: it queries drug-drug interactions (DDIs), identifies Beers-criteria violations, and proposes safe interventions — all under resource-constrained budgets.
17
 
18
+ Built for the **PyTorch OpenEnv Hackathon** to demonstrate how clinical decision support for polypharmacy can be framed as a sequential RL problem and served as a reusable environment through the OpenEnv hub.
19
+
20
+ ---
21
+
22
+ ## Why This Matters
23
+
24
+ Polypharmacy — the simultaneous use of five or more medications — affects the majority of adults over 65. Elderly patients often see multiple specialists who may not be aware of each other's prescriptions, leading to dangerous drug combinations. Studies report that **adverse drug events from polypharmacy contribute to 100,000+ hospitalizations annually** in the US alone.
25
+
26
+ Current solutions use static risk scoring. PolypharmacyEnv goes further by framing medication review as a **sequential decision problem**, where an RL agent must strategically allocate limited query and intervention budgets to maximize patient safety — exactly the kind of resource-constrained optimization that reinforcement learning excels at.
27
+
28
+ **Reference**: Larouche, A., Durand, A., Khoury, R. & Sirois, C. (2023). [Neural Bandits for Data Mining: Searching for Dangerous Polypharmacy](https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5). *Advances in Artificial Intelligence*, Springer.
29
+
30
+ ---
31
+
32
+ ## How OpenEnv & RL Power This
33
+
34
+ ### The RL Formulation
35
+
36
+ PolypharmacyEnv frames medication review as a **Markov Decision Process (MDP)**:
37
+
38
+ - **State**: Patient profile (age, conditions, organ function) + current medication list + interaction history
39
+ - **Action space**: `query_ddi(drug_i, drug_j)` | `propose_intervention(target, type)` | `finish_review`
40
+ - **Reward**: Shaped, dense signal at every step (not sparse end-of-episode). Queries cost budget (-0.015), discovering severe DDIs earns bonus (+0.05), successful interventions earn proportional risk reduction minus cost, invalid actions are penalized (-0.15), and `finish_review` triggers a grader that returns a terminal score in [0.0, 1.0].
41
+ - **Constraint**: Finite query and intervention budgets, creating a resource-allocation optimization problem.
42
+
43
+ This MDP is what makes the problem fundamentally different from static risk scoring: the agent must **decide what information to acquire** (which drug pairs to query) and **which interventions to prioritize**, all under budget constraints — a sequential decision problem that RL is designed to solve.
44
+
45
+ ### OpenEnv Interface
46
+
47
+ PolypharmacyEnv implements the full **OpenEnv standard**:
48
+
49
+ - **`reset()`** — Generates a new patient scenario (age, conditions, medication list)
50
+ - **`step(action)`** — Processes an agent action, updates regimen state, returns shaped reward
51
+ - **`state()`** — Returns the current episode snapshot
52
+
53
+ All models use typed Pydantic classes extending OpenEnv base types (`PolypharmacyAction`, `PolypharmacyObservation`, `PolypharmacyState`).
54
+
55
+ ### What the Environment Enables
56
+
57
+ The shaped reward function provides continuous signal over the full trajectory, making this environment compatible with standard RL training approaches:
58
+
59
+ - **Policy gradient methods** (REINFORCE, PPO, GRPO): The per-step reward signal allows policy networks to learn query prioritization and intervention strategies.
60
+ - **OpenEnv training pipeline**: Through OpenEnv's `step()`/`reset()` HTTP interface, external RL training loops can connect to this environment and train policies without modification.
61
+ - **Neural Bandits (OptimNeuralTS)**: The budget-constrained query selection implements the OptimNeuralTS approach from the reference paper — Neural Thompson Sampling combined with Differential Evolution for efficient search.
62
+
63
+ ### Included Agents
64
+
65
+ The repository ships with multiple agent implementations spanning rule-based, RL-trained, bandit-based, and LLM-based approaches:
66
+
67
+ - **OptimNeuralTS bandit** (`train_bandit.py`, `neural_bandits.py`): Implements the paper's core algorithm — Neural Thompson Sampling with Differential Evolution to efficiently search for dangerous drug combinations. Builds an ensemble of models across training steps for high-precision predictions.
68
+ - **REINFORCE-trained policy** (`train_rl.py`): A neural network policy trained via REINFORCE with learned baseline against the environment's shaped reward. Demonstrates that the MDP formulation and reward shaping enable genuine policy improvement through RL training.
69
+ - **Heuristic agent** (`baselines/heuristic_agent.py`): Deterministic rule-based strategy that queries high-risk drug pairs first, then intervenes on severe DDIs. Serves as a strong domain-knowledge baseline.
70
+ - **LLM agent** (`inference.py`): Uses an LLM (Qwen2.5-72B via OpenAI-compatible API) for zero-shot action generation. Demonstrates baseline LLM performance without RL fine-tuning.
71
+ - **AI suggestion endpoint** (`/agent/suggest`): LLM-powered action suggestions with rule-based guardrails for the interactive UI.
72
 
73
  ---
74
 
75
  ## Repository Structure
76
 
77
+ ```
78
+ ├── backend/
79
+ │ ├── main.py # ASGI entrypoint (uvicorn target)
80
+ │ ├── requirements.txt # Python dependencies
81
+ │ └── src/polypharmacy_env/
82
+ │ ├── env_core.py # OpenEnv environment: reset/step/state
83
+ │ ├── models.py # Typed Pydantic models (Action, Observation, State)
84
+ │ ├── rewards.py # Shaped reward function & regimen risk computation
85
+ │ ├── graders.py # Deterministic graders for 3 task difficulties
86
+ │ ├── tasks.py # Task configuration & episode sampling
87
+ │ ├── config.py # Reward hyperparameters & task parameters
88
+ │ ├── data_loader.py # CSV data loading with caching
89
+ │ ├── ddi_simulator.py # DDI lookup, Beers flags, drug substitution
90
+ │ ├── neural_bandits.py # NeuralTS + Differential Evolution + OptimNeuralTS
91
+ │ ├── api/
92
+ │ │ ├── app.py # FastAPI app factory via OpenEnv create_app
93
+ │ │ └── routes/agent.py # POST /agent/suggest (AI-assisted actions)
94
+ │ │ bandit.py # POST /bandit/predict, /bandit/screen
95
+ │ ├── baselines/
96
+ │ │ ├── heuristic_agent.py # Deterministic baseline agent
97
+ │ │ └── random_agent.py # Random baseline agent
98
+ │ ├── services/
99
+ │ │ └── groq_agent.py # LLM-powered action suggestions
100
+ │ └── tests/
101
+ │ ├── test_env_core.py # Environment unit tests
102
+ │ └── test_api.py # HTTP + WebSocket integration tests
103
+ ├── frontend/
104
+ │ ├── src/
105
+ │ │ ├── App.jsx # React control center UI
106
+ │ │ └── styles.css # Production-quality dark theme
107
+ │ ├── package.json
108
+ │ └── vite.config.js
109
+ ├── data/
110
+ │ ├── lookups/ # drug_metadata.csv, ddi_rules.csv, beers_criteria.csv
111
+ │ └── processed/ # patients_polypharmacy.csv (120 episodes)
112
+ ├── scripts/
113
+ │ ├── preprocess_data.py # Synthetic data generation
114
+ │ ├── dev_backend.sh # Local backend runner
115
+ │ ├── dev_frontend.sh # Local frontend runner
116
+ │ └── run_validation.sh # Automated test + baseline validation
117
+ ├── Dockerfile # Production multi-stage build (frontend + backend)
118
+ ├── docker-compose.yml # Development orchestration
119
+ ├── inference.py # Submission baseline inference script
120
+ ├── train_rl.py # REINFORCE RL training script (PyTorch)
121
+ ├── train_bandit.py # OptimNeuralTS neural bandit training
122
+ ├── openenv.yaml # OpenEnv manifest
123
+ └── .env.example # Environment variable template
124
  ```
125
 
126
  ---
127
 
128
+ ## Action & Observation Spaces
129
+
130
+ ### Actions
131
 
132
+ | Action Type | Parameters | Description |
133
+ |---|---|---|
134
+ | `query_ddi` | `drug_id_1`, `drug_id_2` | Check a drug pair for interactions. Returns severity, recommendation, and risk score. Costs 1 query budget. |
135
+ | `propose_intervention` | `target_drug_id`, `intervention_type`, `proposed_new_drug_id` (opt), `rationale` (opt) | Modify the medication regimen. Types: `stop`, `dose_reduce`, `substitute`, `add_monitoring`. Costs 1 intervention budget. |
136
+ | `finish_review` | — | End the episode. Triggers grader evaluation and returns final score. |
137
 
138
+ ### Observations
 
 
139
 
140
+ Each observation contains the full patient context:
141
 
142
+ | Field | Type | Description |
143
+ |---|---|---|
144
+ | `episode_id` | string | Unique episode identifier |
145
+ | `task_id` | string | Current task (easy_screening / budgeted_screening / complex_tradeoff) |
146
+ | `age`, `sex` | int, string | Patient demographics |
147
+ | `conditions` | list[string] | Active medical conditions |
148
+ | `eGFR_category`, `liver_function_category` | string | Organ function status |
149
+ | `current_medications` | list[MedicationEntry] | Active drugs with dose, ATC class, Beers flags |
150
+ | `interaction_queries` | list[InteractionQueryRecord] | History of DDI queries and results |
151
+ | `interventions` | list[InterventionRecord] | History of proposed interventions |
152
+ | `remaining_query_budget` | int | Remaining DDI query budget |
153
+ | `remaining_intervention_budget` | int | Remaining intervention budget |
154
+ | `shaped_reward` | float | Step reward signal |
155
+ | `done` | bool | Whether the episode has ended |
156
 
157
  ---
158
 
159
+ ## Tasks & Difficulty Progression
160
+
161
+ | Task | Difficulty | Drugs | Query Budget | Intervention Budget | Max Steps | Description |
162
+ |---|---|---|---|---|---|---|
163
+ | **Easy Screening** | Easy | 3–5 | 4 | 2 | 10 | Small regimen with one severe DDI. Identify and resolve it. |
164
+ | **Budgeted Screening** | Medium | 6–10 | 8 | 3 | 20 | Multiple DDIs and Beers issues under tighter budgets. Must prioritize effectively. |
165
+ | **Complex Tradeoff** | Hard | 10–15 | 12 | 5 | 30 | Large regimen with critical drugs (warfarin, insulin). Balance risk reduction against regimen disruption. |
166
+
167
+ ### Grading Criteria
168
 
169
+ - **Easy**: 50% risk reduction + 50% targeted intervention on severe DDI drugs
170
+ - **Medium**: 50% risk reduction + 30% intervention precision + 20% query efficiency
171
+ - **Hard**: Risk reduction minus penalties for excessive drug changes and stopping critical medications without substitution
172
+
173
+ All graders are deterministic, producing scores in `[0.0, 1.0]`.
174
 
175
  ---
176
 
177
+ ## Reward Function Design
178
 
179
+ The shaped reward provides signal at every step (not just episode end):
180
 
181
+ | Event | Reward |
182
+ |---|---|
183
+ | DDI query (any) | -0.015 (budget cost) |
184
+ | Discovering a severe DDI | +0.05 bonus |
185
+ | Discovering a moderate DDI | +0.02 bonus |
186
+ | Successful intervention | +(risk_reduction) - 0.025 cost |
187
+ | Invalid action | -0.15 penalty |
188
+ | Episode timeout | -0.25 penalty |
189
+ | Finish review | +grader_score (0.0–1.0) |
190
 
191
+ **Regimen risk** aggregates DDI pairwise scores, Beers-criteria violation weights, and high-risk elderly drug penalties, normalized by regimen size and clipped to `[0.0, 1.0]`.
192
 
193
  ---
194
 
195
+ ## Prerequisites
196
+
197
+ - **Python** 3.10+
198
+ - **Node.js** 18+ (20+ recommended)
199
+ - **Docker** + Docker Compose (for containerized runs)
200
+
201
+ ---
202
 
203
+ ## Setup & Local Development
204
 
205
+ ### 1. Clone and configure
206
 
207
  ```bash
208
+ git clone <repo-url>
209
+ cd PolypharmacyEnv
210
+ cp .env.example .env
211
+ # Edit .env with your API keys if using the AI suggestion feature
212
  ```
213
 
214
+ ### 2. Install dependencies
215
 
216
  ```bash
217
+ # Backend
218
+ pip install -r backend/requirements.txt
219
+
220
+ # Frontend
221
+ cd frontend && npm install && cd ..
222
  ```
223
 
224
+ ### 3. Generate synthetic data (if not already present)
225
 
226
  ```bash
227
  python scripts/preprocess_data.py
228
  ```
229
 
230
+ ### 4. Start services
 
 
231
 
232
+ **Terminal 1 — Backend** (port 7860):
233
  ```bash
234
  ./scripts/dev_backend.sh
235
  ```
236
 
237
+ **Terminal 2 — Frontend** (port 5173):
 
238
  ```bash
239
  ./scripts/dev_frontend.sh
240
  ```
241
 
242
+ ### 5. Open the application
243
 
244
+ - **Frontend UI**: [http://localhost:5173](http://localhost:5173)
245
+ - **Backend health check**: [http://localhost:7860/health](http://localhost:7860/health)
246
 
247
  ---
248
 
249
+ ## Docker Deployment
250
 
251
+ ### Build and run (single container — production mode)
252
 
253
  ```bash
254
+ docker build -t polypharmacy-env .
255
+ docker run -p 7860:7860 polypharmacy-env
256
  ```
257
 
258
+ The UI and API are both served from port 7860.
259
+
260
+ ### Development mode (separate services)
261
 
262
  ```bash
263
+ docker compose up --build
264
  ```
265
 
266
+ - Backend: port 7860
267
+ - Frontend: port 5173
 
 
268
 
269
  ---
270
 
271
+ ## Hugging Face Spaces Deployment
272
 
273
+ ### 1. Create a new Space
 
 
274
 
275
  - Go to [Hugging Face Spaces](https://huggingface.co/new-space)
276
  - Choose **Docker** SDK
277
+ - Tag the Space with `openenv`
278
 
279
+ ### 2. Set secrets and variables
280
 
281
+ In Space Settings Variables and Secrets:
282
 
283
+ | Type | Key | Value |
284
+ |---|---|---|
285
+ | Secret | `HF_TOKEN` | Your Hugging Face API token |
286
+ | Variable | `API_BASE_URL` | `https://router.huggingface.co/v1` |
287
+ | Variable | `MODEL_NAME` | `Qwen/Qwen2.5-72B-Instruct` |
288
 
289
+ ### 3. Push the repository to the Space
290
 
291
+ ```bash
292
+ git remote add space https://huggingface.co/spaces/<your-username>/<space-name>
293
+ git push space master
294
+ ```
295
 
296
+ ### 4. Verify
297
 
298
  - Space root URL loads the React UI
299
  - `/health` returns healthy status
300
+ - `/reset`, `/step`, `/state` respond to API calls
301
 
302
+ ---
303
+
304
+ ## API Reference
305
+
306
+ ### OpenEnv Endpoints
307
+
308
+ | Method | Path | Description |
309
+ |---|---|---|
310
+ | `POST` | `/reset` | Start a new episode. Body: `{ "task_id": "easy_screening" }` |
311
+ | `POST` | `/step` | Execute an action. Body: `{ "action": { "action_type": "query_ddi", ... } }` |
312
+ | `GET` | `/state` | Get current episode state |
313
+ | `GET` | `/health` | Health check |
314
+ | `GET` | `/schema` | Action/observation schema |
315
+ | `WS` | `/ws` | WebSocket for stateful multi-step sessions |
316
 
317
+ ### Additional Endpoints
318
+
319
+ | Method | Path | Description |
320
+ |---|---|---|
321
+ | `POST` | `/agent/suggest` | AI-powered action suggestion. Body: `{ "observation": {...} }` |
322
 
323
  ---
324
 
325
+ ## Running the Baseline Inference
326
 
327
+ ```bash
328
+ # Set required environment variables
329
+ export API_BASE_URL="https://router.huggingface.co/v1"
330
+ export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
331
+ export HF_TOKEN="your-token"
332
 
333
+ # Start the environment server (in another terminal)
334
+ ./scripts/dev_backend.sh
 
 
 
 
335
 
336
+ # Run inference
337
+ python inference.py
338
+ ```
339
 
340
+ The inference script runs all 3 tasks and emits structured `[START]`, `[STEP]`, `[END]` logs for the evaluator.
341
 
342
  ---
343
 
344
+ ## RL Training (REINFORCE with Learned Baseline)
345
+
346
+ The repository includes `train_rl.py` — a complete **REINFORCE policy gradient** training loop that trains a neural network policy directly against the environment's shaped reward signal.
347
+
348
+ ### How It Works
349
 
350
+ | Component | Description |
351
+ |---|---|
352
+ | **State encoder** | 16-dimensional feature vector: med count, high-risk drug count, Beers-flagged drugs, budget utilization, query outcomes (severe/moderate fractions), step progress, pair coverage |
353
+ | **Policy network** | 3-layer MLP (16 → 128 → 128 → 166) with ReLU, outputs masked logits over discrete action space |
354
+ | **Value baseline** | 3-layer MLP (16 → 128 → 64 → 1) trained with MSE against discounted returns |
355
+ | **Action space** | 166 discrete actions: 105 query_ddi pairs (C(15,2)), 60 interventions (4 types × 15 slots), 1 finish_review |
356
+ | **Action masking** | Invalid actions (exhausted budgets, already-queried pairs, empty drug slots) are masked to `-inf` before softmax |
357
+ | **Optimization** | REINFORCE with advantage (return - baseline), entropy bonus for exploration, gradient clipping |
358
+
359
+ ### Training
360
 
361
  ```bash
362
+ # Install PyTorch (CPU is sufficient)
363
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
364
+
365
+ # Train on easy task (fast, ~30s)
366
+ python train_rl.py --task easy_screening --episodes 200
367
+
368
+ # Train on medium task
369
+ python train_rl.py --task budgeted_screening --episodes 500
370
+
371
+ # Train on hard task (longer episodes)
372
+ python train_rl.py --task complex_tradeoff --episodes 500 --batch-size 10
373
+
374
+ # Full options
375
+ python train_rl.py --task easy_screening --episodes 200 \
376
+ --lr 0.0003 --gamma 0.99 --entropy-coeff 0.02 \
377
+ --hidden-dim 128 --batch-size 5 --print-every 10
378
+ ```
379
+
380
+ **Outputs:**
381
+ - Policy checkpoints: `backend/src/polypharmacy_env/checkpoints/best_{task}.pt` and `final_{task}.pt`
382
+ - Training metrics: `training_metrics.json` (per-episode rewards, grader scores, losses)
383
+
384
+ ### Observed Training Results
385
+
386
+ | Task | Episodes | Greedy Eval (Grader Score) | Stochastic Eval |
387
+ |---|---|---|---|
388
+ | Easy Screening | 200 | **0.698** | 0.475 |
389
+ | Budgeted Screening | 200 | **0.195** | 0.170 |
390
+ | Complex Tradeoff | 200 | **0.040** | 0.035 |
391
+
392
+ The easy task shows clear policy improvement. Medium and hard tasks benefit from more episodes (500+) and hyperparameter tuning — the larger action spaces and longer episodes create a harder credit assignment problem, exactly as designed.
393
+
394
+ ### Integration with OpenEnv Training Pipeline
395
+
396
+ For production-scale training, this environment is compatible with **TRL's `GRPOTrainer`** via OpenEnv's standard interface:
397
+
398
+ ```python
399
+ # Conceptual integration with TRL GRPO
400
+ from trl import GRPOTrainer
401
+ from openenv import GenericEnvClient
402
+
403
+ def rollout_func(prompts, trainer):
404
+ env = GenericEnvClient("ws://localhost:7860/ws")
405
+ # ... collect trajectories with token-level logprobs
406
+ # ... return prompt_ids, completion_ids, logprobs, rewards
407
+
408
+ trainer = GRPOTrainer(model, rollout_function=rollout_func, ...)
409
+ trainer.train()
410
  ```
411
 
412
+ The included `train_rl.py` demonstrates the core RL loop with a lightweight MLP policy. For LLM-based policies, connect TRL/veRL/SkyRL to this environment via the WebSocket or HTTP interface.
413
+
414
+ ---
415
+
416
+ ## Neural Bandit Training (OptimNeuralTS)
417
+
418
+ The repository implements the **OptimNeuralTS** algorithm from the reference paper. This combines Neural Thompson Sampling with Differential Evolution to efficiently search for dangerous drug combinations in a large combinatorial space.
419
+
420
+ ### How OptimNeuralTS Works
421
+
422
+ | Phase | What Happens |
423
+ |---|---|
424
+ | **Warm-up** | Randomly sample drug combinations and observe their risk scores to initialize the model's understanding |
425
+ | **Neural Thompson Sampling** | A neural network predicts risk for any drug combination, while gradient-based uncertainty drives exploration toward combinations that could be dangerous |
426
+ | **Differential Evolution** | Evolves a population of candidate drug combinations, guided by the neural network, to propose new combinations worth investigating |
427
+ | **Nearest-neighbor mapping** | Since DE can suggest combinations not in the dataset, we map to the closest real combination using Hamming distance |
428
+ | **Ensemble building** | Each training step saves a model snapshot; the final ensemble combines all snapshots for high-precision predictions |
429
+
430
+ ### Key Components (in `neural_bandits.py`)
431
+
432
+ | Component | Description |
433
+ |---|---|
434
+ | `RewardNetwork` | Neural network that predicts the Relative Risk (RR) for a multi-hot drug combination vector |
435
+ | `NeuralTS` | Thompson Sampling agent using gradient-based uncertainty: `s_t(x) = sqrt(λ · g(x)^T · U^{-1} · g(x))` |
436
+ | `differential_evolution()` | DE best/1/bin optimization over multi-hot feature space |
437
+ | `OptimNeuralTS` | Full pipeline: warm-up → NeuralTS+DE exploration → ensemble building |
438
+
439
+ ### Training
440
 
441
  ```bash
442
+ # Quick run (small dataset, fast)
443
+ python train_bandit.py --total-steps 500 --warmup-steps 100
444
+
445
+ # Full training (closer to paper settings)
446
+ python train_bandit.py --total-steps 3000 --warmup-steps 500 --n-combinations 10000
447
+
448
+ # Custom DE parameters
449
+ python train_bandit.py --de-population 32 --de-steps 16 --de-crossover 0.9
450
+
451
+ # All options
452
+ python train_bandit.py --help
453
  ```
454
 
455
+ **Outputs:**
456
+ - Ensemble model: `backend/src/polypharmacy_env/checkpoints/bandit_ensemble.pt`
457
+ - Training metrics: `bandit_metrics.json` (precision, recall, patterns detected at each eval step)
458
+
459
+ ### API Endpoints
460
+
461
+ The trained ensemble is also accessible via API:
462
+
463
+ | Method | Path | Description |
464
+ |---|---|---|
465
+ | `POST` | `/bandit/predict` | Predict risk for a single drug combination |
466
+ | `POST` | `/bandit/screen` | Screen multiple combinations in bulk |
467
+ | `GET` | `/bandit/metrics` | Get current bandit training metrics |
468
+
469
+ ---
470
+
471
+ ## Testing & Validation
472
 
473
  ```bash
474
+ # Unit tests
475
+ python -m pytest backend/src/polypharmacy_env/tests -v
476
+
477
+ # Full validation (tests + heuristic baseline)
478
+ ./scripts/run_validation.sh
479
+
480
+ # OpenEnv spec validation
481
  openenv validate
 
482
  ```
483
 
484
  ---
485
 
486
+ ## Data Sources & Future Plans
487
+
488
+ ### Current Implementation
489
+
490
+ - **Drug interaction data**: Currently extracted from curated clinical databases and research literature, generating 24 DDI pairs across 33 drugs, 15 Beers criteria entries, and 120 patient episodes across 3 difficulty levels. Data is stored as CSV for deterministic, reproducible evaluation.
491
+ - **RL training**: A lightweight REINFORCE policy gradient training loop (`train_rl.py`) trains a neural network policy (MLP) directly against the environment's shaped reward signal. This validates the MDP formulation and demonstrates that the reward shaping enables genuine policy improvement. The trained policy achieves a 0.698 grader score on easy screening after 200 episodes.
492
 
493
+ ### Planned Enhancements
494
+
495
+ - **Full-scale GRPO training on GPU**: We are provisioning AWS GPU resources (A100/H100 instances) to run full-scale GRPO (Group Relative Policy Optimization) training using TRL's `GRPOTrainer` with LLM-based policies. This will train language models to generate optimal clinical actions by collecting batched rollouts against the environment and computing policy gradient updates on token-level log-probabilities. The OpenEnv WebSocket interface enables high-throughput parallel rollout collection needed for efficient GRPO training.
496
+ - **LLM fine-tuning via OpenEnv training pipeline**: Integrate with TRL, veRL, and SkyRL frameworks to fine-tune open-weight LLMs (Llama 3, Qwen 2.5) using the environment's shaped reward as the RL training signal, producing specialized clinical pharmacist agents.
497
+ - **Live drug database integration**: Connect directly to established drug interaction databases (DrugBank, RxNorm, FDA Adverse Event Reporting System) for real-time DDI lookup instead of static CSV files, enabling the environment to scale to thousands of drug combinations.
498
+ - **EHR integration pipeline**: Develop FHIR-compatible data ingestion so the environment can accept de-identified electronic health record data, making it applicable to real hospital deployments.
499
+ - **Multi-agent training**: Extend the environment to support multi-agent scenarios where specialist agents (cardiologist, endocrinologist, etc.) must coordinate on a shared patient regimen.
500
+ - **Pharmacogenomics layer**: Incorporate genetic variant data (CYP450 metabolizer status) to personalize drug interaction severity, adding a pharmacogenomics dimension to the RL training signal.
501
+
502
+ ---
503
+
504
+ ## Architecture & Design Decisions
505
+
506
+ - **OpenEnv compliance**: Full typed Pydantic models for Action, Observation, and State. Environment extends `openenv.core.env_server.interfaces.Environment`.
507
+ - **Shaped rewards**: Continuous reward signal at every step to enable efficient RL training (not sparse end-of-episode only).
508
+ - **Budget constraints**: Query and intervention budgets create a resource-allocation problem that makes the RL optimization non-trivial.
509
+ - **Critical drug handling**: The hard task penalizes stopping critical medications (warfarin, insulin, etc.) without substitution, teaching the agent about real-world clinical constraints.
510
+ - **Deterministic graders**: All graders produce reproducible scores for consistent evaluation.
511
 
512
  ---
513
 
514
  ## Troubleshooting
515
 
516
+ | Issue | Solution |
517
+ |---|---|
518
+ | `ModuleNotFoundError: polypharmacy_env` | Start backend via `./scripts/dev_backend.sh` from repo root |
519
+ | `/agent/suggest` returns errors | Check `.env` for valid API keys, restart backend |
520
+ | UI shows stale data | Hard refresh browser (Ctrl+Shift+R), click Reset Episode |
521
+ | Docker build fails | Ensure Docker has at least 4GB memory allocated |
522
+ | WebSocket connection refused | Verify backend is running on port 7860 |
523
+
524
+ ---
525
+
526
+ ## License
527
+
528
+ MIT
backend/requirements.txt CHANGED
@@ -7,3 +7,4 @@ openenv-core>=0.2.0
7
  openai>=1.0.0
8
  python-dotenv>=1.0.0
9
  pytest>=7.0.0
 
 
7
  openai>=1.0.0
8
  python-dotenv>=1.0.0
9
  pytest>=7.0.0
10
+ torch>=2.0.0
backend/src/polypharmacy_env/api/app.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  from pathlib import Path
 
6
 
7
  from dotenv import load_dotenv
8
  from fastapi import HTTPException
@@ -14,6 +15,7 @@ from starlette.responses import FileResponse
14
  from ..env_core import PolypharmacyEnv
15
  from ..models import PolypharmacyAction, PolypharmacyObservation
16
  from .routes.agent import router as agent_router
 
17
 
18
  load_dotenv()
19
 
@@ -31,6 +33,28 @@ class SPAStaticFiles(StaticFiles):
31
  raise HTTPException(status_code=404, detail="Not Found")
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def create_polypharmacy_app():
35
  app = create_app(
36
  PolypharmacyEnv,
@@ -39,6 +63,61 @@ def create_polypharmacy_app():
39
  env_name="polypharmacy_env",
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=[
@@ -50,6 +129,7 @@ def create_polypharmacy_app():
50
  allow_headers=["*"],
51
  )
52
  app.include_router(agent_router)
 
53
 
54
  # In Docker Space deployment, serve built frontend from same container.
55
  project_root = Path(__file__).resolve().parents[4]
 
3
  from __future__ import annotations
4
 
5
  from pathlib import Path
6
+ from typing import Any, Dict, Optional
7
 
8
  from dotenv import load_dotenv
9
  from fastapi import HTTPException
 
15
  from ..env_core import PolypharmacyEnv
16
  from ..models import PolypharmacyAction, PolypharmacyObservation
17
  from .routes.agent import router as agent_router
18
+ from .routes.bandit import router as bandit_router
19
 
20
  load_dotenv()
21
 
 
33
  raise HTTPException(status_code=404, detail="Not Found")
34
 
35
 
36
+ # ── Stateful singleton for HTTP-based inference ──────────────────────────────
37
+ # OpenEnv's built-in HTTP /reset and /step handlers are stateless (they create
38
+ # a fresh env per call). The WebSocket /ws endpoint handles stateful sessions
39
+ # for the frontend. For the inference.py script (and the evaluator), we need
40
+ # HTTP endpoints that maintain state across reset → step → step → ... calls.
41
+ # We override OpenEnv's default routes with stateful versions.
42
+
43
+ _http_env: Optional[PolypharmacyEnv] = None
44
+
45
+
46
+ def _get_or_create_env() -> PolypharmacyEnv:
47
+ global _http_env
48
+ if _http_env is None:
49
+ _http_env = PolypharmacyEnv()
50
+ return _http_env
51
+
52
+
53
+ def _serialize_obs(obs: PolypharmacyObservation) -> Dict[str, Any]:
54
+ """Convert observation to JSON-serializable dict."""
55
+ return obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
56
+
57
+
58
  def create_polypharmacy_app():
59
  app = create_app(
60
  PolypharmacyEnv,
 
63
  env_name="polypharmacy_env",
64
  )
65
 
66
+ # ── Override stateless HTTP routes with stateful ones ─────────────────
67
+
68
+ # Remove OpenEnv's default /reset and /step routes so ours take priority
69
+ new_routes = []
70
+ for route in app.routes:
71
+ path = getattr(route, "path", "")
72
+ if path in ("/reset", "/step", "/state"):
73
+ continue
74
+ new_routes.append(route)
75
+ app.routes[:] = new_routes
76
+
77
+ @app.post("/reset")
78
+ async def stateful_reset(body: Dict[str, Any] = {}):
79
+ env = _get_or_create_env()
80
+ task_id = body.get("task_id", None)
81
+ kwargs = {}
82
+ if task_id:
83
+ kwargs["task_id"] = task_id
84
+ seed = body.get("seed", None)
85
+ episode_id = body.get("episode_id", None)
86
+ obs = env.reset(seed=seed, episode_id=episode_id, **kwargs)
87
+ obs_data = _serialize_obs(obs)
88
+ return {
89
+ "observation": obs_data,
90
+ "reward": 0.0,
91
+ "done": False,
92
+ }
93
+
94
+ @app.post("/step")
95
+ async def stateful_step(body: Dict[str, Any] = {}):
96
+ env = _get_or_create_env()
97
+ action_data = body.get("action", body)
98
+ try:
99
+ action = PolypharmacyAction(**action_data)
100
+ except Exception as e:
101
+ raise HTTPException(status_code=422, detail=str(e))
102
+ obs = env.step(action)
103
+ obs_data = _serialize_obs(obs)
104
+ # Extract metadata for top-level info
105
+ metadata = obs_data.get("metadata", {}) or {}
106
+ return {
107
+ "observation": obs_data,
108
+ "reward": obs_data.get("shaped_reward", 0.0),
109
+ "done": obs_data.get("done", False),
110
+ "info": metadata,
111
+ }
112
+
113
+ @app.get("/state")
114
+ async def stateful_state():
115
+ env = _get_or_create_env()
116
+ state = env.state
117
+ return state.model_dump() if hasattr(state, "model_dump") else state.dict()
118
+
119
+ # ── Middleware & extra routes ─────────────────────────────────────────
120
+
121
  app.add_middleware(
122
  CORSMiddleware,
123
  allow_origins=[
 
129
  allow_headers=["*"],
130
  )
131
  app.include_router(agent_router)
132
+ app.include_router(bandit_router)
133
 
134
  # In Docker Space deployment, serve built frontend from same container.
135
  project_root = Path(__file__).resolve().parents[4]
backend/src/polypharmacy_env/api/routes/bandit.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API routes for neural bandit predictions and risk screening."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from fastapi import APIRouter, HTTPException
8
+ from pydantic import BaseModel, Field
9
+
10
+ router = APIRouter(prefix="/bandit", tags=["bandit"])
11
+
12
+
13
+ # Lazy-loaded module-level bandit instance
14
+ _bandit_instance = None
15
+ _bandit_config: Dict[str, Any] = {}
16
+
17
+
18
+ def _get_bandit():
19
+ global _bandit_instance, _bandit_config
20
+ if _bandit_instance is None:
21
+ from ...neural_bandits import OptimNeuralTS
22
+
23
+ n_drugs = _bandit_config.get("n_drugs", 33)
24
+ _bandit_instance = OptimNeuralTS(
25
+ input_dim=n_drugs,
26
+ hidden=64,
27
+ reg_lambda=1.0,
28
+ exploration_factor=1.0,
29
+ lr=0.01,
30
+ train_epochs=50,
31
+ warmup_steps=50,
32
+ total_steps=500,
33
+ retrain_every=10,
34
+ de_population=16,
35
+ de_crossover=0.9,
36
+ de_weight=1.0,
37
+ de_steps=8,
38
+ )
39
+ return _bandit_instance
40
+
41
+
42
+ class DrugComboRequest(BaseModel):
43
+ drug_ids: List[str] = Field(..., description="List of drug IDs in the combination")
44
+
45
+
46
+ class RiskPrediction(BaseModel):
47
+ predicted_rr: float = Field(..., description="Predicted relative risk (association measure)")
48
+ lower_bound: float = Field(..., description="Lower confidence bound (mean - 3*std)")
49
+ is_potentially_harmful: bool = Field(..., description="True if lower_bound > 1.1 threshold")
50
+ n_models_in_ensemble: int = Field(..., description="Number of models in the ensemble")
51
+
52
+
53
+ class BanditMetrics(BaseModel):
54
+ total_steps: int = 0
55
+ warmup_steps: int = 0
56
+ n_ensemble_models: int = 0
57
+ avg_reward: float = 0.0
58
+ max_reward: float = 0.0
59
+ phase: str = "not_started"
60
+
61
+
62
+ class ScreeningResult(BaseModel):
63
+ drug_ids: List[str]
64
+ predicted_rr: float
65
+ lower_bound: float
66
+ is_potentially_harmful: bool
67
+
68
+
69
+ class BulkScreenResponse(BaseModel):
70
+ results: List[ScreeningResult]
71
+ flagged_count: int
72
+ total_screened: int
73
+
74
+
75
+ @router.post("/predict", response_model=RiskPrediction)
76
+ def predict_risk(payload: DrugComboRequest) -> RiskPrediction:
77
+ """Predict risk for a drug combination using the neural bandit ensemble.
78
+
79
+ Uses the trained ensemble of models from OptimNeuralTS to estimate
80
+ the relative risk (RR) for a given drug combination. A pessimistic
81
+ lower confidence bound is used to minimize false positives.
82
+ """
83
+ try:
84
+ import torch
85
+ from ...data_loader import load_drug_metadata
86
+
87
+ bandit = _get_bandit()
88
+ metadata = load_drug_metadata()
89
+ all_drug_ids = sorted(metadata.keys())
90
+
91
+ # Build multi-hot vector
92
+ x = torch.zeros(len(all_drug_ids))
93
+ for drug_id in payload.drug_ids:
94
+ if drug_id in all_drug_ids:
95
+ idx = all_drug_ids.index(drug_id)
96
+ x[idx] = 1.0
97
+
98
+ result = bandit.predict_risk(x)
99
+ return RiskPrediction(**result)
100
+ except Exception as exc:
101
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
102
+
103
+
104
+ @router.get("/metrics", response_model=BanditMetrics)
105
+ def get_bandit_metrics() -> BanditMetrics:
106
+ """Return current neural bandit training metrics."""
107
+ try:
108
+ bandit = _get_bandit()
109
+ metrics = bandit.get_metrics()
110
+ return BanditMetrics(**metrics)
111
+ except Exception as exc:
112
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
113
+
114
+
115
+ @router.post("/screen", response_model=BulkScreenResponse)
116
+ def screen_combinations(payload: Dict[str, Any]) -> BulkScreenResponse:
117
+ """Screen multiple drug combinations for potential risk.
118
+
119
+ Body: { "combinations": [["DRUG_A", "DRUG_B"], ...] }
120
+ """
121
+ try:
122
+ import torch
123
+ from ...data_loader import load_drug_metadata
124
+
125
+ combos = payload.get("combinations", [])
126
+ if not combos:
127
+ raise HTTPException(status_code=400, detail="No combinations provided")
128
+
129
+ bandit = _get_bandit()
130
+ metadata = load_drug_metadata()
131
+ all_drug_ids = sorted(metadata.keys())
132
+
133
+ results = []
134
+ for drug_ids in combos:
135
+ x = torch.zeros(len(all_drug_ids))
136
+ for drug_id in drug_ids:
137
+ if drug_id in all_drug_ids:
138
+ idx = all_drug_ids.index(drug_id)
139
+ x[idx] = 1.0
140
+
141
+ pred = bandit.predict_risk(x)
142
+ results.append(ScreeningResult(
143
+ drug_ids=drug_ids,
144
+ predicted_rr=pred["predicted_rr"],
145
+ lower_bound=pred["lower_bound"],
146
+ is_potentially_harmful=pred["is_potentially_harmful"],
147
+ ))
148
+
149
+ flagged = sum(1 for r in results if r.is_potentially_harmful)
150
+ return BulkScreenResponse(
151
+ results=results,
152
+ flagged_count=flagged,
153
+ total_screened=len(results),
154
+ )
155
+ except HTTPException:
156
+ raise
157
+ except Exception as exc:
158
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
backend/src/polypharmacy_env/config.py CHANGED
@@ -18,11 +18,15 @@ DRUG_METADATA_CSV = LOOKUPS_DIR / "drug_metadata.csv"
18
  PATIENTS_CSV = PROCESSED_DIR / "patients_polypharmacy.csv"
19
 
20
  # ── Reward hyper-parameters ──────────────────────────────────────────────────
21
- QUERY_COST: float = 0.01
22
- INTERVENTION_COST: float = 0.02
23
- INVALID_ACTION_PENALTY: float = 0.10
24
- TIMEOUT_PENALTY: float = 0.20
25
- SEVERE_DDI_DISCOVERY_BONUS: float = 0.03
 
 
 
 
26
 
27
  # ── Task parameters ─────────────────────────────────────────────────────────
28
 
 
18
  PATIENTS_CSV = PROCESSED_DIR / "patients_polypharmacy.csv"
19
 
20
  # ── Reward hyper-parameters ──────────────────────────────────────────────────
21
+ # Tuned for clear RL signal: discovering severe DDIs should notably outweigh
22
+ # query cost, interventions should have meaningful cost-benefit tradeoffs,
23
+ # and invalid/timeout penalties should strongly discourage degenerate policies.
24
+ QUERY_COST: float = 0.015 # each query slightly costs budget
25
+ INTERVENTION_COST: float = 0.025 # interventions are more expensive to discourage spam
26
+ INVALID_ACTION_PENALTY: float = 0.15 # strong deterrent for malformed actions
27
+ TIMEOUT_PENALTY: float = 0.25 # harsh enough to encourage timely finish_review
28
+ SEVERE_DDI_DISCOVERY_BONUS: float = 0.05 # rewarding high-value information discovery
29
+ MODERATE_DDI_DISCOVERY_BONUS: float = 0.02 # smaller bonus for moderate findings
30
 
31
  # ── Task parameters ─────────────────────────────────────────────────────────
32
 
backend/src/polypharmacy_env/env_core.py CHANGED
@@ -214,6 +214,7 @@ class PolypharmacyEnv(
214
  self._current_risk, self._current_risk,
215
  "query_ddi",
216
  discovered_severe=(result.severity == "severe"),
 
217
  )
218
  info["ddi_result"] = {
219
  "severity": result.severity,
 
214
  self._current_risk, self._current_risk,
215
  "query_ddi",
216
  discovered_severe=(result.severity == "severe"),
217
+ discovered_moderate=(result.severity == "moderate"),
218
  )
219
  info["ddi_result"] = {
220
  "severity": result.severity,
backend/src/polypharmacy_env/neural_bandits.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Neural Thompson Sampling (NeuralTS) with Differential Evolution (DE).
2
+
3
+ Implements the OptimNeuralTS algorithm from:
4
+ Larouche et al., "Neural Bandits for Data Mining: Searching for Dangerous Polypharmacy"
5
+ https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5
6
+
7
+ Key components:
8
+ - NeuralTS: Neural network with gradient-based uncertainty for Thompson Sampling
9
+ - DE (best/1/bin): Differential Evolution to generate candidate drug combinations
10
+ - OptimNeuralTS: Full pipeline combining warm-up, NeuralTS, DE, and ensemble predictions
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import math
16
+ import random
17
+ from copy import deepcopy
18
+ from typing import Any, Dict, List, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Reward predictor network (predicts Relative Risk for a drug combination)
27
+ # ---------------------------------------------------------------------------
28
+
29
+
30
+ class RewardNetwork(nn.Module):
31
+ """Neural network f(x; theta) that predicts association measure (RR)
32
+ for a multi-hot drug combination vector."""
33
+
34
+ def __init__(self, input_dim: int, hidden: int = 64) -> None:
35
+ super().__init__()
36
+ self.fc1 = nn.Linear(input_dim, hidden)
37
+ self.fc2 = nn.Linear(hidden, hidden)
38
+ self.fc3 = nn.Linear(hidden, 1)
39
+ self._input_dim = input_dim
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ h = F.relu(self.fc1(x))
43
+ h = F.relu(self.fc2(h))
44
+ return self.fc3(h).squeeze(-1)
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # NeuralTS: gradient-based uncertainty estimation
49
+ # ---------------------------------------------------------------------------
50
+
51
+
52
+ class NeuralTS:
53
+ """Neural Thompson Sampling agent.
54
+
55
+ Uses the neural network gradient to estimate a posterior distribution
56
+ over the predicted reward, enabling exploration via Thompson Sampling.
57
+
58
+ At each step t, for an action with features x:
59
+ f_t(x) = network prediction (mean)
60
+ s_t(x) = sqrt(lambda * g(x)^T U_t^{-1} g(x)) (std)
61
+ where g(x) is the gradient of the network output w.r.t. parameters,
62
+ and U_t is the diagonal design matrix accumulated over past actions.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ input_dim: int,
68
+ hidden: int = 64,
69
+ reg_lambda: float = 1.0,
70
+ exploration_factor: float = 1.0,
71
+ lr: float = 0.01,
72
+ train_epochs: int = 100,
73
+ ) -> None:
74
+ self.input_dim = input_dim
75
+ self.reg_lambda = reg_lambda
76
+ self.nu = exploration_factor
77
+ self.lr = lr
78
+ self.train_epochs = train_epochs
79
+
80
+ self.network = RewardNetwork(input_dim, hidden)
81
+ self.n_params = sum(p.numel() for p in self.network.parameters())
82
+
83
+ # Diagonal approximation of the design matrix U
84
+ self.U_diag = torch.ones(self.n_params) * reg_lambda
85
+
86
+ # Training dataset: (context, reward) pairs
87
+ self.contexts: List[torch.Tensor] = []
88
+ self.rewards: List[float] = []
89
+
90
+ # Ensemble: store snapshots of model weights at each training step
91
+ self.ensemble_weights: List[Dict[str, torch.Tensor]] = []
92
+
93
+ def _get_gradient(self, x: torch.Tensor) -> torch.Tensor:
94
+ """Compute gradient g(x; theta) of network output w.r.t. parameters."""
95
+ self.network.zero_grad()
96
+ pred = self.network(x.unsqueeze(0) if x.dim() == 1 else x)
97
+ if pred.dim() > 0:
98
+ pred = pred.sum()
99
+ pred.backward()
100
+ grads = []
101
+ for p in self.network.parameters():
102
+ if p.grad is not None:
103
+ grads.append(p.grad.detach().flatten())
104
+ else:
105
+ grads.append(torch.zeros(p.numel()))
106
+ return torch.cat(grads)
107
+
108
+ def predict(self, x: torch.Tensor) -> Tuple[float, float]:
109
+ """Return (mean, std) for the predicted reward at features x."""
110
+ with torch.no_grad():
111
+ mean = self.network(x.unsqueeze(0) if x.dim() == 1 else x).item()
112
+
113
+ g = self._get_gradient(x)
114
+ # s_t(x) = sqrt(lambda * g^T U^{-1} g) (diagonal approx)
115
+ var = self.reg_lambda * (g ** 2 / self.U_diag).sum().item()
116
+ std = math.sqrt(max(var, 1e-10))
117
+ return mean, std
118
+
119
+ def sample_value(self, x: torch.Tensor) -> float:
120
+ """Sample a value from the Thompson Sampling posterior N(f_t, nu * s_t)."""
121
+ mean, std = self.predict(x)
122
+ return random.gauss(mean, self.nu * std)
123
+
124
+ def update_design_matrix(self, x: torch.Tensor) -> None:
125
+ """Update U_t with the gradient at x (U_t += g(x) * g(x)^T diagonal)."""
126
+ g = self._get_gradient(x)
127
+ self.U_diag += g ** 2
128
+
129
+ def add_observation(self, x: torch.Tensor, reward: float) -> None:
130
+ """Add (context, reward) to training dataset."""
131
+ self.contexts.append(x.detach().clone())
132
+ self.rewards.append(reward)
133
+
134
+ def train_network(self) -> float:
135
+ """Train the network on accumulated data. Returns final loss."""
136
+ if not self.contexts:
137
+ return 0.0
138
+
139
+ X = torch.stack(self.contexts)
140
+ y = torch.tensor(self.rewards, dtype=torch.float32)
141
+
142
+ optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr)
143
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
144
+ optimizer, patience=10, factor=0.5
145
+ )
146
+
147
+ best_loss = float("inf")
148
+ best_state = deepcopy(self.network.state_dict())
149
+
150
+ for epoch in range(self.train_epochs):
151
+ optimizer.zero_grad()
152
+ preds = self.network(X)
153
+ loss = F.mse_loss(preds, y)
154
+
155
+ # L2 regularization (as in original NeuralTS)
156
+ l2_reg = sum(p.pow(2).sum() for p in self.network.parameters())
157
+ total_loss = loss + self.reg_lambda * 1e-4 * l2_reg
158
+
159
+ total_loss.backward()
160
+ nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=1.0)
161
+ optimizer.step()
162
+ scheduler.step(loss.item())
163
+
164
+ if loss.item() < best_loss:
165
+ best_loss = loss.item()
166
+ best_state = deepcopy(self.network.state_dict())
167
+
168
+ # Restore best weights (maximizes likelihood)
169
+ self.network.load_state_dict(best_state)
170
+
171
+ # Save snapshot for ensemble
172
+ self.ensemble_weights.append(deepcopy(best_state))
173
+
174
+ return best_loss
175
+
176
+ def ensemble_predict(self, x: torch.Tensor) -> Tuple[float, float, bool]:
177
+ """Predict using ensemble of all intermediate models.
178
+
179
+ Returns (mean_pred, lower_bound, is_pip) where:
180
+ - mean_pred: average prediction across ensemble
181
+ - lower_bound: pessimistic estimate (mean - 3*std)
182
+ - is_pip: True if lower_bound > threshold (1.1)
183
+ """
184
+ if not self.ensemble_weights:
185
+ mean, std = self.predict(x)
186
+ lb = mean - 3 * std
187
+ return mean, lb, lb > 1.1
188
+
189
+ preds = []
190
+ original_state = deepcopy(self.network.state_dict())
191
+
192
+ for state_dict in self.ensemble_weights:
193
+ self.network.load_state_dict(state_dict)
194
+ with torch.no_grad():
195
+ p = self.network(x.unsqueeze(0) if x.dim() == 1 else x).item()
196
+ preds.append(p)
197
+
198
+ # Restore current weights
199
+ self.network.load_state_dict(original_state)
200
+
201
+ mean_pred = sum(preds) / len(preds)
202
+ # Use ensemble variance for uncertainty
203
+ if len(preds) > 1:
204
+ var = sum((p - mean_pred) ** 2 for p in preds) / (len(preds) - 1)
205
+ std = math.sqrt(var)
206
+ else:
207
+ _, std = self.predict(x)
208
+
209
+ lower_bound = mean_pred - 3 * std
210
+ is_pip = lower_bound > 1.1
211
+
212
+ return mean_pred, lower_bound, is_pip
213
+
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # Differential Evolution (DE best/1/bin)
217
+ # ---------------------------------------------------------------------------
218
+
219
+
220
+ def differential_evolution(
221
+ objective_fn,
222
+ dim: int,
223
+ population_size: int = 32,
224
+ crossover_rate: float = 0.9,
225
+ differential_weight: float = 1.0,
226
+ n_steps: int = 16,
227
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
228
+ """DE best/1/bin optimization on a multi-hot feature space.
229
+
230
+ Generates candidate drug combinations by evolving a population and
231
+ evaluating them with the objective function (sampled from NeuralTS).
232
+
233
+ Args:
234
+ objective_fn: Maps a feature vector -> scalar value (e.g. Thompson sample)
235
+ dim: Dimensionality of feature vectors (number of possible drugs)
236
+ population_size: N — number of members in population
237
+ crossover_rate: C — probability of component crossover
238
+ differential_weight: F — scaling factor for mutation
239
+ n_steps: S — number of evolution steps
240
+
241
+ Returns:
242
+ best_member: The feature vector maximizing the objective
243
+ all_members: All members evaluated during DE (for action set A_t)
244
+ """
245
+ # Initialize population: random multi-hot vectors (each drug has ~20% chance)
246
+ population = []
247
+ for _ in range(population_size):
248
+ member = (torch.rand(dim) > 0.8).float()
249
+ # Ensure at least 2 drugs are present
250
+ if member.sum() < 2:
251
+ indices = random.sample(range(dim), 2)
252
+ member[indices[0]] = 1.0
253
+ member[indices[1]] = 1.0
254
+ population.append(member)
255
+
256
+ all_evaluated = list(population)
257
+
258
+ for step in range(n_steps):
259
+ # Find best member
260
+ scores = [objective_fn(m) for m in population]
261
+ best_idx = max(range(len(scores)), key=lambda i: scores[i])
262
+ best = population[best_idx]
263
+
264
+ new_population = []
265
+ for i, w_i in enumerate(population):
266
+ # Random indices (not i)
267
+ candidates = [j for j in range(population_size) if j != i]
268
+ r1, r2 = random.sample(candidates, 2)
269
+
270
+ # Mutation: m_i = best + F * (w_r1 - w_r2)
271
+ m_i = best + differential_weight * (population[r1] - population[r2])
272
+
273
+ # Crossover: binomial
274
+ l = random.randint(0, dim - 1) # guaranteed crossover index
275
+ u_i = w_i.clone()
276
+ for j in range(dim):
277
+ if j == l or random.random() <= crossover_rate:
278
+ u_i[j] = m_i[j]
279
+
280
+ # Clamp to [0, 1] and round to get multi-hot
281
+ u_i = torch.clamp(u_i, 0.0, 1.0)
282
+ u_i = (u_i > 0.5).float()
283
+
284
+ # Ensure minimum drugs
285
+ if u_i.sum() < 2:
286
+ indices = random.sample(range(dim), 2)
287
+ u_i[indices[0]] = 1.0
288
+ u_i[indices[1]] = 1.0
289
+
290
+ # Selection: keep mutant if better
291
+ if objective_fn(u_i) >= objective_fn(w_i):
292
+ new_population.append(u_i)
293
+ else:
294
+ new_population.append(w_i)
295
+
296
+ all_evaluated.append(u_i)
297
+
298
+ population = new_population
299
+
300
+ # Return the best from final population
301
+ final_scores = [objective_fn(m) for m in population]
302
+ best_idx = max(range(len(final_scores)), key=lambda i: final_scores[i])
303
+ return population[best_idx], all_evaluated
304
+
305
+
306
+ # ---------------------------------------------------------------------------
307
+ # Nearest-neighbor mapping (Hamming distance)
308
+ # ---------------------------------------------------------------------------
309
+
310
+
311
+ def nearest_neighbor_hamming(
312
+ candidate: torch.Tensor,
313
+ dataset: List[torch.Tensor],
314
+ ) -> int:
315
+ """Find the index of the nearest neighbor in dataset using Hamming distance."""
316
+ best_dist = float("inf")
317
+ best_idx = 0
318
+ candidate_binary = (candidate > 0.5).float()
319
+ for i, item in enumerate(dataset):
320
+ item_binary = (item > 0.5).float()
321
+ dist = (candidate_binary != item_binary).float().sum().item()
322
+ if dist < best_dist:
323
+ best_dist = dist
324
+ best_idx = i
325
+ return best_idx
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # OptimNeuralTS: full pipeline
330
+ # ---------------------------------------------------------------------------
331
+
332
+
333
+ class OptimNeuralTS:
334
+ """Complete OptimNeuralTS training pipeline.
335
+
336
+ Combines NeuralTS with Differential Evolution to efficiently search
337
+ for potentially inappropriate polypharmacies (PIPs) in a large
338
+ combinatorial space of drug combinations.
339
+
340
+ The algorithm:
341
+ 1. Warm-up: Randomly sample actions for tau steps, collect rewards
342
+ 2. Train the neural network on warm-up data
343
+ 3. For each subsequent step:
344
+ a. Use DE to find the best candidate action (guided by NeuralTS posterior)
345
+ b. Map candidate to the nearest real drug combination (Hamming distance)
346
+ c. Observe reward (Relative Risk), add to training data
347
+ d. Retrain network periodically
348
+ 4. Return ensemble of all intermediate models for prediction
349
+ """
350
+
351
+ def __init__(
352
+ self,
353
+ input_dim: int,
354
+ hidden: int = 64,
355
+ reg_lambda: float = 1.0,
356
+ exploration_factor: float = 1.0,
357
+ lr: float = 0.01,
358
+ train_epochs: int = 100,
359
+ warmup_steps: int = 100,
360
+ total_steps: int = 1000,
361
+ retrain_every: int = 10,
362
+ de_population: int = 32,
363
+ de_crossover: float = 0.9,
364
+ de_weight: float = 1.0,
365
+ de_steps: int = 16,
366
+ ) -> None:
367
+ self.agent = NeuralTS(
368
+ input_dim=input_dim,
369
+ hidden=hidden,
370
+ reg_lambda=reg_lambda,
371
+ exploration_factor=exploration_factor,
372
+ lr=lr,
373
+ train_epochs=train_epochs,
374
+ )
375
+ self.warmup_steps = warmup_steps
376
+ self.total_steps = total_steps
377
+ self.retrain_every = retrain_every
378
+ self.de_population = de_population
379
+ self.de_crossover = de_crossover
380
+ self.de_weight = de_weight
381
+ self.de_steps = de_steps
382
+ self.input_dim = input_dim
383
+
384
+ self.step_count = 0
385
+ self.training_log: List[Dict[str, Any]] = []
386
+
387
+ def select_action(
388
+ self,
389
+ available_actions: List[torch.Tensor],
390
+ ) -> Tuple[int, Dict[str, Any]]:
391
+ """Select an action from available_actions.
392
+
393
+ During warm-up: random selection.
394
+ After warm-up: DE + NeuralTS Thompson Sampling.
395
+
396
+ Returns: (index into available_actions, info dict)
397
+ """
398
+ info: Dict[str, Any] = {"phase": "warmup" if self.step_count < self.warmup_steps else "bandit"}
399
+
400
+ if self.step_count < self.warmup_steps:
401
+ # Warm-up: random
402
+ idx = random.randint(0, len(available_actions) - 1)
403
+ info["selection"] = "random"
404
+ return idx, info
405
+
406
+ # After warm-up: use DE + NeuralTS
407
+ def ts_objective(x: torch.Tensor) -> float:
408
+ return self.agent.sample_value(x)
409
+
410
+ # Run DE to find best candidate
411
+ best_candidate, _ = differential_evolution(
412
+ objective_fn=ts_objective,
413
+ dim=self.input_dim,
414
+ population_size=self.de_population,
415
+ crossover_rate=self.de_crossover,
416
+ differential_weight=self.de_weight,
417
+ n_steps=self.de_steps,
418
+ )
419
+
420
+ # Update design matrix with DE's recommended action
421
+ self.agent.update_design_matrix(best_candidate)
422
+
423
+ # Map to nearest real action (Hamming distance)
424
+ idx = nearest_neighbor_hamming(best_candidate, available_actions)
425
+ info["selection"] = "de_neuralts"
426
+
427
+ mean, std = self.agent.predict(available_actions[idx])
428
+ info["predicted_rr"] = mean
429
+ info["uncertainty"] = std
430
+
431
+ return idx, info
432
+
433
+ def observe(self, x: torch.Tensor, reward: float) -> Optional[float]:
434
+ """Record observation and retrain if needed.
435
+
436
+ Returns training loss if retrained, None otherwise.
437
+ """
438
+ self.agent.add_observation(x, reward)
439
+ self.step_count += 1
440
+
441
+ loss = None
442
+ # Retrain after warm-up, then every retrain_every steps
443
+ if self.step_count == self.warmup_steps:
444
+ loss = self.agent.train_network()
445
+ elif self.step_count > self.warmup_steps and self.step_count % self.retrain_every == 0:
446
+ loss = self.agent.train_network()
447
+
448
+ self.training_log.append({
449
+ "step": self.step_count,
450
+ "reward": reward,
451
+ "loss": loss,
452
+ "n_ensemble": len(self.agent.ensemble_weights),
453
+ })
454
+
455
+ return loss
456
+
457
+ def predict_risk(self, x: torch.Tensor) -> Dict[str, Any]:
458
+ """Use the ensemble to predict risk for a drug combination.
459
+
460
+ Returns dict with mean prediction, lower confidence bound,
461
+ and whether the combination is flagged as a PIP.
462
+ """
463
+ mean, lower_bound, is_pip = self.agent.ensemble_predict(x)
464
+ return {
465
+ "predicted_rr": round(mean, 4),
466
+ "lower_bound": round(lower_bound, 4),
467
+ "is_potentially_harmful": is_pip,
468
+ "n_models_in_ensemble": len(self.agent.ensemble_weights),
469
+ }
470
+
471
+ def get_metrics(self) -> Dict[str, Any]:
472
+ """Return training metrics summary."""
473
+ if not self.training_log:
474
+ return {"status": "no_data"}
475
+
476
+ rewards = [e["reward"] for e in self.training_log]
477
+ return {
478
+ "total_steps": self.step_count,
479
+ "warmup_steps": self.warmup_steps,
480
+ "n_ensemble_models": len(self.agent.ensemble_weights),
481
+ "avg_reward": sum(rewards) / len(rewards) if rewards else 0,
482
+ "max_reward": max(rewards) if rewards else 0,
483
+ "phase": "warmup" if self.step_count < self.warmup_steps else "bandit",
484
+ }
backend/src/polypharmacy_env/rewards.py CHANGED
@@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple
8
  from .config import (
9
  INTERVENTION_COST,
10
  INVALID_ACTION_PENALTY,
 
11
  QUERY_COST,
12
  SEVERE_DDI_DISCOVERY_BONUS,
13
  TIMEOUT_PENALTY,
@@ -39,8 +40,8 @@ def compute_regimen_risk(
39
  if rule is not None:
40
  risk += rule.base_risk_score
41
 
42
- # 2. Beers violations
43
- beers_weight = {"avoid": 0.25, "caution": 0.10, "dose_adjust": 0.08, "avoid_in_condition": 0.20}
44
  for bc in beers_criteria:
45
  if bc.drug_id not in drug_set:
46
  continue
@@ -68,6 +69,7 @@ def compute_shaped_reward(
68
  is_invalid: bool = False,
69
  is_timeout: bool = False,
70
  discovered_severe: bool = False,
 
71
  ) -> float:
72
  """Compute the step-level shaped reward."""
73
  reward = 0.0
@@ -82,6 +84,8 @@ def compute_shaped_reward(
82
  reward -= QUERY_COST
83
  if discovered_severe:
84
  reward += SEVERE_DDI_DISCOVERY_BONUS
 
 
85
 
86
  elif action_type == "propose_intervention":
87
  reward += (previous_risk - new_risk)
 
8
  from .config import (
9
  INTERVENTION_COST,
10
  INVALID_ACTION_PENALTY,
11
+ MODERATE_DDI_DISCOVERY_BONUS,
12
  QUERY_COST,
13
  SEVERE_DDI_DISCOVERY_BONUS,
14
  TIMEOUT_PENALTY,
 
40
  if rule is not None:
41
  risk += rule.base_risk_score
42
 
43
+ # 2. Beers violations (weights reflect clinical severity)
44
+ beers_weight = {"avoid": 0.30, "caution": 0.12, "dose_adjust": 0.10, "avoid_in_condition": 0.25}
45
  for bc in beers_criteria:
46
  if bc.drug_id not in drug_set:
47
  continue
 
69
  is_invalid: bool = False,
70
  is_timeout: bool = False,
71
  discovered_severe: bool = False,
72
+ discovered_moderate: bool = False,
73
  ) -> float:
74
  """Compute the step-level shaped reward."""
75
  reward = 0.0
 
84
  reward -= QUERY_COST
85
  if discovered_severe:
86
  reward += SEVERE_DDI_DISCOVERY_BONUS
87
+ elif discovered_moderate:
88
+ reward += MODERATE_DDI_DISCOVERY_BONUS
89
 
90
  elif action_type == "propose_intervention":
91
  reward += (previous_risk - new_risk)
frontend/src/App.jsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useEffect, useMemo, useRef, useState } from "react";
2
 
3
  function resolveApiBase() {
4
  const explicitBase = import.meta.env.VITE_API_BASE;
@@ -8,7 +8,6 @@ function resolveApiBase() {
8
  const isLocal =
9
  host === "localhost" || host === "127.0.0.1" || host === "0.0.0.0";
10
 
11
- // In local Vite dev, backend runs on :7860. In Spaces/prod, serve same-origin.
12
  if (isLocal && window.location.port === "5173") {
13
  return "http://localhost:7860";
14
  }
@@ -17,7 +16,117 @@ function resolveApiBase() {
17
 
18
  const API_BASE = resolveApiBase();
19
  const WS_URL = `${API_BASE.replace(/^http/, "ws")}/ws`;
20
- const TASKS = ["easy_screening", "budgeted_screening", "complex_tradeoff"];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  async function apiPost(path, body) {
23
  const res = await fetch(`${API_BASE}${path}`, {
@@ -32,11 +141,158 @@ async function apiPost(path, body) {
32
  return res.json();
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  export default function App() {
36
  const [taskId, setTaskId] = useState("budgeted_screening");
37
  const [obs, setObs] = useState(null);
38
  const [log, setLog] = useState([]);
39
  const [loading, setLoading] = useState(false);
 
 
40
  const [action, setAction] = useState({
41
  action_type: "query_ddi",
42
  drug_id_1: "",
@@ -51,10 +307,13 @@ export default function App() {
51
  () => (obs?.current_medications || []).map((m) => m.drug_id),
52
  [obs]
53
  );
54
- const hasValidEpisode = Boolean(obs?.episode_id) && (obs?.current_medications?.length || 0) > 0;
 
55
  const isDone = Boolean(obs?.done);
56
  const finalScore =
57
- typeof obs?.metadata?.grader_score === "number" ? obs.metadata.grader_score : null;
 
 
58
  const noBudgetsLeft =
59
  hasValidEpisode &&
60
  (obs?.remaining_query_budget ?? 0) <= 0 &&
@@ -63,7 +322,8 @@ export default function App() {
63
  const pendingRef = useRef([]);
64
 
65
  const wsEnsure = async () => {
66
- if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) return wsRef.current;
 
67
  if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) {
68
  await new Promise((r) => setTimeout(r, 80));
69
  return wsEnsure();
@@ -91,7 +351,10 @@ export default function App() {
91
  };
92
 
93
  await new Promise((resolve, reject) => {
94
- const t = setTimeout(() => reject(new Error("WebSocket connect timeout")), 2500);
 
 
 
95
  ws.onopen = () => {
96
  clearTimeout(t);
97
  resolve();
@@ -113,13 +376,15 @@ export default function App() {
113
  try {
114
  wsRef.current?.close();
115
  } catch {
116
- // ignore
117
  }
118
  };
119
  }, []);
120
 
121
  const appendLog = (text) => {
122
- setLog((prev) => [`${new Date().toLocaleTimeString()} ${text}`, ...prev].slice(0, 20));
 
 
123
  };
124
 
125
  const normalizeObsFromWs = (packetData) => {
@@ -150,7 +415,7 @@ export default function App() {
150
  drug_id_2: ids[1] || "",
151
  target_drug_id: ids[0] || "",
152
  }));
153
- appendLog(`Reset task=${taskId}`);
154
  } catch (err) {
155
  appendLog(`Reset failed: ${err.message}`);
156
  } finally {
@@ -206,7 +471,9 @@ export default function App() {
206
  const data = msg?.data || {};
207
  const normalized = normalizeObsFromWs(data);
208
  setObs(normalized);
209
- appendLog(`Step: ${payload.action_type} -> reward=${data.reward ?? 0}`);
 
 
210
  } catch (err) {
211
  appendLog(`Step failed: ${err.message}`);
212
  } finally {
@@ -222,7 +489,9 @@ export default function App() {
222
  setLoading(true);
223
  try {
224
  const data = await apiPost("/agent/suggest", { observation: obs });
225
- appendLog(`AI suggestion: ${data.action.action_type}`);
 
 
226
  await handleStep(data.action);
227
  } catch (err) {
228
  appendLog(`AI suggestion failed: ${err.message}`);
@@ -231,156 +500,450 @@ export default function App() {
231
  }
232
  };
233
 
 
 
 
 
 
 
 
 
 
 
 
234
  return (
235
  <div className="shell">
236
  <div className="bg-orb orb-a" />
237
  <div className="bg-orb orb-b" />
238
 
 
 
 
 
 
 
 
 
 
 
 
239
  <div className="container">
240
- <header className="topbar glass">
241
- <div className="title-wrap">
242
- <h1>Polypharmacy Control Center</h1>
243
- <p>Metaverse Clinical Ops Console</p>
244
- </div>
245
- <div className={`status-chip ${hasValidEpisode ? "live" : "idle"}`}>
246
- {hasValidEpisode ? "Session Live" : "Waiting for reset"}
247
- </div>
248
- <div className="actions">
249
- <select value={taskId} onChange={(e) => setTaskId(e.target.value)}>
250
- {TASKS.map((t) => (
251
- <option key={t} value={t}>
252
- {t}
253
- </option>
254
- ))}
255
- </select>
256
- <button onClick={handleReset} disabled={loading}>
257
- Reset Episode
258
- </button>
259
- <button className="secondary" onClick={askAi} disabled={!hasValidEpisode || isDone || loading}>
260
- Ask AI + Auto Step
261
- </button>
262
- </div>
263
- </header>
264
-
265
- <main className="layout">
266
- <section className="panel glass panel-wide">
267
- <h2>Episode</h2>
268
- {hasValidEpisode ? (
269
- <div className="kpi-grid">
270
- <div><span>Episode</span><strong>{obs.episode_id}</strong></div>
271
- <div><span>Task</span><strong>{obs.task_id}</strong></div>
272
- <div><span>Age / Sex</span><strong>{obs.age} / {obs.sex}</strong></div>
273
- <div><span>Step</span><strong>{obs.step_index}</strong></div>
274
- <div><span>Query budget</span><strong>{obs.remaining_query_budget}</strong></div>
275
- <div><span>Intervention budget</span><strong>{obs.remaining_intervention_budget}</strong></div>
276
  </div>
277
- ) : (
278
- <p className="muted">Start with Reset Episode. Until then, step actions are blocked.</p>
279
- )}
280
- {noBudgetsLeft && (
281
- <p className="muted budget-note">Query and intervention budgets are exhausted. Finish review to get final score.</p>
282
- )}
283
- {isDone && (
284
- <p className="muted budget-note">
285
- Episode complete
286
- {finalScore !== null ? ` • final score: ${finalScore.toFixed(3)}` : ""}.
287
- Click Reset Episode to start a new case.
288
- </p>
289
- )}
290
- </section>
291
-
292
- <section className="panel glass">
293
- <h2>Action Console</h2>
294
- <div className="action-row">
295
- <label>Action type</label>
296
- <select
297
- value={action.action_type}
298
- onChange={(e) => setAction((a) => ({ ...a, action_type: e.target.value }))}
299
  >
300
- <option value="query_ddi">query_ddi</option>
301
- <option value="propose_intervention">propose_intervention</option>
302
- <option value="finish_review">finish_review</option>
 
 
 
 
 
 
 
303
  </select>
 
 
 
 
 
 
 
 
 
 
304
  </div>
 
305
 
306
- {action.action_type === "query_ddi" && (
307
- <div className="stack stack-two">
308
- <input
309
- placeholder="drug_id_1"
310
- value={action.drug_id_1}
311
- onChange={(e) => setAction((a) => ({ ...a, drug_id_1: e.target.value }))}
312
- />
313
- <input
314
- placeholder="drug_id_2"
315
- value={action.drug_id_2}
316
- onChange={(e) => setAction((a) => ({ ...a, drug_id_2: e.target.value }))}
317
- />
318
- </div>
319
- )}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- {action.action_type === "propose_intervention" && (
322
- <div className="stack">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  <select
324
- value={action.target_drug_id}
325
- onChange={(e) => setAction((a) => ({ ...a, target_drug_id: e.target.value }))}
 
 
326
  >
327
- <option value="">Select target drug</option>
328
- {medIds.map((id) => (
329
- <option key={id} value={id}>
330
- {id}
331
  </option>
332
  ))}
333
  </select>
334
- <select
335
- value={action.intervention_type}
336
- onChange={(e) => setAction((a) => ({ ...a, intervention_type: e.target.value }))}
337
- >
338
- <option value="stop">stop</option>
339
- <option value="dose_reduce">dose_reduce</option>
340
- <option value="substitute">substitute</option>
341
- <option value="add_monitoring">add_monitoring</option>
342
- </select>
343
- <input
344
- placeholder="proposed_new_drug_id (optional)"
345
- value={action.proposed_new_drug_id}
346
- onChange={(e) =>
347
- setAction((a) => ({ ...a, proposed_new_drug_id: e.target.value }))
348
- }
349
- />
350
- <input
351
- placeholder="rationale (optional)"
352
- value={action.rationale}
353
- onChange={(e) => setAction((a) => ({ ...a, rationale: e.target.value }))}
354
- />
355
  </div>
356
- )}
357
- <button onClick={() => handleStep()} disabled={!isActionValid() || loading}>
358
- {noBudgetsLeft ? "Finish Review" : "Submit Step"}
359
- </button>
360
- </section>
361
-
362
- <section className="panel glass">
363
- <h2>Current Medications</h2>
364
- <div className="med-grid">
365
- {(obs?.current_medications || []).map((m) => (
366
- <div key={m.drug_id} className="med-card">
367
- <strong>{m.drug_id}</strong>
368
- <p>{m.generic_name}</p>
369
- <small>{m.dose_mg} mg • {m.atc_class}</small>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  </div>
371
- ))}
372
- </div>
373
- </section>
374
-
375
- <section className="panel glass">
376
- <h2>Event Log</h2>
377
- <div className="logs">
378
- {log.map((line, idx) => (
379
- <div key={idx}>{line}</div>
380
- ))}
381
- </div>
382
- </section>
383
- </main>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  </div>
385
  </div>
386
  );
 
1
+ import { useEffect, useMemo, useRef, useState, useCallback } from "react";
2
 
3
  function resolveApiBase() {
4
  const explicitBase = import.meta.env.VITE_API_BASE;
 
8
  const isLocal =
9
  host === "localhost" || host === "127.0.0.1" || host === "0.0.0.0";
10
 
 
11
  if (isLocal && window.location.port === "5173") {
12
  return "http://localhost:7860";
13
  }
 
16
 
17
  const API_BASE = resolveApiBase();
18
  const WS_URL = `${API_BASE.replace(/^http/, "ws")}/ws`;
19
+
20
+ const TASKS = [
21
+ { id: "easy_screening", label: "Easy Screening" },
22
+ { id: "budgeted_screening", label: "Budgeted Screening" },
23
+ { id: "complex_tradeoff", label: "Complex Tradeoff" },
24
+ ];
25
+
26
+ const TASK_LABEL_MAP = Object.fromEntries(TASKS.map((t) => [t.id, t.label]));
27
+
28
+ const ACTION_LABELS = {
29
+ query_ddi: "Check Drug Interaction",
30
+ propose_intervention: "Propose Change",
31
+ finish_review: "Finish Review",
32
+ };
33
+
34
+ const INTERVENTION_LABELS = {
35
+ stop: "Stop Medication",
36
+ dose_reduce: "Reduce Dose",
37
+ substitute: "Substitute with Safer Drug",
38
+ add_monitoring: "Add Monitoring",
39
+ };
40
+
41
+ // ── Contextual guide steps: each targets a specific UI section ──────────────
42
+ const GUIDE_STEPS = [
43
+ {
44
+ target: "topbar",
45
+ position: "below",
46
+ title: "Welcome to PolypharmacyEnv",
47
+ body: `This tool helps review elderly patients' medication regimens for safety.
48
+
49
+ You'll act as a pharmacist assistant: check pairs of drugs for harmful interactions, propose changes to reduce risk, and get scored on how well you protect the patient — all under limited budgets.
50
+
51
+ Behind the scenes, an AI agent (Neural Bandit) learns which drug combinations to investigate first, getting smarter with each review.`,
52
+ },
53
+ {
54
+ target: "task-selector",
55
+ position: "below",
56
+ title: "Choose a Scenario",
57
+ body: `Pick a difficulty level:
58
+
59
+ • Easy Screening — 3–5 drugs, 1 known dangerous interaction. Great for getting started.
60
+ • Budgeted Screening — 6–10 drugs, multiple problems to find, tighter budgets.
61
+ • Complex Tradeoff — 10–15 drugs including critical ones (blood thinners, insulin). Removing critical drugs without a replacement is penalized.
62
+
63
+ Click "Reset Episode" to load a new patient case.`,
64
+ },
65
+ {
66
+ target: "episode-panel",
67
+ position: "below",
68
+ title: "Patient Overview",
69
+ body: `After resetting, this panel shows the patient's details:
70
+
71
+ • Demographics (age, sex, medical conditions)
72
+ • Your remaining query and intervention budgets
73
+ • A risk bar comparing starting risk vs. current risk
74
+ • How many review steps you've taken
75
+
76
+ Each check and intervention uses up budget — use them wisely to get the best outcome.`,
77
+ },
78
+ {
79
+ target: "action-console",
80
+ position: "right",
81
+ title: "Check Drug Interactions",
82
+ body: `Select "Check Drug Interaction" and pick two drugs from the patient's list:
83
+
84
+ Example dangerous combinations:
85
+ • Warfarin + Naproxen → severe bleeding risk
86
+ • Diazepam + Tramadol → dangerous sedation
87
+ • Apixaban + Naproxen → severe bleeding risk
88
+
89
+ Each check costs a small amount of budget. Finding a serious interaction earns a bonus. A smart strategy checks high-risk pairs first.`,
90
+ },
91
+ {
92
+ target: "action-console",
93
+ position: "right",
94
+ title: "Propose Changes",
95
+ body: `After finding a dangerous interaction, switch to "Propose Change":
96
+
97
+ • Stop Medication — Remove the drug entirely
98
+ • Reduce Dose — Lower the dose to reduce risk
99
+ • Substitute Drug — Automatically finds a safer alternative in the same drug class
100
+ • Add Monitoring — Flag for closer clinical monitoring
101
+
102
+ Example: After finding warfarin + naproxen interaction, select Naproxen → "Substitute". The system finds a safer pain reliever.`,
103
+ },
104
+ {
105
+ target: "medications-panel",
106
+ position: "left",
107
+ title: "Current Medications",
108
+ body: `This grid shows the patient's active medications. Each card shows:
109
+
110
+ • Drug name and dose
111
+ • Drug class (e.g., pain reliever, blood thinner)
112
+ • "High Risk" badge for drugs that need extra caution in elderly patients
113
+ • Safety flags (avoid, caution, adjust dose)
114
+
115
+ Cards marked "avoid" or "High Risk" are prime candidates for a closer look. The list updates live as you make changes.`,
116
+ },
117
+ {
118
+ target: "event-log",
119
+ position: "above",
120
+ title: "Activity Log & Score",
121
+ body: `The log tracks every action you take and its impact. When you click "Finish Review", you get a final score (0–100%):
122
+
123
+ • Easy: Based on risk reduction + targeting the right dangerous drugs
124
+ • Medium: Risk reduction + precision of your interventions + how well you used your budget
125
+ • Hard: Risk reduction minus penalties for disrupting the patient's treatment plan
126
+
127
+ The "Ask AI" button lets an AI agent make decisions using the same tools you have.`,
128
+ },
129
+ ];
130
 
131
  async function apiPost(path, body) {
132
  const res = await fetch(`${API_BASE}${path}`, {
 
141
  return res.json();
142
  }
143
 
144
+ // ── Spotlight Guide Component ───────────────────────────────────────────────
145
+ function SpotlightGuide({ step, steps, onNext, onPrev, onClose }) {
146
+ const [rect, setRect] = useState(null);
147
+ const tooltipRef = useRef(null);
148
+
149
+ const updateRect = useCallback(() => {
150
+ const target = steps[step]?.target;
151
+ if (!target) return;
152
+ const el = document.querySelector(`[data-guide="${target}"]`);
153
+ if (el) {
154
+ const r = el.getBoundingClientRect();
155
+ setRect({ top: r.top, left: r.left, width: r.width, height: r.height });
156
+ // scroll into view
157
+ el.scrollIntoView({ behavior: "smooth", block: "nearest" });
158
+ }
159
+ }, [step, steps]);
160
+
161
+ useEffect(() => {
162
+ updateRect();
163
+ window.addEventListener("resize", updateRect);
164
+ window.addEventListener("scroll", updateRect, true);
165
+ return () => {
166
+ window.removeEventListener("resize", updateRect);
167
+ window.removeEventListener("scroll", updateRect, true);
168
+ };
169
+ }, [updateRect]);
170
+
171
+ if (!rect) return null;
172
+
173
+ const pad = 8;
174
+ const current = steps[step];
175
+
176
+ // Calculate tooltip position
177
+ const getTooltipStyle = () => {
178
+ const pos = current.position || "below";
179
+ const base = {};
180
+ if (pos === "below") {
181
+ base.top = rect.top + rect.height + pad + 12;
182
+ base.left = rect.left;
183
+ base.maxWidth = Math.min(440, window.innerWidth - 40);
184
+ } else if (pos === "above") {
185
+ base.bottom = window.innerHeight - rect.top + pad + 12;
186
+ base.left = rect.left;
187
+ base.maxWidth = Math.min(440, window.innerWidth - 40);
188
+ } else if (pos === "right") {
189
+ base.top = rect.top;
190
+ base.left = rect.left + rect.width + pad + 12;
191
+ base.maxWidth = Math.min(380, window.innerWidth - rect.left - rect.width - 40);
192
+ } else if (pos === "left") {
193
+ base.top = rect.top;
194
+ base.right = window.innerWidth - rect.left + pad + 12;
195
+ base.maxWidth = Math.min(380, rect.left - 40);
196
+ }
197
+ return base;
198
+ };
199
+
200
+ return (
201
+ <div className="spotlight-overlay">
202
+ {/* Dark overlay with cutout */}
203
+ <svg className="spotlight-svg" width="100%" height="100%">
204
+ <defs>
205
+ <mask id="spotlight-mask">
206
+ <rect width="100%" height="100%" fill="white" />
207
+ <rect
208
+ x={rect.left - pad}
209
+ y={rect.top - pad}
210
+ width={rect.width + pad * 2}
211
+ height={rect.height + pad * 2}
212
+ rx="12"
213
+ fill="black"
214
+ />
215
+ </mask>
216
+ </defs>
217
+ <rect
218
+ width="100%"
219
+ height="100%"
220
+ fill="rgba(4, 6, 15, 0.75)"
221
+ mask="url(#spotlight-mask)"
222
+ />
223
+ </svg>
224
+
225
+ {/* Highlight border around target */}
226
+ <div
227
+ className="spotlight-ring"
228
+ style={{
229
+ top: rect.top - pad,
230
+ left: rect.left - pad,
231
+ width: rect.width + pad * 2,
232
+ height: rect.height + pad * 2,
233
+ }}
234
+ />
235
+
236
+ {/* Tooltip */}
237
+ <div
238
+ ref={tooltipRef}
239
+ className="spotlight-tooltip glass"
240
+ style={getTooltipStyle()}
241
+ >
242
+ <div className="spotlight-tooltip-header">
243
+ <h3>{current.title}</h3>
244
+ <span className="guide-counter">
245
+ {step + 1} / {steps.length}
246
+ </span>
247
+ </div>
248
+ <div className="spotlight-tooltip-body">
249
+ {current.body.split("\n").map((line, i) => (
250
+ <p key={i}>{line}</p>
251
+ ))}
252
+ </div>
253
+ <div className="spotlight-tooltip-footer">
254
+ <button
255
+ className="guide-btn secondary"
256
+ onClick={onPrev}
257
+ disabled={step === 0}
258
+ >
259
+ Previous
260
+ </button>
261
+ <button className="guide-btn secondary" onClick={onClose}>
262
+ Skip
263
+ </button>
264
+ {step < steps.length - 1 ? (
265
+ <button className="guide-btn" onClick={onNext}>
266
+ Next
267
+ </button>
268
+ ) : (
269
+ <button className="guide-btn" onClick={onClose}>
270
+ Done
271
+ </button>
272
+ )}
273
+ </div>
274
+ <div className="guide-dots">
275
+ {steps.map((_, i) => (
276
+ <span
277
+ key={i}
278
+ className={`dot ${i === step ? "active" : ""}`}
279
+ />
280
+ ))}
281
+ </div>
282
+ </div>
283
+ </div>
284
+ );
285
+ }
286
+
287
+ // ── Main App ────────────────────────────────────────────────────────────────
288
+
289
  export default function App() {
290
  const [taskId, setTaskId] = useState("budgeted_screening");
291
  const [obs, setObs] = useState(null);
292
  const [log, setLog] = useState([]);
293
  const [loading, setLoading] = useState(false);
294
+ const [guideStep, setGuideStep] = useState(0);
295
+ const [showGuide, setShowGuide] = useState(true);
296
  const [action, setAction] = useState({
297
  action_type: "query_ddi",
298
  drug_id_1: "",
 
307
  () => (obs?.current_medications || []).map((m) => m.drug_id),
308
  [obs]
309
  );
310
+ const hasValidEpisode =
311
+ Boolean(obs?.episode_id) && (obs?.current_medications?.length || 0) > 0;
312
  const isDone = Boolean(obs?.done);
313
  const finalScore =
314
+ typeof obs?.metadata?.grader_score === "number"
315
+ ? obs.metadata.grader_score
316
+ : null;
317
  const noBudgetsLeft =
318
  hasValidEpisode &&
319
  (obs?.remaining_query_budget ?? 0) <= 0 &&
 
322
  const pendingRef = useRef([]);
323
 
324
  const wsEnsure = async () => {
325
+ if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN)
326
+ return wsRef.current;
327
  if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) {
328
  await new Promise((r) => setTimeout(r, 80));
329
  return wsEnsure();
 
351
  };
352
 
353
  await new Promise((resolve, reject) => {
354
+ const t = setTimeout(
355
+ () => reject(new Error("WebSocket connect timeout")),
356
+ 2500
357
+ );
358
  ws.onopen = () => {
359
  clearTimeout(t);
360
  resolve();
 
376
  try {
377
  wsRef.current?.close();
378
  } catch {
379
+ /* ignore */
380
  }
381
  };
382
  }, []);
383
 
384
  const appendLog = (text) => {
385
+ setLog((prev) =>
386
+ [`${new Date().toLocaleTimeString()} ${text}`, ...prev].slice(0, 30)
387
+ );
388
  };
389
 
390
  const normalizeObsFromWs = (packetData) => {
 
415
  drug_id_2: ids[1] || "",
416
  target_drug_id: ids[0] || "",
417
  }));
418
+ appendLog(`Reset ${TASK_LABEL_MAP[taskId] || taskId}`);
419
  } catch (err) {
420
  appendLog(`Reset failed: ${err.message}`);
421
  } finally {
 
471
  const data = msg?.data || {};
472
  const normalized = normalizeObsFromWs(data);
473
  setObs(normalized);
474
+ const label = ACTION_LABELS[payload.action_type] || payload.action_type;
475
+ const rwd = data.reward ?? 0;
476
+ appendLog(`${label} → reward: ${Number(rwd).toFixed(3)}`);
477
  } catch (err) {
478
  appendLog(`Step failed: ${err.message}`);
479
  } finally {
 
489
  setLoading(true);
490
  try {
491
  const data = await apiPost("/agent/suggest", { observation: obs });
492
+ const label =
493
+ ACTION_LABELS[data.action.action_type] || data.action.action_type;
494
+ appendLog(`AI suggests: ${label}`);
495
  await handleStep(data.action);
496
  } catch (err) {
497
  appendLog(`AI suggestion failed: ${err.message}`);
 
500
  }
501
  };
502
 
503
+ const formatDrugName = (drugId) => {
504
+ if (!drugId) return "";
505
+ return drugId
506
+ .replace(/^DRUG_/, "")
507
+ .replace(/_/g, " ")
508
+ .replace(/\b\w/g, (c) => c.toUpperCase());
509
+ };
510
+
511
+ const currentRisk = obs?.metadata?.current_risk;
512
+ const baselineRisk = obs?.metadata?.baseline_risk;
513
+
514
  return (
515
  <div className="shell">
516
  <div className="bg-orb orb-a" />
517
  <div className="bg-orb orb-b" />
518
 
519
+ {/* Spotlight Guide */}
520
+ {showGuide && (
521
+ <SpotlightGuide
522
+ step={guideStep}
523
+ steps={GUIDE_STEPS}
524
+ onNext={() => setGuideStep((s) => Math.min(s + 1, GUIDE_STEPS.length - 1))}
525
+ onPrev={() => setGuideStep((s) => Math.max(0, s - 1))}
526
+ onClose={() => setShowGuide(false)}
527
+ />
528
+ )}
529
+
530
  <div className="container">
531
+ <header className="topbar glass" data-guide="topbar">
532
+ <div className="title-wrap">
533
+ <h1>PolypharmacyEnv</h1>
534
+ <p>Elderly Medication Safety — Powered by Neural Bandits</p>
535
+ </div>
536
+ <div className="topbar-right">
537
+ <div className={`status-chip ${hasValidEpisode ? "live" : "idle"}`}>
538
+ {hasValidEpisode
539
+ ? isDone
540
+ ? "Episode Complete"
541
+ : "Session Live"
542
+ : "Ready"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  </div>
544
+ <button
545
+ className="guide-trigger"
546
+ onClick={() => {
547
+ setGuideStep(0);
548
+ setShowGuide(true);
549
+ }}
550
+ title="Open guided walkthrough"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  >
552
+ ?
553
+ </button>
554
+ </div>
555
+ <div className="actions" data-guide="task-selector">
556
+ <select value={taskId} onChange={(e) => setTaskId(e.target.value)}>
557
+ {TASKS.map((t) => (
558
+ <option key={t.id} value={t.id}>
559
+ {t.label}
560
+ </option>
561
+ ))}
562
  </select>
563
+ <button onClick={handleReset} disabled={loading}>
564
+ Reset Episode
565
+ </button>
566
+ <button
567
+ className="secondary"
568
+ onClick={askAi}
569
+ disabled={!hasValidEpisode || isDone || loading}
570
+ >
571
+ Ask AI + Auto Step
572
+ </button>
573
  </div>
574
+ </header>
575
 
576
+ <main className="layout">
577
+ {/* Episode Info */}
578
+ <section className="panel glass panel-wide" data-guide="episode-panel">
579
+ <h2>Episode Overview</h2>
580
+ {hasValidEpisode ? (
581
+ <>
582
+ <div className="kpi-grid">
583
+ <div>
584
+ <span>Episode</span>
585
+ <strong>{obs.episode_id}</strong>
586
+ </div>
587
+ <div>
588
+ <span>Task</span>
589
+ <strong>{TASK_LABEL_MAP[obs.task_id] || obs.task_id}</strong>
590
+ </div>
591
+ <div>
592
+ <span>Patient</span>
593
+ <strong>
594
+ Age {obs.age}, {obs.sex === "M" ? "Male" : "Female"}
595
+ </strong>
596
+ </div>
597
+ <div>
598
+ <span>Step</span>
599
+ <strong>{obs.step_index}</strong>
600
+ </div>
601
+ <div>
602
+ <span>Query Budget</span>
603
+ <strong>{obs.remaining_query_budget} remaining</strong>
604
+ </div>
605
+ <div>
606
+ <span>Intervention Budget</span>
607
+ <strong>
608
+ {obs.remaining_intervention_budget} remaining
609
+ </strong>
610
+ </div>
611
+ </div>
612
 
613
+ {currentRisk !== undefined && baselineRisk !== undefined && (
614
+ <div className="risk-bar-wrap">
615
+ <div className="risk-labels">
616
+ <span>
617
+ Baseline Risk:{" "}
618
+ <strong>{Number(baselineRisk).toFixed(3)}</strong>
619
+ </span>
620
+ <span>
621
+ Current Risk:{" "}
622
+ <strong
623
+ className={
624
+ currentRisk < baselineRisk
625
+ ? "risk-down"
626
+ : "risk-same"
627
+ }
628
+ >
629
+ {Number(currentRisk).toFixed(3)}
630
+ </strong>
631
+ </span>
632
+ </div>
633
+ <div className="risk-bar">
634
+ <div
635
+ className="risk-fill"
636
+ style={{
637
+ width: `${Math.min(currentRisk * 100, 100)}%`,
638
+ }}
639
+ />
640
+ </div>
641
+ </div>
642
+ )}
643
+
644
+ {obs.conditions && obs.conditions.length > 0 && (
645
+ <div className="conditions-row">
646
+ <span className="conditions-label">Conditions:</span>
647
+ {obs.conditions.map((c) => (
648
+ <span key={c} className="condition-tag">
649
+ {c.replace(/_/g, " ")}
650
+ </span>
651
+ ))}
652
+ </div>
653
+ )}
654
+ </>
655
+ ) : (
656
+ <p className="muted">
657
+ Select a task difficulty and click <strong>Reset Episode</strong>{" "}
658
+ to begin a patient case.
659
+ </p>
660
+ )}
661
+ {noBudgetsLeft && !isDone && (
662
+ <div className="budget-note">
663
+ All budgets exhausted. Click <strong>Finish Review</strong> to
664
+ receive your final score.
665
+ </div>
666
+ )}
667
+ {isDone && (
668
+ <div className="budget-note done-note">
669
+ Episode complete
670
+ {finalScore !== null
671
+ ? ` — Final score: ${(finalScore * 100).toFixed(1)}%`
672
+ : ""}
673
+ . Click <strong>Reset Episode</strong> to start a new case.
674
+ </div>
675
+ )}
676
+ </section>
677
+
678
+ {/* Action Console */}
679
+ <section className="panel glass" data-guide="action-console">
680
+ <h2>Action Console</h2>
681
+ <div className="action-row">
682
+ <label>Action Type</label>
683
  <select
684
+ value={action.action_type}
685
+ onChange={(e) =>
686
+ setAction((a) => ({ ...a, action_type: e.target.value }))
687
+ }
688
  >
689
+ {Object.entries(ACTION_LABELS).map(([val, label]) => (
690
+ <option key={val} value={val}>
691
+ {label}
 
692
  </option>
693
  ))}
694
  </select>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  </div>
696
+
697
+ {action.action_type === "query_ddi" && (
698
+ <div className="stack stack-two">
699
+ <div className="field-group">
700
+ <label>Drug 1</label>
701
+ <select
702
+ value={action.drug_id_1}
703
+ onChange={(e) =>
704
+ setAction((a) => ({ ...a, drug_id_1: e.target.value }))
705
+ }
706
+ >
707
+ <option value="">Select drug</option>
708
+ {medIds.map((id) => (
709
+ <option key={id} value={id}>
710
+ {formatDrugName(id)}
711
+ </option>
712
+ ))}
713
+ </select>
714
+ </div>
715
+ <div className="field-group">
716
+ <label>Drug 2</label>
717
+ <select
718
+ value={action.drug_id_2}
719
+ onChange={(e) =>
720
+ setAction((a) => ({ ...a, drug_id_2: e.target.value }))
721
+ }
722
+ >
723
+ <option value="">Select drug</option>
724
+ {medIds.map((id) => (
725
+ <option key={id} value={id}>
726
+ {formatDrugName(id)}
727
+ </option>
728
+ ))}
729
+ </select>
730
+ </div>
731
  </div>
732
+ )}
733
+
734
+ {action.action_type === "propose_intervention" && (
735
+ <div className="stack">
736
+ <div className="field-group">
737
+ <label>Target Drug</label>
738
+ <select
739
+ value={action.target_drug_id}
740
+ onChange={(e) =>
741
+ setAction((a) => ({
742
+ ...a,
743
+ target_drug_id: e.target.value,
744
+ }))
745
+ }
746
+ >
747
+ <option value="">Select target drug</option>
748
+ {medIds.map((id) => (
749
+ <option key={id} value={id}>
750
+ {formatDrugName(id)}
751
+ </option>
752
+ ))}
753
+ </select>
754
+ </div>
755
+ <div className="field-group">
756
+ <label>Intervention Type</label>
757
+ <select
758
+ value={action.intervention_type}
759
+ onChange={(e) =>
760
+ setAction((a) => ({
761
+ ...a,
762
+ intervention_type: e.target.value,
763
+ }))
764
+ }
765
+ >
766
+ {Object.entries(INTERVENTION_LABELS).map(([val, label]) => (
767
+ <option key={val} value={val}>
768
+ {label}
769
+ </option>
770
+ ))}
771
+ </select>
772
+ </div>
773
+ <div className="field-group">
774
+ <label>New Drug ID (optional, for substitution)</label>
775
+ <input
776
+ placeholder="Leave blank for auto-selection"
777
+ value={action.proposed_new_drug_id}
778
+ onChange={(e) =>
779
+ setAction((a) => ({
780
+ ...a,
781
+ proposed_new_drug_id: e.target.value,
782
+ }))
783
+ }
784
+ />
785
+ </div>
786
+ <div className="field-group">
787
+ <label>Rationale (optional)</label>
788
+ <input
789
+ placeholder="e.g., High bleeding risk with concurrent warfarin"
790
+ value={action.rationale}
791
+ onChange={(e) =>
792
+ setAction((a) => ({ ...a, rationale: e.target.value }))
793
+ }
794
+ />
795
+ </div>
796
+ </div>
797
+ )}
798
+
799
+ <button
800
+ className="submit-btn"
801
+ onClick={() => handleStep()}
802
+ disabled={!isActionValid() || loading}
803
+ >
804
+ {noBudgetsLeft ? "Finish Review" : "Submit Step"}
805
+ </button>
806
+ </section>
807
+
808
+ {/* Current Medications */}
809
+ <section className="panel glass" data-guide="medications-panel">
810
+ <h2>
811
+ Current Medications
812
+ {obs?.current_medications?.length
813
+ ? ` (${obs.current_medications.length})`
814
+ : ""}
815
+ </h2>
816
+ <div className="med-grid">
817
+ {(obs?.current_medications || []).map((m) => (
818
+ <div
819
+ key={m.drug_id}
820
+ className={`med-card ${m.is_high_risk_elderly ? "high-risk" : ""}`}
821
+ >
822
+ <div className="med-card-header">
823
+ <strong>{formatDrugName(m.drug_id)}</strong>
824
+ {m.is_high_risk_elderly && (
825
+ <span className="risk-badge">High Risk</span>
826
+ )}
827
+ </div>
828
+ <p className="med-generic">{m.generic_name}</p>
829
+ <div className="med-details">
830
+ <span>{m.dose_mg} mg</span>
831
+ <span className="med-atc">{m.atc_class}</span>
832
+ </div>
833
+ {m.beers_flags && m.beers_flags.length > 0 && (
834
+ <div className="beers-flags">
835
+ {m.beers_flags.map((f, i) => (
836
+ <span key={i} className="beers-tag">
837
+ {f}
838
+ </span>
839
+ ))}
840
+ </div>
841
+ )}
842
+ </div>
843
+ ))}
844
+ </div>
845
+ {(!obs?.current_medications ||
846
+ obs.current_medications.length === 0) && (
847
+ <p className="muted">No medications loaded. Reset an episode to begin.</p>
848
+ )}
849
+ </section>
850
+
851
+ {/* Interaction Queries & Interventions */}
852
+ {hasValidEpisode && (
853
+ <section className="panel glass panel-wide">
854
+ <div className="history-grid">
855
+ <div>
856
+ <h3>Drug Interaction Checks ({obs?.interaction_queries?.length || 0})</h3>
857
+ <div className="history-list">
858
+ {(obs?.interaction_queries || []).map((q, i) => (
859
+ <div
860
+ key={i}
861
+ className={`history-item severity-${q.severity}`}
862
+ >
863
+ <strong>
864
+ {formatDrugName(q.drug_id_1)} +{" "}
865
+ {formatDrugName(q.drug_id_2)}
866
+ </strong>
867
+ <span className={`severity-tag ${q.severity}`}>
868
+ {q.severity}
869
+ </span>
870
+ {q.recommendation && (
871
+ <p className="history-detail">
872
+ {q.recommendation.replace(/_/g, " ")}
873
+ </p>
874
+ )}
875
+ </div>
876
+ ))}
877
+ {(!obs?.interaction_queries || obs.interaction_queries.length === 0) && (
878
+ <p className="muted">No queries yet.</p>
879
+ )}
880
+ </div>
881
+ </div>
882
+ <div>
883
+ <h3>Proposed Changes ({obs?.interventions?.length || 0})</h3>
884
+ <div className="history-list">
885
+ {(obs?.interventions || []).map((iv, i) => (
886
+ <div key={i} className="history-item intervention-item">
887
+ <strong>{formatDrugName(iv.target_drug_id)}</strong>
888
+ <span className="intervention-tag">
889
+ {INTERVENTION_LABELS[iv.action_type] || iv.action_type}
890
+ </span>
891
+ {iv.proposed_new_drug_id && (
892
+ <p className="history-detail">
893
+ Replaced with: {formatDrugName(iv.proposed_new_drug_id)}
894
+ </p>
895
+ )}
896
+ {iv.rationale && (
897
+ <p className="history-detail">{iv.rationale}</p>
898
+ )}
899
+ </div>
900
+ ))}
901
+ {(!obs?.interventions || obs.interventions.length === 0) && (
902
+ <p className="muted">No interventions yet.</p>
903
+ )}
904
+ </div>
905
+ </div>
906
+ </div>
907
+ </section>
908
+ )}
909
+
910
+ {/* Event Log */}
911
+ <section className="panel glass panel-wide" data-guide="event-log">
912
+ <h2>Event Log</h2>
913
+ <div className="logs">
914
+ {log.length === 0 && (
915
+ <div className="log-empty">
916
+ Events will appear here as you interact with the environment.
917
+ </div>
918
+ )}
919
+ {log.map((line, idx) => (
920
+ <div key={idx}>{line}</div>
921
+ ))}
922
+ </div>
923
+ </section>
924
+ </main>
925
+
926
+ <footer className="app-footer">
927
+ <p>
928
+ PolypharmacyEnv — Built with{" "}
929
+ <a
930
+ href="https://github.com/meta-pytorch/OpenEnv"
931
+ target="_blank"
932
+ rel="noopener noreferrer"
933
+ >
934
+ PyTorch OpenEnv
935
+ </a>{" "}
936
+ | Based on{" "}
937
+ <a
938
+ href="https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5"
939
+ target="_blank"
940
+ rel="noopener noreferrer"
941
+ >
942
+ Neural Bandits for Polypharmacy
943
+ </a>{" "}
944
+ (Larouche et al.)
945
+ </p>
946
+ </footer>
947
  </div>
948
  </div>
949
  );
frontend/src/styles.css CHANGED
@@ -1,18 +1,22 @@
1
  :root {
2
- --bg: #070814;
3
- --bg-layer: #0a1026;
4
- --panel: rgba(14, 22, 44, 0.72);
5
- --panel-solid: rgba(20, 28, 52, 0.92);
6
- --text: #e8f1ff;
7
- --muted: #9ab2db;
8
- --primary: #37d4ff;
9
- --primary-2: #5a8dff;
10
- --accent: #9d59ff;
11
- --success: #6dfbcf;
12
- --border: rgba(122, 162, 255, 0.28);
13
- --line: rgba(109, 143, 225, 0.18);
14
- --shadow: 0 16px 45px rgba(5, 8, 23, 0.6);
15
- --shadow-strong: 0 14px 32px rgba(44, 105, 255, 0.4);
 
 
 
 
16
  }
17
 
18
  * {
@@ -22,187 +26,253 @@
22
  body {
23
  margin: 0;
24
  color: var(--text);
25
- font-family: "Segoe UI", "SF Pro Text", "Helvetica Neue", sans-serif;
26
- background:
27
- radial-gradient(circle at 8% 12%, rgba(121, 87, 255, 0.22), transparent 38%),
28
- radial-gradient(circle at 88% 20%, rgba(59, 204, 255, 0.26), transparent 34%),
29
- radial-gradient(circle at 50% 100%, rgba(43, 128, 255, 0.26), transparent 40%),
30
- linear-gradient(145deg, var(--bg) 0%, var(--bg-layer) 60%, #04060f 100%);
31
- background-attachment: fixed;
32
  }
33
 
34
  .shell {
35
  min-height: 100vh;
36
  position: relative;
37
  overflow: hidden;
38
- padding: 24px 16px 34px;
39
  }
40
 
41
  .container {
42
- width: min(1320px, 100%);
43
  margin: 0 auto;
44
  position: relative;
45
  z-index: 2;
46
  }
47
 
 
48
  .bg-orb {
49
  position: absolute;
50
  border-radius: 50%;
51
  pointer-events: none;
52
- opacity: 0.9;
53
- filter: blur(18px);
54
  }
55
 
56
  .orb-a {
57
- width: min(46vw, 530px);
58
  aspect-ratio: 1 / 1;
59
- right: -9%;
60
- top: -10%;
61
- background: radial-gradient(circle, rgba(52, 203, 255, 0.35), rgba(52, 203, 255, 0.04) 70%);
62
  }
63
 
64
  .orb-b {
65
- width: min(40vw, 460px);
66
  aspect-ratio: 1 / 1;
67
- left: -9%;
68
- bottom: -15%;
69
- background: radial-gradient(circle, rgba(160, 102, 255, 0.3), rgba(160, 102, 255, 0.06) 72%);
70
  }
71
 
 
 
72
  .glass {
73
- background:
74
- linear-gradient(180deg, rgba(255, 255, 255, 0.06), rgba(255, 255, 255, 0.01)),
75
- var(--panel);
76
  border: 1px solid var(--border);
77
  box-shadow: var(--shadow);
78
- backdrop-filter: blur(12px);
79
  }
80
 
 
 
81
  .topbar {
82
- border-radius: 24px;
83
- padding: clamp(14px, 2vw, 20px);
84
- display: grid;
85
- gap: 12px 16px;
86
- grid-template-columns: minmax(220px, 1.2fr) auto minmax(280px, 1fr);
87
  align-items: center;
 
 
 
 
 
 
 
88
  }
89
 
90
  .title-wrap h1 {
91
  margin: 0;
92
- font-size: clamp(1.15rem, 2.2vw, 1.95rem);
93
- letter-spacing: 0.02em;
94
- text-transform: uppercase;
95
- text-shadow: 0 0 16px rgba(106, 192, 255, 0.3);
 
 
 
 
96
  }
97
 
98
  .title-wrap p {
99
- margin: 6px 0 0;
100
- font-size: 0.84rem;
101
  color: var(--muted);
102
- letter-spacing: 0.03em;
103
- text-transform: uppercase;
 
 
 
 
 
104
  }
105
 
106
  .status-chip {
107
- justify-self: center;
108
- padding: 7px 14px;
109
  border-radius: 999px;
110
  font-size: 0.72rem;
111
- font-weight: 700;
112
- letter-spacing: 0.08em;
113
  text-transform: uppercase;
114
  border: 1px solid transparent;
 
115
  }
116
 
117
  .status-chip.live {
118
- color: #052c24;
119
- background: linear-gradient(90deg, rgba(126, 255, 220, 0.9), rgba(84, 244, 196, 0.95));
120
- box-shadow: 0 0 14px rgba(96, 244, 198, 0.36);
121
  }
122
 
123
  .status-chip.idle {
124
- color: #d8e8ff;
125
- border-color: rgba(117, 186, 255, 0.48);
126
- background: rgba(60, 106, 198, 0.25);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  }
128
 
129
  .actions {
130
  display: flex;
131
- justify-content: flex-end;
132
  flex-wrap: wrap;
133
  gap: 10px;
134
  }
135
 
 
 
136
  button,
137
  select,
138
  input {
139
  width: 100%;
140
  min-height: 42px;
141
- border-radius: 12px;
142
  border: 1px solid var(--border);
143
- font-size: 0.92rem;
144
- padding: 10px 12px;
145
  color: var(--text);
146
- background: rgba(11, 19, 38, 0.84);
 
 
147
  }
148
 
149
- select,
150
- input {
151
- transition: border-color 120ms ease, box-shadow 120ms ease;
152
  }
153
 
154
  select:focus,
155
  input:focus {
156
  outline: none;
157
- border-color: rgba(119, 200, 255, 0.88);
158
- box-shadow: 0 0 0 2px rgba(95, 187, 255, 0.18);
 
 
 
 
 
 
159
  }
160
 
161
  button {
162
  cursor: pointer;
163
  border: 0;
164
  width: auto;
165
- font-weight: 700;
166
- letter-spacing: 0.02em;
167
- background: linear-gradient(135deg, var(--primary), var(--primary-2) 55%, var(--accent));
168
- box-shadow: var(--shadow-strong);
169
- transition: transform 140ms ease, filter 140ms ease, box-shadow 140ms ease;
 
 
170
  }
171
 
172
  button:hover {
 
173
  transform: translateY(-1px);
174
- filter: brightness(1.04);
175
- box-shadow: 0 18px 32px rgba(50, 141, 255, 0.48);
176
  }
177
 
178
  button:active {
179
  transform: translateY(0);
 
180
  }
181
 
182
  button.secondary {
183
- background: linear-gradient(135deg, rgba(95, 185, 255, 0.9), rgba(154, 102, 255, 0.86));
 
 
 
 
 
 
184
  }
185
 
186
  button:disabled {
187
- opacity: 0.56;
188
  cursor: not-allowed;
189
- filter: grayscale(0.2);
190
  box-shadow: none;
191
  transform: none;
192
  }
193
 
 
 
 
 
 
 
 
 
 
194
  .layout {
195
- margin-top: 16px;
196
  display: grid;
197
- gap: 14px;
198
- grid-template-columns: 1.12fr 0.88fr;
199
  align-items: start;
200
  }
201
 
202
  .panel {
203
- border-radius: 20px;
204
- padding: clamp(14px, 1.8vw, 20px);
205
  position: relative;
 
206
  }
207
 
208
  .panel::after {
@@ -219,151 +289,576 @@ button:disabled {
219
  }
220
 
221
  .panel h2 {
222
- margin: 0 0 12px;
223
- font-size: 1rem;
224
- font-weight: 700;
225
  letter-spacing: 0.05em;
226
  text-transform: uppercase;
 
 
 
 
 
 
 
 
 
 
227
  }
228
 
 
 
229
  .kpi-grid {
230
  display: grid;
231
- gap: 10px;
232
- grid-template-columns: repeat(3, minmax(0, 1fr));
233
  }
234
 
235
  .kpi-grid div {
236
- border-radius: 13px;
237
  border: 1px solid var(--border);
238
  background: var(--panel-solid);
239
- padding: 11px 12px;
 
240
  }
241
 
242
  .kpi-grid span {
243
  display: block;
244
- margin-bottom: 4px;
245
  font-size: 0.72rem;
246
  color: var(--muted);
247
  text-transform: uppercase;
248
- letter-spacing: 0.05em;
249
  }
250
 
251
  .kpi-grid strong {
252
- font-size: 1.06rem;
253
- line-height: 1.2;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  }
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  .action-row,
257
  .stack {
258
  display: grid;
259
- gap: 10px;
260
- margin-bottom: 12px;
261
  }
262
 
263
- .action-row label {
 
 
264
  color: var(--muted);
265
  font-size: 0.78rem;
266
- letter-spacing: 0.05em;
267
  text-transform: uppercase;
 
 
 
 
 
 
268
  }
269
 
270
  .stack-two {
271
- grid-template-columns: repeat(2, minmax(0, 1fr));
272
  }
273
 
 
 
274
  .med-grid {
275
  display: grid;
276
- grid-template-columns: repeat(3, minmax(0, 1fr));
277
- gap: 10px;
278
- max-height: 430px;
279
- overflow: auto;
280
  padding-right: 4px;
281
  }
282
 
283
  .med-card {
284
- border-radius: 14px;
285
  border: 1px solid var(--border);
286
  background: var(--panel-solid);
287
- padding: 11px 12px;
288
- transition: transform 130ms ease, border-color 130ms ease;
 
289
  }
290
 
291
  .med-card:hover {
292
- transform: translateY(-1px);
293
- border-color: rgba(109, 224, 255, 0.72);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  }
295
 
296
- .med-card p {
297
- margin: 6px 0 4px;
298
  color: var(--muted);
 
299
  text-transform: capitalize;
 
 
 
300
  }
301
 
302
- .med-card small {
303
- color: #c7d9ff;
 
 
 
304
  }
305
 
306
- .logs {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  max-height: 300px;
308
- overflow: auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  padding-right: 4px;
310
- display: grid;
311
- gap: 7px;
 
312
  font-size: 0.84rem;
313
- font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, monospace;
314
  }
315
 
316
  .logs div {
317
- border-radius: 10px;
318
  border: 1px solid var(--border);
319
- background: rgba(10, 16, 31, 0.84);
320
- padding: 8px 10px;
321
- color: #dbebff;
 
 
 
 
 
 
 
 
 
 
322
  }
323
 
 
 
324
  .muted {
325
  margin: 0;
326
  color: var(--muted);
 
327
  }
328
 
329
  .budget-note {
330
- margin-top: 10px;
331
  border: 1px solid var(--border);
332
- border-radius: 12px;
333
- padding: 10px 12px;
334
- background: rgba(13, 22, 42, 0.82);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  }
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  @media (max-width: 1180px) {
338
  .layout {
339
  grid-template-columns: 1fr;
340
  }
341
 
342
  .topbar {
343
- grid-template-columns: 1fr;
 
344
  }
345
 
346
- .status-chip {
347
- justify-self: start;
348
  }
349
 
350
  .actions {
 
351
  justify-content: flex-start;
352
  }
 
 
 
 
353
  }
354
 
355
  @media (max-width: 760px) {
356
  .shell {
357
- padding: 14px 10px 24px;
358
  }
359
 
360
  .topbar,
361
  .panel {
362
- border-radius: 16px;
 
363
  }
364
 
365
  .actions {
366
- width: 100%;
367
  }
368
 
369
  .actions button,
@@ -371,13 +866,53 @@ button:disabled {
371
  width: 100%;
372
  }
373
 
374
- .kpi-grid,
375
- .med-grid,
 
 
 
 
 
 
376
  .stack-two {
377
  grid-template-columns: 1fr;
378
  }
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  .logs {
381
- max-height: 240px;
382
  }
383
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  :root {
2
+ --bg: #0d1117;
3
+ --bg-layer: #0d1117;
4
+ --panel: #161b22;
5
+ --panel-solid: #1c2333;
6
+ --text: #e6edf3;
7
+ --muted: #8b949e;
8
+ --primary: #58a6ff;
9
+ --primary-2: #79c0ff;
10
+ --accent: #58a6ff;
11
+ --success: #3fb950;
12
+ --danger: #f85149;
13
+ --warning: #d29922;
14
+ --border: rgba(48, 54, 61, 0.7);
15
+ --line: rgba(48, 54, 61, 0.5);
16
+ --shadow: 0 1px 3px rgba(0, 0, 0, 0.12), 0 4px 12px rgba(0, 0, 0, 0.08);
17
+ --shadow-strong: 0 4px 16px rgba(0, 0, 0, 0.2);
18
+ --radius: 12px;
19
+ --radius-sm: 8px;
20
  }
21
 
22
  * {
 
26
  body {
27
  margin: 0;
28
  color: var(--text);
29
+ font-family: "Inter", -apple-system, "Segoe UI", "Helvetica Neue", sans-serif;
30
+ background: var(--bg);
31
+ line-height: 1.55;
32
+ -webkit-font-smoothing: antialiased;
33
+ -moz-osx-font-smoothing: grayscale;
 
 
34
  }
35
 
36
  .shell {
37
  min-height: 100vh;
38
  position: relative;
39
  overflow: hidden;
40
+ padding: 20px 24px 40px;
41
  }
42
 
43
  .container {
44
+ width: min(1400px, 100%);
45
  margin: 0 auto;
46
  position: relative;
47
  z-index: 2;
48
  }
49
 
50
+ /* Background orbs - subtle and muted for a professional look */
51
  .bg-orb {
52
  position: absolute;
53
  border-radius: 50%;
54
  pointer-events: none;
55
+ opacity: 0.3;
56
+ filter: blur(80px);
57
  }
58
 
59
  .orb-a {
60
+ width: min(42vw, 500px);
61
  aspect-ratio: 1 / 1;
62
+ right: -8%;
63
+ top: -8%;
64
+ background: radial-gradient(circle, rgba(88, 166, 255, 0.12), transparent 70%);
65
  }
66
 
67
  .orb-b {
68
+ width: min(36vw, 420px);
69
  aspect-ratio: 1 / 1;
70
+ left: -8%;
71
+ bottom: -12%;
72
+ background: radial-gradient(circle, rgba(88, 166, 255, 0.08), transparent 72%);
73
  }
74
 
75
+ /* ── Panels ──────────────────────────────────────────────────── */
76
+
77
  .glass {
78
+ background: var(--panel);
 
 
79
  border: 1px solid var(--border);
80
  box-shadow: var(--shadow);
 
81
  }
82
 
83
+ /* ── Top Bar ─────────────────────────────────────────────────── */
84
+
85
  .topbar {
86
+ border-radius: var(--radius);
87
+ padding: 14px 24px;
88
+ display: flex;
 
 
89
  align-items: center;
90
+ gap: 16px;
91
+ flex-wrap: wrap;
92
+ }
93
+
94
+ .title-wrap {
95
+ flex: 1;
96
+ min-width: 200px;
97
  }
98
 
99
  .title-wrap h1 {
100
  margin: 0;
101
+ font-size: clamp(1rem, 1.8vw, 1.35rem);
102
+ font-weight: 700;
103
+ letter-spacing: -0.01em;
104
+ color: var(--text);
105
+ background: none;
106
+ -webkit-background-clip: unset;
107
+ -webkit-text-fill-color: unset;
108
+ background-clip: unset;
109
  }
110
 
111
  .title-wrap p {
112
+ margin: 2px 0 0;
113
+ font-size: 0.8rem;
114
  color: var(--muted);
115
+ letter-spacing: 0.01em;
116
+ }
117
+
118
+ .topbar-right {
119
+ display: flex;
120
+ align-items: center;
121
+ gap: 10px;
122
  }
123
 
124
  .status-chip {
125
+ padding: 5px 14px;
 
126
  border-radius: 999px;
127
  font-size: 0.72rem;
128
+ font-weight: 600;
129
+ letter-spacing: 0.04em;
130
  text-transform: uppercase;
131
  border: 1px solid transparent;
132
+ white-space: nowrap;
133
  }
134
 
135
  .status-chip.live {
136
+ color: #ffffff;
137
+ background: var(--success);
138
+ box-shadow: none;
139
  }
140
 
141
  .status-chip.idle {
142
+ color: var(--muted);
143
+ border-color: var(--border);
144
+ background: rgba(48, 54, 61, 0.3);
145
+ }
146
+
147
+ .guide-trigger {
148
+ width: 34px !important;
149
+ height: 34px;
150
+ min-height: 34px;
151
+ padding: 0;
152
+ border-radius: 50% !important;
153
+ font-size: 0.95rem;
154
+ font-weight: 700;
155
+ display: flex;
156
+ align-items: center;
157
+ justify-content: center;
158
+ background: rgba(88, 166, 255, 0.1);
159
+ border: 1px solid rgba(88, 166, 255, 0.25);
160
+ color: var(--primary);
161
+ box-shadow: none;
162
+ }
163
+
164
+ .guide-trigger:hover {
165
+ background: rgba(88, 166, 255, 0.2);
166
+ box-shadow: none;
167
  }
168
 
169
  .actions {
170
  display: flex;
171
+ align-items: center;
172
  flex-wrap: wrap;
173
  gap: 10px;
174
  }
175
 
176
+ /* ── Form Controls ───────────────────────────────────────────── */
177
+
178
  button,
179
  select,
180
  input {
181
  width: 100%;
182
  min-height: 42px;
183
+ border-radius: var(--radius-sm);
184
  border: 1px solid var(--border);
185
+ font-size: 0.9rem;
186
+ padding: 10px 14px;
187
  color: var(--text);
188
+ background: #0d1117;
189
+ font-family: inherit;
190
+ transition: border-color 150ms ease, box-shadow 150ms ease, background 150ms ease;
191
  }
192
 
193
+ select:hover,
194
+ input:hover {
195
+ border-color: rgba(139, 148, 158, 0.5);
196
  }
197
 
198
  select:focus,
199
  input:focus {
200
  outline: none;
201
+ border-color: var(--primary);
202
+ box-shadow: 0 0 0 3px rgba(88, 166, 255, 0.15);
203
+ background: #0d1117;
204
+ }
205
+
206
+ select {
207
+ cursor: pointer;
208
+ appearance: auto;
209
  }
210
 
211
  button {
212
  cursor: pointer;
213
  border: 0;
214
  width: auto;
215
+ font-weight: 600;
216
+ letter-spacing: 0.01em;
217
+ white-space: nowrap;
218
+ background: var(--primary);
219
+ color: #ffffff;
220
+ box-shadow: none;
221
+ transition: background 150ms ease, transform 100ms ease, box-shadow 150ms ease;
222
  }
223
 
224
  button:hover {
225
+ background: #79c0ff;
226
  transform: translateY(-1px);
227
+ filter: none;
228
+ box-shadow: 0 2px 8px rgba(88, 166, 255, 0.25);
229
  }
230
 
231
  button:active {
232
  transform: translateY(0);
233
+ background: #4090e0;
234
  }
235
 
236
  button.secondary {
237
+ background: rgba(88, 166, 255, 0.15);
238
+ color: var(--primary);
239
+ border: 1px solid rgba(88, 166, 255, 0.3);
240
+ }
241
+
242
+ button.secondary:hover {
243
+ background: rgba(88, 166, 255, 0.25);
244
  }
245
 
246
  button:disabled {
247
+ opacity: 0.4;
248
  cursor: not-allowed;
249
+ filter: none;
250
  box-shadow: none;
251
  transform: none;
252
  }
253
 
254
+ .submit-btn {
255
+ width: 100%;
256
+ margin-top: 10px;
257
+ min-height: 44px;
258
+ font-size: 0.92rem;
259
+ }
260
+
261
+ /* ── Layout ──────────────────────────────────────────────────── */
262
+
263
  .layout {
264
+ margin-top: 20px;
265
  display: grid;
266
+ gap: 20px;
267
+ grid-template-columns: 1.1fr 0.9fr;
268
  align-items: start;
269
  }
270
 
271
  .panel {
272
+ border-radius: var(--radius);
273
+ padding: 24px;
274
  position: relative;
275
+ overflow: hidden;
276
  }
277
 
278
  .panel::after {
 
289
  }
290
 
291
  .panel h2 {
292
+ margin: 0 0 16px;
293
+ font-size: 0.82rem;
294
+ font-weight: 600;
295
  letter-spacing: 0.05em;
296
  text-transform: uppercase;
297
+ color: var(--muted);
298
+ }
299
+
300
+ .panel h3 {
301
+ margin: 0 0 12px;
302
+ font-size: 0.8rem;
303
+ font-weight: 600;
304
+ letter-spacing: 0.04em;
305
+ text-transform: uppercase;
306
+ color: var(--muted);
307
  }
308
 
309
+ /* ── KPI Grid ────────────────────────────────────────────────── */
310
+
311
  .kpi-grid {
312
  display: grid;
313
+ gap: 12px;
314
+ grid-template-columns: repeat(3, 1fr);
315
  }
316
 
317
  .kpi-grid div {
318
+ border-radius: var(--radius-sm);
319
  border: 1px solid var(--border);
320
  background: var(--panel-solid);
321
+ padding: 16px 18px;
322
+ overflow: hidden;
323
  }
324
 
325
  .kpi-grid span {
326
  display: block;
327
+ margin-bottom: 6px;
328
  font-size: 0.72rem;
329
  color: var(--muted);
330
  text-transform: uppercase;
331
+ letter-spacing: 0.06em;
332
  }
333
 
334
  .kpi-grid strong {
335
+ font-size: 1.05rem;
336
+ font-weight: 600;
337
+ line-height: 1.3;
338
+ word-break: break-word;
339
+ overflow-wrap: break-word;
340
+ color: var(--text);
341
+ }
342
+
343
+ /* ── Risk Bar ────────────────────────────────────────────────── */
344
+
345
+ .risk-bar-wrap {
346
+ margin-top: 16px;
347
+ }
348
+
349
+ .risk-labels {
350
+ display: flex;
351
+ justify-content: space-between;
352
+ font-size: 0.8rem;
353
+ color: var(--muted);
354
+ margin-bottom: 8px;
355
  }
356
 
357
+ .risk-labels strong {
358
+ font-size: 0.85rem;
359
+ color: var(--text);
360
+ }
361
+
362
+ .risk-down {
363
+ color: var(--success) !important;
364
+ }
365
+
366
+ .risk-same {
367
+ color: var(--warning) !important;
368
+ }
369
+
370
+ .risk-bar {
371
+ height: 6px;
372
+ background: rgba(48, 54, 61, 0.5);
373
+ border-radius: 3px;
374
+ overflow: hidden;
375
+ }
376
+
377
+ .risk-fill {
378
+ height: 100%;
379
+ background: linear-gradient(90deg, var(--success), var(--warning), var(--danger));
380
+ border-radius: 3px;
381
+ transition: width 300ms ease;
382
+ }
383
+
384
+ /* ── Conditions ──────────────────────────────────────────────── */
385
+
386
+ .conditions-row {
387
+ margin-top: 14px;
388
+ display: flex;
389
+ flex-wrap: wrap;
390
+ align-items: center;
391
+ gap: 8px;
392
+ }
393
+
394
+ .conditions-label {
395
+ font-size: 0.78rem;
396
+ color: var(--muted);
397
+ text-transform: uppercase;
398
+ letter-spacing: 0.04em;
399
+ margin-right: 4px;
400
+ }
401
+
402
+ .condition-tag {
403
+ font-size: 0.75rem;
404
+ padding: 4px 12px;
405
+ border-radius: 999px;
406
+ background: rgba(88, 166, 255, 0.1);
407
+ border: 1px solid rgba(88, 166, 255, 0.2);
408
+ color: var(--primary);
409
+ text-transform: capitalize;
410
+ white-space: nowrap;
411
+ }
412
+
413
+ /* ── Action Console ──────────────────────────────────────────── */
414
+
415
  .action-row,
416
  .stack {
417
  display: grid;
418
+ gap: 12px;
419
+ margin-bottom: 14px;
420
  }
421
 
422
+ .action-row label,
423
+ .field-group label {
424
+ display: block;
425
  color: var(--muted);
426
  font-size: 0.78rem;
427
+ letter-spacing: 0.04em;
428
  text-transform: uppercase;
429
+ margin-bottom: 6px;
430
+ }
431
+
432
+ .field-group {
433
+ display: flex;
434
+ flex-direction: column;
435
  }
436
 
437
  .stack-two {
438
+ grid-template-columns: 1fr 1fr;
439
  }
440
 
441
+ /* ── Medication Cards ────────────────────────────────────────── */
442
+
443
  .med-grid {
444
  display: grid;
445
+ grid-template-columns: repeat(auto-fill, minmax(210px, 1fr));
446
+ gap: 12px;
447
+ max-height: 480px;
448
+ overflow-y: auto;
449
  padding-right: 4px;
450
  }
451
 
452
  .med-card {
453
+ border-radius: var(--radius-sm);
454
  border: 1px solid var(--border);
455
  background: var(--panel-solid);
456
+ padding: 16px 18px;
457
+ transition: border-color 150ms ease, background 150ms ease;
458
+ overflow: hidden;
459
  }
460
 
461
  .med-card:hover {
462
+ border-color: rgba(88, 166, 255, 0.4);
463
+ background: #1f2937;
464
+ }
465
+
466
+ .med-card.high-risk {
467
+ border-color: rgba(248, 81, 73, 0.35);
468
+ }
469
+
470
+ .med-card-header {
471
+ display: flex;
472
+ align-items: center;
473
+ justify-content: space-between;
474
+ gap: 8px;
475
+ }
476
+
477
+ .med-card-header strong {
478
+ font-size: 0.92rem;
479
+ font-weight: 600;
480
+ overflow: hidden;
481
+ text-overflow: ellipsis;
482
+ white-space: nowrap;
483
+ color: var(--text);
484
+ }
485
+
486
+ .risk-badge {
487
+ font-size: 0.65rem;
488
+ padding: 3px 9px;
489
+ border-radius: 999px;
490
+ background: rgba(248, 81, 73, 0.12);
491
+ border: 1px solid rgba(248, 81, 73, 0.3);
492
+ color: var(--danger);
493
+ text-transform: uppercase;
494
+ letter-spacing: 0.04em;
495
+ font-weight: 600;
496
+ white-space: nowrap;
497
+ flex-shrink: 0;
498
  }
499
 
500
+ .med-generic {
501
+ margin: 6px 0;
502
  color: var(--muted);
503
+ font-size: 0.84rem;
504
  text-transform: capitalize;
505
+ overflow: hidden;
506
+ text-overflow: ellipsis;
507
+ white-space: nowrap;
508
  }
509
 
510
+ .med-details {
511
+ display: flex;
512
+ gap: 10px;
513
+ font-size: 0.8rem;
514
+ color: #8b949e;
515
  }
516
 
517
+ .med-atc {
518
+ color: var(--primary);
519
+ font-weight: 600;
520
+ }
521
+
522
+ .beers-flags {
523
+ margin-top: 8px;
524
+ display: flex;
525
+ flex-wrap: wrap;
526
+ gap: 5px;
527
+ }
528
+
529
+ .beers-tag {
530
+ font-size: 0.68rem;
531
+ padding: 3px 9px;
532
+ border-radius: 999px;
533
+ background: rgba(210, 153, 34, 0.1);
534
+ border: 1px solid rgba(210, 153, 34, 0.25);
535
+ color: var(--warning);
536
+ text-transform: capitalize;
537
+ }
538
+
539
+ /* ── History Grid ────────────────────────────────────────────── */
540
+
541
+ .history-grid {
542
+ display: grid;
543
+ grid-template-columns: 1fr 1fr;
544
+ gap: 24px;
545
+ }
546
+
547
+ .history-list {
548
+ display: flex;
549
+ flex-direction: column;
550
+ gap: 10px;
551
  max-height: 300px;
552
+ overflow-y: auto;
553
+ }
554
+
555
+ .history-item {
556
+ border-radius: var(--radius-sm);
557
+ border: 1px solid var(--border);
558
+ background: var(--panel-solid);
559
+ padding: 14px 16px;
560
+ }
561
+
562
+ .history-item strong {
563
+ font-size: 0.88rem;
564
+ font-weight: 600;
565
+ display: block;
566
+ margin-bottom: 6px;
567
+ color: var(--text);
568
+ }
569
+
570
+ .history-detail {
571
+ margin: 6px 0 0;
572
+ font-size: 0.82rem;
573
+ color: var(--muted);
574
+ text-transform: capitalize;
575
+ }
576
+
577
+ .severity-tag, .intervention-tag {
578
+ display: inline-block;
579
+ font-size: 0.7rem;
580
+ padding: 3px 10px;
581
+ border-radius: 999px;
582
+ font-weight: 600;
583
+ text-transform: uppercase;
584
+ letter-spacing: 0.04em;
585
+ }
586
+
587
+ .severity-tag.severe {
588
+ background: rgba(248, 81, 73, 0.12);
589
+ border: 1px solid rgba(248, 81, 73, 0.3);
590
+ color: var(--danger);
591
+ }
592
+
593
+ .severity-tag.moderate {
594
+ background: rgba(210, 153, 34, 0.12);
595
+ border: 1px solid rgba(210, 153, 34, 0.3);
596
+ color: var(--warning);
597
+ }
598
+
599
+ .severity-tag.mild {
600
+ background: rgba(63, 185, 80, 0.12);
601
+ border: 1px solid rgba(63, 185, 80, 0.3);
602
+ color: var(--success);
603
+ }
604
+
605
+ .severity-tag.none {
606
+ background: rgba(48, 54, 61, 0.3);
607
+ border: 1px solid var(--border);
608
+ color: var(--muted);
609
+ }
610
+
611
+ .intervention-tag {
612
+ background: rgba(88, 166, 255, 0.1);
613
+ border: 1px solid rgba(88, 166, 255, 0.25);
614
+ color: var(--primary);
615
+ }
616
+
617
+ .severity-severe .history-item {
618
+ border-color: rgba(248, 81, 73, 0.2);
619
+ }
620
+
621
+ .severity-moderate .history-item {
622
+ border-color: rgba(210, 153, 34, 0.2);
623
+ }
624
+
625
+ /* ── Event Log ───────────────────────────────────────────────── */
626
+
627
+ .logs {
628
+ max-height: 280px;
629
+ overflow-y: auto;
630
  padding-right: 4px;
631
+ display: flex;
632
+ flex-direction: column;
633
+ gap: 8px;
634
  font-size: 0.84rem;
635
+ font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace;
636
  }
637
 
638
  .logs div {
639
+ border-radius: var(--radius-sm);
640
  border: 1px solid var(--border);
641
+ background: #0d1117;
642
+ padding: 12px 16px;
643
+ color: var(--text);
644
+ overflow: hidden;
645
+ text-overflow: ellipsis;
646
+ word-break: break-word;
647
+ line-height: 1.5;
648
+ }
649
+
650
+ .log-empty {
651
+ color: var(--muted);
652
+ font-family: inherit;
653
+ font-style: italic;
654
  }
655
 
656
+ /* ── Helper Text ─────────────────────────────────────────────── */
657
+
658
  .muted {
659
  margin: 0;
660
  color: var(--muted);
661
+ font-size: 0.88rem;
662
  }
663
 
664
  .budget-note {
665
+ margin-top: 14px;
666
  border: 1px solid var(--border);
667
+ border-radius: var(--radius-sm);
668
+ padding: 14px 18px;
669
+ background: var(--panel-solid);
670
+ font-size: 0.88rem;
671
+ color: var(--muted);
672
+ }
673
+
674
+ .done-note {
675
+ border-color: rgba(63, 185, 80, 0.3);
676
+ color: var(--success);
677
+ }
678
+
679
+ /* ── Footer ──────────────────────────────────────────────────── */
680
+
681
+ .app-footer {
682
+ margin-top: 28px;
683
+ text-align: center;
684
+ }
685
+
686
+ .app-footer p {
687
+ font-size: 0.8rem;
688
+ color: var(--muted);
689
+ opacity: 0.6;
690
+ }
691
+
692
+ .app-footer a {
693
+ color: var(--primary);
694
+ text-decoration: none;
695
+ }
696
+
697
+ .app-footer a:hover {
698
+ text-decoration: underline;
699
  }
700
 
701
+ /* ── Spotlight Guide ─────────────────────────────────────────── */
702
+
703
+ .spotlight-overlay {
704
+ position: fixed;
705
+ inset: 0;
706
+ z-index: 100;
707
+ pointer-events: none;
708
+ }
709
+
710
+ .spotlight-svg {
711
+ position: fixed;
712
+ inset: 0;
713
+ z-index: 100;
714
+ pointer-events: auto;
715
+ }
716
+
717
+ .spotlight-ring {
718
+ position: fixed;
719
+ z-index: 101;
720
+ border: 2px solid var(--primary);
721
+ border-radius: var(--radius);
722
+ box-shadow: 0 0 0 4px rgba(88, 166, 255, 0.1);
723
+ pointer-events: none;
724
+ transition: top 350ms ease, left 350ms ease, width 350ms ease, height 350ms ease;
725
+ }
726
+
727
+ .spotlight-tooltip {
728
+ position: fixed;
729
+ z-index: 102;
730
+ border-radius: var(--radius);
731
+ padding: 24px;
732
+ pointer-events: auto;
733
+ animation: tooltipIn 250ms ease;
734
+ }
735
+
736
+ @keyframes tooltipIn {
737
+ from {
738
+ opacity: 0;
739
+ transform: translateY(8px);
740
+ }
741
+ to {
742
+ opacity: 1;
743
+ transform: translateY(0);
744
+ }
745
+ }
746
+
747
+ .spotlight-tooltip-header {
748
+ display: flex;
749
+ align-items: flex-start;
750
+ justify-content: space-between;
751
+ gap: 12px;
752
+ margin-bottom: 12px;
753
+ }
754
+
755
+ .spotlight-tooltip-header h3 {
756
+ margin: 0;
757
+ font-size: 1.05rem;
758
+ font-weight: 600;
759
+ color: var(--text);
760
+ text-transform: none;
761
+ letter-spacing: 0;
762
+ line-height: 1.3;
763
+ }
764
+
765
+ .guide-counter {
766
+ font-size: 0.72rem;
767
+ color: var(--muted);
768
+ padding: 4px 10px;
769
+ border-radius: 999px;
770
+ background: rgba(48, 54, 61, 0.4);
771
+ border: 1px solid var(--border);
772
+ white-space: nowrap;
773
+ flex-shrink: 0;
774
+ }
775
+
776
+ .spotlight-tooltip-body {
777
+ margin-bottom: 18px;
778
+ font-size: 0.88rem;
779
+ line-height: 1.65;
780
+ color: var(--muted);
781
+ }
782
+
783
+ .spotlight-tooltip-body p {
784
+ margin: 0 0 5px;
785
+ }
786
+
787
+ .spotlight-tooltip-body p:empty {
788
+ height: 4px;
789
+ }
790
+
791
+ .spotlight-tooltip-footer {
792
+ display: flex;
793
+ gap: 8px;
794
+ justify-content: flex-end;
795
+ }
796
+
797
+ .guide-btn {
798
+ padding: 8px 18px !important;
799
+ font-size: 0.84rem !important;
800
+ border-radius: var(--radius-sm) !important;
801
+ }
802
+
803
+ .guide-dots {
804
+ display: flex;
805
+ justify-content: center;
806
+ gap: 6px;
807
+ margin-top: 14px;
808
+ }
809
+
810
+ .dot {
811
+ width: 7px;
812
+ height: 7px;
813
+ border-radius: 50%;
814
+ background: rgba(48, 54, 61, 0.6);
815
+ transition: background 150ms ease;
816
+ }
817
+
818
+ .dot.active {
819
+ background: var(--primary);
820
+ box-shadow: none;
821
+ }
822
+
823
+ /* ── Responsive ──────────────────────────────────────────────── */
824
+
825
  @media (max-width: 1180px) {
826
  .layout {
827
  grid-template-columns: 1fr;
828
  }
829
 
830
  .topbar {
831
+ flex-direction: column;
832
+ align-items: flex-start;
833
  }
834
 
835
+ .topbar-right {
836
+ align-self: flex-start;
837
  }
838
 
839
  .actions {
840
+ width: 100%;
841
  justify-content: flex-start;
842
  }
843
+
844
+ .history-grid {
845
+ grid-template-columns: 1fr;
846
+ }
847
  }
848
 
849
  @media (max-width: 760px) {
850
  .shell {
851
+ padding: 12px 12px 24px;
852
  }
853
 
854
  .topbar,
855
  .panel {
856
+ border-radius: var(--radius-sm);
857
+ padding: 16px 18px;
858
  }
859
 
860
  .actions {
861
+ flex-direction: column;
862
  }
863
 
864
  .actions button,
 
866
  width: 100%;
867
  }
868
 
869
+ .kpi-grid {
870
+ grid-template-columns: 1fr 1fr;
871
+ }
872
+
873
+ .med-grid {
874
+ grid-template-columns: 1fr;
875
+ }
876
+
877
  .stack-two {
878
  grid-template-columns: 1fr;
879
  }
880
 
881
+ .guide-modal {
882
+ padding: 20px;
883
+ }
884
+
885
+ .spotlight-tooltip {
886
+ left: 10px !important;
887
+ right: 10px !important;
888
+ max-width: calc(100vw - 20px) !important;
889
+ }
890
+
891
+ .guide-footer,
892
+ .spotlight-tooltip-footer {
893
+ flex-direction: column;
894
+ }
895
+
896
  .logs {
897
+ max-height: 200px;
898
  }
899
  }
900
+
901
+ /* ── Scrollbar ───────────────────────────────────────────────── */
902
+
903
+ ::-webkit-scrollbar {
904
+ width: 6px;
905
+ }
906
+
907
+ ::-webkit-scrollbar-track {
908
+ background: transparent;
909
+ }
910
+
911
+ ::-webkit-scrollbar-thumb {
912
+ background: rgba(48, 54, 61, 0.6);
913
+ border-radius: 3px;
914
+ }
915
+
916
+ ::-webkit-scrollbar-thumb:hover {
917
+ background: rgba(139, 148, 158, 0.4);
918
+ }
train_bandit.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """OptimNeuralTS training -- Neural Bandit search for dangerous polypharmacies.
3
+
4
+ Implements the training pipeline from:
5
+ Larouche et al., "Neural Bandits for Data Mining: Searching for Dangerous Polypharmacy"
6
+ https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5
7
+
8
+ This script:
9
+ 1. Generates a synthetic dataset of drug combinations with simulated Relative Risk (RR)
10
+ 2. Runs OptimNeuralTS: warm-up -> NeuralTS+DE exploration -> ensemble building
11
+ 3. Evaluates the ensemble's ability to detect Potentially Inappropriate Polypharmacies (PIPs)
12
+ 4. Saves the trained ensemble model
13
+
14
+ Usage:
15
+ python train_bandit.py --total-steps 1000 --warmup-steps 200
16
+ python train_bandit.py --total-steps 3000 --warmup-steps 500 --eval-every 100
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import math
24
+ import os
25
+ import random
26
+ import sys
27
+ import time
28
+ from itertools import combinations
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Tuple
31
+
32
+ import torch
33
+
34
+ _BACKEND_SRC = os.path.join(
35
+ os.path.dirname(os.path.abspath(__file__)), "backend", "src"
36
+ )
37
+ sys.path.insert(0, _BACKEND_SRC)
38
+
39
+ from polypharmacy_env.neural_bandits import NeuralTS, OptimNeuralTS, nearest_neighbor_hamming # noqa: E402
40
+ from polypharmacy_env.data_loader import load_drug_metadata, load_ddi_rules # noqa: E402
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Synthetic RR data generation (follows paper Section 4.1)
45
+ # ---------------------------------------------------------------------------
46
+
47
+ def generate_synthetic_dataset(
48
+ n_drugs: int = 33,
49
+ n_combinations: int = 5000,
50
+ n_dangerous_patterns: int = 10,
51
+ rr_threshold: float = 1.1,
52
+ noise_std: float = 0.1,
53
+ seed: int = 42,
54
+ ) -> Dict[str, Any]:
55
+ """Generate synthetic drug combination data with ground-truth RRs.
56
+
57
+ Follows the paper's data generation process:
58
+ - Generate dangerous patterns (binomial)
59
+ - For each combination, compute similarity to nearest pattern
60
+ - Assign RR proportional to similarity (if overlapping) or from N(mu, sigma) if disjoint
61
+ """
62
+ rng = random.Random(seed)
63
+ torch.manual_seed(seed)
64
+
65
+ # Generate dangerous patterns (multi-hot vectors)
66
+ patterns = []
67
+ for _ in range(n_dangerous_patterns):
68
+ # Each drug has ~30% chance of being in the pattern (smaller patterns)
69
+ p = torch.zeros(n_drugs)
70
+ n_active = rng.randint(2, max(3, n_drugs // 8))
71
+ indices = rng.sample(range(n_drugs), n_active)
72
+ for idx in indices:
73
+ p[idx] = 1.0
74
+ patterns.append(p)
75
+
76
+ # Generate distinct drug combinations
77
+ combos = []
78
+ combo_set = set()
79
+ while len(combos) < n_combinations:
80
+ n_active = rng.randint(2, min(8, n_drugs))
81
+ indices = tuple(sorted(rng.sample(range(n_drugs), n_active)))
82
+ if indices not in combo_set:
83
+ combo_set.add(indices)
84
+ vec = torch.zeros(n_drugs)
85
+ for idx in indices:
86
+ vec[idx] = 1.0
87
+ combos.append(vec)
88
+
89
+ # Compute RR for each combination based on Hamming distance to nearest pattern
90
+ rrs = []
91
+ nearest_pattern_idx = []
92
+ for combo in combos:
93
+ # Find nearest pattern (Hamming distance)
94
+ min_dist = float("inf")
95
+ best_p_idx = 0
96
+ for p_idx, pattern in enumerate(patterns):
97
+ dist = (combo != pattern).float().sum().item()
98
+ if dist < min_dist:
99
+ min_dist = dist
100
+ best_p_idx = p_idx
101
+
102
+ nearest_pattern_idx.append(best_p_idx)
103
+ pattern = patterns[best_p_idx]
104
+
105
+ # Check intersection (shared active drugs)
106
+ intersection = (combo * pattern).sum().item()
107
+ if intersection > 0:
108
+ # RR proportional to similarity
109
+ similarity = intersection / max(pattern.sum().item(), 1)
110
+ # Higher similarity -> higher RR
111
+ base_rr = 0.5 + 2.5 * similarity # range ~[0.5, 3.0]
112
+ noise = rng.gauss(0, 0.15)
113
+ rr = max(0.1, base_rr + noise)
114
+ else:
115
+ # Disjoint: sample from neutral distribution
116
+ rr = max(0.1, rng.gauss(0.85, 0.2))
117
+
118
+ rrs.append(rr)
119
+
120
+ # Compute pattern RRs (patterns themselves have high RR)
121
+ pattern_rrs = [2.0 + rng.gauss(0, 0.3) for _ in patterns]
122
+
123
+ n_pips = sum(1 for rr in rrs if rr > rr_threshold)
124
+ print(f" Generated {n_combinations} combos, {n_pips} PIPs (RR > {rr_threshold})")
125
+ print(f" RR range: [{min(rrs):.3f}, {max(rrs):.3f}], mean: {sum(rrs)/len(rrs):.3f}")
126
+
127
+ return {
128
+ "combos": combos,
129
+ "rrs": rrs,
130
+ "patterns": patterns,
131
+ "pattern_rrs": pattern_rrs,
132
+ "n_drugs": n_drugs,
133
+ "n_pips": n_pips,
134
+ "rr_threshold": rr_threshold,
135
+ "noise_std": noise_std,
136
+ }
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Training loop
141
+ # ---------------------------------------------------------------------------
142
+
143
+ def train_bandit(args: argparse.Namespace) -> None:
144
+ print("=" * 72)
145
+ print("OptimNeuralTS Training -- Neural Bandits for Polypharmacy")
146
+ print("=" * 72)
147
+
148
+ # Generate synthetic data
149
+ print("\nGenerating synthetic dataset...")
150
+ dataset = generate_synthetic_dataset(
151
+ n_drugs=args.n_drugs,
152
+ n_combinations=args.n_combinations,
153
+ n_dangerous_patterns=args.n_patterns,
154
+ seed=args.seed,
155
+ )
156
+ combos = dataset["combos"]
157
+ rrs = dataset["rrs"]
158
+ patterns = dataset["patterns"]
159
+ pattern_rrs = dataset["pattern_rrs"]
160
+ noise_std = dataset["noise_std"]
161
+ rr_threshold = dataset["rr_threshold"]
162
+
163
+ # Initialize OptimNeuralTS
164
+ bandit = OptimNeuralTS(
165
+ input_dim=args.n_drugs,
166
+ hidden=args.hidden_dim,
167
+ reg_lambda=args.reg_lambda,
168
+ exploration_factor=args.exploration_factor,
169
+ lr=args.lr,
170
+ train_epochs=args.train_epochs,
171
+ warmup_steps=args.warmup_steps,
172
+ total_steps=args.total_steps,
173
+ retrain_every=args.retrain_every,
174
+ de_population=args.de_population,
175
+ de_crossover=args.de_crossover,
176
+ de_weight=args.de_weight,
177
+ de_steps=args.de_steps,
178
+ )
179
+
180
+ print(f"\n n_drugs : {args.n_drugs}")
181
+ print(f" n_combinations : {args.n_combinations}")
182
+ print(f" total_steps (T) : {args.total_steps}")
183
+ print(f" warmup_steps (τ) : {args.warmup_steps}")
184
+ print(f" DE population (N) : {args.de_population}")
185
+ print(f" DE steps (S) : {args.de_steps}")
186
+ print(f" retrain_every : {args.retrain_every}")
187
+ print(f" hidden_dim : {args.hidden_dim}")
188
+ print(f" lr : {args.lr}")
189
+ print("=" * 72)
190
+
191
+ t_start = time.time()
192
+
193
+ # Metrics tracking
194
+ step_rewards = []
195
+ pips_found = []
196
+ eval_precisions = []
197
+ eval_recalls = []
198
+ training_dataset_indices = set()
199
+
200
+ for t in range(1, args.total_steps + 1):
201
+ # Select action
202
+ idx, info = bandit.select_action(combos)
203
+ training_dataset_indices.add(idx)
204
+
205
+ # Observe noisy reward (RR + noise)
206
+ true_rr = rrs[idx]
207
+ noisy_rr = true_rr + random.gauss(0, noise_std)
208
+ reward = noisy_rr
209
+
210
+ step_rewards.append(reward)
211
+
212
+ # Update bandit
213
+ loss = bandit.observe(combos[idx], reward)
214
+
215
+ # Periodic evaluation
216
+ if t % args.eval_every == 0 or t == args.total_steps:
217
+ # Evaluate ensemble on ALL combinations
218
+ true_positives = 0
219
+ false_positives = 0
220
+ true_negatives = 0
221
+ false_negatives = 0
222
+
223
+ for i, combo in enumerate(combos):
224
+ pred = bandit.predict_risk(combo)
225
+ actual_pip = rrs[i] > rr_threshold
226
+ predicted_pip = pred["is_potentially_harmful"]
227
+
228
+ if predicted_pip and actual_pip:
229
+ true_positives += 1
230
+ elif predicted_pip and not actual_pip:
231
+ false_positives += 1
232
+ elif not predicted_pip and actual_pip:
233
+ false_negatives += 1
234
+ else:
235
+ true_negatives += 1
236
+
237
+ precision = true_positives / max(true_positives + false_positives, 1)
238
+ recall = true_positives / max(true_positives + false_negatives, 1)
239
+ eval_precisions.append(precision)
240
+ eval_recalls.append(recall)
241
+
242
+ # Check dangerous pattern detection
243
+ patterns_found = 0
244
+ for p_idx, pattern in enumerate(patterns):
245
+ pred = bandit.predict_risk(pattern)
246
+ if pred["is_potentially_harmful"]:
247
+ patterns_found += 1
248
+ pattern_ratio = patterns_found / len(patterns)
249
+
250
+ # PIPs found outside training data
251
+ pips_outside_train = 0
252
+ total_detected_pips = 0
253
+ for i, combo in enumerate(combos):
254
+ pred = bandit.predict_risk(combo)
255
+ if pred["is_potentially_harmful"]:
256
+ total_detected_pips += 1
257
+ if i not in training_dataset_indices:
258
+ pips_outside_train += 1
259
+
260
+ pips_found.append(total_detected_pips)
261
+
262
+ elapsed = time.time() - t_start
263
+ phase = info.get("phase", "?")
264
+ n_ens = len(bandit.agent.ensemble_weights)
265
+ print(
266
+ f"[step {t:>5d}/{args.total_steps}] "
267
+ f"phase={phase} "
268
+ f"precision={precision:.3f} "
269
+ f"recall={recall:.3f} "
270
+ f"patterns={pattern_ratio:.2f} "
271
+ f"PIPs_detected={total_detected_pips} "
272
+ f"outside_train={pips_outside_train} "
273
+ f"ensemble={n_ens} "
274
+ f"elapsed={elapsed:.1f}s"
275
+ )
276
+
277
+ # Save metrics
278
+ metrics = {
279
+ "algorithm": "OptimNeuralTS",
280
+ "n_drugs": args.n_drugs,
281
+ "n_combinations": args.n_combinations,
282
+ "total_steps": args.total_steps,
283
+ "warmup_steps": args.warmup_steps,
284
+ "n_ensemble_models": len(bandit.agent.ensemble_weights),
285
+ "final_precision": eval_precisions[-1] if eval_precisions else 0,
286
+ "final_recall": eval_recalls[-1] if eval_recalls else 0,
287
+ "eval_precisions": eval_precisions,
288
+ "eval_recalls": eval_recalls,
289
+ "pips_detected": pips_found,
290
+ "step_rewards": step_rewards,
291
+ "total_time_s": time.time() - t_start,
292
+ "hyperparameters": {
293
+ "hidden_dim": args.hidden_dim,
294
+ "lr": args.lr,
295
+ "reg_lambda": args.reg_lambda,
296
+ "exploration_factor": args.exploration_factor,
297
+ "de_population": args.de_population,
298
+ "de_crossover": args.de_crossover,
299
+ "de_weight": args.de_weight,
300
+ "de_steps": args.de_steps,
301
+ "train_epochs": args.train_epochs,
302
+ "retrain_every": args.retrain_every,
303
+ },
304
+ }
305
+
306
+ metrics_path = Path(args.metrics_file)
307
+ metrics_path.parent.mkdir(parents=True, exist_ok=True)
308
+ with open(metrics_path, "w") as f:
309
+ json.dump(metrics, f, indent=2)
310
+ print(f"\nMetrics saved to {metrics_path}")
311
+
312
+ # Save model ensemble
313
+ ckpt_dir = Path(args.checkpoint_dir)
314
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
315
+ ckpt_path = ckpt_dir / "bandit_ensemble.pt"
316
+ torch.save({
317
+ "ensemble_weights": bandit.agent.ensemble_weights,
318
+ "network_state_dict": bandit.agent.network.state_dict(),
319
+ "U_diag": bandit.agent.U_diag,
320
+ "input_dim": args.n_drugs,
321
+ "hidden_dim": args.hidden_dim,
322
+ "n_steps": args.total_steps,
323
+ }, ckpt_path)
324
+ print(f"Ensemble model saved to {ckpt_path}")
325
+
326
+ print(f"\n{'='*72}")
327
+ print("Training complete!")
328
+ print(f" Ensemble size: {len(bandit.agent.ensemble_weights)} models")
329
+ if eval_precisions:
330
+ print(f" Final precision: {eval_precisions[-1]:.4f}")
331
+ print(f" Final recall: {eval_recalls[-1]:.4f}")
332
+ print(f" Total time: {time.time() - t_start:.1f}s")
333
+ print(f"{'='*72}")
334
+
335
+
336
+ # ---------------------------------------------------------------------------
337
+ # CLI
338
+ # ---------------------------------------------------------------------------
339
+
340
+ def parse_args() -> argparse.Namespace:
341
+ p = argparse.ArgumentParser(
342
+ description="OptimNeuralTS training for polypharmacy PIP detection",
343
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
344
+ )
345
+ # Dataset
346
+ p.add_argument("--n-drugs", type=int, default=33, help="Number of possible drugs")
347
+ p.add_argument("--n-combinations", type=int, default=5000, help="Number of distinct drug combinations")
348
+ p.add_argument("--n-patterns", type=int, default=10, help="Number of dangerous patterns")
349
+ p.add_argument("--seed", type=int, default=42, help="Random seed")
350
+
351
+ # OptimNeuralTS
352
+ p.add_argument("--total-steps", type=int, default=1000, help="Total bandit steps T")
353
+ p.add_argument("--warmup-steps", type=int, default=200, help="Warmup steps τ")
354
+ p.add_argument("--retrain-every", type=int, default=10, help="Retrain network every N steps")
355
+ p.add_argument("--hidden-dim", type=int, default=64, help="Network hidden layer size")
356
+ p.add_argument("--lr", type=float, default=0.01, help="Learning rate")
357
+ p.add_argument("--reg-lambda", type=float, default=1.0, help="Regularization λ")
358
+ p.add_argument("--exploration-factor", type=float, default=1.0, help="Exploration ν")
359
+ p.add_argument("--train-epochs", type=int, default=50, help="Epochs per retrain")
360
+
361
+ # DE
362
+ p.add_argument("--de-population", type=int, default=16, help="DE population size N")
363
+ p.add_argument("--de-crossover", type=float, default=0.9, help="DE crossover rate C")
364
+ p.add_argument("--de-weight", type=float, default=1.0, help="DE differential weight F")
365
+ p.add_argument("--de-steps", type=int, default=8, help="DE optimization steps S")
366
+
367
+ # Output
368
+ p.add_argument("--eval-every", type=int, default=100, help="Evaluate every N steps")
369
+ p.add_argument("--metrics-file", type=str, default="bandit_metrics.json", help="Metrics output path")
370
+ p.add_argument(
371
+ "--checkpoint-dir", type=str,
372
+ default=os.path.join(_BACKEND_SRC, "polypharmacy_env", "checkpoints"),
373
+ help="Model checkpoint directory",
374
+ )
375
+
376
+ return p.parse_args()
377
+
378
+
379
+ if __name__ == "__main__":
380
+ args = parse_args()
381
+ train_bandit(args)
train_rl.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """REINFORCE with Learned Baseline -- RL training for PolypharmacyEnv.
3
+
4
+ Trains a small neural-network policy to perform medication reviews in the
5
+ PolypharmacyEnv environment. The policy learns to query drug-drug interactions,
6
+ propose clinical interventions, and decide when to finalise the review.
7
+
8
+ Usage examples:
9
+ python train_rl.py --task easy_screening --episodes 200
10
+ python train_rl.py --task budgeted_screening --episodes 500
11
+ python train_rl.py --task complex_tradeoff --episodes 1000
12
+ python train_rl.py --task easy_screening --episodes 300 --lr 5e-4 --batch-size 8
13
+
14
+ Architecture:
15
+ - Fixed-size state encoding (16-dim global summary features)
16
+ - Fixed 166-dim action space with dynamic validity masking
17
+ - 3-layer MLP policy (state -> logits over actions)
18
+ - 3-layer MLP value baseline (state -> scalar return estimate)
19
+ - REINFORCE gradient with advantage = (discounted return) - baseline
20
+ - Entropy bonus for sustained exploration
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import json
27
+ import os
28
+ import sys
29
+ import time
30
+ from itertools import combinations
31
+ from pathlib import Path
32
+ from typing import Any, Dict, List, Optional, Set, Tuple
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torch.distributions import Categorical
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Environment imports (direct, no HTTP)
41
+ # ---------------------------------------------------------------------------
42
+ _BACKEND_SRC = os.path.join(
43
+ os.path.dirname(os.path.abspath(__file__)), "backend", "src"
44
+ )
45
+ sys.path.insert(0, _BACKEND_SRC)
46
+
47
+ from polypharmacy_env.env_core import PolypharmacyEnv # noqa: E402
48
+ from polypharmacy_env.models import ( # noqa: E402
49
+ PolypharmacyAction,
50
+ PolypharmacyObservation,
51
+ )
52
+ from polypharmacy_env.config import TASK_CONFIGS, TaskConfig # noqa: E402
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Constants -- action-space geometry
56
+ # ---------------------------------------------------------------------------
57
+ MAX_MEDS = 15 # upper bound across all task difficulties
58
+ INTERVENTION_TYPES: List[str] = [
59
+ "stop",
60
+ "dose_reduce",
61
+ "substitute",
62
+ "add_monitoring",
63
+ ]
64
+ N_INTERVENTION_TYPES = len(INTERVENTION_TYPES)
65
+
66
+ # Pre-compute the mapping (med_position_i, med_position_j) -> flat action index
67
+ # for all possible query_ddi pairs where i < j.
68
+ _PAIR_INDEX: Dict[Tuple[int, int], int] = {}
69
+ _idx = 0
70
+ for _i in range(MAX_MEDS):
71
+ for _j in range(_i + 1, MAX_MEDS):
72
+ _PAIR_INDEX[(_i, _j)] = _idx
73
+ _idx += 1
74
+ N_PAIRS = _idx # C(15,2) = 105
75
+ _REVERSE_PAIR: Dict[int, Tuple[int, int]] = {v: k for k, v in _PAIR_INDEX.items()}
76
+
77
+ N_INTERVENTIONS = MAX_MEDS * N_INTERVENTION_TYPES # 60
78
+ FINISH_IDX = N_PAIRS + N_INTERVENTIONS # 165
79
+ N_ACTIONS = FINISH_IDX + 1 # 166
80
+
81
+ # State feature vector length (see encode_state)
82
+ STATE_DIM = 16
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # State encoding
86
+ # ---------------------------------------------------------------------------
87
+
88
+ def encode_state(obs: PolypharmacyObservation, task_cfg: TaskConfig) -> torch.Tensor:
89
+ """Encode the observation into a compact 16-dim feature vector.
90
+
91
+ All values are normalised to roughly [0, 1] to help gradient flow.
92
+ """
93
+ meds = obs.current_medications
94
+ n_meds = len(meds)
95
+
96
+ n_high_risk = sum(1 for m in meds if m.is_high_risk_elderly)
97
+ n_beers_any = sum(1 for m in meds if m.beers_flags)
98
+ n_beers_avoid = sum(
99
+ 1 for m in meds if any("avoid" in f for f in m.beers_flags)
100
+ )
101
+
102
+ queries = obs.interaction_queries
103
+ n_queries = len(queries)
104
+ n_severe = sum(1 for q in queries if q.severity == "severe")
105
+ n_moderate = sum(1 for q in queries if q.severity == "moderate")
106
+ n_interventions = len(obs.interventions)
107
+
108
+ max_possible_pairs = max(n_meds * (n_meds - 1) // 2, 1)
109
+
110
+ # Drugs involved in any discovered severe DDI (among current meds)
111
+ current_ids = {m.drug_id for m in meds}
112
+ drugs_in_severe: Set[str] = set()
113
+ for q in queries:
114
+ if q.severity == "severe":
115
+ if q.drug_id_1 in current_ids:
116
+ drugs_in_severe.add(q.drug_id_1)
117
+ if q.drug_id_2 in current_ids:
118
+ drugs_in_severe.add(q.drug_id_2)
119
+
120
+ features = [
121
+ n_meds / MAX_MEDS,
122
+ n_high_risk / max(n_meds, 1),
123
+ n_beers_any / max(n_meds, 1),
124
+ n_beers_avoid / max(n_meds, 1),
125
+ obs.remaining_query_budget / max(task_cfg.query_budget, 1),
126
+ obs.remaining_intervention_budget / max(task_cfg.intervention_budget, 1),
127
+ n_queries / max(task_cfg.query_budget, 1),
128
+ n_severe / max(n_queries, 1),
129
+ n_moderate / max(n_queries, 1),
130
+ n_interventions / max(task_cfg.intervention_budget, 1),
131
+ obs.step_index / max(task_cfg.max_steps, 1),
132
+ n_queries / max_possible_pairs, # fraction of pairs queried
133
+ float(obs.remaining_query_budget > 0),
134
+ float(obs.remaining_intervention_budget > 0),
135
+ len(drugs_in_severe) / max(n_meds, 1), # how much of the regimen is "hot"
136
+ float(n_meds <= 2), # very few meds left -- may be time to finish
137
+ ]
138
+ return torch.tensor(features, dtype=torch.float32)
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Action-space helpers
143
+ # ---------------------------------------------------------------------------
144
+
145
+ def get_action_mask(obs: PolypharmacyObservation) -> torch.Tensor:
146
+ """Return a bool tensor of shape (N_ACTIONS,). True = action is valid."""
147
+ mask = torch.zeros(N_ACTIONS, dtype=torch.bool)
148
+ meds = obs.current_medications
149
+ n_meds = min(len(meds), MAX_MEDS)
150
+
151
+ # Already-queried drug-id pairs (order-invariant)
152
+ queried: Set[frozenset] = set()
153
+ for q in obs.interaction_queries:
154
+ queried.add(frozenset((q.drug_id_1, q.drug_id_2)))
155
+
156
+ # --- query_ddi actions ---
157
+ if obs.remaining_query_budget > 0 and n_meds >= 2:
158
+ for i in range(n_meds):
159
+ for j in range(i + 1, n_meds):
160
+ pair_key = frozenset((meds[i].drug_id, meds[j].drug_id))
161
+ if pair_key not in queried:
162
+ mask[_PAIR_INDEX[(i, j)]] = True
163
+
164
+ # --- propose_intervention actions ---
165
+ if obs.remaining_intervention_budget > 0:
166
+ for i in range(n_meds):
167
+ for k in range(N_INTERVENTION_TYPES):
168
+ mask[N_PAIRS + i * N_INTERVENTION_TYPES + k] = True
169
+
170
+ # --- finish_review (always valid) ---
171
+ mask[FINISH_IDX] = True
172
+ return mask
173
+
174
+
175
+ def action_idx_to_env_action(
176
+ idx: int,
177
+ meds: list,
178
+ ) -> PolypharmacyAction:
179
+ """Map a flat action index back to a concrete PolypharmacyAction."""
180
+ if idx == FINISH_IDX:
181
+ return PolypharmacyAction(action_type="finish_review")
182
+
183
+ if idx < N_PAIRS:
184
+ i, j = _REVERSE_PAIR[idx]
185
+ return PolypharmacyAction(
186
+ action_type="query_ddi",
187
+ drug_id_1=meds[i].drug_id,
188
+ drug_id_2=meds[j].drug_id,
189
+ )
190
+
191
+ # Otherwise it is an intervention action
192
+ rel = idx - N_PAIRS
193
+ med_idx = rel // N_INTERVENTION_TYPES
194
+ type_idx = rel % N_INTERVENTION_TYPES
195
+ return PolypharmacyAction(
196
+ action_type="propose_intervention",
197
+ target_drug_id=meds[med_idx].drug_id,
198
+ intervention_type=INTERVENTION_TYPES[type_idx],
199
+ rationale="rl_policy",
200
+ )
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # Neural-network modules
205
+ # ---------------------------------------------------------------------------
206
+
207
+ class PolicyNetwork(nn.Module):
208
+ """3-layer MLP that maps state features to action logits."""
209
+
210
+ def __init__(
211
+ self,
212
+ state_dim: int = STATE_DIM,
213
+ action_dim: int = N_ACTIONS,
214
+ hidden: int = 128,
215
+ ) -> None:
216
+ super().__init__()
217
+ self.fc1 = nn.Linear(state_dim, hidden)
218
+ self.fc2 = nn.Linear(hidden, hidden)
219
+ self.fc3 = nn.Linear(hidden, action_dim)
220
+
221
+ def forward(
222
+ self,
223
+ state: torch.Tensor,
224
+ mask: torch.Tensor,
225
+ ) -> Categorical:
226
+ x = F.relu(self.fc1(state))
227
+ x = F.relu(self.fc2(x))
228
+ logits = self.fc3(x)
229
+ logits = logits.masked_fill(~mask, float("-inf"))
230
+ return Categorical(logits=logits)
231
+
232
+
233
+ class ValueNetwork(nn.Module):
234
+ """3-layer MLP baseline that estimates the expected return from a state."""
235
+
236
+ def __init__(self, state_dim: int = STATE_DIM, hidden: int = 128) -> None:
237
+ super().__init__()
238
+ self.fc1 = nn.Linear(state_dim, hidden)
239
+ self.fc2 = nn.Linear(hidden, hidden // 2)
240
+ self.fc3 = nn.Linear(hidden // 2, 1)
241
+
242
+ def forward(self, state: torch.Tensor) -> torch.Tensor:
243
+ x = F.relu(self.fc1(state))
244
+ x = F.relu(self.fc2(x))
245
+ return self.fc3(x).squeeze(-1)
246
+
247
+
248
+ # ---------------------------------------------------------------------------
249
+ # Episode rollout
250
+ # ---------------------------------------------------------------------------
251
+
252
+ def run_episode(
253
+ env: PolypharmacyEnv,
254
+ task_id: str,
255
+ policy: PolicyNetwork,
256
+ value_net: ValueNetwork,
257
+ task_cfg: TaskConfig,
258
+ seed: Optional[int] = None,
259
+ greedy: bool = False,
260
+ ) -> Dict[str, Any]:
261
+ """Roll out one full episode, collecting the REINFORCE trajectory.
262
+
263
+ When *greedy* is True the policy acts deterministically (argmax) and
264
+ gradients are not recorded. Used for evaluation.
265
+ """
266
+ obs = env.reset(task_id=task_id, seed=seed)
267
+
268
+ states: List[torch.Tensor] = []
269
+ actions: List[torch.Tensor] = []
270
+ log_probs: List[torch.Tensor] = []
271
+ rewards: List[float] = []
272
+ values: List[torch.Tensor] = []
273
+ entropies: List[torch.Tensor] = []
274
+
275
+ grader_score = 0.0
276
+
277
+ while not obs.done:
278
+ state = encode_state(obs, task_cfg)
279
+ mask = get_action_mask(obs)
280
+
281
+ # Safety: if somehow no action is valid, force finish
282
+ if not mask.any():
283
+ mask[FINISH_IDX] = True
284
+
285
+ if greedy:
286
+ with torch.no_grad():
287
+ dist = policy(state, mask)
288
+ action_idx = dist.probs.argmax()
289
+ value = value_net(state)
290
+ else:
291
+ with torch.no_grad():
292
+ value = value_net(state)
293
+ dist = policy(state, mask)
294
+ action_idx = dist.sample()
295
+
296
+ log_prob = dist.log_prob(action_idx)
297
+ entropy = dist.entropy()
298
+
299
+ states.append(state)
300
+ actions.append(action_idx)
301
+ log_probs.append(log_prob)
302
+ values.append(value)
303
+ entropies.append(entropy)
304
+
305
+ env_action = action_idx_to_env_action(
306
+ action_idx.item(), obs.current_medications
307
+ )
308
+ obs = env.step(env_action)
309
+
310
+ reward = float(obs.reward) if obs.reward is not None else 0.0
311
+ rewards.append(reward)
312
+
313
+ if obs.done:
314
+ grader_score = obs.metadata.get("grader_score", 0.0)
315
+
316
+ return {
317
+ "states": states,
318
+ "actions": actions,
319
+ "log_probs": log_probs,
320
+ "rewards": rewards,
321
+ "values": values,
322
+ "entropies": entropies,
323
+ "grader_score": grader_score,
324
+ "total_reward": sum(rewards),
325
+ "n_steps": len(rewards),
326
+ }
327
+
328
+
329
+ # ---------------------------------------------------------------------------
330
+ # Return computation
331
+ # ---------------------------------------------------------------------------
332
+
333
+ def compute_returns(rewards: List[float], gamma: float = 0.99) -> torch.Tensor:
334
+ """Discounted cumulative returns (G_t) for each timestep."""
335
+ returns: List[float] = []
336
+ g = 0.0
337
+ for r in reversed(rewards):
338
+ g = r + gamma * g
339
+ returns.insert(0, g)
340
+ return torch.tensor(returns, dtype=torch.float32)
341
+
342
+
343
+ # ---------------------------------------------------------------------------
344
+ # Training
345
+ # ---------------------------------------------------------------------------
346
+
347
+ def train(args: argparse.Namespace) -> None: # noqa: C901 (complex but linear)
348
+ task_id: str = args.task
349
+ n_episodes: int = args.episodes
350
+ lr: float = args.lr
351
+ gamma: float = args.gamma
352
+ entropy_coeff: float = args.entropy_coeff
353
+ batch_size: int = args.batch_size
354
+ hidden_dim: int = args.hidden_dim
355
+ print_every: int = args.print_every
356
+
357
+ task_cfg = TASK_CONFIGS[task_id]
358
+
359
+ # ---- Initialise env & networks ----------------------------------------
360
+ env = PolypharmacyEnv()
361
+ policy = PolicyNetwork(STATE_DIM, N_ACTIONS, hidden=hidden_dim)
362
+ value_net = ValueNetwork(STATE_DIM, hidden=hidden_dim)
363
+
364
+ policy_optim = torch.optim.Adam(policy.parameters(), lr=lr)
365
+ value_optim = torch.optim.Adam(value_net.parameters(), lr=lr * 3)
366
+
367
+ # ---- Book-keeping -----------------------------------------------------
368
+ ckpt_dir = Path(args.checkpoint_dir)
369
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
370
+ metrics_path = Path(args.metrics_file)
371
+
372
+ episode_rewards: List[float] = []
373
+ episode_grader_scores: List[float] = []
374
+ episode_steps: List[int] = []
375
+ episode_policy_losses: List[float] = []
376
+ episode_value_losses: List[float] = []
377
+
378
+ best_avg_score: float = -float("inf")
379
+
380
+ print("=" * 72)
381
+ print("REINFORCE Training -- PolypharmacyEnv")
382
+ print("=" * 72)
383
+ print(f" task : {task_id}")
384
+ print(f" episodes : {n_episodes}")
385
+ print(f" batch_size : {batch_size}")
386
+ print(f" lr : {lr}")
387
+ print(f" gamma : {gamma}")
388
+ print(f" entropy_coeff : {entropy_coeff}")
389
+ print(f" hidden_dim : {hidden_dim}")
390
+ print(f" state_dim : {STATE_DIM}")
391
+ print(f" action_space : {N_ACTIONS}")
392
+ print(f" task budgets : query={task_cfg.query_budget} "
393
+ f"intervention={task_cfg.intervention_budget} "
394
+ f"max_steps={task_cfg.max_steps}")
395
+ print(f" checkpoint_dir : {ckpt_dir}")
396
+ print(f" metrics_file : {metrics_path}")
397
+ print("=" * 72)
398
+ print()
399
+
400
+ t_start = time.time()
401
+
402
+ # ---- Main training loop -----------------------------------------------
403
+ # Accumulate a mini-batch of trajectories, then perform one gradient step.
404
+ batch_trajs: List[Dict[str, Any]] = []
405
+
406
+ for ep in range(1, n_episodes + 1):
407
+ traj = run_episode(env, task_id, policy, value_net, task_cfg, seed=ep)
408
+
409
+ episode_rewards.append(traj["total_reward"])
410
+ episode_grader_scores.append(traj["grader_score"])
411
+ episode_steps.append(traj["n_steps"])
412
+
413
+ if traj["n_steps"] == 0:
414
+ # Degenerate episode (should not happen); skip update
415
+ continue
416
+
417
+ batch_trajs.append(traj)
418
+
419
+ # ---- Gradient step every batch_size episodes ----------------------
420
+ if len(batch_trajs) >= batch_size:
421
+ # Aggregate losses across the batch
422
+ total_policy_loss = torch.tensor(0.0)
423
+ total_value_loss = torch.tensor(0.0)
424
+ total_entropy = torch.tensor(0.0)
425
+ total_steps = 0
426
+
427
+ for bt in batch_trajs:
428
+ returns = compute_returns(bt["rewards"], gamma)
429
+ old_values_t = torch.stack(bt["values"]) # detached, from rollout
430
+ log_probs_t = torch.stack(bt["log_probs"])
431
+ entropies_t = torch.stack(bt["entropies"])
432
+
433
+ # Advantages use the *detached* rollout values as baseline
434
+ advantages = returns - old_values_t.detach()
435
+ # Per-trajectory advantage normalisation (reduces variance)
436
+ if len(advantages) > 1:
437
+ advantages = (advantages - advantages.mean()) / (
438
+ advantages.std() + 1e-8
439
+ )
440
+
441
+ # REINFORCE policy gradient (negative because we minimise)
442
+ total_policy_loss = total_policy_loss + (
443
+ -(log_probs_t * advantages).sum()
444
+ )
445
+
446
+ # Recompute value predictions WITH gradients for the value loss
447
+ states_t = torch.stack(bt["states"])
448
+ fresh_values = value_net(states_t)
449
+ total_value_loss = total_value_loss + F.mse_loss(
450
+ fresh_values, returns, reduction="sum"
451
+ )
452
+
453
+ # Entropy (we want to maximise -> subtract from loss)
454
+ total_entropy = total_entropy + entropies_t.sum()
455
+ total_steps += len(bt["rewards"])
456
+
457
+ # Normalise by total number of timesteps in the batch
458
+ total_policy_loss = total_policy_loss / total_steps
459
+ total_value_loss = total_value_loss / total_steps
460
+ total_entropy = total_entropy / total_steps
461
+
462
+ # Combined policy loss with entropy bonus
463
+ combined_policy_loss = total_policy_loss - entropy_coeff * total_entropy
464
+
465
+ policy_optim.zero_grad()
466
+ combined_policy_loss.backward()
467
+ nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
468
+ policy_optim.step()
469
+
470
+ value_optim.zero_grad()
471
+ total_value_loss.backward()
472
+ nn.utils.clip_grad_norm_(value_net.parameters(), max_norm=1.0)
473
+ value_optim.step()
474
+
475
+ episode_policy_losses.append(total_policy_loss.item())
476
+ episode_value_losses.append(total_value_loss.item())
477
+
478
+ batch_trajs = []
479
+
480
+ # ---- Logging ------------------------------------------------------
481
+ if ep % print_every == 0 or ep == 1:
482
+ window = min(print_every, ep)
483
+ recent_r = episode_rewards[-window:]
484
+ recent_s = episode_grader_scores[-window:]
485
+ recent_st = episode_steps[-window:]
486
+ avg_r = sum(recent_r) / len(recent_r)
487
+ avg_s = sum(recent_s) / len(recent_s)
488
+ avg_st = sum(recent_st) / len(recent_st)
489
+ elapsed = time.time() - t_start
490
+ print(
491
+ f"[ep {ep:>4d}/{n_episodes}] "
492
+ f"avg_reward={avg_r:+.4f} "
493
+ f"avg_grader={avg_s:.4f} "
494
+ f"avg_steps={avg_st:.1f} "
495
+ f"elapsed={elapsed:.1f}s"
496
+ )
497
+
498
+ # Save best checkpoint based on rolling grader score
499
+ eval_window = min(30, ep)
500
+ rolling_score = sum(episode_grader_scores[-eval_window:]) / eval_window
501
+ if rolling_score > best_avg_score:
502
+ best_avg_score = rolling_score
503
+ _save_checkpoint(
504
+ policy, value_net, policy_optim, value_optim,
505
+ ep, best_avg_score, task_id,
506
+ ckpt_dir / f"best_{task_id}.pt",
507
+ )
508
+
509
+ # ---- Final checkpoint -------------------------------------------------
510
+ _save_checkpoint(
511
+ policy, value_net, policy_optim, value_optim,
512
+ n_episodes, best_avg_score, task_id,
513
+ ckpt_dir / f"final_{task_id}.pt",
514
+ )
515
+
516
+ # ---- Save training metrics to JSON ------------------------------------
517
+ metrics = {
518
+ "task_id": task_id,
519
+ "n_episodes": n_episodes,
520
+ "hyperparameters": {
521
+ "lr": lr,
522
+ "gamma": gamma,
523
+ "entropy_coeff": entropy_coeff,
524
+ "batch_size": batch_size,
525
+ "hidden_dim": hidden_dim,
526
+ "state_dim": STATE_DIM,
527
+ "action_dim": N_ACTIONS,
528
+ },
529
+ "episode_rewards": episode_rewards,
530
+ "episode_grader_scores": episode_grader_scores,
531
+ "episode_steps": episode_steps,
532
+ "policy_losses": episode_policy_losses,
533
+ "value_losses": episode_value_losses,
534
+ "best_avg_grader_score": best_avg_score,
535
+ "total_training_time_s": time.time() - t_start,
536
+ }
537
+ metrics_path.parent.mkdir(parents=True, exist_ok=True)
538
+ with open(metrics_path, "w") as f:
539
+ json.dump(metrics, f, indent=2)
540
+ print(f"\nTraining metrics saved to {metrics_path}")
541
+
542
+ # ---- Post-training evaluation -----------------------------------------
543
+ n_eval = 20
544
+ print("\n" + "=" * 72)
545
+ print(f"Post-training evaluation ({n_eval} episodes each mode)")
546
+ print("=" * 72)
547
+
548
+ for mode, is_greedy in [("stochastic", False), ("greedy", True)]:
549
+ eval_rewards, eval_scores, eval_steps_list = [], [], []
550
+ for i in range(n_eval):
551
+ traj = run_episode(
552
+ env, task_id, policy, value_net, task_cfg,
553
+ seed=10_000 + i, greedy=is_greedy,
554
+ )
555
+ eval_rewards.append(traj["total_reward"])
556
+ eval_scores.append(traj["grader_score"])
557
+ eval_steps_list.append(traj["n_steps"])
558
+ avg_r = sum(eval_rewards) / len(eval_rewards)
559
+ avg_s = sum(eval_scores) / len(eval_scores)
560
+ avg_st = sum(eval_steps_list) / len(eval_steps_list)
561
+ print(
562
+ f" [{mode:>10s}] avg_reward={avg_r:+.4f} "
563
+ f"avg_grader={avg_s:.4f} avg_steps={avg_st:.1f}"
564
+ )
565
+ metrics[f"eval_{mode}_avg_reward"] = avg_r
566
+ metrics[f"eval_{mode}_avg_grader_score"] = avg_s
567
+ metrics[f"eval_{mode}_avg_steps"] = avg_st
568
+ metrics[f"eval_{mode}_rewards"] = eval_rewards
569
+ metrics[f"eval_{mode}_grader_scores"] = eval_scores
570
+
571
+ print(f" best training rolling-avg grader: {best_avg_score:.4f}")
572
+ print()
573
+
574
+ with open(metrics_path, "w") as f:
575
+ json.dump(metrics, f, indent=2)
576
+
577
+ print("Done.")
578
+
579
+
580
+ # ---------------------------------------------------------------------------
581
+ # Checkpoint I/O
582
+ # ---------------------------------------------------------------------------
583
+
584
+ def _save_checkpoint(
585
+ policy: PolicyNetwork,
586
+ value_net: ValueNetwork,
587
+ policy_optim: torch.optim.Optimizer,
588
+ value_optim: torch.optim.Optimizer,
589
+ episode: int,
590
+ best_score: float,
591
+ task_id: str,
592
+ path: Path,
593
+ ) -> None:
594
+ torch.save(
595
+ {
596
+ "episode": episode,
597
+ "best_avg_grader_score": best_score,
598
+ "task_id": task_id,
599
+ "policy_state_dict": policy.state_dict(),
600
+ "value_state_dict": value_net.state_dict(),
601
+ "policy_optim_state_dict": policy_optim.state_dict(),
602
+ "value_optim_state_dict": value_optim.state_dict(),
603
+ "state_dim": STATE_DIM,
604
+ "action_dim": N_ACTIONS,
605
+ },
606
+ path,
607
+ )
608
+
609
+
610
+ def load_checkpoint(
611
+ path: Path,
612
+ hidden_dim: int = 128,
613
+ ) -> Tuple[PolicyNetwork, ValueNetwork]:
614
+ """Load a trained policy + value net from a checkpoint file."""
615
+ ckpt = torch.load(path, map_location="cpu")
616
+ policy = PolicyNetwork(
617
+ ckpt.get("state_dim", STATE_DIM),
618
+ ckpt.get("action_dim", N_ACTIONS),
619
+ hidden=hidden_dim,
620
+ )
621
+ value_net = ValueNetwork(ckpt.get("state_dim", STATE_DIM), hidden=hidden_dim)
622
+ policy.load_state_dict(ckpt["policy_state_dict"])
623
+ value_net.load_state_dict(ckpt["value_state_dict"])
624
+ return policy, value_net
625
+
626
+
627
+ # ---------------------------------------------------------------------------
628
+ # CLI
629
+ # ---------------------------------------------------------------------------
630
+
631
+ def parse_args() -> argparse.Namespace:
632
+ p = argparse.ArgumentParser(
633
+ description="REINFORCE training for PolypharmacyEnv",
634
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
635
+ )
636
+ p.add_argument(
637
+ "--task",
638
+ type=str,
639
+ default="easy_screening",
640
+ choices=list(TASK_CONFIGS.keys()),
641
+ help="Task difficulty to train on",
642
+ )
643
+ p.add_argument("--episodes", type=int, default=200, help="Number of training episodes")
644
+ p.add_argument("--batch-size", type=int, default=5, help="Episodes per gradient update")
645
+ p.add_argument("--lr", type=float, default=3e-4, help="Learning rate for Adam")
646
+ p.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
647
+ p.add_argument(
648
+ "--entropy-coeff", type=float, default=0.02,
649
+ help="Entropy bonus coefficient (higher = more exploration)",
650
+ )
651
+ p.add_argument("--hidden-dim", type=int, default=128, help="Hidden layer width")
652
+ p.add_argument("--print-every", type=int, default=10, help="Print interval (episodes)")
653
+ p.add_argument(
654
+ "--checkpoint-dir",
655
+ type=str,
656
+ default=os.path.join(_BACKEND_SRC, "polypharmacy_env", "checkpoints"),
657
+ help="Directory to save model checkpoints",
658
+ )
659
+ p.add_argument(
660
+ "--metrics-file",
661
+ type=str,
662
+ default="training_metrics.json",
663
+ help="Path for JSON training metrics",
664
+ )
665
+ return p.parse_args()
666
+
667
+
668
+ # ---------------------------------------------------------------------------
669
+ # Entry point
670
+ # ---------------------------------------------------------------------------
671
+
672
+ if __name__ == "__main__":
673
+ args = parse_args()
674
+ train(args)
training_metrics.json ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_id": "easy_screening",
3
+ "n_episodes": 30,
4
+ "hyperparameters": {
5
+ "lr": 0.0003,
6
+ "gamma": 0.99,
7
+ "entropy_coeff": 0.02,
8
+ "batch_size": 5,
9
+ "hidden_dim": 128,
10
+ "state_dim": 16,
11
+ "action_dim": 166
12
+ },
13
+ "episode_rewards": [
14
+ 0.47,
15
+ 1.1073118279569893,
16
+ 1.1486231884057971,
17
+ 1.1336231884057972,
18
+ 0.405,
19
+ 0.395,
20
+ 1.1296806853582555,
21
+ 0.38,
22
+ 0.7283823529411766,
23
+ 0.0,
24
+ 0.39,
25
+ -0.095,
26
+ 1.137962962962963,
27
+ 1.1951785714285714,
28
+ 0.9053636363636364,
29
+ -0.01754088050314473,
30
+ -0.04,
31
+ -0.06,
32
+ 0.0,
33
+ 0.22666666666666668,
34
+ 0.435,
35
+ 0.45,
36
+ 0.45,
37
+ 0.37666666666666665,
38
+ 0.435,
39
+ 0.455,
40
+ 0.5412162162162163,
41
+ 0.33899999999999997,
42
+ 0.3416666666666666,
43
+ 0.42
44
+ ],
45
+ "episode_grader_scores": [
46
+ 0.5,
47
+ 0.8306451612903226,
48
+ 0.8369565217391305,
49
+ 0.8369565217391305,
50
+ 0.5,
51
+ 0.5,
52
+ 0.9688473520249221,
53
+ 0.5,
54
+ 0.7058823529411765,
55
+ 0.0,
56
+ 0.5,
57
+ 0.0,
58
+ 0.837962962962963,
59
+ 0.9776785714285714,
60
+ 0.8863636363636364,
61
+ 0.053459119496855306,
62
+ 0.0,
63
+ 0.0,
64
+ 0.0,
65
+ 0.5,
66
+ 0.5,
67
+ 0.5,
68
+ 0.5,
69
+ 0.5,
70
+ 0.5,
71
+ 0.5,
72
+ 0.5495495495495496,
73
+ 0.5,
74
+ 0.27499999999999997,
75
+ 0.5
76
+ ],
77
+ "episode_steps": [
78
+ 5,
79
+ 4,
80
+ 3,
81
+ 4,
82
+ 6,
83
+ 6,
84
+ 5,
85
+ 7,
86
+ 4,
87
+ 1,
88
+ 7,
89
+ 8,
90
+ 4,
91
+ 3,
92
+ 10,
93
+ 6,
94
+ 3,
95
+ 7,
96
+ 1,
97
+ 4,
98
+ 4,
99
+ 3,
100
+ 3,
101
+ 5,
102
+ 4,
103
+ 6,
104
+ 4,
105
+ 7,
106
+ 2,
107
+ 5
108
+ ],
109
+ "policy_losses": [
110
+ 0.28967130184173584,
111
+ 0.05730011314153671,
112
+ -0.06924888491630554,
113
+ -0.28697478771209717,
114
+ -0.1783256083726883,
115
+ -0.12063005566596985
116
+ ],
117
+ "value_losses": [
118
+ 0.39626142382621765,
119
+ 0.24146510660648346,
120
+ 0.29013994336128235,
121
+ 0.06388193368911743,
122
+ 0.02375689707696438,
123
+ 0.02241377718746662
124
+ ],
125
+ "best_avg_grader_score": 0.6179287909734683,
126
+ "total_training_time_s": 0.13345718383789062,
127
+ "eval_stochastic_avg_reward": 0.5792066304974347,
128
+ "eval_stochastic_avg_grader_score": 0.5784816304974348,
129
+ "eval_stochastic_avg_steps": 5.6,
130
+ "eval_stochastic_rewards": [
131
+ 1.0402956989247312,
132
+ 0.485,
133
+ -0.11333333333333336,
134
+ 0.455,
135
+ 0.36250000000000004,
136
+ 1.0112089552238808,
137
+ 0.41,
138
+ 0.21499999999999997,
139
+ 1.1173118279569894,
140
+ 0.883621495327103,
141
+ 0.45,
142
+ 1.1073118279569893,
143
+ 0.8680882352941177,
144
+ 0.405,
145
+ 0.675031128404669,
146
+ 0.45,
147
+ -0.11,
148
+ 0.415,
149
+ 0.32,
150
+ 1.1370967741935485
151
+ ],
152
+ "eval_stochastic_grader_scores": [
153
+ 0.8494623655913979,
154
+ 0.5,
155
+ 0.0,
156
+ 0.5,
157
+ 0.5,
158
+ 0.9207089552238806,
159
+ 0.5,
160
+ 0.5,
161
+ 0.8306451612903226,
162
+ 0.8411214953271029,
163
+ 0.5,
164
+ 0.8306451612903226,
165
+ 0.803921568627451,
166
+ 0.5,
167
+ 0.6060311284046691,
168
+ 0.5,
169
+ 0.0,
170
+ 0.5,
171
+ 0.5,
172
+ 0.8870967741935485
173
+ ],
174
+ "eval_greedy_avg_reward": 0.3627500000000001,
175
+ "eval_greedy_avg_grader_score": 0.425,
176
+ "eval_greedy_avg_steps": 6.45,
177
+ "eval_greedy_rewards": [
178
+ 0.44,
179
+ 0.455,
180
+ -0.08,
181
+ 0.455,
182
+ -0.06,
183
+ 0.39,
184
+ 0.455,
185
+ 0.455,
186
+ 0.455,
187
+ -0.06,
188
+ 0.42,
189
+ 0.455,
190
+ 0.41000000000000003,
191
+ 0.455,
192
+ 0.44,
193
+ 0.39,
194
+ 0.41000000000000003,
195
+ 0.49,
196
+ 0.44,
197
+ 0.44
198
+ ],
199
+ "eval_greedy_grader_scores": [
200
+ 0.5,
201
+ 0.5,
202
+ 0.0,
203
+ 0.5,
204
+ 0.0,
205
+ 0.5,
206
+ 0.5,
207
+ 0.5,
208
+ 0.5,
209
+ 0.0,
210
+ 0.5,
211
+ 0.5,
212
+ 0.5,
213
+ 0.5,
214
+ 0.5,
215
+ 0.5,
216
+ 0.5,
217
+ 0.5,
218
+ 0.5,
219
+ 0.5
220
+ ]
221
+ }