Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +73 -0
- README.md +244 -12
- __init__.py +23 -0
- agent.py +179 -0
- client.py +131 -0
- data.py +136 -0
- env.py +116 -0
- graders.py +33 -0
- inference.py +213 -0
- models.py +52 -0
- openenv.yaml +113 -0
- pyproject.toml +24 -0
- requirements.txt +6 -0
- run_demo.py +23 -0
- server.py +50 -0
- server/__init__.py +11 -0
- server/app.py +50 -0
- server/graders.py +179 -0
- server/pharma_vigilance_env_environment.py +5 -0
- server/requirements.txt +6 -0
- server/tasks.py +27 -0
- tasks.py +222 -0
- tests/test_env.py +132 -0
- uv.lock +0 -0
- validate-submission.sh +185 -0
Dockerfile
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=pharma_vigilance_env
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies into a local virtualenv using the repo requirements file.
|
| 42 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 43 |
+
uv venv .venv && \
|
| 44 |
+
. .venv/bin/activate && \
|
| 45 |
+
uv pip install -r requirements.txt
|
| 46 |
+
|
| 47 |
+
# Final runtime stage
|
| 48 |
+
FROM ${BASE_IMAGE}
|
| 49 |
+
|
| 50 |
+
WORKDIR /app
|
| 51 |
+
|
| 52 |
+
# Copy the virtual environment from builder
|
| 53 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 54 |
+
|
| 55 |
+
# Copy the environment code
|
| 56 |
+
COPY --from=builder /app/env /app/env
|
| 57 |
+
|
| 58 |
+
# Set PATH to use the virtual environment
|
| 59 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 60 |
+
|
| 61 |
+
# Set PYTHONPATH so imports work correctly
|
| 62 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 63 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 64 |
+
|
| 65 |
+
EXPOSE 7860
|
| 66 |
+
|
| 67 |
+
# Health check
|
| 68 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 69 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 70 |
+
|
| 71 |
+
# Run the FastAPI server
|
| 72 |
+
# The module path is constructed to work with this repo's package layout.
|
| 73 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 7860"]
|
README.md
CHANGED
|
@@ -1,12 +1,244 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
pinned: false
|
| 8 |
-
license: mit
|
| 9 |
-
short_description: OpenEnv pharmacovigilance signal detection environment
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Pharmacovigilance Signal Detector
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: green
|
| 5 |
+
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
short_description: OpenEnv pharmacovigilance signal detection environment
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- healthcare
|
| 13 |
+
- pharmacovigilance
|
| 14 |
+
- safety
|
| 15 |
+
- real-world
|
| 16 |
+
base_path: /web
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Pharmacovigilance Signal Detector
|
| 20 |
+
|
| 21 |
+
`Pharmacovigilance Signal Detector` is a real-world OpenEnv environment where an agent acts like a drug-safety analyst. The agent reviews synthetic adverse event reports, uses a hardcoded drug interaction knowledge base, and decides whether the case is a new safety signal, a known side effect, or low-value noise. This mirrors pharmacovigilance triage work performed by regulators and pharmaceutical safety teams.
|
| 22 |
+
|
| 23 |
+
All case data in this repo is synthetic. No real patient data is used.
|
| 24 |
+
|
| 25 |
+
## Why This Environment Matters
|
| 26 |
+
|
| 27 |
+
Pharmacovigilance teams are responsible for detecting harmful safety patterns after a drug is already on the market. That work is operationally important, high-stakes, and difficult: analysts must distinguish expected reactions from true emerging risks, recognize confounding from polypharmacy, and escalate only when justified. This makes the domain a strong fit for agent evaluation because it tests causal reasoning, prioritization, and safety-sensitive decision making.
|
| 28 |
+
|
| 29 |
+
## Environment Overview
|
| 30 |
+
|
| 31 |
+
| Item | Value |
|
| 32 |
+
|---|---|
|
| 33 |
+
| Environment name | `pharma-vigilance` |
|
| 34 |
+
| Domain | Pharmacovigilance / drug safety triage |
|
| 35 |
+
| Episode length | 1 step per task |
|
| 36 |
+
| Task count | 3 |
|
| 37 |
+
| Difficulties | Easy, Medium, Hard |
|
| 38 |
+
| Reward range | `0.0` to `1.0` |
|
| 39 |
+
| API | `reset()`, `step()`, `state()` |
|
| 40 |
+
| Server | FastAPI |
|
| 41 |
+
|
| 42 |
+
The agent receives one final-decision task per episode. Each task includes one or more synthetic reports plus a hardcoded drug interaction database. The environment never exposes ground truth to the agent.
|
| 43 |
+
|
| 44 |
+
## Action Space
|
| 45 |
+
|
| 46 |
+
| Field | Type | Allowed values | Purpose |
|
| 47 |
+
|---|---|---|---|
|
| 48 |
+
| `classification` | `str` | `new_signal`, `known_side_effect`, `noise`, `duplicate` | Overall pharmacovigilance judgment |
|
| 49 |
+
| `suspect_drug` | `str` | Free text | Drug or interaction the agent believes is causal |
|
| 50 |
+
| `severity_assessment` | `str` | `mild`, `moderate`, `severe`, `critical` | Clinical severity assessment |
|
| 51 |
+
| `recommended_action` | `str` | `escalate`, `log_and_monitor`, `dismiss`, `request_more_info` | Operational follow-up |
|
| 52 |
+
| `reasoning` | `str` | Free text | Short explanation used for grading bonus on hard task |
|
| 53 |
+
|
| 54 |
+
## Observation Space
|
| 55 |
+
|
| 56 |
+
| Field | Type | Description |
|
| 57 |
+
|---|---|---|
|
| 58 |
+
| `task_id` | `str` | Current task identifier |
|
| 59 |
+
| `reports` | `List[AdverseEventReport]` | Synthetic adverse event reports for the task |
|
| 60 |
+
| `drug_interaction_db` | `dict` | Hardcoded safety and interaction hints |
|
| 61 |
+
| `step_number` | `int` | Current step index |
|
| 62 |
+
| `max_steps` | `int` | Maximum number of steps in the episode |
|
| 63 |
+
| `feedback` | `Optional[str]` | Feedback message after the previous action |
|
| 64 |
+
|
| 65 |
+
Each `AdverseEventReport` contains:
|
| 66 |
+
|
| 67 |
+
| Field | Description |
|
| 68 |
+
|---|---|
|
| 69 |
+
| `report_id` | Unique synthetic report identifier |
|
| 70 |
+
| `patient_age` | Patient age |
|
| 71 |
+
| `patient_sex` | Patient sex |
|
| 72 |
+
| `drugs` | All drugs the patient was taking |
|
| 73 |
+
| `suspect_drug` | Drug named by the original reporter |
|
| 74 |
+
| `reaction` | Observed adverse reaction |
|
| 75 |
+
| `onset_days` | Days after drug start when reaction began |
|
| 76 |
+
| `severity` | Reported severity |
|
| 77 |
+
| `outcome` | Recovery status |
|
| 78 |
+
| `similar_reports_last_30d` | Count of similar recent reports |
|
| 79 |
+
|
| 80 |
+
## Tasks
|
| 81 |
+
|
| 82 |
+
| Task | Difficulty | Scenario | Ground-truth goal | Expected baseline |
|
| 83 |
+
|---|---|---|---|---|
|
| 84 |
+
| `known_signal_easy` | Easy | Patient on `Lisinopril` develops persistent dry cough with many similar recent reports already known in-label | Recognize a known side effect and recommend `log_and_monitor` | Around `0.85` |
|
| 85 |
+
| `cluster_signal_medium` | Medium | Four recent `Cardiovexa` cases show symptomatic bradycardia and near-syncope despite no labeled rhythm toxicity | Recognize a plausible emerging signal and `escalate` | Around `0.65` |
|
| 86 |
+
| `confounded_hard` | Hard | Transplant patient with acute kidney injury is blamed on `Trimethoprim-sulfamethoxazole`, but the deeper issue is a `Voriconazole`-`Tacrolimus` interaction | Detect the interaction, classify as `new_signal`, and `escalate` | Around `0.40` |
|
| 87 |
+
|
| 88 |
+
The hard task is intentionally more difficult because the named suspect drug is not the true cause. The agent must reason over interaction evidence and therapeutic drug-monitoring clues in the provided hardcoded drug database.
|
| 89 |
+
|
| 90 |
+
## Reward Function
|
| 91 |
+
|
| 92 |
+
The environment uses deterministic programmatic graders.
|
| 93 |
+
|
| 94 |
+
| Reward component | Value |
|
| 95 |
+
|---|---|
|
| 96 |
+
| Correct `classification` | `+0.25` |
|
| 97 |
+
| Correct `suspect_drug` | `+0.25` |
|
| 98 |
+
| Correct `severity_assessment` | `+0.25` |
|
| 99 |
+
| Correct `recommended_action` | `+0.25` |
|
| 100 |
+
| False alarm penalty: agent says `new_signal` when truth is `noise` | `-0.10` |
|
| 101 |
+
| Missed signal penalty: agent says `noise` when truth is `new_signal` | `-0.20` |
|
| 102 |
+
| Hard-task reasoning bonus if explanation mentions `drug interaction`, `tacrolimus`, `voriconazole`, `azole`, `calcineurin`, or `level monitoring` | `+0.15` |
|
| 103 |
+
|
| 104 |
+
Notes:
|
| 105 |
+
- Final reward is clamped to `[0.0, 1.0]`.
|
| 106 |
+
- `suspect_drug` matching is forgiving for the hard task and allows substring matches.
|
| 107 |
+
- The environment is deterministic and reproducible because all tasks and grading logic are hardcoded.
|
| 108 |
+
|
| 109 |
+
## Project Structure
|
| 110 |
+
|
| 111 |
+
| Path | Purpose |
|
| 112 |
+
|---|---|
|
| 113 |
+
| `env.py` | Main environment class and Pydantic models |
|
| 114 |
+
| `tasks.py` | Task definitions and grader functions |
|
| 115 |
+
| `data.py` | Synthetic reports and drug interaction database |
|
| 116 |
+
| `server.py` | Root FastAPI entrypoint |
|
| 117 |
+
| `server/app.py` | OpenEnv-compatible app entrypoint |
|
| 118 |
+
| `inference.py` | Baseline inference runner |
|
| 119 |
+
| `openenv.yaml` | OpenEnv metadata |
|
| 120 |
+
| `Dockerfile` | Multi-stage OpenEnv-style container build |
|
| 121 |
+
| `tests/test_env.py` | Local tests |
|
| 122 |
+
| `validate-submission.sh` | Pre-submission validation helper |
|
| 123 |
+
|
| 124 |
+
## Running Locally
|
| 125 |
+
|
| 126 |
+
### Option 1: Local virtual environment
|
| 127 |
+
|
| 128 |
+
If you already created the local virtual environment in this repo:
|
| 129 |
+
|
| 130 |
+
```powershell
|
| 131 |
+
.\.venv\Scripts\Activate.ps1
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Install dependencies if needed:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
pip install -r requirements.txt
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Start the server:
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
uvicorn server:app --host 0.0.0.0 --port 7860
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### Option 2: Docker
|
| 147 |
+
|
| 148 |
+
Build the image:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
docker build -t pharmacovigilance-env .
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Run the container:
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
docker run -p 7860:7860 pharmacovigilance-env
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
The health endpoint will be available at:
|
| 161 |
+
|
| 162 |
+
```text
|
| 163 |
+
http://localhost:7860/health
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## API Endpoints
|
| 167 |
+
|
| 168 |
+
| Method | Endpoint | Description |
|
| 169 |
+
|---|---|---|
|
| 170 |
+
| `POST` | `/reset` | Starts a task and returns the initial observation |
|
| 171 |
+
| `POST` | `/step` | Submits the final agent action and returns observation, reward, done, info |
|
| 172 |
+
| `GET` | `/state` | Returns internal environment state summary |
|
| 173 |
+
| `GET` | `/tasks` | Lists available task ids |
|
| 174 |
+
| `GET` | `/health` | Health check endpoint |
|
| 175 |
+
|
| 176 |
+
## Baseline Inference Script
|
| 177 |
+
|
| 178 |
+
The required baseline runner is `inference.py`.
|
| 179 |
+
|
| 180 |
+
It:
|
| 181 |
+
- reads `API_BASE_URL`, `MODEL_NAME`, `HF_TOKEN`, and optional `ENV_URL`
|
| 182 |
+
- uses the OpenAI client for all model calls
|
| 183 |
+
- runs all three tasks sequentially
|
| 184 |
+
- emits the required `[START]`, `[STEP]`, and `[END]` lines
|
| 185 |
+
- keeps stdout restricted to the judge-expected line types
|
| 186 |
+
|
| 187 |
+
Required environment variables:
|
| 188 |
+
|
| 189 |
+
```bash
|
| 190 |
+
export API_BASE_URL=https://router.huggingface.co/v1
|
| 191 |
+
export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 192 |
+
export HF_TOKEN=hf_your_token_here
|
| 193 |
+
export ENV_URL=http://localhost:7860
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
Run:
|
| 197 |
+
|
| 198 |
+
```bash
|
| 199 |
+
python inference.py
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
## Testing And Validation
|
| 203 |
+
|
| 204 |
+
Run local tests:
|
| 205 |
+
|
| 206 |
+
```bash
|
| 207 |
+
pytest tests/test_env.py -q
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
Run OpenEnv validation:
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
openenv validate
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
Run the pre-submission helper:
|
| 217 |
+
|
| 218 |
+
```bash
|
| 219 |
+
chmod +x validate-submission.sh
|
| 220 |
+
./validate-submission.sh https://your-space.hf.space
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
That script checks:
|
| 224 |
+
1. your Hugging Face Space responds to `POST /reset`
|
| 225 |
+
2. the Docker image builds
|
| 226 |
+
3. `openenv validate` passes
|
| 227 |
+
|
| 228 |
+
## Submission Checklist
|
| 229 |
+
|
| 230 |
+
- `openenv validate` passes
|
| 231 |
+
- `docker build` succeeds
|
| 232 |
+
- `docker run` starts cleanly
|
| 233 |
+
- `POST /reset` returns HTTP `200`
|
| 234 |
+
- `inference.py` runs all 3 tasks successfully
|
| 235 |
+
- your Hugging Face Space responds to `POST /reset`
|
| 236 |
+
- replace the expected baseline values with your measured live baseline values before final submission
|
| 237 |
+
|
| 238 |
+
## Notes
|
| 239 |
+
|
| 240 |
+
- No external API calls are made by the environment itself.
|
| 241 |
+
- The drug interaction database is hardcoded.
|
| 242 |
+
- Ground truth is never exposed in the observation returned to the agent.
|
| 243 |
+
- The environment is lightweight enough for a 2 vCPU / 8GB RAM target.
|
| 244 |
+
- The expected baseline scores in this README are planning targets until replaced with measured live results.
|
__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Pharmacovigilance Signal Detector Environment."""
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from .client import PharmaVigilanceEnvClient
|
| 11 |
+
from .models import PharmaAction, PharmaObservation, PharmaReward
|
| 12 |
+
except ImportError:
|
| 13 |
+
PharmaVigilanceEnvClient = None
|
| 14 |
+
from env import Action as PharmaAction
|
| 15 |
+
from env import Observation as PharmaObservation
|
| 16 |
+
from env import Reward as PharmaReward
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"PharmaVigilanceEnvClient",
|
| 20 |
+
"PharmaAction",
|
| 21 |
+
"PharmaObservation",
|
| 22 |
+
"PharmaReward",
|
| 23 |
+
]
|
agent.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from .env import Action
|
| 10 |
+
except ImportError:
|
| 11 |
+
from env import Action
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_cached_client: Optional[OpenAI] = None
|
| 15 |
+
_cached_model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _maybe_get_client() -> Optional[OpenAI]:
|
| 19 |
+
global _cached_client
|
| 20 |
+
|
| 21 |
+
if _cached_client is not None:
|
| 22 |
+
return _cached_client
|
| 23 |
+
|
| 24 |
+
base_url = os.environ.get("API_BASE_URL", "").strip()
|
| 25 |
+
api_key = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or "hf-missing-token"
|
| 26 |
+
|
| 27 |
+
if not base_url:
|
| 28 |
+
print(
|
| 29 |
+
"[WARN] API_BASE_URL is not configured; AnalystAgent will use heuristic mode.",
|
| 30 |
+
file=sys.stderr,
|
| 31 |
+
)
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
_cached_client = OpenAI(base_url=base_url, api_key=api_key)
|
| 35 |
+
return _cached_client
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AnalystAgent:
|
| 39 |
+
"""
|
| 40 |
+
Lightweight pharmacovigilance agent for demos and smoke testing.
|
| 41 |
+
|
| 42 |
+
The agent can call an OpenAI-compatible chat endpoint when configured, but
|
| 43 |
+
it also has a deterministic fallback policy for offline or local use.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self) -> None:
|
| 47 |
+
self.review_memory: list[dict] = []
|
| 48 |
+
|
| 49 |
+
def _case_snapshot(self, observation) -> str:
|
| 50 |
+
report_lines = []
|
| 51 |
+
for report in observation.reports:
|
| 52 |
+
report_lines.append(
|
| 53 |
+
f"- {report.report_id}: suspect={report.suspect_drug}, "
|
| 54 |
+
f"reaction={report.reaction}, onset_days={report.onset_days}, "
|
| 55 |
+
f"severity={report.severity}, outcome={report.outcome}, "
|
| 56 |
+
f"similar_30d={report.similar_reports_last_30d}"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
memory_block = ""
|
| 60 |
+
if self.review_memory:
|
| 61 |
+
memory_block = "\nRecent mistakes to avoid:\n"
|
| 62 |
+
for item in self.review_memory[-3:]:
|
| 63 |
+
memory_block += (
|
| 64 |
+
f"- On {item['task_id']} you underperformed after choosing "
|
| 65 |
+
f"{item['classification']} / {item['recommended_action']}. "
|
| 66 |
+
f"Reason note: {item['note']}\n"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return (
|
| 70 |
+
f"Task id: {observation.task_id}\n"
|
| 71 |
+
f"Reports:\n" + "\n".join(report_lines) + "\n"
|
| 72 |
+
f"Knowledge base:\n{json.dumps(observation.drug_interaction_db, ensure_ascii=True, indent=2)}"
|
| 73 |
+
f"{memory_block}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _build_prompt(self, observation) -> str:
|
| 77 |
+
return f"""You are a pharmacovigilance case assessor.
|
| 78 |
+
|
| 79 |
+
Review the case below and return one JSON object only.
|
| 80 |
+
|
| 81 |
+
Return fields:
|
| 82 |
+
- classification: one of new_signal, known_side_effect, noise, duplicate
|
| 83 |
+
- suspect_drug: likely causal drug or interaction
|
| 84 |
+
- severity_assessment: one of mild, moderate, severe, critical
|
| 85 |
+
- recommended_action: one of escalate, log_and_monitor, dismiss, request_more_info
|
| 86 |
+
- reasoning: concise mechanistic explanation
|
| 87 |
+
|
| 88 |
+
Decision principles:
|
| 89 |
+
- Repeated known labeled reactions should usually be known_side_effect
|
| 90 |
+
- Small but coherent post-marketing clusters on a newer drug can justify new_signal
|
| 91 |
+
- If the reporter blames the wrong medication, prefer the stronger causal interaction
|
| 92 |
+
- Missing a serious signal is worse than overcalling a weak case
|
| 93 |
+
|
| 94 |
+
Case:
|
| 95 |
+
{self._case_snapshot(observation)}
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def _llm_decision(self, observation) -> Optional[Action]:
|
| 99 |
+
client = _maybe_get_client()
|
| 100 |
+
if client is None:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
response = client.chat.completions.create(
|
| 105 |
+
model=_cached_model,
|
| 106 |
+
messages=[{"role": "user", "content": self._build_prompt(observation)}],
|
| 107 |
+
temperature=0.0,
|
| 108 |
+
max_tokens=220,
|
| 109 |
+
)
|
| 110 |
+
raw = (response.choices[0].message.content or "").strip()
|
| 111 |
+
payload = json.loads(raw)
|
| 112 |
+
return Action(**payload)
|
| 113 |
+
except Exception as exc:
|
| 114 |
+
print(f"[WARN] AnalystAgent LLM path failed: {exc}; falling back to heuristics.", file=sys.stderr)
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def _heuristic_decision(self, observation) -> Action:
|
| 118 |
+
reports = observation.reports
|
| 119 |
+
report_count = len(reports)
|
| 120 |
+
report = reports[0]
|
| 121 |
+
reaction_blob = " ".join(item.reaction.lower() for item in reports)
|
| 122 |
+
db_blob = json.dumps(observation.drug_interaction_db).lower()
|
| 123 |
+
|
| 124 |
+
if "dry cough" in reaction_blob and "ace inhibitor" in db_blob:
|
| 125 |
+
return Action(
|
| 126 |
+
classification="known_side_effect",
|
| 127 |
+
suspect_drug="Lisinopril",
|
| 128 |
+
severity_assessment="mild",
|
| 129 |
+
recommended_action="log_and_monitor",
|
| 130 |
+
reasoning="Persistent dry cough is a classic labeled ACE inhibitor effect.",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if report_count >= 3 and ("brady" in reaction_blob or "syncope" in reaction_blob):
|
| 134 |
+
return Action(
|
| 135 |
+
classification="new_signal",
|
| 136 |
+
suspect_drug="Cardiovexa",
|
| 137 |
+
severity_assessment="severe",
|
| 138 |
+
recommended_action="escalate",
|
| 139 |
+
reasoning="A coherent cluster of bradycardia reports on a recently launched drug warrants escalation.",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if "tacrolimus" in db_blob and "voriconazole" in db_blob:
|
| 143 |
+
return Action(
|
| 144 |
+
classification="new_signal",
|
| 145 |
+
suspect_drug="Tacrolimus+Voriconazole",
|
| 146 |
+
severity_assessment="critical",
|
| 147 |
+
recommended_action="escalate",
|
| 148 |
+
reasoning="This looks like a tacrolimus exposure interaction requiring urgent escalation and level review.",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
fallback_severity = report.severity if report.severity in {"mild", "moderate", "severe", "critical"} else "moderate"
|
| 152 |
+
return Action(
|
| 153 |
+
classification="new_signal",
|
| 154 |
+
suspect_drug=report.suspect_drug,
|
| 155 |
+
severity_assessment=fallback_severity,
|
| 156 |
+
recommended_action="request_more_info",
|
| 157 |
+
reasoning="The case is ambiguous, so additional information is needed before final triage.",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def act(self, observation) -> Action:
|
| 161 |
+
llm_action = self._llm_decision(observation)
|
| 162 |
+
if llm_action is not None:
|
| 163 |
+
return llm_action
|
| 164 |
+
return self._heuristic_decision(observation)
|
| 165 |
+
|
| 166 |
+
def learn(self, action: Action, observation) -> None:
|
| 167 |
+
reward = getattr(observation, "reward", 0.0)
|
| 168 |
+
if reward is None:
|
| 169 |
+
reward = 0.0
|
| 170 |
+
|
| 171 |
+
if reward < 0.5:
|
| 172 |
+
self.review_memory.append(
|
| 173 |
+
{
|
| 174 |
+
"task_id": getattr(observation, "task_id", "unknown"),
|
| 175 |
+
"classification": action.classification,
|
| 176 |
+
"recommended_action": action.recommended_action,
|
| 177 |
+
"note": getattr(observation, "feedback", "") or "weak outcome",
|
| 178 |
+
}
|
| 179 |
+
)
|
client.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Pharmacovigilance Signal Detector Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from .env import Action, Observation, AdverseEventReport
|
| 17 |
+
except ImportError:
|
| 18 |
+
from env import Action, Observation, AdverseEventReport
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PharmaVigilanceEnvClient(
|
| 22 |
+
EnvClient[Action, Observation, State]
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Client for the Pharmacovigilance Signal Detector environment.
|
| 26 |
+
|
| 27 |
+
This client maintains a persistent connection to the environment server and
|
| 28 |
+
parses server responses into strongly-typed observation models.
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
>>> with PharmaVigilanceEnvClient(base_url="http://localhost:7860") as env:
|
| 32 |
+
... result = env.reset(task_id="known_signal_easy")
|
| 33 |
+
... print(result.observation.task_id)
|
| 34 |
+
...
|
| 35 |
+
... action = Action(
|
| 36 |
+
... classification="known_side_effect",
|
| 37 |
+
... suspect_drug="Ibuprofen",
|
| 38 |
+
... severity_assessment="moderate",
|
| 39 |
+
... recommended_action="log_and_monitor",
|
| 40 |
+
... reasoning="GI bleeding is a known ibuprofen adverse effect.",
|
| 41 |
+
... )
|
| 42 |
+
... result = env.step(action)
|
| 43 |
+
... print(result.observation.feedback)
|
| 44 |
+
... print(result.reward)
|
| 45 |
+
|
| 46 |
+
Example with Docker:
|
| 47 |
+
>>> client = PharmaVigilanceEnvClient.from_docker_image("pharmacovigilance-env:latest")
|
| 48 |
+
>>> try:
|
| 49 |
+
... result = client.reset(task_id="cluster_signal_medium")
|
| 50 |
+
... action = Action(
|
| 51 |
+
... classification="new_signal",
|
| 52 |
+
... suspect_drug="Gliptozin",
|
| 53 |
+
... severity_assessment="severe",
|
| 54 |
+
... recommended_action="escalate",
|
| 55 |
+
... reasoning="Clustered vision loss on a new drug warrants escalation.",
|
| 56 |
+
... )
|
| 57 |
+
... result = client.step(action)
|
| 58 |
+
... finally:
|
| 59 |
+
... client.close()
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def _step_payload(self, action: Action) -> Dict:
|
| 63 |
+
"""
|
| 64 |
+
Convert an Action model into the JSON payload sent to /step.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
action: Typed agent action.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Dictionary representation suitable for JSON transport.
|
| 71 |
+
"""
|
| 72 |
+
return {
|
| 73 |
+
"classification": action.classification,
|
| 74 |
+
"suspect_drug": action.suspect_drug,
|
| 75 |
+
"severity_assessment": action.severity_assessment,
|
| 76 |
+
"recommended_action": action.recommended_action,
|
| 77 |
+
"reasoning": action.reasoning,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def _parse_result(self, payload: Dict) -> StepResult[Observation]:
|
| 81 |
+
"""
|
| 82 |
+
Parse a server /step response into StepResult[Observation].
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
payload: JSON response from the environment server.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
StepResult containing the typed observation, reward, and done flag.
|
| 89 |
+
"""
|
| 90 |
+
obs_data = payload.get("observation", {})
|
| 91 |
+
reports = [
|
| 92 |
+
AdverseEventReport(**report)
|
| 93 |
+
for report in obs_data.get("reports", [])
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
observation = Observation(
|
| 97 |
+
task_id=obs_data.get("task_id", ""),
|
| 98 |
+
reports=reports,
|
| 99 |
+
drug_interaction_db=obs_data.get("drug_interaction_db", {}),
|
| 100 |
+
step_number=obs_data.get("step_number", 0),
|
| 101 |
+
max_steps=obs_data.get("max_steps", 1),
|
| 102 |
+
feedback=obs_data.get("feedback"),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
reward_payload = payload.get("reward", 0.0)
|
| 106 |
+
reward_total = (
|
| 107 |
+
reward_payload.get("total", 0.0)
|
| 108 |
+
if isinstance(reward_payload, dict)
|
| 109 |
+
else reward_payload
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return StepResult(
|
| 113 |
+
observation=observation,
|
| 114 |
+
reward=reward_total,
|
| 115 |
+
done=payload.get("done", False),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 119 |
+
"""
|
| 120 |
+
Parse the /state response into an OpenEnv State object.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
payload: JSON response from the state endpoint.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
State with a task-derived episode identifier and current step count.
|
| 127 |
+
"""
|
| 128 |
+
return State(
|
| 129 |
+
episode_id=payload.get("task_id", "pharma-vigilance"),
|
| 130 |
+
step_count=payload.get("step_number", 0),
|
| 131 |
+
)
|
data.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_DATA = {
|
| 2 |
+
"known_signal_easy": {
|
| 3 |
+
"reports": [
|
| 4 |
+
{
|
| 5 |
+
"report_id": "PV-EASY-001",
|
| 6 |
+
"patient_age": 59,
|
| 7 |
+
"patient_sex": "female",
|
| 8 |
+
"drugs": ["Lisinopril 20mg"],
|
| 9 |
+
"suspect_drug": "Lisinopril",
|
| 10 |
+
"reaction": "Persistent dry cough",
|
| 11 |
+
"onset_days": 11,
|
| 12 |
+
"severity": "mild",
|
| 13 |
+
"outcome": "not_recovered",
|
| 14 |
+
"similar_reports_last_30d": 1264,
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"ground_truth": {
|
| 18 |
+
"classification": "known_side_effect",
|
| 19 |
+
"suspect_drug": "Lisinopril",
|
| 20 |
+
"severity_assessment": "mild",
|
| 21 |
+
"recommended_action": "log_and_monitor",
|
| 22 |
+
},
|
| 23 |
+
"drug_interaction_db": {
|
| 24 |
+
"Lisinopril": {
|
| 25 |
+
"known_reactions": ["dry cough", "hyperkalemia", "angioedema"],
|
| 26 |
+
"class_note": "ACE inhibitors frequently cause persistent non-productive cough.",
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
},
|
| 30 |
+
"cluster_signal_medium": {
|
| 31 |
+
"reports": [
|
| 32 |
+
{
|
| 33 |
+
"report_id": "PV-MED-001",
|
| 34 |
+
"patient_age": 44,
|
| 35 |
+
"patient_sex": "female",
|
| 36 |
+
"drugs": ["Cardiovexa"],
|
| 37 |
+
"suspect_drug": "Cardiovexa",
|
| 38 |
+
"reaction": "symptomatic bradycardia with dizziness",
|
| 39 |
+
"onset_days": 9,
|
| 40 |
+
"severity": "moderate",
|
| 41 |
+
"outcome": "not_recovered",
|
| 42 |
+
"similar_reports_last_30d": 5,
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"report_id": "PV-MED-002",
|
| 46 |
+
"patient_age": 69,
|
| 47 |
+
"patient_sex": "male",
|
| 48 |
+
"drugs": ["Cardiovexa"],
|
| 49 |
+
"suspect_drug": "Cardiovexa",
|
| 50 |
+
"reaction": "heart rate 32 with near-syncope",
|
| 51 |
+
"onset_days": 13,
|
| 52 |
+
"severity": "severe",
|
| 53 |
+
"outcome": "not_recovered",
|
| 54 |
+
"similar_reports_last_30d": 5,
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"report_id": "PV-MED-003",
|
| 58 |
+
"patient_age": 57,
|
| 59 |
+
"patient_sex": "female",
|
| 60 |
+
"drugs": ["Cardiovexa"],
|
| 61 |
+
"suspect_drug": "Cardiovexa",
|
| 62 |
+
"reaction": "fatigue and sinus bradycardia",
|
| 63 |
+
"onset_days": 7,
|
| 64 |
+
"severity": "moderate",
|
| 65 |
+
"outcome": "not_recovered",
|
| 66 |
+
"similar_reports_last_30d": 5,
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"report_id": "PV-MED-004",
|
| 70 |
+
"patient_age": 63,
|
| 71 |
+
"patient_sex": "male",
|
| 72 |
+
"drugs": ["Cardiovexa"],
|
| 73 |
+
"suspect_drug": "Cardiovexa",
|
| 74 |
+
"reaction": "bradyarrhythmia requiring ER evaluation",
|
| 75 |
+
"onset_days": 11,
|
| 76 |
+
"severity": "severe",
|
| 77 |
+
"outcome": "not_recovered",
|
| 78 |
+
"similar_reports_last_30d": 5,
|
| 79 |
+
},
|
| 80 |
+
],
|
| 81 |
+
"ground_truth": {
|
| 82 |
+
"classification": "new_signal",
|
| 83 |
+
"suspect_drug": "Cardiovexa",
|
| 84 |
+
"severity_assessment": "severe",
|
| 85 |
+
"recommended_action": "escalate",
|
| 86 |
+
},
|
| 87 |
+
"drug_interaction_db": {
|
| 88 |
+
"Cardiovexa": {
|
| 89 |
+
"known_reactions": ["headache", "fatigue"],
|
| 90 |
+
"approval_date": "5 months ago",
|
| 91 |
+
"label_note": "No labeled conduction or rhythm adverse effects.",
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
},
|
| 95 |
+
"confounded_hard": {
|
| 96 |
+
"reports": [
|
| 97 |
+
{
|
| 98 |
+
"report_id": "PV-HARD-001",
|
| 99 |
+
"patient_age": 63,
|
| 100 |
+
"patient_sex": "male",
|
| 101 |
+
"drugs": [
|
| 102 |
+
"Tacrolimus",
|
| 103 |
+
"Prednisone",
|
| 104 |
+
"Amlodipine",
|
| 105 |
+
"Magnesium oxide",
|
| 106 |
+
"Voriconazole",
|
| 107 |
+
"Trimethoprim-sulfamethoxazole",
|
| 108 |
+
],
|
| 109 |
+
"suspect_drug": "Trimethoprim-sulfamethoxazole",
|
| 110 |
+
"reaction": "Acute kidney injury with tacrolimus trough 4x baseline",
|
| 111 |
+
"onset_days": 6,
|
| 112 |
+
"severity": "critical",
|
| 113 |
+
"outcome": "not_recovered",
|
| 114 |
+
"similar_reports_last_30d": 1,
|
| 115 |
+
}
|
| 116 |
+
],
|
| 117 |
+
"ground_truth": {
|
| 118 |
+
"classification": "new_signal",
|
| 119 |
+
"suspect_drug": "Tacrolimus+Voriconazole",
|
| 120 |
+
"severity_assessment": "critical",
|
| 121 |
+
"recommended_action": "escalate",
|
| 122 |
+
},
|
| 123 |
+
"drug_interaction_db": {
|
| 124 |
+
"Voriconazole": {
|
| 125 |
+
"strong_metabolic_inhibitor": True,
|
| 126 |
+
"interacts_with": ["Tacrolimus", "Cyclosporine"],
|
| 127 |
+
"interaction_note": "Markedly increases tacrolimus exposure; dose reduction and level monitoring required.",
|
| 128 |
+
},
|
| 129 |
+
"Tacrolimus": {
|
| 130 |
+
"narrow_therapeutic_index": True,
|
| 131 |
+
"known_reactions": ["nephrotoxicity", "tremor"],
|
| 132 |
+
"requires_level_monitoring": True,
|
| 133 |
+
},
|
| 134 |
+
},
|
| 135 |
+
},
|
| 136 |
+
}
|
env.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
from tasks import TaskDefinition, get_task
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AdverseEventReport(BaseModel):
|
| 9 |
+
report_id: str
|
| 10 |
+
patient_age: int
|
| 11 |
+
patient_sex: str
|
| 12 |
+
drugs: List[str]
|
| 13 |
+
suspect_drug: str
|
| 14 |
+
reaction: str
|
| 15 |
+
onset_days: int
|
| 16 |
+
severity: str
|
| 17 |
+
outcome: str
|
| 18 |
+
similar_reports_last_30d: int
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Observation(BaseModel):
|
| 22 |
+
task_id: str
|
| 23 |
+
reports: List[AdverseEventReport]
|
| 24 |
+
drug_interaction_db: dict
|
| 25 |
+
step_number: int
|
| 26 |
+
max_steps: int
|
| 27 |
+
feedback: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Action(BaseModel):
|
| 31 |
+
classification: str
|
| 32 |
+
suspect_drug: str
|
| 33 |
+
severity_assessment: str
|
| 34 |
+
recommended_action: str
|
| 35 |
+
reasoning: str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Reward(BaseModel):
|
| 39 |
+
total: float = Field(..., ge=0.0, le=1.0)
|
| 40 |
+
breakdown: dict
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PharmaVigilanceEnv:
|
| 44 |
+
def __init__(self):
|
| 45 |
+
self.current_task: Optional[TaskDefinition] = None
|
| 46 |
+
self.current_task_id: Optional[str] = None
|
| 47 |
+
self.step_number = 0
|
| 48 |
+
self.max_steps = 1
|
| 49 |
+
self.last_action: Optional[dict] = None
|
| 50 |
+
self.last_reward: Optional[dict] = None
|
| 51 |
+
|
| 52 |
+
def reset(self, task_id: str = "known_signal_easy") -> Observation:
|
| 53 |
+
self.current_task = get_task(task_id)
|
| 54 |
+
self.current_task_id = self.current_task.task_id
|
| 55 |
+
self.step_number = 0
|
| 56 |
+
self.last_action = None
|
| 57 |
+
self.last_reward = None
|
| 58 |
+
return Observation(
|
| 59 |
+
task_id=self.current_task.task_id,
|
| 60 |
+
reports=self.current_task.reports,
|
| 61 |
+
drug_interaction_db=self.current_task.drug_interaction_db,
|
| 62 |
+
step_number=self.step_number,
|
| 63 |
+
max_steps=self.max_steps,
|
| 64 |
+
feedback="Task loaded. Submit one final pharmacovigilance assessment.",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]:
|
| 68 |
+
if self.current_task is None:
|
| 69 |
+
raise RuntimeError("Call reset() before step().")
|
| 70 |
+
|
| 71 |
+
reward = self.current_task.action_grader(action)
|
| 72 |
+
self.step_number += 1
|
| 73 |
+
self.last_action = action.model_dump()
|
| 74 |
+
self.last_reward = reward.model_dump()
|
| 75 |
+
done = True
|
| 76 |
+
|
| 77 |
+
matched = sum(
|
| 78 |
+
1
|
| 79 |
+
for field in (
|
| 80 |
+
"classification",
|
| 81 |
+
"suspect_drug",
|
| 82 |
+
"severity_assessment",
|
| 83 |
+
"recommended_action",
|
| 84 |
+
)
|
| 85 |
+
if reward.breakdown.get(field, 0.0) > 0
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if reward.total >= 0.9:
|
| 89 |
+
feedback = "Strong assessment. The key safety judgment and follow-up were correct."
|
| 90 |
+
elif reward.total >= 0.5:
|
| 91 |
+
feedback = "Partially correct assessment. Some causal or operational details were missed."
|
| 92 |
+
else:
|
| 93 |
+
feedback = "Weak assessment. This case would need human analyst correction."
|
| 94 |
+
|
| 95 |
+
observation = Observation(
|
| 96 |
+
task_id=self.current_task.task_id,
|
| 97 |
+
reports=self.current_task.reports,
|
| 98 |
+
drug_interaction_db=self.current_task.drug_interaction_db,
|
| 99 |
+
step_number=self.step_number,
|
| 100 |
+
max_steps=self.max_steps,
|
| 101 |
+
feedback=feedback,
|
| 102 |
+
)
|
| 103 |
+
info = {
|
| 104 |
+
"matched_fields": matched,
|
| 105 |
+
"difficulty": self.current_task.difficulty,
|
| 106 |
+
"reward_breakdown": reward.breakdown,
|
| 107 |
+
}
|
| 108 |
+
return observation, reward, done, info
|
| 109 |
+
|
| 110 |
+
def state(self) -> dict:
|
| 111 |
+
return {
|
| 112 |
+
"task_id": self.current_task_id,
|
| 113 |
+
"step_number": self.step_number,
|
| 114 |
+
"last_action": self.last_action,
|
| 115 |
+
"last_reward": self.last_reward,
|
| 116 |
+
}
|
graders.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public grader entrypoints for OpenEnv validation and judging."""
|
| 2 |
+
|
| 3 |
+
from server.graders import (
|
| 4 |
+
cluster_signal_medium_grader,
|
| 5 |
+
confounded_hard_grader,
|
| 6 |
+
easy_grader,
|
| 7 |
+
hard_grader,
|
| 8 |
+
known_signal_easy_grader,
|
| 9 |
+
medium_grader,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
TASK_TO_GRADER = {
|
| 13 |
+
"known_signal_easy": known_signal_easy_grader,
|
| 14 |
+
"cluster_signal_medium": cluster_signal_medium_grader,
|
| 15 |
+
"confounded_hard": confounded_hard_grader,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
TIER_TO_GRADER = {
|
| 19 |
+
"easy": easy_grader,
|
| 20 |
+
"medium": medium_grader,
|
| 21 |
+
"hard": hard_grader,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"TASK_TO_GRADER",
|
| 26 |
+
"TIER_TO_GRADER",
|
| 27 |
+
"easy_grader",
|
| 28 |
+
"medium_grader",
|
| 29 |
+
"hard_grader",
|
| 30 |
+
"known_signal_easy_grader",
|
| 31 |
+
"cluster_signal_medium_grader",
|
| 32 |
+
"confounded_hard_grader",
|
| 33 |
+
]
|
inference.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline runner for the Pharmacovigilance Signal Detector submission.
|
| 3 |
+
|
| 4 |
+
This script queries a chat model through the OpenAI client, sends its decision
|
| 5 |
+
to the environment server, and prints the exact machine-readable lines expected
|
| 6 |
+
by the evaluator.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from typing import Iterable, List
|
| 13 |
+
|
| 14 |
+
import requests
|
| 15 |
+
from openai import OpenAI
|
| 16 |
+
from pydantic import ValidationError
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from .models import PharmaAction
|
| 20 |
+
except ImportError:
|
| 21 |
+
from models import PharmaAction
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 25 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 26 |
+
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 27 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 28 |
+
TASK_OVERRIDE = os.getenv("TASK_NAME", "").strip()
|
| 29 |
+
BENCHMARK = "pharma-vigilance"
|
| 30 |
+
|
| 31 |
+
TASK_SETS = {
|
| 32 |
+
"easy": ("known_signal_easy",),
|
| 33 |
+
"medium": ("cluster_signal_medium",),
|
| 34 |
+
"hard": ("confounded_hard",),
|
| 35 |
+
"all": ("known_signal_easy", "cluster_signal_medium", "confounded_hard"),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
SYSTEM_MESSAGE = """
|
| 39 |
+
You are acting as a pharmacovigilance triage analyst.
|
| 40 |
+
|
| 41 |
+
Read the synthetic case bundle and reply with exactly one JSON object.
|
| 42 |
+
Allowed keys:
|
| 43 |
+
- classification
|
| 44 |
+
- suspect_drug
|
| 45 |
+
- severity_assessment
|
| 46 |
+
- recommended_action
|
| 47 |
+
- reasoning
|
| 48 |
+
|
| 49 |
+
Allowed values:
|
| 50 |
+
- classification: new_signal, known_side_effect, noise, duplicate
|
| 51 |
+
- severity_assessment: mild, moderate, severe, critical
|
| 52 |
+
- recommended_action: escalate, log_and_monitor, dismiss, request_more_info
|
| 53 |
+
|
| 54 |
+
No markdown. No explanation outside the JSON object.
|
| 55 |
+
""".strip()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def emit_start(task_name: str) -> None:
|
| 59 |
+
print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def emit_step(step_no: int, action_text: str, reward: float, done: bool, error: str | None) -> None:
|
| 63 |
+
error_text = error if error else "null"
|
| 64 |
+
print(
|
| 65 |
+
f"[STEP] step={step_no} action={action_text} reward={reward:.2f} "
|
| 66 |
+
f"done={str(done).lower()} error={error_text}",
|
| 67 |
+
flush=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def emit_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 72 |
+
reward_text = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 73 |
+
print(
|
| 74 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 75 |
+
f"score={score:.2f} rewards={reward_text}",
|
| 76 |
+
flush=True,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def choose_tasks(selection: str) -> Iterable[str]:
|
| 81 |
+
if TASK_OVERRIDE:
|
| 82 |
+
return (TASK_OVERRIDE,)
|
| 83 |
+
return TASK_SETS[selection]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def client() -> OpenAI:
|
| 87 |
+
if not HF_TOKEN:
|
| 88 |
+
raise EnvironmentError("HF_TOKEN or API_KEY must be set before running inference.py")
|
| 89 |
+
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fetch_reset(task_name: str) -> dict:
|
| 93 |
+
response = requests.post(
|
| 94 |
+
f"{ENV_URL}/reset",
|
| 95 |
+
json={"task_id": task_name},
|
| 96 |
+
timeout=30,
|
| 97 |
+
)
|
| 98 |
+
response.raise_for_status()
|
| 99 |
+
return response.json()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def submit_action(action: PharmaAction) -> dict:
|
| 103 |
+
response = requests.post(
|
| 104 |
+
f"{ENV_URL}/step",
|
| 105 |
+
json=action.model_dump(),
|
| 106 |
+
timeout=30,
|
| 107 |
+
)
|
| 108 |
+
response.raise_for_status()
|
| 109 |
+
return response.json()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def prompt_for_case(observation: dict) -> str:
|
| 113 |
+
return (
|
| 114 |
+
"Assess the following pharmacovigilance case.\n\n"
|
| 115 |
+
"Return one final structured judgment.\n\n"
|
| 116 |
+
f"{json.dumps(observation, ensure_ascii=True, indent=2)}\n\n"
|
| 117 |
+
"Focus on whether the case is novel or known, the most plausible causal "
|
| 118 |
+
"drug or interaction, the right severity band, and the operational next step."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def ask_model(llm: OpenAI, observation: dict) -> PharmaAction:
|
| 123 |
+
completion = llm.chat.completions.create(
|
| 124 |
+
model=MODEL_NAME,
|
| 125 |
+
messages=[
|
| 126 |
+
{"role": "system", "content": SYSTEM_MESSAGE},
|
| 127 |
+
{"role": "user", "content": prompt_for_case(observation)},
|
| 128 |
+
],
|
| 129 |
+
temperature=0.0,
|
| 130 |
+
max_tokens=260,
|
| 131 |
+
stream=False,
|
| 132 |
+
)
|
| 133 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 134 |
+
payload = json.loads(text)
|
| 135 |
+
return PharmaAction(**payload)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def compact_action(action: PharmaAction) -> str:
|
| 139 |
+
label = action.classification
|
| 140 |
+
if action.suspect_drug:
|
| 141 |
+
return f"{label}/{action.suspect_drug}"
|
| 142 |
+
return label
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def final_score(rewards: List[float]) -> float:
|
| 146 |
+
if not rewards:
|
| 147 |
+
return 0.0
|
| 148 |
+
score = sum(rewards) / len(rewards)
|
| 149 |
+
return min(max(round(score, 4), 0.0), 1.0)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def run_one_task(llm: OpenAI, task_name: str) -> None:
|
| 153 |
+
rewards: List[float] = []
|
| 154 |
+
steps_taken = 0
|
| 155 |
+
score = 0.0
|
| 156 |
+
success = False
|
| 157 |
+
|
| 158 |
+
emit_start(task_name)
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
observation = fetch_reset(task_name)
|
| 162 |
+
action = ask_model(llm, observation)
|
| 163 |
+
action_text = compact_action(action)
|
| 164 |
+
|
| 165 |
+
result = submit_action(action)
|
| 166 |
+
reward_payload = result.get("reward", {})
|
| 167 |
+
reward = (
|
| 168 |
+
float(reward_payload.get("total", 0.0))
|
| 169 |
+
if isinstance(reward_payload, dict)
|
| 170 |
+
else float(reward_payload)
|
| 171 |
+
)
|
| 172 |
+
done = bool(result.get("done", False))
|
| 173 |
+
|
| 174 |
+
rewards.append(reward)
|
| 175 |
+
steps_taken = 1
|
| 176 |
+
emit_step(1, action_text, reward, done, None)
|
| 177 |
+
|
| 178 |
+
score = final_score(rewards)
|
| 179 |
+
success = score >= 0.75
|
| 180 |
+
|
| 181 |
+
except json.JSONDecodeError:
|
| 182 |
+
rewards = [0.0]
|
| 183 |
+
steps_taken = 1
|
| 184 |
+
emit_step(1, "parse_error", 0.0, True, "parse_error")
|
| 185 |
+
except ValidationError:
|
| 186 |
+
rewards = [0.0]
|
| 187 |
+
steps_taken = 1
|
| 188 |
+
emit_step(1, "schema_error", 0.0, True, "schema_error")
|
| 189 |
+
except Exception as exc:
|
| 190 |
+
rewards = [0.0]
|
| 191 |
+
steps_taken = 1
|
| 192 |
+
emit_step(1, "error", 0.0, True, str(exc))
|
| 193 |
+
finally:
|
| 194 |
+
emit_end(success, steps_taken, score, rewards or [0.0])
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def main() -> None:
|
| 198 |
+
parser = argparse.ArgumentParser(description="Run the pharmacovigilance baseline agent")
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--difficulty",
|
| 201 |
+
choices=["easy", "medium", "hard", "all"],
|
| 202 |
+
default="all",
|
| 203 |
+
help="Which task subset to run",
|
| 204 |
+
)
|
| 205 |
+
args = parser.parse_args()
|
| 206 |
+
|
| 207 |
+
llm = client()
|
| 208 |
+
for task_name in choose_tasks(args.difficulty):
|
| 209 |
+
run_one_task(llm, task_name)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
from openenv.core.env_server.types import Action, Observation
|
| 4 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AdverseEventReport(BaseModel):
|
| 8 |
+
model_config = ConfigDict(revalidate_instances="never")
|
| 9 |
+
|
| 10 |
+
report_id: str = Field(..., description="Unique synthetic report identifier")
|
| 11 |
+
patient_age: int = Field(..., description="Patient age in years")
|
| 12 |
+
patient_sex: str = Field(..., description="Patient sex")
|
| 13 |
+
drugs: List[str] = Field(default_factory=list, description="All drugs the patient was taking")
|
| 14 |
+
suspect_drug: str = Field(..., description="Drug named as suspect by the original reporter")
|
| 15 |
+
reaction: str = Field(..., description="Observed adverse reaction")
|
| 16 |
+
onset_days: int = Field(..., description="Days from drug start to reaction onset")
|
| 17 |
+
severity: str = Field(..., description="Reported case severity")
|
| 18 |
+
outcome: str = Field(..., description="Clinical outcome status")
|
| 19 |
+
similar_reports_last_30d: int = Field(..., description="Count of similar reports in the last 30 days")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PharmaObservation(Observation):
|
| 23 |
+
model_config = ConfigDict(
|
| 24 |
+
extra="forbid",
|
| 25 |
+
validate_assignment=True,
|
| 26 |
+
arbitrary_types_allowed=True,
|
| 27 |
+
revalidate_instances="never",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
task_id: str = Field(..., description="Current task identifier")
|
| 31 |
+
reports: List[AdverseEventReport] = Field(default_factory=list, description="Synthetic adverse event reports")
|
| 32 |
+
drug_interaction_db: dict = Field(default_factory=dict, description="Hardcoded interaction and safety lookup")
|
| 33 |
+
step_number: int = Field(default=0, description="Current step number")
|
| 34 |
+
max_steps: int = Field(default=1, description="Maximum number of steps in the episode")
|
| 35 |
+
feedback: Optional[str] = Field(default=None, description="Feedback after the previous action")
|
| 36 |
+
|
| 37 |
+
reward: float = Field(default=0.0, description="Reward from the last action")
|
| 38 |
+
done: bool = Field(default=False, description="Episode termination flag")
|
| 39 |
+
metadata: dict = Field(default_factory=dict, description="Additional environment metadata")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class PharmaAction(Action):
|
| 43 |
+
classification: str = Field(..., description="new_signal | known_side_effect | noise | duplicate")
|
| 44 |
+
suspect_drug: str = Field(..., description="Drug or interaction believed to be causal")
|
| 45 |
+
severity_assessment: str = Field(..., description="mild | moderate | severe | critical")
|
| 46 |
+
recommended_action: str = Field(..., description="escalate | log_and_monitor | dismiss | request_more_info")
|
| 47 |
+
reasoning: str = Field(default="", description="Short explanation of the decision")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class PharmaReward(BaseModel):
|
| 51 |
+
total: float = Field(..., description="Total reward in the 0.0-1.0 range")
|
| 52 |
+
breakdown: dict = Field(default_factory=dict, description="Per-component reward breakdown")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: pharma_vigilance_env
|
| 3 |
+
display_name: "Pharmacovigilance Signal Detector"
|
| 4 |
+
description: >
|
| 5 |
+
A real-world OpenEnv environment where an AI agent acts as a pharmacovigilance
|
| 6 |
+
analyst. The agent reviews synthetic adverse-event cases, decides whether they
|
| 7 |
+
represent known labeled effects, emerging safety signals, or low-value noise,
|
| 8 |
+
and recommends the correct operational follow-up. Tasks cover known class
|
| 9 |
+
effects, clustered post-marketing signal detection, and confounded
|
| 10 |
+
drug-drug-interaction cases that require causal reasoning rather than surface
|
| 11 |
+
blame assignment.
|
| 12 |
+
type: space
|
| 13 |
+
runtime: fastapi
|
| 14 |
+
app: server.app:app
|
| 15 |
+
port: 7860
|
| 16 |
+
tags:
|
| 17 |
+
- openenv
|
| 18 |
+
- healthcare
|
| 19 |
+
- pharmacovigilance
|
| 20 |
+
- drug-safety
|
| 21 |
+
- reinforcement-learning
|
| 22 |
+
|
| 23 |
+
action_space:
|
| 24 |
+
type: structured
|
| 25 |
+
fields:
|
| 26 |
+
- name: classification
|
| 27 |
+
type: string
|
| 28 |
+
values: [new_signal, known_side_effect, noise, duplicate]
|
| 29 |
+
description: "Top-level safety classification chosen by the agent"
|
| 30 |
+
- name: suspect_drug
|
| 31 |
+
type: string
|
| 32 |
+
description: "Drug or drug interaction the agent believes is causally responsible"
|
| 33 |
+
- name: severity_assessment
|
| 34 |
+
type: string
|
| 35 |
+
values: [mild, moderate, severe, critical]
|
| 36 |
+
description: "Agent-assigned clinical severity for the case"
|
| 37 |
+
- name: recommended_action
|
| 38 |
+
type: string
|
| 39 |
+
values: [escalate, log_and_monitor, dismiss, request_more_info]
|
| 40 |
+
description: "Operational pharmacovigilance follow-up decision"
|
| 41 |
+
- name: reasoning
|
| 42 |
+
type: string
|
| 43 |
+
description: "Brief free-text rationale used for partial credit on the hard task"
|
| 44 |
+
|
| 45 |
+
observation_space:
|
| 46 |
+
type: structured
|
| 47 |
+
fields:
|
| 48 |
+
- name: task_id
|
| 49 |
+
type: string
|
| 50 |
+
description: "Identifier of the current pharmacovigilance task"
|
| 51 |
+
- name: reports
|
| 52 |
+
type: array
|
| 53 |
+
description: "One or more synthetic adverse-event reports included in the case"
|
| 54 |
+
- name: drug_interaction_db
|
| 55 |
+
type: object
|
| 56 |
+
description: "Hardcoded safety and interaction reference data visible to the agent"
|
| 57 |
+
- name: step_number
|
| 58 |
+
type: integer
|
| 59 |
+
description: "Current step index within the episode"
|
| 60 |
+
- name: max_steps
|
| 61 |
+
type: integer
|
| 62 |
+
description: "Maximum number of steps allowed in the episode"
|
| 63 |
+
- name: feedback
|
| 64 |
+
type: string
|
| 65 |
+
required: false
|
| 66 |
+
description: "Human-readable feedback from the previous action"
|
| 67 |
+
|
| 68 |
+
reward:
|
| 69 |
+
min: 0.0
|
| 70 |
+
max: 1.0
|
| 71 |
+
description: >
|
| 72 |
+
Reward is built from four 0.25 components for classification correctness,
|
| 73 |
+
causal suspect selection, severity assessment, and recommended operational
|
| 74 |
+
action. A false alarm penalty of -0.10 applies when the agent escalates a
|
| 75 |
+
case that is truly noise, and a larger missed-signal penalty of -0.20
|
| 76 |
+
applies when the agent dismisses a true new signal. The hard task can earn
|
| 77 |
+
an additional +0.15 reasoning bonus when the explanation explicitly
|
| 78 |
+
references the interaction mechanism or therapeutic drug monitoring clues.
|
| 79 |
+
|
| 80 |
+
difficulties:
|
| 81 |
+
- easy
|
| 82 |
+
- medium
|
| 83 |
+
- hard
|
| 84 |
+
|
| 85 |
+
max_steps: 1
|
| 86 |
+
|
| 87 |
+
tasks:
|
| 88 |
+
- id: known_signal_easy
|
| 89 |
+
difficulty: easy
|
| 90 |
+
description: >
|
| 91 |
+
Review a synthetic single-patient report in which an ACE inhibitor is
|
| 92 |
+
followed by persistent dry cough and many similar recent cases already
|
| 93 |
+
exist. The correct behavior is to recognize this as a known labeled effect
|
| 94 |
+
and recommend routine logging and monitoring rather than escalation.
|
| 95 |
+
grader: graders.known_signal_easy_grader
|
| 96 |
+
|
| 97 |
+
- id: cluster_signal_medium
|
| 98 |
+
difficulty: medium
|
| 99 |
+
description: >
|
| 100 |
+
Review a clustered set of recent post-marketing reports tied to a newly
|
| 101 |
+
launched cardiovascular therapy. The reports show symptomatic bradycardia
|
| 102 |
+
and near-syncope despite the label lacking rhythm-related warnings. The
|
| 103 |
+
agent should detect an emerging signal and escalate.
|
| 104 |
+
grader: graders.cluster_signal_medium_grader
|
| 105 |
+
|
| 106 |
+
- id: confounded_hard
|
| 107 |
+
difficulty: hard
|
| 108 |
+
description: >
|
| 109 |
+
Review a confounded transplant-medicine case in which the reporter blames
|
| 110 |
+
the wrong drug. The correct judgment requires identifying a tacrolimus and
|
| 111 |
+
voriconazole interaction, recognizing acute kidney injury risk from toxic
|
| 112 |
+
exposure, and escalating the case as a clinically serious new signal.
|
| 113 |
+
grader: graders.confounded_hard_grader
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "pharmacovigilance-env"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Pharmacovigilance Signal Detector"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"fastapi",
|
| 12 |
+
"uvicorn",
|
| 13 |
+
"pydantic",
|
| 14 |
+
"openai",
|
| 15 |
+
"requests",
|
| 16 |
+
"openenv-core",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.scripts]
|
| 20 |
+
server = "server.app:main"
|
| 21 |
+
|
| 22 |
+
[tool.setuptools]
|
| 23 |
+
py-modules = ["env", "tasks", "data", "inference"]
|
| 24 |
+
packages = ["server"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pydantic
|
| 4 |
+
openai
|
| 5 |
+
requests
|
| 6 |
+
openenv-core
|
run_demo.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agent import RuleBasedPharmaAgent
|
| 2 |
+
from env import PharmaVigilanceEnv
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def main() -> None:
|
| 6 |
+
env = PharmaVigilanceEnv()
|
| 7 |
+
agent = RuleBasedPharmaAgent()
|
| 8 |
+
|
| 9 |
+
for task_id in ("known_signal_easy", "cluster_signal_medium", "confounded_hard"):
|
| 10 |
+
observation = env.reset(task_id)
|
| 11 |
+
action = agent.act(observation)
|
| 12 |
+
observation, reward, done, info = env.step(action)
|
| 13 |
+
|
| 14 |
+
print(f"\nTask: {task_id}")
|
| 15 |
+
print(f"Action: {action.classification} / {action.suspect_drug}")
|
| 16 |
+
print(f"Reward: {reward.total:.2f}")
|
| 17 |
+
print(f"Done: {done}")
|
| 18 |
+
print(f"Feedback: {observation.feedback}")
|
| 19 |
+
print(f"Info: {info}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
main()
|
server.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
|
| 3 |
+
from env import Action, PharmaVigilanceEnv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
app = FastAPI()
|
| 7 |
+
env = PharmaVigilanceEnv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@app.post("/reset")
|
| 11 |
+
def reset(body: dict = {}):
|
| 12 |
+
task_id = body.get("task_id", "known_signal_easy")
|
| 13 |
+
obs = env.reset(task_id)
|
| 14 |
+
return obs.model_dump()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@app.post("/step")
|
| 18 |
+
def step(action: Action):
|
| 19 |
+
obs, reward, done, info = env.step(action)
|
| 20 |
+
return {
|
| 21 |
+
"observation": obs.model_dump(),
|
| 22 |
+
"reward": reward.model_dump(),
|
| 23 |
+
"done": done,
|
| 24 |
+
"info": info,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.get("/state")
|
| 29 |
+
def state():
|
| 30 |
+
return env.state()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.get("/tasks")
|
| 34 |
+
def list_tasks():
|
| 35 |
+
return {"tasks": ["known_signal_easy", "cluster_signal_medium", "confounded_hard"]}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.get("/health")
|
| 39 |
+
def health():
|
| 40 |
+
return {"status": "ok"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
|
| 44 |
+
import uvicorn
|
| 45 |
+
|
| 46 |
+
uvicorn.run(app, host=host, port=port)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Pharmacovigilance environment server components."""
|
| 8 |
+
|
| 9 |
+
from .pharma_vigilance_env_environment import PharmaVigilanceEnv
|
| 10 |
+
|
| 11 |
+
__all__ = ["PharmaVigilanceEnv"]
|
server/app.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
|
| 3 |
+
from env import Action, PharmaVigilanceEnv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
app = FastAPI()
|
| 7 |
+
env = PharmaVigilanceEnv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@app.post("/reset")
|
| 11 |
+
def reset(body: dict = {}):
|
| 12 |
+
task_id = body.get("task_id", "known_signal_easy")
|
| 13 |
+
obs = env.reset(task_id)
|
| 14 |
+
return obs.model_dump()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@app.post("/step")
|
| 18 |
+
def step(action: Action):
|
| 19 |
+
obs, reward, done, info = env.step(action)
|
| 20 |
+
return {
|
| 21 |
+
"observation": obs.model_dump(),
|
| 22 |
+
"reward": reward.model_dump(),
|
| 23 |
+
"done": done,
|
| 24 |
+
"info": info,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.get("/state")
|
| 29 |
+
def state():
|
| 30 |
+
return env.state()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.get("/tasks")
|
| 34 |
+
def list_tasks():
|
| 35 |
+
return {"tasks": ["known_signal_easy", "cluster_signal_medium", "confounded_hard"]}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.get("/health")
|
| 39 |
+
def health():
|
| 40 |
+
return {"status": "ok"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
|
| 44 |
+
import uvicorn
|
| 45 |
+
|
| 46 |
+
uvicorn.run(app, host=host, port=port)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
server/graders.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trajectory scorers for the Pharmacovigilance Signal Detector.
|
| 3 |
+
|
| 4 |
+
These functions are intentionally pharmacovigilance-specific rather than
|
| 5 |
+
generic "reward bucket" adapters. The scoring rubric emphasizes:
|
| 6 |
+
|
| 7 |
+
1. Signal sensitivity: missing a true novel safety signal is costly.
|
| 8 |
+
2. Operational judgment: escalation/log/dismiss choices matter independently.
|
| 9 |
+
3. Causal calibration: high scores should reflect not just suspicion, but
|
| 10 |
+
identifying the right drug or interaction.
|
| 11 |
+
|
| 12 |
+
All public grader outputs are forced into the judge-safe interval (0.01, 0.99).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from typing import Any, Iterable, List
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
STRICT_MIN = 0.01
|
| 19 |
+
STRICT_MAX = 0.99
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _bounded(value: float) -> float:
|
| 23 |
+
return min(max(round(value, 4), STRICT_MIN), STRICT_MAX)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _as_reward_list(trajectory: dict | None) -> List[float]:
|
| 27 |
+
payload = trajectory or {}
|
| 28 |
+
|
| 29 |
+
rewards = payload.get("rewards")
|
| 30 |
+
if isinstance(rewards, list) and rewards:
|
| 31 |
+
return [float(item) for item in rewards]
|
| 32 |
+
|
| 33 |
+
if "score" in payload:
|
| 34 |
+
return [float(payload["score"])]
|
| 35 |
+
|
| 36 |
+
reward = payload.get("reward")
|
| 37 |
+
if isinstance(reward, dict) and "total" in reward:
|
| 38 |
+
return [float(reward["total"])]
|
| 39 |
+
if reward is not None:
|
| 40 |
+
return [float(reward)]
|
| 41 |
+
|
| 42 |
+
return []
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _reward_profile(reward: float) -> str:
|
| 46 |
+
"""
|
| 47 |
+
Translate a step reward into a pharmacovigilance interpretation bucket.
|
| 48 |
+
|
| 49 |
+
This keeps the grader coupled to the meaning of the environment rather than
|
| 50 |
+
to borrowed labels from a different domain.
|
| 51 |
+
"""
|
| 52 |
+
if reward <= 0.05:
|
| 53 |
+
return "unsafe_miss"
|
| 54 |
+
if reward <= 0.20:
|
| 55 |
+
return "bad_call"
|
| 56 |
+
if reward < 0.50:
|
| 57 |
+
return "weak_triage"
|
| 58 |
+
if reward < 0.80:
|
| 59 |
+
return "workable_triage"
|
| 60 |
+
if reward < 0.95:
|
| 61 |
+
return "strong_triage"
|
| 62 |
+
return "expert_triage"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _mean(values: Iterable[float]) -> float:
|
| 66 |
+
items = list(values)
|
| 67 |
+
if not items:
|
| 68 |
+
return 0.5
|
| 69 |
+
return sum(items) / len(items)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _score_episode(
|
| 73 |
+
rewards: List[float],
|
| 74 |
+
*,
|
| 75 |
+
miss_cost: float,
|
| 76 |
+
overcall_cost: float,
|
| 77 |
+
stability_gain: float,
|
| 78 |
+
expertise_gain: float,
|
| 79 |
+
) -> float:
|
| 80 |
+
if not rewards:
|
| 81 |
+
return 0.5
|
| 82 |
+
|
| 83 |
+
labels = [_reward_profile(reward) for reward in rewards]
|
| 84 |
+
mean_reward = _mean(rewards)
|
| 85 |
+
total_steps = len(rewards)
|
| 86 |
+
|
| 87 |
+
unsafe_miss_count = labels.count("unsafe_miss")
|
| 88 |
+
bad_call_count = labels.count("bad_call")
|
| 89 |
+
weak_count = labels.count("weak_triage")
|
| 90 |
+
strong_count = labels.count("strong_triage") + labels.count("expert_triage")
|
| 91 |
+
expert_count = labels.count("expert_triage")
|
| 92 |
+
|
| 93 |
+
downward_pressure = (
|
| 94 |
+
min(unsafe_miss_count * miss_cost, 0.35)
|
| 95 |
+
+ min(bad_call_count * overcall_cost, 0.15)
|
| 96 |
+
+ min(weak_count * 0.015, 0.06)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
upward_pressure = 0.0
|
| 100 |
+
if strong_count / total_steps >= 0.80:
|
| 101 |
+
upward_pressure += stability_gain
|
| 102 |
+
if expert_count / total_steps >= 0.60:
|
| 103 |
+
upward_pressure += expertise_gain
|
| 104 |
+
|
| 105 |
+
return _bounded(mean_reward - downward_pressure + upward_pressure)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def easy_grader(trajectory: dict = None) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Easy tier: obvious known-signal recognition and straightforward handling.
|
| 111 |
+
|
| 112 |
+
The scorer expects high reliability here. Weak or missed judgments are
|
| 113 |
+
penalized more sharply because these are the least ambiguous cases.
|
| 114 |
+
"""
|
| 115 |
+
rewards = _as_reward_list(trajectory)
|
| 116 |
+
return _score_episode(
|
| 117 |
+
rewards,
|
| 118 |
+
miss_cost=0.12,
|
| 119 |
+
overcall_cost=0.03,
|
| 120 |
+
stability_gain=0.05,
|
| 121 |
+
expertise_gain=0.01,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def medium_grader(trajectory: dict = None) -> float:
|
| 126 |
+
"""
|
| 127 |
+
Medium tier: cluster recognition and escalation readiness.
|
| 128 |
+
|
| 129 |
+
These cases reward agents that can move from single-case thinking to
|
| 130 |
+
population-level signal interpretation.
|
| 131 |
+
"""
|
| 132 |
+
rewards = _as_reward_list(trajectory)
|
| 133 |
+
return _score_episode(
|
| 134 |
+
rewards,
|
| 135 |
+
miss_cost=0.09,
|
| 136 |
+
overcall_cost=0.04,
|
| 137 |
+
stability_gain=0.03,
|
| 138 |
+
expertise_gain=0.02,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def hard_grader(trajectory: dict = None) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Hard tier: confounding, blame reassignment, and interaction reasoning.
|
| 145 |
+
|
| 146 |
+
The hard scorer gives extra value to near-expert trajectories because this
|
| 147 |
+
tier is specifically designed to separate shallow pattern matching from
|
| 148 |
+
mechanistic causal reasoning.
|
| 149 |
+
"""
|
| 150 |
+
rewards = _as_reward_list(trajectory)
|
| 151 |
+
return _score_episode(
|
| 152 |
+
rewards,
|
| 153 |
+
miss_cost=0.07,
|
| 154 |
+
overcall_cost=0.03,
|
| 155 |
+
stability_gain=0.02,
|
| 156 |
+
expertise_gain=0.04,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def known_signal_easy_grader(trajectory: dict = None) -> float:
|
| 161 |
+
return easy_grader(trajectory)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def cluster_signal_medium_grader(trajectory: dict = None) -> float:
|
| 165 |
+
return medium_grader(trajectory)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def confounded_hard_grader(trajectory: dict = None) -> float:
|
| 169 |
+
return hard_grader(trajectory)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
__all__ = [
|
| 173 |
+
"easy_grader",
|
| 174 |
+
"medium_grader",
|
| 175 |
+
"hard_grader",
|
| 176 |
+
"known_signal_easy_grader",
|
| 177 |
+
"cluster_signal_medium_grader",
|
| 178 |
+
"confounded_hard_grader",
|
| 179 |
+
]
|
server/pharma_vigilance_env_environment.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper exposing the main environment class under server/."""
|
| 2 |
+
|
| 3 |
+
from env import PharmaVigilanceEnv
|
| 4 |
+
|
| 5 |
+
__all__ = ["PharmaVigilanceEnv"]
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
openai
|
| 5 |
+
requests
|
| 6 |
+
openenv-core
|
server/tasks.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server-side task exports for the pharmacovigilance environment."""
|
| 2 |
+
|
| 3 |
+
from tasks import (
|
| 4 |
+
GroundTruth,
|
| 5 |
+
TaskDefinition,
|
| 6 |
+
cluster_signal_medium_action_grader,
|
| 7 |
+
cluster_signal_medium_grader,
|
| 8 |
+
confounded_hard_action_grader,
|
| 9 |
+
confounded_hard_grader,
|
| 10 |
+
get_task,
|
| 11 |
+
get_tasks,
|
| 12 |
+
known_signal_easy_action_grader,
|
| 13 |
+
known_signal_easy_grader,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"GroundTruth",
|
| 18 |
+
"TaskDefinition",
|
| 19 |
+
"get_task",
|
| 20 |
+
"get_tasks",
|
| 21 |
+
"known_signal_easy_action_grader",
|
| 22 |
+
"cluster_signal_medium_action_grader",
|
| 23 |
+
"confounded_hard_action_grader",
|
| 24 |
+
"known_signal_easy_grader",
|
| 25 |
+
"cluster_signal_medium_grader",
|
| 26 |
+
"confounded_hard_grader",
|
| 27 |
+
]
|
tasks.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 5 |
+
|
| 6 |
+
from data import TASK_DATA
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GroundTruth(BaseModel):
|
| 10 |
+
classification: str
|
| 11 |
+
suspect_drug: str
|
| 12 |
+
severity_assessment: str
|
| 13 |
+
recommended_action: str
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TaskDefinition(BaseModel):
|
| 17 |
+
model_config = ConfigDict(arbitrary_types_allowed=True, revalidate_instances="never")
|
| 18 |
+
|
| 19 |
+
task_id: str = Field(..., description="Unique pharmacovigilance task identifier")
|
| 20 |
+
difficulty: str = Field(..., description="easy | medium | hard")
|
| 21 |
+
reports: List[Any] = Field(default_factory=list, description="Synthetic adverse event reports")
|
| 22 |
+
drug_interaction_db: dict = Field(default_factory=dict, description="Hardcoded interaction and safety context")
|
| 23 |
+
ground_truth: GroundTruth
|
| 24 |
+
action_grader: Callable[[Any], Any]
|
| 25 |
+
description: str = Field(default="", description="Human-readable task summary")
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def id(self) -> str:
|
| 29 |
+
return self.task_id
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _base_breakdown(action: Any, ground_truth: GroundTruth) -> dict:
|
| 33 |
+
action_suspect = action.suspect_drug.strip().lower()
|
| 34 |
+
truth_suspect = ground_truth.suspect_drug.strip().lower()
|
| 35 |
+
suspect_match = (
|
| 36 |
+
action_suspect == truth_suspect
|
| 37 |
+
or action_suspect in truth_suspect
|
| 38 |
+
or truth_suspect in action_suspect
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
breakdown = {
|
| 42 |
+
"classification": 0.25 if action.classification == ground_truth.classification else 0.0,
|
| 43 |
+
"suspect_drug": 0.25 if suspect_match else 0.0,
|
| 44 |
+
"severity_assessment": 0.25 if action.severity_assessment == ground_truth.severity_assessment else 0.0,
|
| 45 |
+
"recommended_action": 0.25 if action.recommended_action == ground_truth.recommended_action else 0.0,
|
| 46 |
+
"false_alarm_penalty": 0.0,
|
| 47 |
+
"missed_signal_penalty": 0.0,
|
| 48 |
+
"reasoning_bonus": 0.0,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if action.classification == "new_signal" and ground_truth.classification == "noise":
|
| 52 |
+
breakdown["false_alarm_penalty"] = -0.10
|
| 53 |
+
if action.classification == "noise" and ground_truth.classification == "new_signal":
|
| 54 |
+
breakdown["missed_signal_penalty"] = -0.20
|
| 55 |
+
|
| 56 |
+
return breakdown
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _reward_from_breakdown(breakdown: dict):
|
| 60 |
+
from env import Reward
|
| 61 |
+
|
| 62 |
+
total = round(sum(breakdown.values()), 4)
|
| 63 |
+
return Reward(total=max(0.0, min(1.0, total)), breakdown=breakdown)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def known_signal_easy_action_grader(action: Any):
|
| 67 |
+
truth = GroundTruth(**TASK_DATA["known_signal_easy"]["ground_truth"])
|
| 68 |
+
breakdown = _base_breakdown(action, truth)
|
| 69 |
+
return _reward_from_breakdown(breakdown)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def cluster_signal_medium_action_grader(action: Any):
|
| 73 |
+
truth = GroundTruth(**TASK_DATA["cluster_signal_medium"]["ground_truth"])
|
| 74 |
+
breakdown = _base_breakdown(action, truth)
|
| 75 |
+
return _reward_from_breakdown(breakdown)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def confounded_hard_action_grader(action: Any):
|
| 79 |
+
truth = GroundTruth(**TASK_DATA["confounded_hard"]["ground_truth"])
|
| 80 |
+
breakdown = _base_breakdown(action, truth)
|
| 81 |
+
reasoning = action.reasoning.lower()
|
| 82 |
+
if any(
|
| 83 |
+
term in reasoning
|
| 84 |
+
for term in ("drug interaction", "tacrolimus", "voriconazole", "azole", "calcineurin", "level monitoring")
|
| 85 |
+
):
|
| 86 |
+
breakdown["reasoning_bonus"] = 0.15
|
| 87 |
+
return _reward_from_breakdown(breakdown)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _grader_score_from_trajectory(trajectory: Any = None) -> float:
|
| 91 |
+
trajectory = trajectory or {}
|
| 92 |
+
raw_score = 0.5
|
| 93 |
+
|
| 94 |
+
if isinstance(trajectory, dict):
|
| 95 |
+
if "score" in trajectory:
|
| 96 |
+
raw_score = float(trajectory["score"])
|
| 97 |
+
elif "rewards" in trajectory and trajectory["rewards"]:
|
| 98 |
+
raw_score = float(trajectory["rewards"][-1])
|
| 99 |
+
elif "reward" in trajectory:
|
| 100 |
+
reward_val = trajectory["reward"]
|
| 101 |
+
if isinstance(reward_val, dict) and "total" in reward_val:
|
| 102 |
+
raw_score = float(reward_val["total"])
|
| 103 |
+
else:
|
| 104 |
+
raw_score = float(reward_val)
|
| 105 |
+
|
| 106 |
+
return max(0.01, min(0.99, round(raw_score, 4)))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def known_signal_easy_grader(trajectory: Any = None) -> float:
|
| 110 |
+
from server.graders import known_signal_easy_grader as _delegate
|
| 111 |
+
|
| 112 |
+
return _delegate(trajectory)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def cluster_signal_medium_grader(trajectory: Any = None) -> float:
|
| 116 |
+
from server.graders import cluster_signal_medium_grader as _delegate
|
| 117 |
+
|
| 118 |
+
return _delegate(trajectory)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def confounded_hard_grader(trajectory: Any = None) -> float:
|
| 122 |
+
from server.graders import confounded_hard_grader as _delegate
|
| 123 |
+
|
| 124 |
+
return _delegate(trajectory)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _task_definition(
|
| 128 |
+
task_id: str,
|
| 129 |
+
difficulty: str,
|
| 130 |
+
description: str,
|
| 131 |
+
action_grader: Callable[[Any], Any],
|
| 132 |
+
) -> TaskDefinition:
|
| 133 |
+
from env import AdverseEventReport
|
| 134 |
+
|
| 135 |
+
task_data = TASK_DATA[task_id]
|
| 136 |
+
return TaskDefinition(
|
| 137 |
+
task_id=task_id,
|
| 138 |
+
difficulty=difficulty,
|
| 139 |
+
reports=[AdverseEventReport(**report) for report in task_data["reports"]],
|
| 140 |
+
drug_interaction_db=task_data["drug_interaction_db"],
|
| 141 |
+
ground_truth=GroundTruth(**task_data["ground_truth"]),
|
| 142 |
+
action_grader=action_grader,
|
| 143 |
+
description=description,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _build_all_tasks() -> Dict[str, List[TaskDefinition]]:
|
| 148 |
+
"""Build and return the complete task pool grouped by difficulty."""
|
| 149 |
+
return {
|
| 150 |
+
"easy": [
|
| 151 |
+
_task_definition(
|
| 152 |
+
"known_signal_easy",
|
| 153 |
+
"easy",
|
| 154 |
+
"Known ACE-inhibitor cough case that should be logged and monitored rather than escalated.",
|
| 155 |
+
known_signal_easy_action_grader,
|
| 156 |
+
),
|
| 157 |
+
],
|
| 158 |
+
"medium": [
|
| 159 |
+
_task_definition(
|
| 160 |
+
"cluster_signal_medium",
|
| 161 |
+
"medium",
|
| 162 |
+
"Cluster of bradycardia reports around a newly approved therapy that should be escalated as a signal.",
|
| 163 |
+
cluster_signal_medium_action_grader,
|
| 164 |
+
),
|
| 165 |
+
],
|
| 166 |
+
"hard": [
|
| 167 |
+
_task_definition(
|
| 168 |
+
"confounded_hard",
|
| 169 |
+
"hard",
|
| 170 |
+
"Confounded transplant case where the blamed drug is wrong and the real problem is a tacrolimus interaction.",
|
| 171 |
+
confounded_hard_action_grader,
|
| 172 |
+
),
|
| 173 |
+
],
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_tasks(
|
| 178 |
+
difficulty: Optional[str] = None,
|
| 179 |
+
seed: Optional[int] = None,
|
| 180 |
+
n: int = 5,
|
| 181 |
+
grouped: bool = False,
|
| 182 |
+
):
|
| 183 |
+
"""
|
| 184 |
+
Return tasks either as a flat task-id map or a difficulty-filtered list.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
difficulty: Optional difficulty tier to select from.
|
| 188 |
+
seed: Optional seed for reproducible shuffling within a difficulty pool.
|
| 189 |
+
n: Maximum number of tasks to return when selecting by difficulty.
|
| 190 |
+
grouped: When True and difficulty is None, return the difficulty-grouped dict.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
If grouped=True and difficulty is None:
|
| 194 |
+
Dict[str, List[TaskDefinition]]
|
| 195 |
+
If difficulty is None:
|
| 196 |
+
Dict[str, TaskDefinition]
|
| 197 |
+
Otherwise:
|
| 198 |
+
List[TaskDefinition]
|
| 199 |
+
"""
|
| 200 |
+
all_tasks = _build_all_tasks()
|
| 201 |
+
|
| 202 |
+
if difficulty is None:
|
| 203 |
+
if grouped:
|
| 204 |
+
return {level: tasks[:n] for level, tasks in all_tasks.items()}
|
| 205 |
+
return {
|
| 206 |
+
task.task_id: task
|
| 207 |
+
for tasks in all_tasks.values()
|
| 208 |
+
for task in tasks[:n]
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
pool = list(all_tasks.get(difficulty, []))
|
| 212 |
+
if seed is not None:
|
| 213 |
+
rng = random.Random(seed)
|
| 214 |
+
rng.shuffle(pool)
|
| 215 |
+
return pool[:n]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_task(task_id: str) -> TaskDefinition:
|
| 219 |
+
tasks = get_tasks()
|
| 220 |
+
if task_id not in tasks:
|
| 221 |
+
raise KeyError(f"Unknown task_id: {task_id}")
|
| 222 |
+
return tasks[task_id]
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 5 |
+
|
| 6 |
+
from env import Action, PharmaVigilanceEnv
|
| 7 |
+
from tasks import (
|
| 8 |
+
cluster_signal_medium_action_grader,
|
| 9 |
+
cluster_signal_medium_grader,
|
| 10 |
+
confounded_hard_action_grader,
|
| 11 |
+
confounded_hard_grader,
|
| 12 |
+
get_task,
|
| 13 |
+
get_tasks,
|
| 14 |
+
known_signal_easy_action_grader,
|
| 15 |
+
known_signal_easy_grader,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_reset_loads_easy_task():
|
| 20 |
+
env = PharmaVigilanceEnv()
|
| 21 |
+
obs = env.reset("known_signal_easy")
|
| 22 |
+
assert obs.task_id == "known_signal_easy"
|
| 23 |
+
assert obs.step_number == 0
|
| 24 |
+
assert len(obs.reports) == 1
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_known_signal_grader_full_credit():
|
| 28 |
+
reward = known_signal_easy_action_grader(
|
| 29 |
+
Action(
|
| 30 |
+
classification="known_side_effect",
|
| 31 |
+
suspect_drug="Lisinopril",
|
| 32 |
+
severity_assessment="mild",
|
| 33 |
+
recommended_action="log_and_monitor",
|
| 34 |
+
reasoning="Known reaction pattern.",
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
assert reward.total == 1.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_medium_cluster_grader_partial_credit():
|
| 41 |
+
reward = cluster_signal_medium_action_grader(
|
| 42 |
+
Action(
|
| 43 |
+
classification="new_signal",
|
| 44 |
+
suspect_drug="Cardiovexa",
|
| 45 |
+
severity_assessment="moderate",
|
| 46 |
+
recommended_action="escalate",
|
| 47 |
+
reasoning="A cluster is forming.",
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
assert reward.total == 0.75
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_hard_grader_reasoning_bonus():
|
| 54 |
+
reward = confounded_hard_action_grader(
|
| 55 |
+
Action(
|
| 56 |
+
classification="new_signal",
|
| 57 |
+
suspect_drug="Tacrolimus+Voriconazole",
|
| 58 |
+
severity_assessment="critical",
|
| 59 |
+
recommended_action="escalate",
|
| 60 |
+
reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.",
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
assert reward.total == 1.0
|
| 64 |
+
assert reward.breakdown["reasoning_bonus"] == 0.15
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_hard_grader_substring_suspect_match():
|
| 68 |
+
reward = confounded_hard_action_grader(
|
| 69 |
+
Action(
|
| 70 |
+
classification="new_signal",
|
| 71 |
+
suspect_drug="Tacrolimus",
|
| 72 |
+
severity_assessment="critical",
|
| 73 |
+
recommended_action="escalate",
|
| 74 |
+
reasoning="Voriconazole likely raised tacrolimus exposure.",
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
assert reward.breakdown["suspect_drug"] == 0.25
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_env_step_returns_done():
|
| 81 |
+
env = PharmaVigilanceEnv()
|
| 82 |
+
env.reset("confounded_hard")
|
| 83 |
+
obs, reward, done, info = env.step(
|
| 84 |
+
Action(
|
| 85 |
+
classification="new_signal",
|
| 86 |
+
suspect_drug="Tacrolimus+Voriconazole",
|
| 87 |
+
severity_assessment="critical",
|
| 88 |
+
recommended_action="escalate",
|
| 89 |
+
reasoning="Tacrolimus toxicity from an azole interaction.",
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
assert done is True
|
| 93 |
+
assert obs.step_number == 1
|
| 94 |
+
assert "reward_breakdown" in info
|
| 95 |
+
assert reward.total >= 0.85
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_state_tracks_last_action():
|
| 99 |
+
env = PharmaVigilanceEnv()
|
| 100 |
+
env.reset("known_signal_easy")
|
| 101 |
+
env.step(
|
| 102 |
+
Action(
|
| 103 |
+
classification="known_side_effect",
|
| 104 |
+
suspect_drug="Lisinopril",
|
| 105 |
+
severity_assessment="mild",
|
| 106 |
+
recommended_action="log_and_monitor",
|
| 107 |
+
reasoning="Known adverse effect.",
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
state = env.state()
|
| 111 |
+
assert state["step_number"] == 1
|
| 112 |
+
assert state["last_action"]["classification"] == "known_side_effect"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_all_tasks_available():
|
| 116 |
+
tasks = get_tasks()
|
| 117 |
+
assert set(tasks.keys()) == {
|
| 118 |
+
"known_signal_easy",
|
| 119 |
+
"cluster_signal_medium",
|
| 120 |
+
"confounded_hard",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_get_task_returns_hard_truth():
|
| 125 |
+
task = get_task("confounded_hard")
|
| 126 |
+
assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_public_graders_are_strictly_bounded():
|
| 130 |
+
assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
|
| 131 |
+
assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
|
| 132 |
+
assert confounded_hard_grader({"score": 1.5}) == 0.99
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
validate-submission.sh
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 4 |
+
#
|
| 5 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 6 |
+
#
|
| 7 |
+
# Prerequisites:
|
| 8 |
+
# - Docker: https://docs.docker.com/get-docker/
|
| 9 |
+
# - openenv-core: pip install openenv-core
|
| 10 |
+
# - curl (usually pre-installed)
|
| 11 |
+
#
|
| 12 |
+
# Run:
|
| 13 |
+
# curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
|
| 14 |
+
#
|
| 15 |
+
# Or download and run locally:
|
| 16 |
+
# chmod +x validate-submission.sh
|
| 17 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 18 |
+
#
|
| 19 |
+
# Arguments:
|
| 20 |
+
# ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
|
| 21 |
+
# repo_dir Path to your repo (default: current directory)
|
| 22 |
+
#
|
| 23 |
+
# Examples:
|
| 24 |
+
# ./validate-submission.sh https://my-team.hf.space
|
| 25 |
+
# ./validate-submission.sh https://my-team.hf.space ./my-repo
|
| 26 |
+
#
|
| 27 |
+
|
| 28 |
+
set -uo pipefail
|
| 29 |
+
|
| 30 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 31 |
+
if [ -t 1 ]; then
|
| 32 |
+
RED='\033[0;31m'
|
| 33 |
+
GREEN='\033[0;32m'
|
| 34 |
+
YELLOW='\033[1;33m'
|
| 35 |
+
BOLD='\033[1m'
|
| 36 |
+
NC='\033[0m'
|
| 37 |
+
else
|
| 38 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
run_with_timeout() {
|
| 42 |
+
local secs="$1"; shift
|
| 43 |
+
if command -v timeout &>/dev/null; then
|
| 44 |
+
timeout "$secs" "$@"
|
| 45 |
+
elif command -v gtimeout &>/dev/null; then
|
| 46 |
+
gtimeout "$secs" "$@"
|
| 47 |
+
else
|
| 48 |
+
"$@" &
|
| 49 |
+
local pid=$!
|
| 50 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 51 |
+
local watcher=$!
|
| 52 |
+
wait "$pid" 2>/dev/null
|
| 53 |
+
local rc=$?
|
| 54 |
+
kill "$watcher" 2>/dev/null
|
| 55 |
+
wait "$watcher" 2>/dev/null
|
| 56 |
+
return $rc
|
| 57 |
+
fi
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
portable_mktemp() {
|
| 61 |
+
local prefix="${1:-validate}"
|
| 62 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CLEANUP_FILES=()
|
| 66 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 67 |
+
trap cleanup EXIT
|
| 68 |
+
|
| 69 |
+
PING_URL="${1:-}"
|
| 70 |
+
REPO_DIR="${2:-.}"
|
| 71 |
+
|
| 72 |
+
if [ -z "$PING_URL" ]; then
|
| 73 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 74 |
+
printf "\n"
|
| 75 |
+
printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
|
| 76 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 81 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
PING_URL="${PING_URL%/}"
|
| 85 |
+
export PING_URL
|
| 86 |
+
PASS=0
|
| 87 |
+
|
| 88 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 89 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 90 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 91 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 92 |
+
stop_at() {
|
| 93 |
+
printf "\n"
|
| 94 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 95 |
+
exit 1
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
printf "\n"
|
| 99 |
+
printf "${BOLD}========================================${NC}\n"
|
| 100 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 101 |
+
printf "${BOLD}========================================${NC}\n"
|
| 102 |
+
log "Repo: $REPO_DIR"
|
| 103 |
+
log "Ping URL: $PING_URL"
|
| 104 |
+
printf "\n"
|
| 105 |
+
|
| 106 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 107 |
+
|
| 108 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 109 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 110 |
+
HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
|
| 111 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 112 |
+
"$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
|
| 113 |
+
|
| 114 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 115 |
+
pass "HF Space is live and responds to /reset"
|
| 116 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 117 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 118 |
+
hint "Check your network connection and that the Space is running."
|
| 119 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 120 |
+
stop_at "Step 1"
|
| 121 |
+
else
|
| 122 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 123 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 124 |
+
hint "Try opening $PING_URL in your browser first."
|
| 125 |
+
stop_at "Step 1"
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 129 |
+
|
| 130 |
+
if ! command -v docker &>/dev/null; then
|
| 131 |
+
fail "docker command not found"
|
| 132 |
+
hint "Install Docker: https://docs.docker.com/get-docker/"
|
| 133 |
+
stop_at "Step 2"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 137 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 138 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 139 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 140 |
+
else
|
| 141 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 142 |
+
stop_at "Step 2"
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 146 |
+
|
| 147 |
+
BUILD_OK=false
|
| 148 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 149 |
+
|
| 150 |
+
if [ "$BUILD_OK" = true ]; then
|
| 151 |
+
pass "Docker build succeeded"
|
| 152 |
+
else
|
| 153 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 154 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 155 |
+
stop_at "Step 2"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 159 |
+
|
| 160 |
+
if ! command -v openenv &>/dev/null; then
|
| 161 |
+
fail "openenv command not found"
|
| 162 |
+
hint "Install it: pip install openenv-core"
|
| 163 |
+
stop_at "Step 3"
|
| 164 |
+
fi
|
| 165 |
+
|
| 166 |
+
VALIDATE_OK=false
|
| 167 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 168 |
+
|
| 169 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 170 |
+
pass "openenv validate passed"
|
| 171 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 172 |
+
else
|
| 173 |
+
fail "openenv validate failed"
|
| 174 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 175 |
+
stop_at "Step 3"
|
| 176 |
+
fi
|
| 177 |
+
|
| 178 |
+
printf "\n"
|
| 179 |
+
printf "${BOLD}========================================${NC}\n"
|
| 180 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 181 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 182 |
+
printf "${BOLD}========================================${NC}\n"
|
| 183 |
+
printf "\n"
|
| 184 |
+
|
| 185 |
+
exit 0
|