| """Core PolyGuard environment implementation.""" |
|
|
| from __future__ import annotations |
|
|
| import time |
| import uuid |
| import os |
| from pathlib import Path |
| from typing import Optional |
|
|
| from app.common.constants import ( |
| DEFAULT_EPISODE_TIMEOUT_SECONDS, |
| DEFAULT_MAX_STEPS, |
| DEFAULT_SEED, |
| DEFAULT_STEP_TIMEOUT_SECONDS, |
| ) |
| from app.common.enums import Difficulty, SubEnvironment |
| from app.common.seeding import set_global_seed |
| from app.common.types import ( |
| CandidateAction, |
| PolyGuardAction, |
| PolyGuardObservation, |
| PolyGuardState, |
| RewardBreakdown, |
| StepTrace, |
| UncertaintyReport, |
| ) |
| from app.env.anti_cheat import evaluate_anti_cheat |
| from app.env.curriculum import pick_difficulty, pick_sub_environment |
| from app.env.reward_router import compute_reward_breakdown |
| from app.env.scenario_loader import load_or_generate_scenario |
| from app.env.termination import check_termination_with_timeout |
| from app.env.transition import apply_transition |
| from app.env.verifier import verify_action_legality |
| from app.knowledge.ddi_knowledge import top_risky_pairs |
| from app.models.policy.candidate_builder import build_candidates |
| from app.models.policy.uncertainty import estimate_uncertainty |
|
|
|
|
| class PolyGuardEnv: |
| def __init__(self, root: Optional[Path] = None) -> None: |
| self.root = root or Path(__file__).resolve().parents[2] |
| self._episode_index = 0 |
| self._state: Optional[PolyGuardState] = None |
| self._trace: list[StepTrace] = [] |
| self._last_reward: Optional[RewardBreakdown] = None |
| self._episode_started_at: float = 0.0 |
| self._episode_timeout_seconds: float = float( |
| os.getenv("POLYGUARD_EPISODE_TIMEOUT_SECONDS", str(DEFAULT_EPISODE_TIMEOUT_SECONDS)) |
| ) |
| self._step_timeout_seconds: float = float( |
| os.getenv("POLYGUARD_STEP_TIMEOUT_SECONDS", str(DEFAULT_STEP_TIMEOUT_SECONDS)) |
| ) |
|
|
| @property |
| def state(self) -> PolyGuardState: |
| if self._state is None: |
| raise RuntimeError("Environment has not been reset.") |
| return self._state |
|
|
| def reset( |
| self, |
| seed: Optional[int] = None, |
| difficulty: Optional[str] = None, |
| sub_environment: Optional[str] = None, |
| scenario_id: Optional[str] = None, |
| patient_id: Optional[str] = None, |
| ) -> PolyGuardObservation: |
| run_seed = set_global_seed(seed if seed is not None else DEFAULT_SEED) |
| diff = Difficulty(difficulty) if difficulty else pick_difficulty(self._episode_index) |
| if sub_environment: |
| chosen_sub_environment = SubEnvironment(sub_environment) |
| else: |
| chosen_sub_environment = pick_sub_environment(self._episode_index, diff) |
| patient = load_or_generate_scenario( |
| root=self.root, |
| difficulty=diff, |
| scenario_id=scenario_id, |
| patient_id=patient_id, |
| seed=run_seed, |
| ) |
| scenario_key = scenario_id or patient.patient_id |
| max_steps = { |
| SubEnvironment.DDI: 3, |
| SubEnvironment.REGIMEN_RISK: 6, |
| SubEnvironment.BANDIT_MINING: 6, |
| SubEnvironment.PRECISION_DOSING: 8, |
| SubEnvironment.LONGITUDINAL_DEPRESCRIBING: 10, |
| SubEnvironment.WEB_SEARCH_MISSING_DATA: 5, |
| SubEnvironment.ALTERNATIVE_SUGGESTION: 6, |
| SubEnvironment.NEW_DRUG_DECOMPOSITION: 7, |
| }.get(chosen_sub_environment, { |
| Difficulty.EASY: 3, |
| Difficulty.MEDIUM: 6, |
| Difficulty.HARD: 10, |
| }.get(diff, DEFAULT_MAX_STEPS)) |
| risky_pairs = top_risky_pairs([m.drug for m in patient.medications]) |
| self._state = PolyGuardState( |
| episode_id=f"ep_{uuid.uuid4().hex[:8]}", |
| seed=run_seed, |
| scenario_id=scenario_key, |
| difficulty=diff, |
| sub_environment=chosen_sub_environment, |
| step_count=0, |
| max_steps=max_steps, |
| patient=patient, |
| risk_summary={ |
| "polypharmacy_count": float(len(patient.medications)), |
| "burden_score": len(patient.medications) / 12.0, |
| "severe_pair_count": float(len(risky_pairs)), |
| }, |
| burden_score=min(1.0, len(patient.medications) / 12.0), |
| precision_dosing_flags=["dose_sensitive_case"] if chosen_sub_environment == SubEnvironment.PRECISION_DOSING else [], |
| unresolved_conflicts=list(patient.specialist_conflicts), |
| ) |
| self._trace = [] |
| self._last_reward = None |
| self._episode_started_at = time.monotonic() |
| self._episode_index += 1 |
| obs = self._build_observation() |
| self._trace.append( |
| StepTrace( |
| step=0, |
| observation_snapshot=obs, |
| reward_components={}, |
| ) |
| ) |
| return obs |
|
|
| def _build_observation(self) -> PolyGuardObservation: |
| state = self.state |
| candidates = build_candidates(state) |
| uncertainty = estimate_uncertainty(state) |
| risky_pairs = top_risky_pairs([m.drug for m in state.patient.medications]) |
| warning_summary: list[str] = [] |
| if state.burden_score >= 0.7: |
| warning_summary.append("high_polypharmacy_burden") |
| if state.patient.monitoring_gaps: |
| warning_summary.extend([f"monitoring_gap:{gap}" for gap in state.patient.monitoring_gaps[:2]]) |
| if state.sub_environment == SubEnvironment.WEB_SEARCH_MISSING_DATA: |
| warning_summary.append("missing_data_web_evidence_recommended") |
| if state.sub_environment == SubEnvironment.NEW_DRUG_DECOMPOSITION: |
| warning_summary.append("new_drug_component_analysis_recommended") |
| return PolyGuardObservation( |
| patient_summary={ |
| "patient_id": state.patient.patient_id, |
| "age": state.patient.age, |
| "sex": state.patient.sex, |
| "adherence_estimate": state.patient.adherence_estimate, |
| "sub_environment": state.sub_environment.value, |
| }, |
| medication_table=[m.model_dump(mode="json") for m in state.patient.medications], |
| comorbidity_summary=state.patient.comorbidities, |
| organ_function_summary={ |
| "egfr": state.patient.labs.egfr, |
| "ast": state.patient.labs.ast, |
| "alt": state.patient.labs.alt, |
| }, |
| labs_vitals_summary={**state.patient.labs.model_dump(mode="json"), **state.patient.vitals}, |
| graph_safety_summary={ |
| "top_risk_pairs": risky_pairs, |
| "polypharmacy_count": len(state.patient.medications), |
| "estimated_risk": state.risk_summary.get("burden_score", 0.5), |
| }, |
| burden_score_summary={"burden_score": state.burden_score}, |
| precision_dosing_flags=state.precision_dosing_flags, |
| unresolved_conflicts=state.unresolved_conflicts, |
| candidate_action_set=candidates, |
| step_budget_remaining=max(0, state.max_steps - state.step_count), |
| action_history=state.action_history, |
| warning_summary=warning_summary, |
| abstention_indicators={"uncertainty": uncertainty, "recommended": uncertainty > 0.65}, |
| sub_environment=state.sub_environment, |
| deterministic_contract={ |
| "seed": state.seed, |
| "scenario_id": state.scenario_id, |
| "difficulty": state.difficulty.value, |
| "sub_environment": state.sub_environment.value, |
| }, |
| ) |
|
|
| @staticmethod |
| def _action_from_payload(action: PolyGuardAction | dict) -> PolyGuardAction: |
| if isinstance(action, PolyGuardAction): |
| return action |
| if not isinstance(action, dict): |
| raise ValueError("Action must be a PolyGuardAction or dictionary payload.") |
| try: |
| return PolyGuardAction.model_validate(action) |
| except Exception: |
| candidate = CandidateAction.model_validate(action) |
| return PolyGuardAction( |
| mode=candidate.mode, |
| action_type=candidate.action_type, |
| target_drug=candidate.target_drug, |
| replacement_drug=candidate.replacement_drug, |
| dose_bucket=candidate.dose_bucket, |
| taper_days=candidate.taper_days, |
| monitoring_plan=candidate.monitoring_plan, |
| evidence_query=candidate.evidence_query, |
| new_drug_name=candidate.new_drug_name, |
| candidate_components=candidate.candidate_components, |
| candidate_id=candidate.candidate_id, |
| confidence=max(0.45, 1.0 - candidate.uncertainty_score), |
| rationale_brief=f"Candidate-selected action ({','.join(candidate.rationale_tags[:2]) or 'rule'})", |
| ) |
|
|
| def step(self, action: PolyGuardAction | dict) -> tuple[PolyGuardObservation, float, bool, dict]: |
| step_started_at = time.monotonic() |
| state = self.state |
| if state.done: |
| observation = self._build_observation() |
| reward = self._last_reward.total_reward if self._last_reward else 0.001 |
| info = { |
| "termination_reason": "already_done", |
| "reward_breakdown": self._last_reward.model_dump(mode="json") if self._last_reward else {}, |
| "transition_delta": {"applied": False, "reason": ["episode_already_complete"], "rolled_back": True}, |
| } |
| return observation, reward, True, info |
|
|
| parsed = self._action_from_payload(action) |
| pre_burden = state.burden_score |
| pre_risky_pairs = len(top_risky_pairs([m.drug for m in state.patient.medications])) |
| safety_report = verify_action_legality(state, parsed) |
| legal_candidate_ids = {c.candidate_id for c in build_candidates(state)} |
| anti_cheat = evaluate_anti_cheat(state, parsed, legal_candidate_ids=legal_candidate_ids) |
|
|
| if safety_report.legal and not anti_cheat.exploit_detected: |
| transition_delta = apply_transition(state, parsed) |
| else: |
| transition_delta = { |
| "applied": False, |
| "reason": safety_report.violations or anti_cheat.reasons or ["blocked"], |
| "rolled_back": True, |
| } |
| state.action_history.append({"step": state.step_count, "action": parsed.model_dump(mode="json"), "applied": False}) |
| state.step_count += 1 |
|
|
| uncertainty_report = self.get_uncertainty_report() |
| reward = compute_reward_breakdown( |
| state=state, |
| action=parsed, |
| safety_report=safety_report, |
| anti_cheat_detected=anti_cheat.exploit_detected, |
| uncertainty=uncertainty_report.overall_uncertainty, |
| pre_burden=pre_burden, |
| pre_risky_pairs=pre_risky_pairs, |
| ) |
| self._last_reward = reward |
| state.cumulative_reward += reward.total_reward |
|
|
| elapsed = time.monotonic() - self._episode_started_at |
| done, reason = check_termination_with_timeout( |
| state=state, |
| action=parsed, |
| exploit_detected=anti_cheat.exploit_detected, |
| elapsed_seconds=elapsed, |
| wall_clock_limit_seconds=self._episode_timeout_seconds, |
| ) |
| step_elapsed = time.monotonic() - step_started_at |
| step_timeout = step_elapsed >= self._step_timeout_seconds |
| if step_timeout and not done: |
| done = True |
| reason = "step_timeout" |
|
|
| state.done = done |
| invalid_action_count = sum(1 for item in state.action_history if item.get("applied") is False) |
| transition_failures = transition_delta.get("reason", []) |
| if isinstance(transition_failures, str): |
| transition_failures = [transition_failures] |
| failure_reasons = list(dict.fromkeys([*safety_report.violations, *anti_cheat.reasons, *transition_failures])) |
| observation = self._build_observation() |
| self._trace.append( |
| StepTrace( |
| step=state.step_count, |
| observation_snapshot=observation, |
| selected_action=parsed, |
| critic_output={"safety_report": safety_report.model_dump(mode="json"), "anti_cheat": anti_cheat.reasons}, |
| reward_components=reward.model_dump(mode="json"), |
| transition_delta=transition_delta, |
| uncertainty_report=uncertainty_report, |
| failure_reasons=failure_reasons, |
| timeout=bool(step_timeout or reason == "wall_clock_timeout"), |
| ) |
| ) |
| info = { |
| "termination_reason": reason, |
| "safety_report": safety_report.model_dump(mode="json"), |
| "anti_cheat_reasons": anti_cheat.reasons, |
| "reward_breakdown": reward.model_dump(mode="json"), |
| "primary_reward_channels": { |
| "safety_legality": reward.primary_safety_legality, |
| "clinical_improvement": reward.primary_clinical_improvement, |
| "dosing_quality": reward.primary_dosing_quality, |
| "process_integrity": reward.primary_process_integrity, |
| }, |
| "failure_reasons": failure_reasons, |
| "transition_delta": transition_delta, |
| "step_timeout": step_timeout, |
| "episode_elapsed_seconds": round(elapsed, 3), |
| "step_elapsed_seconds": round(step_elapsed, 3), |
| "invalid_action_count": invalid_action_count, |
| "checks": { |
| "anti_cheat": bool(anti_cheat.reasons), |
| "timeout": bool(step_timeout or reason == "wall_clock_timeout"), |
| "parser_exploit": "parser_exploit_pattern" in anti_cheat.reasons, |
| "legality_gate": bool(safety_report.legal), |
| }, |
| } |
| return observation, reward.total_reward, done, info |
|
|
| def get_state(self) -> dict: |
| return self.state.model_dump(mode="json") |
|
|
| def get_reward_breakdown(self) -> dict: |
| return self._last_reward.model_dump(mode="json") if self._last_reward else {} |
|
|
| def get_trace(self) -> list[dict]: |
| return [item.model_dump(mode="json") for item in self._trace] |
|
|
| def get_legal_actions(self) -> list[dict]: |
| obs = self._build_observation() |
| return [ |
| self._action_from_payload(candidate.model_dump(mode="json")).model_dump(mode="json") |
| for candidate in obs.candidate_action_set |
| if candidate.legality_precheck |
| ] |
|
|
| def get_candidate_actions(self) -> list[dict]: |
| obs = self._build_observation() |
| return [candidate.model_dump(mode="json") for candidate in obs.candidate_action_set] |
|
|
| def get_metadata(self) -> dict[str, object]: |
| return { |
| "name": "polyguard-openenv", |
| "description": ( |
| "Polypharmacy safety and optimization environment with constrained " |
| "actions, reward decomposition, and OpenEnv-compatible APIs." |
| ), |
| "version": "0.2.0", |
| "openenv_mode": "simulation", |
| "reward_range": [0.001, 0.999], |
| "reward_precision": 3, |
| "action_schema": "PolyGuardAction (strict)", |
| "observation_schema": "PolyGuardObservation", |
| "state_schema": "PolyGuardState", |
| "step_timeout_seconds": self._step_timeout_seconds, |
| "episode_timeout_seconds": self._episode_timeout_seconds, |
| } |
|
|
| def get_uncertainty_report(self) -> UncertaintyReport: |
| state = self.state |
| uncertainty = estimate_uncertainty(state) |
| missing_flags: list[str] = [] |
| if state.patient.labs.egfr is None: |
| missing_flags.append("missing_egfr") |
| if state.patient.labs.ast is None or state.patient.labs.alt is None: |
| missing_flags.append("missing_liver_enzymes") |
| return UncertaintyReport( |
| overall_uncertainty=uncertainty, |
| missing_data_flags=missing_flags, |
| abstention_recommended=uncertainty > 0.65, |
| ) |
|
|