Spaces:
Sleeping
Sleeping
File size: 4,643 Bytes
f2beac3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Pharmacovigilance Signal Detector Environment Client."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
try:
from .env import Action, Observation, AdverseEventReport
except ImportError:
from env import Action, Observation, AdverseEventReport
class PharmaVigilanceEnvClient(
EnvClient[Action, Observation, State]
):
"""
Client for the Pharmacovigilance Signal Detector environment.
This client maintains a persistent connection to the environment server and
parses server responses into strongly-typed observation models.
Example:
>>> with PharmaVigilanceEnvClient(base_url="http://localhost:7860") as env:
... result = env.reset(task_id="known_signal_easy")
... print(result.observation.task_id)
...
... action = Action(
... classification="known_side_effect",
... suspect_drug="Ibuprofen",
... severity_assessment="moderate",
... recommended_action="log_and_monitor",
... reasoning="GI bleeding is a known ibuprofen adverse effect.",
... )
... result = env.step(action)
... print(result.observation.feedback)
... print(result.reward)
Example with Docker:
>>> client = PharmaVigilanceEnvClient.from_docker_image("pharmacovigilance-env:latest")
>>> try:
... result = client.reset(task_id="cluster_signal_medium")
... action = Action(
... classification="new_signal",
... suspect_drug="Gliptozin",
... severity_assessment="severe",
... recommended_action="escalate",
... reasoning="Clustered vision loss on a new drug warrants escalation.",
... )
... result = client.step(action)
... finally:
... client.close()
"""
def _step_payload(self, action: Action) -> Dict:
"""
Convert an Action model into the JSON payload sent to /step.
Args:
action: Typed agent action.
Returns:
Dictionary representation suitable for JSON transport.
"""
return {
"classification": action.classification,
"suspect_drug": action.suspect_drug,
"severity_assessment": action.severity_assessment,
"recommended_action": action.recommended_action,
"reasoning": action.reasoning,
}
def _parse_result(self, payload: Dict) -> StepResult[Observation]:
"""
Parse a server /step response into StepResult[Observation].
Args:
payload: JSON response from the environment server.
Returns:
StepResult containing the typed observation, reward, and done flag.
"""
obs_data = payload.get("observation", {})
reports = [
AdverseEventReport(**report)
for report in obs_data.get("reports", [])
]
observation = Observation(
task_id=obs_data.get("task_id", ""),
reports=reports,
drug_interaction_db=obs_data.get("drug_interaction_db", {}),
step_number=obs_data.get("step_number", 0),
max_steps=obs_data.get("max_steps", 1),
feedback=obs_data.get("feedback"),
)
reward_payload = payload.get("reward", 0.0)
reward_total = (
reward_payload.get("total", 0.0)
if isinstance(reward_payload, dict)
else reward_payload
)
return StepResult(
observation=observation,
reward=reward_total,
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""
Parse the /state response into an OpenEnv State object.
Args:
payload: JSON response from the state endpoint.
Returns:
State with a task-derived episode identifier and current step count.
"""
return State(
episode_id=payload.get("task_id", "pharma-vigilance"),
step_count=payload.get("step_number", 0),
)
|