modelbuilderhq commited on
Commit
a612f9c
·
verified ·
1 Parent(s): 4bde500

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. server/app.py +26 -15
server/app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
2
  from openenv.core.env_server import create_web_interface_app
 
3
 
4
  try:
5
  from ..env import PharmaVigilanceEnv
@@ -20,20 +21,24 @@ class OpenEnvPharmaAdapter:
20
 
21
  def __init__(self) -> None:
22
  self._env = PharmaVigilanceEnv()
23
- self._last_state: dict = {
24
- "episode_id": None,
25
- "step_count": 0,
26
- }
 
 
 
 
 
 
 
27
 
28
  def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
29
  observation = self._env.reset(task_id=task_id)
30
- self._last_state = {
31
- "episode_id": task_id,
32
- "step_count": 0,
33
- }
34
  return PharmaObservation(
35
  task_id=observation.task_id,
36
- reports=observation.reports,
37
  drug_interaction_db=observation.drug_interaction_db,
38
  step_number=observation.step_number,
39
  max_steps=observation.max_steps,
@@ -43,15 +48,18 @@ class OpenEnvPharmaAdapter:
43
  metadata={"difficulty": self._env.current_task.difficulty if self._env.current_task else None},
44
  )
45
 
 
 
 
46
  def step(self, action: PharmaAction) -> PharmaObservation:
47
  observation, reward, done, info = self._env.step(action)
48
- self._last_state = {
49
- "episode_id": observation.task_id,
50
- "step_count": observation.step_number,
51
- }
52
  return PharmaObservation(
53
  task_id=observation.task_id,
54
- reports=observation.reports,
55
  drug_interaction_db=observation.drug_interaction_db,
56
  step_number=observation.step_number,
57
  max_steps=observation.max_steps,
@@ -61,8 +69,11 @@ class OpenEnvPharmaAdapter:
61
  metadata=info,
62
  )
63
 
 
 
 
64
  @property
65
- def state(self) -> dict:
66
  return self._last_state
67
 
68
  def close(self) -> None:
 
1
  from fastapi import FastAPI
2
  from openenv.core.env_server import create_web_interface_app
3
+ from openenv.core.env_server.types import State
4
 
5
  try:
6
  from ..env import PharmaVigilanceEnv
 
21
 
22
  def __init__(self) -> None:
23
  self._env = PharmaVigilanceEnv()
24
+ self._last_state = State(episode_id=None, step_count=0)
25
+
26
+ @staticmethod
27
+ def _normalize_reports(reports):
28
+ normalized = []
29
+ for report in reports:
30
+ if hasattr(report, "model_dump"):
31
+ normalized.append(report.model_dump())
32
+ else:
33
+ normalized.append(report)
34
+ return normalized
35
 
36
  def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
37
  observation = self._env.reset(task_id=task_id)
38
+ self._last_state = State(episode_id=task_id, step_count=0)
 
 
 
39
  return PharmaObservation(
40
  task_id=observation.task_id,
41
+ reports=self._normalize_reports(observation.reports),
42
  drug_interaction_db=observation.drug_interaction_db,
43
  step_number=observation.step_number,
44
  max_steps=observation.max_steps,
 
48
  metadata={"difficulty": self._env.current_task.difficulty if self._env.current_task else None},
49
  )
50
 
51
+ async def reset_async(self, task_id: str = "known_signal_easy") -> PharmaObservation:
52
+ return self.reset(task_id=task_id)
53
+
54
  def step(self, action: PharmaAction) -> PharmaObservation:
55
  observation, reward, done, info = self._env.step(action)
56
+ self._last_state = State(
57
+ episode_id=observation.task_id,
58
+ step_count=observation.step_number,
59
+ )
60
  return PharmaObservation(
61
  task_id=observation.task_id,
62
+ reports=self._normalize_reports(observation.reports),
63
  drug_interaction_db=observation.drug_interaction_db,
64
  step_number=observation.step_number,
65
  max_steps=observation.max_steps,
 
69
  metadata=info,
70
  )
71
 
72
+ async def step_async(self, action: PharmaAction) -> PharmaObservation:
73
+ return self.step(action)
74
+
75
  @property
76
+ def state(self) -> State:
77
  return self._last_state
78
 
79
  def close(self) -> None: