"""Simulator rollout helpers with scenario perturbations for evaluation.""" from __future__ import annotations import os from typing import Any from app.agents.orchestrator import Orchestrator from app.common.enums import DoseBucket from app.env.env_core import PolyGuardEnv def _apply_perturbation(env: PolyGuardEnv, perturbation: str | None) -> None: if not perturbation: return state = env.state meds = state.patient.medications if perturbation == "missing_labs": state.patient.labs.egfr = None state.patient.labs.ast = None state.patient.labs.alt = None elif perturbation == "noisy_dose_info": for idx, med in enumerate(meds): if idx % 2 == 0: med.dose_bucket = DoseBucket.HIGH if med.dose_bucket != DoseBucket.HIGH else DoseBucket.LOW elif perturbation == "conflicting_meds" and meds: meds.append(meds[0].model_copy()) elif perturbation == "alias_noise" and meds: meds[0].drug = f"{meds[0].drug}_alias" elif perturbation == "hidden_duplicate" and meds: meds.append(meds[0].model_copy(update={"drug": meds[0].drug})) elif perturbation == "stale_evidence": state.unresolved_conflicts.append("evidence_stale") elif perturbation == "delayed_ade": state.patient.latent_confounders["delayed_ade"] = 0.8 def run_rollouts( episodes: int = 5, difficulty: str = "medium", sub_environment: str | None = None, perturbation: str | None = None, seed_offset: int = 900, policy_stack: str = "llm+bandit", ) -> list[dict[str, Any]]: previous_policy = os.getenv("POLYGUARD_POLICY_STACK") os.environ["POLYGUARD_POLICY_STACK"] = policy_stack env = PolyGuardEnv() orchestrator = Orchestrator(env) rows: list[dict[str, Any]] = [] for i in range(episodes): env.reset(seed=seed_offset + i, difficulty=difficulty, sub_environment=sub_environment) _apply_perturbation(env, perturbation=perturbation) done = False while not done: out = orchestrator.run_step() done = bool(out.get("done")) info = out.get("info", {}) if isinstance(out.get("info", {}), dict) else {} critic = out.get("critic", {}) if isinstance(out.get("critic", {}), dict) else {} reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info.get("reward_breakdown", {}), dict) else {} primary_channels = ( info.get("primary_reward_channels", {}) if isinstance(info.get("primary_reward_channels", {}), dict) else {} ) final_action = out.get("final_action", {}) if isinstance(out.get("final_action", {}), dict) else {} rows.append( { "episode": i, "step": int(env.state.step_count), "reward": float(out.get("reward", 0.0)), "done": done, "legal": bool(critic.get("legal", False)), "severe_violation": len(critic.get("violations", [])) > 1, "abstain": str(final_action.get("action_type", "")).startswith("REQUEST_"), "termination_reason": info.get("termination_reason"), "step_timeout": bool(info.get("step_timeout")), "failure_reasons": info.get("failure_reasons", []), "invalid_action_count": int(info.get("invalid_action_count", 0)), "reward_breakdown": reward_breakdown, "primary_reward_channels": primary_channels, "policy_stack": policy_stack, "difficulty": difficulty, "sub_environment": sub_environment, "perturbation": perturbation, } ) if previous_policy is None: os.environ.pop("POLYGUARD_POLICY_STACK", None) else: os.environ["POLYGUARD_POLICY_STACK"] = previous_policy return rows