Spaces:
Running
Running
| """Multi-agent orchestration graph.""" | |
| from __future__ import annotations | |
| import os | |
| from typing import Any | |
| from app.agents.candidate_agent import CandidateAgent | |
| from app.agents.critic_agent import CriticAgent | |
| from app.agents.dosing_agent import DosingAgent | |
| from app.agents.evidence_agent import EvidenceAgent | |
| from app.agents.explainer_agent import ExplainerAgent | |
| from app.agents.graph_agent import GraphSafetyAgent | |
| from app.agents.medrec_agent import MedRecAgent | |
| from app.agents.planner_agent import PlannerAgent | |
| from app.agents.supervisor_agent import SupervisorAgent | |
| from app.common.enums import CoordinationMode | |
| from app.common.types import CandidateAction, PolyGuardAction | |
| from app.env.env_core import PolyGuardEnv | |
| from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy | |
| class Orchestrator: | |
| def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None: | |
| self.env = env | |
| self.coordination_mode = coordination_mode | |
| self.medrec = MedRecAgent() | |
| self.evidence = EvidenceAgent() | |
| self.graph = GraphSafetyAgent() | |
| self.dosing = DosingAgent() | |
| self.candidate = CandidateAgent() | |
| self.supervisor = SupervisorAgent() | |
| self.planner = PlannerAgent() | |
| self.critic = CriticAgent() | |
| self.explainer = ExplainerAgent() | |
| bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower() | |
| if bandit_algo not in {"linucb", "thompson"}: | |
| bandit_algo = "linucb" | |
| self.bandit = ContextualBanditPolicy( | |
| algorithm=bandit_algo, # type: ignore[arg-type] | |
| alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")), | |
| epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")), | |
| seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")), | |
| ) | |
| self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower() | |
| self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3")) | |
| def set_mode(self, coordination_mode: CoordinationMode) -> None: | |
| self.coordination_mode = coordination_mode | |
| def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]: | |
| if coordination_mode is not None: | |
| self.coordination_mode = CoordinationMode(coordination_mode) | |
| state = self.env.state | |
| medrec_out = self.medrec.run(state) | |
| evidence_out = self.evidence.run(state) | |
| graph_out = self.graph.run(state) | |
| dosing_out = self.dosing.run(state) | |
| candidate_out = self.candidate.run(state) | |
| candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]] | |
| supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"]) | |
| planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates | |
| if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW": | |
| planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates | |
| candidate_by_id = {item.candidate_id: item for item in planner_candidates} | |
| bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k) | |
| bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id] | |
| if not bandit_candidates: | |
| bandit_candidates = planner_candidates | |
| if self.policy_stack == "bandit-only": | |
| selected = bandit_candidates[0] | |
| proposed = PolyGuardAction( | |
| mode=selected.mode, | |
| action_type=selected.action_type, | |
| target_drug=selected.target_drug, | |
| replacement_drug=selected.replacement_drug, | |
| dose_bucket=selected.dose_bucket, | |
| taper_days=selected.taper_days, | |
| monitoring_plan=selected.monitoring_plan, | |
| candidate_id=selected.candidate_id, | |
| confidence=max(0.45, 1.0 - selected.uncertainty_score), | |
| rationale_brief="Bandit-only policy selected top contextual candidate.", | |
| ) | |
| elif self.policy_stack == "llm-only": | |
| proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"]) | |
| else: | |
| proposed = self.planner.run( | |
| candidates=bandit_candidates, | |
| mode=supervisor_out["mode"], | |
| provider_prompt={ | |
| "coordination_mode": self.coordination_mode.value, | |
| "policy_stack": self.policy_stack, | |
| "candidate_count": len(bandit_candidates), | |
| "sub_environment": state.sub_environment.value, | |
| }, | |
| ) | |
| critic_out = self.critic.run(state, proposed) | |
| final_action: PolyGuardAction = critic_out["final_action"] | |
| replan_triggered = False | |
| debate_rounds = 0 | |
| if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]: | |
| replan_triggered = True | |
| review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates | |
| proposed = self.planner.run(candidates=review_candidates, mode="REVIEW") | |
| critic_out = self.critic.run(state, proposed) | |
| final_action = critic_out["final_action"] | |
| debate_rounds = 1 | |
| if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type: | |
| debate_rounds = 2 | |
| obs, reward, done, info = self.env.step(final_action) | |
| selected_for_update = candidate_by_id.get(final_action.candidate_id) | |
| if selected_for_update is not None: | |
| self.bandit.update(selected_for_update, reward=reward) | |
| explanation_out = self.explainer.run(state, final_action, critic_out["report"]) | |
| return { | |
| "medrec": medrec_out, | |
| "evidence": evidence_out, | |
| "graph": graph_out, | |
| "dosing": dosing_out, | |
| "supervisor": supervisor_out, | |
| "proposed_action": proposed.model_dump(mode="json"), | |
| "critic": critic_out["report"], | |
| "final_action": final_action.model_dump(mode="json"), | |
| "observation": obs.model_dump(mode="json"), | |
| "reward": reward, | |
| "done": done, | |
| "info": info, | |
| "explanation": explanation_out, | |
| "coordination_mode": self.coordination_mode.value, | |
| "policy_stack": self.policy_stack, | |
| "bandit_topk": [item.candidate_id for item in bandit_candidates], | |
| "bandit_scores": [ | |
| { | |
| "candidate_id": item.candidate_id, | |
| "score": item.score, | |
| "exploration_bonus": item.exploration_bonus, | |
| "algorithm": item.algorithm, | |
| } | |
| for item in bandit_proposals | |
| ], | |
| "replan_triggered": replan_triggered, | |
| "debate_rounds": debate_rounds, | |
| } | |