File size: 4,056 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Simulator rollout helpers with scenario perturbations for evaluation."""

from __future__ import annotations

import os
from typing import Any

from app.agents.orchestrator import Orchestrator
from app.common.enums import DoseBucket
from app.env.env_core import PolyGuardEnv


def _apply_perturbation(env: PolyGuardEnv, perturbation: str | None) -> None:
    if not perturbation:
        return

    state = env.state
    meds = state.patient.medications

    if perturbation == "missing_labs":
        state.patient.labs.egfr = None
        state.patient.labs.ast = None
        state.patient.labs.alt = None
    elif perturbation == "noisy_dose_info":
        for idx, med in enumerate(meds):
            if idx % 2 == 0:
                med.dose_bucket = DoseBucket.HIGH if med.dose_bucket != DoseBucket.HIGH else DoseBucket.LOW
    elif perturbation == "conflicting_meds" and meds:
        meds.append(meds[0].model_copy())
    elif perturbation == "alias_noise" and meds:
        meds[0].drug = f"{meds[0].drug}_alias"
    elif perturbation == "hidden_duplicate" and meds:
        meds.append(meds[0].model_copy(update={"drug": meds[0].drug}))
    elif perturbation == "stale_evidence":
        state.unresolved_conflicts.append("evidence_stale")
    elif perturbation == "delayed_ade":
        state.patient.latent_confounders["delayed_ade"] = 0.8


def run_rollouts(
    episodes: int = 5,
    difficulty: str = "medium",
    sub_environment: str | None = None,
    perturbation: str | None = None,
    seed_offset: int = 900,
    policy_stack: str = "llm+bandit",
) -> list[dict[str, Any]]:
    previous_policy = os.getenv("POLYGUARD_POLICY_STACK")
    os.environ["POLYGUARD_POLICY_STACK"] = policy_stack

    env = PolyGuardEnv()
    orchestrator = Orchestrator(env)
    rows: list[dict[str, Any]] = []

    for i in range(episodes):
        env.reset(seed=seed_offset + i, difficulty=difficulty, sub_environment=sub_environment)
        _apply_perturbation(env, perturbation=perturbation)

        done = False
        while not done:
            out = orchestrator.run_step()
            done = bool(out.get("done"))
            info = out.get("info", {}) if isinstance(out.get("info", {}), dict) else {}
            critic = out.get("critic", {}) if isinstance(out.get("critic", {}), dict) else {}
            reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info.get("reward_breakdown", {}), dict) else {}
            primary_channels = (
                info.get("primary_reward_channels", {})
                if isinstance(info.get("primary_reward_channels", {}), dict)
                else {}
            )
            final_action = out.get("final_action", {}) if isinstance(out.get("final_action", {}), dict) else {}

            rows.append(
                {
                    "episode": i,
                    "step": int(env.state.step_count),
                    "reward": float(out.get("reward", 0.0)),
                    "done": done,
                    "legal": bool(critic.get("legal", False)),
                    "severe_violation": len(critic.get("violations", [])) > 1,
                    "abstain": str(final_action.get("action_type", "")).startswith("REQUEST_"),
                    "termination_reason": info.get("termination_reason"),
                    "step_timeout": bool(info.get("step_timeout")),
                    "failure_reasons": info.get("failure_reasons", []),
                    "invalid_action_count": int(info.get("invalid_action_count", 0)),
                    "reward_breakdown": reward_breakdown,
                    "primary_reward_channels": primary_channels,
                    "policy_stack": policy_stack,
                    "difficulty": difficulty,
                    "sub_environment": sub_environment,
                    "perturbation": perturbation,
                }
            )

    if previous_policy is None:
        os.environ.pop("POLYGUARD_POLICY_STACK", None)
    else:
        os.environ["POLYGUARD_POLICY_STACK"] = previous_policy

    return rows