modelbuilderhq commited on
Commit
dfc0f77
·
verified ·
1 Parent(s): 9bfdbc5

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. inference.py +6 -6
  2. server/app.py +7 -2
  3. tests/test_env.py +40 -11
inference.py CHANGED
@@ -99,12 +99,12 @@ def fetch_reset(task_name: str) -> dict:
99
  return response.json()
100
 
101
 
102
- def submit_action(action: PharmaAction) -> dict:
103
- response = requests.post(
104
- f"{ENV_URL}/step",
105
- json=action.model_dump(),
106
- timeout=30,
107
- )
108
  response.raise_for_status()
109
  return response.json()
110
 
 
99
  return response.json()
100
 
101
 
102
+ def submit_action(action: PharmaAction) -> dict:
103
+ response = requests.post(
104
+ f"{ENV_URL}/step",
105
+ json={"action": action.model_dump()},
106
+ timeout=30,
107
+ )
108
  response.raise_for_status()
109
  return response.json()
110
 
server/app.py CHANGED
@@ -19,9 +19,12 @@ class OpenEnvPharmaAdapter:
19
  expected by OpenEnv's HTTP server and web playground helpers.
20
  """
21
 
 
 
 
22
  def __init__(self) -> None:
23
- self._env = PharmaVigilanceEnv()
24
- self._last_state = State(episode_id=None, step_count=0)
25
 
26
  @staticmethod
27
  def _normalize_reports(reports):
@@ -36,6 +39,7 @@ class OpenEnvPharmaAdapter:
36
  def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
37
  observation = self._env.reset(task_id=task_id)
38
  self._last_state = State(episode_id=task_id, step_count=0)
 
39
  return PharmaObservation(
40
  task_id=observation.task_id,
41
  reports=self._normalize_reports(observation.reports),
@@ -57,6 +61,7 @@ class OpenEnvPharmaAdapter:
57
  episode_id=observation.task_id,
58
  step_count=observation.step_number,
59
  )
 
60
  return PharmaObservation(
61
  task_id=observation.task_id,
62
  reports=self._normalize_reports(observation.reports),
 
19
  expected by OpenEnv's HTTP server and web playground helpers.
20
  """
21
 
22
+ _shared_env = PharmaVigilanceEnv()
23
+ _shared_state = State(episode_id=None, step_count=0)
24
+
25
  def __init__(self) -> None:
26
+ self._env = self.__class__._shared_env
27
+ self._last_state = self.__class__._shared_state
28
 
29
  @staticmethod
30
  def _normalize_reports(reports):
 
39
  def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
40
  observation = self._env.reset(task_id=task_id)
41
  self._last_state = State(episode_id=task_id, step_count=0)
42
+ self.__class__._shared_state = self._last_state
43
  return PharmaObservation(
44
  task_id=observation.task_id,
45
  reports=self._normalize_reports(observation.reports),
 
61
  episode_id=observation.task_id,
62
  step_count=observation.step_number,
63
  )
64
+ self.__class__._shared_state = self._last_state
65
  return PharmaObservation(
66
  task_id=observation.task_id,
67
  reports=self._normalize_reports(observation.reports),
tests/test_env.py CHANGED
@@ -1,10 +1,11 @@
1
- import sys
2
- from pathlib import Path
3
-
4
- sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
5
-
6
- from env import Action, PharmaVigilanceEnv
7
- from tasks import (
 
8
  cluster_signal_medium_action_grader,
9
  cluster_signal_medium_grader,
10
  confounded_hard_action_grader,
@@ -126,7 +127,35 @@ def test_get_task_returns_hard_truth():
126
  assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
127
 
128
 
129
- def test_public_graders_are_strictly_bounded():
130
- assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
131
- assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
132
- assert confounded_hard_grader({"score": 1.5}) == 0.99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
6
+
7
+ from env import Action, PharmaVigilanceEnv
8
+ from tasks import (
9
  cluster_signal_medium_action_grader,
10
  cluster_signal_medium_grader,
11
  confounded_hard_action_grader,
 
127
  assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
128
 
129
 
130
+ def test_public_graders_are_strictly_bounded():
131
+ assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
132
+ assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
133
+ assert confounded_hard_grader({"score": 1.5}) == 0.99
134
+
135
+
136
+ def test_http_reset_then_step_roundtrip():
137
+ pytest.importorskip("openenv")
138
+ from fastapi.testclient import TestClient
139
+ from server.app import app
140
+
141
+ client = TestClient(app)
142
+
143
+ reset_response = client.post("/reset", json={})
144
+ assert reset_response.status_code == 200
145
+
146
+ step_response = client.post(
147
+ "/step",
148
+ json={
149
+ "action": {
150
+ "classification": "known_side_effect",
151
+ "suspect_drug": "Lisinopril",
152
+ "severity_assessment": "mild",
153
+ "recommended_action": "log_and_monitor",
154
+ "reasoning": "Known ACE inhibitor cough.",
155
+ }
156
+ },
157
+ )
158
+ assert step_response.status_code == 200
159
+ payload = step_response.json()
160
+ assert payload["done"] is True
161
+ assert payload["reward"] == 1.0