Spaces:
Running
Running
Deploy PolyGuard OpenEnv Space
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env.example +13 -0
- .gitignore +23 -0
- Dockerfile +11 -0
- LICENSE +21 -0
- Makefile +29 -0
- README.md +241 -5
- __init__.py +5 -0
- app/__init__.py +1 -0
- app/agents/__init__.py +5 -0
- app/agents/candidate_agent.py +14 -0
- app/agents/critic_agent.py +43 -0
- app/agents/critic_safety_agent.py +11 -0
- app/agents/dosing_agent.py +52 -0
- app/agents/evidence_agent.py +14 -0
- app/agents/explainer_agent.py +22 -0
- app/agents/graph_agent.py +28 -0
- app/agents/graph_safety_agent.py +11 -0
- app/agents/medrec_agent.py +22 -0
- app/agents/orchestrator.py +151 -0
- app/agents/planner_agent.py +44 -0
- app/agents/supervisor_agent.py +23 -0
- app/api/__init__.py +34 -0
- app/api/__main__.py +7 -0
- app/api/dependencies.py +11 -0
- app/api/main.py +10 -0
- app/api/routes.py +134 -0
- app/api/schemas.py +57 -0
- app/api/service.py +186 -0
- app/common/config.py +39 -0
- app/common/constants.py +40 -0
- app/common/enums.py +61 -0
- app/common/exceptions.py +19 -0
- app/common/json_utils.py +14 -0
- app/common/logging_utils.py +17 -0
- app/common/normalization.py +24 -0
- app/common/seeding.py +17 -0
- app/common/types.py +175 -0
- app/dataops/__init__.py +5 -0
- app/dataops/ddi_api.py +65 -0
- app/dataops/normalizer.py +13 -0
- app/dataops/package_loader.py +19 -0
- app/dataops/parser.py +26 -0
- app/dataops/provenance.py +31 -0
- app/dataops/scraper.py +9 -0
- app/dataops/source_manager.py +111 -0
- app/dataops/synthetic_mix.py +9 -0
- app/dataops/web_agent.py +20 -0
- app/dataops/web_fallback.py +59 -0
- app/env/__init__.py +27 -0
- app/env/actions.py +7 -0
.env.example
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
POLYGUARD_DATA_DIR=./data
|
| 2 |
+
POLYGUARD_LOG_LEVEL=INFO
|
| 3 |
+
POLYGUARD_SEED=42
|
| 4 |
+
POLYGUARD_ENV_HOST=127.0.0.1
|
| 5 |
+
POLYGUARD_ENV_PORT=8100
|
| 6 |
+
POLYGUARD_API_HOST=127.0.0.1
|
| 7 |
+
POLYGUARD_API_PORT=8200
|
| 8 |
+
POLYGUARD_UI_PORT=5173
|
| 9 |
+
POLYGUARD_OLLAMA_MODEL=qwen2.5:3b-instruct
|
| 10 |
+
POLYGUARD_FRONTIER_MODEL=Qwen/Qwen2.5-7B-Instruct
|
| 11 |
+
POLYGUARD_ALLOW_WEB_FETCH=false
|
| 12 |
+
POLYGUARD_REWARD_MIN=0.001
|
| 13 |
+
POLYGUARD_REWARD_MAX=0.999
|
.gitignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.pytest_cache/
|
| 7 |
+
.mypy_cache/
|
| 8 |
+
.ruff_cache/
|
| 9 |
+
.venv/
|
| 10 |
+
.env
|
| 11 |
+
node_modules/
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
| 14 |
+
*.log
|
| 15 |
+
outputs/
|
| 16 |
+
checkpoints/
|
| 17 |
+
artifacts/
|
| 18 |
+
data/cache/*
|
| 19 |
+
data/processed/*
|
| 20 |
+
data/synthetic/*
|
| 21 |
+
data/retrieval_index/*
|
| 22 |
+
!data/**/.gitkeep
|
| 23 |
+
app/ui/frontend/.vite/
|
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
COPY requirements.txt .
|
| 5 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 6 |
+
|
| 7 |
+
COPY . .
|
| 8 |
+
|
| 9 |
+
EXPOSE 8100 8200
|
| 10 |
+
|
| 11 |
+
CMD ["python", "-m", "server.app", "--host", "0.0.0.0", "--port", "8100"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Makefile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: install test lint env api ui smoke run-all
|
| 2 |
+
|
| 3 |
+
VENV_DIR := .venv
|
| 4 |
+
PYTHON := $(VENV_DIR)/bin/python
|
| 5 |
+
PIP := $(VENV_DIR)/bin/pip
|
| 6 |
+
|
| 7 |
+
$(PYTHON):
|
| 8 |
+
python3 -m venv $(VENV_DIR)
|
| 9 |
+
|
| 10 |
+
install: $(PYTHON)
|
| 11 |
+
bash scripts/bootstrap_venv.sh
|
| 12 |
+
|
| 13 |
+
test: $(PYTHON)
|
| 14 |
+
PYTHONPATH=. $(PYTHON) -m pytest
|
| 15 |
+
|
| 16 |
+
env: $(PYTHON)
|
| 17 |
+
PYTHONPATH=. $(PYTHON) -m app.env.fastapi_app
|
| 18 |
+
|
| 19 |
+
api: $(PYTHON)
|
| 20 |
+
PYTHONPATH=. $(PYTHON) -m app.api
|
| 21 |
+
|
| 22 |
+
ui:
|
| 23 |
+
cd app/ui/frontend && npm install && npm run dev
|
| 24 |
+
|
| 25 |
+
smoke:
|
| 26 |
+
bash scripts/smoke_test_all.sh
|
| 27 |
+
|
| 28 |
+
run-all: $(PYTHON)
|
| 29 |
+
bash scripts/run_all_local.sh --full
|
README.md
CHANGED
|
@@ -1,10 +1,246 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
colorTo: purple
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PolyGuard OpenEnv
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: green
|
|
|
|
| 5 |
sdk: docker
|
| 6 |
+
app_port: 8100
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# POLYGUARD-OPENENV
|
| 11 |
+
|
| 12 |
+
PolyGuard is an OpenEnv-compatible reinforcement-learning environment for **polypharmacy safety, medication optimization, deprescribing, and precision dosing**. The project turns medication decision making into a stateful environment where an LLM agent observes a patient/regimen state, chooses constrained clinical actions, receives verifier-backed reward, and improves through TRL/GRPO-style post-training.
|
| 13 |
+
|
| 14 |
+
> Clinical safety note: this is a research environment and demo system for RL environment design. It is not a medical device and must not be used for patient care.
|
| 15 |
+
|
| 16 |
+
## Submission Links
|
| 17 |
+
|
| 18 |
+
- GitHub Repo URL: [https://github.com/Vishwa-docs/Meta_Pytorch_OpenEnv_Scaler_VK](https://github.com/Vishwa-docs/Meta_Pytorch_OpenEnv_Scaler_VK)
|
| 19 |
+
- HF Space URL: [https://huggingface.co/spaces/Vishwa-docs/polyguard-openenv](https://huggingface.co/spaces/Vishwa-docs/polyguard-openenv) *(deployment target; verify before final submission)*
|
| 20 |
+
- Colab Notebook URL: [https://colab.research.google.com/github/Vishwa-docs/Meta_Pytorch_OpenEnv_Scaler_VK/blob/master/polyguard-rl/notebooks/09_training_loop.ipynb](https://colab.research.google.com/github/Vishwa-docs/Meta_Pytorch_OpenEnv_Scaler_VK/blob/master/polyguard-rl/notebooks/09_training_loop.ipynb)
|
| 21 |
+
- YouTube Video URL: not used for this submission; the Hugging Face blog URL below is the selected story artifact.
|
| 22 |
+
- Hugging Face Blog URL: [https://huggingface.co/blog/Vishwa-docs/polyguard-openenv](https://huggingface.co/blog/Vishwa-docs/polyguard-openenv) *(story target; publish before final submission)*
|
| 23 |
+
|
| 24 |
+
## Current Readiness
|
| 25 |
+
|
| 26 |
+
Verified locally:
|
| 27 |
+
|
| 28 |
+
- `uv run pytest`: 36 tests passed during the audit pass.
|
| 29 |
+
- `uv run openenv validate .`: local OpenEnv packaging passed.
|
| 30 |
+
- `bash scripts/bootstrap_openenv.sh --runtime-check`: runtime OpenEnv HTTP contract passed when localhost access was allowed.
|
| 31 |
+
- `npm run build` in `app/ui/frontend`: production UI build passed.
|
| 32 |
+
|
| 33 |
+
Still required for final judge-ready submission:
|
| 34 |
+
|
| 35 |
+
- Authenticate Hugging Face with `./.venv/bin/hf auth login`.
|
| 36 |
+
- Deploy and verify the HF Space.
|
| 37 |
+
- Run real TRL/Unsloth SFT and GRPO on GPU/Colab so reports no longer show fallback paths.
|
| 38 |
+
- Replace `docs/results/hf_space_verification.json` with a successful verification payload.
|
| 39 |
+
- Regenerate final plots and reports with `improvement_report.improved == true`.
|
| 40 |
+
- Run strict readiness: `POLYGUARD_ENFORCE_SUBMISSION_LINKS=true ./.venv/bin/python scripts/acceptance_gate.py`.
|
| 41 |
+
|
| 42 |
+
## Problem Statement
|
| 43 |
+
|
| 44 |
+
Polypharmacy decisions are long-horizon, partially observable, and safety-critical. A useful LLM agent must do more than produce a plausible recommendation: it should identify drug-drug interaction risk, reason over comorbidities and labs, choose safe substitutions or deprescribing sequences, request review when uncertain, and expose why it acted.
|
| 45 |
+
|
| 46 |
+
PolyGuard targets the OpenEnv **World Modeling / Professional Tasks** theme, with multi-agent and self-improvement elements. It asks whether environment-backed feedback can make a model better at safe medication action selection than prompt-only or rule-only baselines.
|
| 47 |
+
|
| 48 |
+
## Environment
|
| 49 |
+
|
| 50 |
+
The environment is implemented by `PolyGuardEnv` and exposed through FastAPI/OpenEnv-compatible endpoints:
|
| 51 |
+
|
| 52 |
+
- `POST /reset`
|
| 53 |
+
- `POST /step`
|
| 54 |
+
- `GET /state`
|
| 55 |
+
- `GET /metadata`
|
| 56 |
+
- `GET /schema`
|
| 57 |
+
- `POST /mcp`
|
| 58 |
+
- `GET /health`
|
| 59 |
+
- Backward-compatible aliases under `/env/*` plus `/ws`
|
| 60 |
+
|
| 61 |
+
OpenEnv packaging lives at repo root:
|
| 62 |
+
|
| 63 |
+
- `openenv.yaml`
|
| 64 |
+
- `__init__.py`
|
| 65 |
+
- `client.py`
|
| 66 |
+
- `models.py`
|
| 67 |
+
- `server/app.py`
|
| 68 |
+
|
| 69 |
+
Each episode samples a patient/regimen scenario and a sub-environment:
|
| 70 |
+
|
| 71 |
+
- `DDI`
|
| 72 |
+
- `BANDIT_MINING`
|
| 73 |
+
- `REGIMEN_RISK`
|
| 74 |
+
- `PRECISION_DOSING`
|
| 75 |
+
- `LONGITUDINAL_DEPRESCRIBING`
|
| 76 |
+
- `WEB_SEARCH_MISSING_DATA`
|
| 77 |
+
- `ALTERNATIVE_SUGGESTION`
|
| 78 |
+
- `NEW_DRUG_DECOMPOSITION`
|
| 79 |
+
|
| 80 |
+
Difficulty tracks are available as easy, medium, and hard scenario sets.
|
| 81 |
+
|
| 82 |
+
## Agent Capabilities
|
| 83 |
+
|
| 84 |
+
The agent stack is deliberately decomposed so reward, safety, and explanation can be inspected:
|
| 85 |
+
|
| 86 |
+
- Medication reconciliation
|
| 87 |
+
- Evidence retrieval and missing-data recovery
|
| 88 |
+
- Graph safety analysis for DDI and side effects
|
| 89 |
+
- Dosing guardrails
|
| 90 |
+
- Candidate generation
|
| 91 |
+
- Supervisor routing between regimen, dose, and review modes
|
| 92 |
+
- Planner policy selection
|
| 93 |
+
- Critic safety veto
|
| 94 |
+
- Explanation generation
|
| 95 |
+
- Contextual bandit ranking for policy-stack ablations
|
| 96 |
+
|
| 97 |
+
## Tasks
|
| 98 |
+
|
| 99 |
+
PolyGuard evaluates these action-selection tasks:
|
| 100 |
+
|
| 101 |
+
- Find bad drug combinations and reduce DDI/polypharmacy side-effect risk.
|
| 102 |
+
- Recommend safe adds, substitutions, and alternatives.
|
| 103 |
+
- Optimize regimens under uncertainty.
|
| 104 |
+
- Produce taper/deprescribing sequences over time.
|
| 105 |
+
- Choose precision dosing actions when organ function or dose sensitivity matters.
|
| 106 |
+
- Fetch evidence when critical data is missing.
|
| 107 |
+
- Decompose a new drug into components for first-pass safety reasoning.
|
| 108 |
+
|
| 109 |
+
## Reward Model / Evaluation Logic
|
| 110 |
+
|
| 111 |
+
Rewards are verifier-backed and clamped to `[0.001, 0.999]`. The environment exposes 13 detailed reward columns and 4 primary channels:
|
| 112 |
+
|
| 113 |
+
- `safety_legality`
|
| 114 |
+
- `clinical_improvement`
|
| 115 |
+
- `dosing_quality`
|
| 116 |
+
- `process_integrity`
|
| 117 |
+
|
| 118 |
+
Reward logic combines:
|
| 119 |
+
|
| 120 |
+
- Legal action checks
|
| 121 |
+
- Safety delta and burden improvement
|
| 122 |
+
- Dosing quality
|
| 123 |
+
- Abstention quality under uncertainty
|
| 124 |
+
- Format compliance
|
| 125 |
+
- Process fidelity
|
| 126 |
+
- Explanation grounding
|
| 127 |
+
- Anti-cheat and timeout penalties
|
| 128 |
+
|
| 129 |
+
Anti-hacking checks block repeated action loops, review abuse, keep-regimen abuse, candidate ID mismatches, parser exploit patterns, and unsafe no-op behavior on known holdout DDIs.
|
| 130 |
+
|
| 131 |
+
## Training And Post-Training Strategy
|
| 132 |
+
|
| 133 |
+
The intended pipeline is:
|
| 134 |
+
|
| 135 |
+
1. Build data assets from local knowledge, synthetic patients, scenario rollouts, optional HF instruction data, optional DDI API augmentation, and optional web fallback.
|
| 136 |
+
2. Run SFT with TRL and optional Unsloth/QLoRA acceleration to teach action-selection format.
|
| 137 |
+
3. Run GRPO with environment-backed reward verification.
|
| 138 |
+
4. Track per-component reward columns and sampled generations.
|
| 139 |
+
5. Run policy-stack ablations against baselines.
|
| 140 |
+
6. Merge/export adapters safely.
|
| 141 |
+
7. Validate post-save inference from the exported artifact.
|
| 142 |
+
8. Deploy the OpenEnv environment to Hugging Face Spaces.
|
| 143 |
+
|
| 144 |
+
Core commands:
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
cd polyguard-rl
|
| 148 |
+
bash scripts/bootstrap_venv.sh
|
| 149 |
+
.venv/bin/python scripts/bootstrap_data.py
|
| 150 |
+
.venv/bin/python scripts/build_training_corpus.py --profile small --with-local --with-synthetic --with-hf
|
| 151 |
+
.venv/bin/python scripts/train_sft_trl.py --model-id Qwen/Qwen2.5-1.5B-Instruct --epochs 1 --max-steps 20 --use-unsloth
|
| 152 |
+
.venv/bin/python scripts/train_grpo_trl.py --model-id Qwen/Qwen2.5-1.5B-Instruct --max-steps 20 --num-generations 2 --use-unsloth
|
| 153 |
+
.venv/bin/python scripts/merge_adapters_safe.py --adapter-dir checkpoints/sft_adapter --output-dir checkpoints/merged
|
| 154 |
+
.venv/bin/python scripts/test_inference_postsave.py --samples 3
|
| 155 |
+
.venv/bin/python scripts/evaluate_all.py
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Results
|
| 159 |
+
|
| 160 |
+
Tracked smoke/evaluation artifacts are mirrored in `docs/results/` because `outputs/` and `checkpoints/` are intentionally ignored.
|
| 161 |
+
|
| 162 |
+

|
| 163 |
+
|
| 164 |
+

|
| 165 |
+
|
| 166 |
+
Current smoke reports show the environment, evaluation, and plotting paths are wired, but final training is not yet judge-ready:
|
| 167 |
+
|
| 168 |
+
- `docs/results/sft_trl_run.json` currently records a fallback backend.
|
| 169 |
+
- `docs/results/grpo_trl_run.json` currently records an environment-reward fallback path.
|
| 170 |
+
- `docs/results/postsave_inference.json` currently uses fallback inference.
|
| 171 |
+
- `docs/results/improvement_report.json` currently records no positive improvement.
|
| 172 |
+
- `docs/results/hf_space_verification.json` is blocked until HF auth/deployment succeeds.
|
| 173 |
+
|
| 174 |
+
Final submission should replace these with real GPU/Colab TRL/Unsloth artifacts.
|
| 175 |
+
|
| 176 |
+
## Dataset Gather
|
| 177 |
+
|
| 178 |
+
Implemented data generation and packaging covers:
|
| 179 |
+
|
| 180 |
+
- Normalized drug vocabulary and class tables
|
| 181 |
+
- Interaction graph edges
|
| 182 |
+
- Burden, taper, renal, hepatic, duplicate-therapy, and substitution rules
|
| 183 |
+
- Synthetic patients
|
| 184 |
+
- Easy/medium/hard scenario files
|
| 185 |
+
- Retrieval corpus and local evidence index
|
| 186 |
+
- Unified SFT and GRPO prompt corpora
|
| 187 |
+
|
| 188 |
+
The current local corpus summary is in `data/processed/training_corpus_summary.json` when generated.
|
| 189 |
+
|
| 190 |
+
## Deployment
|
| 191 |
+
|
| 192 |
+
Use the repository-local HF CLI entrypoint. The global `hf` command on this machine is known to be incompatible with its installed Typer version.
|
| 193 |
+
|
| 194 |
+
```bash
|
| 195 |
+
./.venv/bin/hf auth login
|
| 196 |
+
./.venv/bin/hf auth whoami
|
| 197 |
+
export HF_SPACE_REPO_ID="Vishwa-docs/polyguard-openenv"
|
| 198 |
+
bash scripts/deploy_space.sh --repo-id "$HF_SPACE_REPO_ID"
|
| 199 |
+
./.venv/bin/hf spaces info "$HF_SPACE_REPO_ID"
|
| 200 |
+
openenv validate --url "https://Vishwa-docs-polyguard-openenv.hf.space"
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
After deployment, save the successful Space info plus OpenEnv validation payload into `docs/results/hf_space_verification.json`.
|
| 204 |
+
|
| 205 |
+
## Strict Submission Gate
|
| 206 |
+
|
| 207 |
+
Non-strict local readiness:
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
.venv/bin/python scripts/acceptance_gate.py
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
Final submission readiness:
|
| 214 |
+
|
| 215 |
+
```bash
|
| 216 |
+
export POLYGUARD_ENFORCE_SUBMISSION_LINKS=true
|
| 217 |
+
.venv/bin/python scripts/acceptance_gate.py
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
Strict mode fails unless README links are real, tracked plots exist, HF Space verification passed, SFT/GRPO used real TRL/Unsloth paths, post-save inference uses the exported artifact, and measured improvement is positive.
|
| 221 |
+
|
| 222 |
+
## Documentation
|
| 223 |
+
|
| 224 |
+
- [Architecture](docs/architecture.md)
|
| 225 |
+
- [Environment Design](docs/environment_design.md)
|
| 226 |
+
- [Reward Design](docs/reward_design.md)
|
| 227 |
+
- [Training](docs/training.md)
|
| 228 |
+
- [Evaluation](docs/evaluation.md)
|
| 229 |
+
- [Deployment](docs/deployment.md)
|
| 230 |
+
- [Safety](docs/safety.md)
|
| 231 |
+
- [Agents](docs/agents.md)
|
| 232 |
+
- [Datasets](docs/datasets.md)
|
| 233 |
+
- [Math](docs/math.md)
|
| 234 |
+
- [Submission Checklist](docs/submission_checklist.md)
|
| 235 |
+
|
| 236 |
+
## Future Work
|
| 237 |
+
|
| 238 |
+
- Medicine image/barcode ingestion for regimen capture
|
| 239 |
+
- Larger model GRPO sweeps
|
| 240 |
+
- Stronger real-world drug-label ingestion and calibration
|
| 241 |
+
- More clinician-facing explanation studies
|
| 242 |
+
- Published HF blog or short video walkthrough
|
| 243 |
+
|
| 244 |
+
## License
|
| 245 |
+
|
| 246 |
+
MIT
|
__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root OpenEnv package shim for POLYGUARD-OPENENV."""
|
| 2 |
+
|
| 3 |
+
from app.env.env_core import PolyGuardEnv
|
| 4 |
+
|
| 5 |
+
__all__ = ["PolyGuardEnv"]
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""POLYGUARD-RL application package."""
|
app/agents/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent package."""
|
| 2 |
+
|
| 3 |
+
from app.agents.orchestrator import Orchestrator
|
| 4 |
+
|
| 5 |
+
__all__ = ["Orchestrator"]
|
app/agents/candidate_agent.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Candidate generation agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.models.policy.candidate_builder import build_candidates
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CandidateAgent:
|
| 10 |
+
name = "CandidateAgent"
|
| 11 |
+
|
| 12 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 13 |
+
candidates = build_candidates(state)
|
| 14 |
+
return {"candidates": [c.model_dump(mode="json") for c in candidates]}
|
app/agents/critic_agent.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Safety critic agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.enums import ActionType, DecisionMode, DoseBucket
|
| 6 |
+
from app.common.types import PolyGuardAction, PolyGuardState
|
| 7 |
+
from app.env.verifier import verify_action_legality
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CriticAgent:
|
| 11 |
+
name = "CriticAgent"
|
| 12 |
+
|
| 13 |
+
def run(self, state: PolyGuardState, proposed: PolyGuardAction) -> dict:
|
| 14 |
+
report = verify_action_legality(state, proposed)
|
| 15 |
+
if report.legal:
|
| 16 |
+
report_payload = report.model_dump(mode="json")
|
| 17 |
+
return {
|
| 18 |
+
"approved": True,
|
| 19 |
+
"report": report_payload,
|
| 20 |
+
"final_action": proposed,
|
| 21 |
+
"legal": True,
|
| 22 |
+
"violations": report_payload.get("violations", []),
|
| 23 |
+
}
|
| 24 |
+
fallback = PolyGuardAction(
|
| 25 |
+
mode=DecisionMode.REVIEW,
|
| 26 |
+
action_type=ActionType.REQUEST_SPECIALIST_REVIEW,
|
| 27 |
+
target_drug=None,
|
| 28 |
+
replacement_drug=None,
|
| 29 |
+
dose_bucket=DoseBucket.NA,
|
| 30 |
+
taper_days=None,
|
| 31 |
+
monitoring_plan="critic_veto",
|
| 32 |
+
candidate_id="cand_veto_fallback",
|
| 33 |
+
confidence=0.62,
|
| 34 |
+
rationale_brief=f"Critic veto: {', '.join(report.violations)}",
|
| 35 |
+
)
|
| 36 |
+
report_payload = report.model_dump(mode="json")
|
| 37 |
+
return {
|
| 38 |
+
"approved": False,
|
| 39 |
+
"report": report_payload,
|
| 40 |
+
"final_action": fallback,
|
| 41 |
+
"legal": False,
|
| 42 |
+
"violations": report_payload.get("violations", []),
|
| 43 |
+
}
|
app/agents/critic_safety_agent.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical CriticSafety agent module.
|
| 2 |
+
|
| 3 |
+
This file preserves required naming while reusing the current critic
|
| 4 |
+
implementation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from app.agents.critic_agent import CriticAgent as CriticSafetyAgent
|
| 10 |
+
|
| 11 |
+
__all__ = ["CriticSafetyAgent"]
|
app/agents/dosing_agent.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dosing analysis agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.drug_catalog import DRUG_CLASSES
|
| 7 |
+
from app.models.dosing.dose_policy_features import build_dose_features
|
| 8 |
+
from app.models.dosing.infer import infer_dosing_quality
|
| 9 |
+
from app.models.dosing.pkpd_state import PKPDState
|
| 10 |
+
from app.models.dosing.surrogate_pkpd import step_pkpd
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DosingAgent:
|
| 14 |
+
name = "DosingAgent"
|
| 15 |
+
|
| 16 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 17 |
+
sensitive_classes = {"anticoagulant", "sedative", "glucose_lowering"}
|
| 18 |
+
dose_sensitive = [
|
| 19 |
+
m.drug
|
| 20 |
+
for m in state.patient.medications
|
| 21 |
+
if DRUG_CLASSES.get(m.drug) in sensitive_classes
|
| 22 |
+
][:3]
|
| 23 |
+
analyses: list[dict] = []
|
| 24 |
+
for drug in dose_sensitive:
|
| 25 |
+
feats = build_dose_features(state.patient, drug)
|
| 26 |
+
base_state = PKPDState(
|
| 27 |
+
effect_level=min(1.0, 0.35 + feats["adherence"] * 0.45),
|
| 28 |
+
toxicity_level=min(1.0, 0.08 + feats["organ_stress"] * 0.4),
|
| 29 |
+
underdose_risk=max(0.0, 1.0 - (0.35 + feats["adherence"] * 0.45)),
|
| 30 |
+
organ_stress=feats["organ_stress"],
|
| 31 |
+
interaction_load=feats["interaction_load"],
|
| 32 |
+
)
|
| 33 |
+
lower = infer_dosing_quality(step_pkpd(base_state, dose_delta=-0.2, organ_factor=feats["organ_stress"], interaction_factor=feats["interaction_load"]))
|
| 34 |
+
hold = infer_dosing_quality(step_pkpd(base_state, dose_delta=0.0, organ_factor=feats["organ_stress"], interaction_factor=feats["interaction_load"]))
|
| 35 |
+
higher = infer_dosing_quality(step_pkpd(base_state, dose_delta=0.2, organ_factor=feats["organ_stress"], interaction_factor=feats["interaction_load"]))
|
| 36 |
+
analyses.append(
|
| 37 |
+
{
|
| 38 |
+
"drug": drug,
|
| 39 |
+
"features": feats,
|
| 40 |
+
"options": {
|
| 41 |
+
"reduce": lower,
|
| 42 |
+
"hold": hold,
|
| 43 |
+
"increase": higher,
|
| 44 |
+
},
|
| 45 |
+
}
|
| 46 |
+
)
|
| 47 |
+
return {
|
| 48 |
+
"dose_sensitive_drugs": dose_sensitive,
|
| 49 |
+
"dosing_active": bool(dose_sensitive),
|
| 50 |
+
"recommend_mode": "DOSE_OPT" if dose_sensitive else "REGIMEN_OPT",
|
| 51 |
+
"analyses": analyses,
|
| 52 |
+
}
|
app/agents/evidence_agent.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evidence retrieval agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.evidence_retriever import retrieve_evidence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EvidenceAgent:
|
| 10 |
+
name = "EvidenceAgent"
|
| 11 |
+
|
| 12 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 13 |
+
query = " ".join(state.patient.comorbidities + [m.drug for m in state.patient.medications[:2]])
|
| 14 |
+
return {"evidence": retrieve_evidence(query=query, top_k=3)}
|
app/agents/explainer_agent.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Explanation agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardAction, PolyGuardState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ExplainerAgent:
|
| 9 |
+
name = "ExplainerAgent"
|
| 10 |
+
|
| 11 |
+
def run(self, state: PolyGuardState, action: PolyGuardAction, critic_report: dict) -> dict:
|
| 12 |
+
return {
|
| 13 |
+
"explanation": (
|
| 14 |
+
f"Action {action.action_type.value} selected for mode {action.mode.value}. "
|
| 15 |
+
f"Burden score={state.burden_score:.3f}, meds={len(state.patient.medications)}. "
|
| 16 |
+
f"Critic legal={critic_report.get('legal', False)}."
|
| 17 |
+
),
|
| 18 |
+
"grounded_facts": {
|
| 19 |
+
"burden_score": state.burden_score,
|
| 20 |
+
"polypharmacy_count": len(state.patient.medications),
|
| 21 |
+
},
|
| 22 |
+
}
|
app/agents/graph_agent.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph safety agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.ddi_knowledge import top_risky_pairs
|
| 7 |
+
from app.models.graph.infer import infer_graph_risk
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GraphSafetyAgent:
|
| 11 |
+
name = "GraphSafetyAgent"
|
| 12 |
+
|
| 13 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 14 |
+
drugs = [m.drug for m in state.patient.medications]
|
| 15 |
+
risk = infer_graph_risk(drugs)
|
| 16 |
+
top_pairs = top_risky_pairs(drugs)
|
| 17 |
+
triples = []
|
| 18 |
+
if len(drugs) >= 3:
|
| 19 |
+
triples = [
|
| 20 |
+
[drugs[i], drugs[i + 1], drugs[i + 2]]
|
| 21 |
+
for i in range(min(2, len(drugs) - 2))
|
| 22 |
+
]
|
| 23 |
+
return {
|
| 24 |
+
**risk,
|
| 25 |
+
"top_dangerous_pairs": top_pairs[:5],
|
| 26 |
+
"top_dangerous_triples": triples,
|
| 27 |
+
"mechanism_tags": list(risk.get("side_effect_probs", {}).keys())[:5],
|
| 28 |
+
}
|
app/agents/graph_safety_agent.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical GraphSafety agent module.
|
| 2 |
+
|
| 3 |
+
This file is kept for required path compatibility and re-exports the
|
| 4 |
+
implementation from ``graph_agent.py``.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from app.agents.graph_agent import GraphSafetyAgent
|
| 10 |
+
|
| 11 |
+
__all__ = ["GraphSafetyAgent"]
|
app/agents/medrec_agent.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Medication reconciliation agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.drug_catalog import canonicalize_drug_name
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MedRecAgent:
|
| 10 |
+
name = "MedRecAgent"
|
| 11 |
+
|
| 12 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 13 |
+
normalized = []
|
| 14 |
+
duplicates = set()
|
| 15 |
+
seen = set()
|
| 16 |
+
for med in state.patient.medications:
|
| 17 |
+
med.drug = canonicalize_drug_name(med.drug)
|
| 18 |
+
normalized.append(med.drug)
|
| 19 |
+
if med.drug in seen:
|
| 20 |
+
duplicates.add(med.drug)
|
| 21 |
+
seen.add(med.drug)
|
| 22 |
+
return {"normalized_meds": normalized, "duplicates": sorted(duplicates)}
|
app/agents/orchestrator.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-agent orchestration graph."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from app.agents.candidate_agent import CandidateAgent
|
| 9 |
+
from app.agents.critic_agent import CriticAgent
|
| 10 |
+
from app.agents.dosing_agent import DosingAgent
|
| 11 |
+
from app.agents.evidence_agent import EvidenceAgent
|
| 12 |
+
from app.agents.explainer_agent import ExplainerAgent
|
| 13 |
+
from app.agents.graph_agent import GraphSafetyAgent
|
| 14 |
+
from app.agents.medrec_agent import MedRecAgent
|
| 15 |
+
from app.agents.planner_agent import PlannerAgent
|
| 16 |
+
from app.agents.supervisor_agent import SupervisorAgent
|
| 17 |
+
from app.common.enums import CoordinationMode
|
| 18 |
+
from app.common.types import CandidateAction, PolyGuardAction
|
| 19 |
+
from app.env.env_core import PolyGuardEnv
|
| 20 |
+
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Orchestrator:
|
| 24 |
+
def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None:
|
| 25 |
+
self.env = env
|
| 26 |
+
self.coordination_mode = coordination_mode
|
| 27 |
+
self.medrec = MedRecAgent()
|
| 28 |
+
self.evidence = EvidenceAgent()
|
| 29 |
+
self.graph = GraphSafetyAgent()
|
| 30 |
+
self.dosing = DosingAgent()
|
| 31 |
+
self.candidate = CandidateAgent()
|
| 32 |
+
self.supervisor = SupervisorAgent()
|
| 33 |
+
self.planner = PlannerAgent()
|
| 34 |
+
self.critic = CriticAgent()
|
| 35 |
+
self.explainer = ExplainerAgent()
|
| 36 |
+
bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower()
|
| 37 |
+
if bandit_algo not in {"linucb", "thompson"}:
|
| 38 |
+
bandit_algo = "linucb"
|
| 39 |
+
self.bandit = ContextualBanditPolicy(
|
| 40 |
+
algorithm=bandit_algo, # type: ignore[arg-type]
|
| 41 |
+
alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")),
|
| 42 |
+
epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")),
|
| 43 |
+
seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")),
|
| 44 |
+
)
|
| 45 |
+
self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower()
|
| 46 |
+
self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3"))
|
| 47 |
+
|
| 48 |
+
def set_mode(self, coordination_mode: CoordinationMode) -> None:
|
| 49 |
+
self.coordination_mode = coordination_mode
|
| 50 |
+
|
| 51 |
+
def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]:
|
| 52 |
+
if coordination_mode is not None:
|
| 53 |
+
self.coordination_mode = CoordinationMode(coordination_mode)
|
| 54 |
+
state = self.env.state
|
| 55 |
+
medrec_out = self.medrec.run(state)
|
| 56 |
+
evidence_out = self.evidence.run(state)
|
| 57 |
+
graph_out = self.graph.run(state)
|
| 58 |
+
dosing_out = self.dosing.run(state)
|
| 59 |
+
candidate_out = self.candidate.run(state)
|
| 60 |
+
candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]]
|
| 61 |
+
|
| 62 |
+
supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"])
|
| 63 |
+
planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates
|
| 64 |
+
if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW":
|
| 65 |
+
planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates
|
| 66 |
+
|
| 67 |
+
candidate_by_id = {item.candidate_id: item for item in planner_candidates}
|
| 68 |
+
bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k)
|
| 69 |
+
bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id]
|
| 70 |
+
if not bandit_candidates:
|
| 71 |
+
bandit_candidates = planner_candidates
|
| 72 |
+
|
| 73 |
+
if self.policy_stack == "bandit-only":
|
| 74 |
+
selected = bandit_candidates[0]
|
| 75 |
+
proposed = PolyGuardAction(
|
| 76 |
+
mode=selected.mode,
|
| 77 |
+
action_type=selected.action_type,
|
| 78 |
+
target_drug=selected.target_drug,
|
| 79 |
+
replacement_drug=selected.replacement_drug,
|
| 80 |
+
dose_bucket=selected.dose_bucket,
|
| 81 |
+
taper_days=selected.taper_days,
|
| 82 |
+
monitoring_plan=selected.monitoring_plan,
|
| 83 |
+
candidate_id=selected.candidate_id,
|
| 84 |
+
confidence=max(0.45, 1.0 - selected.uncertainty_score),
|
| 85 |
+
rationale_brief="Bandit-only policy selected top contextual candidate.",
|
| 86 |
+
)
|
| 87 |
+
elif self.policy_stack == "llm-only":
|
| 88 |
+
proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"])
|
| 89 |
+
else:
|
| 90 |
+
proposed = self.planner.run(
|
| 91 |
+
candidates=bandit_candidates,
|
| 92 |
+
mode=supervisor_out["mode"],
|
| 93 |
+
provider_prompt={
|
| 94 |
+
"coordination_mode": self.coordination_mode.value,
|
| 95 |
+
"policy_stack": self.policy_stack,
|
| 96 |
+
"candidate_count": len(bandit_candidates),
|
| 97 |
+
"sub_environment": state.sub_environment.value,
|
| 98 |
+
},
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
critic_out = self.critic.run(state, proposed)
|
| 102 |
+
final_action: PolyGuardAction = critic_out["final_action"]
|
| 103 |
+
replan_triggered = False
|
| 104 |
+
debate_rounds = 0
|
| 105 |
+
|
| 106 |
+
if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]:
|
| 107 |
+
replan_triggered = True
|
| 108 |
+
review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates
|
| 109 |
+
proposed = self.planner.run(candidates=review_candidates, mode="REVIEW")
|
| 110 |
+
critic_out = self.critic.run(state, proposed)
|
| 111 |
+
final_action = critic_out["final_action"]
|
| 112 |
+
debate_rounds = 1
|
| 113 |
+
|
| 114 |
+
if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type:
|
| 115 |
+
debate_rounds = 2
|
| 116 |
+
|
| 117 |
+
obs, reward, done, info = self.env.step(final_action)
|
| 118 |
+
selected_for_update = candidate_by_id.get(final_action.candidate_id)
|
| 119 |
+
if selected_for_update is not None:
|
| 120 |
+
self.bandit.update(selected_for_update, reward=reward)
|
| 121 |
+
|
| 122 |
+
explanation_out = self.explainer.run(state, final_action, critic_out["report"])
|
| 123 |
+
return {
|
| 124 |
+
"medrec": medrec_out,
|
| 125 |
+
"evidence": evidence_out,
|
| 126 |
+
"graph": graph_out,
|
| 127 |
+
"dosing": dosing_out,
|
| 128 |
+
"supervisor": supervisor_out,
|
| 129 |
+
"proposed_action": proposed.model_dump(mode="json"),
|
| 130 |
+
"critic": critic_out["report"],
|
| 131 |
+
"final_action": final_action.model_dump(mode="json"),
|
| 132 |
+
"observation": obs.model_dump(mode="json"),
|
| 133 |
+
"reward": reward,
|
| 134 |
+
"done": done,
|
| 135 |
+
"info": info,
|
| 136 |
+
"explanation": explanation_out,
|
| 137 |
+
"coordination_mode": self.coordination_mode.value,
|
| 138 |
+
"policy_stack": self.policy_stack,
|
| 139 |
+
"bandit_topk": [item.candidate_id for item in bandit_candidates],
|
| 140 |
+
"bandit_scores": [
|
| 141 |
+
{
|
| 142 |
+
"candidate_id": item.candidate_id,
|
| 143 |
+
"score": item.score,
|
| 144 |
+
"exploration_bonus": item.exploration_bonus,
|
| 145 |
+
"algorithm": item.algorithm,
|
| 146 |
+
}
|
| 147 |
+
for item in bandit_proposals
|
| 148 |
+
],
|
| 149 |
+
"replan_triggered": replan_triggered,
|
| 150 |
+
"debate_rounds": debate_rounds,
|
| 151 |
+
}
|
app/agents/planner_agent.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Planner agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from app.common.types import CandidateAction, PolyGuardAction
|
| 8 |
+
from app.models.policy.provider_runtime import PolicyProviderRouter
|
| 9 |
+
from app.models.policy.safety_ranker import rank_candidates
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PlannerAgent:
|
| 13 |
+
name = "PlannerAgent"
|
| 14 |
+
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
self.provider_router = PolicyProviderRouter()
|
| 17 |
+
|
| 18 |
+
def run(
|
| 19 |
+
self,
|
| 20 |
+
candidates: list[CandidateAction],
|
| 21 |
+
mode: str,
|
| 22 |
+
provider_prompt: dict[str, Any] | None = None,
|
| 23 |
+
provider_preference: tuple[str, ...] = ("transformers",),
|
| 24 |
+
) -> PolyGuardAction:
|
| 25 |
+
filtered = [c for c in candidates if c.mode.value == mode] or candidates
|
| 26 |
+
selection = self.provider_router.select_candidate(
|
| 27 |
+
candidates=filtered,
|
| 28 |
+
prompt=provider_prompt or {"mode": mode},
|
| 29 |
+
provider_preference=provider_preference,
|
| 30 |
+
)
|
| 31 |
+
by_id = {item.candidate_id: item for item in filtered}
|
| 32 |
+
top = by_id.get(selection.candidate_id, rank_candidates(filtered)[0])
|
| 33 |
+
return PolyGuardAction(
|
| 34 |
+
mode=top.mode,
|
| 35 |
+
action_type=top.action_type,
|
| 36 |
+
target_drug=top.target_drug,
|
| 37 |
+
replacement_drug=top.replacement_drug,
|
| 38 |
+
dose_bucket=top.dose_bucket,
|
| 39 |
+
taper_days=top.taper_days,
|
| 40 |
+
monitoring_plan=top.monitoring_plan,
|
| 41 |
+
candidate_id=top.candidate_id,
|
| 42 |
+
confidence=max(0.45, 1.0 - top.uncertainty_score),
|
| 43 |
+
rationale_brief=selection.rationale,
|
| 44 |
+
)
|
app/agents/supervisor_agent.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supervisor agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.enums import DecisionMode
|
| 6 |
+
from app.common.types import PolyGuardState
|
| 7 |
+
from app.models.policy.uncertainty import estimate_uncertainty
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SupervisorAgent:
|
| 11 |
+
name = "SupervisorAgent"
|
| 12 |
+
|
| 13 |
+
def run(self, state: PolyGuardState, dosing_active: bool) -> dict:
|
| 14 |
+
uncertainty = estimate_uncertainty(state)
|
| 15 |
+
if uncertainty > 0.72:
|
| 16 |
+
mode = DecisionMode.REVIEW
|
| 17 |
+
elif state.sub_environment.value == "PRECISION_DOSING":
|
| 18 |
+
mode = DecisionMode.DOSE_OPT
|
| 19 |
+
elif dosing_active:
|
| 20 |
+
mode = DecisionMode.DOSE_OPT
|
| 21 |
+
else:
|
| 22 |
+
mode = DecisionMode.REGIMEN_OPT
|
| 23 |
+
return {"mode": mode.value, "uncertainty": uncertainty, "sub_environment": state.sub_environment.value}
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API application entrypoint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import uvicorn
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
|
| 11 |
+
from app.api.routes import router
|
| 12 |
+
|
| 13 |
+
app = FastAPI(title="POLYGUARD-RL API", version="0.1.0")
|
| 14 |
+
app.add_middleware(
|
| 15 |
+
CORSMiddleware,
|
| 16 |
+
allow_origins=[
|
| 17 |
+
"http://127.0.0.1:5173",
|
| 18 |
+
"http://localhost:5173",
|
| 19 |
+
],
|
| 20 |
+
allow_credentials=True,
|
| 21 |
+
allow_methods=["*"],
|
| 22 |
+
allow_headers=["*"],
|
| 23 |
+
)
|
| 24 |
+
app.include_router(router)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> None:
|
| 28 |
+
host = os.getenv("POLYGUARD_API_HOST", "127.0.0.1")
|
| 29 |
+
port = int(os.getenv("POLYGUARD_API_PORT", "8200"))
|
| 30 |
+
uvicorn.run("app.api:app", host=host, port=port, reload=False)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
main()
|
app/api/__main__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run API with `python -m app.api`."""
|
| 2 |
+
|
| 3 |
+
from app.api import main
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
main()
|
app/api/dependencies.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API dependencies."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.api.service import APIService
|
| 6 |
+
|
| 7 |
+
_SERVICE = APIService()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_service() -> APIService:
|
| 11 |
+
return _SERVICE
|
app/api/main.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical API module path.
|
| 2 |
+
|
| 3 |
+
Keeps compatibility with required file path while reusing ``app.api`` app.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from app.api import app, main
|
| 9 |
+
|
| 10 |
+
__all__ = ["app", "main"]
|
app/api/routes.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API routes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 6 |
+
|
| 7 |
+
from app.api.dependencies import get_service
|
| 8 |
+
from app.api.schemas import (
|
| 9 |
+
BatchInferRequest,
|
| 10 |
+
EvidenceQueryRequest,
|
| 11 |
+
OrchestrateRequest,
|
| 12 |
+
ResetRequest,
|
| 13 |
+
StepCandidateRequest,
|
| 14 |
+
StepRequest,
|
| 15 |
+
)
|
| 16 |
+
from app.api.service import APIService
|
| 17 |
+
|
| 18 |
+
router = APIRouter()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@router.get("/health")
|
| 22 |
+
def health() -> dict[str, str]:
|
| 23 |
+
return {"status": "ok"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.post("/env/reset")
|
| 27 |
+
def env_reset(payload: ResetRequest, service: APIService = Depends(get_service)) -> dict:
|
| 28 |
+
try:
|
| 29 |
+
return service.reset(**payload.model_dump(mode="json"))
|
| 30 |
+
except ValueError as exc:
|
| 31 |
+
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@router.post("/env/step")
|
| 35 |
+
def env_step(payload: StepRequest, service: APIService = Depends(get_service)) -> dict:
|
| 36 |
+
return service.step(payload.model_dump(mode="json"))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.post("/env/step_candidate")
|
| 40 |
+
def env_step_candidate(payload: StepCandidateRequest, service: APIService = Depends(get_service)) -> dict:
|
| 41 |
+
result = service.step_candidate(
|
| 42 |
+
candidate_id=payload.candidate_id,
|
| 43 |
+
confidence=payload.confidence,
|
| 44 |
+
rationale_brief=payload.rationale_brief,
|
| 45 |
+
)
|
| 46 |
+
if result is None:
|
| 47 |
+
raise HTTPException(status_code=404, detail=f"Candidate {payload.candidate_id!r} is not legal in this state.")
|
| 48 |
+
return result
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.get("/env/catalog")
|
| 52 |
+
def env_catalog(service: APIService = Depends(get_service)) -> dict:
|
| 53 |
+
return service.catalog()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@router.get("/env/state")
|
| 57 |
+
def env_state(service: APIService = Depends(get_service)) -> dict:
|
| 58 |
+
return service.env.get_state()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@router.get("/env/trace")
|
| 62 |
+
def env_trace(service: APIService = Depends(get_service)) -> list[dict]:
|
| 63 |
+
return service.env.get_trace()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@router.get("/env/legal_actions")
|
| 67 |
+
def env_legal_actions(service: APIService = Depends(get_service)) -> list[dict]:
|
| 68 |
+
return service.env.get_legal_actions()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@router.get("/env/reward_breakdown")
|
| 72 |
+
def env_reward_breakdown(service: APIService = Depends(get_service)) -> dict:
|
| 73 |
+
return service.env.get_reward_breakdown()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@router.get("/env/uncertainty")
|
| 77 |
+
def env_uncertainty(service: APIService = Depends(get_service)) -> dict:
|
| 78 |
+
return service.env.get_uncertainty_report().model_dump(mode="json")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@router.post("/agents/orchestrate")
|
| 82 |
+
def agents_orchestrate(
|
| 83 |
+
payload: OrchestrateRequest = OrchestrateRequest(),
|
| 84 |
+
service: APIService = Depends(get_service),
|
| 85 |
+
) -> dict:
|
| 86 |
+
return service.orchestrate(coordination_mode=payload.coordination_mode)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@router.post("/policy/infer")
|
| 90 |
+
def policy_infer(service: APIService = Depends(get_service)) -> dict:
|
| 91 |
+
return service.infer_policy()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@router.post("/policy/batch_infer")
|
| 95 |
+
def policy_batch_infer(
|
| 96 |
+
payload: BatchInferRequest = BatchInferRequest(),
|
| 97 |
+
service: APIService = Depends(get_service),
|
| 98 |
+
) -> list[dict]:
|
| 99 |
+
return service.batch_infer(batch_size=payload.batch_size)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.post("/eval/run_baselines")
|
| 103 |
+
def eval_baselines(service: APIService = Depends(get_service)) -> dict:
|
| 104 |
+
return service.run_baselines()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@router.post("/eval/run_policy")
|
| 108 |
+
def eval_run_policy(service: APIService = Depends(get_service)) -> dict:
|
| 109 |
+
return service.run_policy_eval()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@router.post("/eval/run_dosing")
|
| 113 |
+
def eval_run_dosing(service: APIService = Depends(get_service)) -> dict:
|
| 114 |
+
return service.run_dosing_eval()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@router.get("/metrics/training")
|
| 118 |
+
def metrics_training(service: APIService = Depends(get_service)) -> dict:
|
| 119 |
+
return service.get_metrics()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@router.get("/cases/sample")
|
| 123 |
+
def cases_sample(service: APIService = Depends(get_service)) -> dict:
|
| 124 |
+
return service.sample_case()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@router.get("/cases/search")
|
| 128 |
+
def cases_search(q: str, service: APIService = Depends(get_service)) -> list[dict]:
|
| 129 |
+
return service.search_cases(q)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@router.post("/evidence/query")
|
| 133 |
+
def evidence_query(payload: EvidenceQueryRequest, service: APIService = Depends(get_service)) -> list[dict]:
|
| 134 |
+
return service.evidence_query(query=payload.query, top_k=payload.top_k)
|
app/api/schemas.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API schemas."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 8 |
+
|
| 9 |
+
from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StrictSchema(BaseModel):
|
| 13 |
+
model_config = ConfigDict(extra="forbid")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResetRequest(StrictSchema):
|
| 17 |
+
task_id: Optional[str] = None
|
| 18 |
+
seed: Optional[int] = None
|
| 19 |
+
difficulty: Optional[Difficulty] = None
|
| 20 |
+
sub_environment: Optional[SubEnvironment] = None
|
| 21 |
+
scenario_id: Optional[str] = None
|
| 22 |
+
patient_id: Optional[str] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class StepRequest(StrictSchema):
|
| 26 |
+
mode: DecisionMode
|
| 27 |
+
action_type: ActionType
|
| 28 |
+
target_drug: Optional[str] = None
|
| 29 |
+
replacement_drug: Optional[str] = None
|
| 30 |
+
dose_bucket: DoseBucket
|
| 31 |
+
taper_days: Optional[int] = None
|
| 32 |
+
monitoring_plan: Optional[str] = None
|
| 33 |
+
evidence_query: Optional[str] = None
|
| 34 |
+
new_drug_name: Optional[str] = None
|
| 35 |
+
candidate_components: list[str] = Field(default_factory=list)
|
| 36 |
+
candidate_id: str
|
| 37 |
+
confidence: float
|
| 38 |
+
rationale_brief: str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class StepCandidateRequest(StrictSchema):
|
| 42 |
+
candidate_id: str
|
| 43 |
+
confidence: float
|
| 44 |
+
rationale_brief: str
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class OrchestrateRequest(StrictSchema):
|
| 48 |
+
coordination_mode: Optional[str] = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class BatchInferRequest(StrictSchema):
|
| 52 |
+
batch_size: int = 4
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EvidenceQueryRequest(StrictSchema):
|
| 56 |
+
query: str
|
| 57 |
+
top_k: int = 5
|
app/api/service.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API service layer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from app.agents.orchestrator import Orchestrator
|
| 9 |
+
from app.env.catalog import apply_task_preset, env_catalog
|
| 10 |
+
from app.env.env_core import PolyGuardEnv
|
| 11 |
+
from app.evaluation.benchmark_report import build_benchmark_report
|
| 12 |
+
from app.evaluation.dosing_eval import dosing_eval
|
| 13 |
+
from app.knowledge.evidence_retriever import retrieve_evidence
|
| 14 |
+
from app.models.retrieval.retriever import retrieve
|
| 15 |
+
from app.models.baselines import (
|
| 16 |
+
choose_beam_search,
|
| 17 |
+
choose_contextual_bandit,
|
| 18 |
+
choose_contextual_bandit_topk,
|
| 19 |
+
choose_greedy,
|
| 20 |
+
choose_no_change,
|
| 21 |
+
choose_rules_only,
|
| 22 |
+
)
|
| 23 |
+
from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class APIService:
|
| 27 |
+
def __init__(self) -> None:
|
| 28 |
+
self.env = PolyGuardEnv()
|
| 29 |
+
self.orchestrator = Orchestrator(self.env)
|
| 30 |
+
self.training_metrics: dict[str, Any] = {}
|
| 31 |
+
self.root = Path(__file__).resolve().parents[2]
|
| 32 |
+
|
| 33 |
+
def reset(self, **kwargs: Any) -> dict[str, Any]:
|
| 34 |
+
kwargs = apply_task_preset(dict(kwargs))
|
| 35 |
+
obs = self.env.reset(**kwargs)
|
| 36 |
+
return obs.model_dump(mode="json")
|
| 37 |
+
|
| 38 |
+
def step(self, action: dict[str, Any]) -> dict[str, Any]:
|
| 39 |
+
obs, reward, done, info = self.env.step(action)
|
| 40 |
+
reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else ""
|
| 41 |
+
truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"}
|
| 42 |
+
return {
|
| 43 |
+
"observation": obs.model_dump(mode="json"),
|
| 44 |
+
"reward": reward,
|
| 45 |
+
"done": done,
|
| 46 |
+
"terminated": done,
|
| 47 |
+
"truncated": truncated,
|
| 48 |
+
"info": info,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def catalog(self) -> dict[str, Any]:
|
| 52 |
+
return env_catalog()
|
| 53 |
+
|
| 54 |
+
def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None:
|
| 55 |
+
for action in self.env.get_legal_actions():
|
| 56 |
+
if action.get("candidate_id") != candidate_id:
|
| 57 |
+
continue
|
| 58 |
+
payload = dict(action)
|
| 59 |
+
payload["confidence"] = confidence
|
| 60 |
+
payload["rationale_brief"] = rationale_brief
|
| 61 |
+
return self.step(payload)
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]:
|
| 65 |
+
return self.orchestrator.run_step(coordination_mode=coordination_mode)
|
| 66 |
+
|
| 67 |
+
def infer_policy(self) -> dict[str, Any]:
|
| 68 |
+
legal = self.env.get_legal_actions()
|
| 69 |
+
return legal[0] if legal else {}
|
| 70 |
+
|
| 71 |
+
def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]:
|
| 72 |
+
legal = self.env.get_legal_actions()
|
| 73 |
+
return legal[:batch_size]
|
| 74 |
+
|
| 75 |
+
def run_baselines(self) -> dict[str, Any]:
|
| 76 |
+
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
|
| 77 |
+
if not candidates:
|
| 78 |
+
self.env.reset()
|
| 79 |
+
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
|
| 80 |
+
baseline_results = {
|
| 81 |
+
"no_change": choose_no_change().model_dump(mode="json"),
|
| 82 |
+
"rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 83 |
+
"greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 84 |
+
"contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 85 |
+
"contextual_bandit_topk": [
|
| 86 |
+
{
|
| 87 |
+
"candidate_id": item.candidate_id,
|
| 88 |
+
"score": item.score,
|
| 89 |
+
"exploration_bonus": item.exploration_bonus,
|
| 90 |
+
"algorithm": item.algorithm,
|
| 91 |
+
}
|
| 92 |
+
for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3)
|
| 93 |
+
],
|
| 94 |
+
"beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 95 |
+
}
|
| 96 |
+
return baseline_results
|
| 97 |
+
|
| 98 |
+
def run_policy_eval(self) -> dict[str, Any]:
|
| 99 |
+
out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt"))
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
def run_dosing_eval(self) -> dict[str, Any]:
|
| 103 |
+
return dosing_eval()
|
| 104 |
+
|
| 105 |
+
def run_training(self) -> dict[str, Any]:
|
| 106 |
+
out_dir = Path("checkpoints")
|
| 107 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
self.training_metrics = {
|
| 109 |
+
"supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir),
|
| 110 |
+
"planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir),
|
| 111 |
+
"dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir),
|
| 112 |
+
}
|
| 113 |
+
return self.training_metrics
|
| 114 |
+
|
| 115 |
+
def get_metrics(self) -> dict[str, Any]:
|
| 116 |
+
if self.training_metrics:
|
| 117 |
+
if "planner" in self.training_metrics:
|
| 118 |
+
merged = dict(self.training_metrics["planner"])
|
| 119 |
+
merged["model_metrics"] = self.training_metrics
|
| 120 |
+
return merged
|
| 121 |
+
return self.training_metrics
|
| 122 |
+
reports_dir = Path("outputs/reports")
|
| 123 |
+
metrics: dict[str, Any] = {}
|
| 124 |
+
for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]:
|
| 125 |
+
path = reports_dir / f"{name}.json"
|
| 126 |
+
if path.exists():
|
| 127 |
+
import json
|
| 128 |
+
|
| 129 |
+
metrics[name] = json.loads(path.read_text(encoding="utf-8"))
|
| 130 |
+
self.training_metrics = metrics
|
| 131 |
+
if "planner_grpo" in metrics:
|
| 132 |
+
merged = dict(metrics["planner_grpo"])
|
| 133 |
+
merged["model_metrics"] = metrics
|
| 134 |
+
return merged
|
| 135 |
+
return metrics
|
| 136 |
+
|
| 137 |
+
def sample_case(self) -> dict[str, Any]:
|
| 138 |
+
obs = self.env.reset()
|
| 139 |
+
return obs.model_dump(mode="json")
|
| 140 |
+
|
| 141 |
+
def search_cases(self, query: str) -> list[dict[str, Any]]:
|
| 142 |
+
index_file = self.root / "data" / "retrieval_index" / "index.json"
|
| 143 |
+
hits = retrieve(index_file=index_file, query=query, top_k=5)
|
| 144 |
+
if hits:
|
| 145 |
+
return [
|
| 146 |
+
{
|
| 147 |
+
"patient_id": Path(item.get("path", f"case_{idx}")).stem,
|
| 148 |
+
"query": query,
|
| 149 |
+
"source_path": item.get("path", ""),
|
| 150 |
+
"snippet": str(item.get("text", ""))[:280],
|
| 151 |
+
}
|
| 152 |
+
for idx, item in enumerate(hits)
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
fallback: list[dict[str, Any]] = []
|
| 156 |
+
corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl"
|
| 157 |
+
if corpus.exists():
|
| 158 |
+
query_tokens = {token for token in query.lower().split() if token}
|
| 159 |
+
with corpus.open("r", encoding="utf-8") as handle:
|
| 160 |
+
for idx, line in enumerate(handle):
|
| 161 |
+
if len(fallback) >= 5:
|
| 162 |
+
break
|
| 163 |
+
text = line.strip()
|
| 164 |
+
if not text:
|
| 165 |
+
continue
|
| 166 |
+
hay = text.lower()
|
| 167 |
+
if query_tokens and not any(token in hay for token in query_tokens):
|
| 168 |
+
continue
|
| 169 |
+
fallback.append(
|
| 170 |
+
{
|
| 171 |
+
"patient_id": f"retrieval_corpus_{idx}",
|
| 172 |
+
"query": query,
|
| 173 |
+
"source_path": str(corpus),
|
| 174 |
+
"snippet": text[:280],
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
return fallback
|
| 178 |
+
|
| 179 |
+
def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]:
|
| 180 |
+
return retrieve_evidence(query=query, top_k=top_k)
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def _candidate_obj(payload: dict) -> Any:
|
| 184 |
+
from app.common.types import CandidateAction
|
| 185 |
+
|
| 186 |
+
return CandidateAction.model_validate(payload)
|
app/common/config.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration loading."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _read_yaml(path: Path) -> dict[str, Any]:
|
| 13 |
+
if not path.exists():
|
| 14 |
+
return {}
|
| 15 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 16 |
+
return yaml.safe_load(handle) or {}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_config(config_name: str = "base.yaml") -> dict[str, Any]:
|
| 20 |
+
root = Path(__file__).resolve().parents[2]
|
| 21 |
+
config_path = root / "configs" / config_name
|
| 22 |
+
return _read_yaml(config_path)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def env_bool(name: str, default: bool = False) -> bool:
|
| 26 |
+
raw = os.getenv(name)
|
| 27 |
+
if raw is None:
|
| 28 |
+
return default
|
| 29 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def env_int(name: str, default: int) -> int:
|
| 33 |
+
raw = os.getenv(name)
|
| 34 |
+
if raw is None:
|
| 35 |
+
return default
|
| 36 |
+
try:
|
| 37 |
+
return int(raw)
|
| 38 |
+
except ValueError:
|
| 39 |
+
return default
|
app/common/constants.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared constants for POLYGUARD-RL."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
REWARD_MIN: float = 0.001
|
| 6 |
+
REWARD_MAX: float = 0.999
|
| 7 |
+
REWARD_PRECISION: int = 3
|
| 8 |
+
|
| 9 |
+
DEFAULT_SEED: int = 42
|
| 10 |
+
DEFAULT_MAX_STEPS: int = 10
|
| 11 |
+
MAX_REPEATED_ACTIONS: int = 3
|
| 12 |
+
MAX_KEEP_REGIMEN_RATIO: float = 0.6
|
| 13 |
+
MAX_REVIEW_RATIO: float = 0.5
|
| 14 |
+
DEFAULT_STEP_TIMEOUT_SECONDS: float = 2.5
|
| 15 |
+
DEFAULT_EPISODE_TIMEOUT_SECONDS: float = 45.0
|
| 16 |
+
|
| 17 |
+
DEFAULT_REWARD_WEIGHTS: dict[str, float] = {
|
| 18 |
+
"format_compliance_score": 0.08,
|
| 19 |
+
"candidate_alignment_score": 0.08,
|
| 20 |
+
"legality_score": 0.12,
|
| 21 |
+
"safety_delta_score": 0.15,
|
| 22 |
+
"burden_improvement_score": 0.08,
|
| 23 |
+
"disease_stability_score": 0.10,
|
| 24 |
+
"dosing_quality_score": 0.08,
|
| 25 |
+
"abstention_quality_score": 0.06,
|
| 26 |
+
"efficiency_score": 0.06,
|
| 27 |
+
"process_fidelity_score": 0.06,
|
| 28 |
+
"explanation_grounding_score": 0.03,
|
| 29 |
+
"anti_cheat_score": 0.06,
|
| 30 |
+
"uncertainty_calibration_score": 0.04,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
REQUIRED_REWARD_KEYS: tuple[str, ...] = tuple(DEFAULT_REWARD_WEIGHTS.keys())
|
| 34 |
+
|
| 35 |
+
PRIMARY_REWARD_KEYS: tuple[str, ...] = (
|
| 36 |
+
"safety_legality",
|
| 37 |
+
"clinical_improvement",
|
| 38 |
+
"dosing_quality",
|
| 39 |
+
"process_integrity",
|
| 40 |
+
)
|
app/common/enums.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enumerations used throughout POLYGUARD-RL."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Difficulty(str, Enum):
|
| 9 |
+
EASY = "easy"
|
| 10 |
+
MEDIUM = "medium"
|
| 11 |
+
HARD = "hard"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SubEnvironment(str, Enum):
|
| 15 |
+
DDI = "DDI"
|
| 16 |
+
BANDIT_MINING = "BANDIT_MINING"
|
| 17 |
+
REGIMEN_RISK = "REGIMEN_RISK"
|
| 18 |
+
PRECISION_DOSING = "PRECISION_DOSING"
|
| 19 |
+
LONGITUDINAL_DEPRESCRIBING = "LONGITUDINAL_DEPRESCRIBING"
|
| 20 |
+
WEB_SEARCH_MISSING_DATA = "WEB_SEARCH_MISSING_DATA"
|
| 21 |
+
ALTERNATIVE_SUGGESTION = "ALTERNATIVE_SUGGESTION"
|
| 22 |
+
NEW_DRUG_DECOMPOSITION = "NEW_DRUG_DECOMPOSITION"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DecisionMode(str, Enum):
|
| 26 |
+
REGIMEN_OPT = "REGIMEN_OPT"
|
| 27 |
+
DOSE_OPT = "DOSE_OPT"
|
| 28 |
+
REVIEW = "REVIEW"
|
| 29 |
+
ABSTAIN_REVIEW = "ABSTAIN_REVIEW"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ActionType(str, Enum):
|
| 33 |
+
KEEP_REGIMEN = "KEEP_REGIMEN"
|
| 34 |
+
STOP_DRUG = "STOP_DRUG"
|
| 35 |
+
SUBSTITUTE_WITHIN_CLASS = "SUBSTITUTE_WITHIN_CLASS"
|
| 36 |
+
RECOMMEND_ALTERNATIVE = "RECOMMEND_ALTERNATIVE"
|
| 37 |
+
REDUCE_DOSE_BUCKET = "REDUCE_DOSE_BUCKET"
|
| 38 |
+
INCREASE_DOSE_BUCKET = "INCREASE_DOSE_BUCKET"
|
| 39 |
+
TAPER_INITIATE = "TAPER_INITIATE"
|
| 40 |
+
TAPER_CONTINUE = "TAPER_CONTINUE"
|
| 41 |
+
DOSE_HOLD = "DOSE_HOLD"
|
| 42 |
+
ORDER_MONITORING_AND_WAIT = "ORDER_MONITORING_AND_WAIT"
|
| 43 |
+
FETCH_EXTERNAL_EVIDENCE = "FETCH_EXTERNAL_EVIDENCE"
|
| 44 |
+
DECOMPOSE_NEW_DRUG = "DECOMPOSE_NEW_DRUG"
|
| 45 |
+
REQUEST_SPECIALIST_REVIEW = "REQUEST_SPECIALIST_REVIEW"
|
| 46 |
+
REQUEST_PHARMACIST_REVIEW = "REQUEST_PHARMACIST_REVIEW"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DoseBucket(str, Enum):
|
| 50 |
+
LOW = "LOW"
|
| 51 |
+
MEDIUM = "MEDIUM"
|
| 52 |
+
HIGH = "HIGH"
|
| 53 |
+
HOLD = "HOLD"
|
| 54 |
+
NA = "NA"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CoordinationMode(str, Enum):
|
| 58 |
+
SEQUENTIAL = "sequential_pipeline"
|
| 59 |
+
SUPERVISOR_ROUTED = "supervisor_routed"
|
| 60 |
+
REPLAN_ON_VETO = "replan_on_veto"
|
| 61 |
+
LIGHT_DEBATE = "lightweight_debate"
|
app/common/exceptions.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom exceptions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PolyGuardError(Exception):
|
| 7 |
+
"""Base exception for project errors."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InvalidActionError(PolyGuardError):
|
| 11 |
+
"""Raised when an action is malformed or disallowed."""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SafetyVetoError(PolyGuardError):
|
| 15 |
+
"""Raised when safety governance rejects an action."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ParserError(PolyGuardError):
|
| 19 |
+
"""Raised when structured policy output cannot be parsed."""
|
app/common/json_utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Strict JSON helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def safe_json_dumps(payload: Any) -> str:
|
| 10 |
+
return json.dumps(payload, ensure_ascii=True, sort_keys=True, default=str)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def safe_json_loads(payload: str) -> Any:
|
| 14 |
+
return json.loads(payload)
|
app/common/logging_utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def configure_logging(level: str = "INFO") -> None:
|
| 10 |
+
logging.basicConfig(
|
| 11 |
+
level=getattr(logging, level.upper(), logging.INFO),
|
| 12 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
| 17 |
+
return logging.getLogger(name or "polyguard")
|
app/common/normalization.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalization and reward range utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.constants import REWARD_MAX, REWARD_MIN, REWARD_PRECISION
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def clamp_reward(value: float) -> float:
|
| 9 |
+
"""Clamp and quantize reward to [0.001, 0.999] with 3 decimals."""
|
| 10 |
+
value = min(REWARD_MAX, max(REWARD_MIN, float(value)))
|
| 11 |
+
return round(value, REWARD_PRECISION)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def normalize_unit_interval(value: float, lower: float, upper: float) -> float:
|
| 15 |
+
if upper <= lower:
|
| 16 |
+
return 0.5
|
| 17 |
+
ratio = (value - lower) / (upper - lower)
|
| 18 |
+
return float(min(1.0, max(0.0, ratio)))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def to_reward(value: float, lower: float, upper: float) -> float:
|
| 22 |
+
raw = normalize_unit_interval(value, lower, upper)
|
| 23 |
+
scaled = REWARD_MIN + raw * (REWARD_MAX - REWARD_MIN)
|
| 24 |
+
return clamp_reward(scaled)
|
app/common/seeding.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic seeding helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from app.common.constants import DEFAULT_SEED
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def set_global_seed(seed: int = DEFAULT_SEED) -> int:
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 17 |
+
return seed
|
app/common/types.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core typed models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
| 9 |
+
|
| 10 |
+
from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment
|
| 11 |
+
from app.common.normalization import clamp_reward
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class StrictBase(BaseModel):
|
| 15 |
+
model_config = ConfigDict(extra="forbid")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Medication(StrictBase):
|
| 19 |
+
drug: str
|
| 20 |
+
dose_bucket: DoseBucket = DoseBucket.MEDIUM
|
| 21 |
+
indication: Optional[str] = None
|
| 22 |
+
class_name: Optional[str] = None
|
| 23 |
+
requires_taper: bool = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LabSummary(StrictBase):
|
| 27 |
+
egfr: Optional[float] = None
|
| 28 |
+
ast: Optional[float] = None
|
| 29 |
+
alt: Optional[float] = None
|
| 30 |
+
inr: Optional[float] = None
|
| 31 |
+
glucose: Optional[float] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PatientProfile(StrictBase):
|
| 35 |
+
patient_id: str
|
| 36 |
+
age: int
|
| 37 |
+
sex: str
|
| 38 |
+
comorbidities: list[str] = Field(default_factory=list)
|
| 39 |
+
medications: list[Medication] = Field(default_factory=list)
|
| 40 |
+
labs: LabSummary = Field(default_factory=LabSummary)
|
| 41 |
+
vitals: dict[str, float] = Field(default_factory=dict)
|
| 42 |
+
specialist_conflicts: list[str] = Field(default_factory=list)
|
| 43 |
+
prior_ade_history: list[str] = Field(default_factory=list)
|
| 44 |
+
frailty_score: float = 0.3
|
| 45 |
+
adherence_estimate: float = 0.8
|
| 46 |
+
latent_confounders: dict[str, float] = Field(default_factory=dict)
|
| 47 |
+
monitoring_gaps: list[str] = Field(default_factory=list)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CandidateAction(StrictBase):
|
| 51 |
+
candidate_id: str
|
| 52 |
+
mode: DecisionMode
|
| 53 |
+
action_type: ActionType
|
| 54 |
+
target_drug: Optional[str] = None
|
| 55 |
+
replacement_drug: Optional[str] = None
|
| 56 |
+
dose_bucket: DoseBucket = DoseBucket.NA
|
| 57 |
+
taper_days: Optional[int] = None
|
| 58 |
+
monitoring_plan: Optional[str] = None
|
| 59 |
+
evidence_query: Optional[str] = None
|
| 60 |
+
new_drug_name: Optional[str] = None
|
| 61 |
+
candidate_components: list[str] = Field(default_factory=list)
|
| 62 |
+
estimated_safety_delta: float = 0.0
|
| 63 |
+
burden_delta: float = 0.0
|
| 64 |
+
disease_stability_estimate: float = 0.0
|
| 65 |
+
uncertainty_score: float = 0.5
|
| 66 |
+
rationale_tags: list[str] = Field(default_factory=list)
|
| 67 |
+
required_monitoring: list[str] = Field(default_factory=list)
|
| 68 |
+
legality_precheck: bool = True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class PolyGuardAction(StrictBase):
|
| 72 |
+
mode: DecisionMode
|
| 73 |
+
action_type: ActionType
|
| 74 |
+
target_drug: Optional[str] = None
|
| 75 |
+
replacement_drug: Optional[str] = None
|
| 76 |
+
dose_bucket: DoseBucket = DoseBucket.NA
|
| 77 |
+
taper_days: Optional[int] = None
|
| 78 |
+
monitoring_plan: Optional[str] = None
|
| 79 |
+
evidence_query: Optional[str] = None
|
| 80 |
+
new_drug_name: Optional[str] = None
|
| 81 |
+
candidate_components: list[str] = Field(default_factory=list)
|
| 82 |
+
candidate_id: str
|
| 83 |
+
confidence: float
|
| 84 |
+
rationale_brief: str
|
| 85 |
+
|
| 86 |
+
@field_validator("confidence")
|
| 87 |
+
@classmethod
|
| 88 |
+
def _valid_confidence(cls, value: float) -> float:
|
| 89 |
+
return clamp_reward(value)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class RewardBreakdown(StrictBase):
|
| 93 |
+
format_compliance_score: float
|
| 94 |
+
candidate_alignment_score: float
|
| 95 |
+
legality_score: float
|
| 96 |
+
safety_delta_score: float
|
| 97 |
+
burden_improvement_score: float
|
| 98 |
+
disease_stability_score: float
|
| 99 |
+
dosing_quality_score: float
|
| 100 |
+
abstention_quality_score: float
|
| 101 |
+
efficiency_score: float
|
| 102 |
+
process_fidelity_score: float
|
| 103 |
+
explanation_grounding_score: float
|
| 104 |
+
anti_cheat_score: float
|
| 105 |
+
uncertainty_calibration_score: float
|
| 106 |
+
primary_safety_legality: float = 0.5
|
| 107 |
+
primary_clinical_improvement: float = 0.5
|
| 108 |
+
primary_dosing_quality: float = 0.5
|
| 109 |
+
primary_process_integrity: float = 0.5
|
| 110 |
+
total_reward: float
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class SafetyReport(StrictBase):
|
| 114 |
+
legal: bool
|
| 115 |
+
violations: list[str] = Field(default_factory=list)
|
| 116 |
+
severity: str = "none"
|
| 117 |
+
recommended_fallback: Optional[ActionType] = None
|
| 118 |
+
uncertainty_notes: list[str] = Field(default_factory=list)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class UncertaintyReport(StrictBase):
|
| 122 |
+
overall_uncertainty: float = 0.5
|
| 123 |
+
missing_data_flags: list[str] = Field(default_factory=list)
|
| 124 |
+
abstention_recommended: bool = False
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class PolyGuardState(StrictBase):
|
| 128 |
+
episode_id: str
|
| 129 |
+
seed: int
|
| 130 |
+
scenario_id: Optional[str] = None
|
| 131 |
+
difficulty: Difficulty
|
| 132 |
+
sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK
|
| 133 |
+
step_count: int
|
| 134 |
+
max_steps: int
|
| 135 |
+
patient: PatientProfile
|
| 136 |
+
active_mode: DecisionMode = DecisionMode.REGIMEN_OPT
|
| 137 |
+
cumulative_reward: float = 0.0
|
| 138 |
+
unresolved_conflicts: list[str] = Field(default_factory=list)
|
| 139 |
+
risk_summary: dict[str, float] = Field(default_factory=dict)
|
| 140 |
+
burden_score: float = 0.5
|
| 141 |
+
precision_dosing_flags: list[str] = Field(default_factory=list)
|
| 142 |
+
action_history: list[dict[str, Any]] = Field(default_factory=list)
|
| 143 |
+
done: bool = False
|
| 144 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class PolyGuardObservation(StrictBase):
|
| 148 |
+
patient_summary: dict[str, Any]
|
| 149 |
+
medication_table: list[dict[str, Any]]
|
| 150 |
+
comorbidity_summary: list[str]
|
| 151 |
+
organ_function_summary: dict[str, Any]
|
| 152 |
+
labs_vitals_summary: dict[str, Any]
|
| 153 |
+
graph_safety_summary: dict[str, Any]
|
| 154 |
+
burden_score_summary: dict[str, Any]
|
| 155 |
+
precision_dosing_flags: list[str]
|
| 156 |
+
unresolved_conflicts: list[str]
|
| 157 |
+
candidate_action_set: list[CandidateAction]
|
| 158 |
+
step_budget_remaining: int
|
| 159 |
+
action_history: list[dict[str, Any]]
|
| 160 |
+
warning_summary: list[str]
|
| 161 |
+
abstention_indicators: dict[str, Any]
|
| 162 |
+
sub_environment: SubEnvironment
|
| 163 |
+
deterministic_contract: dict[str, Any] = Field(default_factory=dict)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class StepTrace(StrictBase):
|
| 167 |
+
step: int
|
| 168 |
+
observation_snapshot: PolyGuardObservation
|
| 169 |
+
selected_action: Optional[PolyGuardAction] = None
|
| 170 |
+
critic_output: dict[str, Any] = Field(default_factory=dict)
|
| 171 |
+
reward_components: dict[str, float] = Field(default_factory=dict)
|
| 172 |
+
transition_delta: dict[str, Any] = Field(default_factory=dict)
|
| 173 |
+
uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport)
|
| 174 |
+
failure_reasons: list[str] = Field(default_factory=list)
|
| 175 |
+
timeout: bool = False
|
app/dataops/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data operations package."""
|
| 2 |
+
|
| 3 |
+
from app.dataops.source_manager import SourceManager
|
| 4 |
+
|
| 5 |
+
__all__ = ["SourceManager"]
|
app/dataops/ddi_api.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DDI API ingestion helpers with offline-first caching."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
DEFAULT_DDI_API_URL = "https://api.fda.gov/drug/label.json"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fetch_ddi_api_records(
|
| 16 |
+
drugs: list[str],
|
| 17 |
+
timeout: int = 20,
|
| 18 |
+
api_url: str = DEFAULT_DDI_API_URL,
|
| 19 |
+
) -> list[dict[str, Any]]:
|
| 20 |
+
records: list[dict[str, Any]] = []
|
| 21 |
+
for drug in drugs:
|
| 22 |
+
try:
|
| 23 |
+
response = requests.get(
|
| 24 |
+
api_url,
|
| 25 |
+
params={"search": f"openfda.generic_name:{drug}", "limit": 1},
|
| 26 |
+
timeout=timeout,
|
| 27 |
+
)
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
payload = response.json()
|
| 30 |
+
records.append(
|
| 31 |
+
{
|
| 32 |
+
"drug": drug,
|
| 33 |
+
"source": api_url,
|
| 34 |
+
"status": "ok",
|
| 35 |
+
"payload": payload,
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
except Exception as exc: # noqa: BLE001
|
| 39 |
+
records.append(
|
| 40 |
+
{
|
| 41 |
+
"drug": drug,
|
| 42 |
+
"source": api_url,
|
| 43 |
+
"status": "error",
|
| 44 |
+
"error": str(exc),
|
| 45 |
+
}
|
| 46 |
+
)
|
| 47 |
+
return records
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_cached_ddi(path: Path) -> list[dict[str, Any]]:
|
| 51 |
+
if not path.exists():
|
| 52 |
+
return []
|
| 53 |
+
try:
|
| 54 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 55 |
+
if isinstance(payload, list):
|
| 56 |
+
return payload
|
| 57 |
+
return []
|
| 58 |
+
except Exception:
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def cache_ddi_records(path: Path, records: list[dict[str, Any]]) -> Path:
|
| 63 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
path.write_text(json.dumps(records, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 65 |
+
return path
|
app/dataops/normalizer.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entity normalizer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.knowledge.drug_catalog import canonicalize_drug_name
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def normalize_drug_entities(items: list[str]) -> list[str]:
|
| 9 |
+
return sorted({canonicalize_drug_name(item) for item in items})
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def normalize_component_entities(items: list[str]) -> list[str]:
|
| 13 |
+
return sorted({canonicalize_drug_name(item).replace("-", "_") for item in items if item})
|
app/dataops/package_loader.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Package/local artifact loading."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_artifact(path: Path) -> Any:
|
| 13 |
+
if path.suffix.lower() in {".json"}:
|
| 14 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 15 |
+
if path.suffix.lower() in {".yaml", ".yml"}:
|
| 16 |
+
return yaml.safe_load(path.read_text(encoding="utf-8"))
|
| 17 |
+
if path.suffix.lower() in {".txt", ".md"}:
|
| 18 |
+
return path.read_text(encoding="utf-8")
|
| 19 |
+
return path.read_bytes()
|
app/dataops/parser.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Raw text parser for knowledge ingestion."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def extract_drug_mentions(text: str) -> list[str]:
|
| 9 |
+
tokens = re.findall(r"[a-zA-Z_-]{4,}", text.lower())
|
| 10 |
+
return sorted(set(tokens))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_components(text: str) -> list[str]:
|
| 14 |
+
# Supports "active ingredient(s): ..." and similar label patterns.
|
| 15 |
+
lines = [line.strip().lower() for line in text.splitlines() if line.strip()]
|
| 16 |
+
components: list[str] = []
|
| 17 |
+
for line in lines:
|
| 18 |
+
if "ingredient" in line or "component" in line or "contains" in line:
|
| 19 |
+
parts = re.split(r":|\\.|;", line, maxsplit=1)
|
| 20 |
+
if len(parts) > 1:
|
| 21 |
+
rhs = parts[1]
|
| 22 |
+
for item in re.split(r",|/| and ", rhs):
|
| 23 |
+
token = re.sub(r"[^a-z0-9_ -]", "", item).strip().replace(" ", "_")
|
| 24 |
+
if 3 <= len(token) <= 40:
|
| 25 |
+
components.append(token)
|
| 26 |
+
return sorted(set(components))
|
app/dataops/provenance.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Provenance tracking."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(slots=True)
|
| 10 |
+
class ProvenanceRecord:
|
| 11 |
+
source: str
|
| 12 |
+
source_type: str
|
| 13 |
+
fetched_at: str
|
| 14 |
+
transform: str
|
| 15 |
+
|
| 16 |
+
def to_dict(self) -> dict[str, str]:
|
| 17 |
+
return {
|
| 18 |
+
"source": self.source,
|
| 19 |
+
"source_type": self.source_type,
|
| 20 |
+
"fetched_at": self.fetched_at,
|
| 21 |
+
"transform": self.transform,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def make_provenance(source: str, source_type: str, transform: str) -> ProvenanceRecord:
|
| 26 |
+
return ProvenanceRecord(
|
| 27 |
+
source=source,
|
| 28 |
+
source_type=source_type,
|
| 29 |
+
fetched_at=datetime.utcnow().isoformat(),
|
| 30 |
+
transform=transform,
|
| 31 |
+
)
|
app/dataops/scraper.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Controlled scraper facade."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.dataops.web_agent import fetch_url
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def scrape_allowed_page(url: str, allow_domains: list[str]) -> str:
|
| 9 |
+
return fetch_url(url, allowed_domains=allow_domains)
|
app/dataops/source_manager.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Source management for offline-first ingestion."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from app.dataops.web_agent import fetch_url
|
| 11 |
+
from app.dataops.parser import extract_components, extract_drug_mentions
|
| 12 |
+
from app.dataops.normalizer import normalize_component_entities, normalize_drug_entities
|
| 13 |
+
from app.dataops.provenance import make_provenance
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SourceManager:
|
| 17 |
+
def __init__(self, root: Path) -> None:
|
| 18 |
+
self.root = root
|
| 19 |
+
self.raw = root / "data" / "raw"
|
| 20 |
+
self.cache = root / "data" / "cache"
|
| 21 |
+
self.cache.mkdir(parents=True, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
def local_sources(self) -> list[Path]:
|
| 24 |
+
return [p for p in self.raw.rglob("*") if p.is_file()]
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def checksum_text(text: str) -> str:
|
| 28 |
+
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
| 29 |
+
|
| 30 |
+
def cache_text(self, namespace: str, key: str, text: str) -> Path:
|
| 31 |
+
ns_dir = self.cache / namespace
|
| 32 |
+
ns_dir.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
checksum = self.checksum_text(text)
|
| 34 |
+
target = ns_dir / f"{key}_{checksum[:12]}.txt"
|
| 35 |
+
target.write_text(text, encoding="utf-8")
|
| 36 |
+
meta = {
|
| 37 |
+
"key": key,
|
| 38 |
+
"checksum": checksum,
|
| 39 |
+
"path": str(target),
|
| 40 |
+
}
|
| 41 |
+
(ns_dir / f"{key}.meta.json").write_text(json.dumps(meta, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 42 |
+
return target
|
| 43 |
+
|
| 44 |
+
def read_cached(self, namespace: str, key: str) -> str | None:
|
| 45 |
+
meta_path = self.cache / namespace / f"{key}.meta.json"
|
| 46 |
+
if not meta_path.exists():
|
| 47 |
+
return None
|
| 48 |
+
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
| 49 |
+
target = Path(meta["path"])
|
| 50 |
+
if target.exists():
|
| 51 |
+
return target.read_text(encoding="utf-8")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def fetch_with_cache(
|
| 55 |
+
self,
|
| 56 |
+
url: str,
|
| 57 |
+
allow_domains: list[str],
|
| 58 |
+
namespace: str = "web",
|
| 59 |
+
offline_first: bool = True,
|
| 60 |
+
) -> dict[str, Any]:
|
| 61 |
+
key = url.replace("https://", "").replace("http://", "").replace("/", "_")
|
| 62 |
+
if offline_first:
|
| 63 |
+
cached = self.read_cached(namespace=namespace, key=key)
|
| 64 |
+
if cached is not None:
|
| 65 |
+
provenance = make_provenance(source=url, source_type="cache", transform="read_cached")
|
| 66 |
+
return {"text": cached, "provenance": provenance.__dict__, "from_cache": True}
|
| 67 |
+
text = fetch_url(url, allowed_domains=allow_domains)
|
| 68 |
+
self.cache_text(namespace=namespace, key=key, text=text)
|
| 69 |
+
provenance = make_provenance(source=url, source_type="web", transform="fetch_with_cache")
|
| 70 |
+
return {"text": text, "provenance": provenance.__dict__, "from_cache": False}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DataAcquisitionAgent:
|
| 74 |
+
def __init__(self, root: Path, allow_domains: list[str]) -> None:
|
| 75 |
+
self.manager = SourceManager(root=root)
|
| 76 |
+
self.allow_domains = allow_domains
|
| 77 |
+
|
| 78 |
+
def acquire_local_knowledge(self) -> list[dict[str, Any]]:
|
| 79 |
+
records: list[dict[str, Any]] = []
|
| 80 |
+
for source in self.manager.local_sources():
|
| 81 |
+
text = source.read_text(encoding="utf-8", errors="ignore")
|
| 82 |
+
mentions = normalize_drug_entities(extract_drug_mentions(text))
|
| 83 |
+
components = normalize_component_entities(extract_components(text))
|
| 84 |
+
provenance = make_provenance(source=str(source), source_type="local_file", transform="parse_local").to_dict()
|
| 85 |
+
records.append(
|
| 86 |
+
{
|
| 87 |
+
"source": str(source),
|
| 88 |
+
"mentions": mentions,
|
| 89 |
+
"components": components,
|
| 90 |
+
"provenance": provenance,
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
return records
|
| 94 |
+
|
| 95 |
+
def acquire_web_knowledge(self, url: str, offline_first: bool = True) -> dict[str, Any]:
|
| 96 |
+
blob = self.manager.fetch_with_cache(
|
| 97 |
+
url=url,
|
| 98 |
+
allow_domains=self.allow_domains,
|
| 99 |
+
namespace="drug_labels",
|
| 100 |
+
offline_first=offline_first,
|
| 101 |
+
)
|
| 102 |
+
text = blob["text"]
|
| 103 |
+
mentions = normalize_drug_entities(extract_drug_mentions(text))
|
| 104 |
+
components = normalize_component_entities(extract_components(text))
|
| 105 |
+
return {
|
| 106 |
+
"url": url,
|
| 107 |
+
"mentions": mentions,
|
| 108 |
+
"components": components,
|
| 109 |
+
"provenance": blob["provenance"],
|
| 110 |
+
"from_cache": blob["from_cache"],
|
| 111 |
+
}
|
app/dataops/synthetic_mix.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Synthetic and mock data blending."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def merge_sources(local_items: list[dict[str, Any]], generated_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 9 |
+
return local_items + generated_items
|
app/dataops/web_agent.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Allow-listed web retrieval."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fetch_url(url: str, allowed_domains: list[str]) -> str:
|
| 11 |
+
host = urlparse(url).netloc.lower()
|
| 12 |
+
if not any(host.endswith(domain) for domain in allowed_domains):
|
| 13 |
+
raise ValueError(f"Domain not allow-listed: {host}")
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(url, timeout=20)
|
| 16 |
+
response.raise_for_status()
|
| 17 |
+
return response.text
|
| 18 |
+
except Exception as exc: # noqa: BLE001
|
| 19 |
+
# Explicit failure message makes offline-first behavior easier to reason about upstream.
|
| 20 |
+
raise RuntimeError(f"web_fetch_failed:{host}:{exc}") from exc
|
app/dataops/web_fallback.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optional web fallback ingestion via Scrapling and Playwright."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _allowed(url: str, allow_domains: list[str]) -> bool:
|
| 11 |
+
host = urlparse(url).netloc.lower()
|
| 12 |
+
return any(host.endswith(domain) for domain in allow_domains)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _scrape_with_scrapling(url: str) -> str:
|
| 16 |
+
# Scrapling API compatibility may vary by version, so this path is best-effort.
|
| 17 |
+
from scrapling import Fetcher # type: ignore
|
| 18 |
+
|
| 19 |
+
fetcher = Fetcher()
|
| 20 |
+
page = fetcher.get(url)
|
| 21 |
+
return getattr(page, "text", "") or ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _scrape_with_playwright(url: str) -> str:
|
| 25 |
+
from playwright.sync_api import sync_playwright # type: ignore
|
| 26 |
+
|
| 27 |
+
with sync_playwright() as p:
|
| 28 |
+
browser = p.chromium.launch(headless=True)
|
| 29 |
+
page = browser.new_page()
|
| 30 |
+
page.goto(url, timeout=30_000)
|
| 31 |
+
content = page.content()
|
| 32 |
+
browser.close()
|
| 33 |
+
return content
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def scrape_with_fallback(url: str, allow_domains: list[str]) -> dict[str, str]:
|
| 37 |
+
if not _allowed(url, allow_domains):
|
| 38 |
+
return {"status": "blocked", "url": url, "backend": "allowlist"}
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
text = _scrape_with_scrapling(url)
|
| 42 |
+
if text:
|
| 43 |
+
return {"status": "ok", "url": url, "backend": "scrapling", "text": text}
|
| 44 |
+
except Exception:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
text = _scrape_with_playwright(url)
|
| 49 |
+
if text:
|
| 50 |
+
return {"status": "ok", "url": url, "backend": "playwright", "text": text}
|
| 51 |
+
except Exception:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
response = requests.get(url, timeout=20)
|
| 56 |
+
response.raise_for_status()
|
| 57 |
+
return {"status": "ok", "url": url, "backend": "requests", "text": response.text}
|
| 58 |
+
except Exception as exc: # noqa: BLE001
|
| 59 |
+
return {"status": "error", "url": url, "backend": "none", "error": str(exc)}
|
app/env/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Environment package."""
|
| 2 |
+
|
| 3 |
+
__all__ = ["PolyGuardEnv", "EnvironmentA", "EnvironmentB", "EnvironmentC", "EnvironmentD"]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def __getattr__(name: str):
|
| 7 |
+
if name == "PolyGuardEnv":
|
| 8 |
+
from app.env.env_core import PolyGuardEnv
|
| 9 |
+
|
| 10 |
+
return PolyGuardEnv
|
| 11 |
+
if name == "EnvironmentA":
|
| 12 |
+
from app.env.environment_a import EnvironmentA
|
| 13 |
+
|
| 14 |
+
return EnvironmentA
|
| 15 |
+
if name == "EnvironmentB":
|
| 16 |
+
from app.env.environment_b import EnvironmentB
|
| 17 |
+
|
| 18 |
+
return EnvironmentB
|
| 19 |
+
if name == "EnvironmentC":
|
| 20 |
+
from app.env.environment_c import EnvironmentC
|
| 21 |
+
|
| 22 |
+
return EnvironmentC
|
| 23 |
+
if name == "EnvironmentD":
|
| 24 |
+
from app.env.environment_d import EnvironmentD
|
| 25 |
+
|
| 26 |
+
return EnvironmentD
|
| 27 |
+
raise AttributeError(name)
|
app/env/actions.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Environment action helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardAction
|
| 6 |
+
|
| 7 |
+
__all__ = ["PolyGuardAction"]
|