Spaces:
Running
Running
| """Simulator rollout helpers with scenario perturbations for evaluation.""" | |
| from __future__ import annotations | |
| import os | |
| from typing import Any | |
| from app.agents.orchestrator import Orchestrator | |
| from app.common.enums import DoseBucket | |
| from app.env.env_core import PolyGuardEnv | |
| def _apply_perturbation(env: PolyGuardEnv, perturbation: str | None) -> None: | |
| if not perturbation: | |
| return | |
| state = env.state | |
| meds = state.patient.medications | |
| if perturbation == "missing_labs": | |
| state.patient.labs.egfr = None | |
| state.patient.labs.ast = None | |
| state.patient.labs.alt = None | |
| elif perturbation == "noisy_dose_info": | |
| for idx, med in enumerate(meds): | |
| if idx % 2 == 0: | |
| med.dose_bucket = DoseBucket.HIGH if med.dose_bucket != DoseBucket.HIGH else DoseBucket.LOW | |
| elif perturbation == "conflicting_meds" and meds: | |
| meds.append(meds[0].model_copy()) | |
| elif perturbation == "alias_noise" and meds: | |
| meds[0].drug = f"{meds[0].drug}_alias" | |
| elif perturbation == "hidden_duplicate" and meds: | |
| meds.append(meds[0].model_copy(update={"drug": meds[0].drug})) | |
| elif perturbation == "stale_evidence": | |
| state.unresolved_conflicts.append("evidence_stale") | |
| elif perturbation == "delayed_ade": | |
| state.patient.latent_confounders["delayed_ade"] = 0.8 | |
| def run_rollouts( | |
| episodes: int = 5, | |
| difficulty: str = "medium", | |
| sub_environment: str | None = None, | |
| perturbation: str | None = None, | |
| seed_offset: int = 900, | |
| policy_stack: str = "llm+bandit", | |
| ) -> list[dict[str, Any]]: | |
| previous_policy = os.getenv("POLYGUARD_POLICY_STACK") | |
| os.environ["POLYGUARD_POLICY_STACK"] = policy_stack | |
| env = PolyGuardEnv() | |
| orchestrator = Orchestrator(env) | |
| rows: list[dict[str, Any]] = [] | |
| for i in range(episodes): | |
| env.reset(seed=seed_offset + i, difficulty=difficulty, sub_environment=sub_environment) | |
| _apply_perturbation(env, perturbation=perturbation) | |
| done = False | |
| while not done: | |
| out = orchestrator.run_step() | |
| done = bool(out.get("done")) | |
| info = out.get("info", {}) if isinstance(out.get("info", {}), dict) else {} | |
| critic = out.get("critic", {}) if isinstance(out.get("critic", {}), dict) else {} | |
| reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info.get("reward_breakdown", {}), dict) else {} | |
| primary_channels = ( | |
| info.get("primary_reward_channels", {}) | |
| if isinstance(info.get("primary_reward_channels", {}), dict) | |
| else {} | |
| ) | |
| final_action = out.get("final_action", {}) if isinstance(out.get("final_action", {}), dict) else {} | |
| rows.append( | |
| { | |
| "episode": i, | |
| "step": int(env.state.step_count), | |
| "reward": float(out.get("reward", 0.0)), | |
| "done": done, | |
| "legal": bool(critic.get("legal", False)), | |
| "severe_violation": len(critic.get("violations", [])) > 1, | |
| "abstain": str(final_action.get("action_type", "")).startswith("REQUEST_"), | |
| "termination_reason": info.get("termination_reason"), | |
| "step_timeout": bool(info.get("step_timeout")), | |
| "failure_reasons": info.get("failure_reasons", []), | |
| "invalid_action_count": int(info.get("invalid_action_count", 0)), | |
| "reward_breakdown": reward_breakdown, | |
| "primary_reward_channels": primary_channels, | |
| "policy_stack": policy_stack, | |
| "difficulty": difficulty, | |
| "sub_environment": sub_environment, | |
| "perturbation": perturbation, | |
| } | |
| ) | |
| if previous_policy is None: | |
| os.environ.pop("POLYGUARD_POLICY_STACK", None) | |
| else: | |
| os.environ["POLYGUARD_POLICY_STACK"] = previous_policy | |
| return rows | |