"""OpenEnv-compatible wrapper around local env service. The wrapper intentionally exposes meaningful clinician-facing tool methods for LLM policy training instead of a single opaque ``step(action)`` interface. """ from __future__ import annotations from typing import Any, Literal from app.env.client import PolyGuardEnvClient try: from openenv import GenericEnvClient except Exception: # noqa: BLE001 GenericEnvClient = None # type: ignore[assignment] class LocalOpenEnvWrapper: def __init__(self, base_url: str = "http://127.0.0.1:8100") -> None: self.http_client = PolyGuardEnvClient(base_url=base_url) self.base_url = base_url self._sync_client: Any = None if GenericEnvClient is not None: try: self._sync_client = GenericEnvClient(base_url=base_url).sync() self._sync_client.connect() except Exception: # noqa: BLE001 self._sync_client = None def reset(self, **kwargs: Any) -> dict[str, Any]: if self._sync_client is not None: result = self._sync_client.reset(**kwargs) return { "observation": result.observation, "reward": result.reward, "done": result.done, } return self.http_client.reset(**kwargs) def step(self, action: dict[str, Any]) -> dict[str, Any]: if self._sync_client is not None: result = self._sync_client.step(action) return { "observation": result.observation, "reward": result.reward, "done": result.done, } return self.http_client.step(action) def state(self) -> dict[str, Any]: if self._sync_client is not None: return self._sync_client.state() return self.http_client.state() def trace(self) -> list[dict[str, Any]]: return self.http_client.trace() def legal_actions(self) -> list[dict[str, Any]]: return self.http_client.legal_actions() def reward_breakdown(self) -> dict[str, Any]: return self.http_client.reward_breakdown() def uncertainty(self) -> dict[str, Any]: return self.http_client.uncertainty() def inspect_regimen(self) -> dict[str, Any]: """Return a compact clinical snapshot of the active case.""" state = self.state() patient = state.get("patient", {}) risk_summary = state.get("risk_summary", {}) meds = patient.get("medications", []) return { "patient_id": patient.get("patient_id"), "age": patient.get("age"), "comorbidities": patient.get("comorbidities", []), "medication_count": len(meds), "medications": meds, "risk_summary": risk_summary, "burden_score": state.get("burden_score"), "step_count": state.get("step_count"), "max_steps": state.get("max_steps"), } def evaluate_candidate(self, candidate_id: str) -> dict[str, Any]: """Lookup a legal candidate action by candidate id.""" candidates = self.legal_actions() for candidate in candidates: if candidate.get("candidate_id") == candidate_id: return candidate return {"candidate_id": candidate_id, "found": False} def _execute_action( self, mode: str, action_type: str, target_drug: str | None = None, replacement_drug: str | None = None, dose_bucket: str = "NA", taper_days: int | None = None, monitoring_plan: str | None = None, candidate_id: str = "cand_manual", confidence: float = 0.65, rationale_brief: str = "tool_action", ) -> dict[str, Any]: payload = { "mode": mode, "action_type": action_type, "target_drug": target_drug, "replacement_drug": replacement_drug, "dose_bucket": dose_bucket, "taper_days": taper_days, "monitoring_plan": monitoring_plan, "candidate_id": candidate_id, "confidence": confidence, "rationale_brief": rationale_brief, } return self.step(payload) def stop_drug(self, target_drug: str, taper_days: int | None = None, candidate_id: str = "cand_stop_tool") -> dict[str, Any]: """Issue STOP_DRUG action for a single medication.""" return self._execute_action( mode="REGIMEN_OPT", action_type="STOP_DRUG", target_drug=target_drug, taper_days=taper_days, candidate_id=candidate_id, rationale_brief=f"stop_drug:{target_drug}", ) def substitute_drug( self, target_drug: str, replacement_drug: str, candidate_id: str = "cand_substitute_tool", ) -> dict[str, Any]: """Issue SUBSTITUTE_WITHIN_CLASS action.""" return self._execute_action( mode="REGIMEN_OPT", action_type="SUBSTITUTE_WITHIN_CLASS", target_drug=target_drug, replacement_drug=replacement_drug, candidate_id=candidate_id, rationale_brief=f"substitute:{target_drug}->{replacement_drug}", ) def start_taper(self, target_drug: str, taper_days: int = 14, candidate_id: str = "cand_taper_start_tool") -> dict[str, Any]: """Issue TAPER_INITIATE action.""" return self._execute_action( mode="REGIMEN_OPT", action_type="TAPER_INITIATE", target_drug=target_drug, taper_days=taper_days, candidate_id=candidate_id, rationale_brief=f"taper_start:{target_drug}", ) def continue_taper(self, target_drug: str, taper_days: int = 7, candidate_id: str = "cand_taper_continue_tool") -> dict[str, Any]: """Issue TAPER_CONTINUE action.""" return self._execute_action( mode="REGIMEN_OPT", action_type="TAPER_CONTINUE", target_drug=target_drug, taper_days=taper_days, candidate_id=candidate_id, rationale_brief=f"taper_continue:{target_drug}", ) def adjust_dose( self, target_drug: str, direction: Literal["increase", "reduce", "hold"], candidate_id: str = "cand_adjust_dose_tool", ) -> dict[str, Any]: """Adjust dose bucket with an explicit direction.""" if direction == "increase": action_type = "INCREASE_DOSE_BUCKET" dose_bucket = "HIGH" elif direction == "reduce": action_type = "REDUCE_DOSE_BUCKET" dose_bucket = "LOW" else: action_type = "DOSE_HOLD" dose_bucket = "HOLD" return self._execute_action( mode="DOSE_OPT", action_type=action_type, target_drug=target_drug, dose_bucket=dose_bucket, candidate_id=candidate_id, rationale_brief=f"adjust_dose:{direction}:{target_drug}", ) def request_review( self, review_type: Literal["pharmacist", "specialist"] = "specialist", candidate_id: str = "cand_review_tool", ) -> dict[str, Any]: """Request human review when uncertainty or legality concerns are high.""" action_type = "REQUEST_PHARMACIST_REVIEW" if review_type == "pharmacist" else "REQUEST_SPECIALIST_REVIEW" return self._execute_action( mode="ABSTAIN_REVIEW", action_type=action_type, candidate_id=candidate_id, rationale_brief=f"request_review:{review_type}", ) def finish_case(self, candidate_id: str = "cand_finish_tool") -> dict[str, Any]: """Close the episode with a conservative keep action.""" return self._execute_action( mode="REGIMEN_OPT", action_type="KEEP_REGIMEN", candidate_id=candidate_id, rationale_brief="finish_case", ) def close(self) -> None: if self._sync_client is not None: try: self._sync_client.close() except Exception: # noqa: BLE001 pass