| from app.common.enums import ActionType, DecisionMode, DoseBucket |
| from app.common.types import CandidateAction |
| from app.models.baselines.contextual_bandit import choose_contextual_bandit_topk |
| from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy |
|
|
|
|
| def _candidate(idx: int, delta: float, uncertainty: float, legal: bool = True) -> CandidateAction: |
| return CandidateAction( |
| candidate_id=f"cand_{idx:02d}", |
| mode=DecisionMode.REGIMEN_OPT, |
| action_type=ActionType.KEEP_REGIMEN, |
| target_drug=None, |
| replacement_drug=None, |
| dose_bucket=DoseBucket.NA, |
| taper_days=None, |
| monitoring_plan=None, |
| estimated_safety_delta=delta, |
| burden_delta=0.0, |
| disease_stability_estimate=0.8, |
| uncertainty_score=uncertainty, |
| rationale_tags=["test"], |
| required_monitoring=[], |
| legality_precheck=legal, |
| ) |
|
|
|
|
| def test_bandit_topk_returns_ranked_candidates() -> None: |
| items = [ |
| _candidate(1, 0.10, 0.50), |
| _candidate(2, 0.25, 0.20), |
| _candidate(3, 0.05, 0.10), |
| ] |
| topk = choose_contextual_bandit_topk(items, top_k=2, algorithm="linucb") |
| assert len(topk) == 2 |
| assert {item.candidate_id for item in topk}.issubset({"cand_01", "cand_02", "cand_03"}) |
|
|
|
|
| def test_bandit_policy_update_runs() -> None: |
| items = [_candidate(1, 0.1, 0.4), _candidate(2, 0.2, 0.3)] |
| policy = ContextualBanditPolicy(algorithm="linucb", epsilon=0.0, seed=4) |
| proposal = policy.propose(items, top_k=1) |
| assert proposal |
| chosen = next(item for item in items if item.candidate_id == proposal[0].candidate_id) |
| policy.update(chosen, reward=0.8) |
| proposal2 = policy.propose(items, top_k=1) |
| assert proposal2 |
|
|