File size: 3,427 Bytes
b837a03
 
a612f9c
b837a03
4bde500
 
 
 
 
 
b837a03
 
 
 
 
 
 
 
 
 
 
dfc0f77
 
 
b837a03
dfc0f77
 
a612f9c
 
 
 
 
 
 
 
 
 
b837a03
 
 
a612f9c
dfc0f77
b837a03
 
a612f9c
b837a03
 
 
 
 
 
 
 
 
a612f9c
 
 
b837a03
 
a612f9c
 
 
 
dfc0f77
b837a03
 
a612f9c
b837a03
 
 
 
 
 
 
 
 
a612f9c
 
 
b837a03
a612f9c
b837a03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI
from openenv.core.env_server import create_web_interface_app
from openenv.core.env_server.types import State

try:
    from ..env import PharmaVigilanceEnv
    from ..models import PharmaAction, PharmaObservation
except ImportError:
    from env import PharmaVigilanceEnv
    from models import PharmaAction, PharmaObservation


TASK_IDS = ["known_signal_easy", "cluster_signal_medium", "confounded_hard"]


class OpenEnvPharmaAdapter:
    """
    Thin adapter that exposes the local environment through the interface
    expected by OpenEnv's HTTP server and web playground helpers.
    """

    _shared_env = PharmaVigilanceEnv()
    _shared_state = State(episode_id=None, step_count=0)

    def __init__(self) -> None:
        self._env = self.__class__._shared_env
        self._last_state = self.__class__._shared_state

    @staticmethod
    def _normalize_reports(reports):
        normalized = []
        for report in reports:
            if hasattr(report, "model_dump"):
                normalized.append(report.model_dump())
            else:
                normalized.append(report)
        return normalized

    def reset(self, task_id: str = "known_signal_easy") -> PharmaObservation:
        observation = self._env.reset(task_id=task_id)
        self._last_state = State(episode_id=task_id, step_count=0)
        self.__class__._shared_state = self._last_state
        return PharmaObservation(
            task_id=observation.task_id,
            reports=self._normalize_reports(observation.reports),
            drug_interaction_db=observation.drug_interaction_db,
            step_number=observation.step_number,
            max_steps=observation.max_steps,
            feedback=observation.feedback,
            reward=0.0,
            done=False,
            metadata={"difficulty": self._env.current_task.difficulty if self._env.current_task else None},
        )

    async def reset_async(self, task_id: str = "known_signal_easy") -> PharmaObservation:
        return self.reset(task_id=task_id)

    def step(self, action: PharmaAction) -> PharmaObservation:
        observation, reward, done, info = self._env.step(action)
        self._last_state = State(
            episode_id=observation.task_id,
            step_count=observation.step_number,
        )
        self.__class__._shared_state = self._last_state
        return PharmaObservation(
            task_id=observation.task_id,
            reports=self._normalize_reports(observation.reports),
            drug_interaction_db=observation.drug_interaction_db,
            step_number=observation.step_number,
            max_steps=observation.max_steps,
            feedback=observation.feedback,
            reward=reward.total,
            done=done,
            metadata=info,
        )

    async def step_async(self, action: PharmaAction) -> PharmaObservation:
        return self.step(action)

    @property
    def state(self) -> State:
        return self._last_state

    def close(self) -> None:
        return None


app: FastAPI = create_web_interface_app(
    OpenEnvPharmaAdapter,
    PharmaAction,
    PharmaObservation,
    env_name="pharma_vigilance_env",
)


@app.get("/tasks")
def list_tasks():
    return {"tasks": TASK_IDS}


def main(host: str = "0.0.0.0", port: int = 7860) -> None:
    import uvicorn

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()