| """Planner agent.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from app.common.types import CandidateAction, PolyGuardAction |
| from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference |
| from app.models.policy.safety_ranker import rank_candidates |
|
|
|
|
| class PlannerAgent: |
| name = "PlannerAgent" |
|
|
| def __init__(self) -> None: |
| self.provider_router = PolicyProviderRouter() |
|
|
| def run( |
| self, |
| candidates: list[CandidateAction], |
| mode: str, |
| provider_prompt: dict[str, Any] | None = None, |
| provider_preference: tuple[str, ...] | None = None, |
| ) -> PolyGuardAction: |
| filtered = [c for c in candidates if c.mode.value == mode] or candidates |
| selection = self.provider_router.select_candidate( |
| candidates=filtered, |
| prompt=provider_prompt or {"mode": mode}, |
| provider_preference=provider_preference or default_provider_preference(), |
| ) |
| by_id = {item.candidate_id: item for item in filtered} |
| top = by_id.get(selection.candidate_id, rank_candidates(filtered)[0]) |
| return PolyGuardAction( |
| mode=top.mode, |
| action_type=top.action_type, |
| target_drug=top.target_drug, |
| replacement_drug=top.replacement_drug, |
| dose_bucket=top.dose_bucket, |
| taper_days=top.taper_days, |
| monitoring_plan=top.monitoring_plan, |
| candidate_id=top.candidate_id, |
| confidence=max(0.45, 1.0 - top.uncertainty_score), |
| rationale_brief=selection.rationale, |
| ) |
|
|