Spaces:
Running
Running
File size: 7,236 Bytes
877add7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """Multi-agent orchestration graph."""
from __future__ import annotations
import os
from typing import Any
from app.agents.candidate_agent import CandidateAgent
from app.agents.critic_agent import CriticAgent
from app.agents.dosing_agent import DosingAgent
from app.agents.evidence_agent import EvidenceAgent
from app.agents.explainer_agent import ExplainerAgent
from app.agents.graph_agent import GraphSafetyAgent
from app.agents.medrec_agent import MedRecAgent
from app.agents.planner_agent import PlannerAgent
from app.agents.supervisor_agent import SupervisorAgent
from app.common.enums import CoordinationMode
from app.common.types import CandidateAction, PolyGuardAction
from app.env.env_core import PolyGuardEnv
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy
class Orchestrator:
def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None:
self.env = env
self.coordination_mode = coordination_mode
self.medrec = MedRecAgent()
self.evidence = EvidenceAgent()
self.graph = GraphSafetyAgent()
self.dosing = DosingAgent()
self.candidate = CandidateAgent()
self.supervisor = SupervisorAgent()
self.planner = PlannerAgent()
self.critic = CriticAgent()
self.explainer = ExplainerAgent()
bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower()
if bandit_algo not in {"linucb", "thompson"}:
bandit_algo = "linucb"
self.bandit = ContextualBanditPolicy(
algorithm=bandit_algo, # type: ignore[arg-type]
alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")),
epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")),
seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")),
)
self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower()
self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3"))
def set_mode(self, coordination_mode: CoordinationMode) -> None:
self.coordination_mode = coordination_mode
def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]:
if coordination_mode is not None:
self.coordination_mode = CoordinationMode(coordination_mode)
state = self.env.state
medrec_out = self.medrec.run(state)
evidence_out = self.evidence.run(state)
graph_out = self.graph.run(state)
dosing_out = self.dosing.run(state)
candidate_out = self.candidate.run(state)
candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]]
supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"])
planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates
if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW":
planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates
candidate_by_id = {item.candidate_id: item for item in planner_candidates}
bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k)
bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id]
if not bandit_candidates:
bandit_candidates = planner_candidates
if self.policy_stack == "bandit-only":
selected = bandit_candidates[0]
proposed = PolyGuardAction(
mode=selected.mode,
action_type=selected.action_type,
target_drug=selected.target_drug,
replacement_drug=selected.replacement_drug,
dose_bucket=selected.dose_bucket,
taper_days=selected.taper_days,
monitoring_plan=selected.monitoring_plan,
candidate_id=selected.candidate_id,
confidence=max(0.45, 1.0 - selected.uncertainty_score),
rationale_brief="Bandit-only policy selected top contextual candidate.",
)
elif self.policy_stack == "llm-only":
proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"])
else:
proposed = self.planner.run(
candidates=bandit_candidates,
mode=supervisor_out["mode"],
provider_prompt={
"coordination_mode": self.coordination_mode.value,
"policy_stack": self.policy_stack,
"candidate_count": len(bandit_candidates),
"sub_environment": state.sub_environment.value,
},
)
critic_out = self.critic.run(state, proposed)
final_action: PolyGuardAction = critic_out["final_action"]
replan_triggered = False
debate_rounds = 0
if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]:
replan_triggered = True
review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates
proposed = self.planner.run(candidates=review_candidates, mode="REVIEW")
critic_out = self.critic.run(state, proposed)
final_action = critic_out["final_action"]
debate_rounds = 1
if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type:
debate_rounds = 2
obs, reward, done, info = self.env.step(final_action)
selected_for_update = candidate_by_id.get(final_action.candidate_id)
if selected_for_update is not None:
self.bandit.update(selected_for_update, reward=reward)
explanation_out = self.explainer.run(state, final_action, critic_out["report"])
return {
"medrec": medrec_out,
"evidence": evidence_out,
"graph": graph_out,
"dosing": dosing_out,
"supervisor": supervisor_out,
"proposed_action": proposed.model_dump(mode="json"),
"critic": critic_out["report"],
"final_action": final_action.model_dump(mode="json"),
"observation": obs.model_dump(mode="json"),
"reward": reward,
"done": done,
"info": info,
"explanation": explanation_out,
"coordination_mode": self.coordination_mode.value,
"policy_stack": self.policy_stack,
"bandit_topk": [item.candidate_id for item in bandit_candidates],
"bandit_scores": [
{
"candidate_id": item.candidate_id,
"score": item.score,
"exploration_bonus": item.exploration_bonus,
"algorithm": item.algorithm,
}
for item in bandit_proposals
],
"replan_triggered": replan_triggered,
"debate_rounds": debate_rounds,
}
|