"""Multi-agent orchestration graph.""" from __future__ import annotations import os from typing import Any from app.agents.candidate_agent import CandidateAgent from app.agents.critic_agent import CriticAgent from app.agents.dosing_agent import DosingAgent from app.agents.evidence_agent import EvidenceAgent from app.agents.explainer_agent import ExplainerAgent from app.agents.graph_agent import GraphSafetyAgent from app.agents.medrec_agent import MedRecAgent from app.agents.planner_agent import PlannerAgent from app.agents.supervisor_agent import SupervisorAgent from app.common.enums import CoordinationMode from app.common.types import CandidateAction, PolyGuardAction from app.env.env_core import PolyGuardEnv from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy class Orchestrator: def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None: self.env = env self.coordination_mode = coordination_mode self.medrec = MedRecAgent() self.evidence = EvidenceAgent() self.graph = GraphSafetyAgent() self.dosing = DosingAgent() self.candidate = CandidateAgent() self.supervisor = SupervisorAgent() self.planner = PlannerAgent() self.critic = CriticAgent() self.explainer = ExplainerAgent() bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower() if bandit_algo not in {"linucb", "thompson"}: bandit_algo = "linucb" self.bandit = ContextualBanditPolicy( algorithm=bandit_algo, # type: ignore[arg-type] alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")), epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")), seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")), ) self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower() self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3")) def set_mode(self, coordination_mode: CoordinationMode) -> None: self.coordination_mode = coordination_mode def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]: if coordination_mode is not None: self.coordination_mode = CoordinationMode(coordination_mode) state = self.env.state medrec_out = self.medrec.run(state) evidence_out = self.evidence.run(state) graph_out = self.graph.run(state) dosing_out = self.dosing.run(state) candidate_out = self.candidate.run(state) candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]] supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"]) planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW": planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates candidate_by_id = {item.candidate_id: item for item in planner_candidates} bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k) bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id] if not bandit_candidates: bandit_candidates = planner_candidates if self.policy_stack == "bandit-only": selected = bandit_candidates[0] proposed = PolyGuardAction( mode=selected.mode, action_type=selected.action_type, target_drug=selected.target_drug, replacement_drug=selected.replacement_drug, dose_bucket=selected.dose_bucket, taper_days=selected.taper_days, monitoring_plan=selected.monitoring_plan, candidate_id=selected.candidate_id, confidence=max(0.45, 1.0 - selected.uncertainty_score), rationale_brief="Bandit-only policy selected top contextual candidate.", ) elif self.policy_stack == "llm-only": proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"]) else: proposed = self.planner.run( candidates=bandit_candidates, mode=supervisor_out["mode"], provider_prompt={ "coordination_mode": self.coordination_mode.value, "policy_stack": self.policy_stack, "candidate_count": len(bandit_candidates), "sub_environment": state.sub_environment.value, }, ) critic_out = self.critic.run(state, proposed) final_action: PolyGuardAction = critic_out["final_action"] replan_triggered = False debate_rounds = 0 if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]: replan_triggered = True review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates proposed = self.planner.run(candidates=review_candidates, mode="REVIEW") critic_out = self.critic.run(state, proposed) final_action = critic_out["final_action"] debate_rounds = 1 if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type: debate_rounds = 2 obs, reward, done, info = self.env.step(final_action) selected_for_update = candidate_by_id.get(final_action.candidate_id) if selected_for_update is not None: self.bandit.update(selected_for_update, reward=reward) explanation_out = self.explainer.run(state, final_action, critic_out["report"]) return { "medrec": medrec_out, "evidence": evidence_out, "graph": graph_out, "dosing": dosing_out, "supervisor": supervisor_out, "proposed_action": proposed.model_dump(mode="json"), "critic": critic_out["report"], "final_action": final_action.model_dump(mode="json"), "observation": obs.model_dump(mode="json"), "reward": reward, "done": done, "info": info, "explanation": explanation_out, "coordination_mode": self.coordination_mode.value, "policy_stack": self.policy_stack, "bandit_topk": [item.candidate_id for item in bandit_candidates], "bandit_scores": [ { "candidate_id": item.candidate_id, "score": item.score, "exploration_bonus": item.exploration_bonus, "algorithm": item.algorithm, } for item in bandit_proposals ], "replan_triggered": replan_triggered, "debate_rounds": debate_rounds, }