from app.common.enums import ActionType, DecisionMode, DoseBucket from app.common.types import CandidateAction from app.models.baselines.contextual_bandit import choose_contextual_bandit_topk from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy def _candidate(idx: int, delta: float, uncertainty: float, legal: bool = True) -> CandidateAction: return CandidateAction( candidate_id=f"cand_{idx:02d}", mode=DecisionMode.REGIMEN_OPT, action_type=ActionType.KEEP_REGIMEN, target_drug=None, replacement_drug=None, dose_bucket=DoseBucket.NA, taper_days=None, monitoring_plan=None, estimated_safety_delta=delta, burden_delta=0.0, disease_stability_estimate=0.8, uncertainty_score=uncertainty, rationale_tags=["test"], required_monitoring=[], legality_precheck=legal, ) def test_bandit_topk_returns_ranked_candidates() -> None: items = [ _candidate(1, 0.10, 0.50), _candidate(2, 0.25, 0.20), _candidate(3, 0.05, 0.10), ] topk = choose_contextual_bandit_topk(items, top_k=2, algorithm="linucb") assert len(topk) == 2 assert {item.candidate_id for item in topk}.issubset({"cand_01", "cand_02", "cand_03"}) def test_bandit_policy_update_runs() -> None: items = [_candidate(1, 0.1, 0.4), _candidate(2, 0.2, 0.3)] policy = ContextualBanditPolicy(algorithm="linucb", epsilon=0.0, seed=4) proposal = policy.propose(items, top_k=1) assert proposal chosen = next(item for item in items if item.candidate_id == proposal[0].candidate_id) policy.update(chosen, reward=0.8) proposal2 = policy.propose(items, top_k=1) assert proposal2