Spaces:
Running
Running
| """Planner agent.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| from app.common.types import CandidateAction, PolyGuardAction | |
| from app.models.policy.provider_runtime import PolicyProviderRouter | |
| from app.models.policy.safety_ranker import rank_candidates | |
| class PlannerAgent: | |
| name = "PlannerAgent" | |
| def __init__(self) -> None: | |
| self.provider_router = PolicyProviderRouter() | |
| def run( | |
| self, | |
| candidates: list[CandidateAction], | |
| mode: str, | |
| provider_prompt: dict[str, Any] | None = None, | |
| provider_preference: tuple[str, ...] = ("transformers",), | |
| ) -> PolyGuardAction: | |
| filtered = [c for c in candidates if c.mode.value == mode] or candidates | |
| selection = self.provider_router.select_candidate( | |
| candidates=filtered, | |
| prompt=provider_prompt or {"mode": mode}, | |
| provider_preference=provider_preference, | |
| ) | |
| by_id = {item.candidate_id: item for item in filtered} | |
| top = by_id.get(selection.candidate_id, rank_candidates(filtered)[0]) | |
| return PolyGuardAction( | |
| mode=top.mode, | |
| action_type=top.action_type, | |
| target_drug=top.target_drug, | |
| replacement_drug=top.replacement_drug, | |
| dose_bucket=top.dose_bucket, | |
| taper_days=top.taper_days, | |
| monitoring_plan=top.monitoring_plan, | |
| candidate_id=top.candidate_id, | |
| confidence=max(0.45, 1.0 - top.uncertainty_score), | |
| rationale_brief=selection.rationale, | |
| ) | |