File size: 4,056 Bytes
21c7db9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """Simulator rollout helpers with scenario perturbations for evaluation."""
from __future__ import annotations
import os
from typing import Any
from app.agents.orchestrator import Orchestrator
from app.common.enums import DoseBucket
from app.env.env_core import PolyGuardEnv
def _apply_perturbation(env: PolyGuardEnv, perturbation: str | None) -> None:
if not perturbation:
return
state = env.state
meds = state.patient.medications
if perturbation == "missing_labs":
state.patient.labs.egfr = None
state.patient.labs.ast = None
state.patient.labs.alt = None
elif perturbation == "noisy_dose_info":
for idx, med in enumerate(meds):
if idx % 2 == 0:
med.dose_bucket = DoseBucket.HIGH if med.dose_bucket != DoseBucket.HIGH else DoseBucket.LOW
elif perturbation == "conflicting_meds" and meds:
meds.append(meds[0].model_copy())
elif perturbation == "alias_noise" and meds:
meds[0].drug = f"{meds[0].drug}_alias"
elif perturbation == "hidden_duplicate" and meds:
meds.append(meds[0].model_copy(update={"drug": meds[0].drug}))
elif perturbation == "stale_evidence":
state.unresolved_conflicts.append("evidence_stale")
elif perturbation == "delayed_ade":
state.patient.latent_confounders["delayed_ade"] = 0.8
def run_rollouts(
episodes: int = 5,
difficulty: str = "medium",
sub_environment: str | None = None,
perturbation: str | None = None,
seed_offset: int = 900,
policy_stack: str = "llm+bandit",
) -> list[dict[str, Any]]:
previous_policy = os.getenv("POLYGUARD_POLICY_STACK")
os.environ["POLYGUARD_POLICY_STACK"] = policy_stack
env = PolyGuardEnv()
orchestrator = Orchestrator(env)
rows: list[dict[str, Any]] = []
for i in range(episodes):
env.reset(seed=seed_offset + i, difficulty=difficulty, sub_environment=sub_environment)
_apply_perturbation(env, perturbation=perturbation)
done = False
while not done:
out = orchestrator.run_step()
done = bool(out.get("done"))
info = out.get("info", {}) if isinstance(out.get("info", {}), dict) else {}
critic = out.get("critic", {}) if isinstance(out.get("critic", {}), dict) else {}
reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info.get("reward_breakdown", {}), dict) else {}
primary_channels = (
info.get("primary_reward_channels", {})
if isinstance(info.get("primary_reward_channels", {}), dict)
else {}
)
final_action = out.get("final_action", {}) if isinstance(out.get("final_action", {}), dict) else {}
rows.append(
{
"episode": i,
"step": int(env.state.step_count),
"reward": float(out.get("reward", 0.0)),
"done": done,
"legal": bool(critic.get("legal", False)),
"severe_violation": len(critic.get("violations", [])) > 1,
"abstain": str(final_action.get("action_type", "")).startswith("REQUEST_"),
"termination_reason": info.get("termination_reason"),
"step_timeout": bool(info.get("step_timeout")),
"failure_reasons": info.get("failure_reasons", []),
"invalid_action_count": int(info.get("invalid_action_count", 0)),
"reward_breakdown": reward_breakdown,
"primary_reward_channels": primary_channels,
"policy_stack": policy_stack,
"difficulty": difficulty,
"sub_environment": sub_environment,
"perturbation": perturbation,
}
)
if previous_policy is None:
os.environ.pop("POLYGUARD_POLICY_STACK", None)
else:
os.environ["POLYGUARD_POLICY_STACK"] = previous_policy
return rows
|