Deploy PolyGuard: nginx + OpenEnv + API + static UI (CPU)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +21 -0
- .env.example +22 -0
- .gitattributes +8 -0
- .gitignore +29 -0
- Dockerfile +41 -0
- Dockerfile.space +41 -0
- LICENSE +21 -0
- Makefile +29 -0
- PolyGuard_SFT_GRPO_One_Run_Runner.ipynb +481 -0
- README.md +6 -4
- README_HF_SPACE.md +12 -0
- __init__.py +5 -0
- app/__init__.py +1 -0
- app/agents/__init__.py +5 -0
- app/agents/candidate_agent.py +14 -0
- app/agents/critic_agent.py +43 -0
- app/agents/critic_safety_agent.py +11 -0
- app/agents/dosing_agent.py +52 -0
- app/agents/evidence_agent.py +14 -0
- app/agents/explainer_agent.py +22 -0
- app/agents/graph_agent.py +28 -0
- app/agents/graph_safety_agent.py +11 -0
- app/agents/medrec_agent.py +22 -0
- app/agents/orchestrator.py +151 -0
- app/agents/planner_agent.py +44 -0
- app/agents/supervisor_agent.py +23 -0
- app/api/__init__.py +46 -0
- app/api/__main__.py +7 -0
- app/api/dependencies.py +11 -0
- app/api/main.py +10 -0
- app/api/routes.py +139 -0
- app/api/schemas.py +57 -0
- app/api/service.py +219 -0
- app/common/config.py +57 -0
- app/common/constants.py +40 -0
- app/common/enums.py +61 -0
- app/common/exceptions.py +19 -0
- app/common/json_utils.py +14 -0
- app/common/logging_utils.py +17 -0
- app/common/normalization.py +24 -0
- app/common/seeding.py +17 -0
- app/common/types.py +175 -0
- app/dataops/__init__.py +5 -0
- app/dataops/ddi_api.py +65 -0
- app/dataops/normalizer.py +13 -0
- app/dataops/package_loader.py +19 -0
- app/dataops/parser.py +26 -0
- app/dataops/provenance.py +31 -0
- app/dataops/scraper.py +9 -0
- app/dataops/source_manager.py +111 -0
.dockerignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
venv
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
.pytest_cache
|
| 6 |
+
.mypy_cache
|
| 7 |
+
.git
|
| 8 |
+
.gitignore
|
| 9 |
+
*.md
|
| 10 |
+
!README.md
|
| 11 |
+
node_modules
|
| 12 |
+
app/ui/frontend/node_modules
|
| 13 |
+
app/ui/frontend/dist
|
| 14 |
+
checkpoints/active
|
| 15 |
+
checkpoints/.hf_bundles
|
| 16 |
+
outputs
|
| 17 |
+
.env
|
| 18 |
+
*.log
|
| 19 |
+
submission_bundle
|
| 20 |
+
notebooks
|
| 21 |
+
.pytest_cache
|
.env.example
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
POLYGUARD_DATA_DIR=./data
|
| 2 |
+
POLYGUARD_LOG_LEVEL=INFO
|
| 3 |
+
POLYGUARD_SEED=42
|
| 4 |
+
POLYGUARD_ENV_HOST=127.0.0.1
|
| 5 |
+
POLYGUARD_ENV_PORT=8100
|
| 6 |
+
POLYGUARD_API_HOST=127.0.0.1
|
| 7 |
+
POLYGUARD_API_PORT=8200
|
| 8 |
+
POLYGUARD_UI_PORT=5173
|
| 9 |
+
POLYGUARD_ENABLE_OLLAMA=false
|
| 10 |
+
POLYGUARD_OLLAMA_MODEL=qwen2.5:3b-instruct
|
| 11 |
+
# Optional explicit order (comma-separated): transformers,ollama
|
| 12 |
+
# POLYGUARD_PROVIDER_PREFERENCE=transformers,ollama
|
| 13 |
+
POLYGUARD_PROVIDER_TIMEOUT_SECONDS=25
|
| 14 |
+
# Trained checkpoint (GRPO adapter + merged + SFT) from HF: run
|
| 15 |
+
# python scripts/install_hf_active_bundle.py
|
| 16 |
+
# Then enable loading from checkpoints/active/active_model_manifest.json.
|
| 17 |
+
POLYGUARD_ENABLE_ACTIVE_MODEL=true
|
| 18 |
+
POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct
|
| 19 |
+
POLYGUARD_FRONTIER_MODEL=Qwen/Qwen2.5-7B-Instruct
|
| 20 |
+
POLYGUARD_ALLOW_WEB_FETCH=false
|
| 21 |
+
POLYGUARD_REWARD_MIN=0.001
|
| 22 |
+
POLYGUARD_REWARD_MAX=0.999
|
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
app/ui/frontend/dist/blackhole.webm filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
app/ui/frontend/public/blackhole.webm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
docs/results/model_improvement_evidence_qwen_0_5b_1_5b/charts/reward_function/reward_component_bars.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
docs/results/qwen_completed_runs/charts/generated/reward_component_bars.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
docs/results/submission_evidence/qwen_0_5b_1_5b/reward_component_bars.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
docs/results/submission_evidence/qwen_0_5b_1_5b_3b/reward_component_bars.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
docs/results/submission_evidence_qwen_0_5b_1_5b/charts/generated/reward_component_bars.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
docs/results/submission_evidence_qwen_0_5b_1_5b_3b/charts/generated/reward_component_bars.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.pytest_cache/
|
| 7 |
+
.mypy_cache/
|
| 8 |
+
.ruff_cache/
|
| 9 |
+
.venv/
|
| 10 |
+
.env
|
| 11 |
+
node_modules/
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
| 14 |
+
*.log
|
| 15 |
+
# Weight bundles and run outputs are local-only; tracked READMEs explain layout.
|
| 16 |
+
checkpoints/*
|
| 17 |
+
!checkpoints/README.md
|
| 18 |
+
outputs/*
|
| 19 |
+
!outputs/README.md
|
| 20 |
+
artifacts/
|
| 21 |
+
submission_bundle/model_artifacts/
|
| 22 |
+
submission_bundle/*.zip
|
| 23 |
+
data/cache/*
|
| 24 |
+
data/processed/*
|
| 25 |
+
data/synthetic/*
|
| 26 |
+
data/retrieval_index/*
|
| 27 |
+
!data/**/.gitkeep
|
| 28 |
+
app/ui/frontend/.vite/
|
| 29 |
+
/demo.md
|
Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space: single-port edge (nginx) + OpenEnv (8100) + API (8200) + static UI.
|
| 2 |
+
# Build from repository root: docker build -f Dockerfile.space -t polyguard-space .
|
| 3 |
+
# Cheap tier: use Space "CPU basic"; first boot downloads ~1.1GB model bundle.
|
| 4 |
+
|
| 5 |
+
FROM node:20-bookworm-slim AS frontend
|
| 6 |
+
WORKDIR /build
|
| 7 |
+
COPY app/ui/frontend/package.json app/ui/frontend/package-lock.json ./
|
| 8 |
+
RUN npm ci
|
| 9 |
+
COPY app/ui/frontend/ ./
|
| 10 |
+
ENV VITE_API_BASE=/api
|
| 11 |
+
RUN npm run build
|
| 12 |
+
|
| 13 |
+
FROM python:3.11-slim-bookworm
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends nginx \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
COPY requirements-space.txt /app/requirements-space.txt
|
| 20 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 21 |
+
&& pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu \
|
| 22 |
+
&& pip install --no-cache-dir -r /app/requirements-space.txt
|
| 23 |
+
|
| 24 |
+
COPY . /app
|
| 25 |
+
COPY --from=frontend /build/dist /app/static
|
| 26 |
+
|
| 27 |
+
RUN chmod +x /app/docker/space/entrypoint.sh \
|
| 28 |
+
&& mkdir -p /app/data /app/checkpoints/active
|
| 29 |
+
|
| 30 |
+
ENV PORT=7860
|
| 31 |
+
ENV POLYGUARD_ALLOW_HF_SPACE_CORS=true
|
| 32 |
+
ENV POLYGUARD_ENABLE_OLLAMA=false
|
| 33 |
+
ENV POLYGUARD_ENABLE_ACTIVE_MODEL=true
|
| 34 |
+
ENV POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct
|
| 35 |
+
ENV POLYGUARD_PROVIDER_PREFERENCE=transformers
|
| 36 |
+
ENV POLYGUARD_ALLOW_WEB_FETCH=false
|
| 37 |
+
ENV POLYGUARD_DATA_DIR=/app/data
|
| 38 |
+
ENV PYTHONUNBUFFERED=1
|
| 39 |
+
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
CMD ["/app/docker/space/entrypoint.sh"]
|
Dockerfile.space
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space: single-port edge (nginx) + OpenEnv (8100) + API (8200) + static UI.
|
| 2 |
+
# Build from repository root: docker build -f Dockerfile.space -t polyguard-space .
|
| 3 |
+
# Cheap tier: use Space "CPU basic"; first boot downloads ~1.1GB model bundle.
|
| 4 |
+
|
| 5 |
+
FROM node:20-bookworm-slim AS frontend
|
| 6 |
+
WORKDIR /build
|
| 7 |
+
COPY app/ui/frontend/package.json app/ui/frontend/package-lock.json ./
|
| 8 |
+
RUN npm ci
|
| 9 |
+
COPY app/ui/frontend/ ./
|
| 10 |
+
ENV VITE_API_BASE=/api
|
| 11 |
+
RUN npm run build
|
| 12 |
+
|
| 13 |
+
FROM python:3.11-slim-bookworm
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends nginx \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
COPY requirements-space.txt /app/requirements-space.txt
|
| 20 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 21 |
+
&& pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu \
|
| 22 |
+
&& pip install --no-cache-dir -r /app/requirements-space.txt
|
| 23 |
+
|
| 24 |
+
COPY . /app
|
| 25 |
+
COPY --from=frontend /build/dist /app/static
|
| 26 |
+
|
| 27 |
+
RUN chmod +x /app/docker/space/entrypoint.sh \
|
| 28 |
+
&& mkdir -p /app/data /app/checkpoints/active
|
| 29 |
+
|
| 30 |
+
ENV PORT=7860
|
| 31 |
+
ENV POLYGUARD_ALLOW_HF_SPACE_CORS=true
|
| 32 |
+
ENV POLYGUARD_ENABLE_OLLAMA=false
|
| 33 |
+
ENV POLYGUARD_ENABLE_ACTIVE_MODEL=true
|
| 34 |
+
ENV POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct
|
| 35 |
+
ENV POLYGUARD_PROVIDER_PREFERENCE=transformers
|
| 36 |
+
ENV POLYGUARD_ALLOW_WEB_FETCH=false
|
| 37 |
+
ENV POLYGUARD_DATA_DIR=/app/data
|
| 38 |
+
ENV PYTHONUNBUFFERED=1
|
| 39 |
+
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
CMD ["/app/docker/space/entrypoint.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Makefile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: install test lint env api ui smoke run-all
|
| 2 |
+
|
| 3 |
+
VENV_DIR := .venv
|
| 4 |
+
PYTHON := $(VENV_DIR)/bin/python
|
| 5 |
+
PIP := $(VENV_DIR)/bin/pip
|
| 6 |
+
|
| 7 |
+
$(PYTHON):
|
| 8 |
+
python3 -m venv $(VENV_DIR)
|
| 9 |
+
|
| 10 |
+
install: $(PYTHON)
|
| 11 |
+
bash scripts/bootstrap_venv.sh
|
| 12 |
+
|
| 13 |
+
test: $(PYTHON)
|
| 14 |
+
PYTHONPATH=. $(PYTHON) -m pytest
|
| 15 |
+
|
| 16 |
+
env: $(PYTHON)
|
| 17 |
+
PYTHONPATH=. $(PYTHON) -m app.env.fastapi_app
|
| 18 |
+
|
| 19 |
+
api: $(PYTHON)
|
| 20 |
+
PYTHONPATH=. $(PYTHON) -m app.api
|
| 21 |
+
|
| 22 |
+
ui:
|
| 23 |
+
cd app/ui/frontend && npm install && npm run dev
|
| 24 |
+
|
| 25 |
+
smoke:
|
| 26 |
+
bash scripts/smoke_test_all.sh
|
| 27 |
+
|
| 28 |
+
run-all: $(PYTHON)
|
| 29 |
+
bash scripts/run_all_local.sh --full
|
PolyGuard_SFT_GRPO_One_Run_Runner.ipynb
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# PolyGuard SFT + GRPO One-Run Runner\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"`POLYGUARD_ONE_RUN_RUNNER`\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"Run this notebook from top to bottom to execute the PolyGuard pipeline from data build through SFT baseline training, GRPO environment-reward training, artifact pull, inference validation, report/chart generation, and Hugging Face Space deployment.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"Default behavior uses Hugging Face Spaces for GPU training, not local Ollama or local GPU training. Keep `HF_TOKEN` in an environment variable or notebook secret; do not paste it into a cell output or commit it.\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"Reward values are expected to remain numeric, rounded to 3 decimals, and clamped to `[0.001, 0.999]` throughout the API, reports, and charts."
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "markdown",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"source": [
|
| 22 |
+
"## 0) Configuration Notes\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"The notebook is intentionally root-level in `polyguard-rl/`. If opened from Colab without the rest of the repo, the first cell clones the GitHub repo and changes into `polyguard-rl/`.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"Useful overrides:\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"- `HF_TOKEN`: write token for Spaces, model artifact repos, and private artifact pulls.\n",
|
| 29 |
+
"- `HF_USERNAME`: target Hub namespace. If omitted, the authenticated username is used.\n",
|
| 30 |
+
"- `POLYGUARD_MODEL_SWEEP`: comma-separated models, default Qwen 0.5B, 1.5B, and 3B instruct.\n",
|
| 31 |
+
"- `POLYGUARD_SFT_EPOCHS`, `POLYGUARD_GRPO_EPOCHS`: training epochs.\n",
|
| 32 |
+
"- `POLYGUARD_SFT_MAX_STEPS=0`, `POLYGUARD_GRPO_MAX_STEPS=0`, `POLYGUARD_GRPO_MAX_PROMPTS=0`: full-corpus/full-epoch mode.\n",
|
| 33 |
+
"- `POLYGUARD_WAIT_FOR_REMOTE_TRAINING=1`: keep polling until artifacts are pulled or timeout hits.\n",
|
| 34 |
+
"- `POLYGUARD_RUN_LOCAL_SMOKE=1`: also run a tiny local SFT/GRPO smoke loop."
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"from __future__ import annotations\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"import json\n",
|
| 46 |
+
"import os\n",
|
| 47 |
+
"from pathlib import Path\n",
|
| 48 |
+
"import subprocess\n",
|
| 49 |
+
"import sys\n",
|
| 50 |
+
"import time\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"PROJECT_SUBDIR = \"polyguard-rl\"\n",
|
| 53 |
+
"DEFAULT_REPO_URL = \"https://github.com/Vishwa-docs/Meta_Pytorch_OpenEnv_Scaler_VK.git\"\n",
|
| 54 |
+
"REPO_URL = os.getenv(\"POLYGUARD_GITHUB_REPO_URL\", DEFAULT_REPO_URL)\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"cwd = Path.cwd().resolve()\n",
|
| 57 |
+
"if (cwd / \"pyproject.toml\").exists() and (cwd / \"scripts\").exists():\n",
|
| 58 |
+
" ROOT = cwd\n",
|
| 59 |
+
"elif (cwd / PROJECT_SUBDIR / \"pyproject.toml\").exists():\n",
|
| 60 |
+
" ROOT = cwd / PROJECT_SUBDIR\n",
|
| 61 |
+
"else:\n",
|
| 62 |
+
" clone_root = Path(os.getenv(\"POLYGUARD_REPO_DIR\", \"/content/Meta_Pytorch_OpenEnv_Scaler_VK\")).resolve()\n",
|
| 63 |
+
" if not clone_root.exists():\n",
|
| 64 |
+
" subprocess.run([\"git\", \"clone\", REPO_URL, str(clone_root)], check=True)\n",
|
| 65 |
+
" ROOT = clone_root / PROJECT_SUBDIR\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"os.chdir(ROOT)\n",
|
| 68 |
+
"print(f\"PolyGuard root: {ROOT}\")\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"def run(cmd: list[str] | str, *, check: bool = True, env: dict[str, str] | None = None) -> subprocess.CompletedProcess[str]:\n",
|
| 71 |
+
" printable = cmd if isinstance(cmd, str) else \" \".join(cmd)\n",
|
| 72 |
+
" print(f\"\\n$ {printable}\")\n",
|
| 73 |
+
" merged_env = os.environ.copy()\n",
|
| 74 |
+
" if env:\n",
|
| 75 |
+
" merged_env.update(env)\n",
|
| 76 |
+
" completed = subprocess.run(cmd, check=False, text=True, env=merged_env)\n",
|
| 77 |
+
" if check and completed.returncode != 0:\n",
|
| 78 |
+
" raise RuntimeError(f\"command_failed:{printable}\")\n",
|
| 79 |
+
" return completed\n"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"# Install local runtime dependencies. This keeps the notebook kernel light while project commands run through uv.\n",
|
| 89 |
+
"run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"-U\", \"uv\", \"huggingface_hub\", \"gradio_client\"])\n",
|
| 90 |
+
"run([\"uv\", \"sync\"])\n"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"def read_colab_secret(name: str) -> str:\n",
|
| 100 |
+
" try:\n",
|
| 101 |
+
" from google.colab import userdata # type: ignore\n",
|
| 102 |
+
" except Exception:\n",
|
| 103 |
+
" return \"\"\n",
|
| 104 |
+
" try:\n",
|
| 105 |
+
" return str(userdata.get(name) or \"\")\n",
|
| 106 |
+
" except Exception:\n",
|
| 107 |
+
" return \"\"\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\") or read_colab_secret(\"HF_TOKEN\")\n",
|
| 110 |
+
"if HF_TOKEN:\n",
|
| 111 |
+
" os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"if os.getenv(\"POLYGUARD_REQUIRE_HF_TOKEN\", \"1\") == \"1\" and not HF_TOKEN:\n",
|
| 114 |
+
" raise RuntimeError(\"Set HF_TOKEN as an environment variable or Colab secret before running the remote training cells.\")\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"HF_USERNAME = os.getenv(\"HF_USERNAME\", \"\")\n",
|
| 117 |
+
"if HF_TOKEN and not HF_USERNAME:\n",
|
| 118 |
+
" from huggingface_hub import HfApi\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" whoami = HfApi(token=HF_TOKEN).whoami(token=HF_TOKEN)\n",
|
| 121 |
+
" HF_USERNAME = str(whoami.get(\"name\") or whoami.get(\"fullname\") or \"\")\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"if not HF_USERNAME:\n",
|
| 124 |
+
" HF_USERNAME = \"TheJackBright\"\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"MODEL_SWEEP = os.getenv(\n",
|
| 127 |
+
" \"POLYGUARD_MODEL_SWEEP\",\n",
|
| 128 |
+
" \"Qwen/Qwen2.5-0.5B-Instruct,Qwen/Qwen2.5-1.5B-Instruct,Qwen/Qwen2.5-3B-Instruct\",\n",
|
| 129 |
+
")\n",
|
| 130 |
+
"TRAINING_SPACE_REPO_ID = os.getenv(\"POLYGUARD_TRAINING_SPACE_REPO_ID\", f\"{HF_USERNAME}/polyguard-openenv-training-full\")\n",
|
| 131 |
+
"ARTIFACT_REPO_ID = os.getenv(\"POLYGUARD_ARTIFACT_REPO_ID\", f\"{HF_USERNAME}/polyguard-openenv-training-full-artifacts\")\n",
|
| 132 |
+
"PRODUCT_SPACE_REPO_ID = os.getenv(\"POLYGUARD_PRODUCT_SPACE_REPO_ID\", f\"{HF_USERNAME}/polyguard-openenv\")\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"SFT_EPOCHS = os.getenv(\"POLYGUARD_SFT_EPOCHS\", \"2\")\n",
|
| 135 |
+
"GRPO_EPOCHS = os.getenv(\"POLYGUARD_GRPO_EPOCHS\", \"1\")\n",
|
| 136 |
+
"SFT_MAX_STEPS = os.getenv(\"POLYGUARD_SFT_MAX_STEPS\", \"0\")\n",
|
| 137 |
+
"GRPO_MAX_STEPS = os.getenv(\"POLYGUARD_GRPO_MAX_STEPS\", \"0\")\n",
|
| 138 |
+
"GRPO_MAX_PROMPTS = os.getenv(\"POLYGUARD_GRPO_MAX_PROMPTS\", \"0\")\n",
|
| 139 |
+
"GRPO_NUM_GENERATIONS = os.getenv(\"POLYGUARD_GRPO_NUM_GENERATIONS\", \"2\")\n",
|
| 140 |
+
"DATA_PROFILE = os.getenv(\"POLYGUARD_DATA_PROFILE\", \"massive\")\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"RUN_REMOTE_TRAINING = os.getenv(\"POLYGUARD_RUN_REMOTE_TRAINING\", \"1\") == \"1\"\n",
|
| 143 |
+
"WAIT_FOR_REMOTE_TRAINING = os.getenv(\"POLYGUARD_WAIT_FOR_REMOTE_TRAINING\", \"1\") == \"1\"\n",
|
| 144 |
+
"RUN_LOCAL_SMOKE = os.getenv(\"POLYGUARD_RUN_LOCAL_SMOKE\", \"0\") == \"1\"\n",
|
| 145 |
+
"DEPLOY_PRODUCT_SPACE = os.getenv(\"POLYGUARD_DEPLOY_PRODUCT_SPACE\", \"1\") == \"1\"\n",
|
| 146 |
+
"PRODUCT_SPACE_PRIVATE = os.getenv(\"POLYGUARD_PRODUCT_SPACE_PRIVATE\", \"0\") == \"1\"\n",
|
| 147 |
+
"REMOTE_TIMEOUT_HOURS = float(os.getenv(\"POLYGUARD_REMOTE_TIMEOUT_HOURS\", \"12\"))\n",
|
| 148 |
+
"REMOTE_POLL_SECONDS = int(os.getenv(\"POLYGUARD_REMOTE_POLL_SECONDS\", \"300\"))\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"print(json.dumps({\n",
|
| 151 |
+
" \"hf_username\": HF_USERNAME,\n",
|
| 152 |
+
" \"model_sweep\": MODEL_SWEEP,\n",
|
| 153 |
+
" \"training_space_repo_id\": TRAINING_SPACE_REPO_ID,\n",
|
| 154 |
+
" \"artifact_repo_id\": ARTIFACT_REPO_ID,\n",
|
| 155 |
+
" \"product_space_repo_id\": PRODUCT_SPACE_REPO_ID,\n",
|
| 156 |
+
" \"data_profile\": DATA_PROFILE,\n",
|
| 157 |
+
" \"run_remote_training\": RUN_REMOTE_TRAINING,\n",
|
| 158 |
+
" \"wait_for_remote_training\": WAIT_FOR_REMOTE_TRAINING,\n",
|
| 159 |
+
" \"run_local_smoke\": RUN_LOCAL_SMOKE,\n",
|
| 160 |
+
" \"deploy_product_space\": DEPLOY_PRODUCT_SPACE,\n",
|
| 161 |
+
"}, indent=2))\n"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "markdown",
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"source": [
|
| 168 |
+
"## 1) Build Data And Training Corpora\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"This builds processed data, scenario artifacts, SFT records, and GRPO prompt episodes. The training Space repeats the full build inside its container so remote training is reproducible."
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": null,
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/bootstrap_data.py\"])\n",
|
| 180 |
+
"run([\n",
|
| 181 |
+
" \"uv\", \"run\", \"python\", \"scripts/build_training_corpus.py\",\n",
|
| 182 |
+
" \"--profile\", DATA_PROFILE,\n",
|
| 183 |
+
" \"--with-local\",\n",
|
| 184 |
+
" \"--with-synthetic\",\n",
|
| 185 |
+
" \"--with-hf\",\n",
|
| 186 |
+
"])\n",
|
| 187 |
+
"summary_path = Path(\"data/processed/training_corpus_summary.json\")\n",
|
| 188 |
+
"print(summary_path.read_text(encoding=\"utf-8\") if summary_path.exists() else \"training_corpus_summary_missing\")\n"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "markdown",
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"source": [
|
| 195 |
+
"## 2) Local Contract Checks\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"These checks verify the package, OpenEnv contract, reward bounds, and report-generation surfaces before spending GPU time."
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"cell_type": "code",
|
| 202 |
+
"execution_count": null,
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"outputs": [],
|
| 205 |
+
"source": [
|
| 206 |
+
"run([\"uv\", \"run\", \"pytest\"])\n",
|
| 207 |
+
"run([\"uv\", \"run\", \"openenv\", \"validate\", \".\"])\n"
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"cell_type": "markdown",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"source": [
|
| 214 |
+
"## 3) Optional Local Smoke SFT And GRPO\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"The final training path is the HF Space below. Set `POLYGUARD_RUN_LOCAL_SMOKE=1` only if you want a tiny local compliance run before the remote job."
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "code",
|
| 221 |
+
"execution_count": null,
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"outputs": [],
|
| 224 |
+
"source": [
|
| 225 |
+
"if RUN_LOCAL_SMOKE:\n",
|
| 226 |
+
" local_model = os.getenv(\"POLYGUARD_LOCAL_SMOKE_MODEL\", \"Qwen/Qwen2.5-0.5B-Instruct\")\n",
|
| 227 |
+
" run([\n",
|
| 228 |
+
" \"uv\", \"run\", \"python\", \"scripts/train_sft_trl.py\",\n",
|
| 229 |
+
" \"--model-id\", local_model,\n",
|
| 230 |
+
" \"--dataset-path\", \"data/processed/training_corpus_sft.json\",\n",
|
| 231 |
+
" \"--output-dir\", \"checkpoints/sft_adapter\",\n",
|
| 232 |
+
" \"--report-path\", \"outputs/reports/sft_trl_run.json\",\n",
|
| 233 |
+
" \"--epochs\", \"1\",\n",
|
| 234 |
+
" \"--max-steps\", \"20\",\n",
|
| 235 |
+
" \"--batch-size\", \"1\",\n",
|
| 236 |
+
" \"--use-unsloth\",\n",
|
| 237 |
+
" ])\n",
|
| 238 |
+
" run([\n",
|
| 239 |
+
" \"uv\", \"run\", \"python\", \"scripts/train_grpo_trl.py\",\n",
|
| 240 |
+
" \"--model-id\", local_model,\n",
|
| 241 |
+
" \"--prompts-path\", \"data/processed/training_corpus_grpo_prompts.jsonl\",\n",
|
| 242 |
+
" \"--output-dir\", \"checkpoints/grpo_adapter\",\n",
|
| 243 |
+
" \"--report-path\", \"outputs/reports/grpo_trl_run.json\",\n",
|
| 244 |
+
" \"--max-steps\", \"20\",\n",
|
| 245 |
+
" \"--max-prompts\", \"64\",\n",
|
| 246 |
+
" \"--num-generations\", \"2\",\n",
|
| 247 |
+
" \"--batch-size\", \"1\",\n",
|
| 248 |
+
" \"--use-unsloth\",\n",
|
| 249 |
+
" ])\n",
|
| 250 |
+
"else:\n",
|
| 251 |
+
" print(\"Local smoke skipped. Remote HF Space training remains the main path.\")\n"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"cell_type": "markdown",
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"source": [
|
| 258 |
+
"## 4) Start SFT Baseline And GRPO Training On Hugging Face Spaces\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"This deploys the private training Space and artifact repo, starts the Docker runner, builds the full corpus inside the Space, trains SFT as the baseline, trains GRPO with environment-backed rewards, runs post-save inference and ablations, then uploads reports, plots, adapters, and manifests."
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "code",
|
| 265 |
+
"execution_count": null,
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"outputs": [],
|
| 268 |
+
"source": [
|
| 269 |
+
"if RUN_REMOTE_TRAINING:\n",
|
| 270 |
+
" deploy_cmd = [\n",
|
| 271 |
+
" \"uv\", \"run\", \"python\", \"scripts/deploy_training_space.py\",\n",
|
| 272 |
+
" \"--repo-id\", TRAINING_SPACE_REPO_ID,\n",
|
| 273 |
+
" \"--artifact-repo-id\", ARTIFACT_REPO_ID,\n",
|
| 274 |
+
" \"--hardware\", os.getenv(\"POLYGUARD_HF_HARDWARE\", \"a10g-large\"),\n",
|
| 275 |
+
" \"--model-sweep\", MODEL_SWEEP,\n",
|
| 276 |
+
" \"--training-mode\", os.getenv(\"POLYGUARD_TRAINING_MODE\", \"full\"),\n",
|
| 277 |
+
" \"--sft-epochs\", SFT_EPOCHS,\n",
|
| 278 |
+
" \"--grpo-epochs\", GRPO_EPOCHS,\n",
|
| 279 |
+
" \"--sft-max-steps\", SFT_MAX_STEPS,\n",
|
| 280 |
+
" \"--grpo-max-steps\", GRPO_MAX_STEPS,\n",
|
| 281 |
+
" \"--grpo-max-prompts\", GRPO_MAX_PROMPTS,\n",
|
| 282 |
+
" \"--grpo-num-generations\", GRPO_NUM_GENERATIONS,\n",
|
| 283 |
+
" ]\n",
|
| 284 |
+
" if os.getenv(\"POLYGUARD_TRAINING_SPACE_PUBLIC\", \"0\") == \"1\":\n",
|
| 285 |
+
" deploy_cmd.append(\"--public\")\n",
|
| 286 |
+
" run(deploy_cmd)\n",
|
| 287 |
+
" print(f\"Training Space: https://huggingface.co/spaces/{TRAINING_SPACE_REPO_ID}\")\n",
|
| 288 |
+
" print(f\"Artifact repo: https://huggingface.co/{ARTIFACT_REPO_ID}\")\n",
|
| 289 |
+
"else:\n",
|
| 290 |
+
" print(\"Remote training deployment skipped by POLYGUARD_RUN_REMOTE_TRAINING=0\")\n"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "markdown",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"source": [
|
| 297 |
+
"## 5) Monitor Space And Pull Artifacts\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"If `POLYGUARD_WAIT_FOR_REMOTE_TRAINING=1`, this cell keeps polling until `scripts/pull_training_artifacts.py` succeeds or the timeout is reached. It never prints the token."
|
| 300 |
+
]
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"cell_type": "code",
|
| 304 |
+
"execution_count": null,
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"outputs": [],
|
| 307 |
+
"source": [
|
| 308 |
+
"monitor_output = \"outputs/reports/training_space_runtime_status.json\"\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"def monitor_once() -> int:\n",
|
| 311 |
+
" return run([\n",
|
| 312 |
+
" \"uv\", \"run\", \"python\", \"scripts/monitor_training_space_status.py\",\n",
|
| 313 |
+
" \"--space-id\", TRAINING_SPACE_REPO_ID,\n",
|
| 314 |
+
" \"--artifact-repo-id\", ARTIFACT_REPO_ID,\n",
|
| 315 |
+
" \"--output\", monitor_output,\n",
|
| 316 |
+
" ], check=False).returncode\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"def pull_once() -> bool:\n",
|
| 319 |
+
" return run([\n",
|
| 320 |
+
" \"uv\", \"run\", \"python\", \"scripts/pull_training_artifacts.py\",\n",
|
| 321 |
+
" \"--artifact-repo-id\", ARTIFACT_REPO_ID,\n",
|
| 322 |
+
" ], check=False).returncode == 0\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"pulled = False\n",
|
| 325 |
+
"if RUN_REMOTE_TRAINING and WAIT_FOR_REMOTE_TRAINING:\n",
|
| 326 |
+
" deadline = time.time() + REMOTE_TIMEOUT_HOURS * 3600\n",
|
| 327 |
+
" attempt = 0\n",
|
| 328 |
+
" while time.time() < deadline:\n",
|
| 329 |
+
" attempt += 1\n",
|
| 330 |
+
" print(f\"Remote poll {attempt}\")\n",
|
| 331 |
+
" monitor_once()\n",
|
| 332 |
+
" pulled = pull_once()\n",
|
| 333 |
+
" if pulled:\n",
|
| 334 |
+
" print(\"Remote training artifacts pulled successfully.\")\n",
|
| 335 |
+
" break\n",
|
| 336 |
+
" print(f\"Artifacts not ready yet. Sleeping {REMOTE_POLL_SECONDS} seconds.\")\n",
|
| 337 |
+
" time.sleep(REMOTE_POLL_SECONDS)\n",
|
| 338 |
+
" if not pulled:\n",
|
| 339 |
+
" raise TimeoutError(\"Remote training did not produce pullable artifacts before timeout.\")\n",
|
| 340 |
+
"else:\n",
|
| 341 |
+
" monitor_once()\n",
|
| 342 |
+
" pulled = pull_once()\n",
|
| 343 |
+
" print(f\"Single pull attempt success: {pulled}\")\n"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"cell_type": "markdown",
|
| 348 |
+
"metadata": {},
|
| 349 |
+
"source": [
|
| 350 |
+
"## 6) Generate Reports, Charts, And Evidence Bundles\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"This creates SFT-vs-GRPO charts, Qwen model comparison charts, reward component bars, anti-hacking/overfit checks, basic-LLM-vs-PolyGuard evidence, action traces, and curated submission evidence folders."
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"cell_type": "code",
|
| 357 |
+
"execution_count": null,
|
| 358 |
+
"metadata": {},
|
| 359 |
+
"outputs": [],
|
| 360 |
+
"source": [
|
| 361 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/generate_hf_training_report.py\", \"--mode\", os.getenv(\"POLYGUARD_TRAINING_MODE\", \"full\")], check=False)\n",
|
| 362 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/evaluate_policy_ablations.py\", \"--episodes\", os.getenv(\"POLYGUARD_ABLATION_EPISODES\", \"8\")], check=False)\n",
|
| 363 |
+
"run([\n",
|
| 364 |
+
" \"uv\", \"run\", \"python\", \"scripts/generate_submission_evidence.py\",\n",
|
| 365 |
+
" \"--models\", os.getenv(\"POLYGUARD_EVIDENCE_MODELS\", \"qwen-qwen2-5-0-5b-instruct,qwen-qwen2-5-1-5b-instruct\"),\n",
|
| 366 |
+
" \"--artifact-repo-id\", ARTIFACT_REPO_ID,\n",
|
| 367 |
+
" \"--training-space-url\", f\"https://{TRAINING_SPACE_REPO_ID.replace('/', '-').lower()}.hf.space\",\n",
|
| 368 |
+
" \"--episodes\", os.getenv(\"POLYGUARD_EVIDENCE_EPISODES\", \"8\"),\n",
|
| 369 |
+
"], check=False)\n",
|
| 370 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/build_improvement_evidence_bundle.py\"], check=False)\n"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "markdown",
|
| 375 |
+
"metadata": {},
|
| 376 |
+
"source": [
|
| 377 |
+
"## 7) Activate A Model For Product Inference And Validate Post-Save Inference\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"The app reads `checkpoints/active/active_model_manifest.json`. The default active run is Qwen 0.5B because it is the smallest practical implementation target; switch `POLYGUARD_ACTIVE_RUN_ID` to the 1.5B or 3B run after those artifacts are pulled."
|
| 380 |
+
]
|
| 381 |
+
},
|
| 382 |
+
{
|
| 383 |
+
"cell_type": "code",
|
| 384 |
+
"execution_count": null,
|
| 385 |
+
"metadata": {},
|
| 386 |
+
"outputs": [],
|
| 387 |
+
"source": [
|
| 388 |
+
"ACTIVE_RUN_ID = os.getenv(\"POLYGUARD_ACTIVE_RUN_ID\", \"qwen-qwen2-5-0-5b-instruct\")\n",
|
| 389 |
+
"run([\n",
|
| 390 |
+
" \"uv\", \"run\", \"python\", \"scripts/activate_sweep_model.py\",\n",
|
| 391 |
+
" \"--source\", \"sweep\",\n",
|
| 392 |
+
" \"--run-id\", ACTIVE_RUN_ID,\n",
|
| 393 |
+
" \"--preferred-artifact\", os.getenv(\"POLYGUARD_PREFERRED_ARTIFACT\", \"grpo_adapter\"),\n",
|
| 394 |
+
"], check=False)\n",
|
| 395 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/test_inference_postsave.py\", \"--samples\", os.getenv(\"POLYGUARD_INFERENCE_SAMPLES\", \"3\")], check=False)\n",
|
| 396 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/benchmark_inference.py\"], check=False)\n"
|
| 397 |
+
]
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"cell_type": "markdown",
|
| 401 |
+
"metadata": {},
|
| 402 |
+
"source": [
|
| 403 |
+
"## 8) Deploy The Product OpenEnv Space\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"This deploys the FastAPI/OpenEnv product Space. It is separate from the private GPU training Space."
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"cell_type": "code",
|
| 410 |
+
"execution_count": null,
|
| 411 |
+
"metadata": {},
|
| 412 |
+
"outputs": [],
|
| 413 |
+
"source": [
|
| 414 |
+
"if DEPLOY_PRODUCT_SPACE:\n",
|
| 415 |
+
" product_cmd = [\"uv\", \"run\", \"python\", \"scripts/deploy_space_api.py\", \"--repo-id\", PRODUCT_SPACE_REPO_ID]\n",
|
| 416 |
+
" if PRODUCT_SPACE_PRIVATE:\n",
|
| 417 |
+
" product_cmd.append(\"--private\")\n",
|
| 418 |
+
" run(product_cmd)\n",
|
| 419 |
+
" runtime_url = f\"https://{PRODUCT_SPACE_REPO_ID.replace('/', '-').lower()}.hf.space\"\n",
|
| 420 |
+
" run([\"uv\", \"run\", \"openenv\", \"validate\", \"--url\", runtime_url], check=False)\n",
|
| 421 |
+
" print(f\"Product Space: https://huggingface.co/spaces/{PRODUCT_SPACE_REPO_ID}\")\n",
|
| 422 |
+
" print(f\"Runtime URL: {runtime_url}\")\n",
|
| 423 |
+
"else:\n",
|
| 424 |
+
" print(\"Product Space deploy skipped by POLYGUARD_DEPLOY_PRODUCT_SPACE=0\")\n"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "markdown",
|
| 429 |
+
"metadata": {},
|
| 430 |
+
"source": [
|
| 431 |
+
"## 9) Final Acceptance Gate And Output Summary"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"cell_type": "code",
|
| 436 |
+
"execution_count": null,
|
| 437 |
+
"metadata": {},
|
| 438 |
+
"outputs": [],
|
| 439 |
+
"source": [
|
| 440 |
+
"run([\"uv\", \"run\", \"python\", \"scripts/acceptance_gate.py\"], check=False)\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"summary = {\n",
|
| 443 |
+
" \"training_space\": f\"https://huggingface.co/spaces/{TRAINING_SPACE_REPO_ID}\",\n",
|
| 444 |
+
" \"artifact_repo\": f\"https://huggingface.co/{ARTIFACT_REPO_ID}\",\n",
|
| 445 |
+
" \"product_space\": f\"https://huggingface.co/spaces/{PRODUCT_SPACE_REPO_ID}\",\n",
|
| 446 |
+
" \"reports\": [\n",
|
| 447 |
+
" \"outputs/reports/hf_sweep_summary.json\",\n",
|
| 448 |
+
" \"outputs/reports/anti_hacking_overfit_report.json\",\n",
|
| 449 |
+
" \"outputs/reports/postsave_inference.json\",\n",
|
| 450 |
+
" \"docs/results/submission_evidence_qwen_0_5b_1_5b/README.md\",\n",
|
| 451 |
+
" \"docs/results/model_improvement_evidence_qwen_0_5b_1_5b/README.md\",\n",
|
| 452 |
+
" ],\n",
|
| 453 |
+
" \"plots_dir\": \"outputs/plots\",\n",
|
| 454 |
+
" \"active_model_manifest\": \"checkpoints/active/active_model_manifest.json\",\n",
|
| 455 |
+
"}\n",
|
| 456 |
+
"print(json.dumps(summary, indent=2))\n"
|
| 457 |
+
]
|
| 458 |
+
}
|
| 459 |
+
],
|
| 460 |
+
"metadata": {
|
| 461 |
+
"kernelspec": {
|
| 462 |
+
"display_name": "Python 3",
|
| 463 |
+
"language": "python",
|
| 464 |
+
"name": "python3"
|
| 465 |
+
},
|
| 466 |
+
"language_info": {
|
| 467 |
+
"codemirror_mode": {
|
| 468 |
+
"name": "ipython",
|
| 469 |
+
"version": 3
|
| 470 |
+
},
|
| 471 |
+
"file_extension": ".py",
|
| 472 |
+
"mimetype": "text/x-python",
|
| 473 |
+
"name": "python",
|
| 474 |
+
"nbconvert_exporter": "python",
|
| 475 |
+
"pygments_lexer": "ipython3",
|
| 476 |
+
"version": "3.11"
|
| 477 |
+
}
|
| 478 |
+
},
|
| 479 |
+
"nbformat": 4,
|
| 480 |
+
"nbformat_minor": 5
|
| 481 |
+
}
|
README.md
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PolyGuard OpenEnv
|
| 3 |
+
emoji: 🛡️
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
+
Full-stack **PolyGuard** workbench: OpenEnv (WebSocket), FastAPI, and React UI behind nginx on `PORT`. Uses **CPU basic**; first cold start downloads the public [usable model bundle](https://huggingface.co/TheJackBright/polyguard-openenv-training-full-artifacts/tree/main/usable_model_bundles/local-qwen-0-5b-active-smoke) (~1.1 GB). See `docker/space/README.md` for details.
|
README_HF_SPACE.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PolyGuard OpenEnv
|
| 3 |
+
emoji: 🛡️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Full-stack **PolyGuard** workbench: OpenEnv (WebSocket), FastAPI, and React UI behind nginx on `PORT`. Uses **CPU basic**; first cold start downloads the public [usable model bundle](https://huggingface.co/TheJackBright/polyguard-openenv-training-full-artifacts/tree/main/usable_model_bundles/local-qwen-0-5b-active-smoke) (~1.1 GB). See `docker/space/README.md` for details.
|
__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Root OpenEnv package shim for POLYGUARD-OPENENV."""
|
| 2 |
+
|
| 3 |
+
from app.env.env_core import PolyGuardEnv
|
| 4 |
+
|
| 5 |
+
__all__ = ["PolyGuardEnv"]
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""POLYGUARD-RL application package."""
|
app/agents/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent package."""
|
| 2 |
+
|
| 3 |
+
from app.agents.orchestrator import Orchestrator
|
| 4 |
+
|
| 5 |
+
__all__ = ["Orchestrator"]
|
app/agents/candidate_agent.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Candidate generation agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.models.policy.candidate_builder import build_candidates
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CandidateAgent:
|
| 10 |
+
name = "CandidateAgent"
|
| 11 |
+
|
| 12 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 13 |
+
candidates = build_candidates(state)
|
| 14 |
+
return {"candidates": [c.model_dump(mode="json") for c in candidates]}
|
app/agents/critic_agent.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Safety critic agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.enums import ActionType, DecisionMode, DoseBucket
|
| 6 |
+
from app.common.types import PolyGuardAction, PolyGuardState
|
| 7 |
+
from app.env.verifier import verify_action_legality
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CriticAgent:
|
| 11 |
+
name = "CriticAgent"
|
| 12 |
+
|
| 13 |
+
def run(self, state: PolyGuardState, proposed: PolyGuardAction) -> dict:
|
| 14 |
+
report = verify_action_legality(state, proposed)
|
| 15 |
+
if report.legal:
|
| 16 |
+
report_payload = report.model_dump(mode="json")
|
| 17 |
+
return {
|
| 18 |
+
"approved": True,
|
| 19 |
+
"report": report_payload,
|
| 20 |
+
"final_action": proposed,
|
| 21 |
+
"legal": True,
|
| 22 |
+
"violations": report_payload.get("violations", []),
|
| 23 |
+
}
|
| 24 |
+
fallback = PolyGuardAction(
|
| 25 |
+
mode=DecisionMode.REVIEW,
|
| 26 |
+
action_type=ActionType.REQUEST_SPECIALIST_REVIEW,
|
| 27 |
+
target_drug=None,
|
| 28 |
+
replacement_drug=None,
|
| 29 |
+
dose_bucket=DoseBucket.NA,
|
| 30 |
+
taper_days=None,
|
| 31 |
+
monitoring_plan="critic_veto",
|
| 32 |
+
candidate_id="cand_veto_fallback",
|
| 33 |
+
confidence=0.62,
|
| 34 |
+
rationale_brief=f"Critic veto: {', '.join(report.violations)}",
|
| 35 |
+
)
|
| 36 |
+
report_payload = report.model_dump(mode="json")
|
| 37 |
+
return {
|
| 38 |
+
"approved": False,
|
| 39 |
+
"report": report_payload,
|
| 40 |
+
"final_action": fallback,
|
| 41 |
+
"legal": False,
|
| 42 |
+
"violations": report_payload.get("violations", []),
|
| 43 |
+
}
|
app/agents/critic_safety_agent.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical CriticSafety agent module.
|
| 2 |
+
|
| 3 |
+
This file preserves required naming while reusing the current critic
|
| 4 |
+
implementation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from app.agents.critic_agent import CriticAgent as CriticSafetyAgent
|
| 10 |
+
|
| 11 |
+
__all__ = ["CriticSafetyAgent"]
|
app/agents/dosing_agent.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dosing analysis agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.drug_catalog import DRUG_CLASSES
|
| 7 |
+
from app.models.dosing.dose_policy_features import build_dose_features
|
| 8 |
+
from app.models.dosing.infer import infer_dosing_quality
|
| 9 |
+
from app.models.dosing.pkpd_state import PKPDState
|
| 10 |
+
from app.models.dosing.surrogate_pkpd import step_pkpd
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DosingAgent:
|
| 14 |
+
name = "DosingAgent"
|
| 15 |
+
|
| 16 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 17 |
+
sensitive_classes = {"anticoagulant", "sedative", "glucose_lowering"}
|
| 18 |
+
dose_sensitive = [
|
| 19 |
+
m.drug
|
| 20 |
+
for m in state.patient.medications
|
| 21 |
+
if DRUG_CLASSES.get(m.drug) in sensitive_classes
|
| 22 |
+
][:3]
|
| 23 |
+
analyses: list[dict] = []
|
| 24 |
+
for drug in dose_sensitive:
|
| 25 |
+
feats = build_dose_features(state.patient, drug)
|
| 26 |
+
base_state = PKPDState(
|
| 27 |
+
effect_level=min(1.0, 0.35 + feats["adherence"] * 0.45),
|
| 28 |
+
toxicity_level=min(1.0, 0.08 + feats["organ_stress"] * 0.4),
|
| 29 |
+
underdose_risk=max(0.0, 1.0 - (0.35 + feats["adherence"] * 0.45)),
|
| 30 |
+
organ_stress=feats["organ_stress"],
|
| 31 |
+
interaction_load=feats["interaction_load"],
|
| 32 |
+
)
|
| 33 |
+
lower = infer_dosing_quality(step_pkpd(base_state, dose_delta=-0.2, organ_factor=feats["organ_stress"], interaction_factor=feats["interaction_load"]))
|
| 34 |
+
hold = infer_dosing_quality(step_pkpd(base_state, dose_delta=0.0, organ_factor=feats["organ_stress"], interaction_factor=feats["interaction_load"]))
|
| 35 |
+
higher = infer_dosing_quality(step_pkpd(base_state, dose_delta=0.2, organ_factor=feats["organ_stress"], interaction_factor=feats["interaction_load"]))
|
| 36 |
+
analyses.append(
|
| 37 |
+
{
|
| 38 |
+
"drug": drug,
|
| 39 |
+
"features": feats,
|
| 40 |
+
"options": {
|
| 41 |
+
"reduce": lower,
|
| 42 |
+
"hold": hold,
|
| 43 |
+
"increase": higher,
|
| 44 |
+
},
|
| 45 |
+
}
|
| 46 |
+
)
|
| 47 |
+
return {
|
| 48 |
+
"dose_sensitive_drugs": dose_sensitive,
|
| 49 |
+
"dosing_active": bool(dose_sensitive),
|
| 50 |
+
"recommend_mode": "DOSE_OPT" if dose_sensitive else "REGIMEN_OPT",
|
| 51 |
+
"analyses": analyses,
|
| 52 |
+
}
|
app/agents/evidence_agent.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evidence retrieval agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.evidence_retriever import retrieve_evidence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EvidenceAgent:
|
| 10 |
+
name = "EvidenceAgent"
|
| 11 |
+
|
| 12 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 13 |
+
query = " ".join(state.patient.comorbidities + [m.drug for m in state.patient.medications[:2]])
|
| 14 |
+
return {"evidence": retrieve_evidence(query=query, top_k=3)}
|
app/agents/explainer_agent.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Explanation agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardAction, PolyGuardState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ExplainerAgent:
|
| 9 |
+
name = "ExplainerAgent"
|
| 10 |
+
|
| 11 |
+
def run(self, state: PolyGuardState, action: PolyGuardAction, critic_report: dict) -> dict:
|
| 12 |
+
return {
|
| 13 |
+
"explanation": (
|
| 14 |
+
f"Action {action.action_type.value} selected for mode {action.mode.value}. "
|
| 15 |
+
f"Burden score={state.burden_score:.3f}, meds={len(state.patient.medications)}. "
|
| 16 |
+
f"Critic legal={critic_report.get('legal', False)}."
|
| 17 |
+
),
|
| 18 |
+
"grounded_facts": {
|
| 19 |
+
"burden_score": state.burden_score,
|
| 20 |
+
"polypharmacy_count": len(state.patient.medications),
|
| 21 |
+
},
|
| 22 |
+
}
|
app/agents/graph_agent.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph safety agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.ddi_knowledge import top_risky_pairs
|
| 7 |
+
from app.models.graph.infer import infer_graph_risk
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GraphSafetyAgent:
|
| 11 |
+
name = "GraphSafetyAgent"
|
| 12 |
+
|
| 13 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 14 |
+
drugs = [m.drug for m in state.patient.medications]
|
| 15 |
+
risk = infer_graph_risk(drugs)
|
| 16 |
+
top_pairs = top_risky_pairs(drugs)
|
| 17 |
+
triples = []
|
| 18 |
+
if len(drugs) >= 3:
|
| 19 |
+
triples = [
|
| 20 |
+
[drugs[i], drugs[i + 1], drugs[i + 2]]
|
| 21 |
+
for i in range(min(2, len(drugs) - 2))
|
| 22 |
+
]
|
| 23 |
+
return {
|
| 24 |
+
**risk,
|
| 25 |
+
"top_dangerous_pairs": top_pairs[:5],
|
| 26 |
+
"top_dangerous_triples": triples,
|
| 27 |
+
"mechanism_tags": list(risk.get("side_effect_probs", {}).keys())[:5],
|
| 28 |
+
}
|
app/agents/graph_safety_agent.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical GraphSafety agent module.
|
| 2 |
+
|
| 3 |
+
This file is kept for required path compatibility and re-exports the
|
| 4 |
+
implementation from ``graph_agent.py``.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from app.agents.graph_agent import GraphSafetyAgent
|
| 10 |
+
|
| 11 |
+
__all__ = ["GraphSafetyAgent"]
|
app/agents/medrec_agent.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Medication reconciliation agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.types import PolyGuardState
|
| 6 |
+
from app.knowledge.drug_catalog import canonicalize_drug_name
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MedRecAgent:
|
| 10 |
+
name = "MedRecAgent"
|
| 11 |
+
|
| 12 |
+
def run(self, state: PolyGuardState) -> dict:
|
| 13 |
+
normalized = []
|
| 14 |
+
duplicates = set()
|
| 15 |
+
seen = set()
|
| 16 |
+
for med in state.patient.medications:
|
| 17 |
+
med.drug = canonicalize_drug_name(med.drug)
|
| 18 |
+
normalized.append(med.drug)
|
| 19 |
+
if med.drug in seen:
|
| 20 |
+
duplicates.add(med.drug)
|
| 21 |
+
seen.add(med.drug)
|
| 22 |
+
return {"normalized_meds": normalized, "duplicates": sorted(duplicates)}
|
app/agents/orchestrator.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-agent orchestration graph."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from app.agents.candidate_agent import CandidateAgent
|
| 9 |
+
from app.agents.critic_agent import CriticAgent
|
| 10 |
+
from app.agents.dosing_agent import DosingAgent
|
| 11 |
+
from app.agents.evidence_agent import EvidenceAgent
|
| 12 |
+
from app.agents.explainer_agent import ExplainerAgent
|
| 13 |
+
from app.agents.graph_agent import GraphSafetyAgent
|
| 14 |
+
from app.agents.medrec_agent import MedRecAgent
|
| 15 |
+
from app.agents.planner_agent import PlannerAgent
|
| 16 |
+
from app.agents.supervisor_agent import SupervisorAgent
|
| 17 |
+
from app.common.enums import CoordinationMode
|
| 18 |
+
from app.common.types import CandidateAction, PolyGuardAction
|
| 19 |
+
from app.env.env_core import PolyGuardEnv
|
| 20 |
+
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Orchestrator:
|
| 24 |
+
def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None:
|
| 25 |
+
self.env = env
|
| 26 |
+
self.coordination_mode = coordination_mode
|
| 27 |
+
self.medrec = MedRecAgent()
|
| 28 |
+
self.evidence = EvidenceAgent()
|
| 29 |
+
self.graph = GraphSafetyAgent()
|
| 30 |
+
self.dosing = DosingAgent()
|
| 31 |
+
self.candidate = CandidateAgent()
|
| 32 |
+
self.supervisor = SupervisorAgent()
|
| 33 |
+
self.planner = PlannerAgent()
|
| 34 |
+
self.critic = CriticAgent()
|
| 35 |
+
self.explainer = ExplainerAgent()
|
| 36 |
+
bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower()
|
| 37 |
+
if bandit_algo not in {"linucb", "thompson"}:
|
| 38 |
+
bandit_algo = "linucb"
|
| 39 |
+
self.bandit = ContextualBanditPolicy(
|
| 40 |
+
algorithm=bandit_algo, # type: ignore[arg-type]
|
| 41 |
+
alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")),
|
| 42 |
+
epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")),
|
| 43 |
+
seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")),
|
| 44 |
+
)
|
| 45 |
+
self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower()
|
| 46 |
+
self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3"))
|
| 47 |
+
|
| 48 |
+
def set_mode(self, coordination_mode: CoordinationMode) -> None:
|
| 49 |
+
self.coordination_mode = coordination_mode
|
| 50 |
+
|
| 51 |
+
def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]:
|
| 52 |
+
if coordination_mode is not None:
|
| 53 |
+
self.coordination_mode = CoordinationMode(coordination_mode)
|
| 54 |
+
state = self.env.state
|
| 55 |
+
medrec_out = self.medrec.run(state)
|
| 56 |
+
evidence_out = self.evidence.run(state)
|
| 57 |
+
graph_out = self.graph.run(state)
|
| 58 |
+
dosing_out = self.dosing.run(state)
|
| 59 |
+
candidate_out = self.candidate.run(state)
|
| 60 |
+
candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]]
|
| 61 |
+
|
| 62 |
+
supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"])
|
| 63 |
+
planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates
|
| 64 |
+
if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW":
|
| 65 |
+
planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates
|
| 66 |
+
|
| 67 |
+
candidate_by_id = {item.candidate_id: item for item in planner_candidates}
|
| 68 |
+
bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k)
|
| 69 |
+
bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id]
|
| 70 |
+
if not bandit_candidates:
|
| 71 |
+
bandit_candidates = planner_candidates
|
| 72 |
+
|
| 73 |
+
if self.policy_stack == "bandit-only":
|
| 74 |
+
selected = bandit_candidates[0]
|
| 75 |
+
proposed = PolyGuardAction(
|
| 76 |
+
mode=selected.mode,
|
| 77 |
+
action_type=selected.action_type,
|
| 78 |
+
target_drug=selected.target_drug,
|
| 79 |
+
replacement_drug=selected.replacement_drug,
|
| 80 |
+
dose_bucket=selected.dose_bucket,
|
| 81 |
+
taper_days=selected.taper_days,
|
| 82 |
+
monitoring_plan=selected.monitoring_plan,
|
| 83 |
+
candidate_id=selected.candidate_id,
|
| 84 |
+
confidence=max(0.45, 1.0 - selected.uncertainty_score),
|
| 85 |
+
rationale_brief="Bandit-only policy selected top contextual candidate.",
|
| 86 |
+
)
|
| 87 |
+
elif self.policy_stack == "llm-only":
|
| 88 |
+
proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"])
|
| 89 |
+
else:
|
| 90 |
+
proposed = self.planner.run(
|
| 91 |
+
candidates=bandit_candidates,
|
| 92 |
+
mode=supervisor_out["mode"],
|
| 93 |
+
provider_prompt={
|
| 94 |
+
"coordination_mode": self.coordination_mode.value,
|
| 95 |
+
"policy_stack": self.policy_stack,
|
| 96 |
+
"candidate_count": len(bandit_candidates),
|
| 97 |
+
"sub_environment": state.sub_environment.value,
|
| 98 |
+
},
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
critic_out = self.critic.run(state, proposed)
|
| 102 |
+
final_action: PolyGuardAction = critic_out["final_action"]
|
| 103 |
+
replan_triggered = False
|
| 104 |
+
debate_rounds = 0
|
| 105 |
+
|
| 106 |
+
if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]:
|
| 107 |
+
replan_triggered = True
|
| 108 |
+
review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates
|
| 109 |
+
proposed = self.planner.run(candidates=review_candidates, mode="REVIEW")
|
| 110 |
+
critic_out = self.critic.run(state, proposed)
|
| 111 |
+
final_action = critic_out["final_action"]
|
| 112 |
+
debate_rounds = 1
|
| 113 |
+
|
| 114 |
+
if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type:
|
| 115 |
+
debate_rounds = 2
|
| 116 |
+
|
| 117 |
+
obs, reward, done, info = self.env.step(final_action)
|
| 118 |
+
selected_for_update = candidate_by_id.get(final_action.candidate_id)
|
| 119 |
+
if selected_for_update is not None:
|
| 120 |
+
self.bandit.update(selected_for_update, reward=reward)
|
| 121 |
+
|
| 122 |
+
explanation_out = self.explainer.run(state, final_action, critic_out["report"])
|
| 123 |
+
return {
|
| 124 |
+
"medrec": medrec_out,
|
| 125 |
+
"evidence": evidence_out,
|
| 126 |
+
"graph": graph_out,
|
| 127 |
+
"dosing": dosing_out,
|
| 128 |
+
"supervisor": supervisor_out,
|
| 129 |
+
"proposed_action": proposed.model_dump(mode="json"),
|
| 130 |
+
"critic": critic_out["report"],
|
| 131 |
+
"final_action": final_action.model_dump(mode="json"),
|
| 132 |
+
"observation": obs.model_dump(mode="json"),
|
| 133 |
+
"reward": reward,
|
| 134 |
+
"done": done,
|
| 135 |
+
"info": info,
|
| 136 |
+
"explanation": explanation_out,
|
| 137 |
+
"coordination_mode": self.coordination_mode.value,
|
| 138 |
+
"policy_stack": self.policy_stack,
|
| 139 |
+
"bandit_topk": [item.candidate_id for item in bandit_candidates],
|
| 140 |
+
"bandit_scores": [
|
| 141 |
+
{
|
| 142 |
+
"candidate_id": item.candidate_id,
|
| 143 |
+
"score": item.score,
|
| 144 |
+
"exploration_bonus": item.exploration_bonus,
|
| 145 |
+
"algorithm": item.algorithm,
|
| 146 |
+
}
|
| 147 |
+
for item in bandit_proposals
|
| 148 |
+
],
|
| 149 |
+
"replan_triggered": replan_triggered,
|
| 150 |
+
"debate_rounds": debate_rounds,
|
| 151 |
+
}
|
app/agents/planner_agent.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Planner agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from app.common.types import CandidateAction, PolyGuardAction
|
| 8 |
+
from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference
|
| 9 |
+
from app.models.policy.safety_ranker import rank_candidates
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PlannerAgent:
|
| 13 |
+
name = "PlannerAgent"
|
| 14 |
+
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
self.provider_router = PolicyProviderRouter()
|
| 17 |
+
|
| 18 |
+
def run(
|
| 19 |
+
self,
|
| 20 |
+
candidates: list[CandidateAction],
|
| 21 |
+
mode: str,
|
| 22 |
+
provider_prompt: dict[str, Any] | None = None,
|
| 23 |
+
provider_preference: tuple[str, ...] | None = None,
|
| 24 |
+
) -> PolyGuardAction:
|
| 25 |
+
filtered = [c for c in candidates if c.mode.value == mode] or candidates
|
| 26 |
+
selection = self.provider_router.select_candidate(
|
| 27 |
+
candidates=filtered,
|
| 28 |
+
prompt=provider_prompt or {"mode": mode},
|
| 29 |
+
provider_preference=provider_preference or default_provider_preference(),
|
| 30 |
+
)
|
| 31 |
+
by_id = {item.candidate_id: item for item in filtered}
|
| 32 |
+
top = by_id.get(selection.candidate_id, rank_candidates(filtered)[0])
|
| 33 |
+
return PolyGuardAction(
|
| 34 |
+
mode=top.mode,
|
| 35 |
+
action_type=top.action_type,
|
| 36 |
+
target_drug=top.target_drug,
|
| 37 |
+
replacement_drug=top.replacement_drug,
|
| 38 |
+
dose_bucket=top.dose_bucket,
|
| 39 |
+
taper_days=top.taper_days,
|
| 40 |
+
monitoring_plan=top.monitoring_plan,
|
| 41 |
+
candidate_id=top.candidate_id,
|
| 42 |
+
confidence=max(0.45, 1.0 - top.uncertainty_score),
|
| 43 |
+
rationale_brief=selection.rationale,
|
| 44 |
+
)
|
app/agents/supervisor_agent.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supervisor agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.enums import DecisionMode
|
| 6 |
+
from app.common.types import PolyGuardState
|
| 7 |
+
from app.models.policy.uncertainty import estimate_uncertainty
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SupervisorAgent:
|
| 11 |
+
name = "SupervisorAgent"
|
| 12 |
+
|
| 13 |
+
def run(self, state: PolyGuardState, dosing_active: bool) -> dict:
|
| 14 |
+
uncertainty = estimate_uncertainty(state)
|
| 15 |
+
if uncertainty > 0.72:
|
| 16 |
+
mode = DecisionMode.REVIEW
|
| 17 |
+
elif state.sub_environment.value == "PRECISION_DOSING":
|
| 18 |
+
mode = DecisionMode.DOSE_OPT
|
| 19 |
+
elif dosing_active:
|
| 20 |
+
mode = DecisionMode.DOSE_OPT
|
| 21 |
+
else:
|
| 22 |
+
mode = DecisionMode.REGIMEN_OPT
|
| 23 |
+
return {"mode": mode.value, "uncertainty": uncertainty, "sub_environment": state.sub_environment.value}
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API application entrypoint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import uvicorn
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
|
| 11 |
+
from app.common.config import load_project_env
|
| 12 |
+
from app.api.routes import router
|
| 13 |
+
|
| 14 |
+
load_project_env()
|
| 15 |
+
|
| 16 |
+
_cors_local = [
|
| 17 |
+
"http://127.0.0.1:5173",
|
| 18 |
+
"http://localhost:5173",
|
| 19 |
+
]
|
| 20 |
+
_extra = os.getenv("POLYGUARD_CORS_ORIGINS", "").strip()
|
| 21 |
+
if _extra and _extra != "*":
|
| 22 |
+
_cors_local = _cors_local + [o.strip() for o in _extra.split(",") if o.strip()]
|
| 23 |
+
_hf_space_regex = None
|
| 24 |
+
if os.getenv("POLYGUARD_ALLOW_HF_SPACE_CORS", "").lower() in {"1", "true", "yes", "on"}:
|
| 25 |
+
_hf_space_regex = r"https://.*\.hf\.space"
|
| 26 |
+
|
| 27 |
+
app = FastAPI(title="POLYGUARD-RL API", version="0.1.0")
|
| 28 |
+
app.add_middleware(
|
| 29 |
+
CORSMiddleware,
|
| 30 |
+
allow_origins=_cors_local,
|
| 31 |
+
allow_origin_regex=_hf_space_regex,
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
app.include_router(router)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main() -> None:
|
| 40 |
+
host = os.getenv("POLYGUARD_API_HOST", "127.0.0.1")
|
| 41 |
+
port = int(os.getenv("POLYGUARD_API_PORT", "8200"))
|
| 42 |
+
uvicorn.run("app.api:app", host=host, port=port, reload=False)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
app/api/__main__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run API with `python -m app.api`."""
|
| 2 |
+
|
| 3 |
+
from app.api import main
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
main()
|
app/api/dependencies.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API dependencies."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.api.service import APIService
|
| 6 |
+
|
| 7 |
+
_SERVICE = APIService()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_service() -> APIService:
|
| 11 |
+
return _SERVICE
|
app/api/main.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical API module path.
|
| 2 |
+
|
| 3 |
+
Keeps compatibility with required file path while reusing ``app.api`` app.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from app.api import app, main
|
| 9 |
+
|
| 10 |
+
__all__ = ["app", "main"]
|
app/api/routes.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API routes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 6 |
+
|
| 7 |
+
from app.api.dependencies import get_service
|
| 8 |
+
from app.api.schemas import (
|
| 9 |
+
BatchInferRequest,
|
| 10 |
+
EvidenceQueryRequest,
|
| 11 |
+
OrchestrateRequest,
|
| 12 |
+
ResetRequest,
|
| 13 |
+
StepCandidateRequest,
|
| 14 |
+
StepRequest,
|
| 15 |
+
)
|
| 16 |
+
from app.api.service import APIService
|
| 17 |
+
|
| 18 |
+
router = APIRouter()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@router.get("/health")
|
| 22 |
+
def health() -> dict[str, str]:
|
| 23 |
+
return {"status": "ok"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.post("/env/reset")
|
| 27 |
+
def env_reset(payload: ResetRequest, service: APIService = Depends(get_service)) -> dict:
|
| 28 |
+
try:
|
| 29 |
+
return service.reset(**payload.model_dump(mode="json"))
|
| 30 |
+
except ValueError as exc:
|
| 31 |
+
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@router.post("/env/step")
|
| 35 |
+
def env_step(payload: StepRequest, service: APIService = Depends(get_service)) -> dict:
|
| 36 |
+
return service.step(payload.model_dump(mode="json"))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.post("/env/step_candidate")
|
| 40 |
+
def env_step_candidate(payload: StepCandidateRequest, service: APIService = Depends(get_service)) -> dict:
|
| 41 |
+
result = service.step_candidate(
|
| 42 |
+
candidate_id=payload.candidate_id,
|
| 43 |
+
confidence=payload.confidence,
|
| 44 |
+
rationale_brief=payload.rationale_brief,
|
| 45 |
+
)
|
| 46 |
+
if result is None:
|
| 47 |
+
raise HTTPException(status_code=404, detail=f"Candidate {payload.candidate_id!r} is not legal in this state.")
|
| 48 |
+
return result
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.get("/env/catalog")
|
| 52 |
+
def env_catalog(service: APIService = Depends(get_service)) -> dict:
|
| 53 |
+
return service.catalog()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@router.get("/env/state")
|
| 57 |
+
def env_state(service: APIService = Depends(get_service)) -> dict:
|
| 58 |
+
return service.env.get_state()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@router.get("/env/trace")
|
| 62 |
+
def env_trace(service: APIService = Depends(get_service)) -> list[dict]:
|
| 63 |
+
return service.env.get_trace()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@router.get("/env/legal_actions")
|
| 67 |
+
def env_legal_actions(service: APIService = Depends(get_service)) -> list[dict]:
|
| 68 |
+
return service.env.get_legal_actions()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@router.get("/env/reward_breakdown")
|
| 72 |
+
def env_reward_breakdown(service: APIService = Depends(get_service)) -> dict:
|
| 73 |
+
return service.env.get_reward_breakdown()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@router.get("/env/uncertainty")
|
| 77 |
+
def env_uncertainty(service: APIService = Depends(get_service)) -> dict:
|
| 78 |
+
return service.env.get_uncertainty_report().model_dump(mode="json")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@router.post("/agents/orchestrate")
|
| 82 |
+
def agents_orchestrate(
|
| 83 |
+
payload: OrchestrateRequest = OrchestrateRequest(),
|
| 84 |
+
service: APIService = Depends(get_service),
|
| 85 |
+
) -> dict:
|
| 86 |
+
return service.orchestrate(coordination_mode=payload.coordination_mode)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@router.post("/policy/infer")
|
| 90 |
+
def policy_infer(service: APIService = Depends(get_service)) -> dict:
|
| 91 |
+
return service.infer_policy()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@router.get("/policy/model_status")
|
| 95 |
+
def policy_model_status(service: APIService = Depends(get_service)) -> dict:
|
| 96 |
+
return service.model_status()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@router.post("/policy/batch_infer")
|
| 100 |
+
def policy_batch_infer(
|
| 101 |
+
payload: BatchInferRequest = BatchInferRequest(),
|
| 102 |
+
service: APIService = Depends(get_service),
|
| 103 |
+
) -> list[dict]:
|
| 104 |
+
return service.batch_infer(batch_size=payload.batch_size)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@router.post("/eval/run_baselines")
|
| 108 |
+
def eval_baselines(service: APIService = Depends(get_service)) -> dict:
|
| 109 |
+
return service.run_baselines()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@router.post("/eval/run_policy")
|
| 113 |
+
def eval_run_policy(service: APIService = Depends(get_service)) -> dict:
|
| 114 |
+
return service.run_policy_eval()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@router.post("/eval/run_dosing")
|
| 118 |
+
def eval_run_dosing(service: APIService = Depends(get_service)) -> dict:
|
| 119 |
+
return service.run_dosing_eval()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@router.get("/metrics/training")
|
| 123 |
+
def metrics_training(service: APIService = Depends(get_service)) -> dict:
|
| 124 |
+
return service.get_metrics()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@router.get("/cases/sample")
|
| 128 |
+
def cases_sample(service: APIService = Depends(get_service)) -> dict:
|
| 129 |
+
return service.sample_case()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@router.get("/cases/search")
|
| 133 |
+
def cases_search(q: str, service: APIService = Depends(get_service)) -> list[dict]:
|
| 134 |
+
return service.search_cases(q)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@router.post("/evidence/query")
|
| 138 |
+
def evidence_query(payload: EvidenceQueryRequest, service: APIService = Depends(get_service)) -> list[dict]:
|
| 139 |
+
return service.evidence_query(query=payload.query, top_k=payload.top_k)
|
app/api/schemas.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API schemas."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 8 |
+
|
| 9 |
+
from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StrictSchema(BaseModel):
|
| 13 |
+
model_config = ConfigDict(extra="forbid")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResetRequest(StrictSchema):
|
| 17 |
+
task_id: Optional[str] = None
|
| 18 |
+
seed: Optional[int] = None
|
| 19 |
+
difficulty: Optional[Difficulty] = None
|
| 20 |
+
sub_environment: Optional[SubEnvironment] = None
|
| 21 |
+
scenario_id: Optional[str] = None
|
| 22 |
+
patient_id: Optional[str] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class StepRequest(StrictSchema):
|
| 26 |
+
mode: DecisionMode
|
| 27 |
+
action_type: ActionType
|
| 28 |
+
target_drug: Optional[str] = None
|
| 29 |
+
replacement_drug: Optional[str] = None
|
| 30 |
+
dose_bucket: DoseBucket
|
| 31 |
+
taper_days: Optional[int] = None
|
| 32 |
+
monitoring_plan: Optional[str] = None
|
| 33 |
+
evidence_query: Optional[str] = None
|
| 34 |
+
new_drug_name: Optional[str] = None
|
| 35 |
+
candidate_components: list[str] = Field(default_factory=list)
|
| 36 |
+
candidate_id: str
|
| 37 |
+
confidence: float
|
| 38 |
+
rationale_brief: str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class StepCandidateRequest(StrictSchema):
|
| 42 |
+
candidate_id: str
|
| 43 |
+
confidence: float
|
| 44 |
+
rationale_brief: str
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class OrchestrateRequest(StrictSchema):
|
| 48 |
+
coordination_mode: Optional[str] = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class BatchInferRequest(StrictSchema):
|
| 52 |
+
batch_size: int = 4
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EvidenceQueryRequest(StrictSchema):
|
| 56 |
+
query: str
|
| 57 |
+
top_k: int = 5
|
app/api/service.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API service layer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from app.agents.orchestrator import Orchestrator
|
| 9 |
+
from app.env.catalog import apply_task_preset, env_catalog
|
| 10 |
+
from app.env.env_core import PolyGuardEnv
|
| 11 |
+
from app.evaluation.benchmark_report import build_benchmark_report
|
| 12 |
+
from app.evaluation.dosing_eval import dosing_eval
|
| 13 |
+
from app.knowledge.evidence_retriever import retrieve_evidence
|
| 14 |
+
from app.models.retrieval.retriever import retrieve
|
| 15 |
+
from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference
|
| 16 |
+
from app.models.baselines import (
|
| 17 |
+
choose_beam_search,
|
| 18 |
+
choose_contextual_bandit,
|
| 19 |
+
choose_contextual_bandit_topk,
|
| 20 |
+
choose_greedy,
|
| 21 |
+
choose_no_change,
|
| 22 |
+
choose_rules_only,
|
| 23 |
+
)
|
| 24 |
+
from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class APIService:
|
| 28 |
+
def __init__(self) -> None:
|
| 29 |
+
self.env = PolyGuardEnv()
|
| 30 |
+
self.orchestrator = Orchestrator(self.env)
|
| 31 |
+
self.policy_router = PolicyProviderRouter()
|
| 32 |
+
self.training_metrics: dict[str, Any] = {}
|
| 33 |
+
self.root = Path(__file__).resolve().parents[2]
|
| 34 |
+
|
| 35 |
+
def reset(self, **kwargs: Any) -> dict[str, Any]:
|
| 36 |
+
kwargs = apply_task_preset(dict(kwargs))
|
| 37 |
+
obs = self.env.reset(**kwargs)
|
| 38 |
+
return obs.model_dump(mode="json")
|
| 39 |
+
|
| 40 |
+
def step(self, action: dict[str, Any]) -> dict[str, Any]:
|
| 41 |
+
obs, reward, done, info = self.env.step(action)
|
| 42 |
+
reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else ""
|
| 43 |
+
truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"}
|
| 44 |
+
return {
|
| 45 |
+
"observation": obs.model_dump(mode="json"),
|
| 46 |
+
"reward": reward,
|
| 47 |
+
"done": done,
|
| 48 |
+
"terminated": done,
|
| 49 |
+
"truncated": truncated,
|
| 50 |
+
"info": info,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
def catalog(self) -> dict[str, Any]:
|
| 54 |
+
return env_catalog()
|
| 55 |
+
|
| 56 |
+
def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None:
|
| 57 |
+
for action in self.env.get_legal_actions():
|
| 58 |
+
if action.get("candidate_id") != candidate_id:
|
| 59 |
+
continue
|
| 60 |
+
payload = dict(action)
|
| 61 |
+
payload["confidence"] = confidence
|
| 62 |
+
payload["rationale_brief"] = rationale_brief
|
| 63 |
+
return self.step(payload)
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]:
|
| 67 |
+
return self.orchestrator.run_step(coordination_mode=coordination_mode)
|
| 68 |
+
|
| 69 |
+
def infer_policy(self) -> dict[str, Any]:
|
| 70 |
+
legal = self.env.get_legal_actions()
|
| 71 |
+
if not legal:
|
| 72 |
+
return {}
|
| 73 |
+
candidate_payloads = [
|
| 74 |
+
item for item in self.env.get_candidate_actions() if bool(item.get("legality_precheck", False))
|
| 75 |
+
]
|
| 76 |
+
if not candidate_payloads:
|
| 77 |
+
return legal[0]
|
| 78 |
+
candidates = [self._candidate_obj(item) for item in candidate_payloads]
|
| 79 |
+
state = self.env.state
|
| 80 |
+
selection = self.policy_router.select_candidate(
|
| 81 |
+
candidates=candidates,
|
| 82 |
+
prompt={
|
| 83 |
+
"patient_id": state.patient.patient_id,
|
| 84 |
+
"difficulty": state.difficulty.value,
|
| 85 |
+
"sub_environment": state.sub_environment.value,
|
| 86 |
+
"step_count": state.step_count,
|
| 87 |
+
},
|
| 88 |
+
provider_preference=default_provider_preference(),
|
| 89 |
+
)
|
| 90 |
+
selected = next((item for item in legal if item.get("candidate_id") == selection.candidate_id), legal[0])
|
| 91 |
+
payload = dict(selected)
|
| 92 |
+
payload["policy_selection"] = {
|
| 93 |
+
"provider": selection.provider,
|
| 94 |
+
"candidate_id": selection.candidate_id,
|
| 95 |
+
"rationale": selection.rationale,
|
| 96 |
+
"latency_ms": round(selection.latency_ms, 3),
|
| 97 |
+
"raw_output": selection.raw_output,
|
| 98 |
+
}
|
| 99 |
+
return payload
|
| 100 |
+
|
| 101 |
+
def model_status(self) -> dict[str, Any]:
|
| 102 |
+
return self.policy_router.model_status()
|
| 103 |
+
|
| 104 |
+
def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]:
|
| 105 |
+
legal = self.env.get_legal_actions()
|
| 106 |
+
return legal[:batch_size]
|
| 107 |
+
|
| 108 |
+
def run_baselines(self) -> dict[str, Any]:
|
| 109 |
+
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
|
| 110 |
+
if not candidates:
|
| 111 |
+
self.env.reset()
|
| 112 |
+
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
|
| 113 |
+
baseline_results = {
|
| 114 |
+
"no_change": choose_no_change().model_dump(mode="json"),
|
| 115 |
+
"rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 116 |
+
"greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 117 |
+
"contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 118 |
+
"contextual_bandit_topk": [
|
| 119 |
+
{
|
| 120 |
+
"candidate_id": item.candidate_id,
|
| 121 |
+
"score": item.score,
|
| 122 |
+
"exploration_bonus": item.exploration_bonus,
|
| 123 |
+
"algorithm": item.algorithm,
|
| 124 |
+
}
|
| 125 |
+
for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3)
|
| 126 |
+
],
|
| 127 |
+
"beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
|
| 128 |
+
}
|
| 129 |
+
return baseline_results
|
| 130 |
+
|
| 131 |
+
def run_policy_eval(self) -> dict[str, Any]:
|
| 132 |
+
out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt"))
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
def run_dosing_eval(self) -> dict[str, Any]:
|
| 136 |
+
return dosing_eval()
|
| 137 |
+
|
| 138 |
+
def run_training(self) -> dict[str, Any]:
|
| 139 |
+
out_dir = Path("checkpoints")
|
| 140 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 141 |
+
self.training_metrics = {
|
| 142 |
+
"supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir),
|
| 143 |
+
"planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir),
|
| 144 |
+
"dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir),
|
| 145 |
+
}
|
| 146 |
+
return self.training_metrics
|
| 147 |
+
|
| 148 |
+
def get_metrics(self) -> dict[str, Any]:
|
| 149 |
+
if self.training_metrics:
|
| 150 |
+
if "planner" in self.training_metrics:
|
| 151 |
+
merged = dict(self.training_metrics["planner"])
|
| 152 |
+
merged["model_metrics"] = self.training_metrics
|
| 153 |
+
return merged
|
| 154 |
+
return self.training_metrics
|
| 155 |
+
reports_dir = Path("outputs/reports")
|
| 156 |
+
metrics: dict[str, Any] = {}
|
| 157 |
+
for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]:
|
| 158 |
+
path = reports_dir / f"{name}.json"
|
| 159 |
+
if path.exists():
|
| 160 |
+
import json
|
| 161 |
+
|
| 162 |
+
metrics[name] = json.loads(path.read_text(encoding="utf-8"))
|
| 163 |
+
self.training_metrics = metrics
|
| 164 |
+
if "planner_grpo" in metrics:
|
| 165 |
+
merged = dict(metrics["planner_grpo"])
|
| 166 |
+
merged["model_metrics"] = metrics
|
| 167 |
+
return merged
|
| 168 |
+
return metrics
|
| 169 |
+
|
| 170 |
+
def sample_case(self) -> dict[str, Any]:
|
| 171 |
+
obs = self.env.reset()
|
| 172 |
+
return obs.model_dump(mode="json")
|
| 173 |
+
|
| 174 |
+
def search_cases(self, query: str) -> list[dict[str, Any]]:
|
| 175 |
+
index_file = self.root / "data" / "retrieval_index" / "index.json"
|
| 176 |
+
hits = retrieve(index_file=index_file, query=query, top_k=5)
|
| 177 |
+
if hits:
|
| 178 |
+
return [
|
| 179 |
+
{
|
| 180 |
+
"patient_id": Path(item.get("path", f"case_{idx}")).stem,
|
| 181 |
+
"query": query,
|
| 182 |
+
"source_path": item.get("path", ""),
|
| 183 |
+
"snippet": str(item.get("text", ""))[:280],
|
| 184 |
+
}
|
| 185 |
+
for idx, item in enumerate(hits)
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
fallback: list[dict[str, Any]] = []
|
| 189 |
+
corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl"
|
| 190 |
+
if corpus.exists():
|
| 191 |
+
query_tokens = {token for token in query.lower().split() if token}
|
| 192 |
+
with corpus.open("r", encoding="utf-8") as handle:
|
| 193 |
+
for idx, line in enumerate(handle):
|
| 194 |
+
if len(fallback) >= 5:
|
| 195 |
+
break
|
| 196 |
+
text = line.strip()
|
| 197 |
+
if not text:
|
| 198 |
+
continue
|
| 199 |
+
hay = text.lower()
|
| 200 |
+
if query_tokens and not any(token in hay for token in query_tokens):
|
| 201 |
+
continue
|
| 202 |
+
fallback.append(
|
| 203 |
+
{
|
| 204 |
+
"patient_id": f"retrieval_corpus_{idx}",
|
| 205 |
+
"query": query,
|
| 206 |
+
"source_path": str(corpus),
|
| 207 |
+
"snippet": text[:280],
|
| 208 |
+
}
|
| 209 |
+
)
|
| 210 |
+
return fallback
|
| 211 |
+
|
| 212 |
+
def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]:
|
| 213 |
+
return retrieve_evidence(query=query, top_k=top_k)
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def _candidate_obj(payload: dict) -> Any:
|
| 217 |
+
from app.common.types import CandidateAction
|
| 218 |
+
|
| 219 |
+
return CandidateAction.model_validate(payload)
|
app/common/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration loading."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _read_yaml(path: Path) -> dict[str, Any]:
|
| 13 |
+
if not path.exists():
|
| 14 |
+
return {}
|
| 15 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 16 |
+
return yaml.safe_load(handle) or {}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_config(config_name: str = "base.yaml") -> dict[str, Any]:
|
| 20 |
+
root = Path(__file__).resolve().parents[2]
|
| 21 |
+
config_path = root / "configs" / config_name
|
| 22 |
+
return _read_yaml(config_path)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_project_env(path: Path | None = None) -> None:
|
| 26 |
+
"""Load simple KEY=VALUE pairs from .env without overriding shell env."""
|
| 27 |
+
|
| 28 |
+
root = Path(__file__).resolve().parents[2]
|
| 29 |
+
env_path = path or root / ".env"
|
| 30 |
+
if not env_path.exists():
|
| 31 |
+
return
|
| 32 |
+
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
|
| 33 |
+
line = raw_line.strip()
|
| 34 |
+
if not line or line.startswith("#") or "=" not in line:
|
| 35 |
+
continue
|
| 36 |
+
key, value = line.split("=", 1)
|
| 37 |
+
key = key.strip()
|
| 38 |
+
if not key or key in os.environ:
|
| 39 |
+
continue
|
| 40 |
+
os.environ[key] = value.strip().strip('"').strip("'")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def env_bool(name: str, default: bool = False) -> bool:
|
| 44 |
+
raw = os.getenv(name)
|
| 45 |
+
if raw is None:
|
| 46 |
+
return default
|
| 47 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def env_int(name: str, default: int) -> int:
|
| 51 |
+
raw = os.getenv(name)
|
| 52 |
+
if raw is None:
|
| 53 |
+
return default
|
| 54 |
+
try:
|
| 55 |
+
return int(raw)
|
| 56 |
+
except ValueError:
|
| 57 |
+
return default
|
app/common/constants.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared constants for POLYGUARD-RL."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
REWARD_MIN: float = 0.001
|
| 6 |
+
REWARD_MAX: float = 0.999
|
| 7 |
+
REWARD_PRECISION: int = 3
|
| 8 |
+
|
| 9 |
+
DEFAULT_SEED: int = 42
|
| 10 |
+
DEFAULT_MAX_STEPS: int = 10
|
| 11 |
+
MAX_REPEATED_ACTIONS: int = 3
|
| 12 |
+
MAX_KEEP_REGIMEN_RATIO: float = 0.6
|
| 13 |
+
MAX_REVIEW_RATIO: float = 0.5
|
| 14 |
+
DEFAULT_STEP_TIMEOUT_SECONDS: float = 2.5
|
| 15 |
+
DEFAULT_EPISODE_TIMEOUT_SECONDS: float = 45.0
|
| 16 |
+
|
| 17 |
+
DEFAULT_REWARD_WEIGHTS: dict[str, float] = {
|
| 18 |
+
"format_compliance_score": 0.08,
|
| 19 |
+
"candidate_alignment_score": 0.08,
|
| 20 |
+
"legality_score": 0.12,
|
| 21 |
+
"safety_delta_score": 0.15,
|
| 22 |
+
"burden_improvement_score": 0.08,
|
| 23 |
+
"disease_stability_score": 0.10,
|
| 24 |
+
"dosing_quality_score": 0.08,
|
| 25 |
+
"abstention_quality_score": 0.06,
|
| 26 |
+
"efficiency_score": 0.06,
|
| 27 |
+
"process_fidelity_score": 0.06,
|
| 28 |
+
"explanation_grounding_score": 0.03,
|
| 29 |
+
"anti_cheat_score": 0.06,
|
| 30 |
+
"uncertainty_calibration_score": 0.04,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
REQUIRED_REWARD_KEYS: tuple[str, ...] = tuple(DEFAULT_REWARD_WEIGHTS.keys())
|
| 34 |
+
|
| 35 |
+
PRIMARY_REWARD_KEYS: tuple[str, ...] = (
|
| 36 |
+
"safety_legality",
|
| 37 |
+
"clinical_improvement",
|
| 38 |
+
"dosing_quality",
|
| 39 |
+
"process_integrity",
|
| 40 |
+
)
|
app/common/enums.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enumerations used throughout POLYGUARD-RL."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Difficulty(str, Enum):
|
| 9 |
+
EASY = "easy"
|
| 10 |
+
MEDIUM = "medium"
|
| 11 |
+
HARD = "hard"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SubEnvironment(str, Enum):
|
| 15 |
+
DDI = "DDI"
|
| 16 |
+
BANDIT_MINING = "BANDIT_MINING"
|
| 17 |
+
REGIMEN_RISK = "REGIMEN_RISK"
|
| 18 |
+
PRECISION_DOSING = "PRECISION_DOSING"
|
| 19 |
+
LONGITUDINAL_DEPRESCRIBING = "LONGITUDINAL_DEPRESCRIBING"
|
| 20 |
+
WEB_SEARCH_MISSING_DATA = "WEB_SEARCH_MISSING_DATA"
|
| 21 |
+
ALTERNATIVE_SUGGESTION = "ALTERNATIVE_SUGGESTION"
|
| 22 |
+
NEW_DRUG_DECOMPOSITION = "NEW_DRUG_DECOMPOSITION"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DecisionMode(str, Enum):
|
| 26 |
+
REGIMEN_OPT = "REGIMEN_OPT"
|
| 27 |
+
DOSE_OPT = "DOSE_OPT"
|
| 28 |
+
REVIEW = "REVIEW"
|
| 29 |
+
ABSTAIN_REVIEW = "ABSTAIN_REVIEW"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ActionType(str, Enum):
|
| 33 |
+
KEEP_REGIMEN = "KEEP_REGIMEN"
|
| 34 |
+
STOP_DRUG = "STOP_DRUG"
|
| 35 |
+
SUBSTITUTE_WITHIN_CLASS = "SUBSTITUTE_WITHIN_CLASS"
|
| 36 |
+
RECOMMEND_ALTERNATIVE = "RECOMMEND_ALTERNATIVE"
|
| 37 |
+
REDUCE_DOSE_BUCKET = "REDUCE_DOSE_BUCKET"
|
| 38 |
+
INCREASE_DOSE_BUCKET = "INCREASE_DOSE_BUCKET"
|
| 39 |
+
TAPER_INITIATE = "TAPER_INITIATE"
|
| 40 |
+
TAPER_CONTINUE = "TAPER_CONTINUE"
|
| 41 |
+
DOSE_HOLD = "DOSE_HOLD"
|
| 42 |
+
ORDER_MONITORING_AND_WAIT = "ORDER_MONITORING_AND_WAIT"
|
| 43 |
+
FETCH_EXTERNAL_EVIDENCE = "FETCH_EXTERNAL_EVIDENCE"
|
| 44 |
+
DECOMPOSE_NEW_DRUG = "DECOMPOSE_NEW_DRUG"
|
| 45 |
+
REQUEST_SPECIALIST_REVIEW = "REQUEST_SPECIALIST_REVIEW"
|
| 46 |
+
REQUEST_PHARMACIST_REVIEW = "REQUEST_PHARMACIST_REVIEW"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DoseBucket(str, Enum):
|
| 50 |
+
LOW = "LOW"
|
| 51 |
+
MEDIUM = "MEDIUM"
|
| 52 |
+
HIGH = "HIGH"
|
| 53 |
+
HOLD = "HOLD"
|
| 54 |
+
NA = "NA"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CoordinationMode(str, Enum):
|
| 58 |
+
SEQUENTIAL = "sequential_pipeline"
|
| 59 |
+
SUPERVISOR_ROUTED = "supervisor_routed"
|
| 60 |
+
REPLAN_ON_VETO = "replan_on_veto"
|
| 61 |
+
LIGHT_DEBATE = "lightweight_debate"
|
app/common/exceptions.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom exceptions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PolyGuardError(Exception):
|
| 7 |
+
"""Base exception for project errors."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InvalidActionError(PolyGuardError):
|
| 11 |
+
"""Raised when an action is malformed or disallowed."""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SafetyVetoError(PolyGuardError):
|
| 15 |
+
"""Raised when safety governance rejects an action."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ParserError(PolyGuardError):
|
| 19 |
+
"""Raised when structured policy output cannot be parsed."""
|
app/common/json_utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Strict JSON helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def safe_json_dumps(payload: Any) -> str:
|
| 10 |
+
return json.dumps(payload, ensure_ascii=True, sort_keys=True, default=str)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def safe_json_loads(payload: str) -> Any:
|
| 14 |
+
return json.loads(payload)
|
app/common/logging_utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def configure_logging(level: str = "INFO") -> None:
|
| 10 |
+
logging.basicConfig(
|
| 11 |
+
level=getattr(logging, level.upper(), logging.INFO),
|
| 12 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
| 17 |
+
return logging.getLogger(name or "polyguard")
|
app/common/normalization.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalization and reward range utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.common.constants import REWARD_MAX, REWARD_MIN, REWARD_PRECISION
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def clamp_reward(value: float) -> float:
|
| 9 |
+
"""Clamp and quantize reward to [0.001, 0.999] with 3 decimals."""
|
| 10 |
+
value = min(REWARD_MAX, max(REWARD_MIN, float(value)))
|
| 11 |
+
return round(value, REWARD_PRECISION)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def normalize_unit_interval(value: float, lower: float, upper: float) -> float:
|
| 15 |
+
if upper <= lower:
|
| 16 |
+
return 0.5
|
| 17 |
+
ratio = (value - lower) / (upper - lower)
|
| 18 |
+
return float(min(1.0, max(0.0, ratio)))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def to_reward(value: float, lower: float, upper: float) -> float:
|
| 22 |
+
raw = normalize_unit_interval(value, lower, upper)
|
| 23 |
+
scaled = REWARD_MIN + raw * (REWARD_MAX - REWARD_MIN)
|
| 24 |
+
return clamp_reward(scaled)
|
app/common/seeding.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic seeding helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from app.common.constants import DEFAULT_SEED
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def set_global_seed(seed: int = DEFAULT_SEED) -> int:
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 17 |
+
return seed
|
app/common/types.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core typed models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
| 9 |
+
|
| 10 |
+
from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment
|
| 11 |
+
from app.common.normalization import clamp_reward
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class StrictBase(BaseModel):
|
| 15 |
+
model_config = ConfigDict(extra="forbid")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Medication(StrictBase):
|
| 19 |
+
drug: str
|
| 20 |
+
dose_bucket: DoseBucket = DoseBucket.MEDIUM
|
| 21 |
+
indication: Optional[str] = None
|
| 22 |
+
class_name: Optional[str] = None
|
| 23 |
+
requires_taper: bool = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LabSummary(StrictBase):
|
| 27 |
+
egfr: Optional[float] = None
|
| 28 |
+
ast: Optional[float] = None
|
| 29 |
+
alt: Optional[float] = None
|
| 30 |
+
inr: Optional[float] = None
|
| 31 |
+
glucose: Optional[float] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PatientProfile(StrictBase):
|
| 35 |
+
patient_id: str
|
| 36 |
+
age: int
|
| 37 |
+
sex: str
|
| 38 |
+
comorbidities: list[str] = Field(default_factory=list)
|
| 39 |
+
medications: list[Medication] = Field(default_factory=list)
|
| 40 |
+
labs: LabSummary = Field(default_factory=LabSummary)
|
| 41 |
+
vitals: dict[str, float] = Field(default_factory=dict)
|
| 42 |
+
specialist_conflicts: list[str] = Field(default_factory=list)
|
| 43 |
+
prior_ade_history: list[str] = Field(default_factory=list)
|
| 44 |
+
frailty_score: float = 0.3
|
| 45 |
+
adherence_estimate: float = 0.8
|
| 46 |
+
latent_confounders: dict[str, float] = Field(default_factory=dict)
|
| 47 |
+
monitoring_gaps: list[str] = Field(default_factory=list)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CandidateAction(StrictBase):
|
| 51 |
+
candidate_id: str
|
| 52 |
+
mode: DecisionMode
|
| 53 |
+
action_type: ActionType
|
| 54 |
+
target_drug: Optional[str] = None
|
| 55 |
+
replacement_drug: Optional[str] = None
|
| 56 |
+
dose_bucket: DoseBucket = DoseBucket.NA
|
| 57 |
+
taper_days: Optional[int] = None
|
| 58 |
+
monitoring_plan: Optional[str] = None
|
| 59 |
+
evidence_query: Optional[str] = None
|
| 60 |
+
new_drug_name: Optional[str] = None
|
| 61 |
+
candidate_components: list[str] = Field(default_factory=list)
|
| 62 |
+
estimated_safety_delta: float = 0.0
|
| 63 |
+
burden_delta: float = 0.0
|
| 64 |
+
disease_stability_estimate: float = 0.0
|
| 65 |
+
uncertainty_score: float = 0.5
|
| 66 |
+
rationale_tags: list[str] = Field(default_factory=list)
|
| 67 |
+
required_monitoring: list[str] = Field(default_factory=list)
|
| 68 |
+
legality_precheck: bool = True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class PolyGuardAction(StrictBase):
|
| 72 |
+
mode: DecisionMode
|
| 73 |
+
action_type: ActionType
|
| 74 |
+
target_drug: Optional[str] = None
|
| 75 |
+
replacement_drug: Optional[str] = None
|
| 76 |
+
dose_bucket: DoseBucket = DoseBucket.NA
|
| 77 |
+
taper_days: Optional[int] = None
|
| 78 |
+
monitoring_plan: Optional[str] = None
|
| 79 |
+
evidence_query: Optional[str] = None
|
| 80 |
+
new_drug_name: Optional[str] = None
|
| 81 |
+
candidate_components: list[str] = Field(default_factory=list)
|
| 82 |
+
candidate_id: str
|
| 83 |
+
confidence: float
|
| 84 |
+
rationale_brief: str
|
| 85 |
+
|
| 86 |
+
@field_validator("confidence")
|
| 87 |
+
@classmethod
|
| 88 |
+
def _valid_confidence(cls, value: float) -> float:
|
| 89 |
+
return clamp_reward(value)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class RewardBreakdown(StrictBase):
|
| 93 |
+
format_compliance_score: float
|
| 94 |
+
candidate_alignment_score: float
|
| 95 |
+
legality_score: float
|
| 96 |
+
safety_delta_score: float
|
| 97 |
+
burden_improvement_score: float
|
| 98 |
+
disease_stability_score: float
|
| 99 |
+
dosing_quality_score: float
|
| 100 |
+
abstention_quality_score: float
|
| 101 |
+
efficiency_score: float
|
| 102 |
+
process_fidelity_score: float
|
| 103 |
+
explanation_grounding_score: float
|
| 104 |
+
anti_cheat_score: float
|
| 105 |
+
uncertainty_calibration_score: float
|
| 106 |
+
primary_safety_legality: float = 0.5
|
| 107 |
+
primary_clinical_improvement: float = 0.5
|
| 108 |
+
primary_dosing_quality: float = 0.5
|
| 109 |
+
primary_process_integrity: float = 0.5
|
| 110 |
+
total_reward: float
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class SafetyReport(StrictBase):
|
| 114 |
+
legal: bool
|
| 115 |
+
violations: list[str] = Field(default_factory=list)
|
| 116 |
+
severity: str = "none"
|
| 117 |
+
recommended_fallback: Optional[ActionType] = None
|
| 118 |
+
uncertainty_notes: list[str] = Field(default_factory=list)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class UncertaintyReport(StrictBase):
|
| 122 |
+
overall_uncertainty: float = 0.5
|
| 123 |
+
missing_data_flags: list[str] = Field(default_factory=list)
|
| 124 |
+
abstention_recommended: bool = False
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class PolyGuardState(StrictBase):
|
| 128 |
+
episode_id: str
|
| 129 |
+
seed: int
|
| 130 |
+
scenario_id: Optional[str] = None
|
| 131 |
+
difficulty: Difficulty
|
| 132 |
+
sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK
|
| 133 |
+
step_count: int
|
| 134 |
+
max_steps: int
|
| 135 |
+
patient: PatientProfile
|
| 136 |
+
active_mode: DecisionMode = DecisionMode.REGIMEN_OPT
|
| 137 |
+
cumulative_reward: float = 0.0
|
| 138 |
+
unresolved_conflicts: list[str] = Field(default_factory=list)
|
| 139 |
+
risk_summary: dict[str, float] = Field(default_factory=dict)
|
| 140 |
+
burden_score: float = 0.5
|
| 141 |
+
precision_dosing_flags: list[str] = Field(default_factory=list)
|
| 142 |
+
action_history: list[dict[str, Any]] = Field(default_factory=list)
|
| 143 |
+
done: bool = False
|
| 144 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class PolyGuardObservation(StrictBase):
|
| 148 |
+
patient_summary: dict[str, Any]
|
| 149 |
+
medication_table: list[dict[str, Any]]
|
| 150 |
+
comorbidity_summary: list[str]
|
| 151 |
+
organ_function_summary: dict[str, Any]
|
| 152 |
+
labs_vitals_summary: dict[str, Any]
|
| 153 |
+
graph_safety_summary: dict[str, Any]
|
| 154 |
+
burden_score_summary: dict[str, Any]
|
| 155 |
+
precision_dosing_flags: list[str]
|
| 156 |
+
unresolved_conflicts: list[str]
|
| 157 |
+
candidate_action_set: list[CandidateAction]
|
| 158 |
+
step_budget_remaining: int
|
| 159 |
+
action_history: list[dict[str, Any]]
|
| 160 |
+
warning_summary: list[str]
|
| 161 |
+
abstention_indicators: dict[str, Any]
|
| 162 |
+
sub_environment: SubEnvironment
|
| 163 |
+
deterministic_contract: dict[str, Any] = Field(default_factory=dict)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class StepTrace(StrictBase):
|
| 167 |
+
step: int
|
| 168 |
+
observation_snapshot: PolyGuardObservation
|
| 169 |
+
selected_action: Optional[PolyGuardAction] = None
|
| 170 |
+
critic_output: dict[str, Any] = Field(default_factory=dict)
|
| 171 |
+
reward_components: dict[str, float] = Field(default_factory=dict)
|
| 172 |
+
transition_delta: dict[str, Any] = Field(default_factory=dict)
|
| 173 |
+
uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport)
|
| 174 |
+
failure_reasons: list[str] = Field(default_factory=list)
|
| 175 |
+
timeout: bool = False
|
app/dataops/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data operations package."""
|
| 2 |
+
|
| 3 |
+
from app.dataops.source_manager import SourceManager
|
| 4 |
+
|
| 5 |
+
__all__ = ["SourceManager"]
|
app/dataops/ddi_api.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DDI API ingestion helpers with offline-first caching."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
DEFAULT_DDI_API_URL = "https://api.fda.gov/drug/label.json"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fetch_ddi_api_records(
|
| 16 |
+
drugs: list[str],
|
| 17 |
+
timeout: int = 20,
|
| 18 |
+
api_url: str = DEFAULT_DDI_API_URL,
|
| 19 |
+
) -> list[dict[str, Any]]:
|
| 20 |
+
records: list[dict[str, Any]] = []
|
| 21 |
+
for drug in drugs:
|
| 22 |
+
try:
|
| 23 |
+
response = requests.get(
|
| 24 |
+
api_url,
|
| 25 |
+
params={"search": f"openfda.generic_name:{drug}", "limit": 1},
|
| 26 |
+
timeout=timeout,
|
| 27 |
+
)
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
payload = response.json()
|
| 30 |
+
records.append(
|
| 31 |
+
{
|
| 32 |
+
"drug": drug,
|
| 33 |
+
"source": api_url,
|
| 34 |
+
"status": "ok",
|
| 35 |
+
"payload": payload,
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
except Exception as exc: # noqa: BLE001
|
| 39 |
+
records.append(
|
| 40 |
+
{
|
| 41 |
+
"drug": drug,
|
| 42 |
+
"source": api_url,
|
| 43 |
+
"status": "error",
|
| 44 |
+
"error": str(exc),
|
| 45 |
+
}
|
| 46 |
+
)
|
| 47 |
+
return records
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_cached_ddi(path: Path) -> list[dict[str, Any]]:
|
| 51 |
+
if not path.exists():
|
| 52 |
+
return []
|
| 53 |
+
try:
|
| 54 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 55 |
+
if isinstance(payload, list):
|
| 56 |
+
return payload
|
| 57 |
+
return []
|
| 58 |
+
except Exception:
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def cache_ddi_records(path: Path, records: list[dict[str, Any]]) -> Path:
|
| 63 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
path.write_text(json.dumps(records, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 65 |
+
return path
|
app/dataops/normalizer.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entity normalizer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.knowledge.drug_catalog import canonicalize_drug_name
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def normalize_drug_entities(items: list[str]) -> list[str]:
|
| 9 |
+
return sorted({canonicalize_drug_name(item) for item in items})
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def normalize_component_entities(items: list[str]) -> list[str]:
|
| 13 |
+
return sorted({canonicalize_drug_name(item).replace("-", "_") for item in items if item})
|
app/dataops/package_loader.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Package/local artifact loading."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_artifact(path: Path) -> Any:
|
| 13 |
+
if path.suffix.lower() in {".json"}:
|
| 14 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 15 |
+
if path.suffix.lower() in {".yaml", ".yml"}:
|
| 16 |
+
return yaml.safe_load(path.read_text(encoding="utf-8"))
|
| 17 |
+
if path.suffix.lower() in {".txt", ".md"}:
|
| 18 |
+
return path.read_text(encoding="utf-8")
|
| 19 |
+
return path.read_bytes()
|
app/dataops/parser.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Raw text parser for knowledge ingestion."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def extract_drug_mentions(text: str) -> list[str]:
|
| 9 |
+
tokens = re.findall(r"[a-zA-Z_-]{4,}", text.lower())
|
| 10 |
+
return sorted(set(tokens))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_components(text: str) -> list[str]:
|
| 14 |
+
# Supports "active ingredient(s): ..." and similar label patterns.
|
| 15 |
+
lines = [line.strip().lower() for line in text.splitlines() if line.strip()]
|
| 16 |
+
components: list[str] = []
|
| 17 |
+
for line in lines:
|
| 18 |
+
if "ingredient" in line or "component" in line or "contains" in line:
|
| 19 |
+
parts = re.split(r":|\\.|;", line, maxsplit=1)
|
| 20 |
+
if len(parts) > 1:
|
| 21 |
+
rhs = parts[1]
|
| 22 |
+
for item in re.split(r",|/| and ", rhs):
|
| 23 |
+
token = re.sub(r"[^a-z0-9_ -]", "", item).strip().replace(" ", "_")
|
| 24 |
+
if 3 <= len(token) <= 40:
|
| 25 |
+
components.append(token)
|
| 26 |
+
return sorted(set(components))
|
app/dataops/provenance.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Provenance tracking."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(slots=True)
|
| 10 |
+
class ProvenanceRecord:
|
| 11 |
+
source: str
|
| 12 |
+
source_type: str
|
| 13 |
+
fetched_at: str
|
| 14 |
+
transform: str
|
| 15 |
+
|
| 16 |
+
def to_dict(self) -> dict[str, str]:
|
| 17 |
+
return {
|
| 18 |
+
"source": self.source,
|
| 19 |
+
"source_type": self.source_type,
|
| 20 |
+
"fetched_at": self.fetched_at,
|
| 21 |
+
"transform": self.transform,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def make_provenance(source: str, source_type: str, transform: str) -> ProvenanceRecord:
|
| 26 |
+
return ProvenanceRecord(
|
| 27 |
+
source=source,
|
| 28 |
+
source_type=source_type,
|
| 29 |
+
fetched_at=datetime.utcnow().isoformat(),
|
| 30 |
+
transform=transform,
|
| 31 |
+
)
|
app/dataops/scraper.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Controlled scraper facade."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from app.dataops.web_agent import fetch_url
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def scrape_allowed_page(url: str, allow_domains: list[str]) -> str:
|
| 9 |
+
return fetch_url(url, allowed_domains=allow_domains)
|
app/dataops/source_manager.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Source management for offline-first ingestion."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from app.dataops.web_agent import fetch_url
|
| 11 |
+
from app.dataops.parser import extract_components, extract_drug_mentions
|
| 12 |
+
from app.dataops.normalizer import normalize_component_entities, normalize_drug_entities
|
| 13 |
+
from app.dataops.provenance import make_provenance
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SourceManager:
|
| 17 |
+
def __init__(self, root: Path) -> None:
|
| 18 |
+
self.root = root
|
| 19 |
+
self.raw = root / "data" / "raw"
|
| 20 |
+
self.cache = root / "data" / "cache"
|
| 21 |
+
self.cache.mkdir(parents=True, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
def local_sources(self) -> list[Path]:
|
| 24 |
+
return [p for p in self.raw.rglob("*") if p.is_file()]
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def checksum_text(text: str) -> str:
|
| 28 |
+
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
| 29 |
+
|
| 30 |
+
def cache_text(self, namespace: str, key: str, text: str) -> Path:
|
| 31 |
+
ns_dir = self.cache / namespace
|
| 32 |
+
ns_dir.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
checksum = self.checksum_text(text)
|
| 34 |
+
target = ns_dir / f"{key}_{checksum[:12]}.txt"
|
| 35 |
+
target.write_text(text, encoding="utf-8")
|
| 36 |
+
meta = {
|
| 37 |
+
"key": key,
|
| 38 |
+
"checksum": checksum,
|
| 39 |
+
"path": str(target),
|
| 40 |
+
}
|
| 41 |
+
(ns_dir / f"{key}.meta.json").write_text(json.dumps(meta, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 42 |
+
return target
|
| 43 |
+
|
| 44 |
+
def read_cached(self, namespace: str, key: str) -> str | None:
|
| 45 |
+
meta_path = self.cache / namespace / f"{key}.meta.json"
|
| 46 |
+
if not meta_path.exists():
|
| 47 |
+
return None
|
| 48 |
+
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
| 49 |
+
target = Path(meta["path"])
|
| 50 |
+
if target.exists():
|
| 51 |
+
return target.read_text(encoding="utf-8")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def fetch_with_cache(
|
| 55 |
+
self,
|
| 56 |
+
url: str,
|
| 57 |
+
allow_domains: list[str],
|
| 58 |
+
namespace: str = "web",
|
| 59 |
+
offline_first: bool = True,
|
| 60 |
+
) -> dict[str, Any]:
|
| 61 |
+
key = url.replace("https://", "").replace("http://", "").replace("/", "_")
|
| 62 |
+
if offline_first:
|
| 63 |
+
cached = self.read_cached(namespace=namespace, key=key)
|
| 64 |
+
if cached is not None:
|
| 65 |
+
provenance = make_provenance(source=url, source_type="cache", transform="read_cached")
|
| 66 |
+
return {"text": cached, "provenance": provenance.__dict__, "from_cache": True}
|
| 67 |
+
text = fetch_url(url, allowed_domains=allow_domains)
|
| 68 |
+
self.cache_text(namespace=namespace, key=key, text=text)
|
| 69 |
+
provenance = make_provenance(source=url, source_type="web", transform="fetch_with_cache")
|
| 70 |
+
return {"text": text, "provenance": provenance.__dict__, "from_cache": False}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DataAcquisitionAgent:
|
| 74 |
+
def __init__(self, root: Path, allow_domains: list[str]) -> None:
|
| 75 |
+
self.manager = SourceManager(root=root)
|
| 76 |
+
self.allow_domains = allow_domains
|
| 77 |
+
|
| 78 |
+
def acquire_local_knowledge(self) -> list[dict[str, Any]]:
|
| 79 |
+
records: list[dict[str, Any]] = []
|
| 80 |
+
for source in self.manager.local_sources():
|
| 81 |
+
text = source.read_text(encoding="utf-8", errors="ignore")
|
| 82 |
+
mentions = normalize_drug_entities(extract_drug_mentions(text))
|
| 83 |
+
components = normalize_component_entities(extract_components(text))
|
| 84 |
+
provenance = make_provenance(source=str(source), source_type="local_file", transform="parse_local").to_dict()
|
| 85 |
+
records.append(
|
| 86 |
+
{
|
| 87 |
+
"source": str(source),
|
| 88 |
+
"mentions": mentions,
|
| 89 |
+
"components": components,
|
| 90 |
+
"provenance": provenance,
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
return records
|
| 94 |
+
|
| 95 |
+
def acquire_web_knowledge(self, url: str, offline_first: bool = True) -> dict[str, Any]:
|
| 96 |
+
blob = self.manager.fetch_with_cache(
|
| 97 |
+
url=url,
|
| 98 |
+
allow_domains=self.allow_domains,
|
| 99 |
+
namespace="drug_labels",
|
| 100 |
+
offline_first=offline_first,
|
| 101 |
+
)
|
| 102 |
+
text = blob["text"]
|
| 103 |
+
mentions = normalize_drug_entities(extract_drug_mentions(text))
|
| 104 |
+
components = normalize_component_entities(extract_components(text))
|
| 105 |
+
return {
|
| 106 |
+
"url": url,
|
| 107 |
+
"mentions": mentions,
|
| 108 |
+
"components": components,
|
| 109 |
+
"provenance": blob["provenance"],
|
| 110 |
+
"from_cache": blob["from_cache"],
|
| 111 |
+
}
|