| """Simulator rollout helpers with scenario perturbations for evaluation.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Any |
|
|
| from app.agents.orchestrator import Orchestrator |
| from app.common.enums import DoseBucket |
| from app.env.env_core import PolyGuardEnv |
|
|
|
|
| def _apply_perturbation(env: PolyGuardEnv, perturbation: str | None) -> None: |
| if not perturbation: |
| return |
|
|
| state = env.state |
| meds = state.patient.medications |
|
|
| if perturbation == "missing_labs": |
| state.patient.labs.egfr = None |
| state.patient.labs.ast = None |
| state.patient.labs.alt = None |
| elif perturbation == "noisy_dose_info": |
| for idx, med in enumerate(meds): |
| if idx % 2 == 0: |
| med.dose_bucket = DoseBucket.HIGH if med.dose_bucket != DoseBucket.HIGH else DoseBucket.LOW |
| elif perturbation == "conflicting_meds" and meds: |
| meds.append(meds[0].model_copy()) |
| elif perturbation == "alias_noise" and meds: |
| meds[0].drug = f"{meds[0].drug}_alias" |
| elif perturbation == "hidden_duplicate" and meds: |
| meds.append(meds[0].model_copy(update={"drug": meds[0].drug})) |
| elif perturbation == "stale_evidence": |
| state.unresolved_conflicts.append("evidence_stale") |
| elif perturbation == "delayed_ade": |
| state.patient.latent_confounders["delayed_ade"] = 0.8 |
|
|
|
|
| def run_rollouts( |
| episodes: int = 5, |
| difficulty: str = "medium", |
| sub_environment: str | None = None, |
| perturbation: str | None = None, |
| seed_offset: int = 900, |
| policy_stack: str = "llm+bandit", |
| ) -> list[dict[str, Any]]: |
| previous_policy = os.getenv("POLYGUARD_POLICY_STACK") |
| os.environ["POLYGUARD_POLICY_STACK"] = policy_stack |
|
|
| env = PolyGuardEnv() |
| orchestrator = Orchestrator(env) |
| rows: list[dict[str, Any]] = [] |
|
|
| for i in range(episodes): |
| env.reset(seed=seed_offset + i, difficulty=difficulty, sub_environment=sub_environment) |
| _apply_perturbation(env, perturbation=perturbation) |
|
|
| done = False |
| while not done: |
| out = orchestrator.run_step() |
| done = bool(out.get("done")) |
| info = out.get("info", {}) if isinstance(out.get("info", {}), dict) else {} |
| critic = out.get("critic", {}) if isinstance(out.get("critic", {}), dict) else {} |
| reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info.get("reward_breakdown", {}), dict) else {} |
| primary_channels = ( |
| info.get("primary_reward_channels", {}) |
| if isinstance(info.get("primary_reward_channels", {}), dict) |
| else {} |
| ) |
| final_action = out.get("final_action", {}) if isinstance(out.get("final_action", {}), dict) else {} |
|
|
| rows.append( |
| { |
| "episode": i, |
| "step": int(env.state.step_count), |
| "reward": float(out.get("reward", 0.0)), |
| "done": done, |
| "legal": bool(critic.get("legal", False)), |
| "severe_violation": len(critic.get("violations", [])) > 1, |
| "abstain": str(final_action.get("action_type", "")).startswith("REQUEST_"), |
| "termination_reason": info.get("termination_reason"), |
| "step_timeout": bool(info.get("step_timeout")), |
| "failure_reasons": info.get("failure_reasons", []), |
| "invalid_action_count": int(info.get("invalid_action_count", 0)), |
| "reward_breakdown": reward_breakdown, |
| "primary_reward_channels": primary_channels, |
| "policy_stack": policy_stack, |
| "difficulty": difficulty, |
| "sub_environment": sub_environment, |
| "perturbation": perturbation, |
| } |
| ) |
|
|
| if previous_policy is None: |
| os.environ.pop("POLYGUARD_POLICY_STACK", None) |
| else: |
| os.environ["POLYGUARD_POLICY_STACK"] = previous_policy |
|
|
| return rows |
|
|