"""Core PolyGuard environment implementation.""" from __future__ import annotations import time import uuid import os from pathlib import Path from typing import Optional from app.common.constants import ( DEFAULT_EPISODE_TIMEOUT_SECONDS, DEFAULT_MAX_STEPS, DEFAULT_SEED, DEFAULT_STEP_TIMEOUT_SECONDS, ) from app.common.enums import Difficulty, SubEnvironment from app.common.seeding import set_global_seed from app.common.types import ( CandidateAction, PolyGuardAction, PolyGuardObservation, PolyGuardState, RewardBreakdown, StepTrace, UncertaintyReport, ) from app.env.anti_cheat import evaluate_anti_cheat from app.env.curriculum import pick_difficulty, pick_sub_environment from app.env.reward_router import compute_reward_breakdown from app.env.scenario_loader import load_or_generate_scenario from app.env.termination import check_termination_with_timeout from app.env.transition import apply_transition from app.env.verifier import verify_action_legality from app.knowledge.ddi_knowledge import top_risky_pairs from app.models.policy.candidate_builder import build_candidates from app.models.policy.uncertainty import estimate_uncertainty class PolyGuardEnv: def __init__(self, root: Optional[Path] = None) -> None: self.root = root or Path(__file__).resolve().parents[2] self._episode_index = 0 self._state: Optional[PolyGuardState] = None self._trace: list[StepTrace] = [] self._last_reward: Optional[RewardBreakdown] = None self._episode_started_at: float = 0.0 self._episode_timeout_seconds: float = float( os.getenv("POLYGUARD_EPISODE_TIMEOUT_SECONDS", str(DEFAULT_EPISODE_TIMEOUT_SECONDS)) ) self._step_timeout_seconds: float = float( os.getenv("POLYGUARD_STEP_TIMEOUT_SECONDS", str(DEFAULT_STEP_TIMEOUT_SECONDS)) ) @property def state(self) -> PolyGuardState: if self._state is None: raise RuntimeError("Environment has not been reset.") return self._state def reset( self, seed: Optional[int] = None, difficulty: Optional[str] = None, sub_environment: Optional[str] = None, scenario_id: Optional[str] = None, patient_id: Optional[str] = None, ) -> PolyGuardObservation: run_seed = set_global_seed(seed if seed is not None else DEFAULT_SEED) diff = Difficulty(difficulty) if difficulty else pick_difficulty(self._episode_index) if sub_environment: chosen_sub_environment = SubEnvironment(sub_environment) else: chosen_sub_environment = pick_sub_environment(self._episode_index, diff) patient = load_or_generate_scenario( root=self.root, difficulty=diff, scenario_id=scenario_id, patient_id=patient_id, seed=run_seed, ) scenario_key = scenario_id or patient.patient_id max_steps = { SubEnvironment.DDI: 3, SubEnvironment.REGIMEN_RISK: 6, SubEnvironment.BANDIT_MINING: 6, SubEnvironment.PRECISION_DOSING: 8, SubEnvironment.LONGITUDINAL_DEPRESCRIBING: 10, SubEnvironment.WEB_SEARCH_MISSING_DATA: 5, SubEnvironment.ALTERNATIVE_SUGGESTION: 6, SubEnvironment.NEW_DRUG_DECOMPOSITION: 7, }.get(chosen_sub_environment, { Difficulty.EASY: 3, Difficulty.MEDIUM: 6, Difficulty.HARD: 10, }.get(diff, DEFAULT_MAX_STEPS)) risky_pairs = top_risky_pairs([m.drug for m in patient.medications]) self._state = PolyGuardState( episode_id=f"ep_{uuid.uuid4().hex[:8]}", seed=run_seed, scenario_id=scenario_key, difficulty=diff, sub_environment=chosen_sub_environment, step_count=0, max_steps=max_steps, patient=patient, risk_summary={ "polypharmacy_count": float(len(patient.medications)), "burden_score": len(patient.medications) / 12.0, "severe_pair_count": float(len(risky_pairs)), }, burden_score=min(1.0, len(patient.medications) / 12.0), precision_dosing_flags=["dose_sensitive_case"] if chosen_sub_environment == SubEnvironment.PRECISION_DOSING else [], unresolved_conflicts=list(patient.specialist_conflicts), ) self._trace = [] self._last_reward = None self._episode_started_at = time.monotonic() self._episode_index += 1 obs = self._build_observation() self._trace.append( StepTrace( step=0, observation_snapshot=obs, reward_components={}, ) ) return obs def _build_observation(self) -> PolyGuardObservation: state = self.state candidates = build_candidates(state) uncertainty = estimate_uncertainty(state) risky_pairs = top_risky_pairs([m.drug for m in state.patient.medications]) warning_summary: list[str] = [] if state.burden_score >= 0.7: warning_summary.append("high_polypharmacy_burden") if state.patient.monitoring_gaps: warning_summary.extend([f"monitoring_gap:{gap}" for gap in state.patient.monitoring_gaps[:2]]) if state.sub_environment == SubEnvironment.WEB_SEARCH_MISSING_DATA: warning_summary.append("missing_data_web_evidence_recommended") if state.sub_environment == SubEnvironment.NEW_DRUG_DECOMPOSITION: warning_summary.append("new_drug_component_analysis_recommended") return PolyGuardObservation( patient_summary={ "patient_id": state.patient.patient_id, "age": state.patient.age, "sex": state.patient.sex, "adherence_estimate": state.patient.adherence_estimate, "sub_environment": state.sub_environment.value, }, medication_table=[m.model_dump(mode="json") for m in state.patient.medications], comorbidity_summary=state.patient.comorbidities, organ_function_summary={ "egfr": state.patient.labs.egfr, "ast": state.patient.labs.ast, "alt": state.patient.labs.alt, }, labs_vitals_summary={**state.patient.labs.model_dump(mode="json"), **state.patient.vitals}, graph_safety_summary={ "top_risk_pairs": risky_pairs, "polypharmacy_count": len(state.patient.medications), "estimated_risk": state.risk_summary.get("burden_score", 0.5), }, burden_score_summary={"burden_score": state.burden_score}, precision_dosing_flags=state.precision_dosing_flags, unresolved_conflicts=state.unresolved_conflicts, candidate_action_set=candidates, step_budget_remaining=max(0, state.max_steps - state.step_count), action_history=state.action_history, warning_summary=warning_summary, abstention_indicators={"uncertainty": uncertainty, "recommended": uncertainty > 0.65}, sub_environment=state.sub_environment, deterministic_contract={ "seed": state.seed, "scenario_id": state.scenario_id, "difficulty": state.difficulty.value, "sub_environment": state.sub_environment.value, }, ) @staticmethod def _action_from_payload(action: PolyGuardAction | dict) -> PolyGuardAction: if isinstance(action, PolyGuardAction): return action if not isinstance(action, dict): raise ValueError("Action must be a PolyGuardAction or dictionary payload.") try: return PolyGuardAction.model_validate(action) except Exception: # noqa: BLE001 candidate = CandidateAction.model_validate(action) return PolyGuardAction( mode=candidate.mode, action_type=candidate.action_type, target_drug=candidate.target_drug, replacement_drug=candidate.replacement_drug, dose_bucket=candidate.dose_bucket, taper_days=candidate.taper_days, monitoring_plan=candidate.monitoring_plan, evidence_query=candidate.evidence_query, new_drug_name=candidate.new_drug_name, candidate_components=candidate.candidate_components, candidate_id=candidate.candidate_id, confidence=max(0.45, 1.0 - candidate.uncertainty_score), rationale_brief=f"Candidate-selected action ({','.join(candidate.rationale_tags[:2]) or 'rule'})", ) def step(self, action: PolyGuardAction | dict) -> tuple[PolyGuardObservation, float, bool, dict]: step_started_at = time.monotonic() state = self.state if state.done: observation = self._build_observation() reward = self._last_reward.total_reward if self._last_reward else 0.001 info = { "termination_reason": "already_done", "reward_breakdown": self._last_reward.model_dump(mode="json") if self._last_reward else {}, "transition_delta": {"applied": False, "reason": ["episode_already_complete"], "rolled_back": True}, } return observation, reward, True, info parsed = self._action_from_payload(action) pre_burden = state.burden_score pre_risky_pairs = len(top_risky_pairs([m.drug for m in state.patient.medications])) safety_report = verify_action_legality(state, parsed) legal_candidate_ids = {c.candidate_id for c in build_candidates(state)} anti_cheat = evaluate_anti_cheat(state, parsed, legal_candidate_ids=legal_candidate_ids) if safety_report.legal and not anti_cheat.exploit_detected: transition_delta = apply_transition(state, parsed) else: transition_delta = { "applied": False, "reason": safety_report.violations or anti_cheat.reasons or ["blocked"], "rolled_back": True, } state.action_history.append({"step": state.step_count, "action": parsed.model_dump(mode="json"), "applied": False}) state.step_count += 1 uncertainty_report = self.get_uncertainty_report() reward = compute_reward_breakdown( state=state, action=parsed, safety_report=safety_report, anti_cheat_detected=anti_cheat.exploit_detected, uncertainty=uncertainty_report.overall_uncertainty, pre_burden=pre_burden, pre_risky_pairs=pre_risky_pairs, ) self._last_reward = reward state.cumulative_reward += reward.total_reward elapsed = time.monotonic() - self._episode_started_at done, reason = check_termination_with_timeout( state=state, action=parsed, exploit_detected=anti_cheat.exploit_detected, elapsed_seconds=elapsed, wall_clock_limit_seconds=self._episode_timeout_seconds, ) step_elapsed = time.monotonic() - step_started_at step_timeout = step_elapsed >= self._step_timeout_seconds if step_timeout and not done: done = True reason = "step_timeout" state.done = done invalid_action_count = sum(1 for item in state.action_history if item.get("applied") is False) transition_failures = transition_delta.get("reason", []) if isinstance(transition_failures, str): transition_failures = [transition_failures] failure_reasons = list(dict.fromkeys([*safety_report.violations, *anti_cheat.reasons, *transition_failures])) observation = self._build_observation() self._trace.append( StepTrace( step=state.step_count, observation_snapshot=observation, selected_action=parsed, critic_output={"safety_report": safety_report.model_dump(mode="json"), "anti_cheat": anti_cheat.reasons}, reward_components=reward.model_dump(mode="json"), transition_delta=transition_delta, uncertainty_report=uncertainty_report, failure_reasons=failure_reasons, timeout=bool(step_timeout or reason == "wall_clock_timeout"), ) ) info = { "termination_reason": reason, "safety_report": safety_report.model_dump(mode="json"), "anti_cheat_reasons": anti_cheat.reasons, "reward_breakdown": reward.model_dump(mode="json"), "primary_reward_channels": { "safety_legality": reward.primary_safety_legality, "clinical_improvement": reward.primary_clinical_improvement, "dosing_quality": reward.primary_dosing_quality, "process_integrity": reward.primary_process_integrity, }, "failure_reasons": failure_reasons, "transition_delta": transition_delta, "step_timeout": step_timeout, "episode_elapsed_seconds": round(elapsed, 3), "step_elapsed_seconds": round(step_elapsed, 3), "invalid_action_count": invalid_action_count, "checks": { "anti_cheat": bool(anti_cheat.reasons), "timeout": bool(step_timeout or reason == "wall_clock_timeout"), "parser_exploit": "parser_exploit_pattern" in anti_cheat.reasons, "legality_gate": bool(safety_report.legal), }, } return observation, reward.total_reward, done, info def get_state(self) -> dict: return self.state.model_dump(mode="json") def get_reward_breakdown(self) -> dict: return self._last_reward.model_dump(mode="json") if self._last_reward else {} def get_trace(self) -> list[dict]: return [item.model_dump(mode="json") for item in self._trace] def get_legal_actions(self) -> list[dict]: obs = self._build_observation() return [ self._action_from_payload(candidate.model_dump(mode="json")).model_dump(mode="json") for candidate in obs.candidate_action_set if candidate.legality_precheck ] def get_candidate_actions(self) -> list[dict]: obs = self._build_observation() return [candidate.model_dump(mode="json") for candidate in obs.candidate_action_set] def get_metadata(self) -> dict[str, object]: return { "name": "polyguard-openenv", "description": ( "Polypharmacy safety and optimization environment with constrained " "actions, reward decomposition, and OpenEnv-compatible APIs." ), "version": "0.2.0", "openenv_mode": "simulation", "reward_range": [0.001, 0.999], "reward_precision": 3, "action_schema": "PolyGuardAction (strict)", "observation_schema": "PolyGuardObservation", "state_schema": "PolyGuardState", "step_timeout_seconds": self._step_timeout_seconds, "episode_timeout_seconds": self._episode_timeout_seconds, } def get_uncertainty_report(self) -> UncertaintyReport: state = self.state uncertainty = estimate_uncertainty(state) missing_flags: list[str] = [] if state.patient.labs.egfr is None: missing_flags.append("missing_egfr") if state.patient.labs.ast is None or state.patient.labs.alt is None: missing_flags.append("missing_liver_enzymes") return UncertaintyReport( overall_uncertainty=uncertainty, missing_data_flags=missing_flags, abstention_recommended=uncertainty > 0.65, )