| """Multi-agent orchestration graph.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Any |
|
|
| from app.agents.candidate_agent import CandidateAgent |
| from app.agents.critic_agent import CriticAgent |
| from app.agents.dosing_agent import DosingAgent |
| from app.agents.evidence_agent import EvidenceAgent |
| from app.agents.explainer_agent import ExplainerAgent |
| from app.agents.graph_agent import GraphSafetyAgent |
| from app.agents.medrec_agent import MedRecAgent |
| from app.agents.planner_agent import PlannerAgent |
| from app.agents.supervisor_agent import SupervisorAgent |
| from app.common.enums import CoordinationMode |
| from app.common.types import CandidateAction, PolyGuardAction |
| from app.env.env_core import PolyGuardEnv |
| from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy |
|
|
|
|
| class Orchestrator: |
| def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None: |
| self.env = env |
| self.coordination_mode = coordination_mode |
| self.medrec = MedRecAgent() |
| self.evidence = EvidenceAgent() |
| self.graph = GraphSafetyAgent() |
| self.dosing = DosingAgent() |
| self.candidate = CandidateAgent() |
| self.supervisor = SupervisorAgent() |
| self.planner = PlannerAgent() |
| self.critic = CriticAgent() |
| self.explainer = ExplainerAgent() |
| bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower() |
| if bandit_algo not in {"linucb", "thompson"}: |
| bandit_algo = "linucb" |
| self.bandit = ContextualBanditPolicy( |
| algorithm=bandit_algo, |
| alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")), |
| epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")), |
| seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")), |
| ) |
| self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower() |
| self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3")) |
|
|
| def set_mode(self, coordination_mode: CoordinationMode) -> None: |
| self.coordination_mode = coordination_mode |
|
|
| def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]: |
| if coordination_mode is not None: |
| self.coordination_mode = CoordinationMode(coordination_mode) |
| state = self.env.state |
| medrec_out = self.medrec.run(state) |
| evidence_out = self.evidence.run(state) |
| graph_out = self.graph.run(state) |
| dosing_out = self.dosing.run(state) |
| candidate_out = self.candidate.run(state) |
| candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]] |
|
|
| supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"]) |
| planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates |
| if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW": |
| planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates |
|
|
| candidate_by_id = {item.candidate_id: item for item in planner_candidates} |
| bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k) |
| bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id] |
| if not bandit_candidates: |
| bandit_candidates = planner_candidates |
|
|
| if self.policy_stack == "bandit-only": |
| selected = bandit_candidates[0] |
| proposed = PolyGuardAction( |
| mode=selected.mode, |
| action_type=selected.action_type, |
| target_drug=selected.target_drug, |
| replacement_drug=selected.replacement_drug, |
| dose_bucket=selected.dose_bucket, |
| taper_days=selected.taper_days, |
| monitoring_plan=selected.monitoring_plan, |
| candidate_id=selected.candidate_id, |
| confidence=max(0.45, 1.0 - selected.uncertainty_score), |
| rationale_brief="Bandit-only policy selected top contextual candidate.", |
| ) |
| elif self.policy_stack == "llm-only": |
| proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"]) |
| else: |
| proposed = self.planner.run( |
| candidates=bandit_candidates, |
| mode=supervisor_out["mode"], |
| provider_prompt={ |
| "coordination_mode": self.coordination_mode.value, |
| "policy_stack": self.policy_stack, |
| "candidate_count": len(bandit_candidates), |
| "sub_environment": state.sub_environment.value, |
| }, |
| ) |
|
|
| critic_out = self.critic.run(state, proposed) |
| final_action: PolyGuardAction = critic_out["final_action"] |
| replan_triggered = False |
| debate_rounds = 0 |
|
|
| if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]: |
| replan_triggered = True |
| review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates |
| proposed = self.planner.run(candidates=review_candidates, mode="REVIEW") |
| critic_out = self.critic.run(state, proposed) |
| final_action = critic_out["final_action"] |
| debate_rounds = 1 |
|
|
| if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type: |
| debate_rounds = 2 |
|
|
| obs, reward, done, info = self.env.step(final_action) |
| selected_for_update = candidate_by_id.get(final_action.candidate_id) |
| if selected_for_update is not None: |
| self.bandit.update(selected_for_update, reward=reward) |
|
|
| explanation_out = self.explainer.run(state, final_action, critic_out["report"]) |
| return { |
| "medrec": medrec_out, |
| "evidence": evidence_out, |
| "graph": graph_out, |
| "dosing": dosing_out, |
| "supervisor": supervisor_out, |
| "proposed_action": proposed.model_dump(mode="json"), |
| "critic": critic_out["report"], |
| "final_action": final_action.model_dump(mode="json"), |
| "observation": obs.model_dump(mode="json"), |
| "reward": reward, |
| "done": done, |
| "info": info, |
| "explanation": explanation_out, |
| "coordination_mode": self.coordination_mode.value, |
| "policy_stack": self.policy_stack, |
| "bandit_topk": [item.candidate_id for item in bandit_candidates], |
| "bandit_scores": [ |
| { |
| "candidate_id": item.candidate_id, |
| "score": item.score, |
| "exploration_bonus": item.exploration_bonus, |
| "algorithm": item.algorithm, |
| } |
| for item in bandit_proposals |
| ], |
| "replan_triggered": replan_triggered, |
| "debate_rounds": debate_rounds, |
| } |
|
|