Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +6 -6
- server/app.py +7 -2
- tests/test_env.py +40 -11
inference.py
CHANGED
|
@@ -99,12 +99,12 @@ def fetch_reset(task_name: str) -> dict:
|
|
| 99 |
return response.json()
|
| 100 |
|
| 101 |
|
| 102 |
-
def submit_action(action: PharmaAction) -> dict:
|
| 103 |
-
response = requests.post(
|
| 104 |
-
f"{ENV_URL}/step",
|
| 105 |
-
json=action.model_dump(),
|
| 106 |
-
timeout=30,
|
| 107 |
-
)
|
| 108 |
response.raise_for_status()
|
| 109 |
return response.json()
|
| 110 |
|
|
|
|
| 99 |
return response.json()
|
| 100 |
|
| 101 |
|
| 102 |
+
def submit_action(action: PharmaAction) -> dict:
|
| 103 |
+
response = requests.post(
|
| 104 |
+
f"{ENV_URL}/step",
|
| 105 |
+
json={"action": action.model_dump()},
|
| 106 |
+
timeout=30,
|
| 107 |
+
)
|
| 108 |
response.raise_for_status()
|
| 109 |
return response.json()
|
| 110 |
|
server/app.py
CHANGED
|
@@ -19,9 +19,12 @@ class OpenEnvPharmaAdapter:
|
|
| 19 |
expected by OpenEnv's HTTP server and web playground helpers.
|
| 20 |
"""
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
def __init__(self) -> None:
|
| 23 |
-
self._env =
|
| 24 |
-
self._last_state =
|
| 25 |
|
| 26 |
@staticmethod
|
| 27 |
def _normalize_reports(reports):
|
|
@@ -36,6 +39,7 @@ class OpenEnvPharmaAdapter:
|
|
| 36 |
def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
|
| 37 |
observation = self._env.reset(task_id=task_id)
|
| 38 |
self._last_state = State(episode_id=task_id, step_count=0)
|
|
|
|
| 39 |
return PharmaObservation(
|
| 40 |
task_id=observation.task_id,
|
| 41 |
reports=self._normalize_reports(observation.reports),
|
|
@@ -57,6 +61,7 @@ class OpenEnvPharmaAdapter:
|
|
| 57 |
episode_id=observation.task_id,
|
| 58 |
step_count=observation.step_number,
|
| 59 |
)
|
|
|
|
| 60 |
return PharmaObservation(
|
| 61 |
task_id=observation.task_id,
|
| 62 |
reports=self._normalize_reports(observation.reports),
|
|
|
|
| 19 |
expected by OpenEnv's HTTP server and web playground helpers.
|
| 20 |
"""
|
| 21 |
|
| 22 |
+
_shared_env = PharmaVigilanceEnv()
|
| 23 |
+
_shared_state = State(episode_id=None, step_count=0)
|
| 24 |
+
|
| 25 |
def __init__(self) -> None:
|
| 26 |
+
self._env = self.__class__._shared_env
|
| 27 |
+
self._last_state = self.__class__._shared_state
|
| 28 |
|
| 29 |
@staticmethod
|
| 30 |
def _normalize_reports(reports):
|
|
|
|
| 39 |
def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
|
| 40 |
observation = self._env.reset(task_id=task_id)
|
| 41 |
self._last_state = State(episode_id=task_id, step_count=0)
|
| 42 |
+
self.__class__._shared_state = self._last_state
|
| 43 |
return PharmaObservation(
|
| 44 |
task_id=observation.task_id,
|
| 45 |
reports=self._normalize_reports(observation.reports),
|
|
|
|
| 61 |
episode_id=observation.task_id,
|
| 62 |
step_count=observation.step_number,
|
| 63 |
)
|
| 64 |
+
self.__class__._shared_state = self._last_state
|
| 65 |
return PharmaObservation(
|
| 66 |
task_id=observation.task_id,
|
| 67 |
reports=self._normalize_reports(observation.reports),
|
tests/test_env.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
from
|
|
|
|
| 8 |
cluster_signal_medium_action_grader,
|
| 9 |
cluster_signal_medium_grader,
|
| 10 |
confounded_hard_action_grader,
|
|
@@ -126,7 +127,35 @@ def test_get_task_returns_hard_truth():
|
|
| 126 |
assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
|
| 127 |
|
| 128 |
|
| 129 |
-
def test_public_graders_are_strictly_bounded():
|
| 130 |
-
assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
|
| 131 |
-
assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
|
| 132 |
-
assert confounded_hard_grader({"score": 1.5}) == 0.99
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 6 |
+
|
| 7 |
+
from env import Action, PharmaVigilanceEnv
|
| 8 |
+
from tasks import (
|
| 9 |
cluster_signal_medium_action_grader,
|
| 10 |
cluster_signal_medium_grader,
|
| 11 |
confounded_hard_action_grader,
|
|
|
|
| 127 |
assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
|
| 128 |
|
| 129 |
|
| 130 |
+
def test_public_graders_are_strictly_bounded():
|
| 131 |
+
assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
|
| 132 |
+
assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
|
| 133 |
+
assert confounded_hard_grader({"score": 1.5}) == 0.99
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def test_http_reset_then_step_roundtrip():
|
| 137 |
+
pytest.importorskip("openenv")
|
| 138 |
+
from fastapi.testclient import TestClient
|
| 139 |
+
from server.app import app
|
| 140 |
+
|
| 141 |
+
client = TestClient(app)
|
| 142 |
+
|
| 143 |
+
reset_response = client.post("/reset", json={})
|
| 144 |
+
assert reset_response.status_code == 200
|
| 145 |
+
|
| 146 |
+
step_response = client.post(
|
| 147 |
+
"/step",
|
| 148 |
+
json={
|
| 149 |
+
"action": {
|
| 150 |
+
"classification": "known_side_effect",
|
| 151 |
+
"suspect_drug": "Lisinopril",
|
| 152 |
+
"severity_assessment": "mild",
|
| 153 |
+
"recommended_action": "log_and_monitor",
|
| 154 |
+
"reasoning": "Known ACE inhibitor cough.",
|
| 155 |
+
}
|
| 156 |
+
},
|
| 157 |
+
)
|
| 158 |
+
assert step_response.status_code == 200
|
| 159 |
+
payload = step_response.json()
|
| 160 |
+
assert payload["done"] is True
|
| 161 |
+
assert payload["reward"] == 1.0
|