Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from openenv.core.env_server import create_web_interface_app | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..env import PharmaVigilanceEnv | |
| from ..models import PharmaAction, PharmaObservation | |
| except ImportError: | |
| from env import PharmaVigilanceEnv | |
| from models import PharmaAction, PharmaObservation | |
| TASK_IDS = ["known_signal_easy", "cluster_signal_medium", "confounded_hard"] | |
| class OpenEnvPharmaAdapter: | |
| """ | |
| Thin adapter that exposes the local environment through the interface | |
| expected by OpenEnv's HTTP server and web playground helpers. | |
| """ | |
| _shared_env = PharmaVigilanceEnv() | |
| _shared_state = State(episode_id=None, step_count=0) | |
| def __init__(self) -> None: | |
| self._env = self.__class__._shared_env | |
| self._last_state = self.__class__._shared_state | |
| def _normalize_reports(reports): | |
| normalized = [] | |
| for report in reports: | |
| if hasattr(report, "model_dump"): | |
| normalized.append(report.model_dump()) | |
| else: | |
| normalized.append(report) | |
| return normalized | |
| def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation: | |
| observation = self._env.reset(task_id=task_id) | |
| self._last_state = State(episode_id=task_id, step_count=0) | |
| self.__class__._shared_state = self._last_state | |
| return PharmaObservation( | |
| task_id=observation.task_id, | |
| reports=self._normalize_reports(observation.reports), | |
| drug_interaction_db=observation.drug_interaction_db, | |
| step_number=observation.step_number, | |
| max_steps=observation.max_steps, | |
| feedback=observation.feedback, | |
| reward=0.0, | |
| done=False, | |
| metadata={"difficulty": self._env.current_task.difficulty if self._env.current_task else None}, | |
| ) | |
| async def reset_async(self, task_id: str = "known_signal_easy") -> PharmaObservation: | |
| return self.reset(task_id=task_id) | |
| def step(self, action: PharmaAction) -> PharmaObservation: | |
| observation, reward, done, info = self._env.step(action) | |
| self._last_state = State( | |
| episode_id=observation.task_id, | |
| step_count=observation.step_number, | |
| ) | |
| self.__class__._shared_state = self._last_state | |
| return PharmaObservation( | |
| task_id=observation.task_id, | |
| reports=self._normalize_reports(observation.reports), | |
| drug_interaction_db=observation.drug_interaction_db, | |
| step_number=observation.step_number, | |
| max_steps=observation.max_steps, | |
| feedback=observation.feedback, | |
| reward=reward.total, | |
| done=done, | |
| metadata=info, | |
| ) | |
| async def step_async(self, action: PharmaAction) -> PharmaObservation: | |
| return self.step(action) | |
| def state(self) -> State: | |
| return self._last_state | |
| def close(self) -> None: | |
| return None | |
| app: FastAPI = create_web_interface_app( | |
| OpenEnvPharmaAdapter, | |
| PharmaAction, | |
| PharmaObservation, | |
| env_name="pharma_vigilance_env", | |
| ) | |
| def list_tasks(): | |
| return {"tasks": TASK_IDS} | |
| def main(host: str = "0.0.0.0", port: int = 7860) -> None: | |
| import uvicorn | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| main() | |