polyguard-openenv / app /models /baselines /contextual_bandit.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Contextual bandit baseline and top-k proposer."""
from __future__ import annotations
import random
from app.common.types import CandidateAction, PolyGuardAction
from app.models.baselines.contextual_bandit_policy import BanditProposal, ContextualBanditPolicy
from app.models.baselines.rules_only import choose_rules_only
def choose_contextual_bandit(candidates: list[CandidateAction], epsilon: float = 0.2) -> PolyGuardAction:
proposals = choose_contextual_bandit_topk(candidates=candidates, top_k=1, epsilon=epsilon)
if not proposals:
return choose_rules_only(candidates)
candidate_map = {item.candidate_id: item for item in candidates}
top = candidate_map.get(proposals[0].candidate_id)
if top is None:
return choose_rules_only(candidates)
return PolyGuardAction(
mode=top.mode,
action_type=top.action_type,
target_drug=top.target_drug,
replacement_drug=top.replacement_drug,
dose_bucket=top.dose_bucket,
taper_days=top.taper_days,
monitoring_plan=top.monitoring_plan,
candidate_id=top.candidate_id,
confidence=0.68,
rationale_brief="Contextual bandit selected candidate.",
)
def choose_contextual_bandit_topk(
candidates: list[CandidateAction],
top_k: int = 3,
epsilon: float = 0.2,
algorithm: str = "linucb",
) -> list[BanditProposal]:
if not candidates:
return []
if algorithm not in {"linucb", "thompson"}:
algorithm = "linucb"
policy = ContextualBanditPolicy(
algorithm=algorithm, # type: ignore[arg-type]
epsilon=max(0.0, min(1.0, epsilon)),
seed=random.randint(1, 10_000),
)
return policy.propose(candidates=candidates, top_k=top_k)