"""Contextual bandit baseline and top-k proposer.""" from __future__ import annotations import random from app.common.types import CandidateAction, PolyGuardAction from app.models.baselines.contextual_bandit_policy import BanditProposal, ContextualBanditPolicy from app.models.baselines.rules_only import choose_rules_only def choose_contextual_bandit(candidates: list[CandidateAction], epsilon: float = 0.2) -> PolyGuardAction: proposals = choose_contextual_bandit_topk(candidates=candidates, top_k=1, epsilon=epsilon) if not proposals: return choose_rules_only(candidates) candidate_map = {item.candidate_id: item for item in candidates} top = candidate_map.get(proposals[0].candidate_id) if top is None: return choose_rules_only(candidates) return PolyGuardAction( mode=top.mode, action_type=top.action_type, target_drug=top.target_drug, replacement_drug=top.replacement_drug, dose_bucket=top.dose_bucket, taper_days=top.taper_days, monitoring_plan=top.monitoring_plan, candidate_id=top.candidate_id, confidence=0.68, rationale_brief="Contextual bandit selected candidate.", ) def choose_contextual_bandit_topk( candidates: list[CandidateAction], top_k: int = 3, epsilon: float = 0.2, algorithm: str = "linucb", ) -> list[BanditProposal]: if not candidates: return [] if algorithm not in {"linucb", "thompson"}: algorithm = "linucb" policy = ContextualBanditPolicy( algorithm=algorithm, # type: ignore[arg-type] epsilon=max(0.0, min(1.0, epsilon)), seed=random.randint(1, 10_000), ) return policy.propose(candidates=candidates, top_k=top_k)