Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Pharmacovigilance Signal Detector Environment Client.""" | |
| from typing import Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| try: | |
| from .env import Action, Observation, AdverseEventReport | |
| except ImportError: | |
| from env import Action, Observation, AdverseEventReport | |
| class PharmaVigilanceEnvClient( | |
| EnvClient[Action, Observation, State] | |
| ): | |
| """ | |
| Client for the Pharmacovigilance Signal Detector environment. | |
| This client maintains a persistent connection to the environment server and | |
| parses server responses into strongly-typed observation models. | |
| Example: | |
| >>> with PharmaVigilanceEnvClient(base_url="http://localhost:7860") as env: | |
| ... result = env.reset(task_id="known_signal_easy") | |
| ... print(result.observation.task_id) | |
| ... | |
| ... action = Action( | |
| ... classification="known_side_effect", | |
| ... suspect_drug="Ibuprofen", | |
| ... severity_assessment="moderate", | |
| ... recommended_action="log_and_monitor", | |
| ... reasoning="GI bleeding is a known ibuprofen adverse effect.", | |
| ... ) | |
| ... result = env.step(action) | |
| ... print(result.observation.feedback) | |
| ... print(result.reward) | |
| Example with Docker: | |
| >>> client = PharmaVigilanceEnvClient.from_docker_image("pharmacovigilance-env:latest") | |
| >>> try: | |
| ... result = client.reset(task_id="cluster_signal_medium") | |
| ... action = Action( | |
| ... classification="new_signal", | |
| ... suspect_drug="Gliptozin", | |
| ... severity_assessment="severe", | |
| ... recommended_action="escalate", | |
| ... reasoning="Clustered vision loss on a new drug warrants escalation.", | |
| ... ) | |
| ... result = client.step(action) | |
| ... finally: | |
| ... client.close() | |
| """ | |
| def _step_payload(self, action: Action) -> Dict: | |
| """ | |
| Convert an Action model into the JSON payload sent to /step. | |
| Args: | |
| action: Typed agent action. | |
| Returns: | |
| Dictionary representation suitable for JSON transport. | |
| """ | |
| return { | |
| "classification": action.classification, | |
| "suspect_drug": action.suspect_drug, | |
| "severity_assessment": action.severity_assessment, | |
| "recommended_action": action.recommended_action, | |
| "reasoning": action.reasoning, | |
| "confidence": action.confidence, | |
| } | |
| def _parse_result(self, payload: Dict) -> StepResult[Observation]: | |
| """ | |
| Parse a server /step response into StepResult[Observation]. | |
| Args: | |
| payload: JSON response from the environment server. | |
| Returns: | |
| StepResult containing the typed observation, reward, and done flag. | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| reports = [ | |
| AdverseEventReport(**report) | |
| for report in obs_data.get("reports", []) | |
| ] | |
| observation = Observation( | |
| task_id=obs_data.get("task_id", ""), | |
| reports=reports, | |
| drug_interaction_db=obs_data.get("drug_interaction_db", {}), | |
| step_number=obs_data.get("step_number", 0), | |
| max_steps=obs_data.get("max_steps", 1), | |
| feedback=obs_data.get("feedback"), | |
| ) | |
| reward_payload = payload.get("reward", 0.0) | |
| reward_total = ( | |
| reward_payload.get("total", 0.0) | |
| if isinstance(reward_payload, dict) | |
| else reward_payload | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=reward_total, | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> State: | |
| """ | |
| Parse the /state response into an OpenEnv State object. | |
| Args: | |
| payload: JSON response from the state endpoint. | |
| Returns: | |
| State with a task-derived episode identifier and current step count. | |
| """ | |
| return State( | |
| episode_id=payload.get("task_id", "pharma-vigilance"), | |
| step_count=payload.get("step_number", 0), | |
| ) | |