Spaces:
Running
Running
| """Core PolyGuard environment implementation.""" | |
| from __future__ import annotations | |
| import time | |
| import uuid | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| from app.common.constants import ( | |
| DEFAULT_EPISODE_TIMEOUT_SECONDS, | |
| DEFAULT_MAX_STEPS, | |
| DEFAULT_SEED, | |
| DEFAULT_STEP_TIMEOUT_SECONDS, | |
| ) | |
| from app.common.enums import Difficulty, SubEnvironment | |
| from app.common.seeding import set_global_seed | |
| from app.common.types import ( | |
| CandidateAction, | |
| PolyGuardAction, | |
| PolyGuardObservation, | |
| PolyGuardState, | |
| RewardBreakdown, | |
| StepTrace, | |
| UncertaintyReport, | |
| ) | |
| from app.env.anti_cheat import evaluate_anti_cheat | |
| from app.env.curriculum import pick_difficulty, pick_sub_environment | |
| from app.env.reward_router import compute_reward_breakdown | |
| from app.env.scenario_loader import load_or_generate_scenario | |
| from app.env.termination import check_termination_with_timeout | |
| from app.env.transition import apply_transition | |
| from app.env.verifier import verify_action_legality | |
| from app.knowledge.ddi_knowledge import top_risky_pairs | |
| from app.models.policy.candidate_builder import build_candidates | |
| from app.models.policy.uncertainty import estimate_uncertainty | |
| class PolyGuardEnv: | |
| def __init__(self, root: Optional[Path] = None) -> None: | |
| self.root = root or Path(__file__).resolve().parents[2] | |
| self._episode_index = 0 | |
| self._state: Optional[PolyGuardState] = None | |
| self._trace: list[StepTrace] = [] | |
| self._last_reward: Optional[RewardBreakdown] = None | |
| self._episode_started_at: float = 0.0 | |
| self._episode_timeout_seconds: float = float( | |
| os.getenv("POLYGUARD_EPISODE_TIMEOUT_SECONDS", str(DEFAULT_EPISODE_TIMEOUT_SECONDS)) | |
| ) | |
| self._step_timeout_seconds: float = float( | |
| os.getenv("POLYGUARD_STEP_TIMEOUT_SECONDS", str(DEFAULT_STEP_TIMEOUT_SECONDS)) | |
| ) | |
| def state(self) -> PolyGuardState: | |
| if self._state is None: | |
| raise RuntimeError("Environment has not been reset.") | |
| return self._state | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| difficulty: Optional[str] = None, | |
| sub_environment: Optional[str] = None, | |
| scenario_id: Optional[str] = None, | |
| patient_id: Optional[str] = None, | |
| ) -> PolyGuardObservation: | |
| run_seed = set_global_seed(seed if seed is not None else DEFAULT_SEED) | |
| diff = Difficulty(difficulty) if difficulty else pick_difficulty(self._episode_index) | |
| if sub_environment: | |
| chosen_sub_environment = SubEnvironment(sub_environment) | |
| else: | |
| chosen_sub_environment = pick_sub_environment(self._episode_index, diff) | |
| patient = load_or_generate_scenario( | |
| root=self.root, | |
| difficulty=diff, | |
| scenario_id=scenario_id, | |
| patient_id=patient_id, | |
| seed=run_seed, | |
| ) | |
| scenario_key = scenario_id or patient.patient_id | |
| max_steps = { | |
| SubEnvironment.DDI: 3, | |
| SubEnvironment.REGIMEN_RISK: 6, | |
| SubEnvironment.BANDIT_MINING: 6, | |
| SubEnvironment.PRECISION_DOSING: 8, | |
| SubEnvironment.LONGITUDINAL_DEPRESCRIBING: 10, | |
| SubEnvironment.WEB_SEARCH_MISSING_DATA: 5, | |
| SubEnvironment.ALTERNATIVE_SUGGESTION: 6, | |
| SubEnvironment.NEW_DRUG_DECOMPOSITION: 7, | |
| }.get(chosen_sub_environment, { | |
| Difficulty.EASY: 3, | |
| Difficulty.MEDIUM: 6, | |
| Difficulty.HARD: 10, | |
| }.get(diff, DEFAULT_MAX_STEPS)) | |
| risky_pairs = top_risky_pairs([m.drug for m in patient.medications]) | |
| self._state = PolyGuardState( | |
| episode_id=f"ep_{uuid.uuid4().hex[:8]}", | |
| seed=run_seed, | |
| scenario_id=scenario_key, | |
| difficulty=diff, | |
| sub_environment=chosen_sub_environment, | |
| step_count=0, | |
| max_steps=max_steps, | |
| patient=patient, | |
| risk_summary={ | |
| "polypharmacy_count": float(len(patient.medications)), | |
| "burden_score": len(patient.medications) / 12.0, | |
| "severe_pair_count": float(len(risky_pairs)), | |
| }, | |
| burden_score=min(1.0, len(patient.medications) / 12.0), | |
| precision_dosing_flags=["dose_sensitive_case"] if chosen_sub_environment == SubEnvironment.PRECISION_DOSING else [], | |
| unresolved_conflicts=list(patient.specialist_conflicts), | |
| ) | |
| self._trace = [] | |
| self._last_reward = None | |
| self._episode_started_at = time.monotonic() | |
| self._episode_index += 1 | |
| obs = self._build_observation() | |
| self._trace.append( | |
| StepTrace( | |
| step=0, | |
| observation_snapshot=obs, | |
| reward_components={}, | |
| ) | |
| ) | |
| return obs | |
| def _build_observation(self) -> PolyGuardObservation: | |
| state = self.state | |
| candidates = build_candidates(state) | |
| uncertainty = estimate_uncertainty(state) | |
| risky_pairs = top_risky_pairs([m.drug for m in state.patient.medications]) | |
| warning_summary: list[str] = [] | |
| if state.burden_score >= 0.7: | |
| warning_summary.append("high_polypharmacy_burden") | |
| if state.patient.monitoring_gaps: | |
| warning_summary.extend([f"monitoring_gap:{gap}" for gap in state.patient.monitoring_gaps[:2]]) | |
| if state.sub_environment == SubEnvironment.WEB_SEARCH_MISSING_DATA: | |
| warning_summary.append("missing_data_web_evidence_recommended") | |
| if state.sub_environment == SubEnvironment.NEW_DRUG_DECOMPOSITION: | |
| warning_summary.append("new_drug_component_analysis_recommended") | |
| return PolyGuardObservation( | |
| patient_summary={ | |
| "patient_id": state.patient.patient_id, | |
| "age": state.patient.age, | |
| "sex": state.patient.sex, | |
| "adherence_estimate": state.patient.adherence_estimate, | |
| "sub_environment": state.sub_environment.value, | |
| }, | |
| medication_table=[m.model_dump(mode="json") for m in state.patient.medications], | |
| comorbidity_summary=state.patient.comorbidities, | |
| organ_function_summary={ | |
| "egfr": state.patient.labs.egfr, | |
| "ast": state.patient.labs.ast, | |
| "alt": state.patient.labs.alt, | |
| }, | |
| labs_vitals_summary={**state.patient.labs.model_dump(mode="json"), **state.patient.vitals}, | |
| graph_safety_summary={ | |
| "top_risk_pairs": risky_pairs, | |
| "polypharmacy_count": len(state.patient.medications), | |
| "estimated_risk": state.risk_summary.get("burden_score", 0.5), | |
| }, | |
| burden_score_summary={"burden_score": state.burden_score}, | |
| precision_dosing_flags=state.precision_dosing_flags, | |
| unresolved_conflicts=state.unresolved_conflicts, | |
| candidate_action_set=candidates, | |
| step_budget_remaining=max(0, state.max_steps - state.step_count), | |
| action_history=state.action_history, | |
| warning_summary=warning_summary, | |
| abstention_indicators={"uncertainty": uncertainty, "recommended": uncertainty > 0.65}, | |
| sub_environment=state.sub_environment, | |
| deterministic_contract={ | |
| "seed": state.seed, | |
| "scenario_id": state.scenario_id, | |
| "difficulty": state.difficulty.value, | |
| "sub_environment": state.sub_environment.value, | |
| }, | |
| ) | |
| def _action_from_payload(action: PolyGuardAction | dict) -> PolyGuardAction: | |
| if isinstance(action, PolyGuardAction): | |
| return action | |
| if not isinstance(action, dict): | |
| raise ValueError("Action must be a PolyGuardAction or dictionary payload.") | |
| try: | |
| return PolyGuardAction.model_validate(action) | |
| except Exception: # noqa: BLE001 | |
| candidate = CandidateAction.model_validate(action) | |
| return PolyGuardAction( | |
| mode=candidate.mode, | |
| action_type=candidate.action_type, | |
| target_drug=candidate.target_drug, | |
| replacement_drug=candidate.replacement_drug, | |
| dose_bucket=candidate.dose_bucket, | |
| taper_days=candidate.taper_days, | |
| monitoring_plan=candidate.monitoring_plan, | |
| evidence_query=candidate.evidence_query, | |
| new_drug_name=candidate.new_drug_name, | |
| candidate_components=candidate.candidate_components, | |
| candidate_id=candidate.candidate_id, | |
| confidence=max(0.45, 1.0 - candidate.uncertainty_score), | |
| rationale_brief=f"Candidate-selected action ({','.join(candidate.rationale_tags[:2]) or 'rule'})", | |
| ) | |
| def step(self, action: PolyGuardAction | dict) -> tuple[PolyGuardObservation, float, bool, dict]: | |
| step_started_at = time.monotonic() | |
| state = self.state | |
| parsed = self._action_from_payload(action) | |
| pre_burden = state.burden_score | |
| pre_risky_pairs = len(top_risky_pairs([m.drug for m in state.patient.medications])) | |
| safety_report = verify_action_legality(state, parsed) | |
| legal_candidate_ids = {c.candidate_id for c in build_candidates(state)} | |
| anti_cheat = evaluate_anti_cheat(state, parsed, legal_candidate_ids=legal_candidate_ids) | |
| if safety_report.legal and not anti_cheat.exploit_detected: | |
| transition_delta = apply_transition(state, parsed) | |
| else: | |
| transition_delta = { | |
| "applied": False, | |
| "reason": safety_report.violations or anti_cheat.reasons or ["blocked"], | |
| "rolled_back": True, | |
| } | |
| state.action_history.append({"step": state.step_count, "action": parsed.model_dump(mode="json"), "applied": False}) | |
| state.step_count += 1 | |
| uncertainty_report = self.get_uncertainty_report() | |
| reward = compute_reward_breakdown( | |
| state=state, | |
| action=parsed, | |
| safety_report=safety_report, | |
| anti_cheat_detected=anti_cheat.exploit_detected, | |
| uncertainty=uncertainty_report.overall_uncertainty, | |
| pre_burden=pre_burden, | |
| pre_risky_pairs=pre_risky_pairs, | |
| ) | |
| self._last_reward = reward | |
| state.cumulative_reward += reward.total_reward | |
| elapsed = time.monotonic() - self._episode_started_at | |
| done, reason = check_termination_with_timeout( | |
| state=state, | |
| action=parsed, | |
| exploit_detected=anti_cheat.exploit_detected, | |
| elapsed_seconds=elapsed, | |
| wall_clock_limit_seconds=self._episode_timeout_seconds, | |
| ) | |
| step_elapsed = time.monotonic() - step_started_at | |
| step_timeout = step_elapsed >= self._step_timeout_seconds | |
| if step_timeout and not done: | |
| done = True | |
| reason = "step_timeout" | |
| state.done = done | |
| invalid_action_count = sum(1 for item in state.action_history if item.get("applied") is False) | |
| transition_failures = transition_delta.get("reason", []) | |
| if isinstance(transition_failures, str): | |
| transition_failures = [transition_failures] | |
| failure_reasons = list(dict.fromkeys([*safety_report.violations, *anti_cheat.reasons, *transition_failures])) | |
| observation = self._build_observation() | |
| self._trace.append( | |
| StepTrace( | |
| step=state.step_count, | |
| observation_snapshot=observation, | |
| selected_action=parsed, | |
| critic_output={"safety_report": safety_report.model_dump(mode="json"), "anti_cheat": anti_cheat.reasons}, | |
| reward_components=reward.model_dump(mode="json"), | |
| transition_delta=transition_delta, | |
| uncertainty_report=uncertainty_report, | |
| failure_reasons=failure_reasons, | |
| timeout=bool(step_timeout or reason == "wall_clock_timeout"), | |
| ) | |
| ) | |
| info = { | |
| "termination_reason": reason, | |
| "safety_report": safety_report.model_dump(mode="json"), | |
| "anti_cheat_reasons": anti_cheat.reasons, | |
| "reward_breakdown": reward.model_dump(mode="json"), | |
| "primary_reward_channels": { | |
| "safety_legality": reward.primary_safety_legality, | |
| "clinical_improvement": reward.primary_clinical_improvement, | |
| "dosing_quality": reward.primary_dosing_quality, | |
| "process_integrity": reward.primary_process_integrity, | |
| }, | |
| "failure_reasons": failure_reasons, | |
| "transition_delta": transition_delta, | |
| "step_timeout": step_timeout, | |
| "episode_elapsed_seconds": round(elapsed, 3), | |
| "step_elapsed_seconds": round(step_elapsed, 3), | |
| "invalid_action_count": invalid_action_count, | |
| "checks": { | |
| "anti_cheat": bool(anti_cheat.reasons), | |
| "timeout": bool(step_timeout or reason == "wall_clock_timeout"), | |
| "parser_exploit": "parser_exploit_pattern" in anti_cheat.reasons, | |
| "legality_gate": bool(safety_report.legal), | |
| }, | |
| } | |
| return observation, reward.total_reward, done, info | |
| def get_state(self) -> dict: | |
| return self.state.model_dump(mode="json") | |
| def get_reward_breakdown(self) -> dict: | |
| return self._last_reward.model_dump(mode="json") if self._last_reward else {} | |
| def get_trace(self) -> list[dict]: | |
| return [item.model_dump(mode="json") for item in self._trace] | |
| def get_legal_actions(self) -> list[dict]: | |
| obs = self._build_observation() | |
| return [ | |
| self._action_from_payload(candidate.model_dump(mode="json")).model_dump(mode="json") | |
| for candidate in obs.candidate_action_set | |
| if candidate.legality_precheck | |
| ] | |
| def get_candidate_actions(self) -> list[dict]: | |
| obs = self._build_observation() | |
| return [candidate.model_dump(mode="json") for candidate in obs.candidate_action_set] | |
| def get_metadata(self) -> dict[str, object]: | |
| return { | |
| "name": "polyguard-openenv", | |
| "description": ( | |
| "Polypharmacy safety and optimization environment with constrained " | |
| "actions, reward decomposition, and OpenEnv-compatible APIs." | |
| ), | |
| "version": "0.2.0", | |
| "openenv_mode": "simulation", | |
| "reward_range": [0.001, 0.999], | |
| "reward_precision": 3, | |
| "action_schema": "PolyGuardAction (strict)", | |
| "observation_schema": "PolyGuardObservation", | |
| "state_schema": "PolyGuardState", | |
| "step_timeout_seconds": self._step_timeout_seconds, | |
| "episode_timeout_seconds": self._episode_timeout_seconds, | |
| } | |
| def get_uncertainty_report(self) -> UncertaintyReport: | |
| state = self.state | |
| uncertainty = estimate_uncertainty(state) | |
| missing_flags: list[str] = [] | |
| if state.patient.labs.egfr is None: | |
| missing_flags.append("missing_egfr") | |
| if state.patient.labs.ast is None or state.patient.labs.alt is None: | |
| missing_flags.append("missing_liver_enzymes") | |
| return UncertaintyReport( | |
| overall_uncertainty=uncertainty, | |
| missing_data_flags=missing_flags, | |
| abstention_recommended=uncertainty > 0.65, | |
| ) | |