Spaces:
Running
Running
| from app.common.enums import ActionType, DecisionMode, DoseBucket | |
| from app.common.types import CandidateAction | |
| from app.models.baselines.contextual_bandit import choose_contextual_bandit_topk | |
| from app.models.baselines.contextual_bandit_policy import ContextualBanditPolicy | |
| def _candidate(idx: int, delta: float, uncertainty: float, legal: bool = True) -> CandidateAction: | |
| return CandidateAction( | |
| candidate_id=f"cand_{idx:02d}", | |
| mode=DecisionMode.REGIMEN_OPT, | |
| action_type=ActionType.KEEP_REGIMEN, | |
| target_drug=None, | |
| replacement_drug=None, | |
| dose_bucket=DoseBucket.NA, | |
| taper_days=None, | |
| monitoring_plan=None, | |
| estimated_safety_delta=delta, | |
| burden_delta=0.0, | |
| disease_stability_estimate=0.8, | |
| uncertainty_score=uncertainty, | |
| rationale_tags=["test"], | |
| required_monitoring=[], | |
| legality_precheck=legal, | |
| ) | |
| def test_bandit_topk_returns_ranked_candidates() -> None: | |
| items = [ | |
| _candidate(1, 0.10, 0.50), | |
| _candidate(2, 0.25, 0.20), | |
| _candidate(3, 0.05, 0.10), | |
| ] | |
| topk = choose_contextual_bandit_topk(items, top_k=2, algorithm="linucb") | |
| assert len(topk) == 2 | |
| assert {item.candidate_id for item in topk}.issubset({"cand_01", "cand_02", "cand_03"}) | |
| def test_bandit_policy_update_runs() -> None: | |
| items = [_candidate(1, 0.1, 0.4), _candidate(2, 0.2, 0.3)] | |
| policy = ContextualBanditPolicy(algorithm="linucb", epsilon=0.0, seed=4) | |
| proposal = policy.propose(items, top_k=1) | |
| assert proposal | |
| chosen = next(item for item in items if item.candidate_id == proposal[0].candidate_id) | |
| policy.update(chosen, reward=0.8) | |
| proposal2 = policy.propose(items, top_k=1) | |
| assert proposal2 | |