Spaces:
Running
Running
File size: 1,751 Bytes
877add7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import CandidateAction
from app.models.baselines.contextual_bandit import choose_contextual_bandit_topk
from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy
def _candidate(idx: int, delta: float, uncertainty: float, legal: bool = True) -> CandidateAction:
return CandidateAction(
candidate_id=f"cand_{idx:02d}",
mode=DecisionMode.REGIMEN_OPT,
action_type=ActionType.KEEP_REGIMEN,
target_drug=None,
replacement_drug=None,
dose_bucket=DoseBucket.NA,
taper_days=None,
monitoring_plan=None,
estimated_safety_delta=delta,
burden_delta=0.0,
disease_stability_estimate=0.8,
uncertainty_score=uncertainty,
rationale_tags=["test"],
required_monitoring=[],
legality_precheck=legal,
)
def test_bandit_topk_returns_ranked_candidates() -> None:
items = [
_candidate(1, 0.10, 0.50),
_candidate(2, 0.25, 0.20),
_candidate(3, 0.05, 0.10),
]
topk = choose_contextual_bandit_topk(items, top_k=2, algorithm="linucb")
assert len(topk) == 2
assert {item.candidate_id for item in topk}.issubset({"cand_01", "cand_02", "cand_03"})
def test_bandit_policy_update_runs() -> None:
items = [_candidate(1, 0.1, 0.4), _candidate(2, 0.2, 0.3)]
policy = ContextualBanditPolicy(algorithm="linucb", epsilon=0.0, seed=4)
proposal = policy.propose(items, top_k=1)
assert proposal
chosen = next(item for item in items if item.candidate_id == proposal[0].candidate_id)
policy.update(chosen, reward=0.8)
proposal2 = policy.propose(items, top_k=1)
assert proposal2
|