polyguard-openenv / app /training /openenv_wrapper.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""OpenEnv-compatible wrapper around local env service.
The wrapper intentionally exposes meaningful clinician-facing tool methods for
LLM policy training instead of a single opaque ``step(action)`` interface.
"""
from __future__ import annotations
from typing import Any, Literal
from app.env.client import PolyGuardEnvClient
try:
from openenv import GenericEnvClient
except Exception: # noqa: BLE001
GenericEnvClient = None # type: ignore[assignment]
class LocalOpenEnvWrapper:
def __init__(self, base_url: str = "http://127.0.0.1:8100") -> None:
self.http_client = PolyGuardEnvClient(base_url=base_url)
self.base_url = base_url
self._sync_client: Any = None
if GenericEnvClient is not None:
try:
self._sync_client = GenericEnvClient(base_url=base_url).sync()
self._sync_client.connect()
except Exception: # noqa: BLE001
self._sync_client = None
def reset(self, **kwargs: Any) -> dict[str, Any]:
if self._sync_client is not None:
result = self._sync_client.reset(**kwargs)
return {
"observation": result.observation,
"reward": result.reward,
"done": result.done,
}
return self.http_client.reset(**kwargs)
def step(self, action: dict[str, Any]) -> dict[str, Any]:
if self._sync_client is not None:
result = self._sync_client.step(action)
return {
"observation": result.observation,
"reward": result.reward,
"done": result.done,
}
return self.http_client.step(action)
def state(self) -> dict[str, Any]:
if self._sync_client is not None:
return self._sync_client.state()
return self.http_client.state()
def trace(self) -> list[dict[str, Any]]:
return self.http_client.trace()
def legal_actions(self) -> list[dict[str, Any]]:
return self.http_client.legal_actions()
def reward_breakdown(self) -> dict[str, Any]:
return self.http_client.reward_breakdown()
def uncertainty(self) -> dict[str, Any]:
return self.http_client.uncertainty()
def inspect_regimen(self) -> dict[str, Any]:
"""Return a compact clinical snapshot of the active case."""
state = self.state()
patient = state.get("patient", {})
risk_summary = state.get("risk_summary", {})
meds = patient.get("medications", [])
return {
"patient_id": patient.get("patient_id"),
"age": patient.get("age"),
"comorbidities": patient.get("comorbidities", []),
"medication_count": len(meds),
"medications": meds,
"risk_summary": risk_summary,
"burden_score": state.get("burden_score"),
"step_count": state.get("step_count"),
"max_steps": state.get("max_steps"),
}
def evaluate_candidate(self, candidate_id: str) -> dict[str, Any]:
"""Lookup a legal candidate action by candidate id."""
candidates = self.legal_actions()
for candidate in candidates:
if candidate.get("candidate_id") == candidate_id:
return candidate
return {"candidate_id": candidate_id, "found": False}
def _execute_action(
self,
mode: str,
action_type: str,
target_drug: str | None = None,
replacement_drug: str | None = None,
dose_bucket: str = "NA",
taper_days: int | None = None,
monitoring_plan: str | None = None,
candidate_id: str = "cand_manual",
confidence: float = 0.65,
rationale_brief: str = "tool_action",
) -> dict[str, Any]:
payload = {
"mode": mode,
"action_type": action_type,
"target_drug": target_drug,
"replacement_drug": replacement_drug,
"dose_bucket": dose_bucket,
"taper_days": taper_days,
"monitoring_plan": monitoring_plan,
"candidate_id": candidate_id,
"confidence": confidence,
"rationale_brief": rationale_brief,
}
return self.step(payload)
def stop_drug(self, target_drug: str, taper_days: int | None = None, candidate_id: str = "cand_stop_tool") -> dict[str, Any]:
"""Issue STOP_DRUG action for a single medication."""
return self._execute_action(
mode="REGIMEN_OPT",
action_type="STOP_DRUG",
target_drug=target_drug,
taper_days=taper_days,
candidate_id=candidate_id,
rationale_brief=f"stop_drug:{target_drug}",
)
def substitute_drug(
self,
target_drug: str,
replacement_drug: str,
candidate_id: str = "cand_substitute_tool",
) -> dict[str, Any]:
"""Issue SUBSTITUTE_WITHIN_CLASS action."""
return self._execute_action(
mode="REGIMEN_OPT",
action_type="SUBSTITUTE_WITHIN_CLASS",
target_drug=target_drug,
replacement_drug=replacement_drug,
candidate_id=candidate_id,
rationale_brief=f"substitute:{target_drug}->{replacement_drug}",
)
def start_taper(self, target_drug: str, taper_days: int = 14, candidate_id: str = "cand_taper_start_tool") -> dict[str, Any]:
"""Issue TAPER_INITIATE action."""
return self._execute_action(
mode="REGIMEN_OPT",
action_type="TAPER_INITIATE",
target_drug=target_drug,
taper_days=taper_days,
candidate_id=candidate_id,
rationale_brief=f"taper_start:{target_drug}",
)
def continue_taper(self, target_drug: str, taper_days: int = 7, candidate_id: str = "cand_taper_continue_tool") -> dict[str, Any]:
"""Issue TAPER_CONTINUE action."""
return self._execute_action(
mode="REGIMEN_OPT",
action_type="TAPER_CONTINUE",
target_drug=target_drug,
taper_days=taper_days,
candidate_id=candidate_id,
rationale_brief=f"taper_continue:{target_drug}",
)
def adjust_dose(
self,
target_drug: str,
direction: Literal["increase", "reduce", "hold"],
candidate_id: str = "cand_adjust_dose_tool",
) -> dict[str, Any]:
"""Adjust dose bucket with an explicit direction."""
if direction == "increase":
action_type = "INCREASE_DOSE_BUCKET"
dose_bucket = "HIGH"
elif direction == "reduce":
action_type = "REDUCE_DOSE_BUCKET"
dose_bucket = "LOW"
else:
action_type = "DOSE_HOLD"
dose_bucket = "HOLD"
return self._execute_action(
mode="DOSE_OPT",
action_type=action_type,
target_drug=target_drug,
dose_bucket=dose_bucket,
candidate_id=candidate_id,
rationale_brief=f"adjust_dose:{direction}:{target_drug}",
)
def request_review(
self,
review_type: Literal["pharmacist", "specialist"] = "specialist",
candidate_id: str = "cand_review_tool",
) -> dict[str, Any]:
"""Request human review when uncertainty or legality concerns are high."""
action_type = "REQUEST_PHARMACIST_REVIEW" if review_type == "pharmacist" else "REQUEST_SPECIALIST_REVIEW"
return self._execute_action(
mode="ABSTAIN_REVIEW",
action_type=action_type,
candidate_id=candidate_id,
rationale_brief=f"request_review:{review_type}",
)
def finish_case(self, candidate_id: str = "cand_finish_tool") -> dict[str, Any]:
"""Close the episode with a conservative keep action."""
return self._execute_action(
mode="REGIMEN_OPT",
action_type="KEEP_REGIMEN",
candidate_id=candidate_id,
rationale_brief="finish_case",
)
def close(self) -> None:
if self._sync_client is not None:
try:
self._sync_client.close()
except Exception: # noqa: BLE001
pass