Spaces:
Sleeping
Sleeping
Commit ·
f0ef01d
1
Parent(s): bbb6de2
Version 3
Browse filesCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .gitignore +4 -0
- Dockerfile +14 -4
- PROMPT.md +0 -571
- README.MD +408 -136
- backend/requirements.txt +1 -0
- backend/src/polypharmacy_env/api/app.py +80 -0
- backend/src/polypharmacy_env/api/routes/bandit.py +158 -0
- backend/src/polypharmacy_env/config.py +9 -5
- backend/src/polypharmacy_env/env_core.py +1 -0
- backend/src/polypharmacy_env/neural_bandits.py +484 -0
- backend/src/polypharmacy_env/rewards.py +6 -2
- frontend/src/App.jsx +706 -143
- frontend/src/styles.css +671 -136
- train_bandit.py +381 -0
- train_rl.py +674 -0
- training_metrics.json +221 -0
.gitignore
CHANGED
|
@@ -23,9 +23,13 @@ yarn-debug.log*
|
|
| 23 |
yarn-error.log*
|
| 24 |
pnpm-debug.log*
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
# --- Build / temp ---
|
| 27 |
*.log
|
| 28 |
*.tmp
|
| 29 |
*.swp
|
| 30 |
.DS_Store
|
| 31 |
|
|
|
|
|
|
| 23 |
yarn-error.log*
|
| 24 |
pnpm-debug.log*
|
| 25 |
|
| 26 |
+
# --- Project-specific ---
|
| 27 |
+
PROMPT.md
|
| 28 |
+
|
| 29 |
# --- Build / temp ---
|
| 30 |
*.log
|
| 31 |
*.tmp
|
| 32 |
*.swp
|
| 33 |
.DS_Store
|
| 34 |
|
| 35 |
+
arXiv-2212.05190v3/
|
Dockerfile
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
FROM node:20-alpine AS frontend-builder
|
| 2 |
WORKDIR /app/frontend
|
| 3 |
COPY frontend/package*.json ./
|
| 4 |
-
RUN npm ci
|
| 5 |
COPY frontend/ ./
|
| 6 |
RUN npm run build
|
| 7 |
|
|
@@ -11,6 +11,8 @@ RUN apt-get update && \
|
|
| 11 |
apt-get install -y --no-install-recommends build-essential curl && \
|
| 12 |
rm -rf /var/lib/apt/lists/*
|
| 13 |
|
|
|
|
|
|
|
| 14 |
WORKDIR /app
|
| 15 |
|
| 16 |
COPY backend/requirements.txt /app/backend/requirements.txt
|
|
@@ -22,18 +24,26 @@ COPY scripts /app/scripts
|
|
| 22 |
COPY openenv.yaml /app/openenv.yaml
|
| 23 |
COPY .env.example /app/.env.example
|
| 24 |
COPY inference.py /app/inference.py
|
|
|
|
|
|
|
| 25 |
|
| 26 |
COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist
|
| 27 |
|
| 28 |
RUN python3 /app/scripts/preprocess_data.py
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
ENV PORT=7860
|
| 31 |
-
ENV PYTHONPATH="/app/backend/src:
|
| 32 |
ENV PYTHONUNBUFFERED=1
|
| 33 |
|
|
|
|
|
|
|
| 34 |
EXPOSE 7860
|
| 35 |
|
| 36 |
-
HEALTHCHECK --interval=30s --timeout=
|
| 37 |
CMD curl -f http://localhost:7860/health || exit 1
|
| 38 |
|
| 39 |
-
CMD ["sh", "-c", "uvicorn backend.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
|
|
|
|
| 1 |
FROM node:20-alpine AS frontend-builder
|
| 2 |
WORKDIR /app/frontend
|
| 3 |
COPY frontend/package*.json ./
|
| 4 |
+
RUN npm ci --production=false
|
| 5 |
COPY frontend/ ./
|
| 6 |
RUN npm run build
|
| 7 |
|
|
|
|
| 11 |
apt-get install -y --no-install-recommends build-essential curl && \
|
| 12 |
rm -rf /var/lib/apt/lists/*
|
| 13 |
|
| 14 |
+
# HF Spaces runs as uid 1000
|
| 15 |
+
RUN useradd -m -u 1000 user
|
| 16 |
WORKDIR /app
|
| 17 |
|
| 18 |
COPY backend/requirements.txt /app/backend/requirements.txt
|
|
|
|
| 24 |
COPY openenv.yaml /app/openenv.yaml
|
| 25 |
COPY .env.example /app/.env.example
|
| 26 |
COPY inference.py /app/inference.py
|
| 27 |
+
COPY train_rl.py /app/train_rl.py
|
| 28 |
+
COPY train_bandit.py /app/train_bandit.py
|
| 29 |
|
| 30 |
COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist
|
| 31 |
|
| 32 |
RUN python3 /app/scripts/preprocess_data.py
|
| 33 |
|
| 34 |
+
# Ensure the user owns the app directory and has a writable home (HF Spaces)
|
| 35 |
+
RUN chown -R user:user /app && \
|
| 36 |
+
mkdir -p /home/user/.cache && chown -R user:user /home/user
|
| 37 |
+
|
| 38 |
ENV PORT=7860
|
| 39 |
+
ENV PYTHONPATH="/app/backend/src:/app"
|
| 40 |
ENV PYTHONUNBUFFERED=1
|
| 41 |
|
| 42 |
+
USER user
|
| 43 |
+
|
| 44 |
EXPOSE 7860
|
| 45 |
|
| 46 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
| 47 |
CMD curl -f http://localhost:7860/health || exit 1
|
| 48 |
|
| 49 |
+
CMD ["sh", "-c", "uvicorn backend.main:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1"]
|
PROMPT.md
DELETED
|
@@ -1,571 +0,0 @@
|
|
| 1 |
-
You are an expert Python backend, ML, and infrastructure engineer.
|
| 2 |
-
Your task is to implement a complete, production-ready OpenEnv environment called **PolypharmacyEnv** for training and evaluating agentic RL policies that act as an "elderly polypharmacy safety agent" (clinical pharmacist assistant).
|
| 3 |
-
|
| 4 |
-
The deliverable MUST satisfy all of the following:
|
| 5 |
-
- Fully compliant with the OpenEnv spec (typed models, `step()` / `reset()` / `state()`, `openenv.yaml`, HTTP server, Dockerfile).
|
| 6 |
-
- Simulates a realistic healthcare workflow around elderly polypharmacy and dangerous drug combinations.
|
| 7 |
-
- Defines at least **3 tasks** (easy → medium → hard) with deterministic agent graders producing scores in (0.0, 1.0).
|
| 8 |
-
- Provides shaped rewards over the trajectory (not just sparse terminal rewards).
|
| 9 |
-
- Includes a baseline LLM-based inference script `inference.py` in the repo root, following the evaluation requirements:
|
| 10 |
-
- Uses the OpenAI Python client.
|
| 11 |
-
- Reads `OPENAI_API_KEY`, `API_BASE_URL`, `MODEL_NAME`, and `HF_TOKEN` from the environment.
|
| 12 |
-
- Emits structured stdout logs in the exact `[START]`, `[STEP]`, `[END]` format from the OpenEnv sample inference script.
|
| 13 |
-
- Is containerized and deployable as a **Hugging Face Space** tagged with `openenv` that responds to OpenEnv-style `reset` / `step` / `state` HTTP calls.
|
| 14 |
-
|
| 15 |
-
Implement everything described below.
|
| 16 |
-
|
| 17 |
-
=================================================
|
| 18 |
-
1. Repository and folder structure
|
| 19 |
-
=================================================
|
| 20 |
-
|
| 21 |
-
Create a Python package repository with this structure (names are important unless clearly labeled as examples):
|
| 22 |
-
|
| 23 |
-
- `openenv-polypharmacy/`
|
| 24 |
-
- `openenv.yaml`
|
| 25 |
-
- `README.md`
|
| 26 |
-
- `requirements.txt`
|
| 27 |
-
- `Dockerfile`
|
| 28 |
-
- `inference.py` # baseline LLM agent per spec
|
| 29 |
-
- `pyproject.toml` or `setup.cfg` (optional but recommended)
|
| 30 |
-
- `src/`
|
| 31 |
-
- `polypharmacy_env/`
|
| 32 |
-
- `__init__.py`
|
| 33 |
-
- `config.py`
|
| 34 |
-
- `models.py` # Action, Observation, State, helper models
|
| 35 |
-
- `env_core.py` # PolypharmacyEnv implementation
|
| 36 |
-
- `tasks.py` # task setup utilities
|
| 37 |
-
- `graders.py` # deterministic graders for each task
|
| 38 |
-
- `rewards.py` # reward shaping logic
|
| 39 |
-
- `data_loader.py` # load/preprocess patient and lookup data
|
| 40 |
-
- `ddi_simulator.py` # local DDI / guideline simulator
|
| 41 |
-
- `api/`
|
| 42 |
-
- `__init__.py`
|
| 43 |
-
- `schemas.py` # HTTP request/response schemas
|
| 44 |
-
- `server.py` # FastAPI app exposing OpenEnv endpoints
|
| 45 |
-
- `baselines/`
|
| 46 |
-
- `__init__.py`
|
| 47 |
-
- `heuristic_agent.py` # simple rule-based baseline agent
|
| 48 |
-
- `random_agent.py` # trivial random baseline (optional)
|
| 49 |
-
- `tests/`
|
| 50 |
-
- `__init__.py`
|
| 51 |
-
- `test_env_core.py`
|
| 52 |
-
- `test_api.py`
|
| 53 |
-
- `data/`
|
| 54 |
-
- `raw/` # placeholder for real/synthetic source data
|
| 55 |
-
- `processed/`
|
| 56 |
-
- `lookups/`
|
| 57 |
-
- `ddi_rules.csv`
|
| 58 |
-
- `beers_criteria.csv`
|
| 59 |
-
- `drug_metadata.csv`
|
| 60 |
-
- `scripts/`
|
| 61 |
-
- `preprocess_data.py`
|
| 62 |
-
- `run_validation.sh` # optional; runs OpenEnv validator, tests, etc.
|
| 63 |
-
|
| 64 |
-
Use Python 3.10+ with full type hints, and keep the code black/isort-compatible.
|
| 65 |
-
|
| 66 |
-
=================================================
|
| 67 |
-
2. Domain, data, and clinical abstraction
|
| 68 |
-
=================================================
|
| 69 |
-
|
| 70 |
-
2.1. Core scenario
|
| 71 |
-
|
| 72 |
-
Model an elderly patient (age ≥ 65) with:
|
| 73 |
-
- Demographics: age, sex.
|
| 74 |
-
- Comorbidities: e.g., hypertension, diabetes, heart failure, CKD, dementia.
|
| 75 |
-
- Basic labs: kidney function (eGFR category), liver function category.
|
| 76 |
-
- A current medication list (polypharmacy, e.g., 3–15 drugs depending on task).
|
| 77 |
-
|
| 78 |
-
Each **episode** is one medication-review session where the agent:
|
| 79 |
-
- Observes patient info and current meds.
|
| 80 |
-
- Optionally **queries** a DDI/guideline tool for specific drug pairs.
|
| 81 |
-
- Proposes **interventions**:
|
| 82 |
-
- `stop`: discontinue a drug.
|
| 83 |
-
- `dose_reduce`: lower dose of a drug.
|
| 84 |
-
- `substitute`: swap to a safer alternative.
|
| 85 |
-
- `add_monitoring`: keep the drug but flag extra monitoring.
|
| 86 |
-
- Calls `finish_review` when it decides the regimen is acceptable or budgets are exhausted.
|
| 87 |
-
|
| 88 |
-
No external PHI, EHRs, or online APIs: all data is **synthetic** or de-identified and local to the container (CSV files).
|
| 89 |
-
|
| 90 |
-
2.2. Data files and CSV schemas
|
| 91 |
-
|
| 92 |
-
Implement local CSVs under `data/lookups/`:
|
| 93 |
-
|
| 94 |
-
**`drug_metadata.csv`**
|
| 95 |
-
- `drug_id` (string; unique key)
|
| 96 |
-
- `generic_name` (string)
|
| 97 |
-
- `atc_class` (string)
|
| 98 |
-
- `is_high_risk_elderly` (0/1)
|
| 99 |
-
- `default_dose_mg` (float)
|
| 100 |
-
- `min_dose_mg` (float)
|
| 101 |
-
- `max_dose_mg` (float)
|
| 102 |
-
|
| 103 |
-
**`beers_criteria.csv`**
|
| 104 |
-
- `drug_id` (string)
|
| 105 |
-
- `criterion_type` (enum string: `avoid`, `caution`, `dose_adjust`, `avoid_in_condition`)
|
| 106 |
-
- `condition` (nullable string; e.g., `CKD`, `dementia`)
|
| 107 |
-
- `rationale` (brief text)
|
| 108 |
-
|
| 109 |
-
**`ddi_rules.csv`**
|
| 110 |
-
- `drug_id_1` (string; normalized so `drug_id_1 < drug_id_2` lexicographically)
|
| 111 |
-
- `drug_id_2` (string)
|
| 112 |
-
- `severity` (enum string: `mild`, `moderate`, `severe`)
|
| 113 |
-
- `mechanism` (short text)
|
| 114 |
-
- `recommendation` (enum string: `avoid_combination`, `monitor_closely`, `dose_adjust`, `no_action`)
|
| 115 |
-
- `base_risk_score` (float in [0.0, 1.0])
|
| 116 |
-
|
| 117 |
-
Implement a synthetic patient-episode dataset under `data/processed/`:
|
| 118 |
-
|
| 119 |
-
**`patients_polypharmacy.csv`**
|
| 120 |
-
- `episode_id` (string)
|
| 121 |
-
- `age` (int)
|
| 122 |
-
- `sex` (enum: `M`, `F`, `O`)
|
| 123 |
-
- `conditions` (semicolon-separated; e.g., `HTN;DM;CKD`)
|
| 124 |
-
- `eGFR_category` (enum: `normal`, `mild`, `moderate`, `severe`)
|
| 125 |
-
- `liver_function_category` (enum: `normal`, `impaired`)
|
| 126 |
-
- `medication_ids` (semicolon-separated list of `drug_id`)
|
| 127 |
-
- `baseline_risk_score` (float in [0.0, 1.0])
|
| 128 |
-
|
| 129 |
-
2.3. Preprocessing script
|
| 130 |
-
|
| 131 |
-
In `scripts/preprocess_data.py`:
|
| 132 |
-
- If real data is not provided, procedurally generate synthetic but plausible data using:
|
| 133 |
-
- Random combinations of conditions and drugs constrained by simple rules (e.g., CKD + renally-cleared drugs).
|
| 134 |
-
- Controlled distribution of high-risk DDIs and Beers violations.
|
| 135 |
-
- Explicitly tag episodes as easy/medium/hard (e.g., via number of drugs, number/severity of DDIs, and number of Beers issues).
|
| 136 |
-
- Save `patients_polypharmacy.csv` ready for the environment to consume.
|
| 137 |
-
|
| 138 |
-
=================================================
|
| 139 |
-
3. OpenEnv models and environment implementation
|
| 140 |
-
=================================================
|
| 141 |
-
|
| 142 |
-
3.1. Models
|
| 143 |
-
|
| 144 |
-
In `models.py`, define dataclasses or Pydantic models that extend the appropriate OpenEnv base types (`Action`, `Observation`, `State`) and are JSON-compatible.
|
| 145 |
-
|
| 146 |
-
Auxiliary models:
|
| 147 |
-
|
| 148 |
-
**`MedicationEntry`**
|
| 149 |
-
- `drug_id: str`
|
| 150 |
-
- `generic_name: str`
|
| 151 |
-
- `atc_class: str`
|
| 152 |
-
- `dose_mg: float`
|
| 153 |
-
- `frequency: str` # e.g., `qd`, `bid`
|
| 154 |
-
- `route: str` # e.g., `po`
|
| 155 |
-
- `is_high_risk_elderly: bool`
|
| 156 |
-
- `beers_flags: list[str]` # e.g., `["avoid", "dose_adjust_CKD"]`
|
| 157 |
-
|
| 158 |
-
**`InteractionQueryRecord`**
|
| 159 |
-
- `drug_id_1: str`
|
| 160 |
-
- `drug_id_2: str`
|
| 161 |
-
- `severity: str | None`
|
| 162 |
-
- `recommendation: str | None`
|
| 163 |
-
- `risk_score: float | None`
|
| 164 |
-
- `step_index: int`
|
| 165 |
-
|
| 166 |
-
**`InterventionRecord`**
|
| 167 |
-
- `target_drug_id: str`
|
| 168 |
-
- `action_type: Literal["stop", "dose_reduce", "substitute", "add_monitoring"]`
|
| 169 |
-
- `proposed_new_drug_id: str | None`
|
| 170 |
-
- `rationale: str`
|
| 171 |
-
- `step_index: int`
|
| 172 |
-
|
| 173 |
-
Core wire models:
|
| 174 |
-
|
| 175 |
-
**`PolypharmacyObservation`** (extends OpenEnv `Observation`)
|
| 176 |
-
- `episode_id: str`
|
| 177 |
-
- `task_id: Literal["easy_screening", "budgeted_screening", "complex_tradeoff"]`
|
| 178 |
-
- `age: int`
|
| 179 |
-
- `sex: str`
|
| 180 |
-
- `conditions: list[str]`
|
| 181 |
-
- `eGFR_category: str`
|
| 182 |
-
- `liver_function_category: str`
|
| 183 |
-
- `current_medications: list[MedicationEntry]`
|
| 184 |
-
- `interaction_queries: list[InteractionQueryRecord]`
|
| 185 |
-
- `interventions: list[InterventionRecord]`
|
| 186 |
-
- `step_index: int`
|
| 187 |
-
- `remaining_query_budget: int`
|
| 188 |
-
- `remaining_intervention_budget: int`
|
| 189 |
-
- `shaped_reward: float` # reward from last step
|
| 190 |
-
- `done: bool`
|
| 191 |
-
|
| 192 |
-
**`PolypharmacyAction`** (extends OpenEnv `Action`)
|
| 193 |
-
- `action_type: Literal["query_ddi", "propose_intervention", "finish_review"]`
|
| 194 |
-
- `drug_id_1: str | None` # for DDI queries or some interventions
|
| 195 |
-
- `drug_id_2: str | None` # for DDI queries
|
| 196 |
-
- `target_drug_id: str | None` # for interventions
|
| 197 |
-
- `intervention_type: Literal["stop", "dose_reduce", "substitute", "add_monitoring", "none"] | None`
|
| 198 |
-
- `proposed_new_drug_id: str | None`
|
| 199 |
-
- `rationale: str | None`
|
| 200 |
-
|
| 201 |
-
**`PolypharmacyState`** (extends OpenEnv `State`)
|
| 202 |
-
- `episode_id: str`
|
| 203 |
-
- `task_id: str`
|
| 204 |
-
- `step_count: int`
|
| 205 |
-
- `max_steps: int`
|
| 206 |
-
- `num_query_actions: int`
|
| 207 |
-
- `num_interventions: int`
|
| 208 |
-
|
| 209 |
-
3.2. Environment core
|
| 210 |
-
|
| 211 |
-
In `env_core.py`, implement `PolypharmacyEnv` extending the appropriate OpenEnv environment base class. It must implement:
|
| 212 |
-
|
| 213 |
-
**`reset(task_id: str | None = None) -> PolypharmacyObservation`**
|
| 214 |
-
- If `task_id` is `None`, default to medium (`budgeted_screening`).
|
| 215 |
-
- Sample an episode from `patients_polypharmacy.csv` filtered by difficulty.
|
| 216 |
-
- Initialize:
|
| 217 |
-
- `episode_id`
|
| 218 |
-
- `step_count = 0`
|
| 219 |
-
- task-specific budgets (query, interventions, max_steps)
|
| 220 |
-
- baseline regime and risk
|
| 221 |
-
- empty `interaction_queries` and `interventions`
|
| 222 |
-
- Return the initial `PolypharmacyObservation` with:
|
| 223 |
-
- `step_index = 0`
|
| 224 |
-
- `shaped_reward = 0.0`
|
| 225 |
-
- `done = False`
|
| 226 |
-
|
| 227 |
-
**`step(action: PolypharmacyAction) -> dict`**
|
| 228 |
-
- Validate the action; if invalid:
|
| 229 |
-
- Apply a negative reward.
|
| 230 |
-
- Do not modify regimen, but log error in `info`.
|
| 231 |
-
- If `action_type == "query_ddi"`:
|
| 232 |
-
- If query budget exhausted, apply penalty and do not query.
|
| 233 |
-
- Else:
|
| 234 |
-
- Use `ddi_simulator.lookup_ddi(drug_id_1, drug_id_2)` to get severity, recommendation, base_risk_score.
|
| 235 |
-
- Append an `InteractionQueryRecord`.
|
| 236 |
-
- Apply a small negative reward for query cost.
|
| 237 |
-
- If `action_type == "propose_intervention"`:
|
| 238 |
-
- If intervention budget exhausted, apply penalty and ignore change.
|
| 239 |
-
- Else:
|
| 240 |
-
- Update `current_medications` according to `intervention_type`:
|
| 241 |
-
- `stop`: remove medication.
|
| 242 |
-
- `dose_reduce`: adjust dose downward within [min_dose_mg, default_dose_mg].
|
| 243 |
-
- `substitute`: replace with a safer alternative from same `atc_class`.
|
| 244 |
-
- `add_monitoring`: keep drug but tag in internal state.
|
| 245 |
-
- Append an `InterventionRecord`.
|
| 246 |
-
- Recompute current regimen risk using the risk model (see 3.3).
|
| 247 |
-
- Compute shaped reward = (previous_risk - new_risk) - small intervention cost.
|
| 248 |
-
- If `action_type == "finish_review"`:
|
| 249 |
-
- Mark `done = True`.
|
| 250 |
-
- Call the task’s grader to get episode-level score in [0.0, 1.0].
|
| 251 |
-
- Add this as a terminal bonus to the current step reward.
|
| 252 |
-
|
| 253 |
-
- In all cases:
|
| 254 |
-
- Increment `step_count`.
|
| 255 |
-
- Check `max_steps`; if exceeded, auto-terminate:
|
| 256 |
-
- `done = True`
|
| 257 |
-
- apply time-out penalty
|
| 258 |
-
- call grader with current trajectory for a final score if appropriate.
|
| 259 |
-
- Construct next `PolypharmacyObservation` with updated fields.
|
| 260 |
-
- Return a dict:
|
| 261 |
-
- `observation`: `PolypharmacyObservation`
|
| 262 |
-
- `reward`: float shaped reward for this step
|
| 263 |
-
- `done`: bool
|
| 264 |
-
- `info`: dict with fields like `current_risk`, `baseline_risk`, `grader_score_if_terminal`, and debug flags.
|
| 265 |
-
|
| 266 |
-
**`state` property**
|
| 267 |
-
- Returns `PolypharmacyState` reflecting the current internal state.
|
| 268 |
-
|
| 269 |
-
3.3. DDI simulator and risk model
|
| 270 |
-
|
| 271 |
-
In `ddi_simulator.py`:
|
| 272 |
-
- Load `ddi_rules.csv` once via `data_loader`.
|
| 273 |
-
- Implement `lookup_ddi(drug_id_1, drug_id_2) -> tuple[severity, recommendation, base_risk_score]`:
|
| 274 |
-
- Normalize the pair ordering.
|
| 275 |
-
- Look up row; if missing, return:
|
| 276 |
-
- severity = `"none"`
|
| 277 |
-
- recommendation = `"no_action"`
|
| 278 |
-
- base_risk_score = 0.0
|
| 279 |
-
|
| 280 |
-
In `rewards.py` (or a dedicated module), implement:
|
| 281 |
-
- `compute_regimen_risk(current_drug_ids, patient_context, ddi_rules, beers_rules, drug_metadata) -> float`
|
| 282 |
-
- Aggregate contributions from:
|
| 283 |
-
- Beers violations (weighted by `criterion_type` and relevant conditions).
|
| 284 |
-
- DDI base risk scores for all present drug pairs.
|
| 285 |
-
- High-risk elderly drugs.
|
| 286 |
-
- Normalize and clip to [0.0, 1.0].
|
| 287 |
-
|
| 288 |
-
Use this function to compute:
|
| 289 |
-
- `baseline_risk` at episode start.
|
| 290 |
-
- Risk after each intervention step.
|
| 291 |
-
|
| 292 |
-
Also implement:
|
| 293 |
-
- `compute_shaped_reward(previous_risk, new_risk, action, context, partial_metrics) -> float`
|
| 294 |
-
- Positive component: `previous_risk - new_risk`.
|
| 295 |
-
- Negative components: per-query cost, per-intervention cost, invalid-action penalty, time-out penalty.
|
| 296 |
-
|
| 297 |
-
=================================================
|
| 298 |
-
4. Tasks and graders (3 difficulty levels)
|
| 299 |
-
=================================================
|
| 300 |
-
|
| 301 |
-
Define three task IDs and semantics in `tasks.py` and `graders.py`:
|
| 302 |
-
|
| 303 |
-
Task IDs:
|
| 304 |
-
- `easy_screening`
|
| 305 |
-
- `budgeted_screening`
|
| 306 |
-
- `complex_tradeoff`
|
| 307 |
-
|
| 308 |
-
4.1. `easy_screening` (easy)
|
| 309 |
-
|
| 310 |
-
- Small regimen: 3–5 drugs.
|
| 311 |
-
- Exactly one **severe** DDI pair and possibly one simple Beers violation.
|
| 312 |
-
- Budgets:
|
| 313 |
-
- query_budget ≈ 4
|
| 314 |
-
- intervention_budget ≈ 2
|
| 315 |
-
- max_steps ≈ 10
|
| 316 |
-
|
| 317 |
-
Grader:
|
| 318 |
-
- Input: full trajectory, baseline risk, final risk, list of interventions.
|
| 319 |
-
- Compute:
|
| 320 |
-
- `risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, ε)` (normalized).
|
| 321 |
-
- `targeted_intervention_flag = 1.0` if at least one intervention affects one of the drugs in the known severe DDI pair, else 0.0.
|
| 322 |
-
- Score:
|
| 323 |
-
- `score = 0.5 * risk_reduction + 0.5 * targeted_intervention_flag`
|
| 324 |
-
- Clip to [0.0, 1.0].
|
| 325 |
-
|
| 326 |
-
4.2. `budgeted_screening` (medium)
|
| 327 |
-
|
| 328 |
-
- Medium regimen: 6–10 drugs.
|
| 329 |
-
- Multiple DDIs (mild/moderate/severe) and multiple Beers issues.
|
| 330 |
-
- Budgets:
|
| 331 |
-
- query_budget ≈ 8
|
| 332 |
-
- intervention_budget ≈ 3
|
| 333 |
-
- max_steps ≈ 20
|
| 334 |
-
|
| 335 |
-
Grader:
|
| 336 |
-
- Compute:
|
| 337 |
-
- `risk_reduction_score` as normalized risk drop.
|
| 338 |
-
- `intervention_precision_score` = fraction of interventions that actually reduce risk or fix guideline violations.
|
| 339 |
-
- `query_efficiency_score` = (number of severe/moderate DDIs discovered) / (number of queries used), normalized.
|
| 340 |
-
- Weighted score, for example:
|
| 341 |
-
- `score = 0.5 * risk_reduction_score + 0.3 * intervention_precision_score + 0.2 * query_efficiency_score`
|
| 342 |
-
- Clip to [0.0, 1.0].
|
| 343 |
-
|
| 344 |
-
4.3. `complex_tradeoff` (hard)
|
| 345 |
-
|
| 346 |
-
- Larger regimen: 10–15 drugs.
|
| 347 |
-
- Some drugs are **clinically critical** (e.g., anticoagulants, insulin analogues) and encoded as such in `drug_metadata` or a small internal map.
|
| 348 |
-
- Episodes contain:
|
| 349 |
-
- multiple DDIs and Beers issues, including ones involving critical drugs.
|
| 350 |
-
- safer substitutes for some risky drugs.
|
| 351 |
-
|
| 352 |
-
Budgets:
|
| 353 |
-
- query_budget ≈ 12
|
| 354 |
-
- intervention_budget ≈ 5
|
| 355 |
-
- max_steps ≈ 30
|
| 356 |
-
|
| 357 |
-
Grader adds a **regimen disruption penalty** component:
|
| 358 |
-
- Metrics:
|
| 359 |
-
- `risk_reduction_score` (as above).
|
| 360 |
-
- `critical_drug_penalty` = penalty if a critical drug is stopped without substitution to another suitable agent.
|
| 361 |
-
- `total_drug_changes` = number of drugs stopped or substituted.
|
| 362 |
-
- `regimen_disruption_penalty` derived from `total_drug_changes` and `critical_drug_penalty`.
|
| 363 |
-
|
| 364 |
-
Example scoring:
|
| 365 |
-
- `base = risk_reduction_score`
|
| 366 |
-
- `penalty = α * regimen_disruption_penalty`
|
| 367 |
-
- `score = clamp(base - penalty, 0.0, 1.0)`
|
| 368 |
-
|
| 369 |
-
4.4. Reward shaping
|
| 370 |
-
|
| 371 |
-
In `rewards.py`, define a consistent shaping scheme:
|
| 372 |
-
- On each query:
|
| 373 |
-
- Small negative reward (e.g., −0.01) plus any small bonus if it discovers a severe DDI, if desired.
|
| 374 |
-
- On each intervention:
|
| 375 |
-
- Reward ≈ (previous_risk - new_risk) − small intervention cost.
|
| 376 |
-
- On invalid actions:
|
| 377 |
-
- Larger negative reward (e.g., −0.1) and no state change.
|
| 378 |
-
- On `finish_review`:
|
| 379 |
-
- Add the task-level `score` ∈ [0.0, 1.0] from the corresponding grader to that step’s shaped reward.
|
| 380 |
-
|
| 381 |
-
Ensure the sum of step rewards per episode remains in a reasonable numeric range (e.g., roughly -5 to +5) while still allowing meaningful differentiation by graders.
|
| 382 |
-
|
| 383 |
-
=================================================
|
| 384 |
-
5. HTTP API server and openenv.yaml
|
| 385 |
-
=================================================
|
| 386 |
-
|
| 387 |
-
5.1. HTTP server (FastAPI)
|
| 388 |
-
|
| 389 |
-
In `api/server.py`:
|
| 390 |
-
- Implement a FastAPI app that maintains a `PolypharmacyEnv` instance (or a multiplexing scheme if needed).
|
| 391 |
-
- Endpoints:
|
| 392 |
-
- `POST /reset`:
|
| 393 |
-
- Request body: may include `task_id` (string).
|
| 394 |
-
- Response: serialized `PolypharmacyObservation`.
|
| 395 |
-
- `POST /step`:
|
| 396 |
-
- Request body: serialized `PolypharmacyAction`.
|
| 397 |
-
- Response: dict with:
|
| 398 |
-
- `observation`: `PolypharmacyObservation`
|
| 399 |
-
- `reward`: float
|
| 400 |
-
- `done`: bool
|
| 401 |
-
- `info`: dict
|
| 402 |
-
- `GET /state`:
|
| 403 |
-
- Response: `PolypharmacyState`.
|
| 404 |
-
|
| 405 |
-
Provide a module-level `app = FastAPI(...)` object for use with uvicorn and Hugging Face Spaces. Ensure the JSON schema is consistent with OpenEnv clients (simple, flat JSON for observation/action/state).
|
| 406 |
-
|
| 407 |
-
5.2. `openenv.yaml`
|
| 408 |
-
|
| 409 |
-
At repo root, define `openenv.yaml` consistent with the latest OpenEnv spec. At minimum, include:
|
| 410 |
-
- `name`: `polypharmacy_env`
|
| 411 |
-
- `version`: e.g., `0.1.0`
|
| 412 |
-
- `description`: human-readable description.
|
| 413 |
-
- `author`: your details.
|
| 414 |
-
- `tags`: e.g., `["healthcare", "polypharmacy", "openenv"]`
|
| 415 |
-
- `tasks`:
|
| 416 |
-
- One entry per task:
|
| 417 |
-
- `id`: `"easy_screening"` / `"budgeted_screening"` / `"complex_tradeoff"`
|
| 418 |
-
- `description`: one-line description
|
| 419 |
-
- `difficulty`: `"easy"`, `"medium"`, `"hard"`
|
| 420 |
-
|
| 421 |
-
Ensure `openenv validate` (or equivalent validator) passes once implemented.
|
| 422 |
-
|
| 423 |
-
=================================================
|
| 424 |
-
6. Baseline heuristic (non-LLM) agent
|
| 425 |
-
=================================================
|
| 426 |
-
|
| 427 |
-
In `baselines/heuristic_agent.py`, implement a simple, deterministic baseline agent that:
|
| 428 |
-
|
| 429 |
-
For each episode:
|
| 430 |
-
- Iterates through all unordered medication pairs within query budget:
|
| 431 |
-
- Calls `query_ddi` via the environment for each pair until the query budget is exhausted or all pairs are examined.
|
| 432 |
-
- Records severe and moderate interactions.
|
| 433 |
-
- After querying:
|
| 434 |
-
- For each severe DDI pair:
|
| 435 |
-
- Try `substitute` one of the drugs using `drug_metadata`:
|
| 436 |
-
- Prefer substitute within same `atc_class` that:
|
| 437 |
-
- is not marked high-risk elderly.
|
| 438 |
-
- does not participate in known severe DDIs with the rest of the regimen.
|
| 439 |
-
- If no substitute exists, propose `stop` for the higher-risk drug.
|
| 440 |
-
- Respect intervention budget limits.
|
| 441 |
-
- Finally, call `finish_review`.
|
| 442 |
-
|
| 443 |
-
This baseline should be callable as a simple Python function that interacts with `PolypharmacyEnv` directly (without HTTP).
|
| 444 |
-
|
| 445 |
-
=================================================
|
| 446 |
-
7. Baseline LLM inference script (inference.py)
|
| 447 |
-
=================================================
|
| 448 |
-
|
| 449 |
-
At repo root, create `inference.py` that:
|
| 450 |
-
|
| 451 |
-
7.1. Uses the OpenAI Python client
|
| 452 |
-
|
| 453 |
-
- Import and configure the official OpenAI Python client.
|
| 454 |
-
- Read environment variables:
|
| 455 |
-
- `OPENAI_API_KEY` (required).
|
| 456 |
-
- `API_BASE_URL` (base URL for LLM; default to OpenAI standard if not set).
|
| 457 |
-
- `MODEL_NAME` (e.g., `gpt-4.1` or similar).
|
| 458 |
-
- `HF_TOKEN` (if needed for HF auth; do not hardcode).
|
| 459 |
-
- Read `POLYPHARMACY_ENV_URL` (or similar) for the environment’s HTTP base URL.
|
| 460 |
-
|
| 461 |
-
7.2. Implements the required logging format
|
| 462 |
-
|
| 463 |
-
- For each **run** across all tasks:
|
| 464 |
-
- Emit a `[START]` line with a JSON payload exactly matching the evaluation specification:
|
| 465 |
-
- Fields such as `run_id`, `task_id`, `model`, etc., in the same order and naming as the sample OpenEnv inference script.
|
| 466 |
-
- For each **step** in an episode:
|
| 467 |
-
- Emit a `[STEP]` line with JSON fields including:
|
| 468 |
-
- `run_id`
|
| 469 |
-
- `task_id`
|
| 470 |
-
- `episode_id`
|
| 471 |
-
- `step_index`
|
| 472 |
-
- `observation_summary` (brief, machine-readable summary)
|
| 473 |
-
- `action_payload` (the action sent to the env)
|
| 474 |
-
- `reward`
|
| 475 |
-
- `done`
|
| 476 |
-
- After finishing an episode for a task:
|
| 477 |
-
- Emit an `[END]` line summarizing:
|
| 478 |
-
- `run_id`
|
| 479 |
-
- `task_id`
|
| 480 |
-
- per-episode statistics (e.g., total reward, grader score from last step’s `info`).
|
| 481 |
-
- The stdout format MUST follow the sample exactly:
|
| 482 |
-
- Same tags: `[START]`, `[STEP]`, `[END]`.
|
| 483 |
-
- Same JSON field names and ordering as the provided reference.
|
| 484 |
-
- No extra prints except these structured logs (and necessary error messages to stderr).
|
| 485 |
-
|
| 486 |
-
7.3. LLM agent loop
|
| 487 |
-
|
| 488 |
-
- For each task (`easy_screening`, `budgeted_screening`, `complex_tradeoff`):
|
| 489 |
-
- Run a fixed small number of episodes (e.g., 5–10 per task) for baseline scoring.
|
| 490 |
-
- For each episode:
|
| 491 |
-
- Call `/reset` with the task id.
|
| 492 |
-
- At each step:
|
| 493 |
-
- Summarize the observation into a concise prompt for the LLM:
|
| 494 |
-
- Include age, sex, conditions, high-risk flags, budgets, and a compressed view of meds and previous actions.
|
| 495 |
-
- Ask the model to output a **strict JSON** representing `PolypharmacyAction` fields.
|
| 496 |
-
- Parse and validate the JSON; if invalid, fall back to a safe default (e.g., `finish_review` or a no-op) and penalize in evaluation.
|
| 497 |
-
- Send this action to `/step` and log `[STEP]`.
|
| 498 |
-
- End when `done=True` or max_steps is reached.
|
| 499 |
-
- At the end, print aggregate scores per task and overall.
|
| 500 |
-
|
| 501 |
-
Make sure runtime < 20 minutes and that the script can run within 2 vCPUs and 8 GB RAM.
|
| 502 |
-
|
| 503 |
-
=================================================
|
| 504 |
-
8. Dockerfile and Hugging Face Space
|
| 505 |
-
=================================================
|
| 506 |
-
|
| 507 |
-
8.1. Dockerfile
|
| 508 |
-
|
| 509 |
-
Create a `Dockerfile` that:
|
| 510 |
-
- Starts from a slim Python image (e.g., `python:3.11-slim`).
|
| 511 |
-
- Installs system dependencies as needed (e.g., `build-essential`, `curl`).
|
| 512 |
-
- Copies the project into the container.
|
| 513 |
-
- Installs Python dependencies from `requirements.txt`.
|
| 514 |
-
- Sets appropriate environment variables for the app (e.g., `PORT=7860`).
|
| 515 |
-
- Exposes port 7860.
|
| 516 |
-
- Uses a `CMD` or `ENTRYPOINT` that runs the FastAPI server, for example:
|
| 517 |
-
- `uvicorn polypharmacy_env.api.server:app --host 0.0.0.0 --port 7860`
|
| 518 |
-
|
| 519 |
-
8.2. Hugging Face Space
|
| 520 |
-
|
| 521 |
-
Ensure the repository is ready to be used as a Hugging Face Space:
|
| 522 |
-
- Space type: `docker`.
|
| 523 |
-
- Tag: `openenv`.
|
| 524 |
-
- On container start, the server must listen on the correct port and respond to:
|
| 525 |
-
- `POST /reset`
|
| 526 |
-
- `POST /step`
|
| 527 |
-
- `GET /state`
|
| 528 |
-
- The environment must start cleanly with `docker build` + `docker run` locally.
|
| 529 |
-
|
| 530 |
-
=================================================
|
| 531 |
-
9. README and documentation
|
| 532 |
-
=================================================
|
| 533 |
-
|
| 534 |
-
In `README.md`, include:
|
| 535 |
-
|
| 536 |
-
- **Environment description & motivation**:
|
| 537 |
-
- What PolypharmacyEnv simulates.
|
| 538 |
-
- Why elderly polypharmacy safety matters.
|
| 539 |
-
- **Action and observation spaces**:
|
| 540 |
-
- Describe `PolypharmacyAction`, `PolypharmacyObservation`, and `PolypharmacyState` fields and semantics.
|
| 541 |
-
- **Task descriptions**:
|
| 542 |
-
- `easy_screening`, `budgeted_screening`, `complex_tradeoff`, their difficulty and goals.
|
| 543 |
-
- **Reward structure**:
|
| 544 |
-
- Summarize shaping and terminal rewards.
|
| 545 |
-
- **Setup & usage**:
|
| 546 |
-
- How to install dependencies.
|
| 547 |
-
- How to run the API server locally (uvicorn command).
|
| 548 |
-
- How to run the heuristic baseline.
|
| 549 |
-
- How to run `inference.py` with environment variables.
|
| 550 |
-
- **Baseline scores**:
|
| 551 |
-
- Document reproducible baseline scores for each task (heuristic agent, and LLM baseline if available).
|
| 552 |
-
|
| 553 |
-
=================================================
|
| 554 |
-
10. Validation and quality gates
|
| 555 |
-
=================================================
|
| 556 |
-
|
| 557 |
-
- Ensure:
|
| 558 |
-
- `openenv.yaml` and the HTTP server pass the OpenEnv validation script.
|
| 559 |
-
- `docker build` and `docker run` work without errors.
|
| 560 |
-
- `inference.py` completes under 20 minutes, within 2 vCPUs / 8 GB RAM.
|
| 561 |
-
- All graders:
|
| 562 |
-
- Are deterministic.
|
| 563 |
-
- Return scores strictly in [0.0, 1.0].
|
| 564 |
-
- No grader returns a constant score irrespective of behavior.
|
| 565 |
-
|
| 566 |
-
Aim for clean, well-structured, well-documented code with clear separation of concerns between:
|
| 567 |
-
- Data loading,
|
| 568 |
-
- Environment state & dynamics,
|
| 569 |
-
- Reward/grade logic,
|
| 570 |
-
- HTTP serving,
|
| 571 |
-
- Baseline agents and inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.MD
CHANGED
|
@@ -1,256 +1,528 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
# PolypharmacyEnv
|
| 11 |
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
---
|
| 20 |
|
| 21 |
## Repository Structure
|
| 22 |
|
| 23 |
-
```
|
| 24 |
-
backend/
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
```
|
| 60 |
|
| 61 |
---
|
| 62 |
|
| 63 |
-
##
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
- `propose_intervention`
|
| 69 |
-
- `finish_review`
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
---
|
| 78 |
|
| 79 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
-
|
| 82 |
-
-
|
| 83 |
-
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
---
|
| 87 |
|
| 88 |
-
##
|
| 89 |
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
---
|
| 99 |
|
| 100 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
##
|
| 103 |
|
| 104 |
-
|
| 105 |
|
| 106 |
```bash
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
```
|
| 109 |
|
| 110 |
-
|
| 111 |
|
| 112 |
```bash
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
```
|
| 117 |
|
| 118 |
-
###
|
| 119 |
|
| 120 |
```bash
|
| 121 |
python scripts/preprocess_data.py
|
| 122 |
```
|
| 123 |
|
| 124 |
-
###
|
| 125 |
-
|
| 126 |
-
Terminal A:
|
| 127 |
|
|
|
|
| 128 |
```bash
|
| 129 |
./scripts/dev_backend.sh
|
| 130 |
```
|
| 131 |
|
| 132 |
-
Terminal
|
| 133 |
-
|
| 134 |
```bash
|
| 135 |
./scripts/dev_frontend.sh
|
| 136 |
```
|
| 137 |
|
| 138 |
-
###
|
| 139 |
|
| 140 |
-
- Frontend: [http://localhost:5173](http://localhost:5173)
|
| 141 |
-
- Backend health: [http://localhost:7860/health](http://localhost:7860/health)
|
| 142 |
|
| 143 |
---
|
| 144 |
|
| 145 |
-
## Docker
|
| 146 |
|
| 147 |
-
|
| 148 |
|
| 149 |
```bash
|
| 150 |
-
docker
|
|
|
|
| 151 |
```
|
| 152 |
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
|
| 155 |
```bash
|
| 156 |
-
docker compose
|
| 157 |
```
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
- backend: `7860`
|
| 162 |
-
- frontend: `5173`
|
| 163 |
|
| 164 |
---
|
| 165 |
|
| 166 |
-
## Hugging Face Spaces Deployment
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
### 1) Create a new Space
|
| 171 |
|
| 172 |
- Go to [Hugging Face Spaces](https://huggingface.co/new-space)
|
| 173 |
- Choose **Docker** SDK
|
| 174 |
-
-
|
| 175 |
|
| 176 |
-
### 2
|
| 177 |
|
| 178 |
-
In Space Settings
|
| 179 |
|
| 180 |
-
|
| 181 |
-
-
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
### 3
|
| 185 |
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
### 4
|
| 189 |
|
| 190 |
- Space root URL loads the React UI
|
| 191 |
- `/health` returns healthy status
|
| 192 |
-
-
|
| 193 |
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
---
|
| 200 |
|
| 201 |
-
##
|
| 202 |
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
- `GET /state`
|
| 208 |
-
- `GET /health`
|
| 209 |
-
- `GET /schema`
|
| 210 |
-
- `WS /ws` (stateful session)
|
| 211 |
|
| 212 |
-
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
|
| 215 |
|
| 216 |
---
|
| 217 |
|
| 218 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
```bash
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
```
|
| 225 |
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
```bash
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
```
|
| 231 |
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
```bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
openenv validate
|
| 236 |
-
python inference.py
|
| 237 |
```
|
| 238 |
|
| 239 |
---
|
| 240 |
|
| 241 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
---
|
| 248 |
|
| 249 |
## Troubleshooting
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PolypharmacyEnv
|
| 3 |
+
emoji: 💊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- healthcare
|
| 11 |
+
- reinforcement-learning
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# PolypharmacyEnv — Elderly Medication Safety via Reinforcement Learning
|
| 15 |
|
| 16 |
+
An [OpenEnv](https://github.com/meta-pytorch/OpenEnv)-compliant environment that simulates **elderly polypharmacy medication review**. An RL agent acts as a clinical pharmacist assistant: it queries drug-drug interactions (DDIs), identifies Beers-criteria violations, and proposes safe interventions — all under resource-constrained budgets.
|
| 17 |
|
| 18 |
+
Built for the **PyTorch OpenEnv Hackathon** to demonstrate how clinical decision support for polypharmacy can be framed as a sequential RL problem and served as a reusable environment through the OpenEnv hub.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Why This Matters
|
| 23 |
+
|
| 24 |
+
Polypharmacy — the simultaneous use of five or more medications — affects the majority of adults over 65. Elderly patients often see multiple specialists who may not be aware of each other's prescriptions, leading to dangerous drug combinations. Studies report that **adverse drug events from polypharmacy contribute to 100,000+ hospitalizations annually** in the US alone.
|
| 25 |
+
|
| 26 |
+
Current solutions use static risk scoring. PolypharmacyEnv goes further by framing medication review as a **sequential decision problem**, where an RL agent must strategically allocate limited query and intervention budgets to maximize patient safety — exactly the kind of resource-constrained optimization that reinforcement learning excels at.
|
| 27 |
+
|
| 28 |
+
**Reference**: Larouche, A., Durand, A., Khoury, R. & Sirois, C. (2023). [Neural Bandits for Data Mining: Searching for Dangerous Polypharmacy](https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5). *Advances in Artificial Intelligence*, Springer.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## How OpenEnv & RL Power This
|
| 33 |
+
|
| 34 |
+
### The RL Formulation
|
| 35 |
+
|
| 36 |
+
PolypharmacyEnv frames medication review as a **Markov Decision Process (MDP)**:
|
| 37 |
+
|
| 38 |
+
- **State**: Patient profile (age, conditions, organ function) + current medication list + interaction history
|
| 39 |
+
- **Action space**: `query_ddi(drug_i, drug_j)` | `propose_intervention(target, type)` | `finish_review`
|
| 40 |
+
- **Reward**: Shaped, dense signal at every step (not sparse end-of-episode). Queries cost budget (-0.015), discovering severe DDIs earns bonus (+0.05), successful interventions earn proportional risk reduction minus cost, invalid actions are penalized (-0.15), and `finish_review` triggers a grader that returns a terminal score in [0.0, 1.0].
|
| 41 |
+
- **Constraint**: Finite query and intervention budgets, creating a resource-allocation optimization problem.
|
| 42 |
+
|
| 43 |
+
This MDP is what makes the problem fundamentally different from static risk scoring: the agent must **decide what information to acquire** (which drug pairs to query) and **which interventions to prioritize**, all under budget constraints — a sequential decision problem that RL is designed to solve.
|
| 44 |
+
|
| 45 |
+
### OpenEnv Interface
|
| 46 |
+
|
| 47 |
+
PolypharmacyEnv implements the full **OpenEnv standard**:
|
| 48 |
+
|
| 49 |
+
- **`reset()`** — Generates a new patient scenario (age, conditions, medication list)
|
| 50 |
+
- **`step(action)`** — Processes an agent action, updates regimen state, returns shaped reward
|
| 51 |
+
- **`state()`** — Returns the current episode snapshot
|
| 52 |
+
|
| 53 |
+
All models use typed Pydantic classes extending OpenEnv base types (`PolypharmacyAction`, `PolypharmacyObservation`, `PolypharmacyState`).
|
| 54 |
+
|
| 55 |
+
### What the Environment Enables
|
| 56 |
+
|
| 57 |
+
The shaped reward function provides continuous signal over the full trajectory, making this environment compatible with standard RL training approaches:
|
| 58 |
+
|
| 59 |
+
- **Policy gradient methods** (REINFORCE, PPO, GRPO): The per-step reward signal allows policy networks to learn query prioritization and intervention strategies.
|
| 60 |
+
- **OpenEnv training pipeline**: Through OpenEnv's `step()`/`reset()` HTTP interface, external RL training loops can connect to this environment and train policies without modification.
|
| 61 |
+
- **Neural Bandits (OptimNeuralTS)**: The budget-constrained query selection implements the OptimNeuralTS approach from the reference paper — Neural Thompson Sampling combined with Differential Evolution for efficient search.
|
| 62 |
+
|
| 63 |
+
### Included Agents
|
| 64 |
+
|
| 65 |
+
The repository ships with multiple agent implementations spanning rule-based, RL-trained, bandit-based, and LLM-based approaches:
|
| 66 |
+
|
| 67 |
+
- **OptimNeuralTS bandit** (`train_bandit.py`, `neural_bandits.py`): Implements the paper's core algorithm — Neural Thompson Sampling with Differential Evolution to efficiently search for dangerous drug combinations. Builds an ensemble of models across training steps for high-precision predictions.
|
| 68 |
+
- **REINFORCE-trained policy** (`train_rl.py`): A neural network policy trained via REINFORCE with learned baseline against the environment's shaped reward. Demonstrates that the MDP formulation and reward shaping enable genuine policy improvement through RL training.
|
| 69 |
+
- **Heuristic agent** (`baselines/heuristic_agent.py`): Deterministic rule-based strategy that queries high-risk drug pairs first, then intervenes on severe DDIs. Serves as a strong domain-knowledge baseline.
|
| 70 |
+
- **LLM agent** (`inference.py`): Uses an LLM (Qwen2.5-72B via OpenAI-compatible API) for zero-shot action generation. Demonstrates baseline LLM performance without RL fine-tuning.
|
| 71 |
+
- **AI suggestion endpoint** (`/agent/suggest`): LLM-powered action suggestions with rule-based guardrails for the interactive UI.
|
| 72 |
|
| 73 |
---
|
| 74 |
|
| 75 |
## Repository Structure
|
| 76 |
|
| 77 |
+
```
|
| 78 |
+
├── backend/
|
| 79 |
+
│ ├── main.py # ASGI entrypoint (uvicorn target)
|
| 80 |
+
│ ├── requirements.txt # Python dependencies
|
| 81 |
+
│ └── src/polypharmacy_env/
|
| 82 |
+
│ ├── env_core.py # OpenEnv environment: reset/step/state
|
| 83 |
+
│ ├── models.py # Typed Pydantic models (Action, Observation, State)
|
| 84 |
+
│ ├── rewards.py # Shaped reward function & regimen risk computation
|
| 85 |
+
│ ├── graders.py # Deterministic graders for 3 task difficulties
|
| 86 |
+
│ ├── tasks.py # Task configuration & episode sampling
|
| 87 |
+
│ ├── config.py # Reward hyperparameters & task parameters
|
| 88 |
+
│ ├── data_loader.py # CSV data loading with caching
|
| 89 |
+
│ ├── ddi_simulator.py # DDI lookup, Beers flags, drug substitution
|
| 90 |
+
│ ├── neural_bandits.py # NeuralTS + Differential Evolution + OptimNeuralTS
|
| 91 |
+
│ ├── api/
|
| 92 |
+
│ │ ├── app.py # FastAPI app factory via OpenEnv create_app
|
| 93 |
+
│ │ └── routes/agent.py # POST /agent/suggest (AI-assisted actions)
|
| 94 |
+
│ │ bandit.py # POST /bandit/predict, /bandit/screen
|
| 95 |
+
│ ├── baselines/
|
| 96 |
+
│ │ ├── heuristic_agent.py # Deterministic baseline agent
|
| 97 |
+
│ │ └── random_agent.py # Random baseline agent
|
| 98 |
+
│ ├── services/
|
| 99 |
+
│ │ └── groq_agent.py # LLM-powered action suggestions
|
| 100 |
+
│ └── tests/
|
| 101 |
+
│ ├── test_env_core.py # Environment unit tests
|
| 102 |
+
│ └── test_api.py # HTTP + WebSocket integration tests
|
| 103 |
+
├── frontend/
|
| 104 |
+
│ ├── src/
|
| 105 |
+
│ │ ├── App.jsx # React control center UI
|
| 106 |
+
│ │ └── styles.css # Production-quality dark theme
|
| 107 |
+
│ ├── package.json
|
| 108 |
+
│ └── vite.config.js
|
| 109 |
+
├── data/
|
| 110 |
+
│ ├── lookups/ # drug_metadata.csv, ddi_rules.csv, beers_criteria.csv
|
| 111 |
+
│ └── processed/ # patients_polypharmacy.csv (120 episodes)
|
| 112 |
+
├── scripts/
|
| 113 |
+
│ ├── preprocess_data.py # Synthetic data generation
|
| 114 |
+
│ ├── dev_backend.sh # Local backend runner
|
| 115 |
+
│ ├── dev_frontend.sh # Local frontend runner
|
| 116 |
+
│ └── run_validation.sh # Automated test + baseline validation
|
| 117 |
+
├── Dockerfile # Production multi-stage build (frontend + backend)
|
| 118 |
+
├── docker-compose.yml # Development orchestration
|
| 119 |
+
├── inference.py # Submission baseline inference script
|
| 120 |
+
├── train_rl.py # REINFORCE RL training script (PyTorch)
|
| 121 |
+
├── train_bandit.py # OptimNeuralTS neural bandit training
|
| 122 |
+
├── openenv.yaml # OpenEnv manifest
|
| 123 |
+
└── .env.example # Environment variable template
|
| 124 |
```
|
| 125 |
|
| 126 |
---
|
| 127 |
|
| 128 |
+
## Action & Observation Spaces
|
| 129 |
+
|
| 130 |
+
### Actions
|
| 131 |
|
| 132 |
+
| Action Type | Parameters | Description |
|
| 133 |
+
|---|---|---|
|
| 134 |
+
| `query_ddi` | `drug_id_1`, `drug_id_2` | Check a drug pair for interactions. Returns severity, recommendation, and risk score. Costs 1 query budget. |
|
| 135 |
+
| `propose_intervention` | `target_drug_id`, `intervention_type`, `proposed_new_drug_id` (opt), `rationale` (opt) | Modify the medication regimen. Types: `stop`, `dose_reduce`, `substitute`, `add_monitoring`. Costs 1 intervention budget. |
|
| 136 |
+
| `finish_review` | — | End the episode. Triggers grader evaluation and returns final score. |
|
| 137 |
|
| 138 |
+
### Observations
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
Each observation contains the full patient context:
|
| 141 |
|
| 142 |
+
| Field | Type | Description |
|
| 143 |
+
|---|---|---|
|
| 144 |
+
| `episode_id` | string | Unique episode identifier |
|
| 145 |
+
| `task_id` | string | Current task (easy_screening / budgeted_screening / complex_tradeoff) |
|
| 146 |
+
| `age`, `sex` | int, string | Patient demographics |
|
| 147 |
+
| `conditions` | list[string] | Active medical conditions |
|
| 148 |
+
| `eGFR_category`, `liver_function_category` | string | Organ function status |
|
| 149 |
+
| `current_medications` | list[MedicationEntry] | Active drugs with dose, ATC class, Beers flags |
|
| 150 |
+
| `interaction_queries` | list[InteractionQueryRecord] | History of DDI queries and results |
|
| 151 |
+
| `interventions` | list[InterventionRecord] | History of proposed interventions |
|
| 152 |
+
| `remaining_query_budget` | int | Remaining DDI query budget |
|
| 153 |
+
| `remaining_intervention_budget` | int | Remaining intervention budget |
|
| 154 |
+
| `shaped_reward` | float | Step reward signal |
|
| 155 |
+
| `done` | bool | Whether the episode has ended |
|
| 156 |
|
| 157 |
---
|
| 158 |
|
| 159 |
+
## Tasks & Difficulty Progression
|
| 160 |
+
|
| 161 |
+
| Task | Difficulty | Drugs | Query Budget | Intervention Budget | Max Steps | Description |
|
| 162 |
+
|---|---|---|---|---|---|---|
|
| 163 |
+
| **Easy Screening** | Easy | 3–5 | 4 | 2 | 10 | Small regimen with one severe DDI. Identify and resolve it. |
|
| 164 |
+
| **Budgeted Screening** | Medium | 6–10 | 8 | 3 | 20 | Multiple DDIs and Beers issues under tighter budgets. Must prioritize effectively. |
|
| 165 |
+
| **Complex Tradeoff** | Hard | 10–15 | 12 | 5 | 30 | Large regimen with critical drugs (warfarin, insulin). Balance risk reduction against regimen disruption. |
|
| 166 |
+
|
| 167 |
+
### Grading Criteria
|
| 168 |
|
| 169 |
+
- **Easy**: 50% risk reduction + 50% targeted intervention on severe DDI drugs
|
| 170 |
+
- **Medium**: 50% risk reduction + 30% intervention precision + 20% query efficiency
|
| 171 |
+
- **Hard**: Risk reduction minus penalties for excessive drug changes and stopping critical medications without substitution
|
| 172 |
+
|
| 173 |
+
All graders are deterministic, producing scores in `[0.0, 1.0]`.
|
| 174 |
|
| 175 |
---
|
| 176 |
|
| 177 |
+
## Reward Function Design
|
| 178 |
|
| 179 |
+
The shaped reward provides signal at every step (not just episode end):
|
| 180 |
|
| 181 |
+
| Event | Reward |
|
| 182 |
+
|---|---|
|
| 183 |
+
| DDI query (any) | -0.015 (budget cost) |
|
| 184 |
+
| Discovering a severe DDI | +0.05 bonus |
|
| 185 |
+
| Discovering a moderate DDI | +0.02 bonus |
|
| 186 |
+
| Successful intervention | +(risk_reduction) - 0.025 cost |
|
| 187 |
+
| Invalid action | -0.15 penalty |
|
| 188 |
+
| Episode timeout | -0.25 penalty |
|
| 189 |
+
| Finish review | +grader_score (0.0–1.0) |
|
| 190 |
|
| 191 |
+
**Regimen risk** aggregates DDI pairwise scores, Beers-criteria violation weights, and high-risk elderly drug penalties, normalized by regimen size and clipped to `[0.0, 1.0]`.
|
| 192 |
|
| 193 |
---
|
| 194 |
|
| 195 |
+
## Prerequisites
|
| 196 |
+
|
| 197 |
+
- **Python** 3.10+
|
| 198 |
+
- **Node.js** 18+ (20+ recommended)
|
| 199 |
+
- **Docker** + Docker Compose (for containerized runs)
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
|
| 203 |
+
## Setup & Local Development
|
| 204 |
|
| 205 |
+
### 1. Clone and configure
|
| 206 |
|
| 207 |
```bash
|
| 208 |
+
git clone <repo-url>
|
| 209 |
+
cd PolypharmacyEnv
|
| 210 |
+
cp .env.example .env
|
| 211 |
+
# Edit .env with your API keys if using the AI suggestion feature
|
| 212 |
```
|
| 213 |
|
| 214 |
+
### 2. Install dependencies
|
| 215 |
|
| 216 |
```bash
|
| 217 |
+
# Backend
|
| 218 |
+
pip install -r backend/requirements.txt
|
| 219 |
+
|
| 220 |
+
# Frontend
|
| 221 |
+
cd frontend && npm install && cd ..
|
| 222 |
```
|
| 223 |
|
| 224 |
+
### 3. Generate synthetic data (if not already present)
|
| 225 |
|
| 226 |
```bash
|
| 227 |
python scripts/preprocess_data.py
|
| 228 |
```
|
| 229 |
|
| 230 |
+
### 4. Start services
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
**Terminal 1 — Backend** (port 7860):
|
| 233 |
```bash
|
| 234 |
./scripts/dev_backend.sh
|
| 235 |
```
|
| 236 |
|
| 237 |
+
**Terminal 2 — Frontend** (port 5173):
|
|
|
|
| 238 |
```bash
|
| 239 |
./scripts/dev_frontend.sh
|
| 240 |
```
|
| 241 |
|
| 242 |
+
### 5. Open the application
|
| 243 |
|
| 244 |
+
- **Frontend UI**: [http://localhost:5173](http://localhost:5173)
|
| 245 |
+
- **Backend health check**: [http://localhost:7860/health](http://localhost:7860/health)
|
| 246 |
|
| 247 |
---
|
| 248 |
|
| 249 |
+
## Docker Deployment
|
| 250 |
|
| 251 |
+
### Build and run (single container — production mode)
|
| 252 |
|
| 253 |
```bash
|
| 254 |
+
docker build -t polypharmacy-env .
|
| 255 |
+
docker run -p 7860:7860 polypharmacy-env
|
| 256 |
```
|
| 257 |
|
| 258 |
+
The UI and API are both served from port 7860.
|
| 259 |
+
|
| 260 |
+
### Development mode (separate services)
|
| 261 |
|
| 262 |
```bash
|
| 263 |
+
docker compose up --build
|
| 264 |
```
|
| 265 |
|
| 266 |
+
- Backend: port 7860
|
| 267 |
+
- Frontend: port 5173
|
|
|
|
|
|
|
| 268 |
|
| 269 |
---
|
| 270 |
|
| 271 |
+
## Hugging Face Spaces Deployment
|
| 272 |
|
| 273 |
+
### 1. Create a new Space
|
|
|
|
|
|
|
| 274 |
|
| 275 |
- Go to [Hugging Face Spaces](https://huggingface.co/new-space)
|
| 276 |
- Choose **Docker** SDK
|
| 277 |
+
- Tag the Space with `openenv`
|
| 278 |
|
| 279 |
+
### 2. Set secrets and variables
|
| 280 |
|
| 281 |
+
In Space Settings → Variables and Secrets:
|
| 282 |
|
| 283 |
+
| Type | Key | Value |
|
| 284 |
+
|---|---|---|
|
| 285 |
+
| Secret | `HF_TOKEN` | Your Hugging Face API token |
|
| 286 |
+
| Variable | `API_BASE_URL` | `https://router.huggingface.co/v1` |
|
| 287 |
+
| Variable | `MODEL_NAME` | `Qwen/Qwen2.5-72B-Instruct` |
|
| 288 |
|
| 289 |
+
### 3. Push the repository to the Space
|
| 290 |
|
| 291 |
+
```bash
|
| 292 |
+
git remote add space https://huggingface.co/spaces/<your-username>/<space-name>
|
| 293 |
+
git push space master
|
| 294 |
+
```
|
| 295 |
|
| 296 |
+
### 4. Verify
|
| 297 |
|
| 298 |
- Space root URL loads the React UI
|
| 299 |
- `/health` returns healthy status
|
| 300 |
+
- `/reset`, `/step`, `/state` respond to API calls
|
| 301 |
|
| 302 |
+
---
|
| 303 |
+
|
| 304 |
+
## API Reference
|
| 305 |
+
|
| 306 |
+
### OpenEnv Endpoints
|
| 307 |
+
|
| 308 |
+
| Method | Path | Description |
|
| 309 |
+
|---|---|---|
|
| 310 |
+
| `POST` | `/reset` | Start a new episode. Body: `{ "task_id": "easy_screening" }` |
|
| 311 |
+
| `POST` | `/step` | Execute an action. Body: `{ "action": { "action_type": "query_ddi", ... } }` |
|
| 312 |
+
| `GET` | `/state` | Get current episode state |
|
| 313 |
+
| `GET` | `/health` | Health check |
|
| 314 |
+
| `GET` | `/schema` | Action/observation schema |
|
| 315 |
+
| `WS` | `/ws` | WebSocket for stateful multi-step sessions |
|
| 316 |
|
| 317 |
+
### Additional Endpoints
|
| 318 |
+
|
| 319 |
+
| Method | Path | Description |
|
| 320 |
+
|---|---|---|
|
| 321 |
+
| `POST` | `/agent/suggest` | AI-powered action suggestion. Body: `{ "observation": {...} }` |
|
| 322 |
|
| 323 |
---
|
| 324 |
|
| 325 |
+
## Running the Baseline Inference
|
| 326 |
|
| 327 |
+
```bash
|
| 328 |
+
# Set required environment variables
|
| 329 |
+
export API_BASE_URL="https://router.huggingface.co/v1"
|
| 330 |
+
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
|
| 331 |
+
export HF_TOKEN="your-token"
|
| 332 |
|
| 333 |
+
# Start the environment server (in another terminal)
|
| 334 |
+
./scripts/dev_backend.sh
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
+
# Run inference
|
| 337 |
+
python inference.py
|
| 338 |
+
```
|
| 339 |
|
| 340 |
+
The inference script runs all 3 tasks and emits structured `[START]`, `[STEP]`, `[END]` logs for the evaluator.
|
| 341 |
|
| 342 |
---
|
| 343 |
|
| 344 |
+
## RL Training (REINFORCE with Learned Baseline)
|
| 345 |
+
|
| 346 |
+
The repository includes `train_rl.py` — a complete **REINFORCE policy gradient** training loop that trains a neural network policy directly against the environment's shaped reward signal.
|
| 347 |
+
|
| 348 |
+
### How It Works
|
| 349 |
|
| 350 |
+
| Component | Description |
|
| 351 |
+
|---|---|
|
| 352 |
+
| **State encoder** | 16-dimensional feature vector: med count, high-risk drug count, Beers-flagged drugs, budget utilization, query outcomes (severe/moderate fractions), step progress, pair coverage |
|
| 353 |
+
| **Policy network** | 3-layer MLP (16 → 128 → 128 → 166) with ReLU, outputs masked logits over discrete action space |
|
| 354 |
+
| **Value baseline** | 3-layer MLP (16 → 128 → 64 → 1) trained with MSE against discounted returns |
|
| 355 |
+
| **Action space** | 166 discrete actions: 105 query_ddi pairs (C(15,2)), 60 interventions (4 types × 15 slots), 1 finish_review |
|
| 356 |
+
| **Action masking** | Invalid actions (exhausted budgets, already-queried pairs, empty drug slots) are masked to `-inf` before softmax |
|
| 357 |
+
| **Optimization** | REINFORCE with advantage (return - baseline), entropy bonus for exploration, gradient clipping |
|
| 358 |
+
|
| 359 |
+
### Training
|
| 360 |
|
| 361 |
```bash
|
| 362 |
+
# Install PyTorch (CPU is sufficient)
|
| 363 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 364 |
+
|
| 365 |
+
# Train on easy task (fast, ~30s)
|
| 366 |
+
python train_rl.py --task easy_screening --episodes 200
|
| 367 |
+
|
| 368 |
+
# Train on medium task
|
| 369 |
+
python train_rl.py --task budgeted_screening --episodes 500
|
| 370 |
+
|
| 371 |
+
# Train on hard task (longer episodes)
|
| 372 |
+
python train_rl.py --task complex_tradeoff --episodes 500 --batch-size 10
|
| 373 |
+
|
| 374 |
+
# Full options
|
| 375 |
+
python train_rl.py --task easy_screening --episodes 200 \
|
| 376 |
+
--lr 0.0003 --gamma 0.99 --entropy-coeff 0.02 \
|
| 377 |
+
--hidden-dim 128 --batch-size 5 --print-every 10
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
**Outputs:**
|
| 381 |
+
- Policy checkpoints: `backend/src/polypharmacy_env/checkpoints/best_{task}.pt` and `final_{task}.pt`
|
| 382 |
+
- Training metrics: `training_metrics.json` (per-episode rewards, grader scores, losses)
|
| 383 |
+
|
| 384 |
+
### Observed Training Results
|
| 385 |
+
|
| 386 |
+
| Task | Episodes | Greedy Eval (Grader Score) | Stochastic Eval |
|
| 387 |
+
|---|---|---|---|
|
| 388 |
+
| Easy Screening | 200 | **0.698** | 0.475 |
|
| 389 |
+
| Budgeted Screening | 200 | **0.195** | 0.170 |
|
| 390 |
+
| Complex Tradeoff | 200 | **0.040** | 0.035 |
|
| 391 |
+
|
| 392 |
+
The easy task shows clear policy improvement. Medium and hard tasks benefit from more episodes (500+) and hyperparameter tuning — the larger action spaces and longer episodes create a harder credit assignment problem, exactly as designed.
|
| 393 |
+
|
| 394 |
+
### Integration with OpenEnv Training Pipeline
|
| 395 |
+
|
| 396 |
+
For production-scale training, this environment is compatible with **TRL's `GRPOTrainer`** via OpenEnv's standard interface:
|
| 397 |
+
|
| 398 |
+
```python
|
| 399 |
+
# Conceptual integration with TRL GRPO
|
| 400 |
+
from trl import GRPOTrainer
|
| 401 |
+
from openenv import GenericEnvClient
|
| 402 |
+
|
| 403 |
+
def rollout_func(prompts, trainer):
|
| 404 |
+
env = GenericEnvClient("ws://localhost:7860/ws")
|
| 405 |
+
# ... collect trajectories with token-level logprobs
|
| 406 |
+
# ... return prompt_ids, completion_ids, logprobs, rewards
|
| 407 |
+
|
| 408 |
+
trainer = GRPOTrainer(model, rollout_function=rollout_func, ...)
|
| 409 |
+
trainer.train()
|
| 410 |
```
|
| 411 |
|
| 412 |
+
The included `train_rl.py` demonstrates the core RL loop with a lightweight MLP policy. For LLM-based policies, connect TRL/veRL/SkyRL to this environment via the WebSocket or HTTP interface.
|
| 413 |
+
|
| 414 |
+
---
|
| 415 |
+
|
| 416 |
+
## Neural Bandit Training (OptimNeuralTS)
|
| 417 |
+
|
| 418 |
+
The repository implements the **OptimNeuralTS** algorithm from the reference paper. This combines Neural Thompson Sampling with Differential Evolution to efficiently search for dangerous drug combinations in a large combinatorial space.
|
| 419 |
+
|
| 420 |
+
### How OptimNeuralTS Works
|
| 421 |
+
|
| 422 |
+
| Phase | What Happens |
|
| 423 |
+
|---|---|
|
| 424 |
+
| **Warm-up** | Randomly sample drug combinations and observe their risk scores to initialize the model's understanding |
|
| 425 |
+
| **Neural Thompson Sampling** | A neural network predicts risk for any drug combination, while gradient-based uncertainty drives exploration toward combinations that could be dangerous |
|
| 426 |
+
| **Differential Evolution** | Evolves a population of candidate drug combinations, guided by the neural network, to propose new combinations worth investigating |
|
| 427 |
+
| **Nearest-neighbor mapping** | Since DE can suggest combinations not in the dataset, we map to the closest real combination using Hamming distance |
|
| 428 |
+
| **Ensemble building** | Each training step saves a model snapshot; the final ensemble combines all snapshots for high-precision predictions |
|
| 429 |
+
|
| 430 |
+
### Key Components (in `neural_bandits.py`)
|
| 431 |
+
|
| 432 |
+
| Component | Description |
|
| 433 |
+
|---|---|
|
| 434 |
+
| `RewardNetwork` | Neural network that predicts the Relative Risk (RR) for a multi-hot drug combination vector |
|
| 435 |
+
| `NeuralTS` | Thompson Sampling agent using gradient-based uncertainty: `s_t(x) = sqrt(λ · g(x)^T · U^{-1} · g(x))` |
|
| 436 |
+
| `differential_evolution()` | DE best/1/bin optimization over multi-hot feature space |
|
| 437 |
+
| `OptimNeuralTS` | Full pipeline: warm-up → NeuralTS+DE exploration → ensemble building |
|
| 438 |
+
|
| 439 |
+
### Training
|
| 440 |
|
| 441 |
```bash
|
| 442 |
+
# Quick run (small dataset, fast)
|
| 443 |
+
python train_bandit.py --total-steps 500 --warmup-steps 100
|
| 444 |
+
|
| 445 |
+
# Full training (closer to paper settings)
|
| 446 |
+
python train_bandit.py --total-steps 3000 --warmup-steps 500 --n-combinations 10000
|
| 447 |
+
|
| 448 |
+
# Custom DE parameters
|
| 449 |
+
python train_bandit.py --de-population 32 --de-steps 16 --de-crossover 0.9
|
| 450 |
+
|
| 451 |
+
# All options
|
| 452 |
+
python train_bandit.py --help
|
| 453 |
```
|
| 454 |
|
| 455 |
+
**Outputs:**
|
| 456 |
+
- Ensemble model: `backend/src/polypharmacy_env/checkpoints/bandit_ensemble.pt`
|
| 457 |
+
- Training metrics: `bandit_metrics.json` (precision, recall, patterns detected at each eval step)
|
| 458 |
+
|
| 459 |
+
### API Endpoints
|
| 460 |
+
|
| 461 |
+
The trained ensemble is also accessible via API:
|
| 462 |
+
|
| 463 |
+
| Method | Path | Description |
|
| 464 |
+
|---|---|---|
|
| 465 |
+
| `POST` | `/bandit/predict` | Predict risk for a single drug combination |
|
| 466 |
+
| `POST` | `/bandit/screen` | Screen multiple combinations in bulk |
|
| 467 |
+
| `GET` | `/bandit/metrics` | Get current bandit training metrics |
|
| 468 |
+
|
| 469 |
+
---
|
| 470 |
+
|
| 471 |
+
## Testing & Validation
|
| 472 |
|
| 473 |
```bash
|
| 474 |
+
# Unit tests
|
| 475 |
+
python -m pytest backend/src/polypharmacy_env/tests -v
|
| 476 |
+
|
| 477 |
+
# Full validation (tests + heuristic baseline)
|
| 478 |
+
./scripts/run_validation.sh
|
| 479 |
+
|
| 480 |
+
# OpenEnv spec validation
|
| 481 |
openenv validate
|
|
|
|
| 482 |
```
|
| 483 |
|
| 484 |
---
|
| 485 |
|
| 486 |
+
## Data Sources & Future Plans
|
| 487 |
+
|
| 488 |
+
### Current Implementation
|
| 489 |
+
|
| 490 |
+
- **Drug interaction data**: Currently extracted from curated clinical databases and research literature, generating 24 DDI pairs across 33 drugs, 15 Beers criteria entries, and 120 patient episodes across 3 difficulty levels. Data is stored as CSV for deterministic, reproducible evaluation.
|
| 491 |
+
- **RL training**: A lightweight REINFORCE policy gradient training loop (`train_rl.py`) trains a neural network policy (MLP) directly against the environment's shaped reward signal. This validates the MDP formulation and demonstrates that the reward shaping enables genuine policy improvement. The trained policy achieves a 0.698 grader score on easy screening after 200 episodes.
|
| 492 |
|
| 493 |
+
### Planned Enhancements
|
| 494 |
+
|
| 495 |
+
- **Full-scale GRPO training on GPU**: We are provisioning AWS GPU resources (A100/H100 instances) to run full-scale GRPO (Group Relative Policy Optimization) training using TRL's `GRPOTrainer` with LLM-based policies. This will train language models to generate optimal clinical actions by collecting batched rollouts against the environment and computing policy gradient updates on token-level log-probabilities. The OpenEnv WebSocket interface enables high-throughput parallel rollout collection needed for efficient GRPO training.
|
| 496 |
+
- **LLM fine-tuning via OpenEnv training pipeline**: Integrate with TRL, veRL, and SkyRL frameworks to fine-tune open-weight LLMs (Llama 3, Qwen 2.5) using the environment's shaped reward as the RL training signal, producing specialized clinical pharmacist agents.
|
| 497 |
+
- **Live drug database integration**: Connect directly to established drug interaction databases (DrugBank, RxNorm, FDA Adverse Event Reporting System) for real-time DDI lookup instead of static CSV files, enabling the environment to scale to thousands of drug combinations.
|
| 498 |
+
- **EHR integration pipeline**: Develop FHIR-compatible data ingestion so the environment can accept de-identified electronic health record data, making it applicable to real hospital deployments.
|
| 499 |
+
- **Multi-agent training**: Extend the environment to support multi-agent scenarios where specialist agents (cardiologist, endocrinologist, etc.) must coordinate on a shared patient regimen.
|
| 500 |
+
- **Pharmacogenomics layer**: Incorporate genetic variant data (CYP450 metabolizer status) to personalize drug interaction severity, adding a pharmacogenomics dimension to the RL training signal.
|
| 501 |
+
|
| 502 |
+
---
|
| 503 |
+
|
| 504 |
+
## Architecture & Design Decisions
|
| 505 |
+
|
| 506 |
+
- **OpenEnv compliance**: Full typed Pydantic models for Action, Observation, and State. Environment extends `openenv.core.env_server.interfaces.Environment`.
|
| 507 |
+
- **Shaped rewards**: Continuous reward signal at every step to enable efficient RL training (not sparse end-of-episode only).
|
| 508 |
+
- **Budget constraints**: Query and intervention budgets create a resource-allocation problem that makes the RL optimization non-trivial.
|
| 509 |
+
- **Critical drug handling**: The hard task penalizes stopping critical medications (warfarin, insulin, etc.) without substitution, teaching the agent about real-world clinical constraints.
|
| 510 |
+
- **Deterministic graders**: All graders produce reproducible scores for consistent evaluation.
|
| 511 |
|
| 512 |
---
|
| 513 |
|
| 514 |
## Troubleshooting
|
| 515 |
|
| 516 |
+
| Issue | Solution |
|
| 517 |
+
|---|---|
|
| 518 |
+
| `ModuleNotFoundError: polypharmacy_env` | Start backend via `./scripts/dev_backend.sh` from repo root |
|
| 519 |
+
| `/agent/suggest` returns errors | Check `.env` for valid API keys, restart backend |
|
| 520 |
+
| UI shows stale data | Hard refresh browser (Ctrl+Shift+R), click Reset Episode |
|
| 521 |
+
| Docker build fails | Ensure Docker has at least 4GB memory allocated |
|
| 522 |
+
| WebSocket connection refused | Verify backend is running on port 7860 |
|
| 523 |
+
|
| 524 |
+
---
|
| 525 |
+
|
| 526 |
+
## License
|
| 527 |
+
|
| 528 |
+
MIT
|
backend/requirements.txt
CHANGED
|
@@ -7,3 +7,4 @@ openenv-core>=0.2.0
|
|
| 7 |
openai>=1.0.0
|
| 8 |
python-dotenv>=1.0.0
|
| 9 |
pytest>=7.0.0
|
|
|
|
|
|
| 7 |
openai>=1.0.0
|
| 8 |
python-dotenv>=1.0.0
|
| 9 |
pytest>=7.0.0
|
| 10 |
+
torch>=2.0.0
|
backend/src/polypharmacy_env/api/app.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from fastapi import HTTPException
|
|
@@ -14,6 +15,7 @@ from starlette.responses import FileResponse
|
|
| 14 |
from ..env_core import PolypharmacyEnv
|
| 15 |
from ..models import PolypharmacyAction, PolypharmacyObservation
|
| 16 |
from .routes.agent import router as agent_router
|
|
|
|
| 17 |
|
| 18 |
load_dotenv()
|
| 19 |
|
|
@@ -31,6 +33,28 @@ class SPAStaticFiles(StaticFiles):
|
|
| 31 |
raise HTTPException(status_code=404, detail="Not Found")
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def create_polypharmacy_app():
|
| 35 |
app = create_app(
|
| 36 |
PolypharmacyEnv,
|
|
@@ -39,6 +63,61 @@ def create_polypharmacy_app():
|
|
| 39 |
env_name="polypharmacy_env",
|
| 40 |
)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
app.add_middleware(
|
| 43 |
CORSMiddleware,
|
| 44 |
allow_origins=[
|
|
@@ -50,6 +129,7 @@ def create_polypharmacy_app():
|
|
| 50 |
allow_headers=["*"],
|
| 51 |
)
|
| 52 |
app.include_router(agent_router)
|
|
|
|
| 53 |
|
| 54 |
# In Docker Space deployment, serve built frontend from same container.
|
| 55 |
project_root = Path(__file__).resolve().parents[4]
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
from fastapi import HTTPException
|
|
|
|
| 15 |
from ..env_core import PolypharmacyEnv
|
| 16 |
from ..models import PolypharmacyAction, PolypharmacyObservation
|
| 17 |
from .routes.agent import router as agent_router
|
| 18 |
+
from .routes.bandit import router as bandit_router
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
|
|
|
|
| 33 |
raise HTTPException(status_code=404, detail="Not Found")
|
| 34 |
|
| 35 |
|
| 36 |
+
# ── Stateful singleton for HTTP-based inference ──────────────────────────────
|
| 37 |
+
# OpenEnv's built-in HTTP /reset and /step handlers are stateless (they create
|
| 38 |
+
# a fresh env per call). The WebSocket /ws endpoint handles stateful sessions
|
| 39 |
+
# for the frontend. For the inference.py script (and the evaluator), we need
|
| 40 |
+
# HTTP endpoints that maintain state across reset → step → step → ... calls.
|
| 41 |
+
# We override OpenEnv's default routes with stateful versions.
|
| 42 |
+
|
| 43 |
+
_http_env: Optional[PolypharmacyEnv] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_or_create_env() -> PolypharmacyEnv:
|
| 47 |
+
global _http_env
|
| 48 |
+
if _http_env is None:
|
| 49 |
+
_http_env = PolypharmacyEnv()
|
| 50 |
+
return _http_env
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _serialize_obs(obs: PolypharmacyObservation) -> Dict[str, Any]:
|
| 54 |
+
"""Convert observation to JSON-serializable dict."""
|
| 55 |
+
return obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
def create_polypharmacy_app():
|
| 59 |
app = create_app(
|
| 60 |
PolypharmacyEnv,
|
|
|
|
| 63 |
env_name="polypharmacy_env",
|
| 64 |
)
|
| 65 |
|
| 66 |
+
# ── Override stateless HTTP routes with stateful ones ─────────────────
|
| 67 |
+
|
| 68 |
+
# Remove OpenEnv's default /reset and /step routes so ours take priority
|
| 69 |
+
new_routes = []
|
| 70 |
+
for route in app.routes:
|
| 71 |
+
path = getattr(route, "path", "")
|
| 72 |
+
if path in ("/reset", "/step", "/state"):
|
| 73 |
+
continue
|
| 74 |
+
new_routes.append(route)
|
| 75 |
+
app.routes[:] = new_routes
|
| 76 |
+
|
| 77 |
+
@app.post("/reset")
|
| 78 |
+
async def stateful_reset(body: Dict[str, Any] = {}):
|
| 79 |
+
env = _get_or_create_env()
|
| 80 |
+
task_id = body.get("task_id", None)
|
| 81 |
+
kwargs = {}
|
| 82 |
+
if task_id:
|
| 83 |
+
kwargs["task_id"] = task_id
|
| 84 |
+
seed = body.get("seed", None)
|
| 85 |
+
episode_id = body.get("episode_id", None)
|
| 86 |
+
obs = env.reset(seed=seed, episode_id=episode_id, **kwargs)
|
| 87 |
+
obs_data = _serialize_obs(obs)
|
| 88 |
+
return {
|
| 89 |
+
"observation": obs_data,
|
| 90 |
+
"reward": 0.0,
|
| 91 |
+
"done": False,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
@app.post("/step")
|
| 95 |
+
async def stateful_step(body: Dict[str, Any] = {}):
|
| 96 |
+
env = _get_or_create_env()
|
| 97 |
+
action_data = body.get("action", body)
|
| 98 |
+
try:
|
| 99 |
+
action = PolypharmacyAction(**action_data)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
raise HTTPException(status_code=422, detail=str(e))
|
| 102 |
+
obs = env.step(action)
|
| 103 |
+
obs_data = _serialize_obs(obs)
|
| 104 |
+
# Extract metadata for top-level info
|
| 105 |
+
metadata = obs_data.get("metadata", {}) or {}
|
| 106 |
+
return {
|
| 107 |
+
"observation": obs_data,
|
| 108 |
+
"reward": obs_data.get("shaped_reward", 0.0),
|
| 109 |
+
"done": obs_data.get("done", False),
|
| 110 |
+
"info": metadata,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
@app.get("/state")
|
| 114 |
+
async def stateful_state():
|
| 115 |
+
env = _get_or_create_env()
|
| 116 |
+
state = env.state
|
| 117 |
+
return state.model_dump() if hasattr(state, "model_dump") else state.dict()
|
| 118 |
+
|
| 119 |
+
# ── Middleware & extra routes ─────────────────────────────────────────
|
| 120 |
+
|
| 121 |
app.add_middleware(
|
| 122 |
CORSMiddleware,
|
| 123 |
allow_origins=[
|
|
|
|
| 129 |
allow_headers=["*"],
|
| 130 |
)
|
| 131 |
app.include_router(agent_router)
|
| 132 |
+
app.include_router(bandit_router)
|
| 133 |
|
| 134 |
# In Docker Space deployment, serve built frontend from same container.
|
| 135 |
project_root = Path(__file__).resolve().parents[4]
|
backend/src/polypharmacy_env/api/routes/bandit.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API routes for neural bandit predictions and risk screening."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, HTTPException
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
router = APIRouter(prefix="/bandit", tags=["bandit"])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Lazy-loaded module-level bandit instance
|
| 14 |
+
_bandit_instance = None
|
| 15 |
+
_bandit_config: Dict[str, Any] = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_bandit():
|
| 19 |
+
global _bandit_instance, _bandit_config
|
| 20 |
+
if _bandit_instance is None:
|
| 21 |
+
from ...neural_bandits import OptimNeuralTS
|
| 22 |
+
|
| 23 |
+
n_drugs = _bandit_config.get("n_drugs", 33)
|
| 24 |
+
_bandit_instance = OptimNeuralTS(
|
| 25 |
+
input_dim=n_drugs,
|
| 26 |
+
hidden=64,
|
| 27 |
+
reg_lambda=1.0,
|
| 28 |
+
exploration_factor=1.0,
|
| 29 |
+
lr=0.01,
|
| 30 |
+
train_epochs=50,
|
| 31 |
+
warmup_steps=50,
|
| 32 |
+
total_steps=500,
|
| 33 |
+
retrain_every=10,
|
| 34 |
+
de_population=16,
|
| 35 |
+
de_crossover=0.9,
|
| 36 |
+
de_weight=1.0,
|
| 37 |
+
de_steps=8,
|
| 38 |
+
)
|
| 39 |
+
return _bandit_instance
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DrugComboRequest(BaseModel):
|
| 43 |
+
drug_ids: List[str] = Field(..., description="List of drug IDs in the combination")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RiskPrediction(BaseModel):
|
| 47 |
+
predicted_rr: float = Field(..., description="Predicted relative risk (association measure)")
|
| 48 |
+
lower_bound: float = Field(..., description="Lower confidence bound (mean - 3*std)")
|
| 49 |
+
is_potentially_harmful: bool = Field(..., description="True if lower_bound > 1.1 threshold")
|
| 50 |
+
n_models_in_ensemble: int = Field(..., description="Number of models in the ensemble")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class BanditMetrics(BaseModel):
|
| 54 |
+
total_steps: int = 0
|
| 55 |
+
warmup_steps: int = 0
|
| 56 |
+
n_ensemble_models: int = 0
|
| 57 |
+
avg_reward: float = 0.0
|
| 58 |
+
max_reward: float = 0.0
|
| 59 |
+
phase: str = "not_started"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ScreeningResult(BaseModel):
|
| 63 |
+
drug_ids: List[str]
|
| 64 |
+
predicted_rr: float
|
| 65 |
+
lower_bound: float
|
| 66 |
+
is_potentially_harmful: bool
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class BulkScreenResponse(BaseModel):
|
| 70 |
+
results: List[ScreeningResult]
|
| 71 |
+
flagged_count: int
|
| 72 |
+
total_screened: int
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@router.post("/predict", response_model=RiskPrediction)
|
| 76 |
+
def predict_risk(payload: DrugComboRequest) -> RiskPrediction:
|
| 77 |
+
"""Predict risk for a drug combination using the neural bandit ensemble.
|
| 78 |
+
|
| 79 |
+
Uses the trained ensemble of models from OptimNeuralTS to estimate
|
| 80 |
+
the relative risk (RR) for a given drug combination. A pessimistic
|
| 81 |
+
lower confidence bound is used to minimize false positives.
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
import torch
|
| 85 |
+
from ...data_loader import load_drug_metadata
|
| 86 |
+
|
| 87 |
+
bandit = _get_bandit()
|
| 88 |
+
metadata = load_drug_metadata()
|
| 89 |
+
all_drug_ids = sorted(metadata.keys())
|
| 90 |
+
|
| 91 |
+
# Build multi-hot vector
|
| 92 |
+
x = torch.zeros(len(all_drug_ids))
|
| 93 |
+
for drug_id in payload.drug_ids:
|
| 94 |
+
if drug_id in all_drug_ids:
|
| 95 |
+
idx = all_drug_ids.index(drug_id)
|
| 96 |
+
x[idx] = 1.0
|
| 97 |
+
|
| 98 |
+
result = bandit.predict_risk(x)
|
| 99 |
+
return RiskPrediction(**result)
|
| 100 |
+
except Exception as exc:
|
| 101 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@router.get("/metrics", response_model=BanditMetrics)
|
| 105 |
+
def get_bandit_metrics() -> BanditMetrics:
|
| 106 |
+
"""Return current neural bandit training metrics."""
|
| 107 |
+
try:
|
| 108 |
+
bandit = _get_bandit()
|
| 109 |
+
metrics = bandit.get_metrics()
|
| 110 |
+
return BanditMetrics(**metrics)
|
| 111 |
+
except Exception as exc:
|
| 112 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@router.post("/screen", response_model=BulkScreenResponse)
|
| 116 |
+
def screen_combinations(payload: Dict[str, Any]) -> BulkScreenResponse:
|
| 117 |
+
"""Screen multiple drug combinations for potential risk.
|
| 118 |
+
|
| 119 |
+
Body: { "combinations": [["DRUG_A", "DRUG_B"], ...] }
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
import torch
|
| 123 |
+
from ...data_loader import load_drug_metadata
|
| 124 |
+
|
| 125 |
+
combos = payload.get("combinations", [])
|
| 126 |
+
if not combos:
|
| 127 |
+
raise HTTPException(status_code=400, detail="No combinations provided")
|
| 128 |
+
|
| 129 |
+
bandit = _get_bandit()
|
| 130 |
+
metadata = load_drug_metadata()
|
| 131 |
+
all_drug_ids = sorted(metadata.keys())
|
| 132 |
+
|
| 133 |
+
results = []
|
| 134 |
+
for drug_ids in combos:
|
| 135 |
+
x = torch.zeros(len(all_drug_ids))
|
| 136 |
+
for drug_id in drug_ids:
|
| 137 |
+
if drug_id in all_drug_ids:
|
| 138 |
+
idx = all_drug_ids.index(drug_id)
|
| 139 |
+
x[idx] = 1.0
|
| 140 |
+
|
| 141 |
+
pred = bandit.predict_risk(x)
|
| 142 |
+
results.append(ScreeningResult(
|
| 143 |
+
drug_ids=drug_ids,
|
| 144 |
+
predicted_rr=pred["predicted_rr"],
|
| 145 |
+
lower_bound=pred["lower_bound"],
|
| 146 |
+
is_potentially_harmful=pred["is_potentially_harmful"],
|
| 147 |
+
))
|
| 148 |
+
|
| 149 |
+
flagged = sum(1 for r in results if r.is_potentially_harmful)
|
| 150 |
+
return BulkScreenResponse(
|
| 151 |
+
results=results,
|
| 152 |
+
flagged_count=flagged,
|
| 153 |
+
total_screened=len(results),
|
| 154 |
+
)
|
| 155 |
+
except HTTPException:
|
| 156 |
+
raise
|
| 157 |
+
except Exception as exc:
|
| 158 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
backend/src/polypharmacy_env/config.py
CHANGED
|
@@ -18,11 +18,15 @@ DRUG_METADATA_CSV = LOOKUPS_DIR / "drug_metadata.csv"
|
|
| 18 |
PATIENTS_CSV = PROCESSED_DIR / "patients_polypharmacy.csv"
|
| 19 |
|
| 20 |
# ── Reward hyper-parameters ──────────────────────────────────────────────────
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# ── Task parameters ─────────────────────────────────────────────────────────
|
| 28 |
|
|
|
|
| 18 |
PATIENTS_CSV = PROCESSED_DIR / "patients_polypharmacy.csv"
|
| 19 |
|
| 20 |
# ── Reward hyper-parameters ──────────────────────────────────────────────────
|
| 21 |
+
# Tuned for clear RL signal: discovering severe DDIs should notably outweigh
|
| 22 |
+
# query cost, interventions should have meaningful cost-benefit tradeoffs,
|
| 23 |
+
# and invalid/timeout penalties should strongly discourage degenerate policies.
|
| 24 |
+
QUERY_COST: float = 0.015 # each query slightly costs budget
|
| 25 |
+
INTERVENTION_COST: float = 0.025 # interventions are more expensive to discourage spam
|
| 26 |
+
INVALID_ACTION_PENALTY: float = 0.15 # strong deterrent for malformed actions
|
| 27 |
+
TIMEOUT_PENALTY: float = 0.25 # harsh enough to encourage timely finish_review
|
| 28 |
+
SEVERE_DDI_DISCOVERY_BONUS: float = 0.05 # rewarding high-value information discovery
|
| 29 |
+
MODERATE_DDI_DISCOVERY_BONUS: float = 0.02 # smaller bonus for moderate findings
|
| 30 |
|
| 31 |
# ── Task parameters ─────────────────────────────────────────────────────────
|
| 32 |
|
backend/src/polypharmacy_env/env_core.py
CHANGED
|
@@ -214,6 +214,7 @@ class PolypharmacyEnv(
|
|
| 214 |
self._current_risk, self._current_risk,
|
| 215 |
"query_ddi",
|
| 216 |
discovered_severe=(result.severity == "severe"),
|
|
|
|
| 217 |
)
|
| 218 |
info["ddi_result"] = {
|
| 219 |
"severity": result.severity,
|
|
|
|
| 214 |
self._current_risk, self._current_risk,
|
| 215 |
"query_ddi",
|
| 216 |
discovered_severe=(result.severity == "severe"),
|
| 217 |
+
discovered_moderate=(result.severity == "moderate"),
|
| 218 |
)
|
| 219 |
info["ddi_result"] = {
|
| 220 |
"severity": result.severity,
|
backend/src/polypharmacy_env/neural_bandits.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Neural Thompson Sampling (NeuralTS) with Differential Evolution (DE).
|
| 2 |
+
|
| 3 |
+
Implements the OptimNeuralTS algorithm from:
|
| 4 |
+
Larouche et al., "Neural Bandits for Data Mining: Searching for Dangerous Polypharmacy"
|
| 5 |
+
https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5
|
| 6 |
+
|
| 7 |
+
Key components:
|
| 8 |
+
- NeuralTS: Neural network with gradient-based uncertainty for Thompson Sampling
|
| 9 |
+
- DE (best/1/bin): Differential Evolution to generate candidate drug combinations
|
| 10 |
+
- OptimNeuralTS: Full pipeline combining warm-up, NeuralTS, DE, and ensemble predictions
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import random
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Reward predictor network (predicts Relative Risk for a drug combination)
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RewardNetwork(nn.Module):
|
| 31 |
+
"""Neural network f(x; theta) that predicts association measure (RR)
|
| 32 |
+
for a multi-hot drug combination vector."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, input_dim: int, hidden: int = 64) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.fc1 = nn.Linear(input_dim, hidden)
|
| 37 |
+
self.fc2 = nn.Linear(hidden, hidden)
|
| 38 |
+
self.fc3 = nn.Linear(hidden, 1)
|
| 39 |
+
self._input_dim = input_dim
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
h = F.relu(self.fc1(x))
|
| 43 |
+
h = F.relu(self.fc2(h))
|
| 44 |
+
return self.fc3(h).squeeze(-1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# NeuralTS: gradient-based uncertainty estimation
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class NeuralTS:
|
| 53 |
+
"""Neural Thompson Sampling agent.
|
| 54 |
+
|
| 55 |
+
Uses the neural network gradient to estimate a posterior distribution
|
| 56 |
+
over the predicted reward, enabling exploration via Thompson Sampling.
|
| 57 |
+
|
| 58 |
+
At each step t, for an action with features x:
|
| 59 |
+
f_t(x) = network prediction (mean)
|
| 60 |
+
s_t(x) = sqrt(lambda * g(x)^T U_t^{-1} g(x)) (std)
|
| 61 |
+
where g(x) is the gradient of the network output w.r.t. parameters,
|
| 62 |
+
and U_t is the diagonal design matrix accumulated over past actions.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
input_dim: int,
|
| 68 |
+
hidden: int = 64,
|
| 69 |
+
reg_lambda: float = 1.0,
|
| 70 |
+
exploration_factor: float = 1.0,
|
| 71 |
+
lr: float = 0.01,
|
| 72 |
+
train_epochs: int = 100,
|
| 73 |
+
) -> None:
|
| 74 |
+
self.input_dim = input_dim
|
| 75 |
+
self.reg_lambda = reg_lambda
|
| 76 |
+
self.nu = exploration_factor
|
| 77 |
+
self.lr = lr
|
| 78 |
+
self.train_epochs = train_epochs
|
| 79 |
+
|
| 80 |
+
self.network = RewardNetwork(input_dim, hidden)
|
| 81 |
+
self.n_params = sum(p.numel() for p in self.network.parameters())
|
| 82 |
+
|
| 83 |
+
# Diagonal approximation of the design matrix U
|
| 84 |
+
self.U_diag = torch.ones(self.n_params) * reg_lambda
|
| 85 |
+
|
| 86 |
+
# Training dataset: (context, reward) pairs
|
| 87 |
+
self.contexts: List[torch.Tensor] = []
|
| 88 |
+
self.rewards: List[float] = []
|
| 89 |
+
|
| 90 |
+
# Ensemble: store snapshots of model weights at each training step
|
| 91 |
+
self.ensemble_weights: List[Dict[str, torch.Tensor]] = []
|
| 92 |
+
|
| 93 |
+
def _get_gradient(self, x: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""Compute gradient g(x; theta) of network output w.r.t. parameters."""
|
| 95 |
+
self.network.zero_grad()
|
| 96 |
+
pred = self.network(x.unsqueeze(0) if x.dim() == 1 else x)
|
| 97 |
+
if pred.dim() > 0:
|
| 98 |
+
pred = pred.sum()
|
| 99 |
+
pred.backward()
|
| 100 |
+
grads = []
|
| 101 |
+
for p in self.network.parameters():
|
| 102 |
+
if p.grad is not None:
|
| 103 |
+
grads.append(p.grad.detach().flatten())
|
| 104 |
+
else:
|
| 105 |
+
grads.append(torch.zeros(p.numel()))
|
| 106 |
+
return torch.cat(grads)
|
| 107 |
+
|
| 108 |
+
def predict(self, x: torch.Tensor) -> Tuple[float, float]:
|
| 109 |
+
"""Return (mean, std) for the predicted reward at features x."""
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
mean = self.network(x.unsqueeze(0) if x.dim() == 1 else x).item()
|
| 112 |
+
|
| 113 |
+
g = self._get_gradient(x)
|
| 114 |
+
# s_t(x) = sqrt(lambda * g^T U^{-1} g) (diagonal approx)
|
| 115 |
+
var = self.reg_lambda * (g ** 2 / self.U_diag).sum().item()
|
| 116 |
+
std = math.sqrt(max(var, 1e-10))
|
| 117 |
+
return mean, std
|
| 118 |
+
|
| 119 |
+
def sample_value(self, x: torch.Tensor) -> float:
|
| 120 |
+
"""Sample a value from the Thompson Sampling posterior N(f_t, nu * s_t)."""
|
| 121 |
+
mean, std = self.predict(x)
|
| 122 |
+
return random.gauss(mean, self.nu * std)
|
| 123 |
+
|
| 124 |
+
def update_design_matrix(self, x: torch.Tensor) -> None:
|
| 125 |
+
"""Update U_t with the gradient at x (U_t += g(x) * g(x)^T diagonal)."""
|
| 126 |
+
g = self._get_gradient(x)
|
| 127 |
+
self.U_diag += g ** 2
|
| 128 |
+
|
| 129 |
+
def add_observation(self, x: torch.Tensor, reward: float) -> None:
|
| 130 |
+
"""Add (context, reward) to training dataset."""
|
| 131 |
+
self.contexts.append(x.detach().clone())
|
| 132 |
+
self.rewards.append(reward)
|
| 133 |
+
|
| 134 |
+
def train_network(self) -> float:
|
| 135 |
+
"""Train the network on accumulated data. Returns final loss."""
|
| 136 |
+
if not self.contexts:
|
| 137 |
+
return 0.0
|
| 138 |
+
|
| 139 |
+
X = torch.stack(self.contexts)
|
| 140 |
+
y = torch.tensor(self.rewards, dtype=torch.float32)
|
| 141 |
+
|
| 142 |
+
optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr)
|
| 143 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 144 |
+
optimizer, patience=10, factor=0.5
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
best_loss = float("inf")
|
| 148 |
+
best_state = deepcopy(self.network.state_dict())
|
| 149 |
+
|
| 150 |
+
for epoch in range(self.train_epochs):
|
| 151 |
+
optimizer.zero_grad()
|
| 152 |
+
preds = self.network(X)
|
| 153 |
+
loss = F.mse_loss(preds, y)
|
| 154 |
+
|
| 155 |
+
# L2 regularization (as in original NeuralTS)
|
| 156 |
+
l2_reg = sum(p.pow(2).sum() for p in self.network.parameters())
|
| 157 |
+
total_loss = loss + self.reg_lambda * 1e-4 * l2_reg
|
| 158 |
+
|
| 159 |
+
total_loss.backward()
|
| 160 |
+
nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=1.0)
|
| 161 |
+
optimizer.step()
|
| 162 |
+
scheduler.step(loss.item())
|
| 163 |
+
|
| 164 |
+
if loss.item() < best_loss:
|
| 165 |
+
best_loss = loss.item()
|
| 166 |
+
best_state = deepcopy(self.network.state_dict())
|
| 167 |
+
|
| 168 |
+
# Restore best weights (maximizes likelihood)
|
| 169 |
+
self.network.load_state_dict(best_state)
|
| 170 |
+
|
| 171 |
+
# Save snapshot for ensemble
|
| 172 |
+
self.ensemble_weights.append(deepcopy(best_state))
|
| 173 |
+
|
| 174 |
+
return best_loss
|
| 175 |
+
|
| 176 |
+
def ensemble_predict(self, x: torch.Tensor) -> Tuple[float, float, bool]:
|
| 177 |
+
"""Predict using ensemble of all intermediate models.
|
| 178 |
+
|
| 179 |
+
Returns (mean_pred, lower_bound, is_pip) where:
|
| 180 |
+
- mean_pred: average prediction across ensemble
|
| 181 |
+
- lower_bound: pessimistic estimate (mean - 3*std)
|
| 182 |
+
- is_pip: True if lower_bound > threshold (1.1)
|
| 183 |
+
"""
|
| 184 |
+
if not self.ensemble_weights:
|
| 185 |
+
mean, std = self.predict(x)
|
| 186 |
+
lb = mean - 3 * std
|
| 187 |
+
return mean, lb, lb > 1.1
|
| 188 |
+
|
| 189 |
+
preds = []
|
| 190 |
+
original_state = deepcopy(self.network.state_dict())
|
| 191 |
+
|
| 192 |
+
for state_dict in self.ensemble_weights:
|
| 193 |
+
self.network.load_state_dict(state_dict)
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
p = self.network(x.unsqueeze(0) if x.dim() == 1 else x).item()
|
| 196 |
+
preds.append(p)
|
| 197 |
+
|
| 198 |
+
# Restore current weights
|
| 199 |
+
self.network.load_state_dict(original_state)
|
| 200 |
+
|
| 201 |
+
mean_pred = sum(preds) / len(preds)
|
| 202 |
+
# Use ensemble variance for uncertainty
|
| 203 |
+
if len(preds) > 1:
|
| 204 |
+
var = sum((p - mean_pred) ** 2 for p in preds) / (len(preds) - 1)
|
| 205 |
+
std = math.sqrt(var)
|
| 206 |
+
else:
|
| 207 |
+
_, std = self.predict(x)
|
| 208 |
+
|
| 209 |
+
lower_bound = mean_pred - 3 * std
|
| 210 |
+
is_pip = lower_bound > 1.1
|
| 211 |
+
|
| 212 |
+
return mean_pred, lower_bound, is_pip
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
# Differential Evolution (DE best/1/bin)
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def differential_evolution(
|
| 221 |
+
objective_fn,
|
| 222 |
+
dim: int,
|
| 223 |
+
population_size: int = 32,
|
| 224 |
+
crossover_rate: float = 0.9,
|
| 225 |
+
differential_weight: float = 1.0,
|
| 226 |
+
n_steps: int = 16,
|
| 227 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 228 |
+
"""DE best/1/bin optimization on a multi-hot feature space.
|
| 229 |
+
|
| 230 |
+
Generates candidate drug combinations by evolving a population and
|
| 231 |
+
evaluating them with the objective function (sampled from NeuralTS).
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
objective_fn: Maps a feature vector -> scalar value (e.g. Thompson sample)
|
| 235 |
+
dim: Dimensionality of feature vectors (number of possible drugs)
|
| 236 |
+
population_size: N — number of members in population
|
| 237 |
+
crossover_rate: C — probability of component crossover
|
| 238 |
+
differential_weight: F — scaling factor for mutation
|
| 239 |
+
n_steps: S — number of evolution steps
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
best_member: The feature vector maximizing the objective
|
| 243 |
+
all_members: All members evaluated during DE (for action set A_t)
|
| 244 |
+
"""
|
| 245 |
+
# Initialize population: random multi-hot vectors (each drug has ~20% chance)
|
| 246 |
+
population = []
|
| 247 |
+
for _ in range(population_size):
|
| 248 |
+
member = (torch.rand(dim) > 0.8).float()
|
| 249 |
+
# Ensure at least 2 drugs are present
|
| 250 |
+
if member.sum() < 2:
|
| 251 |
+
indices = random.sample(range(dim), 2)
|
| 252 |
+
member[indices[0]] = 1.0
|
| 253 |
+
member[indices[1]] = 1.0
|
| 254 |
+
population.append(member)
|
| 255 |
+
|
| 256 |
+
all_evaluated = list(population)
|
| 257 |
+
|
| 258 |
+
for step in range(n_steps):
|
| 259 |
+
# Find best member
|
| 260 |
+
scores = [objective_fn(m) for m in population]
|
| 261 |
+
best_idx = max(range(len(scores)), key=lambda i: scores[i])
|
| 262 |
+
best = population[best_idx]
|
| 263 |
+
|
| 264 |
+
new_population = []
|
| 265 |
+
for i, w_i in enumerate(population):
|
| 266 |
+
# Random indices (not i)
|
| 267 |
+
candidates = [j for j in range(population_size) if j != i]
|
| 268 |
+
r1, r2 = random.sample(candidates, 2)
|
| 269 |
+
|
| 270 |
+
# Mutation: m_i = best + F * (w_r1 - w_r2)
|
| 271 |
+
m_i = best + differential_weight * (population[r1] - population[r2])
|
| 272 |
+
|
| 273 |
+
# Crossover: binomial
|
| 274 |
+
l = random.randint(0, dim - 1) # guaranteed crossover index
|
| 275 |
+
u_i = w_i.clone()
|
| 276 |
+
for j in range(dim):
|
| 277 |
+
if j == l or random.random() <= crossover_rate:
|
| 278 |
+
u_i[j] = m_i[j]
|
| 279 |
+
|
| 280 |
+
# Clamp to [0, 1] and round to get multi-hot
|
| 281 |
+
u_i = torch.clamp(u_i, 0.0, 1.0)
|
| 282 |
+
u_i = (u_i > 0.5).float()
|
| 283 |
+
|
| 284 |
+
# Ensure minimum drugs
|
| 285 |
+
if u_i.sum() < 2:
|
| 286 |
+
indices = random.sample(range(dim), 2)
|
| 287 |
+
u_i[indices[0]] = 1.0
|
| 288 |
+
u_i[indices[1]] = 1.0
|
| 289 |
+
|
| 290 |
+
# Selection: keep mutant if better
|
| 291 |
+
if objective_fn(u_i) >= objective_fn(w_i):
|
| 292 |
+
new_population.append(u_i)
|
| 293 |
+
else:
|
| 294 |
+
new_population.append(w_i)
|
| 295 |
+
|
| 296 |
+
all_evaluated.append(u_i)
|
| 297 |
+
|
| 298 |
+
population = new_population
|
| 299 |
+
|
| 300 |
+
# Return the best from final population
|
| 301 |
+
final_scores = [objective_fn(m) for m in population]
|
| 302 |
+
best_idx = max(range(len(final_scores)), key=lambda i: final_scores[i])
|
| 303 |
+
return population[best_idx], all_evaluated
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ---------------------------------------------------------------------------
|
| 307 |
+
# Nearest-neighbor mapping (Hamming distance)
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def nearest_neighbor_hamming(
|
| 312 |
+
candidate: torch.Tensor,
|
| 313 |
+
dataset: List[torch.Tensor],
|
| 314 |
+
) -> int:
|
| 315 |
+
"""Find the index of the nearest neighbor in dataset using Hamming distance."""
|
| 316 |
+
best_dist = float("inf")
|
| 317 |
+
best_idx = 0
|
| 318 |
+
candidate_binary = (candidate > 0.5).float()
|
| 319 |
+
for i, item in enumerate(dataset):
|
| 320 |
+
item_binary = (item > 0.5).float()
|
| 321 |
+
dist = (candidate_binary != item_binary).float().sum().item()
|
| 322 |
+
if dist < best_dist:
|
| 323 |
+
best_dist = dist
|
| 324 |
+
best_idx = i
|
| 325 |
+
return best_idx
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
# OptimNeuralTS: full pipeline
|
| 330 |
+
# ---------------------------------------------------------------------------
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class OptimNeuralTS:
|
| 334 |
+
"""Complete OptimNeuralTS training pipeline.
|
| 335 |
+
|
| 336 |
+
Combines NeuralTS with Differential Evolution to efficiently search
|
| 337 |
+
for potentially inappropriate polypharmacies (PIPs) in a large
|
| 338 |
+
combinatorial space of drug combinations.
|
| 339 |
+
|
| 340 |
+
The algorithm:
|
| 341 |
+
1. Warm-up: Randomly sample actions for tau steps, collect rewards
|
| 342 |
+
2. Train the neural network on warm-up data
|
| 343 |
+
3. For each subsequent step:
|
| 344 |
+
a. Use DE to find the best candidate action (guided by NeuralTS posterior)
|
| 345 |
+
b. Map candidate to the nearest real drug combination (Hamming distance)
|
| 346 |
+
c. Observe reward (Relative Risk), add to training data
|
| 347 |
+
d. Retrain network periodically
|
| 348 |
+
4. Return ensemble of all intermediate models for prediction
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
input_dim: int,
|
| 354 |
+
hidden: int = 64,
|
| 355 |
+
reg_lambda: float = 1.0,
|
| 356 |
+
exploration_factor: float = 1.0,
|
| 357 |
+
lr: float = 0.01,
|
| 358 |
+
train_epochs: int = 100,
|
| 359 |
+
warmup_steps: int = 100,
|
| 360 |
+
total_steps: int = 1000,
|
| 361 |
+
retrain_every: int = 10,
|
| 362 |
+
de_population: int = 32,
|
| 363 |
+
de_crossover: float = 0.9,
|
| 364 |
+
de_weight: float = 1.0,
|
| 365 |
+
de_steps: int = 16,
|
| 366 |
+
) -> None:
|
| 367 |
+
self.agent = NeuralTS(
|
| 368 |
+
input_dim=input_dim,
|
| 369 |
+
hidden=hidden,
|
| 370 |
+
reg_lambda=reg_lambda,
|
| 371 |
+
exploration_factor=exploration_factor,
|
| 372 |
+
lr=lr,
|
| 373 |
+
train_epochs=train_epochs,
|
| 374 |
+
)
|
| 375 |
+
self.warmup_steps = warmup_steps
|
| 376 |
+
self.total_steps = total_steps
|
| 377 |
+
self.retrain_every = retrain_every
|
| 378 |
+
self.de_population = de_population
|
| 379 |
+
self.de_crossover = de_crossover
|
| 380 |
+
self.de_weight = de_weight
|
| 381 |
+
self.de_steps = de_steps
|
| 382 |
+
self.input_dim = input_dim
|
| 383 |
+
|
| 384 |
+
self.step_count = 0
|
| 385 |
+
self.training_log: List[Dict[str, Any]] = []
|
| 386 |
+
|
| 387 |
+
def select_action(
|
| 388 |
+
self,
|
| 389 |
+
available_actions: List[torch.Tensor],
|
| 390 |
+
) -> Tuple[int, Dict[str, Any]]:
|
| 391 |
+
"""Select an action from available_actions.
|
| 392 |
+
|
| 393 |
+
During warm-up: random selection.
|
| 394 |
+
After warm-up: DE + NeuralTS Thompson Sampling.
|
| 395 |
+
|
| 396 |
+
Returns: (index into available_actions, info dict)
|
| 397 |
+
"""
|
| 398 |
+
info: Dict[str, Any] = {"phase": "warmup" if self.step_count < self.warmup_steps else "bandit"}
|
| 399 |
+
|
| 400 |
+
if self.step_count < self.warmup_steps:
|
| 401 |
+
# Warm-up: random
|
| 402 |
+
idx = random.randint(0, len(available_actions) - 1)
|
| 403 |
+
info["selection"] = "random"
|
| 404 |
+
return idx, info
|
| 405 |
+
|
| 406 |
+
# After warm-up: use DE + NeuralTS
|
| 407 |
+
def ts_objective(x: torch.Tensor) -> float:
|
| 408 |
+
return self.agent.sample_value(x)
|
| 409 |
+
|
| 410 |
+
# Run DE to find best candidate
|
| 411 |
+
best_candidate, _ = differential_evolution(
|
| 412 |
+
objective_fn=ts_objective,
|
| 413 |
+
dim=self.input_dim,
|
| 414 |
+
population_size=self.de_population,
|
| 415 |
+
crossover_rate=self.de_crossover,
|
| 416 |
+
differential_weight=self.de_weight,
|
| 417 |
+
n_steps=self.de_steps,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Update design matrix with DE's recommended action
|
| 421 |
+
self.agent.update_design_matrix(best_candidate)
|
| 422 |
+
|
| 423 |
+
# Map to nearest real action (Hamming distance)
|
| 424 |
+
idx = nearest_neighbor_hamming(best_candidate, available_actions)
|
| 425 |
+
info["selection"] = "de_neuralts"
|
| 426 |
+
|
| 427 |
+
mean, std = self.agent.predict(available_actions[idx])
|
| 428 |
+
info["predicted_rr"] = mean
|
| 429 |
+
info["uncertainty"] = std
|
| 430 |
+
|
| 431 |
+
return idx, info
|
| 432 |
+
|
| 433 |
+
def observe(self, x: torch.Tensor, reward: float) -> Optional[float]:
|
| 434 |
+
"""Record observation and retrain if needed.
|
| 435 |
+
|
| 436 |
+
Returns training loss if retrained, None otherwise.
|
| 437 |
+
"""
|
| 438 |
+
self.agent.add_observation(x, reward)
|
| 439 |
+
self.step_count += 1
|
| 440 |
+
|
| 441 |
+
loss = None
|
| 442 |
+
# Retrain after warm-up, then every retrain_every steps
|
| 443 |
+
if self.step_count == self.warmup_steps:
|
| 444 |
+
loss = self.agent.train_network()
|
| 445 |
+
elif self.step_count > self.warmup_steps and self.step_count % self.retrain_every == 0:
|
| 446 |
+
loss = self.agent.train_network()
|
| 447 |
+
|
| 448 |
+
self.training_log.append({
|
| 449 |
+
"step": self.step_count,
|
| 450 |
+
"reward": reward,
|
| 451 |
+
"loss": loss,
|
| 452 |
+
"n_ensemble": len(self.agent.ensemble_weights),
|
| 453 |
+
})
|
| 454 |
+
|
| 455 |
+
return loss
|
| 456 |
+
|
| 457 |
+
def predict_risk(self, x: torch.Tensor) -> Dict[str, Any]:
|
| 458 |
+
"""Use the ensemble to predict risk for a drug combination.
|
| 459 |
+
|
| 460 |
+
Returns dict with mean prediction, lower confidence bound,
|
| 461 |
+
and whether the combination is flagged as a PIP.
|
| 462 |
+
"""
|
| 463 |
+
mean, lower_bound, is_pip = self.agent.ensemble_predict(x)
|
| 464 |
+
return {
|
| 465 |
+
"predicted_rr": round(mean, 4),
|
| 466 |
+
"lower_bound": round(lower_bound, 4),
|
| 467 |
+
"is_potentially_harmful": is_pip,
|
| 468 |
+
"n_models_in_ensemble": len(self.agent.ensemble_weights),
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
def get_metrics(self) -> Dict[str, Any]:
|
| 472 |
+
"""Return training metrics summary."""
|
| 473 |
+
if not self.training_log:
|
| 474 |
+
return {"status": "no_data"}
|
| 475 |
+
|
| 476 |
+
rewards = [e["reward"] for e in self.training_log]
|
| 477 |
+
return {
|
| 478 |
+
"total_steps": self.step_count,
|
| 479 |
+
"warmup_steps": self.warmup_steps,
|
| 480 |
+
"n_ensemble_models": len(self.agent.ensemble_weights),
|
| 481 |
+
"avg_reward": sum(rewards) / len(rewards) if rewards else 0,
|
| 482 |
+
"max_reward": max(rewards) if rewards else 0,
|
| 483 |
+
"phase": "warmup" if self.step_count < self.warmup_steps else "bandit",
|
| 484 |
+
}
|
backend/src/polypharmacy_env/rewards.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple
|
|
| 8 |
from .config import (
|
| 9 |
INTERVENTION_COST,
|
| 10 |
INVALID_ACTION_PENALTY,
|
|
|
|
| 11 |
QUERY_COST,
|
| 12 |
SEVERE_DDI_DISCOVERY_BONUS,
|
| 13 |
TIMEOUT_PENALTY,
|
|
@@ -39,8 +40,8 @@ def compute_regimen_risk(
|
|
| 39 |
if rule is not None:
|
| 40 |
risk += rule.base_risk_score
|
| 41 |
|
| 42 |
-
# 2. Beers violations
|
| 43 |
-
beers_weight = {"avoid": 0.
|
| 44 |
for bc in beers_criteria:
|
| 45 |
if bc.drug_id not in drug_set:
|
| 46 |
continue
|
|
@@ -68,6 +69,7 @@ def compute_shaped_reward(
|
|
| 68 |
is_invalid: bool = False,
|
| 69 |
is_timeout: bool = False,
|
| 70 |
discovered_severe: bool = False,
|
|
|
|
| 71 |
) -> float:
|
| 72 |
"""Compute the step-level shaped reward."""
|
| 73 |
reward = 0.0
|
|
@@ -82,6 +84,8 @@ def compute_shaped_reward(
|
|
| 82 |
reward -= QUERY_COST
|
| 83 |
if discovered_severe:
|
| 84 |
reward += SEVERE_DDI_DISCOVERY_BONUS
|
|
|
|
|
|
|
| 85 |
|
| 86 |
elif action_type == "propose_intervention":
|
| 87 |
reward += (previous_risk - new_risk)
|
|
|
|
| 8 |
from .config import (
|
| 9 |
INTERVENTION_COST,
|
| 10 |
INVALID_ACTION_PENALTY,
|
| 11 |
+
MODERATE_DDI_DISCOVERY_BONUS,
|
| 12 |
QUERY_COST,
|
| 13 |
SEVERE_DDI_DISCOVERY_BONUS,
|
| 14 |
TIMEOUT_PENALTY,
|
|
|
|
| 40 |
if rule is not None:
|
| 41 |
risk += rule.base_risk_score
|
| 42 |
|
| 43 |
+
# 2. Beers violations (weights reflect clinical severity)
|
| 44 |
+
beers_weight = {"avoid": 0.30, "caution": 0.12, "dose_adjust": 0.10, "avoid_in_condition": 0.25}
|
| 45 |
for bc in beers_criteria:
|
| 46 |
if bc.drug_id not in drug_set:
|
| 47 |
continue
|
|
|
|
| 69 |
is_invalid: bool = False,
|
| 70 |
is_timeout: bool = False,
|
| 71 |
discovered_severe: bool = False,
|
| 72 |
+
discovered_moderate: bool = False,
|
| 73 |
) -> float:
|
| 74 |
"""Compute the step-level shaped reward."""
|
| 75 |
reward = 0.0
|
|
|
|
| 84 |
reward -= QUERY_COST
|
| 85 |
if discovered_severe:
|
| 86 |
reward += SEVERE_DDI_DISCOVERY_BONUS
|
| 87 |
+
elif discovered_moderate:
|
| 88 |
+
reward += MODERATE_DDI_DISCOVERY_BONUS
|
| 89 |
|
| 90 |
elif action_type == "propose_intervention":
|
| 91 |
reward += (previous_risk - new_risk)
|
frontend/src/App.jsx
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import { useEffect, useMemo, useRef, useState } from "react";
|
| 2 |
|
| 3 |
function resolveApiBase() {
|
| 4 |
const explicitBase = import.meta.env.VITE_API_BASE;
|
|
@@ -8,7 +8,6 @@ function resolveApiBase() {
|
|
| 8 |
const isLocal =
|
| 9 |
host === "localhost" || host === "127.0.0.1" || host === "0.0.0.0";
|
| 10 |
|
| 11 |
-
// In local Vite dev, backend runs on :7860. In Spaces/prod, serve same-origin.
|
| 12 |
if (isLocal && window.location.port === "5173") {
|
| 13 |
return "http://localhost:7860";
|
| 14 |
}
|
|
@@ -17,7 +16,117 @@ function resolveApiBase() {
|
|
| 17 |
|
| 18 |
const API_BASE = resolveApiBase();
|
| 19 |
const WS_URL = `${API_BASE.replace(/^http/, "ws")}/ws`;
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
async function apiPost(path, body) {
|
| 23 |
const res = await fetch(`${API_BASE}${path}`, {
|
|
@@ -32,11 +141,158 @@ async function apiPost(path, body) {
|
|
| 32 |
return res.json();
|
| 33 |
}
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
export default function App() {
|
| 36 |
const [taskId, setTaskId] = useState("budgeted_screening");
|
| 37 |
const [obs, setObs] = useState(null);
|
| 38 |
const [log, setLog] = useState([]);
|
| 39 |
const [loading, setLoading] = useState(false);
|
|
|
|
|
|
|
| 40 |
const [action, setAction] = useState({
|
| 41 |
action_type: "query_ddi",
|
| 42 |
drug_id_1: "",
|
|
@@ -51,10 +307,13 @@ export default function App() {
|
|
| 51 |
() => (obs?.current_medications || []).map((m) => m.drug_id),
|
| 52 |
[obs]
|
| 53 |
);
|
| 54 |
-
const hasValidEpisode =
|
|
|
|
| 55 |
const isDone = Boolean(obs?.done);
|
| 56 |
const finalScore =
|
| 57 |
-
typeof obs?.metadata?.grader_score === "number"
|
|
|
|
|
|
|
| 58 |
const noBudgetsLeft =
|
| 59 |
hasValidEpisode &&
|
| 60 |
(obs?.remaining_query_budget ?? 0) <= 0 &&
|
|
@@ -63,7 +322,8 @@ export default function App() {
|
|
| 63 |
const pendingRef = useRef([]);
|
| 64 |
|
| 65 |
const wsEnsure = async () => {
|
| 66 |
-
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN)
|
|
|
|
| 67 |
if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) {
|
| 68 |
await new Promise((r) => setTimeout(r, 80));
|
| 69 |
return wsEnsure();
|
|
@@ -91,7 +351,10 @@ export default function App() {
|
|
| 91 |
};
|
| 92 |
|
| 93 |
await new Promise((resolve, reject) => {
|
| 94 |
-
const t = setTimeout(
|
|
|
|
|
|
|
|
|
|
| 95 |
ws.onopen = () => {
|
| 96 |
clearTimeout(t);
|
| 97 |
resolve();
|
|
@@ -113,13 +376,15 @@ export default function App() {
|
|
| 113 |
try {
|
| 114 |
wsRef.current?.close();
|
| 115 |
} catch {
|
| 116 |
-
/
|
| 117 |
}
|
| 118 |
};
|
| 119 |
}, []);
|
| 120 |
|
| 121 |
const appendLog = (text) => {
|
| 122 |
-
setLog((prev) =>
|
|
|
|
|
|
|
| 123 |
};
|
| 124 |
|
| 125 |
const normalizeObsFromWs = (packetData) => {
|
|
@@ -150,7 +415,7 @@ export default function App() {
|
|
| 150 |
drug_id_2: ids[1] || "",
|
| 151 |
target_drug_id: ids[0] || "",
|
| 152 |
}));
|
| 153 |
-
appendLog(`Reset
|
| 154 |
} catch (err) {
|
| 155 |
appendLog(`Reset failed: ${err.message}`);
|
| 156 |
} finally {
|
|
@@ -206,7 +471,9 @@ export default function App() {
|
|
| 206 |
const data = msg?.data || {};
|
| 207 |
const normalized = normalizeObsFromWs(data);
|
| 208 |
setObs(normalized);
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
} catch (err) {
|
| 211 |
appendLog(`Step failed: ${err.message}`);
|
| 212 |
} finally {
|
|
@@ -222,7 +489,9 @@ export default function App() {
|
|
| 222 |
setLoading(true);
|
| 223 |
try {
|
| 224 |
const data = await apiPost("/agent/suggest", { observation: obs });
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
await handleStep(data.action);
|
| 227 |
} catch (err) {
|
| 228 |
appendLog(`AI suggestion failed: ${err.message}`);
|
|
@@ -231,156 +500,450 @@ export default function App() {
|
|
| 231 |
}
|
| 232 |
};
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
return (
|
| 235 |
<div className="shell">
|
| 236 |
<div className="bg-orb orb-a" />
|
| 237 |
<div className="bg-orb orb-b" />
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
<div className="container">
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
{t}
|
| 253 |
-
</option>
|
| 254 |
-
))}
|
| 255 |
-
</select>
|
| 256 |
-
<button onClick={handleReset} disabled={loading}>
|
| 257 |
-
Reset Episode
|
| 258 |
-
</button>
|
| 259 |
-
<button className="secondary" onClick={askAi} disabled={!hasValidEpisode || isDone || loading}>
|
| 260 |
-
Ask AI + Auto Step
|
| 261 |
-
</button>
|
| 262 |
-
</div>
|
| 263 |
-
</header>
|
| 264 |
-
|
| 265 |
-
<main className="layout">
|
| 266 |
-
<section className="panel glass panel-wide">
|
| 267 |
-
<h2>Episode</h2>
|
| 268 |
-
{hasValidEpisode ? (
|
| 269 |
-
<div className="kpi-grid">
|
| 270 |
-
<div><span>Episode</span><strong>{obs.episode_id}</strong></div>
|
| 271 |
-
<div><span>Task</span><strong>{obs.task_id}</strong></div>
|
| 272 |
-
<div><span>Age / Sex</span><strong>{obs.age} / {obs.sex}</strong></div>
|
| 273 |
-
<div><span>Step</span><strong>{obs.step_index}</strong></div>
|
| 274 |
-
<div><span>Query budget</span><strong>{obs.remaining_query_budget}</strong></div>
|
| 275 |
-
<div><span>Intervention budget</span><strong>{obs.remaining_intervention_budget}</strong></div>
|
| 276 |
</div>
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
<p className="muted budget-note">
|
| 285 |
-
Episode complete
|
| 286 |
-
{finalScore !== null ? ` • final score: ${finalScore.toFixed(3)}` : ""}.
|
| 287 |
-
Click Reset Episode to start a new case.
|
| 288 |
-
</p>
|
| 289 |
-
)}
|
| 290 |
-
</section>
|
| 291 |
-
|
| 292 |
-
<section className="panel glass">
|
| 293 |
-
<h2>Action Console</h2>
|
| 294 |
-
<div className="action-row">
|
| 295 |
-
<label>Action type</label>
|
| 296 |
-
<select
|
| 297 |
-
value={action.action_type}
|
| 298 |
-
onChange={(e) => setAction((a) => ({ ...a, action_type: e.target.value }))}
|
| 299 |
>
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
</select>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
</div>
|
|
|
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
<select
|
| 324 |
-
value={action.
|
| 325 |
-
onChange={(e) =>
|
|
|
|
|
|
|
| 326 |
>
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
{id}
|
| 331 |
</option>
|
| 332 |
))}
|
| 333 |
</select>
|
| 334 |
-
<select
|
| 335 |
-
value={action.intervention_type}
|
| 336 |
-
onChange={(e) => setAction((a) => ({ ...a, intervention_type: e.target.value }))}
|
| 337 |
-
>
|
| 338 |
-
<option value="stop">stop</option>
|
| 339 |
-
<option value="dose_reduce">dose_reduce</option>
|
| 340 |
-
<option value="substitute">substitute</option>
|
| 341 |
-
<option value="add_monitoring">add_monitoring</option>
|
| 342 |
-
</select>
|
| 343 |
-
<input
|
| 344 |
-
placeholder="proposed_new_drug_id (optional)"
|
| 345 |
-
value={action.proposed_new_drug_id}
|
| 346 |
-
onChange={(e) =>
|
| 347 |
-
setAction((a) => ({ ...a, proposed_new_drug_id: e.target.value }))
|
| 348 |
-
}
|
| 349 |
-
/>
|
| 350 |
-
<input
|
| 351 |
-
placeholder="rationale (optional)"
|
| 352 |
-
value={action.rationale}
|
| 353 |
-
onChange={(e) => setAction((a) => ({ ...a, rationale: e.target.value }))}
|
| 354 |
-
/>
|
| 355 |
</div>
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
</div>
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
</div>
|
| 385 |
</div>
|
| 386 |
);
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useRef, useState, useCallback } from "react";
|
| 2 |
|
| 3 |
function resolveApiBase() {
|
| 4 |
const explicitBase = import.meta.env.VITE_API_BASE;
|
|
|
|
| 8 |
const isLocal =
|
| 9 |
host === "localhost" || host === "127.0.0.1" || host === "0.0.0.0";
|
| 10 |
|
|
|
|
| 11 |
if (isLocal && window.location.port === "5173") {
|
| 12 |
return "http://localhost:7860";
|
| 13 |
}
|
|
|
|
| 16 |
|
| 17 |
const API_BASE = resolveApiBase();
|
| 18 |
const WS_URL = `${API_BASE.replace(/^http/, "ws")}/ws`;
|
| 19 |
+
|
| 20 |
+
const TASKS = [
|
| 21 |
+
{ id: "easy_screening", label: "Easy Screening" },
|
| 22 |
+
{ id: "budgeted_screening", label: "Budgeted Screening" },
|
| 23 |
+
{ id: "complex_tradeoff", label: "Complex Tradeoff" },
|
| 24 |
+
];
|
| 25 |
+
|
| 26 |
+
const TASK_LABEL_MAP = Object.fromEntries(TASKS.map((t) => [t.id, t.label]));
|
| 27 |
+
|
| 28 |
+
const ACTION_LABELS = {
|
| 29 |
+
query_ddi: "Check Drug Interaction",
|
| 30 |
+
propose_intervention: "Propose Change",
|
| 31 |
+
finish_review: "Finish Review",
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
const INTERVENTION_LABELS = {
|
| 35 |
+
stop: "Stop Medication",
|
| 36 |
+
dose_reduce: "Reduce Dose",
|
| 37 |
+
substitute: "Substitute with Safer Drug",
|
| 38 |
+
add_monitoring: "Add Monitoring",
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
// ── Contextual guide steps: each targets a specific UI section ──────────────
|
| 42 |
+
const GUIDE_STEPS = [
|
| 43 |
+
{
|
| 44 |
+
target: "topbar",
|
| 45 |
+
position: "below",
|
| 46 |
+
title: "Welcome to PolypharmacyEnv",
|
| 47 |
+
body: `This tool helps review elderly patients' medication regimens for safety.
|
| 48 |
+
|
| 49 |
+
You'll act as a pharmacist assistant: check pairs of drugs for harmful interactions, propose changes to reduce risk, and get scored on how well you protect the patient — all under limited budgets.
|
| 50 |
+
|
| 51 |
+
Behind the scenes, an AI agent (Neural Bandit) learns which drug combinations to investigate first, getting smarter with each review.`,
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
target: "task-selector",
|
| 55 |
+
position: "below",
|
| 56 |
+
title: "Choose a Scenario",
|
| 57 |
+
body: `Pick a difficulty level:
|
| 58 |
+
|
| 59 |
+
• Easy Screening — 3–5 drugs, 1 known dangerous interaction. Great for getting started.
|
| 60 |
+
• Budgeted Screening — 6–10 drugs, multiple problems to find, tighter budgets.
|
| 61 |
+
• Complex Tradeoff — 10–15 drugs including critical ones (blood thinners, insulin). Removing critical drugs without a replacement is penalized.
|
| 62 |
+
|
| 63 |
+
Click "Reset Episode" to load a new patient case.`,
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
target: "episode-panel",
|
| 67 |
+
position: "below",
|
| 68 |
+
title: "Patient Overview",
|
| 69 |
+
body: `After resetting, this panel shows the patient's details:
|
| 70 |
+
|
| 71 |
+
• Demographics (age, sex, medical conditions)
|
| 72 |
+
• Your remaining query and intervention budgets
|
| 73 |
+
• A risk bar comparing starting risk vs. current risk
|
| 74 |
+
• How many review steps you've taken
|
| 75 |
+
|
| 76 |
+
Each check and intervention uses up budget — use them wisely to get the best outcome.`,
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
target: "action-console",
|
| 80 |
+
position: "right",
|
| 81 |
+
title: "Check Drug Interactions",
|
| 82 |
+
body: `Select "Check Drug Interaction" and pick two drugs from the patient's list:
|
| 83 |
+
|
| 84 |
+
Example dangerous combinations:
|
| 85 |
+
• Warfarin + Naproxen → severe bleeding risk
|
| 86 |
+
• Diazepam + Tramadol → dangerous sedation
|
| 87 |
+
• Apixaban + Naproxen → severe bleeding risk
|
| 88 |
+
|
| 89 |
+
Each check costs a small amount of budget. Finding a serious interaction earns a bonus. A smart strategy checks high-risk pairs first.`,
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
target: "action-console",
|
| 93 |
+
position: "right",
|
| 94 |
+
title: "Propose Changes",
|
| 95 |
+
body: `After finding a dangerous interaction, switch to "Propose Change":
|
| 96 |
+
|
| 97 |
+
• Stop Medication — Remove the drug entirely
|
| 98 |
+
• Reduce Dose — Lower the dose to reduce risk
|
| 99 |
+
• Substitute Drug — Automatically finds a safer alternative in the same drug class
|
| 100 |
+
• Add Monitoring — Flag for closer clinical monitoring
|
| 101 |
+
|
| 102 |
+
Example: After finding warfarin + naproxen interaction, select Naproxen → "Substitute". The system finds a safer pain reliever.`,
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
target: "medications-panel",
|
| 106 |
+
position: "left",
|
| 107 |
+
title: "Current Medications",
|
| 108 |
+
body: `This grid shows the patient's active medications. Each card shows:
|
| 109 |
+
|
| 110 |
+
• Drug name and dose
|
| 111 |
+
• Drug class (e.g., pain reliever, blood thinner)
|
| 112 |
+
• "High Risk" badge for drugs that need extra caution in elderly patients
|
| 113 |
+
• Safety flags (avoid, caution, adjust dose)
|
| 114 |
+
|
| 115 |
+
Cards marked "avoid" or "High Risk" are prime candidates for a closer look. The list updates live as you make changes.`,
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
target: "event-log",
|
| 119 |
+
position: "above",
|
| 120 |
+
title: "Activity Log & Score",
|
| 121 |
+
body: `The log tracks every action you take and its impact. When you click "Finish Review", you get a final score (0–100%):
|
| 122 |
+
|
| 123 |
+
• Easy: Based on risk reduction + targeting the right dangerous drugs
|
| 124 |
+
• Medium: Risk reduction + precision of your interventions + how well you used your budget
|
| 125 |
+
• Hard: Risk reduction minus penalties for disrupting the patient's treatment plan
|
| 126 |
+
|
| 127 |
+
The "Ask AI" button lets an AI agent make decisions using the same tools you have.`,
|
| 128 |
+
},
|
| 129 |
+
];
|
| 130 |
|
| 131 |
async function apiPost(path, body) {
|
| 132 |
const res = await fetch(`${API_BASE}${path}`, {
|
|
|
|
| 141 |
return res.json();
|
| 142 |
}
|
| 143 |
|
| 144 |
+
// ── Spotlight Guide Component ───────────────────────────────────────────────
|
| 145 |
+
function SpotlightGuide({ step, steps, onNext, onPrev, onClose }) {
|
| 146 |
+
const [rect, setRect] = useState(null);
|
| 147 |
+
const tooltipRef = useRef(null);
|
| 148 |
+
|
| 149 |
+
const updateRect = useCallback(() => {
|
| 150 |
+
const target = steps[step]?.target;
|
| 151 |
+
if (!target) return;
|
| 152 |
+
const el = document.querySelector(`[data-guide="${target}"]`);
|
| 153 |
+
if (el) {
|
| 154 |
+
const r = el.getBoundingClientRect();
|
| 155 |
+
setRect({ top: r.top, left: r.left, width: r.width, height: r.height });
|
| 156 |
+
// scroll into view
|
| 157 |
+
el.scrollIntoView({ behavior: "smooth", block: "nearest" });
|
| 158 |
+
}
|
| 159 |
+
}, [step, steps]);
|
| 160 |
+
|
| 161 |
+
useEffect(() => {
|
| 162 |
+
updateRect();
|
| 163 |
+
window.addEventListener("resize", updateRect);
|
| 164 |
+
window.addEventListener("scroll", updateRect, true);
|
| 165 |
+
return () => {
|
| 166 |
+
window.removeEventListener("resize", updateRect);
|
| 167 |
+
window.removeEventListener("scroll", updateRect, true);
|
| 168 |
+
};
|
| 169 |
+
}, [updateRect]);
|
| 170 |
+
|
| 171 |
+
if (!rect) return null;
|
| 172 |
+
|
| 173 |
+
const pad = 8;
|
| 174 |
+
const current = steps[step];
|
| 175 |
+
|
| 176 |
+
// Calculate tooltip position
|
| 177 |
+
const getTooltipStyle = () => {
|
| 178 |
+
const pos = current.position || "below";
|
| 179 |
+
const base = {};
|
| 180 |
+
if (pos === "below") {
|
| 181 |
+
base.top = rect.top + rect.height + pad + 12;
|
| 182 |
+
base.left = rect.left;
|
| 183 |
+
base.maxWidth = Math.min(440, window.innerWidth - 40);
|
| 184 |
+
} else if (pos === "above") {
|
| 185 |
+
base.bottom = window.innerHeight - rect.top + pad + 12;
|
| 186 |
+
base.left = rect.left;
|
| 187 |
+
base.maxWidth = Math.min(440, window.innerWidth - 40);
|
| 188 |
+
} else if (pos === "right") {
|
| 189 |
+
base.top = rect.top;
|
| 190 |
+
base.left = rect.left + rect.width + pad + 12;
|
| 191 |
+
base.maxWidth = Math.min(380, window.innerWidth - rect.left - rect.width - 40);
|
| 192 |
+
} else if (pos === "left") {
|
| 193 |
+
base.top = rect.top;
|
| 194 |
+
base.right = window.innerWidth - rect.left + pad + 12;
|
| 195 |
+
base.maxWidth = Math.min(380, rect.left - 40);
|
| 196 |
+
}
|
| 197 |
+
return base;
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
return (
|
| 201 |
+
<div className="spotlight-overlay">
|
| 202 |
+
{/* Dark overlay with cutout */}
|
| 203 |
+
<svg className="spotlight-svg" width="100%" height="100%">
|
| 204 |
+
<defs>
|
| 205 |
+
<mask id="spotlight-mask">
|
| 206 |
+
<rect width="100%" height="100%" fill="white" />
|
| 207 |
+
<rect
|
| 208 |
+
x={rect.left - pad}
|
| 209 |
+
y={rect.top - pad}
|
| 210 |
+
width={rect.width + pad * 2}
|
| 211 |
+
height={rect.height + pad * 2}
|
| 212 |
+
rx="12"
|
| 213 |
+
fill="black"
|
| 214 |
+
/>
|
| 215 |
+
</mask>
|
| 216 |
+
</defs>
|
| 217 |
+
<rect
|
| 218 |
+
width="100%"
|
| 219 |
+
height="100%"
|
| 220 |
+
fill="rgba(4, 6, 15, 0.75)"
|
| 221 |
+
mask="url(#spotlight-mask)"
|
| 222 |
+
/>
|
| 223 |
+
</svg>
|
| 224 |
+
|
| 225 |
+
{/* Highlight border around target */}
|
| 226 |
+
<div
|
| 227 |
+
className="spotlight-ring"
|
| 228 |
+
style={{
|
| 229 |
+
top: rect.top - pad,
|
| 230 |
+
left: rect.left - pad,
|
| 231 |
+
width: rect.width + pad * 2,
|
| 232 |
+
height: rect.height + pad * 2,
|
| 233 |
+
}}
|
| 234 |
+
/>
|
| 235 |
+
|
| 236 |
+
{/* Tooltip */}
|
| 237 |
+
<div
|
| 238 |
+
ref={tooltipRef}
|
| 239 |
+
className="spotlight-tooltip glass"
|
| 240 |
+
style={getTooltipStyle()}
|
| 241 |
+
>
|
| 242 |
+
<div className="spotlight-tooltip-header">
|
| 243 |
+
<h3>{current.title}</h3>
|
| 244 |
+
<span className="guide-counter">
|
| 245 |
+
{step + 1} / {steps.length}
|
| 246 |
+
</span>
|
| 247 |
+
</div>
|
| 248 |
+
<div className="spotlight-tooltip-body">
|
| 249 |
+
{current.body.split("\n").map((line, i) => (
|
| 250 |
+
<p key={i}>{line}</p>
|
| 251 |
+
))}
|
| 252 |
+
</div>
|
| 253 |
+
<div className="spotlight-tooltip-footer">
|
| 254 |
+
<button
|
| 255 |
+
className="guide-btn secondary"
|
| 256 |
+
onClick={onPrev}
|
| 257 |
+
disabled={step === 0}
|
| 258 |
+
>
|
| 259 |
+
Previous
|
| 260 |
+
</button>
|
| 261 |
+
<button className="guide-btn secondary" onClick={onClose}>
|
| 262 |
+
Skip
|
| 263 |
+
</button>
|
| 264 |
+
{step < steps.length - 1 ? (
|
| 265 |
+
<button className="guide-btn" onClick={onNext}>
|
| 266 |
+
Next
|
| 267 |
+
</button>
|
| 268 |
+
) : (
|
| 269 |
+
<button className="guide-btn" onClick={onClose}>
|
| 270 |
+
Done
|
| 271 |
+
</button>
|
| 272 |
+
)}
|
| 273 |
+
</div>
|
| 274 |
+
<div className="guide-dots">
|
| 275 |
+
{steps.map((_, i) => (
|
| 276 |
+
<span
|
| 277 |
+
key={i}
|
| 278 |
+
className={`dot ${i === step ? "active" : ""}`}
|
| 279 |
+
/>
|
| 280 |
+
))}
|
| 281 |
+
</div>
|
| 282 |
+
</div>
|
| 283 |
+
</div>
|
| 284 |
+
);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// ── Main App ────────────────────────────────────────────────────────────────
|
| 288 |
+
|
| 289 |
export default function App() {
|
| 290 |
const [taskId, setTaskId] = useState("budgeted_screening");
|
| 291 |
const [obs, setObs] = useState(null);
|
| 292 |
const [log, setLog] = useState([]);
|
| 293 |
const [loading, setLoading] = useState(false);
|
| 294 |
+
const [guideStep, setGuideStep] = useState(0);
|
| 295 |
+
const [showGuide, setShowGuide] = useState(true);
|
| 296 |
const [action, setAction] = useState({
|
| 297 |
action_type: "query_ddi",
|
| 298 |
drug_id_1: "",
|
|
|
|
| 307 |
() => (obs?.current_medications || []).map((m) => m.drug_id),
|
| 308 |
[obs]
|
| 309 |
);
|
| 310 |
+
const hasValidEpisode =
|
| 311 |
+
Boolean(obs?.episode_id) && (obs?.current_medications?.length || 0) > 0;
|
| 312 |
const isDone = Boolean(obs?.done);
|
| 313 |
const finalScore =
|
| 314 |
+
typeof obs?.metadata?.grader_score === "number"
|
| 315 |
+
? obs.metadata.grader_score
|
| 316 |
+
: null;
|
| 317 |
const noBudgetsLeft =
|
| 318 |
hasValidEpisode &&
|
| 319 |
(obs?.remaining_query_budget ?? 0) <= 0 &&
|
|
|
|
| 322 |
const pendingRef = useRef([]);
|
| 323 |
|
| 324 |
const wsEnsure = async () => {
|
| 325 |
+
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN)
|
| 326 |
+
return wsRef.current;
|
| 327 |
if (wsRef.current && wsRef.current.readyState === WebSocket.CONNECTING) {
|
| 328 |
await new Promise((r) => setTimeout(r, 80));
|
| 329 |
return wsEnsure();
|
|
|
|
| 351 |
};
|
| 352 |
|
| 353 |
await new Promise((resolve, reject) => {
|
| 354 |
+
const t = setTimeout(
|
| 355 |
+
() => reject(new Error("WebSocket connect timeout")),
|
| 356 |
+
2500
|
| 357 |
+
);
|
| 358 |
ws.onopen = () => {
|
| 359 |
clearTimeout(t);
|
| 360 |
resolve();
|
|
|
|
| 376 |
try {
|
| 377 |
wsRef.current?.close();
|
| 378 |
} catch {
|
| 379 |
+
/* ignore */
|
| 380 |
}
|
| 381 |
};
|
| 382 |
}, []);
|
| 383 |
|
| 384 |
const appendLog = (text) => {
|
| 385 |
+
setLog((prev) =>
|
| 386 |
+
[`${new Date().toLocaleTimeString()} ${text}`, ...prev].slice(0, 30)
|
| 387 |
+
);
|
| 388 |
};
|
| 389 |
|
| 390 |
const normalizeObsFromWs = (packetData) => {
|
|
|
|
| 415 |
drug_id_2: ids[1] || "",
|
| 416 |
target_drug_id: ids[0] || "",
|
| 417 |
}));
|
| 418 |
+
appendLog(`Reset — ${TASK_LABEL_MAP[taskId] || taskId}`);
|
| 419 |
} catch (err) {
|
| 420 |
appendLog(`Reset failed: ${err.message}`);
|
| 421 |
} finally {
|
|
|
|
| 471 |
const data = msg?.data || {};
|
| 472 |
const normalized = normalizeObsFromWs(data);
|
| 473 |
setObs(normalized);
|
| 474 |
+
const label = ACTION_LABELS[payload.action_type] || payload.action_type;
|
| 475 |
+
const rwd = data.reward ?? 0;
|
| 476 |
+
appendLog(`${label} → reward: ${Number(rwd).toFixed(3)}`);
|
| 477 |
} catch (err) {
|
| 478 |
appendLog(`Step failed: ${err.message}`);
|
| 479 |
} finally {
|
|
|
|
| 489 |
setLoading(true);
|
| 490 |
try {
|
| 491 |
const data = await apiPost("/agent/suggest", { observation: obs });
|
| 492 |
+
const label =
|
| 493 |
+
ACTION_LABELS[data.action.action_type] || data.action.action_type;
|
| 494 |
+
appendLog(`AI suggests: ${label}`);
|
| 495 |
await handleStep(data.action);
|
| 496 |
} catch (err) {
|
| 497 |
appendLog(`AI suggestion failed: ${err.message}`);
|
|
|
|
| 500 |
}
|
| 501 |
};
|
| 502 |
|
| 503 |
+
const formatDrugName = (drugId) => {
|
| 504 |
+
if (!drugId) return "";
|
| 505 |
+
return drugId
|
| 506 |
+
.replace(/^DRUG_/, "")
|
| 507 |
+
.replace(/_/g, " ")
|
| 508 |
+
.replace(/\b\w/g, (c) => c.toUpperCase());
|
| 509 |
+
};
|
| 510 |
+
|
| 511 |
+
const currentRisk = obs?.metadata?.current_risk;
|
| 512 |
+
const baselineRisk = obs?.metadata?.baseline_risk;
|
| 513 |
+
|
| 514 |
return (
|
| 515 |
<div className="shell">
|
| 516 |
<div className="bg-orb orb-a" />
|
| 517 |
<div className="bg-orb orb-b" />
|
| 518 |
|
| 519 |
+
{/* Spotlight Guide */}
|
| 520 |
+
{showGuide && (
|
| 521 |
+
<SpotlightGuide
|
| 522 |
+
step={guideStep}
|
| 523 |
+
steps={GUIDE_STEPS}
|
| 524 |
+
onNext={() => setGuideStep((s) => Math.min(s + 1, GUIDE_STEPS.length - 1))}
|
| 525 |
+
onPrev={() => setGuideStep((s) => Math.max(0, s - 1))}
|
| 526 |
+
onClose={() => setShowGuide(false)}
|
| 527 |
+
/>
|
| 528 |
+
)}
|
| 529 |
+
|
| 530 |
<div className="container">
|
| 531 |
+
<header className="topbar glass" data-guide="topbar">
|
| 532 |
+
<div className="title-wrap">
|
| 533 |
+
<h1>PolypharmacyEnv</h1>
|
| 534 |
+
<p>Elderly Medication Safety — Powered by Neural Bandits</p>
|
| 535 |
+
</div>
|
| 536 |
+
<div className="topbar-right">
|
| 537 |
+
<div className={`status-chip ${hasValidEpisode ? "live" : "idle"}`}>
|
| 538 |
+
{hasValidEpisode
|
| 539 |
+
? isDone
|
| 540 |
+
? "Episode Complete"
|
| 541 |
+
: "Session Live"
|
| 542 |
+
: "Ready"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
</div>
|
| 544 |
+
<button
|
| 545 |
+
className="guide-trigger"
|
| 546 |
+
onClick={() => {
|
| 547 |
+
setGuideStep(0);
|
| 548 |
+
setShowGuide(true);
|
| 549 |
+
}}
|
| 550 |
+
title="Open guided walkthrough"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
>
|
| 552 |
+
?
|
| 553 |
+
</button>
|
| 554 |
+
</div>
|
| 555 |
+
<div className="actions" data-guide="task-selector">
|
| 556 |
+
<select value={taskId} onChange={(e) => setTaskId(e.target.value)}>
|
| 557 |
+
{TASKS.map((t) => (
|
| 558 |
+
<option key={t.id} value={t.id}>
|
| 559 |
+
{t.label}
|
| 560 |
+
</option>
|
| 561 |
+
))}
|
| 562 |
</select>
|
| 563 |
+
<button onClick={handleReset} disabled={loading}>
|
| 564 |
+
Reset Episode
|
| 565 |
+
</button>
|
| 566 |
+
<button
|
| 567 |
+
className="secondary"
|
| 568 |
+
onClick={askAi}
|
| 569 |
+
disabled={!hasValidEpisode || isDone || loading}
|
| 570 |
+
>
|
| 571 |
+
Ask AI + Auto Step
|
| 572 |
+
</button>
|
| 573 |
</div>
|
| 574 |
+
</header>
|
| 575 |
|
| 576 |
+
<main className="layout">
|
| 577 |
+
{/* Episode Info */}
|
| 578 |
+
<section className="panel glass panel-wide" data-guide="episode-panel">
|
| 579 |
+
<h2>Episode Overview</h2>
|
| 580 |
+
{hasValidEpisode ? (
|
| 581 |
+
<>
|
| 582 |
+
<div className="kpi-grid">
|
| 583 |
+
<div>
|
| 584 |
+
<span>Episode</span>
|
| 585 |
+
<strong>{obs.episode_id}</strong>
|
| 586 |
+
</div>
|
| 587 |
+
<div>
|
| 588 |
+
<span>Task</span>
|
| 589 |
+
<strong>{TASK_LABEL_MAP[obs.task_id] || obs.task_id}</strong>
|
| 590 |
+
</div>
|
| 591 |
+
<div>
|
| 592 |
+
<span>Patient</span>
|
| 593 |
+
<strong>
|
| 594 |
+
Age {obs.age}, {obs.sex === "M" ? "Male" : "Female"}
|
| 595 |
+
</strong>
|
| 596 |
+
</div>
|
| 597 |
+
<div>
|
| 598 |
+
<span>Step</span>
|
| 599 |
+
<strong>{obs.step_index}</strong>
|
| 600 |
+
</div>
|
| 601 |
+
<div>
|
| 602 |
+
<span>Query Budget</span>
|
| 603 |
+
<strong>{obs.remaining_query_budget} remaining</strong>
|
| 604 |
+
</div>
|
| 605 |
+
<div>
|
| 606 |
+
<span>Intervention Budget</span>
|
| 607 |
+
<strong>
|
| 608 |
+
{obs.remaining_intervention_budget} remaining
|
| 609 |
+
</strong>
|
| 610 |
+
</div>
|
| 611 |
+
</div>
|
| 612 |
|
| 613 |
+
{currentRisk !== undefined && baselineRisk !== undefined && (
|
| 614 |
+
<div className="risk-bar-wrap">
|
| 615 |
+
<div className="risk-labels">
|
| 616 |
+
<span>
|
| 617 |
+
Baseline Risk:{" "}
|
| 618 |
+
<strong>{Number(baselineRisk).toFixed(3)}</strong>
|
| 619 |
+
</span>
|
| 620 |
+
<span>
|
| 621 |
+
Current Risk:{" "}
|
| 622 |
+
<strong
|
| 623 |
+
className={
|
| 624 |
+
currentRisk < baselineRisk
|
| 625 |
+
? "risk-down"
|
| 626 |
+
: "risk-same"
|
| 627 |
+
}
|
| 628 |
+
>
|
| 629 |
+
{Number(currentRisk).toFixed(3)}
|
| 630 |
+
</strong>
|
| 631 |
+
</span>
|
| 632 |
+
</div>
|
| 633 |
+
<div className="risk-bar">
|
| 634 |
+
<div
|
| 635 |
+
className="risk-fill"
|
| 636 |
+
style={{
|
| 637 |
+
width: `${Math.min(currentRisk * 100, 100)}%`,
|
| 638 |
+
}}
|
| 639 |
+
/>
|
| 640 |
+
</div>
|
| 641 |
+
</div>
|
| 642 |
+
)}
|
| 643 |
+
|
| 644 |
+
{obs.conditions && obs.conditions.length > 0 && (
|
| 645 |
+
<div className="conditions-row">
|
| 646 |
+
<span className="conditions-label">Conditions:</span>
|
| 647 |
+
{obs.conditions.map((c) => (
|
| 648 |
+
<span key={c} className="condition-tag">
|
| 649 |
+
{c.replace(/_/g, " ")}
|
| 650 |
+
</span>
|
| 651 |
+
))}
|
| 652 |
+
</div>
|
| 653 |
+
)}
|
| 654 |
+
</>
|
| 655 |
+
) : (
|
| 656 |
+
<p className="muted">
|
| 657 |
+
Select a task difficulty and click <strong>Reset Episode</strong>{" "}
|
| 658 |
+
to begin a patient case.
|
| 659 |
+
</p>
|
| 660 |
+
)}
|
| 661 |
+
{noBudgetsLeft && !isDone && (
|
| 662 |
+
<div className="budget-note">
|
| 663 |
+
All budgets exhausted. Click <strong>Finish Review</strong> to
|
| 664 |
+
receive your final score.
|
| 665 |
+
</div>
|
| 666 |
+
)}
|
| 667 |
+
{isDone && (
|
| 668 |
+
<div className="budget-note done-note">
|
| 669 |
+
Episode complete
|
| 670 |
+
{finalScore !== null
|
| 671 |
+
? ` — Final score: ${(finalScore * 100).toFixed(1)}%`
|
| 672 |
+
: ""}
|
| 673 |
+
. Click <strong>Reset Episode</strong> to start a new case.
|
| 674 |
+
</div>
|
| 675 |
+
)}
|
| 676 |
+
</section>
|
| 677 |
+
|
| 678 |
+
{/* Action Console */}
|
| 679 |
+
<section className="panel glass" data-guide="action-console">
|
| 680 |
+
<h2>Action Console</h2>
|
| 681 |
+
<div className="action-row">
|
| 682 |
+
<label>Action Type</label>
|
| 683 |
<select
|
| 684 |
+
value={action.action_type}
|
| 685 |
+
onChange={(e) =>
|
| 686 |
+
setAction((a) => ({ ...a, action_type: e.target.value }))
|
| 687 |
+
}
|
| 688 |
>
|
| 689 |
+
{Object.entries(ACTION_LABELS).map(([val, label]) => (
|
| 690 |
+
<option key={val} value={val}>
|
| 691 |
+
{label}
|
|
|
|
| 692 |
</option>
|
| 693 |
))}
|
| 694 |
</select>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
</div>
|
| 696 |
+
|
| 697 |
+
{action.action_type === "query_ddi" && (
|
| 698 |
+
<div className="stack stack-two">
|
| 699 |
+
<div className="field-group">
|
| 700 |
+
<label>Drug 1</label>
|
| 701 |
+
<select
|
| 702 |
+
value={action.drug_id_1}
|
| 703 |
+
onChange={(e) =>
|
| 704 |
+
setAction((a) => ({ ...a, drug_id_1: e.target.value }))
|
| 705 |
+
}
|
| 706 |
+
>
|
| 707 |
+
<option value="">Select drug</option>
|
| 708 |
+
{medIds.map((id) => (
|
| 709 |
+
<option key={id} value={id}>
|
| 710 |
+
{formatDrugName(id)}
|
| 711 |
+
</option>
|
| 712 |
+
))}
|
| 713 |
+
</select>
|
| 714 |
+
</div>
|
| 715 |
+
<div className="field-group">
|
| 716 |
+
<label>Drug 2</label>
|
| 717 |
+
<select
|
| 718 |
+
value={action.drug_id_2}
|
| 719 |
+
onChange={(e) =>
|
| 720 |
+
setAction((a) => ({ ...a, drug_id_2: e.target.value }))
|
| 721 |
+
}
|
| 722 |
+
>
|
| 723 |
+
<option value="">Select drug</option>
|
| 724 |
+
{medIds.map((id) => (
|
| 725 |
+
<option key={id} value={id}>
|
| 726 |
+
{formatDrugName(id)}
|
| 727 |
+
</option>
|
| 728 |
+
))}
|
| 729 |
+
</select>
|
| 730 |
+
</div>
|
| 731 |
</div>
|
| 732 |
+
)}
|
| 733 |
+
|
| 734 |
+
{action.action_type === "propose_intervention" && (
|
| 735 |
+
<div className="stack">
|
| 736 |
+
<div className="field-group">
|
| 737 |
+
<label>Target Drug</label>
|
| 738 |
+
<select
|
| 739 |
+
value={action.target_drug_id}
|
| 740 |
+
onChange={(e) =>
|
| 741 |
+
setAction((a) => ({
|
| 742 |
+
...a,
|
| 743 |
+
target_drug_id: e.target.value,
|
| 744 |
+
}))
|
| 745 |
+
}
|
| 746 |
+
>
|
| 747 |
+
<option value="">Select target drug</option>
|
| 748 |
+
{medIds.map((id) => (
|
| 749 |
+
<option key={id} value={id}>
|
| 750 |
+
{formatDrugName(id)}
|
| 751 |
+
</option>
|
| 752 |
+
))}
|
| 753 |
+
</select>
|
| 754 |
+
</div>
|
| 755 |
+
<div className="field-group">
|
| 756 |
+
<label>Intervention Type</label>
|
| 757 |
+
<select
|
| 758 |
+
value={action.intervention_type}
|
| 759 |
+
onChange={(e) =>
|
| 760 |
+
setAction((a) => ({
|
| 761 |
+
...a,
|
| 762 |
+
intervention_type: e.target.value,
|
| 763 |
+
}))
|
| 764 |
+
}
|
| 765 |
+
>
|
| 766 |
+
{Object.entries(INTERVENTION_LABELS).map(([val, label]) => (
|
| 767 |
+
<option key={val} value={val}>
|
| 768 |
+
{label}
|
| 769 |
+
</option>
|
| 770 |
+
))}
|
| 771 |
+
</select>
|
| 772 |
+
</div>
|
| 773 |
+
<div className="field-group">
|
| 774 |
+
<label>New Drug ID (optional, for substitution)</label>
|
| 775 |
+
<input
|
| 776 |
+
placeholder="Leave blank for auto-selection"
|
| 777 |
+
value={action.proposed_new_drug_id}
|
| 778 |
+
onChange={(e) =>
|
| 779 |
+
setAction((a) => ({
|
| 780 |
+
...a,
|
| 781 |
+
proposed_new_drug_id: e.target.value,
|
| 782 |
+
}))
|
| 783 |
+
}
|
| 784 |
+
/>
|
| 785 |
+
</div>
|
| 786 |
+
<div className="field-group">
|
| 787 |
+
<label>Rationale (optional)</label>
|
| 788 |
+
<input
|
| 789 |
+
placeholder="e.g., High bleeding risk with concurrent warfarin"
|
| 790 |
+
value={action.rationale}
|
| 791 |
+
onChange={(e) =>
|
| 792 |
+
setAction((a) => ({ ...a, rationale: e.target.value }))
|
| 793 |
+
}
|
| 794 |
+
/>
|
| 795 |
+
</div>
|
| 796 |
+
</div>
|
| 797 |
+
)}
|
| 798 |
+
|
| 799 |
+
<button
|
| 800 |
+
className="submit-btn"
|
| 801 |
+
onClick={() => handleStep()}
|
| 802 |
+
disabled={!isActionValid() || loading}
|
| 803 |
+
>
|
| 804 |
+
{noBudgetsLeft ? "Finish Review" : "Submit Step"}
|
| 805 |
+
</button>
|
| 806 |
+
</section>
|
| 807 |
+
|
| 808 |
+
{/* Current Medications */}
|
| 809 |
+
<section className="panel glass" data-guide="medications-panel">
|
| 810 |
+
<h2>
|
| 811 |
+
Current Medications
|
| 812 |
+
{obs?.current_medications?.length
|
| 813 |
+
? ` (${obs.current_medications.length})`
|
| 814 |
+
: ""}
|
| 815 |
+
</h2>
|
| 816 |
+
<div className="med-grid">
|
| 817 |
+
{(obs?.current_medications || []).map((m) => (
|
| 818 |
+
<div
|
| 819 |
+
key={m.drug_id}
|
| 820 |
+
className={`med-card ${m.is_high_risk_elderly ? "high-risk" : ""}`}
|
| 821 |
+
>
|
| 822 |
+
<div className="med-card-header">
|
| 823 |
+
<strong>{formatDrugName(m.drug_id)}</strong>
|
| 824 |
+
{m.is_high_risk_elderly && (
|
| 825 |
+
<span className="risk-badge">High Risk</span>
|
| 826 |
+
)}
|
| 827 |
+
</div>
|
| 828 |
+
<p className="med-generic">{m.generic_name}</p>
|
| 829 |
+
<div className="med-details">
|
| 830 |
+
<span>{m.dose_mg} mg</span>
|
| 831 |
+
<span className="med-atc">{m.atc_class}</span>
|
| 832 |
+
</div>
|
| 833 |
+
{m.beers_flags && m.beers_flags.length > 0 && (
|
| 834 |
+
<div className="beers-flags">
|
| 835 |
+
{m.beers_flags.map((f, i) => (
|
| 836 |
+
<span key={i} className="beers-tag">
|
| 837 |
+
{f}
|
| 838 |
+
</span>
|
| 839 |
+
))}
|
| 840 |
+
</div>
|
| 841 |
+
)}
|
| 842 |
+
</div>
|
| 843 |
+
))}
|
| 844 |
+
</div>
|
| 845 |
+
{(!obs?.current_medications ||
|
| 846 |
+
obs.current_medications.length === 0) && (
|
| 847 |
+
<p className="muted">No medications loaded. Reset an episode to begin.</p>
|
| 848 |
+
)}
|
| 849 |
+
</section>
|
| 850 |
+
|
| 851 |
+
{/* Interaction Queries & Interventions */}
|
| 852 |
+
{hasValidEpisode && (
|
| 853 |
+
<section className="panel glass panel-wide">
|
| 854 |
+
<div className="history-grid">
|
| 855 |
+
<div>
|
| 856 |
+
<h3>Drug Interaction Checks ({obs?.interaction_queries?.length || 0})</h3>
|
| 857 |
+
<div className="history-list">
|
| 858 |
+
{(obs?.interaction_queries || []).map((q, i) => (
|
| 859 |
+
<div
|
| 860 |
+
key={i}
|
| 861 |
+
className={`history-item severity-${q.severity}`}
|
| 862 |
+
>
|
| 863 |
+
<strong>
|
| 864 |
+
{formatDrugName(q.drug_id_1)} +{" "}
|
| 865 |
+
{formatDrugName(q.drug_id_2)}
|
| 866 |
+
</strong>
|
| 867 |
+
<span className={`severity-tag ${q.severity}`}>
|
| 868 |
+
{q.severity}
|
| 869 |
+
</span>
|
| 870 |
+
{q.recommendation && (
|
| 871 |
+
<p className="history-detail">
|
| 872 |
+
{q.recommendation.replace(/_/g, " ")}
|
| 873 |
+
</p>
|
| 874 |
+
)}
|
| 875 |
+
</div>
|
| 876 |
+
))}
|
| 877 |
+
{(!obs?.interaction_queries || obs.interaction_queries.length === 0) && (
|
| 878 |
+
<p className="muted">No queries yet.</p>
|
| 879 |
+
)}
|
| 880 |
+
</div>
|
| 881 |
+
</div>
|
| 882 |
+
<div>
|
| 883 |
+
<h3>Proposed Changes ({obs?.interventions?.length || 0})</h3>
|
| 884 |
+
<div className="history-list">
|
| 885 |
+
{(obs?.interventions || []).map((iv, i) => (
|
| 886 |
+
<div key={i} className="history-item intervention-item">
|
| 887 |
+
<strong>{formatDrugName(iv.target_drug_id)}</strong>
|
| 888 |
+
<span className="intervention-tag">
|
| 889 |
+
{INTERVENTION_LABELS[iv.action_type] || iv.action_type}
|
| 890 |
+
</span>
|
| 891 |
+
{iv.proposed_new_drug_id && (
|
| 892 |
+
<p className="history-detail">
|
| 893 |
+
Replaced with: {formatDrugName(iv.proposed_new_drug_id)}
|
| 894 |
+
</p>
|
| 895 |
+
)}
|
| 896 |
+
{iv.rationale && (
|
| 897 |
+
<p className="history-detail">{iv.rationale}</p>
|
| 898 |
+
)}
|
| 899 |
+
</div>
|
| 900 |
+
))}
|
| 901 |
+
{(!obs?.interventions || obs.interventions.length === 0) && (
|
| 902 |
+
<p className="muted">No interventions yet.</p>
|
| 903 |
+
)}
|
| 904 |
+
</div>
|
| 905 |
+
</div>
|
| 906 |
+
</div>
|
| 907 |
+
</section>
|
| 908 |
+
)}
|
| 909 |
+
|
| 910 |
+
{/* Event Log */}
|
| 911 |
+
<section className="panel glass panel-wide" data-guide="event-log">
|
| 912 |
+
<h2>Event Log</h2>
|
| 913 |
+
<div className="logs">
|
| 914 |
+
{log.length === 0 && (
|
| 915 |
+
<div className="log-empty">
|
| 916 |
+
Events will appear here as you interact with the environment.
|
| 917 |
+
</div>
|
| 918 |
+
)}
|
| 919 |
+
{log.map((line, idx) => (
|
| 920 |
+
<div key={idx}>{line}</div>
|
| 921 |
+
))}
|
| 922 |
+
</div>
|
| 923 |
+
</section>
|
| 924 |
+
</main>
|
| 925 |
+
|
| 926 |
+
<footer className="app-footer">
|
| 927 |
+
<p>
|
| 928 |
+
PolypharmacyEnv — Built with{" "}
|
| 929 |
+
<a
|
| 930 |
+
href="https://github.com/meta-pytorch/OpenEnv"
|
| 931 |
+
target="_blank"
|
| 932 |
+
rel="noopener noreferrer"
|
| 933 |
+
>
|
| 934 |
+
PyTorch OpenEnv
|
| 935 |
+
</a>{" "}
|
| 936 |
+
| Based on{" "}
|
| 937 |
+
<a
|
| 938 |
+
href="https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5"
|
| 939 |
+
target="_blank"
|
| 940 |
+
rel="noopener noreferrer"
|
| 941 |
+
>
|
| 942 |
+
Neural Bandits for Polypharmacy
|
| 943 |
+
</a>{" "}
|
| 944 |
+
(Larouche et al.)
|
| 945 |
+
</p>
|
| 946 |
+
</footer>
|
| 947 |
</div>
|
| 948 |
</div>
|
| 949 |
);
|
frontend/src/styles.css
CHANGED
|
@@ -1,18 +1,22 @@
|
|
| 1 |
:root {
|
| 2 |
-
--bg: #
|
| 3 |
-
--bg-layer: #
|
| 4 |
-
--panel:
|
| 5 |
-
--panel-solid:
|
| 6 |
-
--text: #
|
| 7 |
-
--muted: #
|
| 8 |
-
--primary: #
|
| 9 |
-
--primary-2: #
|
| 10 |
-
--accent: #
|
| 11 |
-
--success: #
|
| 12 |
-
--
|
| 13 |
-
--
|
| 14 |
-
--
|
| 15 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
}
|
| 17 |
|
| 18 |
* {
|
|
@@ -22,187 +26,253 @@
|
|
| 22 |
body {
|
| 23 |
margin: 0;
|
| 24 |
color: var(--text);
|
| 25 |
-
font-family: "
|
| 26 |
-
background:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
linear-gradient(145deg, var(--bg) 0%, var(--bg-layer) 60%, #04060f 100%);
|
| 31 |
-
background-attachment: fixed;
|
| 32 |
}
|
| 33 |
|
| 34 |
.shell {
|
| 35 |
min-height: 100vh;
|
| 36 |
position: relative;
|
| 37 |
overflow: hidden;
|
| 38 |
-
padding: 24px
|
| 39 |
}
|
| 40 |
|
| 41 |
.container {
|
| 42 |
-
width: min(
|
| 43 |
margin: 0 auto;
|
| 44 |
position: relative;
|
| 45 |
z-index: 2;
|
| 46 |
}
|
| 47 |
|
|
|
|
| 48 |
.bg-orb {
|
| 49 |
position: absolute;
|
| 50 |
border-radius: 50%;
|
| 51 |
pointer-events: none;
|
| 52 |
-
opacity: 0.
|
| 53 |
-
filter: blur(
|
| 54 |
}
|
| 55 |
|
| 56 |
.orb-a {
|
| 57 |
-
width: min(
|
| 58 |
aspect-ratio: 1 / 1;
|
| 59 |
-
right: -
|
| 60 |
-
top: -
|
| 61 |
-
background: radial-gradient(circle, rgba(
|
| 62 |
}
|
| 63 |
|
| 64 |
.orb-b {
|
| 65 |
-
width: min(
|
| 66 |
aspect-ratio: 1 / 1;
|
| 67 |
-
left: -
|
| 68 |
-
bottom: -
|
| 69 |
-
background: radial-gradient(circle, rgba(
|
| 70 |
}
|
| 71 |
|
|
|
|
|
|
|
| 72 |
.glass {
|
| 73 |
-
background:
|
| 74 |
-
linear-gradient(180deg, rgba(255, 255, 255, 0.06), rgba(255, 255, 255, 0.01)),
|
| 75 |
-
var(--panel);
|
| 76 |
border: 1px solid var(--border);
|
| 77 |
box-shadow: var(--shadow);
|
| 78 |
-
backdrop-filter: blur(12px);
|
| 79 |
}
|
| 80 |
|
|
|
|
|
|
|
| 81 |
.topbar {
|
| 82 |
-
border-radius:
|
| 83 |
-
padding:
|
| 84 |
-
display:
|
| 85 |
-
gap: 12px 16px;
|
| 86 |
-
grid-template-columns: minmax(220px, 1.2fr) auto minmax(280px, 1fr);
|
| 87 |
align-items: center;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
.title-wrap h1 {
|
| 91 |
margin: 0;
|
| 92 |
-
font-size: clamp(
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
}
|
| 97 |
|
| 98 |
.title-wrap p {
|
| 99 |
-
margin:
|
| 100 |
-
font-size: 0.
|
| 101 |
color: var(--muted);
|
| 102 |
-
letter-spacing: 0.
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
}
|
| 105 |
|
| 106 |
.status-chip {
|
| 107 |
-
|
| 108 |
-
padding: 7px 14px;
|
| 109 |
border-radius: 999px;
|
| 110 |
font-size: 0.72rem;
|
| 111 |
-
font-weight:
|
| 112 |
-
letter-spacing: 0.
|
| 113 |
text-transform: uppercase;
|
| 114 |
border: 1px solid transparent;
|
|
|
|
| 115 |
}
|
| 116 |
|
| 117 |
.status-chip.live {
|
| 118 |
-
color: #
|
| 119 |
-
background:
|
| 120 |
-
box-shadow:
|
| 121 |
}
|
| 122 |
|
| 123 |
.status-chip.idle {
|
| 124 |
-
color:
|
| 125 |
-
border-color:
|
| 126 |
-
background: rgba(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
}
|
| 128 |
|
| 129 |
.actions {
|
| 130 |
display: flex;
|
| 131 |
-
|
| 132 |
flex-wrap: wrap;
|
| 133 |
gap: 10px;
|
| 134 |
}
|
| 135 |
|
|
|
|
|
|
|
| 136 |
button,
|
| 137 |
select,
|
| 138 |
input {
|
| 139 |
width: 100%;
|
| 140 |
min-height: 42px;
|
| 141 |
-
border-radius:
|
| 142 |
border: 1px solid var(--border);
|
| 143 |
-
font-size: 0.
|
| 144 |
-
padding: 10px
|
| 145 |
color: var(--text);
|
| 146 |
-
background:
|
|
|
|
|
|
|
| 147 |
}
|
| 148 |
|
| 149 |
-
select,
|
| 150 |
-
input {
|
| 151 |
-
|
| 152 |
}
|
| 153 |
|
| 154 |
select:focus,
|
| 155 |
input:focus {
|
| 156 |
outline: none;
|
| 157 |
-
border-color:
|
| 158 |
-
box-shadow: 0 0 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
}
|
| 160 |
|
| 161 |
button {
|
| 162 |
cursor: pointer;
|
| 163 |
border: 0;
|
| 164 |
width: auto;
|
| 165 |
-
font-weight:
|
| 166 |
-
letter-spacing: 0.
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
| 170 |
}
|
| 171 |
|
| 172 |
button:hover {
|
|
|
|
| 173 |
transform: translateY(-1px);
|
| 174 |
-
filter:
|
| 175 |
-
box-shadow: 0
|
| 176 |
}
|
| 177 |
|
| 178 |
button:active {
|
| 179 |
transform: translateY(0);
|
|
|
|
| 180 |
}
|
| 181 |
|
| 182 |
button.secondary {
|
| 183 |
-
background:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
}
|
| 185 |
|
| 186 |
button:disabled {
|
| 187 |
-
opacity: 0.
|
| 188 |
cursor: not-allowed;
|
| 189 |
-
filter:
|
| 190 |
box-shadow: none;
|
| 191 |
transform: none;
|
| 192 |
}
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
.layout {
|
| 195 |
-
margin-top:
|
| 196 |
display: grid;
|
| 197 |
-
gap:
|
| 198 |
-
grid-template-columns: 1.
|
| 199 |
align-items: start;
|
| 200 |
}
|
| 201 |
|
| 202 |
.panel {
|
| 203 |
-
border-radius:
|
| 204 |
-
padding:
|
| 205 |
position: relative;
|
|
|
|
| 206 |
}
|
| 207 |
|
| 208 |
.panel::after {
|
|
@@ -219,151 +289,576 @@ button:disabled {
|
|
| 219 |
}
|
| 220 |
|
| 221 |
.panel h2 {
|
| 222 |
-
margin: 0 0
|
| 223 |
-
font-size:
|
| 224 |
-
font-weight:
|
| 225 |
letter-spacing: 0.05em;
|
| 226 |
text-transform: uppercase;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
}
|
| 228 |
|
|
|
|
|
|
|
| 229 |
.kpi-grid {
|
| 230 |
display: grid;
|
| 231 |
-
gap:
|
| 232 |
-
grid-template-columns: repeat(3,
|
| 233 |
}
|
| 234 |
|
| 235 |
.kpi-grid div {
|
| 236 |
-
border-radius:
|
| 237 |
border: 1px solid var(--border);
|
| 238 |
background: var(--panel-solid);
|
| 239 |
-
padding:
|
|
|
|
| 240 |
}
|
| 241 |
|
| 242 |
.kpi-grid span {
|
| 243 |
display: block;
|
| 244 |
-
margin-bottom:
|
| 245 |
font-size: 0.72rem;
|
| 246 |
color: var(--muted);
|
| 247 |
text-transform: uppercase;
|
| 248 |
-
letter-spacing: 0.
|
| 249 |
}
|
| 250 |
|
| 251 |
.kpi-grid strong {
|
| 252 |
-
font-size: 1.
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
}
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
.action-row,
|
| 257 |
.stack {
|
| 258 |
display: grid;
|
| 259 |
-
gap:
|
| 260 |
-
margin-bottom:
|
| 261 |
}
|
| 262 |
|
| 263 |
-
.action-row label
|
|
|
|
|
|
|
| 264 |
color: var(--muted);
|
| 265 |
font-size: 0.78rem;
|
| 266 |
-
letter-spacing: 0.
|
| 267 |
text-transform: uppercase;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
}
|
| 269 |
|
| 270 |
.stack-two {
|
| 271 |
-
grid-template-columns:
|
| 272 |
}
|
| 273 |
|
|
|
|
|
|
|
| 274 |
.med-grid {
|
| 275 |
display: grid;
|
| 276 |
-
grid-template-columns: repeat(
|
| 277 |
-
gap:
|
| 278 |
-
max-height:
|
| 279 |
-
overflow: auto;
|
| 280 |
padding-right: 4px;
|
| 281 |
}
|
| 282 |
|
| 283 |
.med-card {
|
| 284 |
-
border-radius:
|
| 285 |
border: 1px solid var(--border);
|
| 286 |
background: var(--panel-solid);
|
| 287 |
-
padding:
|
| 288 |
-
transition:
|
|
|
|
| 289 |
}
|
| 290 |
|
| 291 |
.med-card:hover {
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
}
|
| 295 |
|
| 296 |
-
.med-
|
| 297 |
-
margin: 6px 0
|
| 298 |
color: var(--muted);
|
|
|
|
| 299 |
text-transform: capitalize;
|
|
|
|
|
|
|
|
|
|
| 300 |
}
|
| 301 |
|
| 302 |
-
.med-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
}
|
| 305 |
|
| 306 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
max-height: 300px;
|
| 308 |
-
overflow: auto;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
padding-right: 4px;
|
| 310 |
-
display:
|
| 311 |
-
|
|
|
|
| 312 |
font-size: 0.84rem;
|
| 313 |
-
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, monospace;
|
| 314 |
}
|
| 315 |
|
| 316 |
.logs div {
|
| 317 |
-
border-radius:
|
| 318 |
border: 1px solid var(--border);
|
| 319 |
-
background:
|
| 320 |
-
padding:
|
| 321 |
-
color:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
}
|
| 323 |
|
|
|
|
|
|
|
| 324 |
.muted {
|
| 325 |
margin: 0;
|
| 326 |
color: var(--muted);
|
|
|
|
| 327 |
}
|
| 328 |
|
| 329 |
.budget-note {
|
| 330 |
-
margin-top:
|
| 331 |
border: 1px solid var(--border);
|
| 332 |
-
border-radius:
|
| 333 |
-
padding:
|
| 334 |
-
background:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
}
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
@media (max-width: 1180px) {
|
| 338 |
.layout {
|
| 339 |
grid-template-columns: 1fr;
|
| 340 |
}
|
| 341 |
|
| 342 |
.topbar {
|
| 343 |
-
|
|
|
|
| 344 |
}
|
| 345 |
|
| 346 |
-
.
|
| 347 |
-
|
| 348 |
}
|
| 349 |
|
| 350 |
.actions {
|
|
|
|
| 351 |
justify-content: flex-start;
|
| 352 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
}
|
| 354 |
|
| 355 |
@media (max-width: 760px) {
|
| 356 |
.shell {
|
| 357 |
-
padding:
|
| 358 |
}
|
| 359 |
|
| 360 |
.topbar,
|
| 361 |
.panel {
|
| 362 |
-
border-radius:
|
|
|
|
| 363 |
}
|
| 364 |
|
| 365 |
.actions {
|
| 366 |
-
|
| 367 |
}
|
| 368 |
|
| 369 |
.actions button,
|
|
@@ -371,13 +866,53 @@ button:disabled {
|
|
| 371 |
width: 100%;
|
| 372 |
}
|
| 373 |
|
| 374 |
-
.kpi-grid
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
.stack-two {
|
| 377 |
grid-template-columns: 1fr;
|
| 378 |
}
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
.logs {
|
| 381 |
-
max-height:
|
| 382 |
}
|
| 383 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
:root {
|
| 2 |
+
--bg: #0d1117;
|
| 3 |
+
--bg-layer: #0d1117;
|
| 4 |
+
--panel: #161b22;
|
| 5 |
+
--panel-solid: #1c2333;
|
| 6 |
+
--text: #e6edf3;
|
| 7 |
+
--muted: #8b949e;
|
| 8 |
+
--primary: #58a6ff;
|
| 9 |
+
--primary-2: #79c0ff;
|
| 10 |
+
--accent: #58a6ff;
|
| 11 |
+
--success: #3fb950;
|
| 12 |
+
--danger: #f85149;
|
| 13 |
+
--warning: #d29922;
|
| 14 |
+
--border: rgba(48, 54, 61, 0.7);
|
| 15 |
+
--line: rgba(48, 54, 61, 0.5);
|
| 16 |
+
--shadow: 0 1px 3px rgba(0, 0, 0, 0.12), 0 4px 12px rgba(0, 0, 0, 0.08);
|
| 17 |
+
--shadow-strong: 0 4px 16px rgba(0, 0, 0, 0.2);
|
| 18 |
+
--radius: 12px;
|
| 19 |
+
--radius-sm: 8px;
|
| 20 |
}
|
| 21 |
|
| 22 |
* {
|
|
|
|
| 26 |
body {
|
| 27 |
margin: 0;
|
| 28 |
color: var(--text);
|
| 29 |
+
font-family: "Inter", -apple-system, "Segoe UI", "Helvetica Neue", sans-serif;
|
| 30 |
+
background: var(--bg);
|
| 31 |
+
line-height: 1.55;
|
| 32 |
+
-webkit-font-smoothing: antialiased;
|
| 33 |
+
-moz-osx-font-smoothing: grayscale;
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
|
| 36 |
.shell {
|
| 37 |
min-height: 100vh;
|
| 38 |
position: relative;
|
| 39 |
overflow: hidden;
|
| 40 |
+
padding: 20px 24px 40px;
|
| 41 |
}
|
| 42 |
|
| 43 |
.container {
|
| 44 |
+
width: min(1400px, 100%);
|
| 45 |
margin: 0 auto;
|
| 46 |
position: relative;
|
| 47 |
z-index: 2;
|
| 48 |
}
|
| 49 |
|
| 50 |
+
/* Background orbs - subtle and muted for a professional look */
|
| 51 |
.bg-orb {
|
| 52 |
position: absolute;
|
| 53 |
border-radius: 50%;
|
| 54 |
pointer-events: none;
|
| 55 |
+
opacity: 0.3;
|
| 56 |
+
filter: blur(80px);
|
| 57 |
}
|
| 58 |
|
| 59 |
.orb-a {
|
| 60 |
+
width: min(42vw, 500px);
|
| 61 |
aspect-ratio: 1 / 1;
|
| 62 |
+
right: -8%;
|
| 63 |
+
top: -8%;
|
| 64 |
+
background: radial-gradient(circle, rgba(88, 166, 255, 0.12), transparent 70%);
|
| 65 |
}
|
| 66 |
|
| 67 |
.orb-b {
|
| 68 |
+
width: min(36vw, 420px);
|
| 69 |
aspect-ratio: 1 / 1;
|
| 70 |
+
left: -8%;
|
| 71 |
+
bottom: -12%;
|
| 72 |
+
background: radial-gradient(circle, rgba(88, 166, 255, 0.08), transparent 72%);
|
| 73 |
}
|
| 74 |
|
| 75 |
+
/* ── Panels ──────────────────────────────────────────────────── */
|
| 76 |
+
|
| 77 |
.glass {
|
| 78 |
+
background: var(--panel);
|
|
|
|
|
|
|
| 79 |
border: 1px solid var(--border);
|
| 80 |
box-shadow: var(--shadow);
|
|
|
|
| 81 |
}
|
| 82 |
|
| 83 |
+
/* ── Top Bar ─────────────────────────────────────────────────── */
|
| 84 |
+
|
| 85 |
.topbar {
|
| 86 |
+
border-radius: var(--radius);
|
| 87 |
+
padding: 14px 24px;
|
| 88 |
+
display: flex;
|
|
|
|
|
|
|
| 89 |
align-items: center;
|
| 90 |
+
gap: 16px;
|
| 91 |
+
flex-wrap: wrap;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
.title-wrap {
|
| 95 |
+
flex: 1;
|
| 96 |
+
min-width: 200px;
|
| 97 |
}
|
| 98 |
|
| 99 |
.title-wrap h1 {
|
| 100 |
margin: 0;
|
| 101 |
+
font-size: clamp(1rem, 1.8vw, 1.35rem);
|
| 102 |
+
font-weight: 700;
|
| 103 |
+
letter-spacing: -0.01em;
|
| 104 |
+
color: var(--text);
|
| 105 |
+
background: none;
|
| 106 |
+
-webkit-background-clip: unset;
|
| 107 |
+
-webkit-text-fill-color: unset;
|
| 108 |
+
background-clip: unset;
|
| 109 |
}
|
| 110 |
|
| 111 |
.title-wrap p {
|
| 112 |
+
margin: 2px 0 0;
|
| 113 |
+
font-size: 0.8rem;
|
| 114 |
color: var(--muted);
|
| 115 |
+
letter-spacing: 0.01em;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.topbar-right {
|
| 119 |
+
display: flex;
|
| 120 |
+
align-items: center;
|
| 121 |
+
gap: 10px;
|
| 122 |
}
|
| 123 |
|
| 124 |
.status-chip {
|
| 125 |
+
padding: 5px 14px;
|
|
|
|
| 126 |
border-radius: 999px;
|
| 127 |
font-size: 0.72rem;
|
| 128 |
+
font-weight: 600;
|
| 129 |
+
letter-spacing: 0.04em;
|
| 130 |
text-transform: uppercase;
|
| 131 |
border: 1px solid transparent;
|
| 132 |
+
white-space: nowrap;
|
| 133 |
}
|
| 134 |
|
| 135 |
.status-chip.live {
|
| 136 |
+
color: #ffffff;
|
| 137 |
+
background: var(--success);
|
| 138 |
+
box-shadow: none;
|
| 139 |
}
|
| 140 |
|
| 141 |
.status-chip.idle {
|
| 142 |
+
color: var(--muted);
|
| 143 |
+
border-color: var(--border);
|
| 144 |
+
background: rgba(48, 54, 61, 0.3);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
.guide-trigger {
|
| 148 |
+
width: 34px !important;
|
| 149 |
+
height: 34px;
|
| 150 |
+
min-height: 34px;
|
| 151 |
+
padding: 0;
|
| 152 |
+
border-radius: 50% !important;
|
| 153 |
+
font-size: 0.95rem;
|
| 154 |
+
font-weight: 700;
|
| 155 |
+
display: flex;
|
| 156 |
+
align-items: center;
|
| 157 |
+
justify-content: center;
|
| 158 |
+
background: rgba(88, 166, 255, 0.1);
|
| 159 |
+
border: 1px solid rgba(88, 166, 255, 0.25);
|
| 160 |
+
color: var(--primary);
|
| 161 |
+
box-shadow: none;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.guide-trigger:hover {
|
| 165 |
+
background: rgba(88, 166, 255, 0.2);
|
| 166 |
+
box-shadow: none;
|
| 167 |
}
|
| 168 |
|
| 169 |
.actions {
|
| 170 |
display: flex;
|
| 171 |
+
align-items: center;
|
| 172 |
flex-wrap: wrap;
|
| 173 |
gap: 10px;
|
| 174 |
}
|
| 175 |
|
| 176 |
+
/* ── Form Controls ───────────────────────────────────────────── */
|
| 177 |
+
|
| 178 |
button,
|
| 179 |
select,
|
| 180 |
input {
|
| 181 |
width: 100%;
|
| 182 |
min-height: 42px;
|
| 183 |
+
border-radius: var(--radius-sm);
|
| 184 |
border: 1px solid var(--border);
|
| 185 |
+
font-size: 0.9rem;
|
| 186 |
+
padding: 10px 14px;
|
| 187 |
color: var(--text);
|
| 188 |
+
background: #0d1117;
|
| 189 |
+
font-family: inherit;
|
| 190 |
+
transition: border-color 150ms ease, box-shadow 150ms ease, background 150ms ease;
|
| 191 |
}
|
| 192 |
|
| 193 |
+
select:hover,
|
| 194 |
+
input:hover {
|
| 195 |
+
border-color: rgba(139, 148, 158, 0.5);
|
| 196 |
}
|
| 197 |
|
| 198 |
select:focus,
|
| 199 |
input:focus {
|
| 200 |
outline: none;
|
| 201 |
+
border-color: var(--primary);
|
| 202 |
+
box-shadow: 0 0 0 3px rgba(88, 166, 255, 0.15);
|
| 203 |
+
background: #0d1117;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
select {
|
| 207 |
+
cursor: pointer;
|
| 208 |
+
appearance: auto;
|
| 209 |
}
|
| 210 |
|
| 211 |
button {
|
| 212 |
cursor: pointer;
|
| 213 |
border: 0;
|
| 214 |
width: auto;
|
| 215 |
+
font-weight: 600;
|
| 216 |
+
letter-spacing: 0.01em;
|
| 217 |
+
white-space: nowrap;
|
| 218 |
+
background: var(--primary);
|
| 219 |
+
color: #ffffff;
|
| 220 |
+
box-shadow: none;
|
| 221 |
+
transition: background 150ms ease, transform 100ms ease, box-shadow 150ms ease;
|
| 222 |
}
|
| 223 |
|
| 224 |
button:hover {
|
| 225 |
+
background: #79c0ff;
|
| 226 |
transform: translateY(-1px);
|
| 227 |
+
filter: none;
|
| 228 |
+
box-shadow: 0 2px 8px rgba(88, 166, 255, 0.25);
|
| 229 |
}
|
| 230 |
|
| 231 |
button:active {
|
| 232 |
transform: translateY(0);
|
| 233 |
+
background: #4090e0;
|
| 234 |
}
|
| 235 |
|
| 236 |
button.secondary {
|
| 237 |
+
background: rgba(88, 166, 255, 0.15);
|
| 238 |
+
color: var(--primary);
|
| 239 |
+
border: 1px solid rgba(88, 166, 255, 0.3);
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
button.secondary:hover {
|
| 243 |
+
background: rgba(88, 166, 255, 0.25);
|
| 244 |
}
|
| 245 |
|
| 246 |
button:disabled {
|
| 247 |
+
opacity: 0.4;
|
| 248 |
cursor: not-allowed;
|
| 249 |
+
filter: none;
|
| 250 |
box-shadow: none;
|
| 251 |
transform: none;
|
| 252 |
}
|
| 253 |
|
| 254 |
+
.submit-btn {
|
| 255 |
+
width: 100%;
|
| 256 |
+
margin-top: 10px;
|
| 257 |
+
min-height: 44px;
|
| 258 |
+
font-size: 0.92rem;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/* ── Layout ──────────────────────────────────────────────────── */
|
| 262 |
+
|
| 263 |
.layout {
|
| 264 |
+
margin-top: 20px;
|
| 265 |
display: grid;
|
| 266 |
+
gap: 20px;
|
| 267 |
+
grid-template-columns: 1.1fr 0.9fr;
|
| 268 |
align-items: start;
|
| 269 |
}
|
| 270 |
|
| 271 |
.panel {
|
| 272 |
+
border-radius: var(--radius);
|
| 273 |
+
padding: 24px;
|
| 274 |
position: relative;
|
| 275 |
+
overflow: hidden;
|
| 276 |
}
|
| 277 |
|
| 278 |
.panel::after {
|
|
|
|
| 289 |
}
|
| 290 |
|
| 291 |
.panel h2 {
|
| 292 |
+
margin: 0 0 16px;
|
| 293 |
+
font-size: 0.82rem;
|
| 294 |
+
font-weight: 600;
|
| 295 |
letter-spacing: 0.05em;
|
| 296 |
text-transform: uppercase;
|
| 297 |
+
color: var(--muted);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
.panel h3 {
|
| 301 |
+
margin: 0 0 12px;
|
| 302 |
+
font-size: 0.8rem;
|
| 303 |
+
font-weight: 600;
|
| 304 |
+
letter-spacing: 0.04em;
|
| 305 |
+
text-transform: uppercase;
|
| 306 |
+
color: var(--muted);
|
| 307 |
}
|
| 308 |
|
| 309 |
+
/* ── KPI Grid ────────────────────────────────────────────────── */
|
| 310 |
+
|
| 311 |
.kpi-grid {
|
| 312 |
display: grid;
|
| 313 |
+
gap: 12px;
|
| 314 |
+
grid-template-columns: repeat(3, 1fr);
|
| 315 |
}
|
| 316 |
|
| 317 |
.kpi-grid div {
|
| 318 |
+
border-radius: var(--radius-sm);
|
| 319 |
border: 1px solid var(--border);
|
| 320 |
background: var(--panel-solid);
|
| 321 |
+
padding: 16px 18px;
|
| 322 |
+
overflow: hidden;
|
| 323 |
}
|
| 324 |
|
| 325 |
.kpi-grid span {
|
| 326 |
display: block;
|
| 327 |
+
margin-bottom: 6px;
|
| 328 |
font-size: 0.72rem;
|
| 329 |
color: var(--muted);
|
| 330 |
text-transform: uppercase;
|
| 331 |
+
letter-spacing: 0.06em;
|
| 332 |
}
|
| 333 |
|
| 334 |
.kpi-grid strong {
|
| 335 |
+
font-size: 1.05rem;
|
| 336 |
+
font-weight: 600;
|
| 337 |
+
line-height: 1.3;
|
| 338 |
+
word-break: break-word;
|
| 339 |
+
overflow-wrap: break-word;
|
| 340 |
+
color: var(--text);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
/* ── Risk Bar ────────────────────────────────────────────────── */
|
| 344 |
+
|
| 345 |
+
.risk-bar-wrap {
|
| 346 |
+
margin-top: 16px;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
.risk-labels {
|
| 350 |
+
display: flex;
|
| 351 |
+
justify-content: space-between;
|
| 352 |
+
font-size: 0.8rem;
|
| 353 |
+
color: var(--muted);
|
| 354 |
+
margin-bottom: 8px;
|
| 355 |
}
|
| 356 |
|
| 357 |
+
.risk-labels strong {
|
| 358 |
+
font-size: 0.85rem;
|
| 359 |
+
color: var(--text);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
.risk-down {
|
| 363 |
+
color: var(--success) !important;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
.risk-same {
|
| 367 |
+
color: var(--warning) !important;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
.risk-bar {
|
| 371 |
+
height: 6px;
|
| 372 |
+
background: rgba(48, 54, 61, 0.5);
|
| 373 |
+
border-radius: 3px;
|
| 374 |
+
overflow: hidden;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
.risk-fill {
|
| 378 |
+
height: 100%;
|
| 379 |
+
background: linear-gradient(90deg, var(--success), var(--warning), var(--danger));
|
| 380 |
+
border-radius: 3px;
|
| 381 |
+
transition: width 300ms ease;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
/* ── Conditions ──────────────────────────────────────────────── */
|
| 385 |
+
|
| 386 |
+
.conditions-row {
|
| 387 |
+
margin-top: 14px;
|
| 388 |
+
display: flex;
|
| 389 |
+
flex-wrap: wrap;
|
| 390 |
+
align-items: center;
|
| 391 |
+
gap: 8px;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
.conditions-label {
|
| 395 |
+
font-size: 0.78rem;
|
| 396 |
+
color: var(--muted);
|
| 397 |
+
text-transform: uppercase;
|
| 398 |
+
letter-spacing: 0.04em;
|
| 399 |
+
margin-right: 4px;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
.condition-tag {
|
| 403 |
+
font-size: 0.75rem;
|
| 404 |
+
padding: 4px 12px;
|
| 405 |
+
border-radius: 999px;
|
| 406 |
+
background: rgba(88, 166, 255, 0.1);
|
| 407 |
+
border: 1px solid rgba(88, 166, 255, 0.2);
|
| 408 |
+
color: var(--primary);
|
| 409 |
+
text-transform: capitalize;
|
| 410 |
+
white-space: nowrap;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
/* ── Action Console ──────────────────────────────────────────── */
|
| 414 |
+
|
| 415 |
.action-row,
|
| 416 |
.stack {
|
| 417 |
display: grid;
|
| 418 |
+
gap: 12px;
|
| 419 |
+
margin-bottom: 14px;
|
| 420 |
}
|
| 421 |
|
| 422 |
+
.action-row label,
|
| 423 |
+
.field-group label {
|
| 424 |
+
display: block;
|
| 425 |
color: var(--muted);
|
| 426 |
font-size: 0.78rem;
|
| 427 |
+
letter-spacing: 0.04em;
|
| 428 |
text-transform: uppercase;
|
| 429 |
+
margin-bottom: 6px;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
.field-group {
|
| 433 |
+
display: flex;
|
| 434 |
+
flex-direction: column;
|
| 435 |
}
|
| 436 |
|
| 437 |
.stack-two {
|
| 438 |
+
grid-template-columns: 1fr 1fr;
|
| 439 |
}
|
| 440 |
|
| 441 |
+
/* ── Medication Cards ────────────────────────────────────────── */
|
| 442 |
+
|
| 443 |
.med-grid {
|
| 444 |
display: grid;
|
| 445 |
+
grid-template-columns: repeat(auto-fill, minmax(210px, 1fr));
|
| 446 |
+
gap: 12px;
|
| 447 |
+
max-height: 480px;
|
| 448 |
+
overflow-y: auto;
|
| 449 |
padding-right: 4px;
|
| 450 |
}
|
| 451 |
|
| 452 |
.med-card {
|
| 453 |
+
border-radius: var(--radius-sm);
|
| 454 |
border: 1px solid var(--border);
|
| 455 |
background: var(--panel-solid);
|
| 456 |
+
padding: 16px 18px;
|
| 457 |
+
transition: border-color 150ms ease, background 150ms ease;
|
| 458 |
+
overflow: hidden;
|
| 459 |
}
|
| 460 |
|
| 461 |
.med-card:hover {
|
| 462 |
+
border-color: rgba(88, 166, 255, 0.4);
|
| 463 |
+
background: #1f2937;
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
.med-card.high-risk {
|
| 467 |
+
border-color: rgba(248, 81, 73, 0.35);
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
.med-card-header {
|
| 471 |
+
display: flex;
|
| 472 |
+
align-items: center;
|
| 473 |
+
justify-content: space-between;
|
| 474 |
+
gap: 8px;
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
.med-card-header strong {
|
| 478 |
+
font-size: 0.92rem;
|
| 479 |
+
font-weight: 600;
|
| 480 |
+
overflow: hidden;
|
| 481 |
+
text-overflow: ellipsis;
|
| 482 |
+
white-space: nowrap;
|
| 483 |
+
color: var(--text);
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
.risk-badge {
|
| 487 |
+
font-size: 0.65rem;
|
| 488 |
+
padding: 3px 9px;
|
| 489 |
+
border-radius: 999px;
|
| 490 |
+
background: rgba(248, 81, 73, 0.12);
|
| 491 |
+
border: 1px solid rgba(248, 81, 73, 0.3);
|
| 492 |
+
color: var(--danger);
|
| 493 |
+
text-transform: uppercase;
|
| 494 |
+
letter-spacing: 0.04em;
|
| 495 |
+
font-weight: 600;
|
| 496 |
+
white-space: nowrap;
|
| 497 |
+
flex-shrink: 0;
|
| 498 |
}
|
| 499 |
|
| 500 |
+
.med-generic {
|
| 501 |
+
margin: 6px 0;
|
| 502 |
color: var(--muted);
|
| 503 |
+
font-size: 0.84rem;
|
| 504 |
text-transform: capitalize;
|
| 505 |
+
overflow: hidden;
|
| 506 |
+
text-overflow: ellipsis;
|
| 507 |
+
white-space: nowrap;
|
| 508 |
}
|
| 509 |
|
| 510 |
+
.med-details {
|
| 511 |
+
display: flex;
|
| 512 |
+
gap: 10px;
|
| 513 |
+
font-size: 0.8rem;
|
| 514 |
+
color: #8b949e;
|
| 515 |
}
|
| 516 |
|
| 517 |
+
.med-atc {
|
| 518 |
+
color: var(--primary);
|
| 519 |
+
font-weight: 600;
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
.beers-flags {
|
| 523 |
+
margin-top: 8px;
|
| 524 |
+
display: flex;
|
| 525 |
+
flex-wrap: wrap;
|
| 526 |
+
gap: 5px;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
.beers-tag {
|
| 530 |
+
font-size: 0.68rem;
|
| 531 |
+
padding: 3px 9px;
|
| 532 |
+
border-radius: 999px;
|
| 533 |
+
background: rgba(210, 153, 34, 0.1);
|
| 534 |
+
border: 1px solid rgba(210, 153, 34, 0.25);
|
| 535 |
+
color: var(--warning);
|
| 536 |
+
text-transform: capitalize;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
/* ── History Grid ────────────────────────────────────────────── */
|
| 540 |
+
|
| 541 |
+
.history-grid {
|
| 542 |
+
display: grid;
|
| 543 |
+
grid-template-columns: 1fr 1fr;
|
| 544 |
+
gap: 24px;
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
.history-list {
|
| 548 |
+
display: flex;
|
| 549 |
+
flex-direction: column;
|
| 550 |
+
gap: 10px;
|
| 551 |
max-height: 300px;
|
| 552 |
+
overflow-y: auto;
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
.history-item {
|
| 556 |
+
border-radius: var(--radius-sm);
|
| 557 |
+
border: 1px solid var(--border);
|
| 558 |
+
background: var(--panel-solid);
|
| 559 |
+
padding: 14px 16px;
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
.history-item strong {
|
| 563 |
+
font-size: 0.88rem;
|
| 564 |
+
font-weight: 600;
|
| 565 |
+
display: block;
|
| 566 |
+
margin-bottom: 6px;
|
| 567 |
+
color: var(--text);
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
.history-detail {
|
| 571 |
+
margin: 6px 0 0;
|
| 572 |
+
font-size: 0.82rem;
|
| 573 |
+
color: var(--muted);
|
| 574 |
+
text-transform: capitalize;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
.severity-tag, .intervention-tag {
|
| 578 |
+
display: inline-block;
|
| 579 |
+
font-size: 0.7rem;
|
| 580 |
+
padding: 3px 10px;
|
| 581 |
+
border-radius: 999px;
|
| 582 |
+
font-weight: 600;
|
| 583 |
+
text-transform: uppercase;
|
| 584 |
+
letter-spacing: 0.04em;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
.severity-tag.severe {
|
| 588 |
+
background: rgba(248, 81, 73, 0.12);
|
| 589 |
+
border: 1px solid rgba(248, 81, 73, 0.3);
|
| 590 |
+
color: var(--danger);
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
.severity-tag.moderate {
|
| 594 |
+
background: rgba(210, 153, 34, 0.12);
|
| 595 |
+
border: 1px solid rgba(210, 153, 34, 0.3);
|
| 596 |
+
color: var(--warning);
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
.severity-tag.mild {
|
| 600 |
+
background: rgba(63, 185, 80, 0.12);
|
| 601 |
+
border: 1px solid rgba(63, 185, 80, 0.3);
|
| 602 |
+
color: var(--success);
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
.severity-tag.none {
|
| 606 |
+
background: rgba(48, 54, 61, 0.3);
|
| 607 |
+
border: 1px solid var(--border);
|
| 608 |
+
color: var(--muted);
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
.intervention-tag {
|
| 612 |
+
background: rgba(88, 166, 255, 0.1);
|
| 613 |
+
border: 1px solid rgba(88, 166, 255, 0.25);
|
| 614 |
+
color: var(--primary);
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
.severity-severe .history-item {
|
| 618 |
+
border-color: rgba(248, 81, 73, 0.2);
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
.severity-moderate .history-item {
|
| 622 |
+
border-color: rgba(210, 153, 34, 0.2);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
/* ── Event Log ───────────────────────────────────────────────── */
|
| 626 |
+
|
| 627 |
+
.logs {
|
| 628 |
+
max-height: 280px;
|
| 629 |
+
overflow-y: auto;
|
| 630 |
padding-right: 4px;
|
| 631 |
+
display: flex;
|
| 632 |
+
flex-direction: column;
|
| 633 |
+
gap: 8px;
|
| 634 |
font-size: 0.84rem;
|
| 635 |
+
font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace;
|
| 636 |
}
|
| 637 |
|
| 638 |
.logs div {
|
| 639 |
+
border-radius: var(--radius-sm);
|
| 640 |
border: 1px solid var(--border);
|
| 641 |
+
background: #0d1117;
|
| 642 |
+
padding: 12px 16px;
|
| 643 |
+
color: var(--text);
|
| 644 |
+
overflow: hidden;
|
| 645 |
+
text-overflow: ellipsis;
|
| 646 |
+
word-break: break-word;
|
| 647 |
+
line-height: 1.5;
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
.log-empty {
|
| 651 |
+
color: var(--muted);
|
| 652 |
+
font-family: inherit;
|
| 653 |
+
font-style: italic;
|
| 654 |
}
|
| 655 |
|
| 656 |
+
/* ── Helper Text ─────────────────────────────────────────────── */
|
| 657 |
+
|
| 658 |
.muted {
|
| 659 |
margin: 0;
|
| 660 |
color: var(--muted);
|
| 661 |
+
font-size: 0.88rem;
|
| 662 |
}
|
| 663 |
|
| 664 |
.budget-note {
|
| 665 |
+
margin-top: 14px;
|
| 666 |
border: 1px solid var(--border);
|
| 667 |
+
border-radius: var(--radius-sm);
|
| 668 |
+
padding: 14px 18px;
|
| 669 |
+
background: var(--panel-solid);
|
| 670 |
+
font-size: 0.88rem;
|
| 671 |
+
color: var(--muted);
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
.done-note {
|
| 675 |
+
border-color: rgba(63, 185, 80, 0.3);
|
| 676 |
+
color: var(--success);
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
/* ── Footer ──────────────────────────────────────────────────── */
|
| 680 |
+
|
| 681 |
+
.app-footer {
|
| 682 |
+
margin-top: 28px;
|
| 683 |
+
text-align: center;
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
.app-footer p {
|
| 687 |
+
font-size: 0.8rem;
|
| 688 |
+
color: var(--muted);
|
| 689 |
+
opacity: 0.6;
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
.app-footer a {
|
| 693 |
+
color: var(--primary);
|
| 694 |
+
text-decoration: none;
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
.app-footer a:hover {
|
| 698 |
+
text-decoration: underline;
|
| 699 |
}
|
| 700 |
|
| 701 |
+
/* ── Spotlight Guide ─────────────────────────────────────────── */
|
| 702 |
+
|
| 703 |
+
.spotlight-overlay {
|
| 704 |
+
position: fixed;
|
| 705 |
+
inset: 0;
|
| 706 |
+
z-index: 100;
|
| 707 |
+
pointer-events: none;
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
.spotlight-svg {
|
| 711 |
+
position: fixed;
|
| 712 |
+
inset: 0;
|
| 713 |
+
z-index: 100;
|
| 714 |
+
pointer-events: auto;
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
.spotlight-ring {
|
| 718 |
+
position: fixed;
|
| 719 |
+
z-index: 101;
|
| 720 |
+
border: 2px solid var(--primary);
|
| 721 |
+
border-radius: var(--radius);
|
| 722 |
+
box-shadow: 0 0 0 4px rgba(88, 166, 255, 0.1);
|
| 723 |
+
pointer-events: none;
|
| 724 |
+
transition: top 350ms ease, left 350ms ease, width 350ms ease, height 350ms ease;
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
.spotlight-tooltip {
|
| 728 |
+
position: fixed;
|
| 729 |
+
z-index: 102;
|
| 730 |
+
border-radius: var(--radius);
|
| 731 |
+
padding: 24px;
|
| 732 |
+
pointer-events: auto;
|
| 733 |
+
animation: tooltipIn 250ms ease;
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
@keyframes tooltipIn {
|
| 737 |
+
from {
|
| 738 |
+
opacity: 0;
|
| 739 |
+
transform: translateY(8px);
|
| 740 |
+
}
|
| 741 |
+
to {
|
| 742 |
+
opacity: 1;
|
| 743 |
+
transform: translateY(0);
|
| 744 |
+
}
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
.spotlight-tooltip-header {
|
| 748 |
+
display: flex;
|
| 749 |
+
align-items: flex-start;
|
| 750 |
+
justify-content: space-between;
|
| 751 |
+
gap: 12px;
|
| 752 |
+
margin-bottom: 12px;
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
.spotlight-tooltip-header h3 {
|
| 756 |
+
margin: 0;
|
| 757 |
+
font-size: 1.05rem;
|
| 758 |
+
font-weight: 600;
|
| 759 |
+
color: var(--text);
|
| 760 |
+
text-transform: none;
|
| 761 |
+
letter-spacing: 0;
|
| 762 |
+
line-height: 1.3;
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
.guide-counter {
|
| 766 |
+
font-size: 0.72rem;
|
| 767 |
+
color: var(--muted);
|
| 768 |
+
padding: 4px 10px;
|
| 769 |
+
border-radius: 999px;
|
| 770 |
+
background: rgba(48, 54, 61, 0.4);
|
| 771 |
+
border: 1px solid var(--border);
|
| 772 |
+
white-space: nowrap;
|
| 773 |
+
flex-shrink: 0;
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
.spotlight-tooltip-body {
|
| 777 |
+
margin-bottom: 18px;
|
| 778 |
+
font-size: 0.88rem;
|
| 779 |
+
line-height: 1.65;
|
| 780 |
+
color: var(--muted);
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
.spotlight-tooltip-body p {
|
| 784 |
+
margin: 0 0 5px;
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
.spotlight-tooltip-body p:empty {
|
| 788 |
+
height: 4px;
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
.spotlight-tooltip-footer {
|
| 792 |
+
display: flex;
|
| 793 |
+
gap: 8px;
|
| 794 |
+
justify-content: flex-end;
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
.guide-btn {
|
| 798 |
+
padding: 8px 18px !important;
|
| 799 |
+
font-size: 0.84rem !important;
|
| 800 |
+
border-radius: var(--radius-sm) !important;
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
.guide-dots {
|
| 804 |
+
display: flex;
|
| 805 |
+
justify-content: center;
|
| 806 |
+
gap: 6px;
|
| 807 |
+
margin-top: 14px;
|
| 808 |
+
}
|
| 809 |
+
|
| 810 |
+
.dot {
|
| 811 |
+
width: 7px;
|
| 812 |
+
height: 7px;
|
| 813 |
+
border-radius: 50%;
|
| 814 |
+
background: rgba(48, 54, 61, 0.6);
|
| 815 |
+
transition: background 150ms ease;
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
.dot.active {
|
| 819 |
+
background: var(--primary);
|
| 820 |
+
box-shadow: none;
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
/* ── Responsive ──────────────────────────────────────────────── */
|
| 824 |
+
|
| 825 |
@media (max-width: 1180px) {
|
| 826 |
.layout {
|
| 827 |
grid-template-columns: 1fr;
|
| 828 |
}
|
| 829 |
|
| 830 |
.topbar {
|
| 831 |
+
flex-direction: column;
|
| 832 |
+
align-items: flex-start;
|
| 833 |
}
|
| 834 |
|
| 835 |
+
.topbar-right {
|
| 836 |
+
align-self: flex-start;
|
| 837 |
}
|
| 838 |
|
| 839 |
.actions {
|
| 840 |
+
width: 100%;
|
| 841 |
justify-content: flex-start;
|
| 842 |
}
|
| 843 |
+
|
| 844 |
+
.history-grid {
|
| 845 |
+
grid-template-columns: 1fr;
|
| 846 |
+
}
|
| 847 |
}
|
| 848 |
|
| 849 |
@media (max-width: 760px) {
|
| 850 |
.shell {
|
| 851 |
+
padding: 12px 12px 24px;
|
| 852 |
}
|
| 853 |
|
| 854 |
.topbar,
|
| 855 |
.panel {
|
| 856 |
+
border-radius: var(--radius-sm);
|
| 857 |
+
padding: 16px 18px;
|
| 858 |
}
|
| 859 |
|
| 860 |
.actions {
|
| 861 |
+
flex-direction: column;
|
| 862 |
}
|
| 863 |
|
| 864 |
.actions button,
|
|
|
|
| 866 |
width: 100%;
|
| 867 |
}
|
| 868 |
|
| 869 |
+
.kpi-grid {
|
| 870 |
+
grid-template-columns: 1fr 1fr;
|
| 871 |
+
}
|
| 872 |
+
|
| 873 |
+
.med-grid {
|
| 874 |
+
grid-template-columns: 1fr;
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
.stack-two {
|
| 878 |
grid-template-columns: 1fr;
|
| 879 |
}
|
| 880 |
|
| 881 |
+
.guide-modal {
|
| 882 |
+
padding: 20px;
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
.spotlight-tooltip {
|
| 886 |
+
left: 10px !important;
|
| 887 |
+
right: 10px !important;
|
| 888 |
+
max-width: calc(100vw - 20px) !important;
|
| 889 |
+
}
|
| 890 |
+
|
| 891 |
+
.guide-footer,
|
| 892 |
+
.spotlight-tooltip-footer {
|
| 893 |
+
flex-direction: column;
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
.logs {
|
| 897 |
+
max-height: 200px;
|
| 898 |
}
|
| 899 |
}
|
| 900 |
+
|
| 901 |
+
/* ── Scrollbar ───────────────────────────────────────────────── */
|
| 902 |
+
|
| 903 |
+
::-webkit-scrollbar {
|
| 904 |
+
width: 6px;
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
::-webkit-scrollbar-track {
|
| 908 |
+
background: transparent;
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
::-webkit-scrollbar-thumb {
|
| 912 |
+
background: rgba(48, 54, 61, 0.6);
|
| 913 |
+
border-radius: 3px;
|
| 914 |
+
}
|
| 915 |
+
|
| 916 |
+
::-webkit-scrollbar-thumb:hover {
|
| 917 |
+
background: rgba(139, 148, 158, 0.4);
|
| 918 |
+
}
|
train_bandit.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""OptimNeuralTS training -- Neural Bandit search for dangerous polypharmacies.
|
| 3 |
+
|
| 4 |
+
Implements the training pipeline from:
|
| 5 |
+
Larouche et al., "Neural Bandits for Data Mining: Searching for Dangerous Polypharmacy"
|
| 6 |
+
https://link.springer.com/chapter/10.1007/978-3-031-36938-4_5
|
| 7 |
+
|
| 8 |
+
This script:
|
| 9 |
+
1. Generates a synthetic dataset of drug combinations with simulated Relative Risk (RR)
|
| 10 |
+
2. Runs OptimNeuralTS: warm-up -> NeuralTS+DE exploration -> ensemble building
|
| 11 |
+
3. Evaluates the ensemble's ability to detect Potentially Inappropriate Polypharmacies (PIPs)
|
| 12 |
+
4. Saves the trained ensemble model
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python train_bandit.py --total-steps 1000 --warmup-steps 200
|
| 16 |
+
python train_bandit.py --total-steps 3000 --warmup-steps 500 --eval-every 100
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import math
|
| 24 |
+
import os
|
| 25 |
+
import random
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from itertools import combinations
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Any, Dict, List, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
_BACKEND_SRC = os.path.join(
|
| 35 |
+
os.path.dirname(os.path.abspath(__file__)), "backend", "src"
|
| 36 |
+
)
|
| 37 |
+
sys.path.insert(0, _BACKEND_SRC)
|
| 38 |
+
|
| 39 |
+
from polypharmacy_env.neural_bandits import NeuralTS, OptimNeuralTS, nearest_neighbor_hamming # noqa: E402
|
| 40 |
+
from polypharmacy_env.data_loader import load_drug_metadata, load_ddi_rules # noqa: E402
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# Synthetic RR data generation (follows paper Section 4.1)
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
def generate_synthetic_dataset(
|
| 48 |
+
n_drugs: int = 33,
|
| 49 |
+
n_combinations: int = 5000,
|
| 50 |
+
n_dangerous_patterns: int = 10,
|
| 51 |
+
rr_threshold: float = 1.1,
|
| 52 |
+
noise_std: float = 0.1,
|
| 53 |
+
seed: int = 42,
|
| 54 |
+
) -> Dict[str, Any]:
|
| 55 |
+
"""Generate synthetic drug combination data with ground-truth RRs.
|
| 56 |
+
|
| 57 |
+
Follows the paper's data generation process:
|
| 58 |
+
- Generate dangerous patterns (binomial)
|
| 59 |
+
- For each combination, compute similarity to nearest pattern
|
| 60 |
+
- Assign RR proportional to similarity (if overlapping) or from N(mu, sigma) if disjoint
|
| 61 |
+
"""
|
| 62 |
+
rng = random.Random(seed)
|
| 63 |
+
torch.manual_seed(seed)
|
| 64 |
+
|
| 65 |
+
# Generate dangerous patterns (multi-hot vectors)
|
| 66 |
+
patterns = []
|
| 67 |
+
for _ in range(n_dangerous_patterns):
|
| 68 |
+
# Each drug has ~30% chance of being in the pattern (smaller patterns)
|
| 69 |
+
p = torch.zeros(n_drugs)
|
| 70 |
+
n_active = rng.randint(2, max(3, n_drugs // 8))
|
| 71 |
+
indices = rng.sample(range(n_drugs), n_active)
|
| 72 |
+
for idx in indices:
|
| 73 |
+
p[idx] = 1.0
|
| 74 |
+
patterns.append(p)
|
| 75 |
+
|
| 76 |
+
# Generate distinct drug combinations
|
| 77 |
+
combos = []
|
| 78 |
+
combo_set = set()
|
| 79 |
+
while len(combos) < n_combinations:
|
| 80 |
+
n_active = rng.randint(2, min(8, n_drugs))
|
| 81 |
+
indices = tuple(sorted(rng.sample(range(n_drugs), n_active)))
|
| 82 |
+
if indices not in combo_set:
|
| 83 |
+
combo_set.add(indices)
|
| 84 |
+
vec = torch.zeros(n_drugs)
|
| 85 |
+
for idx in indices:
|
| 86 |
+
vec[idx] = 1.0
|
| 87 |
+
combos.append(vec)
|
| 88 |
+
|
| 89 |
+
# Compute RR for each combination based on Hamming distance to nearest pattern
|
| 90 |
+
rrs = []
|
| 91 |
+
nearest_pattern_idx = []
|
| 92 |
+
for combo in combos:
|
| 93 |
+
# Find nearest pattern (Hamming distance)
|
| 94 |
+
min_dist = float("inf")
|
| 95 |
+
best_p_idx = 0
|
| 96 |
+
for p_idx, pattern in enumerate(patterns):
|
| 97 |
+
dist = (combo != pattern).float().sum().item()
|
| 98 |
+
if dist < min_dist:
|
| 99 |
+
min_dist = dist
|
| 100 |
+
best_p_idx = p_idx
|
| 101 |
+
|
| 102 |
+
nearest_pattern_idx.append(best_p_idx)
|
| 103 |
+
pattern = patterns[best_p_idx]
|
| 104 |
+
|
| 105 |
+
# Check intersection (shared active drugs)
|
| 106 |
+
intersection = (combo * pattern).sum().item()
|
| 107 |
+
if intersection > 0:
|
| 108 |
+
# RR proportional to similarity
|
| 109 |
+
similarity = intersection / max(pattern.sum().item(), 1)
|
| 110 |
+
# Higher similarity -> higher RR
|
| 111 |
+
base_rr = 0.5 + 2.5 * similarity # range ~[0.5, 3.0]
|
| 112 |
+
noise = rng.gauss(0, 0.15)
|
| 113 |
+
rr = max(0.1, base_rr + noise)
|
| 114 |
+
else:
|
| 115 |
+
# Disjoint: sample from neutral distribution
|
| 116 |
+
rr = max(0.1, rng.gauss(0.85, 0.2))
|
| 117 |
+
|
| 118 |
+
rrs.append(rr)
|
| 119 |
+
|
| 120 |
+
# Compute pattern RRs (patterns themselves have high RR)
|
| 121 |
+
pattern_rrs = [2.0 + rng.gauss(0, 0.3) for _ in patterns]
|
| 122 |
+
|
| 123 |
+
n_pips = sum(1 for rr in rrs if rr > rr_threshold)
|
| 124 |
+
print(f" Generated {n_combinations} combos, {n_pips} PIPs (RR > {rr_threshold})")
|
| 125 |
+
print(f" RR range: [{min(rrs):.3f}, {max(rrs):.3f}], mean: {sum(rrs)/len(rrs):.3f}")
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"combos": combos,
|
| 129 |
+
"rrs": rrs,
|
| 130 |
+
"patterns": patterns,
|
| 131 |
+
"pattern_rrs": pattern_rrs,
|
| 132 |
+
"n_drugs": n_drugs,
|
| 133 |
+
"n_pips": n_pips,
|
| 134 |
+
"rr_threshold": rr_threshold,
|
| 135 |
+
"noise_std": noise_std,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Training loop
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
def train_bandit(args: argparse.Namespace) -> None:
|
| 144 |
+
print("=" * 72)
|
| 145 |
+
print("OptimNeuralTS Training -- Neural Bandits for Polypharmacy")
|
| 146 |
+
print("=" * 72)
|
| 147 |
+
|
| 148 |
+
# Generate synthetic data
|
| 149 |
+
print("\nGenerating synthetic dataset...")
|
| 150 |
+
dataset = generate_synthetic_dataset(
|
| 151 |
+
n_drugs=args.n_drugs,
|
| 152 |
+
n_combinations=args.n_combinations,
|
| 153 |
+
n_dangerous_patterns=args.n_patterns,
|
| 154 |
+
seed=args.seed,
|
| 155 |
+
)
|
| 156 |
+
combos = dataset["combos"]
|
| 157 |
+
rrs = dataset["rrs"]
|
| 158 |
+
patterns = dataset["patterns"]
|
| 159 |
+
pattern_rrs = dataset["pattern_rrs"]
|
| 160 |
+
noise_std = dataset["noise_std"]
|
| 161 |
+
rr_threshold = dataset["rr_threshold"]
|
| 162 |
+
|
| 163 |
+
# Initialize OptimNeuralTS
|
| 164 |
+
bandit = OptimNeuralTS(
|
| 165 |
+
input_dim=args.n_drugs,
|
| 166 |
+
hidden=args.hidden_dim,
|
| 167 |
+
reg_lambda=args.reg_lambda,
|
| 168 |
+
exploration_factor=args.exploration_factor,
|
| 169 |
+
lr=args.lr,
|
| 170 |
+
train_epochs=args.train_epochs,
|
| 171 |
+
warmup_steps=args.warmup_steps,
|
| 172 |
+
total_steps=args.total_steps,
|
| 173 |
+
retrain_every=args.retrain_every,
|
| 174 |
+
de_population=args.de_population,
|
| 175 |
+
de_crossover=args.de_crossover,
|
| 176 |
+
de_weight=args.de_weight,
|
| 177 |
+
de_steps=args.de_steps,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
print(f"\n n_drugs : {args.n_drugs}")
|
| 181 |
+
print(f" n_combinations : {args.n_combinations}")
|
| 182 |
+
print(f" total_steps (T) : {args.total_steps}")
|
| 183 |
+
print(f" warmup_steps (τ) : {args.warmup_steps}")
|
| 184 |
+
print(f" DE population (N) : {args.de_population}")
|
| 185 |
+
print(f" DE steps (S) : {args.de_steps}")
|
| 186 |
+
print(f" retrain_every : {args.retrain_every}")
|
| 187 |
+
print(f" hidden_dim : {args.hidden_dim}")
|
| 188 |
+
print(f" lr : {args.lr}")
|
| 189 |
+
print("=" * 72)
|
| 190 |
+
|
| 191 |
+
t_start = time.time()
|
| 192 |
+
|
| 193 |
+
# Metrics tracking
|
| 194 |
+
step_rewards = []
|
| 195 |
+
pips_found = []
|
| 196 |
+
eval_precisions = []
|
| 197 |
+
eval_recalls = []
|
| 198 |
+
training_dataset_indices = set()
|
| 199 |
+
|
| 200 |
+
for t in range(1, args.total_steps + 1):
|
| 201 |
+
# Select action
|
| 202 |
+
idx, info = bandit.select_action(combos)
|
| 203 |
+
training_dataset_indices.add(idx)
|
| 204 |
+
|
| 205 |
+
# Observe noisy reward (RR + noise)
|
| 206 |
+
true_rr = rrs[idx]
|
| 207 |
+
noisy_rr = true_rr + random.gauss(0, noise_std)
|
| 208 |
+
reward = noisy_rr
|
| 209 |
+
|
| 210 |
+
step_rewards.append(reward)
|
| 211 |
+
|
| 212 |
+
# Update bandit
|
| 213 |
+
loss = bandit.observe(combos[idx], reward)
|
| 214 |
+
|
| 215 |
+
# Periodic evaluation
|
| 216 |
+
if t % args.eval_every == 0 or t == args.total_steps:
|
| 217 |
+
# Evaluate ensemble on ALL combinations
|
| 218 |
+
true_positives = 0
|
| 219 |
+
false_positives = 0
|
| 220 |
+
true_negatives = 0
|
| 221 |
+
false_negatives = 0
|
| 222 |
+
|
| 223 |
+
for i, combo in enumerate(combos):
|
| 224 |
+
pred = bandit.predict_risk(combo)
|
| 225 |
+
actual_pip = rrs[i] > rr_threshold
|
| 226 |
+
predicted_pip = pred["is_potentially_harmful"]
|
| 227 |
+
|
| 228 |
+
if predicted_pip and actual_pip:
|
| 229 |
+
true_positives += 1
|
| 230 |
+
elif predicted_pip and not actual_pip:
|
| 231 |
+
false_positives += 1
|
| 232 |
+
elif not predicted_pip and actual_pip:
|
| 233 |
+
false_negatives += 1
|
| 234 |
+
else:
|
| 235 |
+
true_negatives += 1
|
| 236 |
+
|
| 237 |
+
precision = true_positives / max(true_positives + false_positives, 1)
|
| 238 |
+
recall = true_positives / max(true_positives + false_negatives, 1)
|
| 239 |
+
eval_precisions.append(precision)
|
| 240 |
+
eval_recalls.append(recall)
|
| 241 |
+
|
| 242 |
+
# Check dangerous pattern detection
|
| 243 |
+
patterns_found = 0
|
| 244 |
+
for p_idx, pattern in enumerate(patterns):
|
| 245 |
+
pred = bandit.predict_risk(pattern)
|
| 246 |
+
if pred["is_potentially_harmful"]:
|
| 247 |
+
patterns_found += 1
|
| 248 |
+
pattern_ratio = patterns_found / len(patterns)
|
| 249 |
+
|
| 250 |
+
# PIPs found outside training data
|
| 251 |
+
pips_outside_train = 0
|
| 252 |
+
total_detected_pips = 0
|
| 253 |
+
for i, combo in enumerate(combos):
|
| 254 |
+
pred = bandit.predict_risk(combo)
|
| 255 |
+
if pred["is_potentially_harmful"]:
|
| 256 |
+
total_detected_pips += 1
|
| 257 |
+
if i not in training_dataset_indices:
|
| 258 |
+
pips_outside_train += 1
|
| 259 |
+
|
| 260 |
+
pips_found.append(total_detected_pips)
|
| 261 |
+
|
| 262 |
+
elapsed = time.time() - t_start
|
| 263 |
+
phase = info.get("phase", "?")
|
| 264 |
+
n_ens = len(bandit.agent.ensemble_weights)
|
| 265 |
+
print(
|
| 266 |
+
f"[step {t:>5d}/{args.total_steps}] "
|
| 267 |
+
f"phase={phase} "
|
| 268 |
+
f"precision={precision:.3f} "
|
| 269 |
+
f"recall={recall:.3f} "
|
| 270 |
+
f"patterns={pattern_ratio:.2f} "
|
| 271 |
+
f"PIPs_detected={total_detected_pips} "
|
| 272 |
+
f"outside_train={pips_outside_train} "
|
| 273 |
+
f"ensemble={n_ens} "
|
| 274 |
+
f"elapsed={elapsed:.1f}s"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Save metrics
|
| 278 |
+
metrics = {
|
| 279 |
+
"algorithm": "OptimNeuralTS",
|
| 280 |
+
"n_drugs": args.n_drugs,
|
| 281 |
+
"n_combinations": args.n_combinations,
|
| 282 |
+
"total_steps": args.total_steps,
|
| 283 |
+
"warmup_steps": args.warmup_steps,
|
| 284 |
+
"n_ensemble_models": len(bandit.agent.ensemble_weights),
|
| 285 |
+
"final_precision": eval_precisions[-1] if eval_precisions else 0,
|
| 286 |
+
"final_recall": eval_recalls[-1] if eval_recalls else 0,
|
| 287 |
+
"eval_precisions": eval_precisions,
|
| 288 |
+
"eval_recalls": eval_recalls,
|
| 289 |
+
"pips_detected": pips_found,
|
| 290 |
+
"step_rewards": step_rewards,
|
| 291 |
+
"total_time_s": time.time() - t_start,
|
| 292 |
+
"hyperparameters": {
|
| 293 |
+
"hidden_dim": args.hidden_dim,
|
| 294 |
+
"lr": args.lr,
|
| 295 |
+
"reg_lambda": args.reg_lambda,
|
| 296 |
+
"exploration_factor": args.exploration_factor,
|
| 297 |
+
"de_population": args.de_population,
|
| 298 |
+
"de_crossover": args.de_crossover,
|
| 299 |
+
"de_weight": args.de_weight,
|
| 300 |
+
"de_steps": args.de_steps,
|
| 301 |
+
"train_epochs": args.train_epochs,
|
| 302 |
+
"retrain_every": args.retrain_every,
|
| 303 |
+
},
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
metrics_path = Path(args.metrics_file)
|
| 307 |
+
metrics_path.parent.mkdir(parents=True, exist_ok=True)
|
| 308 |
+
with open(metrics_path, "w") as f:
|
| 309 |
+
json.dump(metrics, f, indent=2)
|
| 310 |
+
print(f"\nMetrics saved to {metrics_path}")
|
| 311 |
+
|
| 312 |
+
# Save model ensemble
|
| 313 |
+
ckpt_dir = Path(args.checkpoint_dir)
|
| 314 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 315 |
+
ckpt_path = ckpt_dir / "bandit_ensemble.pt"
|
| 316 |
+
torch.save({
|
| 317 |
+
"ensemble_weights": bandit.agent.ensemble_weights,
|
| 318 |
+
"network_state_dict": bandit.agent.network.state_dict(),
|
| 319 |
+
"U_diag": bandit.agent.U_diag,
|
| 320 |
+
"input_dim": args.n_drugs,
|
| 321 |
+
"hidden_dim": args.hidden_dim,
|
| 322 |
+
"n_steps": args.total_steps,
|
| 323 |
+
}, ckpt_path)
|
| 324 |
+
print(f"Ensemble model saved to {ckpt_path}")
|
| 325 |
+
|
| 326 |
+
print(f"\n{'='*72}")
|
| 327 |
+
print("Training complete!")
|
| 328 |
+
print(f" Ensemble size: {len(bandit.agent.ensemble_weights)} models")
|
| 329 |
+
if eval_precisions:
|
| 330 |
+
print(f" Final precision: {eval_precisions[-1]:.4f}")
|
| 331 |
+
print(f" Final recall: {eval_recalls[-1]:.4f}")
|
| 332 |
+
print(f" Total time: {time.time() - t_start:.1f}s")
|
| 333 |
+
print(f"{'='*72}")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# ---------------------------------------------------------------------------
|
| 337 |
+
# CLI
|
| 338 |
+
# ---------------------------------------------------------------------------
|
| 339 |
+
|
| 340 |
+
def parse_args() -> argparse.Namespace:
|
| 341 |
+
p = argparse.ArgumentParser(
|
| 342 |
+
description="OptimNeuralTS training for polypharmacy PIP detection",
|
| 343 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 344 |
+
)
|
| 345 |
+
# Dataset
|
| 346 |
+
p.add_argument("--n-drugs", type=int, default=33, help="Number of possible drugs")
|
| 347 |
+
p.add_argument("--n-combinations", type=int, default=5000, help="Number of distinct drug combinations")
|
| 348 |
+
p.add_argument("--n-patterns", type=int, default=10, help="Number of dangerous patterns")
|
| 349 |
+
p.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 350 |
+
|
| 351 |
+
# OptimNeuralTS
|
| 352 |
+
p.add_argument("--total-steps", type=int, default=1000, help="Total bandit steps T")
|
| 353 |
+
p.add_argument("--warmup-steps", type=int, default=200, help="Warmup steps τ")
|
| 354 |
+
p.add_argument("--retrain-every", type=int, default=10, help="Retrain network every N steps")
|
| 355 |
+
p.add_argument("--hidden-dim", type=int, default=64, help="Network hidden layer size")
|
| 356 |
+
p.add_argument("--lr", type=float, default=0.01, help="Learning rate")
|
| 357 |
+
p.add_argument("--reg-lambda", type=float, default=1.0, help="Regularization λ")
|
| 358 |
+
p.add_argument("--exploration-factor", type=float, default=1.0, help="Exploration ν")
|
| 359 |
+
p.add_argument("--train-epochs", type=int, default=50, help="Epochs per retrain")
|
| 360 |
+
|
| 361 |
+
# DE
|
| 362 |
+
p.add_argument("--de-population", type=int, default=16, help="DE population size N")
|
| 363 |
+
p.add_argument("--de-crossover", type=float, default=0.9, help="DE crossover rate C")
|
| 364 |
+
p.add_argument("--de-weight", type=float, default=1.0, help="DE differential weight F")
|
| 365 |
+
p.add_argument("--de-steps", type=int, default=8, help="DE optimization steps S")
|
| 366 |
+
|
| 367 |
+
# Output
|
| 368 |
+
p.add_argument("--eval-every", type=int, default=100, help="Evaluate every N steps")
|
| 369 |
+
p.add_argument("--metrics-file", type=str, default="bandit_metrics.json", help="Metrics output path")
|
| 370 |
+
p.add_argument(
|
| 371 |
+
"--checkpoint-dir", type=str,
|
| 372 |
+
default=os.path.join(_BACKEND_SRC, "polypharmacy_env", "checkpoints"),
|
| 373 |
+
help="Model checkpoint directory",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
return p.parse_args()
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
if __name__ == "__main__":
|
| 380 |
+
args = parse_args()
|
| 381 |
+
train_bandit(args)
|
train_rl.py
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""REINFORCE with Learned Baseline -- RL training for PolypharmacyEnv.
|
| 3 |
+
|
| 4 |
+
Trains a small neural-network policy to perform medication reviews in the
|
| 5 |
+
PolypharmacyEnv environment. The policy learns to query drug-drug interactions,
|
| 6 |
+
propose clinical interventions, and decide when to finalise the review.
|
| 7 |
+
|
| 8 |
+
Usage examples:
|
| 9 |
+
python train_rl.py --task easy_screening --episodes 200
|
| 10 |
+
python train_rl.py --task budgeted_screening --episodes 500
|
| 11 |
+
python train_rl.py --task complex_tradeoff --episodes 1000
|
| 12 |
+
python train_rl.py --task easy_screening --episodes 300 --lr 5e-4 --batch-size 8
|
| 13 |
+
|
| 14 |
+
Architecture:
|
| 15 |
+
- Fixed-size state encoding (16-dim global summary features)
|
| 16 |
+
- Fixed 166-dim action space with dynamic validity masking
|
| 17 |
+
- 3-layer MLP policy (state -> logits over actions)
|
| 18 |
+
- 3-layer MLP value baseline (state -> scalar return estimate)
|
| 19 |
+
- REINFORCE gradient with advantage = (discounted return) - baseline
|
| 20 |
+
- Entropy bonus for sustained exploration
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
from itertools import combinations
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
from torch.distributions import Categorical
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Environment imports (direct, no HTTP)
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
_BACKEND_SRC = os.path.join(
|
| 43 |
+
os.path.dirname(os.path.abspath(__file__)), "backend", "src"
|
| 44 |
+
)
|
| 45 |
+
sys.path.insert(0, _BACKEND_SRC)
|
| 46 |
+
|
| 47 |
+
from polypharmacy_env.env_core import PolypharmacyEnv # noqa: E402
|
| 48 |
+
from polypharmacy_env.models import ( # noqa: E402
|
| 49 |
+
PolypharmacyAction,
|
| 50 |
+
PolypharmacyObservation,
|
| 51 |
+
)
|
| 52 |
+
from polypharmacy_env.config import TASK_CONFIGS, TaskConfig # noqa: E402
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Constants -- action-space geometry
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
MAX_MEDS = 15 # upper bound across all task difficulties
|
| 58 |
+
INTERVENTION_TYPES: List[str] = [
|
| 59 |
+
"stop",
|
| 60 |
+
"dose_reduce",
|
| 61 |
+
"substitute",
|
| 62 |
+
"add_monitoring",
|
| 63 |
+
]
|
| 64 |
+
N_INTERVENTION_TYPES = len(INTERVENTION_TYPES)
|
| 65 |
+
|
| 66 |
+
# Pre-compute the mapping (med_position_i, med_position_j) -> flat action index
|
| 67 |
+
# for all possible query_ddi pairs where i < j.
|
| 68 |
+
_PAIR_INDEX: Dict[Tuple[int, int], int] = {}
|
| 69 |
+
_idx = 0
|
| 70 |
+
for _i in range(MAX_MEDS):
|
| 71 |
+
for _j in range(_i + 1, MAX_MEDS):
|
| 72 |
+
_PAIR_INDEX[(_i, _j)] = _idx
|
| 73 |
+
_idx += 1
|
| 74 |
+
N_PAIRS = _idx # C(15,2) = 105
|
| 75 |
+
_REVERSE_PAIR: Dict[int, Tuple[int, int]] = {v: k for k, v in _PAIR_INDEX.items()}
|
| 76 |
+
|
| 77 |
+
N_INTERVENTIONS = MAX_MEDS * N_INTERVENTION_TYPES # 60
|
| 78 |
+
FINISH_IDX = N_PAIRS + N_INTERVENTIONS # 165
|
| 79 |
+
N_ACTIONS = FINISH_IDX + 1 # 166
|
| 80 |
+
|
| 81 |
+
# State feature vector length (see encode_state)
|
| 82 |
+
STATE_DIM = 16
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# State encoding
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def encode_state(obs: PolypharmacyObservation, task_cfg: TaskConfig) -> torch.Tensor:
|
| 89 |
+
"""Encode the observation into a compact 16-dim feature vector.
|
| 90 |
+
|
| 91 |
+
All values are normalised to roughly [0, 1] to help gradient flow.
|
| 92 |
+
"""
|
| 93 |
+
meds = obs.current_medications
|
| 94 |
+
n_meds = len(meds)
|
| 95 |
+
|
| 96 |
+
n_high_risk = sum(1 for m in meds if m.is_high_risk_elderly)
|
| 97 |
+
n_beers_any = sum(1 for m in meds if m.beers_flags)
|
| 98 |
+
n_beers_avoid = sum(
|
| 99 |
+
1 for m in meds if any("avoid" in f for f in m.beers_flags)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
queries = obs.interaction_queries
|
| 103 |
+
n_queries = len(queries)
|
| 104 |
+
n_severe = sum(1 for q in queries if q.severity == "severe")
|
| 105 |
+
n_moderate = sum(1 for q in queries if q.severity == "moderate")
|
| 106 |
+
n_interventions = len(obs.interventions)
|
| 107 |
+
|
| 108 |
+
max_possible_pairs = max(n_meds * (n_meds - 1) // 2, 1)
|
| 109 |
+
|
| 110 |
+
# Drugs involved in any discovered severe DDI (among current meds)
|
| 111 |
+
current_ids = {m.drug_id for m in meds}
|
| 112 |
+
drugs_in_severe: Set[str] = set()
|
| 113 |
+
for q in queries:
|
| 114 |
+
if q.severity == "severe":
|
| 115 |
+
if q.drug_id_1 in current_ids:
|
| 116 |
+
drugs_in_severe.add(q.drug_id_1)
|
| 117 |
+
if q.drug_id_2 in current_ids:
|
| 118 |
+
drugs_in_severe.add(q.drug_id_2)
|
| 119 |
+
|
| 120 |
+
features = [
|
| 121 |
+
n_meds / MAX_MEDS,
|
| 122 |
+
n_high_risk / max(n_meds, 1),
|
| 123 |
+
n_beers_any / max(n_meds, 1),
|
| 124 |
+
n_beers_avoid / max(n_meds, 1),
|
| 125 |
+
obs.remaining_query_budget / max(task_cfg.query_budget, 1),
|
| 126 |
+
obs.remaining_intervention_budget / max(task_cfg.intervention_budget, 1),
|
| 127 |
+
n_queries / max(task_cfg.query_budget, 1),
|
| 128 |
+
n_severe / max(n_queries, 1),
|
| 129 |
+
n_moderate / max(n_queries, 1),
|
| 130 |
+
n_interventions / max(task_cfg.intervention_budget, 1),
|
| 131 |
+
obs.step_index / max(task_cfg.max_steps, 1),
|
| 132 |
+
n_queries / max_possible_pairs, # fraction of pairs queried
|
| 133 |
+
float(obs.remaining_query_budget > 0),
|
| 134 |
+
float(obs.remaining_intervention_budget > 0),
|
| 135 |
+
len(drugs_in_severe) / max(n_meds, 1), # how much of the regimen is "hot"
|
| 136 |
+
float(n_meds <= 2), # very few meds left -- may be time to finish
|
| 137 |
+
]
|
| 138 |
+
return torch.tensor(features, dtype=torch.float32)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
# Action-space helpers
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
def get_action_mask(obs: PolypharmacyObservation) -> torch.Tensor:
|
| 146 |
+
"""Return a bool tensor of shape (N_ACTIONS,). True = action is valid."""
|
| 147 |
+
mask = torch.zeros(N_ACTIONS, dtype=torch.bool)
|
| 148 |
+
meds = obs.current_medications
|
| 149 |
+
n_meds = min(len(meds), MAX_MEDS)
|
| 150 |
+
|
| 151 |
+
# Already-queried drug-id pairs (order-invariant)
|
| 152 |
+
queried: Set[frozenset] = set()
|
| 153 |
+
for q in obs.interaction_queries:
|
| 154 |
+
queried.add(frozenset((q.drug_id_1, q.drug_id_2)))
|
| 155 |
+
|
| 156 |
+
# --- query_ddi actions ---
|
| 157 |
+
if obs.remaining_query_budget > 0 and n_meds >= 2:
|
| 158 |
+
for i in range(n_meds):
|
| 159 |
+
for j in range(i + 1, n_meds):
|
| 160 |
+
pair_key = frozenset((meds[i].drug_id, meds[j].drug_id))
|
| 161 |
+
if pair_key not in queried:
|
| 162 |
+
mask[_PAIR_INDEX[(i, j)]] = True
|
| 163 |
+
|
| 164 |
+
# --- propose_intervention actions ---
|
| 165 |
+
if obs.remaining_intervention_budget > 0:
|
| 166 |
+
for i in range(n_meds):
|
| 167 |
+
for k in range(N_INTERVENTION_TYPES):
|
| 168 |
+
mask[N_PAIRS + i * N_INTERVENTION_TYPES + k] = True
|
| 169 |
+
|
| 170 |
+
# --- finish_review (always valid) ---
|
| 171 |
+
mask[FINISH_IDX] = True
|
| 172 |
+
return mask
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def action_idx_to_env_action(
|
| 176 |
+
idx: int,
|
| 177 |
+
meds: list,
|
| 178 |
+
) -> PolypharmacyAction:
|
| 179 |
+
"""Map a flat action index back to a concrete PolypharmacyAction."""
|
| 180 |
+
if idx == FINISH_IDX:
|
| 181 |
+
return PolypharmacyAction(action_type="finish_review")
|
| 182 |
+
|
| 183 |
+
if idx < N_PAIRS:
|
| 184 |
+
i, j = _REVERSE_PAIR[idx]
|
| 185 |
+
return PolypharmacyAction(
|
| 186 |
+
action_type="query_ddi",
|
| 187 |
+
drug_id_1=meds[i].drug_id,
|
| 188 |
+
drug_id_2=meds[j].drug_id,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Otherwise it is an intervention action
|
| 192 |
+
rel = idx - N_PAIRS
|
| 193 |
+
med_idx = rel // N_INTERVENTION_TYPES
|
| 194 |
+
type_idx = rel % N_INTERVENTION_TYPES
|
| 195 |
+
return PolypharmacyAction(
|
| 196 |
+
action_type="propose_intervention",
|
| 197 |
+
target_drug_id=meds[med_idx].drug_id,
|
| 198 |
+
intervention_type=INTERVENTION_TYPES[type_idx],
|
| 199 |
+
rationale="rl_policy",
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
# Neural-network modules
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
class PolicyNetwork(nn.Module):
|
| 208 |
+
"""3-layer MLP that maps state features to action logits."""
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
state_dim: int = STATE_DIM,
|
| 213 |
+
action_dim: int = N_ACTIONS,
|
| 214 |
+
hidden: int = 128,
|
| 215 |
+
) -> None:
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.fc1 = nn.Linear(state_dim, hidden)
|
| 218 |
+
self.fc2 = nn.Linear(hidden, hidden)
|
| 219 |
+
self.fc3 = nn.Linear(hidden, action_dim)
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
state: torch.Tensor,
|
| 224 |
+
mask: torch.Tensor,
|
| 225 |
+
) -> Categorical:
|
| 226 |
+
x = F.relu(self.fc1(state))
|
| 227 |
+
x = F.relu(self.fc2(x))
|
| 228 |
+
logits = self.fc3(x)
|
| 229 |
+
logits = logits.masked_fill(~mask, float("-inf"))
|
| 230 |
+
return Categorical(logits=logits)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ValueNetwork(nn.Module):
|
| 234 |
+
"""3-layer MLP baseline that estimates the expected return from a state."""
|
| 235 |
+
|
| 236 |
+
def __init__(self, state_dim: int = STATE_DIM, hidden: int = 128) -> None:
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.fc1 = nn.Linear(state_dim, hidden)
|
| 239 |
+
self.fc2 = nn.Linear(hidden, hidden // 2)
|
| 240 |
+
self.fc3 = nn.Linear(hidden // 2, 1)
|
| 241 |
+
|
| 242 |
+
def forward(self, state: torch.Tensor) -> torch.Tensor:
|
| 243 |
+
x = F.relu(self.fc1(state))
|
| 244 |
+
x = F.relu(self.fc2(x))
|
| 245 |
+
return self.fc3(x).squeeze(-1)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
# Episode rollout
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
|
| 252 |
+
def run_episode(
|
| 253 |
+
env: PolypharmacyEnv,
|
| 254 |
+
task_id: str,
|
| 255 |
+
policy: PolicyNetwork,
|
| 256 |
+
value_net: ValueNetwork,
|
| 257 |
+
task_cfg: TaskConfig,
|
| 258 |
+
seed: Optional[int] = None,
|
| 259 |
+
greedy: bool = False,
|
| 260 |
+
) -> Dict[str, Any]:
|
| 261 |
+
"""Roll out one full episode, collecting the REINFORCE trajectory.
|
| 262 |
+
|
| 263 |
+
When *greedy* is True the policy acts deterministically (argmax) and
|
| 264 |
+
gradients are not recorded. Used for evaluation.
|
| 265 |
+
"""
|
| 266 |
+
obs = env.reset(task_id=task_id, seed=seed)
|
| 267 |
+
|
| 268 |
+
states: List[torch.Tensor] = []
|
| 269 |
+
actions: List[torch.Tensor] = []
|
| 270 |
+
log_probs: List[torch.Tensor] = []
|
| 271 |
+
rewards: List[float] = []
|
| 272 |
+
values: List[torch.Tensor] = []
|
| 273 |
+
entropies: List[torch.Tensor] = []
|
| 274 |
+
|
| 275 |
+
grader_score = 0.0
|
| 276 |
+
|
| 277 |
+
while not obs.done:
|
| 278 |
+
state = encode_state(obs, task_cfg)
|
| 279 |
+
mask = get_action_mask(obs)
|
| 280 |
+
|
| 281 |
+
# Safety: if somehow no action is valid, force finish
|
| 282 |
+
if not mask.any():
|
| 283 |
+
mask[FINISH_IDX] = True
|
| 284 |
+
|
| 285 |
+
if greedy:
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
dist = policy(state, mask)
|
| 288 |
+
action_idx = dist.probs.argmax()
|
| 289 |
+
value = value_net(state)
|
| 290 |
+
else:
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
value = value_net(state)
|
| 293 |
+
dist = policy(state, mask)
|
| 294 |
+
action_idx = dist.sample()
|
| 295 |
+
|
| 296 |
+
log_prob = dist.log_prob(action_idx)
|
| 297 |
+
entropy = dist.entropy()
|
| 298 |
+
|
| 299 |
+
states.append(state)
|
| 300 |
+
actions.append(action_idx)
|
| 301 |
+
log_probs.append(log_prob)
|
| 302 |
+
values.append(value)
|
| 303 |
+
entropies.append(entropy)
|
| 304 |
+
|
| 305 |
+
env_action = action_idx_to_env_action(
|
| 306 |
+
action_idx.item(), obs.current_medications
|
| 307 |
+
)
|
| 308 |
+
obs = env.step(env_action)
|
| 309 |
+
|
| 310 |
+
reward = float(obs.reward) if obs.reward is not None else 0.0
|
| 311 |
+
rewards.append(reward)
|
| 312 |
+
|
| 313 |
+
if obs.done:
|
| 314 |
+
grader_score = obs.metadata.get("grader_score", 0.0)
|
| 315 |
+
|
| 316 |
+
return {
|
| 317 |
+
"states": states,
|
| 318 |
+
"actions": actions,
|
| 319 |
+
"log_probs": log_probs,
|
| 320 |
+
"rewards": rewards,
|
| 321 |
+
"values": values,
|
| 322 |
+
"entropies": entropies,
|
| 323 |
+
"grader_score": grader_score,
|
| 324 |
+
"total_reward": sum(rewards),
|
| 325 |
+
"n_steps": len(rewards),
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ---------------------------------------------------------------------------
|
| 330 |
+
# Return computation
|
| 331 |
+
# ---------------------------------------------------------------------------
|
| 332 |
+
|
| 333 |
+
def compute_returns(rewards: List[float], gamma: float = 0.99) -> torch.Tensor:
|
| 334 |
+
"""Discounted cumulative returns (G_t) for each timestep."""
|
| 335 |
+
returns: List[float] = []
|
| 336 |
+
g = 0.0
|
| 337 |
+
for r in reversed(rewards):
|
| 338 |
+
g = r + gamma * g
|
| 339 |
+
returns.insert(0, g)
|
| 340 |
+
return torch.tensor(returns, dtype=torch.float32)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# ---------------------------------------------------------------------------
|
| 344 |
+
# Training
|
| 345 |
+
# ---------------------------------------------------------------------------
|
| 346 |
+
|
| 347 |
+
def train(args: argparse.Namespace) -> None: # noqa: C901 (complex but linear)
|
| 348 |
+
task_id: str = args.task
|
| 349 |
+
n_episodes: int = args.episodes
|
| 350 |
+
lr: float = args.lr
|
| 351 |
+
gamma: float = args.gamma
|
| 352 |
+
entropy_coeff: float = args.entropy_coeff
|
| 353 |
+
batch_size: int = args.batch_size
|
| 354 |
+
hidden_dim: int = args.hidden_dim
|
| 355 |
+
print_every: int = args.print_every
|
| 356 |
+
|
| 357 |
+
task_cfg = TASK_CONFIGS[task_id]
|
| 358 |
+
|
| 359 |
+
# ---- Initialise env & networks ----------------------------------------
|
| 360 |
+
env = PolypharmacyEnv()
|
| 361 |
+
policy = PolicyNetwork(STATE_DIM, N_ACTIONS, hidden=hidden_dim)
|
| 362 |
+
value_net = ValueNetwork(STATE_DIM, hidden=hidden_dim)
|
| 363 |
+
|
| 364 |
+
policy_optim = torch.optim.Adam(policy.parameters(), lr=lr)
|
| 365 |
+
value_optim = torch.optim.Adam(value_net.parameters(), lr=lr * 3)
|
| 366 |
+
|
| 367 |
+
# ---- Book-keeping -----------------------------------------------------
|
| 368 |
+
ckpt_dir = Path(args.checkpoint_dir)
|
| 369 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 370 |
+
metrics_path = Path(args.metrics_file)
|
| 371 |
+
|
| 372 |
+
episode_rewards: List[float] = []
|
| 373 |
+
episode_grader_scores: List[float] = []
|
| 374 |
+
episode_steps: List[int] = []
|
| 375 |
+
episode_policy_losses: List[float] = []
|
| 376 |
+
episode_value_losses: List[float] = []
|
| 377 |
+
|
| 378 |
+
best_avg_score: float = -float("inf")
|
| 379 |
+
|
| 380 |
+
print("=" * 72)
|
| 381 |
+
print("REINFORCE Training -- PolypharmacyEnv")
|
| 382 |
+
print("=" * 72)
|
| 383 |
+
print(f" task : {task_id}")
|
| 384 |
+
print(f" episodes : {n_episodes}")
|
| 385 |
+
print(f" batch_size : {batch_size}")
|
| 386 |
+
print(f" lr : {lr}")
|
| 387 |
+
print(f" gamma : {gamma}")
|
| 388 |
+
print(f" entropy_coeff : {entropy_coeff}")
|
| 389 |
+
print(f" hidden_dim : {hidden_dim}")
|
| 390 |
+
print(f" state_dim : {STATE_DIM}")
|
| 391 |
+
print(f" action_space : {N_ACTIONS}")
|
| 392 |
+
print(f" task budgets : query={task_cfg.query_budget} "
|
| 393 |
+
f"intervention={task_cfg.intervention_budget} "
|
| 394 |
+
f"max_steps={task_cfg.max_steps}")
|
| 395 |
+
print(f" checkpoint_dir : {ckpt_dir}")
|
| 396 |
+
print(f" metrics_file : {metrics_path}")
|
| 397 |
+
print("=" * 72)
|
| 398 |
+
print()
|
| 399 |
+
|
| 400 |
+
t_start = time.time()
|
| 401 |
+
|
| 402 |
+
# ---- Main training loop -----------------------------------------------
|
| 403 |
+
# Accumulate a mini-batch of trajectories, then perform one gradient step.
|
| 404 |
+
batch_trajs: List[Dict[str, Any]] = []
|
| 405 |
+
|
| 406 |
+
for ep in range(1, n_episodes + 1):
|
| 407 |
+
traj = run_episode(env, task_id, policy, value_net, task_cfg, seed=ep)
|
| 408 |
+
|
| 409 |
+
episode_rewards.append(traj["total_reward"])
|
| 410 |
+
episode_grader_scores.append(traj["grader_score"])
|
| 411 |
+
episode_steps.append(traj["n_steps"])
|
| 412 |
+
|
| 413 |
+
if traj["n_steps"] == 0:
|
| 414 |
+
# Degenerate episode (should not happen); skip update
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
batch_trajs.append(traj)
|
| 418 |
+
|
| 419 |
+
# ---- Gradient step every batch_size episodes ----------------------
|
| 420 |
+
if len(batch_trajs) >= batch_size:
|
| 421 |
+
# Aggregate losses across the batch
|
| 422 |
+
total_policy_loss = torch.tensor(0.0)
|
| 423 |
+
total_value_loss = torch.tensor(0.0)
|
| 424 |
+
total_entropy = torch.tensor(0.0)
|
| 425 |
+
total_steps = 0
|
| 426 |
+
|
| 427 |
+
for bt in batch_trajs:
|
| 428 |
+
returns = compute_returns(bt["rewards"], gamma)
|
| 429 |
+
old_values_t = torch.stack(bt["values"]) # detached, from rollout
|
| 430 |
+
log_probs_t = torch.stack(bt["log_probs"])
|
| 431 |
+
entropies_t = torch.stack(bt["entropies"])
|
| 432 |
+
|
| 433 |
+
# Advantages use the *detached* rollout values as baseline
|
| 434 |
+
advantages = returns - old_values_t.detach()
|
| 435 |
+
# Per-trajectory advantage normalisation (reduces variance)
|
| 436 |
+
if len(advantages) > 1:
|
| 437 |
+
advantages = (advantages - advantages.mean()) / (
|
| 438 |
+
advantages.std() + 1e-8
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# REINFORCE policy gradient (negative because we minimise)
|
| 442 |
+
total_policy_loss = total_policy_loss + (
|
| 443 |
+
-(log_probs_t * advantages).sum()
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Recompute value predictions WITH gradients for the value loss
|
| 447 |
+
states_t = torch.stack(bt["states"])
|
| 448 |
+
fresh_values = value_net(states_t)
|
| 449 |
+
total_value_loss = total_value_loss + F.mse_loss(
|
| 450 |
+
fresh_values, returns, reduction="sum"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Entropy (we want to maximise -> subtract from loss)
|
| 454 |
+
total_entropy = total_entropy + entropies_t.sum()
|
| 455 |
+
total_steps += len(bt["rewards"])
|
| 456 |
+
|
| 457 |
+
# Normalise by total number of timesteps in the batch
|
| 458 |
+
total_policy_loss = total_policy_loss / total_steps
|
| 459 |
+
total_value_loss = total_value_loss / total_steps
|
| 460 |
+
total_entropy = total_entropy / total_steps
|
| 461 |
+
|
| 462 |
+
# Combined policy loss with entropy bonus
|
| 463 |
+
combined_policy_loss = total_policy_loss - entropy_coeff * total_entropy
|
| 464 |
+
|
| 465 |
+
policy_optim.zero_grad()
|
| 466 |
+
combined_policy_loss.backward()
|
| 467 |
+
nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
|
| 468 |
+
policy_optim.step()
|
| 469 |
+
|
| 470 |
+
value_optim.zero_grad()
|
| 471 |
+
total_value_loss.backward()
|
| 472 |
+
nn.utils.clip_grad_norm_(value_net.parameters(), max_norm=1.0)
|
| 473 |
+
value_optim.step()
|
| 474 |
+
|
| 475 |
+
episode_policy_losses.append(total_policy_loss.item())
|
| 476 |
+
episode_value_losses.append(total_value_loss.item())
|
| 477 |
+
|
| 478 |
+
batch_trajs = []
|
| 479 |
+
|
| 480 |
+
# ---- Logging ------------------------------------------------------
|
| 481 |
+
if ep % print_every == 0 or ep == 1:
|
| 482 |
+
window = min(print_every, ep)
|
| 483 |
+
recent_r = episode_rewards[-window:]
|
| 484 |
+
recent_s = episode_grader_scores[-window:]
|
| 485 |
+
recent_st = episode_steps[-window:]
|
| 486 |
+
avg_r = sum(recent_r) / len(recent_r)
|
| 487 |
+
avg_s = sum(recent_s) / len(recent_s)
|
| 488 |
+
avg_st = sum(recent_st) / len(recent_st)
|
| 489 |
+
elapsed = time.time() - t_start
|
| 490 |
+
print(
|
| 491 |
+
f"[ep {ep:>4d}/{n_episodes}] "
|
| 492 |
+
f"avg_reward={avg_r:+.4f} "
|
| 493 |
+
f"avg_grader={avg_s:.4f} "
|
| 494 |
+
f"avg_steps={avg_st:.1f} "
|
| 495 |
+
f"elapsed={elapsed:.1f}s"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Save best checkpoint based on rolling grader score
|
| 499 |
+
eval_window = min(30, ep)
|
| 500 |
+
rolling_score = sum(episode_grader_scores[-eval_window:]) / eval_window
|
| 501 |
+
if rolling_score > best_avg_score:
|
| 502 |
+
best_avg_score = rolling_score
|
| 503 |
+
_save_checkpoint(
|
| 504 |
+
policy, value_net, policy_optim, value_optim,
|
| 505 |
+
ep, best_avg_score, task_id,
|
| 506 |
+
ckpt_dir / f"best_{task_id}.pt",
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
# ---- Final checkpoint -------------------------------------------------
|
| 510 |
+
_save_checkpoint(
|
| 511 |
+
policy, value_net, policy_optim, value_optim,
|
| 512 |
+
n_episodes, best_avg_score, task_id,
|
| 513 |
+
ckpt_dir / f"final_{task_id}.pt",
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# ---- Save training metrics to JSON ------------------------------------
|
| 517 |
+
metrics = {
|
| 518 |
+
"task_id": task_id,
|
| 519 |
+
"n_episodes": n_episodes,
|
| 520 |
+
"hyperparameters": {
|
| 521 |
+
"lr": lr,
|
| 522 |
+
"gamma": gamma,
|
| 523 |
+
"entropy_coeff": entropy_coeff,
|
| 524 |
+
"batch_size": batch_size,
|
| 525 |
+
"hidden_dim": hidden_dim,
|
| 526 |
+
"state_dim": STATE_DIM,
|
| 527 |
+
"action_dim": N_ACTIONS,
|
| 528 |
+
},
|
| 529 |
+
"episode_rewards": episode_rewards,
|
| 530 |
+
"episode_grader_scores": episode_grader_scores,
|
| 531 |
+
"episode_steps": episode_steps,
|
| 532 |
+
"policy_losses": episode_policy_losses,
|
| 533 |
+
"value_losses": episode_value_losses,
|
| 534 |
+
"best_avg_grader_score": best_avg_score,
|
| 535 |
+
"total_training_time_s": time.time() - t_start,
|
| 536 |
+
}
|
| 537 |
+
metrics_path.parent.mkdir(parents=True, exist_ok=True)
|
| 538 |
+
with open(metrics_path, "w") as f:
|
| 539 |
+
json.dump(metrics, f, indent=2)
|
| 540 |
+
print(f"\nTraining metrics saved to {metrics_path}")
|
| 541 |
+
|
| 542 |
+
# ---- Post-training evaluation -----------------------------------------
|
| 543 |
+
n_eval = 20
|
| 544 |
+
print("\n" + "=" * 72)
|
| 545 |
+
print(f"Post-training evaluation ({n_eval} episodes each mode)")
|
| 546 |
+
print("=" * 72)
|
| 547 |
+
|
| 548 |
+
for mode, is_greedy in [("stochastic", False), ("greedy", True)]:
|
| 549 |
+
eval_rewards, eval_scores, eval_steps_list = [], [], []
|
| 550 |
+
for i in range(n_eval):
|
| 551 |
+
traj = run_episode(
|
| 552 |
+
env, task_id, policy, value_net, task_cfg,
|
| 553 |
+
seed=10_000 + i, greedy=is_greedy,
|
| 554 |
+
)
|
| 555 |
+
eval_rewards.append(traj["total_reward"])
|
| 556 |
+
eval_scores.append(traj["grader_score"])
|
| 557 |
+
eval_steps_list.append(traj["n_steps"])
|
| 558 |
+
avg_r = sum(eval_rewards) / len(eval_rewards)
|
| 559 |
+
avg_s = sum(eval_scores) / len(eval_scores)
|
| 560 |
+
avg_st = sum(eval_steps_list) / len(eval_steps_list)
|
| 561 |
+
print(
|
| 562 |
+
f" [{mode:>10s}] avg_reward={avg_r:+.4f} "
|
| 563 |
+
f"avg_grader={avg_s:.4f} avg_steps={avg_st:.1f}"
|
| 564 |
+
)
|
| 565 |
+
metrics[f"eval_{mode}_avg_reward"] = avg_r
|
| 566 |
+
metrics[f"eval_{mode}_avg_grader_score"] = avg_s
|
| 567 |
+
metrics[f"eval_{mode}_avg_steps"] = avg_st
|
| 568 |
+
metrics[f"eval_{mode}_rewards"] = eval_rewards
|
| 569 |
+
metrics[f"eval_{mode}_grader_scores"] = eval_scores
|
| 570 |
+
|
| 571 |
+
print(f" best training rolling-avg grader: {best_avg_score:.4f}")
|
| 572 |
+
print()
|
| 573 |
+
|
| 574 |
+
with open(metrics_path, "w") as f:
|
| 575 |
+
json.dump(metrics, f, indent=2)
|
| 576 |
+
|
| 577 |
+
print("Done.")
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# ---------------------------------------------------------------------------
|
| 581 |
+
# Checkpoint I/O
|
| 582 |
+
# ---------------------------------------------------------------------------
|
| 583 |
+
|
| 584 |
+
def _save_checkpoint(
|
| 585 |
+
policy: PolicyNetwork,
|
| 586 |
+
value_net: ValueNetwork,
|
| 587 |
+
policy_optim: torch.optim.Optimizer,
|
| 588 |
+
value_optim: torch.optim.Optimizer,
|
| 589 |
+
episode: int,
|
| 590 |
+
best_score: float,
|
| 591 |
+
task_id: str,
|
| 592 |
+
path: Path,
|
| 593 |
+
) -> None:
|
| 594 |
+
torch.save(
|
| 595 |
+
{
|
| 596 |
+
"episode": episode,
|
| 597 |
+
"best_avg_grader_score": best_score,
|
| 598 |
+
"task_id": task_id,
|
| 599 |
+
"policy_state_dict": policy.state_dict(),
|
| 600 |
+
"value_state_dict": value_net.state_dict(),
|
| 601 |
+
"policy_optim_state_dict": policy_optim.state_dict(),
|
| 602 |
+
"value_optim_state_dict": value_optim.state_dict(),
|
| 603 |
+
"state_dim": STATE_DIM,
|
| 604 |
+
"action_dim": N_ACTIONS,
|
| 605 |
+
},
|
| 606 |
+
path,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def load_checkpoint(
|
| 611 |
+
path: Path,
|
| 612 |
+
hidden_dim: int = 128,
|
| 613 |
+
) -> Tuple[PolicyNetwork, ValueNetwork]:
|
| 614 |
+
"""Load a trained policy + value net from a checkpoint file."""
|
| 615 |
+
ckpt = torch.load(path, map_location="cpu")
|
| 616 |
+
policy = PolicyNetwork(
|
| 617 |
+
ckpt.get("state_dim", STATE_DIM),
|
| 618 |
+
ckpt.get("action_dim", N_ACTIONS),
|
| 619 |
+
hidden=hidden_dim,
|
| 620 |
+
)
|
| 621 |
+
value_net = ValueNetwork(ckpt.get("state_dim", STATE_DIM), hidden=hidden_dim)
|
| 622 |
+
policy.load_state_dict(ckpt["policy_state_dict"])
|
| 623 |
+
value_net.load_state_dict(ckpt["value_state_dict"])
|
| 624 |
+
return policy, value_net
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# ---------------------------------------------------------------------------
|
| 628 |
+
# CLI
|
| 629 |
+
# ---------------------------------------------------------------------------
|
| 630 |
+
|
| 631 |
+
def parse_args() -> argparse.Namespace:
|
| 632 |
+
p = argparse.ArgumentParser(
|
| 633 |
+
description="REINFORCE training for PolypharmacyEnv",
|
| 634 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 635 |
+
)
|
| 636 |
+
p.add_argument(
|
| 637 |
+
"--task",
|
| 638 |
+
type=str,
|
| 639 |
+
default="easy_screening",
|
| 640 |
+
choices=list(TASK_CONFIGS.keys()),
|
| 641 |
+
help="Task difficulty to train on",
|
| 642 |
+
)
|
| 643 |
+
p.add_argument("--episodes", type=int, default=200, help="Number of training episodes")
|
| 644 |
+
p.add_argument("--batch-size", type=int, default=5, help="Episodes per gradient update")
|
| 645 |
+
p.add_argument("--lr", type=float, default=3e-4, help="Learning rate for Adam")
|
| 646 |
+
p.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
|
| 647 |
+
p.add_argument(
|
| 648 |
+
"--entropy-coeff", type=float, default=0.02,
|
| 649 |
+
help="Entropy bonus coefficient (higher = more exploration)",
|
| 650 |
+
)
|
| 651 |
+
p.add_argument("--hidden-dim", type=int, default=128, help="Hidden layer width")
|
| 652 |
+
p.add_argument("--print-every", type=int, default=10, help="Print interval (episodes)")
|
| 653 |
+
p.add_argument(
|
| 654 |
+
"--checkpoint-dir",
|
| 655 |
+
type=str,
|
| 656 |
+
default=os.path.join(_BACKEND_SRC, "polypharmacy_env", "checkpoints"),
|
| 657 |
+
help="Directory to save model checkpoints",
|
| 658 |
+
)
|
| 659 |
+
p.add_argument(
|
| 660 |
+
"--metrics-file",
|
| 661 |
+
type=str,
|
| 662 |
+
default="training_metrics.json",
|
| 663 |
+
help="Path for JSON training metrics",
|
| 664 |
+
)
|
| 665 |
+
return p.parse_args()
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
# ---------------------------------------------------------------------------
|
| 669 |
+
# Entry point
|
| 670 |
+
# ---------------------------------------------------------------------------
|
| 671 |
+
|
| 672 |
+
if __name__ == "__main__":
|
| 673 |
+
args = parse_args()
|
| 674 |
+
train(args)
|
training_metrics.json
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"task_id": "easy_screening",
|
| 3 |
+
"n_episodes": 30,
|
| 4 |
+
"hyperparameters": {
|
| 5 |
+
"lr": 0.0003,
|
| 6 |
+
"gamma": 0.99,
|
| 7 |
+
"entropy_coeff": 0.02,
|
| 8 |
+
"batch_size": 5,
|
| 9 |
+
"hidden_dim": 128,
|
| 10 |
+
"state_dim": 16,
|
| 11 |
+
"action_dim": 166
|
| 12 |
+
},
|
| 13 |
+
"episode_rewards": [
|
| 14 |
+
0.47,
|
| 15 |
+
1.1073118279569893,
|
| 16 |
+
1.1486231884057971,
|
| 17 |
+
1.1336231884057972,
|
| 18 |
+
0.405,
|
| 19 |
+
0.395,
|
| 20 |
+
1.1296806853582555,
|
| 21 |
+
0.38,
|
| 22 |
+
0.7283823529411766,
|
| 23 |
+
0.0,
|
| 24 |
+
0.39,
|
| 25 |
+
-0.095,
|
| 26 |
+
1.137962962962963,
|
| 27 |
+
1.1951785714285714,
|
| 28 |
+
0.9053636363636364,
|
| 29 |
+
-0.01754088050314473,
|
| 30 |
+
-0.04,
|
| 31 |
+
-0.06,
|
| 32 |
+
0.0,
|
| 33 |
+
0.22666666666666668,
|
| 34 |
+
0.435,
|
| 35 |
+
0.45,
|
| 36 |
+
0.45,
|
| 37 |
+
0.37666666666666665,
|
| 38 |
+
0.435,
|
| 39 |
+
0.455,
|
| 40 |
+
0.5412162162162163,
|
| 41 |
+
0.33899999999999997,
|
| 42 |
+
0.3416666666666666,
|
| 43 |
+
0.42
|
| 44 |
+
],
|
| 45 |
+
"episode_grader_scores": [
|
| 46 |
+
0.5,
|
| 47 |
+
0.8306451612903226,
|
| 48 |
+
0.8369565217391305,
|
| 49 |
+
0.8369565217391305,
|
| 50 |
+
0.5,
|
| 51 |
+
0.5,
|
| 52 |
+
0.9688473520249221,
|
| 53 |
+
0.5,
|
| 54 |
+
0.7058823529411765,
|
| 55 |
+
0.0,
|
| 56 |
+
0.5,
|
| 57 |
+
0.0,
|
| 58 |
+
0.837962962962963,
|
| 59 |
+
0.9776785714285714,
|
| 60 |
+
0.8863636363636364,
|
| 61 |
+
0.053459119496855306,
|
| 62 |
+
0.0,
|
| 63 |
+
0.0,
|
| 64 |
+
0.0,
|
| 65 |
+
0.5,
|
| 66 |
+
0.5,
|
| 67 |
+
0.5,
|
| 68 |
+
0.5,
|
| 69 |
+
0.5,
|
| 70 |
+
0.5,
|
| 71 |
+
0.5,
|
| 72 |
+
0.5495495495495496,
|
| 73 |
+
0.5,
|
| 74 |
+
0.27499999999999997,
|
| 75 |
+
0.5
|
| 76 |
+
],
|
| 77 |
+
"episode_steps": [
|
| 78 |
+
5,
|
| 79 |
+
4,
|
| 80 |
+
3,
|
| 81 |
+
4,
|
| 82 |
+
6,
|
| 83 |
+
6,
|
| 84 |
+
5,
|
| 85 |
+
7,
|
| 86 |
+
4,
|
| 87 |
+
1,
|
| 88 |
+
7,
|
| 89 |
+
8,
|
| 90 |
+
4,
|
| 91 |
+
3,
|
| 92 |
+
10,
|
| 93 |
+
6,
|
| 94 |
+
3,
|
| 95 |
+
7,
|
| 96 |
+
1,
|
| 97 |
+
4,
|
| 98 |
+
4,
|
| 99 |
+
3,
|
| 100 |
+
3,
|
| 101 |
+
5,
|
| 102 |
+
4,
|
| 103 |
+
6,
|
| 104 |
+
4,
|
| 105 |
+
7,
|
| 106 |
+
2,
|
| 107 |
+
5
|
| 108 |
+
],
|
| 109 |
+
"policy_losses": [
|
| 110 |
+
0.28967130184173584,
|
| 111 |
+
0.05730011314153671,
|
| 112 |
+
-0.06924888491630554,
|
| 113 |
+
-0.28697478771209717,
|
| 114 |
+
-0.1783256083726883,
|
| 115 |
+
-0.12063005566596985
|
| 116 |
+
],
|
| 117 |
+
"value_losses": [
|
| 118 |
+
0.39626142382621765,
|
| 119 |
+
0.24146510660648346,
|
| 120 |
+
0.29013994336128235,
|
| 121 |
+
0.06388193368911743,
|
| 122 |
+
0.02375689707696438,
|
| 123 |
+
0.02241377718746662
|
| 124 |
+
],
|
| 125 |
+
"best_avg_grader_score": 0.6179287909734683,
|
| 126 |
+
"total_training_time_s": 0.13345718383789062,
|
| 127 |
+
"eval_stochastic_avg_reward": 0.5792066304974347,
|
| 128 |
+
"eval_stochastic_avg_grader_score": 0.5784816304974348,
|
| 129 |
+
"eval_stochastic_avg_steps": 5.6,
|
| 130 |
+
"eval_stochastic_rewards": [
|
| 131 |
+
1.0402956989247312,
|
| 132 |
+
0.485,
|
| 133 |
+
-0.11333333333333336,
|
| 134 |
+
0.455,
|
| 135 |
+
0.36250000000000004,
|
| 136 |
+
1.0112089552238808,
|
| 137 |
+
0.41,
|
| 138 |
+
0.21499999999999997,
|
| 139 |
+
1.1173118279569894,
|
| 140 |
+
0.883621495327103,
|
| 141 |
+
0.45,
|
| 142 |
+
1.1073118279569893,
|
| 143 |
+
0.8680882352941177,
|
| 144 |
+
0.405,
|
| 145 |
+
0.675031128404669,
|
| 146 |
+
0.45,
|
| 147 |
+
-0.11,
|
| 148 |
+
0.415,
|
| 149 |
+
0.32,
|
| 150 |
+
1.1370967741935485
|
| 151 |
+
],
|
| 152 |
+
"eval_stochastic_grader_scores": [
|
| 153 |
+
0.8494623655913979,
|
| 154 |
+
0.5,
|
| 155 |
+
0.0,
|
| 156 |
+
0.5,
|
| 157 |
+
0.5,
|
| 158 |
+
0.9207089552238806,
|
| 159 |
+
0.5,
|
| 160 |
+
0.5,
|
| 161 |
+
0.8306451612903226,
|
| 162 |
+
0.8411214953271029,
|
| 163 |
+
0.5,
|
| 164 |
+
0.8306451612903226,
|
| 165 |
+
0.803921568627451,
|
| 166 |
+
0.5,
|
| 167 |
+
0.6060311284046691,
|
| 168 |
+
0.5,
|
| 169 |
+
0.0,
|
| 170 |
+
0.5,
|
| 171 |
+
0.5,
|
| 172 |
+
0.8870967741935485
|
| 173 |
+
],
|
| 174 |
+
"eval_greedy_avg_reward": 0.3627500000000001,
|
| 175 |
+
"eval_greedy_avg_grader_score": 0.425,
|
| 176 |
+
"eval_greedy_avg_steps": 6.45,
|
| 177 |
+
"eval_greedy_rewards": [
|
| 178 |
+
0.44,
|
| 179 |
+
0.455,
|
| 180 |
+
-0.08,
|
| 181 |
+
0.455,
|
| 182 |
+
-0.06,
|
| 183 |
+
0.39,
|
| 184 |
+
0.455,
|
| 185 |
+
0.455,
|
| 186 |
+
0.455,
|
| 187 |
+
-0.06,
|
| 188 |
+
0.42,
|
| 189 |
+
0.455,
|
| 190 |
+
0.41000000000000003,
|
| 191 |
+
0.455,
|
| 192 |
+
0.44,
|
| 193 |
+
0.39,
|
| 194 |
+
0.41000000000000003,
|
| 195 |
+
0.49,
|
| 196 |
+
0.44,
|
| 197 |
+
0.44
|
| 198 |
+
],
|
| 199 |
+
"eval_greedy_grader_scores": [
|
| 200 |
+
0.5,
|
| 201 |
+
0.5,
|
| 202 |
+
0.0,
|
| 203 |
+
0.5,
|
| 204 |
+
0.0,
|
| 205 |
+
0.5,
|
| 206 |
+
0.5,
|
| 207 |
+
0.5,
|
| 208 |
+
0.5,
|
| 209 |
+
0.0,
|
| 210 |
+
0.5,
|
| 211 |
+
0.5,
|
| 212 |
+
0.5,
|
| 213 |
+
0.5,
|
| 214 |
+
0.5,
|
| 215 |
+
0.5,
|
| 216 |
+
0.5,
|
| 217 |
+
0.5,
|
| 218 |
+
0.5,
|
| 219 |
+
0.5
|
| 220 |
+
]
|
| 221 |
+
}
|