File size: 7,236 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Multi-agent orchestration graph."""

from __future__ import annotations

import os
from typing import Any

from app.agents.candidate_agent import CandidateAgent
from app.agents.critic_agent import CriticAgent
from app.agents.dosing_agent import DosingAgent
from app.agents.evidence_agent import EvidenceAgent
from app.agents.explainer_agent import ExplainerAgent
from app.agents.graph_agent import GraphSafetyAgent
from app.agents.medrec_agent import MedRecAgent
from app.agents.planner_agent import PlannerAgent
from app.agents.supervisor_agent import SupervisorAgent
from app.common.enums import CoordinationMode
from app.common.types import CandidateAction, PolyGuardAction
from app.env.env_core import PolyGuardEnv
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy


class Orchestrator:
    def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None:
        self.env = env
        self.coordination_mode = coordination_mode
        self.medrec = MedRecAgent()
        self.evidence = EvidenceAgent()
        self.graph = GraphSafetyAgent()
        self.dosing = DosingAgent()
        self.candidate = CandidateAgent()
        self.supervisor = SupervisorAgent()
        self.planner = PlannerAgent()
        self.critic = CriticAgent()
        self.explainer = ExplainerAgent()
        bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower()
        if bandit_algo not in {"linucb", "thompson"}:
            bandit_algo = "linucb"
        self.bandit = ContextualBanditPolicy(
            algorithm=bandit_algo,  # type: ignore[arg-type]
            alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")),
            epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")),
            seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")),
        )
        self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower()
        self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3"))

    def set_mode(self, coordination_mode: CoordinationMode) -> None:
        self.coordination_mode = coordination_mode

    def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]:
        if coordination_mode is not None:
            self.coordination_mode = CoordinationMode(coordination_mode)
        state = self.env.state
        medrec_out = self.medrec.run(state)
        evidence_out = self.evidence.run(state)
        graph_out = self.graph.run(state)
        dosing_out = self.dosing.run(state)
        candidate_out = self.candidate.run(state)
        candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]]

        supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"])
        planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates
        if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW":
            planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates

        candidate_by_id = {item.candidate_id: item for item in planner_candidates}
        bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k)
        bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id]
        if not bandit_candidates:
            bandit_candidates = planner_candidates

        if self.policy_stack == "bandit-only":
            selected = bandit_candidates[0]
            proposed = PolyGuardAction(
                mode=selected.mode,
                action_type=selected.action_type,
                target_drug=selected.target_drug,
                replacement_drug=selected.replacement_drug,
                dose_bucket=selected.dose_bucket,
                taper_days=selected.taper_days,
                monitoring_plan=selected.monitoring_plan,
                candidate_id=selected.candidate_id,
                confidence=max(0.45, 1.0 - selected.uncertainty_score),
                rationale_brief="Bandit-only policy selected top contextual candidate.",
            )
        elif self.policy_stack == "llm-only":
            proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"])
        else:
            proposed = self.planner.run(
                candidates=bandit_candidates,
                mode=supervisor_out["mode"],
                provider_prompt={
                    "coordination_mode": self.coordination_mode.value,
                    "policy_stack": self.policy_stack,
                    "candidate_count": len(bandit_candidates),
                    "sub_environment": state.sub_environment.value,
                },
            )

        critic_out = self.critic.run(state, proposed)
        final_action: PolyGuardAction = critic_out["final_action"]
        replan_triggered = False
        debate_rounds = 0

        if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]:
            replan_triggered = True
            review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates
            proposed = self.planner.run(candidates=review_candidates, mode="REVIEW")
            critic_out = self.critic.run(state, proposed)
            final_action = critic_out["final_action"]
            debate_rounds = 1

        if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type:
            debate_rounds = 2

        obs, reward, done, info = self.env.step(final_action)
        selected_for_update = candidate_by_id.get(final_action.candidate_id)
        if selected_for_update is not None:
            self.bandit.update(selected_for_update, reward=reward)

        explanation_out = self.explainer.run(state, final_action, critic_out["report"])
        return {
            "medrec": medrec_out,
            "evidence": evidence_out,
            "graph": graph_out,
            "dosing": dosing_out,
            "supervisor": supervisor_out,
            "proposed_action": proposed.model_dump(mode="json"),
            "critic": critic_out["report"],
            "final_action": final_action.model_dump(mode="json"),
            "observation": obs.model_dump(mode="json"),
            "reward": reward,
            "done": done,
            "info": info,
            "explanation": explanation_out,
            "coordination_mode": self.coordination_mode.value,
            "policy_stack": self.policy_stack,
            "bandit_topk": [item.candidate_id for item in bandit_candidates],
            "bandit_scores": [
                {
                    "candidate_id": item.candidate_id,
                    "score": item.score,
                    "exploration_bonus": item.exploration_bonus,
                    "algorithm": item.algorithm,
                }
                for item in bandit_proposals
            ],
            "replan_triggered": replan_triggered,
            "debate_rounds": debate_rounds,
        }