Spaces:
Running
Running
| """Contextual bandit baseline and top-k proposer.""" | |
| from __future__ import annotations | |
| import random | |
| from app.common.types import CandidateAction, PolyGuardAction | |
| from app.models.baselines.contextual_bandit_policy import BanditProposal, ContextualBanditPolicy | |
| from app.models.baselines.rules_only import choose_rules_only | |
| def choose_contextual_bandit(candidates: list[CandidateAction], epsilon: float = 0.2) -> PolyGuardAction: | |
| proposals = choose_contextual_bandit_topk(candidates=candidates, top_k=1, epsilon=epsilon) | |
| if not proposals: | |
| return choose_rules_only(candidates) | |
| candidate_map = {item.candidate_id: item for item in candidates} | |
| top = candidate_map.get(proposals[0].candidate_id) | |
| if top is None: | |
| return choose_rules_only(candidates) | |
| return PolyGuardAction( | |
| mode=top.mode, | |
| action_type=top.action_type, | |
| target_drug=top.target_drug, | |
| replacement_drug=top.replacement_drug, | |
| dose_bucket=top.dose_bucket, | |
| taper_days=top.taper_days, | |
| monitoring_plan=top.monitoring_plan, | |
| candidate_id=top.candidate_id, | |
| confidence=0.68, | |
| rationale_brief="Contextual bandit selected candidate.", | |
| ) | |
| def choose_contextual_bandit_topk( | |
| candidates: list[CandidateAction], | |
| top_k: int = 3, | |
| epsilon: float = 0.2, | |
| algorithm: str = "linucb", | |
| ) -> list[BanditProposal]: | |
| if not candidates: | |
| return [] | |
| if algorithm not in {"linucb", "thompson"}: | |
| algorithm = "linucb" | |
| policy = ContextualBanditPolicy( | |
| algorithm=algorithm, # type: ignore[arg-type] | |
| epsilon=max(0.0, min(1.0, epsilon)), | |
| seed=random.randint(1, 10_000), | |
| ) | |
| return policy.propose(candidates=candidates, top_k=top_k) | |