| """OpenEnv-compatible wrapper around local env service. |
| |
| The wrapper intentionally exposes meaningful clinician-facing tool methods for |
| LLM policy training instead of a single opaque ``step(action)`` interface. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, Literal |
|
|
| from app.env.client import PolyGuardEnvClient |
|
|
| try: |
| from openenv import GenericEnvClient |
| except Exception: |
| GenericEnvClient = None |
|
|
|
|
| class LocalOpenEnvWrapper: |
| def __init__(self, base_url: str = "http://127.0.0.1:8100") -> None: |
| self.http_client = PolyGuardEnvClient(base_url=base_url) |
| self.base_url = base_url |
| self._sync_client: Any = None |
| if GenericEnvClient is not None: |
| try: |
| self._sync_client = GenericEnvClient(base_url=base_url).sync() |
| self._sync_client.connect() |
| except Exception: |
| self._sync_client = None |
|
|
| def reset(self, **kwargs: Any) -> dict[str, Any]: |
| if self._sync_client is not None: |
| result = self._sync_client.reset(**kwargs) |
| return { |
| "observation": result.observation, |
| "reward": result.reward, |
| "done": result.done, |
| } |
| return self.http_client.reset(**kwargs) |
|
|
| def step(self, action: dict[str, Any]) -> dict[str, Any]: |
| if self._sync_client is not None: |
| result = self._sync_client.step(action) |
| return { |
| "observation": result.observation, |
| "reward": result.reward, |
| "done": result.done, |
| } |
| return self.http_client.step(action) |
|
|
| def state(self) -> dict[str, Any]: |
| if self._sync_client is not None: |
| return self._sync_client.state() |
| return self.http_client.state() |
|
|
| def trace(self) -> list[dict[str, Any]]: |
| return self.http_client.trace() |
|
|
| def legal_actions(self) -> list[dict[str, Any]]: |
| return self.http_client.legal_actions() |
|
|
| def reward_breakdown(self) -> dict[str, Any]: |
| return self.http_client.reward_breakdown() |
|
|
| def uncertainty(self) -> dict[str, Any]: |
| return self.http_client.uncertainty() |
|
|
| def inspect_regimen(self) -> dict[str, Any]: |
| """Return a compact clinical snapshot of the active case.""" |
| state = self.state() |
| patient = state.get("patient", {}) |
| risk_summary = state.get("risk_summary", {}) |
| meds = patient.get("medications", []) |
| return { |
| "patient_id": patient.get("patient_id"), |
| "age": patient.get("age"), |
| "comorbidities": patient.get("comorbidities", []), |
| "medication_count": len(meds), |
| "medications": meds, |
| "risk_summary": risk_summary, |
| "burden_score": state.get("burden_score"), |
| "step_count": state.get("step_count"), |
| "max_steps": state.get("max_steps"), |
| } |
|
|
| def evaluate_candidate(self, candidate_id: str) -> dict[str, Any]: |
| """Lookup a legal candidate action by candidate id.""" |
| candidates = self.legal_actions() |
| for candidate in candidates: |
| if candidate.get("candidate_id") == candidate_id: |
| return candidate |
| return {"candidate_id": candidate_id, "found": False} |
|
|
| def _execute_action( |
| self, |
| mode: str, |
| action_type: str, |
| target_drug: str | None = None, |
| replacement_drug: str | None = None, |
| dose_bucket: str = "NA", |
| taper_days: int | None = None, |
| monitoring_plan: str | None = None, |
| candidate_id: str = "cand_manual", |
| confidence: float = 0.65, |
| rationale_brief: str = "tool_action", |
| ) -> dict[str, Any]: |
| payload = { |
| "mode": mode, |
| "action_type": action_type, |
| "target_drug": target_drug, |
| "replacement_drug": replacement_drug, |
| "dose_bucket": dose_bucket, |
| "taper_days": taper_days, |
| "monitoring_plan": monitoring_plan, |
| "candidate_id": candidate_id, |
| "confidence": confidence, |
| "rationale_brief": rationale_brief, |
| } |
| return self.step(payload) |
|
|
| def stop_drug(self, target_drug: str, taper_days: int | None = None, candidate_id: str = "cand_stop_tool") -> dict[str, Any]: |
| """Issue STOP_DRUG action for a single medication.""" |
| return self._execute_action( |
| mode="REGIMEN_OPT", |
| action_type="STOP_DRUG", |
| target_drug=target_drug, |
| taper_days=taper_days, |
| candidate_id=candidate_id, |
| rationale_brief=f"stop_drug:{target_drug}", |
| ) |
|
|
| def substitute_drug( |
| self, |
| target_drug: str, |
| replacement_drug: str, |
| candidate_id: str = "cand_substitute_tool", |
| ) -> dict[str, Any]: |
| """Issue SUBSTITUTE_WITHIN_CLASS action.""" |
| return self._execute_action( |
| mode="REGIMEN_OPT", |
| action_type="SUBSTITUTE_WITHIN_CLASS", |
| target_drug=target_drug, |
| replacement_drug=replacement_drug, |
| candidate_id=candidate_id, |
| rationale_brief=f"substitute:{target_drug}->{replacement_drug}", |
| ) |
|
|
| def start_taper(self, target_drug: str, taper_days: int = 14, candidate_id: str = "cand_taper_start_tool") -> dict[str, Any]: |
| """Issue TAPER_INITIATE action.""" |
| return self._execute_action( |
| mode="REGIMEN_OPT", |
| action_type="TAPER_INITIATE", |
| target_drug=target_drug, |
| taper_days=taper_days, |
| candidate_id=candidate_id, |
| rationale_brief=f"taper_start:{target_drug}", |
| ) |
|
|
| def continue_taper(self, target_drug: str, taper_days: int = 7, candidate_id: str = "cand_taper_continue_tool") -> dict[str, Any]: |
| """Issue TAPER_CONTINUE action.""" |
| return self._execute_action( |
| mode="REGIMEN_OPT", |
| action_type="TAPER_CONTINUE", |
| target_drug=target_drug, |
| taper_days=taper_days, |
| candidate_id=candidate_id, |
| rationale_brief=f"taper_continue:{target_drug}", |
| ) |
|
|
| def adjust_dose( |
| self, |
| target_drug: str, |
| direction: Literal["increase", "reduce", "hold"], |
| candidate_id: str = "cand_adjust_dose_tool", |
| ) -> dict[str, Any]: |
| """Adjust dose bucket with an explicit direction.""" |
| if direction == "increase": |
| action_type = "INCREASE_DOSE_BUCKET" |
| dose_bucket = "HIGH" |
| elif direction == "reduce": |
| action_type = "REDUCE_DOSE_BUCKET" |
| dose_bucket = "LOW" |
| else: |
| action_type = "DOSE_HOLD" |
| dose_bucket = "HOLD" |
| return self._execute_action( |
| mode="DOSE_OPT", |
| action_type=action_type, |
| target_drug=target_drug, |
| dose_bucket=dose_bucket, |
| candidate_id=candidate_id, |
| rationale_brief=f"adjust_dose:{direction}:{target_drug}", |
| ) |
|
|
| def request_review( |
| self, |
| review_type: Literal["pharmacist", "specialist"] = "specialist", |
| candidate_id: str = "cand_review_tool", |
| ) -> dict[str, Any]: |
| """Request human review when uncertainty or legality concerns are high.""" |
| action_type = "REQUEST_PHARMACIST_REVIEW" if review_type == "pharmacist" else "REQUEST_SPECIALIST_REVIEW" |
| return self._execute_action( |
| mode="ABSTAIN_REVIEW", |
| action_type=action_type, |
| candidate_id=candidate_id, |
| rationale_brief=f"request_review:{review_type}", |
| ) |
|
|
| def finish_case(self, candidate_id: str = "cand_finish_tool") -> dict[str, Any]: |
| """Close the episode with a conservative keep action.""" |
| return self._execute_action( |
| mode="REGIMEN_OPT", |
| action_type="KEEP_REGIMEN", |
| candidate_id=candidate_id, |
| rationale_brief="finish_case", |
| ) |
|
|
| def close(self) -> None: |
| if self._sync_client is not None: |
| try: |
| self._sync_client.close() |
| except Exception: |
| pass |
|
|