polyguard-openenv / app /agents /orchestrator.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Multi-agent orchestration graph."""
from __future__ import annotations
import os
from typing import Any
from app.agents.candidate_agent import CandidateAgent
from app.agents.critic_agent import CriticAgent
from app.agents.dosing_agent import DosingAgent
from app.agents.evidence_agent import EvidenceAgent
from app.agents.explainer_agent import ExplainerAgent
from app.agents.graph_agent import GraphSafetyAgent
from app.agents.medrec_agent import MedRecAgent
from app.agents.planner_agent import PlannerAgent
from app.agents.supervisor_agent import SupervisorAgent
from app.common.enums import CoordinationMode
from app.common.types import CandidateAction, PolyGuardAction
from app.env.env_core import PolyGuardEnv
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy
class Orchestrator:
def __init__(self, env: PolyGuardEnv, coordination_mode: CoordinationMode = CoordinationMode.SEQUENTIAL) -> None:
self.env = env
self.coordination_mode = coordination_mode
self.medrec = MedRecAgent()
self.evidence = EvidenceAgent()
self.graph = GraphSafetyAgent()
self.dosing = DosingAgent()
self.candidate = CandidateAgent()
self.supervisor = SupervisorAgent()
self.planner = PlannerAgent()
self.critic = CriticAgent()
self.explainer = ExplainerAgent()
bandit_algo = os.getenv("POLYGUARD_BANDIT_ALGO", "linucb").strip().lower()
if bandit_algo not in {"linucb", "thompson"}:
bandit_algo = "linucb"
self.bandit = ContextualBanditPolicy(
algorithm=bandit_algo, # type: ignore[arg-type]
alpha=float(os.getenv("POLYGUARD_BANDIT_ALPHA", "0.55")),
epsilon=float(os.getenv("POLYGUARD_BANDIT_EPSILON", "0.1")),
seed=int(os.getenv("POLYGUARD_BANDIT_SEED", "42")),
)
self.policy_stack = os.getenv("POLYGUARD_POLICY_STACK", "llm+bandit").strip().lower()
self.bandit_top_k = int(os.getenv("POLYGUARD_BANDIT_TOP_K", "3"))
def set_mode(self, coordination_mode: CoordinationMode) -> None:
self.coordination_mode = coordination_mode
def run_step(self, coordination_mode: str | None = None) -> dict[str, Any]:
if coordination_mode is not None:
self.coordination_mode = CoordinationMode(coordination_mode)
state = self.env.state
medrec_out = self.medrec.run(state)
evidence_out = self.evidence.run(state)
graph_out = self.graph.run(state)
dosing_out = self.dosing.run(state)
candidate_out = self.candidate.run(state)
candidates = [CandidateAction.model_validate(item) for item in candidate_out["candidates"]]
supervisor_out = self.supervisor.run(state, dosing_active=dosing_out["dosing_active"])
planner_candidates = [c for c in candidates if c.mode.value == supervisor_out["mode"]] or candidates
if self.coordination_mode == CoordinationMode.SUPERVISOR_ROUTED and supervisor_out["mode"] == "REVIEW":
planner_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or planner_candidates
candidate_by_id = {item.candidate_id: item for item in planner_candidates}
bandit_proposals = self.bandit.propose(planner_candidates, top_k=self.bandit_top_k)
bandit_candidates = [candidate_by_id[item.candidate_id] for item in bandit_proposals if item.candidate_id in candidate_by_id]
if not bandit_candidates:
bandit_candidates = planner_candidates
if self.policy_stack == "bandit-only":
selected = bandit_candidates[0]
proposed = PolyGuardAction(
mode=selected.mode,
action_type=selected.action_type,
target_drug=selected.target_drug,
replacement_drug=selected.replacement_drug,
dose_bucket=selected.dose_bucket,
taper_days=selected.taper_days,
monitoring_plan=selected.monitoring_plan,
candidate_id=selected.candidate_id,
confidence=max(0.45, 1.0 - selected.uncertainty_score),
rationale_brief="Bandit-only policy selected top contextual candidate.",
)
elif self.policy_stack == "llm-only":
proposed = self.planner.run(candidates=planner_candidates, mode=supervisor_out["mode"])
else:
proposed = self.planner.run(
candidates=bandit_candidates,
mode=supervisor_out["mode"],
provider_prompt={
"coordination_mode": self.coordination_mode.value,
"policy_stack": self.policy_stack,
"candidate_count": len(bandit_candidates),
"sub_environment": state.sub_environment.value,
},
)
critic_out = self.critic.run(state, proposed)
final_action: PolyGuardAction = critic_out["final_action"]
replan_triggered = False
debate_rounds = 0
if self.coordination_mode in {CoordinationMode.REPLAN_ON_VETO, CoordinationMode.LIGHT_DEBATE} and not critic_out["approved"]:
replan_triggered = True
review_candidates = [c for c in candidates if c.mode.value == "REVIEW"] or candidates
proposed = self.planner.run(candidates=review_candidates, mode="REVIEW")
critic_out = self.critic.run(state, proposed)
final_action = critic_out["final_action"]
debate_rounds = 1
if self.coordination_mode == CoordinationMode.LIGHT_DEBATE and critic_out["approved"] and proposed.action_type != final_action.action_type:
debate_rounds = 2
obs, reward, done, info = self.env.step(final_action)
selected_for_update = candidate_by_id.get(final_action.candidate_id)
if selected_for_update is not None:
self.bandit.update(selected_for_update, reward=reward)
explanation_out = self.explainer.run(state, final_action, critic_out["report"])
return {
"medrec": medrec_out,
"evidence": evidence_out,
"graph": graph_out,
"dosing": dosing_out,
"supervisor": supervisor_out,
"proposed_action": proposed.model_dump(mode="json"),
"critic": critic_out["report"],
"final_action": final_action.model_dump(mode="json"),
"observation": obs.model_dump(mode="json"),
"reward": reward,
"done": done,
"info": info,
"explanation": explanation_out,
"coordination_mode": self.coordination_mode.value,
"policy_stack": self.policy_stack,
"bandit_topk": [item.candidate_id for item in bandit_candidates],
"bandit_scores": [
{
"candidate_id": item.candidate_id,
"score": item.score,
"exploration_bonus": item.exploration_bonus,
"algorithm": item.algorithm,
}
for item in bandit_proposals
],
"replan_triggered": replan_triggered,
"debate_rounds": debate_rounds,
}