File size: 1,751 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import CandidateAction
from app.models.baselines.contextual_bandit import choose_contextual_bandit_topk
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy


def _candidate(idx: int, delta: float, uncertainty: float, legal: bool = True) -> CandidateAction:
    return CandidateAction(
        candidate_id=f"cand_{idx:02d}",
        mode=DecisionMode.REGIMEN_OPT,
        action_type=ActionType.KEEP_REGIMEN,
        target_drug=None,
        replacement_drug=None,
        dose_bucket=DoseBucket.NA,
        taper_days=None,
        monitoring_plan=None,
        estimated_safety_delta=delta,
        burden_delta=0.0,
        disease_stability_estimate=0.8,
        uncertainty_score=uncertainty,
        rationale_tags=["test"],
        required_monitoring=[],
        legality_precheck=legal,
    )


def test_bandit_topk_returns_ranked_candidates() -> None:
    items = [
        _candidate(1, 0.10, 0.50),
        _candidate(2, 0.25, 0.20),
        _candidate(3, 0.05, 0.10),
    ]
    topk = choose_contextual_bandit_topk(items, top_k=2, algorithm="linucb")
    assert len(topk) == 2
    assert {item.candidate_id for item in topk}.issubset({"cand_01", "cand_02", "cand_03"})


def test_bandit_policy_update_runs() -> None:
    items = [_candidate(1, 0.1, 0.4), _candidate(2, 0.2, 0.3)]
    policy = ContextualBanditPolicy(algorithm="linucb", epsilon=0.0, seed=4)
    proposal = policy.propose(items, top_k=1)
    assert proposal
    chosen = next(item for item in items if item.candidate_id == proposal[0].candidate_id)
    policy.update(chosen, reward=0.8)
    proposal2 = policy.propose(items, top_k=1)
    assert proposal2