File size: 4,643 Bytes
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Pharmacovigilance Signal Detector Environment Client."""

from typing import Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

try:
    from .env import Action, Observation, AdverseEventReport
except ImportError:
    from env import Action, Observation, AdverseEventReport


class PharmaVigilanceEnvClient(
    EnvClient[Action, Observation, State]
):
    """

    Client for the Pharmacovigilance Signal Detector environment.



    This client maintains a persistent connection to the environment server and

    parses server responses into strongly-typed observation models.



    Example:

        >>> with PharmaVigilanceEnvClient(base_url="http://localhost:7860") as env:

        ...     result = env.reset(task_id="known_signal_easy")

        ...     print(result.observation.task_id)

        ...

        ...     action = Action(

        ...         classification="known_side_effect",

        ...         suspect_drug="Ibuprofen",

        ...         severity_assessment="moderate",

        ...         recommended_action="log_and_monitor",

        ...         reasoning="GI bleeding is a known ibuprofen adverse effect.",

        ...     )

        ...     result = env.step(action)

        ...     print(result.observation.feedback)

        ...     print(result.reward)



    Example with Docker:

        >>> client = PharmaVigilanceEnvClient.from_docker_image("pharmacovigilance-env:latest")

        >>> try:

        ...     result = client.reset(task_id="cluster_signal_medium")

        ...     action = Action(

        ...         classification="new_signal",

        ...         suspect_drug="Gliptozin",

        ...         severity_assessment="severe",

        ...         recommended_action="escalate",

        ...         reasoning="Clustered vision loss on a new drug warrants escalation.",

        ...     )

        ...     result = client.step(action)

        ... finally:

        ...     client.close()

    """

    def _step_payload(self, action: Action) -> Dict:
        """

        Convert an Action model into the JSON payload sent to /step.



        Args:

            action: Typed agent action.



        Returns:

            Dictionary representation suitable for JSON transport.

        """
        return {
            "classification": action.classification,
            "suspect_drug": action.suspect_drug,
            "severity_assessment": action.severity_assessment,
            "recommended_action": action.recommended_action,
            "reasoning": action.reasoning,
        }

    def _parse_result(self, payload: Dict) -> StepResult[Observation]:
        """

        Parse a server /step response into StepResult[Observation].



        Args:

            payload: JSON response from the environment server.



        Returns:

            StepResult containing the typed observation, reward, and done flag.

        """
        obs_data = payload.get("observation", {})
        reports = [
            AdverseEventReport(**report)
            for report in obs_data.get("reports", [])
        ]

        observation = Observation(
            task_id=obs_data.get("task_id", ""),
            reports=reports,
            drug_interaction_db=obs_data.get("drug_interaction_db", {}),
            step_number=obs_data.get("step_number", 0),
            max_steps=obs_data.get("max_steps", 1),
            feedback=obs_data.get("feedback"),
        )

        reward_payload = payload.get("reward", 0.0)
        reward_total = (
            reward_payload.get("total", 0.0)
            if isinstance(reward_payload, dict)
            else reward_payload
        )

        return StepResult(
            observation=observation,
            reward=reward_total,
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> State:
        """

        Parse the /state response into an OpenEnv State object.



        Args:

            payload: JSON response from the state endpoint.



        Returns:

            State with a task-derived episode identifier and current step count.

        """
        return State(
            episode_id=payload.get("task_id", "pharma-vigilance"),
            step_count=payload.get("step_number", 0),
        )