| """Robustness evaluation suite computed from perturbed rollouts.""" | |
| from __future__ import annotations | |
| from app.evaluation.simulator_rollouts import run_rollouts | |
| def _safety_rate(rows: list[dict]) -> float: | |
| if not rows: | |
| return 0.0 | |
| return round(sum(1.0 for row in rows if bool(row.get("legal", False))) / len(rows), 6) | |
| def robustness_eval() -> dict[str, float]: | |
| return { | |
| "missing_labs_safety_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="hard", perturbation="missing_labs") | |
| ), | |
| "noisy_dose_info_safety_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="medium", perturbation="noisy_dose_info") | |
| ), | |
| "conflicting_meds_safety_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="hard", perturbation="conflicting_meds") | |
| ), | |
| "alias_noise_safety_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="medium", perturbation="alias_noise") | |
| ), | |
| "hidden_duplicate_detection_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="hard", perturbation="hidden_duplicate") | |
| ), | |
| "wrong_candidate_id_resilience": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="medium", policy_stack="bandit-only") | |
| ), | |
| "stale_evidence_safety_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="hard", perturbation="stale_evidence") | |
| ), | |
| "delayed_ade_manifestation_safety_rate": _safety_rate( | |
| run_rollouts(episodes=6, difficulty="hard", perturbation="delayed_ade") | |
| ), | |
| } | |