File size: 1,060 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Subgroup evaluation."""

from __future__ import annotations

from app.evaluation.simulator_rollouts import run_rollouts


def subgroup_eval() -> dict[str, dict[str, float]]:
    def _summary(rows: list[dict]) -> dict[str, float]:
        if not rows:
            return {"avg_reward": 0.0, "legal_rate": 0.0}
        return {
            "avg_reward": round(sum(float(r.get("reward", 0.0)) for r in rows) / len(rows), 6),
            "legal_rate": round(sum(1.0 for r in rows if bool(r.get("legal", False))) / len(rows), 6),
        }

    renal_rows = run_rollouts(episodes=6, difficulty="hard", sub_environment="PRECISION_DOSING", perturbation="missing_labs")
    hepatic_rows = run_rollouts(episodes=6, difficulty="hard", sub_environment="REGIMEN_RISK", perturbation="stale_evidence")
    frail_rows = run_rollouts(episodes=6, difficulty="hard", sub_environment="LONGITUDINAL_DEPRESCRIBING")
    return {
        "renal_compromise": _summary(renal_rows),
        "hepatic_compromise": _summary(hepatic_rows),
        "frail": _summary(frail_rows),
    }