TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Core PolyGuard environment implementation."""
from __future__ import annotations
import time
import uuid
import os
from pathlib import Path
from typing import Optional
from app.common.constants import (
DEFAULT_EPISODE_TIMEOUT_SECONDS,
DEFAULT_MAX_STEPS,
DEFAULT_SEED,
DEFAULT_STEP_TIMEOUT_SECONDS,
)
from app.common.enums import Difficulty, SubEnvironment
from app.common.seeding import set_global_seed
from app.common.types import (
CandidateAction,
PolyGuardAction,
PolyGuardObservation,
PolyGuardState,
RewardBreakdown,
StepTrace,
UncertaintyReport,
)
from app.env.anti_cheat import evaluate_anti_cheat
from app.env.curriculum import pick_difficulty, pick_sub_environment
from app.env.reward_router import compute_reward_breakdown
from app.env.scenario_loader import load_or_generate_scenario
from app.env.termination import check_termination_with_timeout
from app.env.transition import apply_transition
from app.env.verifier import verify_action_legality
from app.knowledge.ddi_knowledge import top_risky_pairs
from app.models.policy.candidate_builder import build_candidates
from app.models.policy.uncertainty import estimate_uncertainty
class PolyGuardEnv:
def __init__(self, root: Optional[Path] = None) -> None:
self.root = root or Path(__file__).resolve().parents[2]
self._episode_index = 0
self._state: Optional[PolyGuardState] = None
self._trace: list[StepTrace] = []
self._last_reward: Optional[RewardBreakdown] = None
self._episode_started_at: float = 0.0
self._episode_timeout_seconds: float = float(
os.getenv("POLYGUARD_EPISODE_TIMEOUT_SECONDS", str(DEFAULT_EPISODE_TIMEOUT_SECONDS))
)
self._step_timeout_seconds: float = float(
os.getenv("POLYGUARD_STEP_TIMEOUT_SECONDS", str(DEFAULT_STEP_TIMEOUT_SECONDS))
)
@property
def state(self) -> PolyGuardState:
if self._state is None:
raise RuntimeError("Environment has not been reset.")
return self._state
def reset(
self,
seed: Optional[int] = None,
difficulty: Optional[str] = None,
sub_environment: Optional[str] = None,
scenario_id: Optional[str] = None,
patient_id: Optional[str] = None,
) -> PolyGuardObservation:
run_seed = set_global_seed(seed if seed is not None else DEFAULT_SEED)
diff = Difficulty(difficulty) if difficulty else pick_difficulty(self._episode_index)
if sub_environment:
chosen_sub_environment = SubEnvironment(sub_environment)
else:
chosen_sub_environment = pick_sub_environment(self._episode_index, diff)
patient = load_or_generate_scenario(
root=self.root,
difficulty=diff,
scenario_id=scenario_id,
patient_id=patient_id,
seed=run_seed,
)
scenario_key = scenario_id or patient.patient_id
max_steps = {
SubEnvironment.DDI: 3,
SubEnvironment.REGIMEN_RISK: 6,
SubEnvironment.BANDIT_MINING: 6,
SubEnvironment.PRECISION_DOSING: 8,
SubEnvironment.LONGITUDINAL_DEPRESCRIBING: 10,
SubEnvironment.WEB_SEARCH_MISSING_DATA: 5,
SubEnvironment.ALTERNATIVE_SUGGESTION: 6,
SubEnvironment.NEW_DRUG_DECOMPOSITION: 7,
}.get(chosen_sub_environment, {
Difficulty.EASY: 3,
Difficulty.MEDIUM: 6,
Difficulty.HARD: 10,
}.get(diff, DEFAULT_MAX_STEPS))
risky_pairs = top_risky_pairs([m.drug for m in patient.medications])
self._state = PolyGuardState(
episode_id=f"ep_{uuid.uuid4().hex[:8]}",
seed=run_seed,
scenario_id=scenario_key,
difficulty=diff,
sub_environment=chosen_sub_environment,
step_count=0,
max_steps=max_steps,
patient=patient,
risk_summary={
"polypharmacy_count": float(len(patient.medications)),
"burden_score": len(patient.medications) / 12.0,
"severe_pair_count": float(len(risky_pairs)),
},
burden_score=min(1.0, len(patient.medications) / 12.0),
precision_dosing_flags=["dose_sensitive_case"] if chosen_sub_environment == SubEnvironment.PRECISION_DOSING else [],
unresolved_conflicts=list(patient.specialist_conflicts),
)
self._trace = []
self._last_reward = None
self._episode_started_at = time.monotonic()
self._episode_index += 1
obs = self._build_observation()
self._trace.append(
StepTrace(
step=0,
observation_snapshot=obs,
reward_components={},
)
)
return obs
def _build_observation(self) -> PolyGuardObservation:
state = self.state
candidates = build_candidates(state)
uncertainty = estimate_uncertainty(state)
risky_pairs = top_risky_pairs([m.drug for m in state.patient.medications])
warning_summary: list[str] = []
if state.burden_score >= 0.7:
warning_summary.append("high_polypharmacy_burden")
if state.patient.monitoring_gaps:
warning_summary.extend([f"monitoring_gap:{gap}" for gap in state.patient.monitoring_gaps[:2]])
if state.sub_environment == SubEnvironment.WEB_SEARCH_MISSING_DATA:
warning_summary.append("missing_data_web_evidence_recommended")
if state.sub_environment == SubEnvironment.NEW_DRUG_DECOMPOSITION:
warning_summary.append("new_drug_component_analysis_recommended")
return PolyGuardObservation(
patient_summary={
"patient_id": state.patient.patient_id,
"age": state.patient.age,
"sex": state.patient.sex,
"adherence_estimate": state.patient.adherence_estimate,
"sub_environment": state.sub_environment.value,
},
medication_table=[m.model_dump(mode="json") for m in state.patient.medications],
comorbidity_summary=state.patient.comorbidities,
organ_function_summary={
"egfr": state.patient.labs.egfr,
"ast": state.patient.labs.ast,
"alt": state.patient.labs.alt,
},
labs_vitals_summary={**state.patient.labs.model_dump(mode="json"), **state.patient.vitals},
graph_safety_summary={
"top_risk_pairs": risky_pairs,
"polypharmacy_count": len(state.patient.medications),
"estimated_risk": state.risk_summary.get("burden_score", 0.5),
},
burden_score_summary={"burden_score": state.burden_score},
precision_dosing_flags=state.precision_dosing_flags,
unresolved_conflicts=state.unresolved_conflicts,
candidate_action_set=candidates,
step_budget_remaining=max(0, state.max_steps - state.step_count),
action_history=state.action_history,
warning_summary=warning_summary,
abstention_indicators={"uncertainty": uncertainty, "recommended": uncertainty > 0.65},
sub_environment=state.sub_environment,
deterministic_contract={
"seed": state.seed,
"scenario_id": state.scenario_id,
"difficulty": state.difficulty.value,
"sub_environment": state.sub_environment.value,
},
)
@staticmethod
def _action_from_payload(action: PolyGuardAction | dict) -> PolyGuardAction:
if isinstance(action, PolyGuardAction):
return action
if not isinstance(action, dict):
raise ValueError("Action must be a PolyGuardAction or dictionary payload.")
try:
return PolyGuardAction.model_validate(action)
except Exception: # noqa: BLE001
candidate = CandidateAction.model_validate(action)
return PolyGuardAction(
mode=candidate.mode,
action_type=candidate.action_type,
target_drug=candidate.target_drug,
replacement_drug=candidate.replacement_drug,
dose_bucket=candidate.dose_bucket,
taper_days=candidate.taper_days,
monitoring_plan=candidate.monitoring_plan,
evidence_query=candidate.evidence_query,
new_drug_name=candidate.new_drug_name,
candidate_components=candidate.candidate_components,
candidate_id=candidate.candidate_id,
confidence=max(0.45, 1.0 - candidate.uncertainty_score),
rationale_brief=f"Candidate-selected action ({','.join(candidate.rationale_tags[:2]) or 'rule'})",
)
def step(self, action: PolyGuardAction | dict) -> tuple[PolyGuardObservation, float, bool, dict]:
step_started_at = time.monotonic()
state = self.state
if state.done:
observation = self._build_observation()
reward = self._last_reward.total_reward if self._last_reward else 0.001
info = {
"termination_reason": "already_done",
"reward_breakdown": self._last_reward.model_dump(mode="json") if self._last_reward else {},
"transition_delta": {"applied": False, "reason": ["episode_already_complete"], "rolled_back": True},
}
return observation, reward, True, info
parsed = self._action_from_payload(action)
pre_burden = state.burden_score
pre_risky_pairs = len(top_risky_pairs([m.drug for m in state.patient.medications]))
safety_report = verify_action_legality(state, parsed)
legal_candidate_ids = {c.candidate_id for c in build_candidates(state)}
anti_cheat = evaluate_anti_cheat(state, parsed, legal_candidate_ids=legal_candidate_ids)
if safety_report.legal and not anti_cheat.exploit_detected:
transition_delta = apply_transition(state, parsed)
else:
transition_delta = {
"applied": False,
"reason": safety_report.violations or anti_cheat.reasons or ["blocked"],
"rolled_back": True,
}
state.action_history.append({"step": state.step_count, "action": parsed.model_dump(mode="json"), "applied": False})
state.step_count += 1
uncertainty_report = self.get_uncertainty_report()
reward = compute_reward_breakdown(
state=state,
action=parsed,
safety_report=safety_report,
anti_cheat_detected=anti_cheat.exploit_detected,
uncertainty=uncertainty_report.overall_uncertainty,
pre_burden=pre_burden,
pre_risky_pairs=pre_risky_pairs,
)
self._last_reward = reward
state.cumulative_reward += reward.total_reward
elapsed = time.monotonic() - self._episode_started_at
done, reason = check_termination_with_timeout(
state=state,
action=parsed,
exploit_detected=anti_cheat.exploit_detected,
elapsed_seconds=elapsed,
wall_clock_limit_seconds=self._episode_timeout_seconds,
)
step_elapsed = time.monotonic() - step_started_at
step_timeout = step_elapsed >= self._step_timeout_seconds
if step_timeout and not done:
done = True
reason = "step_timeout"
state.done = done
invalid_action_count = sum(1 for item in state.action_history if item.get("applied") is False)
transition_failures = transition_delta.get("reason", [])
if isinstance(transition_failures, str):
transition_failures = [transition_failures]
failure_reasons = list(dict.fromkeys([*safety_report.violations, *anti_cheat.reasons, *transition_failures]))
observation = self._build_observation()
self._trace.append(
StepTrace(
step=state.step_count,
observation_snapshot=observation,
selected_action=parsed,
critic_output={"safety_report": safety_report.model_dump(mode="json"), "anti_cheat": anti_cheat.reasons},
reward_components=reward.model_dump(mode="json"),
transition_delta=transition_delta,
uncertainty_report=uncertainty_report,
failure_reasons=failure_reasons,
timeout=bool(step_timeout or reason == "wall_clock_timeout"),
)
)
info = {
"termination_reason": reason,
"safety_report": safety_report.model_dump(mode="json"),
"anti_cheat_reasons": anti_cheat.reasons,
"reward_breakdown": reward.model_dump(mode="json"),
"primary_reward_channels": {
"safety_legality": reward.primary_safety_legality,
"clinical_improvement": reward.primary_clinical_improvement,
"dosing_quality": reward.primary_dosing_quality,
"process_integrity": reward.primary_process_integrity,
},
"failure_reasons": failure_reasons,
"transition_delta": transition_delta,
"step_timeout": step_timeout,
"episode_elapsed_seconds": round(elapsed, 3),
"step_elapsed_seconds": round(step_elapsed, 3),
"invalid_action_count": invalid_action_count,
"checks": {
"anti_cheat": bool(anti_cheat.reasons),
"timeout": bool(step_timeout or reason == "wall_clock_timeout"),
"parser_exploit": "parser_exploit_pattern" in anti_cheat.reasons,
"legality_gate": bool(safety_report.legal),
},
}
return observation, reward.total_reward, done, info
def get_state(self) -> dict:
return self.state.model_dump(mode="json")
def get_reward_breakdown(self) -> dict:
return self._last_reward.model_dump(mode="json") if self._last_reward else {}
def get_trace(self) -> list[dict]:
return [item.model_dump(mode="json") for item in self._trace]
def get_legal_actions(self) -> list[dict]:
obs = self._build_observation()
return [
self._action_from_payload(candidate.model_dump(mode="json")).model_dump(mode="json")
for candidate in obs.candidate_action_set
if candidate.legality_precheck
]
def get_candidate_actions(self) -> list[dict]:
obs = self._build_observation()
return [candidate.model_dump(mode="json") for candidate in obs.candidate_action_set]
def get_metadata(self) -> dict[str, object]:
return {
"name": "polyguard-openenv",
"description": (
"Polypharmacy safety and optimization environment with constrained "
"actions, reward decomposition, and OpenEnv-compatible APIs."
),
"version": "0.2.0",
"openenv_mode": "simulation",
"reward_range": [0.001, 0.999],
"reward_precision": 3,
"action_schema": "PolyGuardAction (strict)",
"observation_schema": "PolyGuardObservation",
"state_schema": "PolyGuardState",
"step_timeout_seconds": self._step_timeout_seconds,
"episode_timeout_seconds": self._episode_timeout_seconds,
}
def get_uncertainty_report(self) -> UncertaintyReport:
state = self.state
uncertainty = estimate_uncertainty(state)
missing_flags: list[str] = []
if state.patient.labs.egfr is None:
missing_flags.append("missing_egfr")
if state.patient.labs.ast is None or state.patient.labs.alt is None:
missing_flags.append("missing_liver_enzymes")
return UncertaintyReport(
overall_uncertainty=uncertainty,
missing_data_flags=missing_flags,
abstention_recommended=uncertainty > 0.65,
)