| """Constrained candidate action generation.""" |
|
|
| from __future__ import annotations |
|
|
| from app.common.enums import ActionType, DecisionMode, DoseBucket |
| from app.common.types import CandidateAction, PolyGuardAction, PolyGuardState |
| from app.env.verifier import verify_action_legality |
| from app.knowledge.ddi_knowledge import top_risky_pairs |
| from app.knowledge.hepatic_rules import is_hepatic_unsafe |
| from app.knowledge.renal_rules import is_renal_unsafe |
| from app.knowledge.substitution_rules import get_substitutions |
|
|
|
|
| def _base_candidate( |
| idx: int, |
| action_type: ActionType, |
| target_drug: str | None = None, |
| replacement_drug: str | None = None, |
| mode: DecisionMode = DecisionMode.REGIMEN_OPT, |
| ) -> CandidateAction: |
| return CandidateAction( |
| candidate_id=f"cand_{idx:02d}", |
| mode=mode, |
| action_type=action_type, |
| target_drug=target_drug, |
| replacement_drug=replacement_drug, |
| dose_bucket=DoseBucket.NA, |
| taper_days=14 if action_type == ActionType.TAPER_INITIATE else None, |
| monitoring_plan="repeat_labs_7d" if action_type == ActionType.ORDER_MONITORING_AND_WAIT else None, |
| estimated_safety_delta=0.02, |
| burden_delta=0.0, |
| disease_stability_estimate=0.85, |
| uncertainty_score=0.45, |
| rationale_tags=["rule_based_seed"], |
| required_monitoring=[], |
| legality_precheck=True, |
| ) |
|
|
|
|
| def _to_action(candidate: CandidateAction) -> PolyGuardAction: |
| return PolyGuardAction( |
| mode=candidate.mode, |
| action_type=candidate.action_type, |
| target_drug=candidate.target_drug, |
| replacement_drug=candidate.replacement_drug, |
| dose_bucket=candidate.dose_bucket, |
| taper_days=candidate.taper_days, |
| monitoring_plan=candidate.monitoring_plan, |
| evidence_query=candidate.evidence_query, |
| new_drug_name=candidate.new_drug_name, |
| candidate_components=candidate.candidate_components, |
| candidate_id=candidate.candidate_id, |
| confidence=max(0.45, 1.0 - candidate.uncertainty_score), |
| rationale_brief="candidate_precheck", |
| ) |
|
|
|
|
| def build_candidates(state: PolyGuardState) -> list[CandidateAction]: |
| meds = state.patient.medications |
| candidates: list[CandidateAction] = [] |
| risky_pairs = top_risky_pairs([m.drug for m in meds]) |
| target_risky_drug = risky_pairs[0][0] if risky_pairs else (meds[0].drug if meds else None) |
|
|
| keep = _base_candidate(1, ActionType.KEEP_REGIMEN) |
| keep = keep.model_copy(update={"estimated_safety_delta": -0.02, "uncertainty_score": 0.5}) |
| candidates.append(keep) |
|
|
| if meds: |
| first = target_risky_drug or meds[0].drug |
| stop = _base_candidate(2, ActionType.STOP_DRUG, target_drug=first) |
| stop = stop.model_copy( |
| update={ |
| "estimated_safety_delta": 0.26, |
| "burden_delta": 0.12, |
| "disease_stability_estimate": 0.68 if first == "warfarin_like" else 0.81, |
| "uncertainty_score": 0.42, |
| "rationale_tags": ["ddi_reduction", "deprescribing"], |
| } |
| ) |
| candidates.append(stop) |
|
|
| dose_candidate = _base_candidate(3, ActionType.REDUCE_DOSE_BUCKET, target_drug=first) |
| candidates.append( |
| dose_candidate.model_copy( |
| update={ |
| "mode": DecisionMode.DOSE_OPT, |
| "dose_bucket": DoseBucket.LOW, |
| "estimated_safety_delta": 0.16, |
| "burden_delta": 0.03, |
| "uncertainty_score": 0.33, |
| "rationale_tags": ["dose_deintensification"], |
| } |
| ) |
| ) |
|
|
| subs = get_substitutions(first) |
| if subs: |
| preferred = subs[0] |
| candidates.append( |
| _base_candidate( |
| 4, |
| ActionType.SUBSTITUTE_WITHIN_CLASS, |
| target_drug=first, |
| replacement_drug=preferred, |
| ).model_copy( |
| update={ |
| "estimated_safety_delta": 0.22, |
| "burden_delta": 0.05, |
| "uncertainty_score": 0.36, |
| "rationale_tags": ["therapeutic_substitution"], |
| } |
| ) |
| ) |
|
|
| for med in meds: |
| if is_renal_unsafe(med.drug, state.patient.labs.egfr) or is_hepatic_unsafe(med.drug, state.patient.labs.ast, state.patient.labs.alt): |
| hold = _base_candidate(5, ActionType.DOSE_HOLD, target_drug=med.drug, mode=DecisionMode.DOSE_OPT).model_copy( |
| update={ |
| "monitoring_plan": "repeat_labs_72h", |
| "estimated_safety_delta": 0.2, |
| "disease_stability_estimate": 0.74, |
| "uncertainty_score": 0.28, |
| "required_monitoring": ["renal_or_hepatic_panel"], |
| "rationale_tags": ["organ_function_guardrail"], |
| } |
| ) |
| candidates.append(hold) |
| break |
|
|
| monitoring = _base_candidate(8, ActionType.ORDER_MONITORING_AND_WAIT, mode=DecisionMode.DOSE_OPT).model_copy( |
| update={ |
| "monitoring_plan": "vitals_labs_7d", |
| "estimated_safety_delta": 0.1, |
| "disease_stability_estimate": 0.88, |
| "uncertainty_score": 0.26, |
| "rationale_tags": ["monitor_before_change"], |
| "required_monitoring": ["cbc", "cmp"], |
| } |
| ) |
| candidates.append(monitoring) |
|
|
| pharm = _base_candidate(9, ActionType.REQUEST_PHARMACIST_REVIEW, mode=DecisionMode.REVIEW).model_copy( |
| update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.18, "rationale_tags": ["abstain_for_review"]} |
| ) |
| spec = _base_candidate(10, ActionType.REQUEST_SPECIALIST_REVIEW, mode=DecisionMode.REVIEW).model_copy( |
| update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.2, "rationale_tags": ["abstain_for_review"]} |
| ) |
| candidates.extend([pharm, spec]) |
|
|
| if state.sub_environment.value == "BANDIT_MINING" and meds: |
| bandit = _base_candidate(6, ActionType.KEEP_REGIMEN).model_copy( |
| update={ |
| "candidate_id": "cand_06", |
| "mode": DecisionMode.REGIMEN_OPT, |
| "estimated_safety_delta": 0.08, |
| "burden_delta": 0.01, |
| "uncertainty_score": 0.31, |
| "rationale_tags": ["contextual_bandit_exploration"], |
| } |
| ) |
| candidates.append(bandit) |
|
|
| if state.sub_environment.value == "WEB_SEARCH_MISSING_DATA": |
| candidates.append( |
| _base_candidate(7, ActionType.FETCH_EXTERNAL_EVIDENCE, mode=DecisionMode.REVIEW).model_copy( |
| update={ |
| "candidate_id": "cand_07", |
| "evidence_query": "https://www.nih.gov", |
| "estimated_safety_delta": 0.11, |
| "disease_stability_estimate": 0.84, |
| "uncertainty_score": 0.22, |
| "rationale_tags": ["missing_data_recovery", "external_evidence_fetch"], |
| } |
| ) |
| ) |
|
|
| if state.sub_environment.value == "ALTERNATIVE_SUGGESTION" and meds: |
| alt_target = meds[0].drug |
| alt_replacements = get_substitutions(alt_target) |
| if alt_replacements: |
| candidates.append( |
| _base_candidate( |
| 11, |
| ActionType.RECOMMEND_ALTERNATIVE, |
| target_drug=alt_target, |
| replacement_drug=alt_replacements[0], |
| mode=DecisionMode.REGIMEN_OPT, |
| ).model_copy( |
| update={ |
| "candidate_id": "cand_11", |
| "estimated_safety_delta": 0.24, |
| "burden_delta": 0.04, |
| "uncertainty_score": 0.29, |
| "rationale_tags": ["alternative_suggestion", "safer_addition_or_swap"], |
| } |
| ) |
| ) |
|
|
| if state.sub_environment.value == "NEW_DRUG_DECOMPOSITION": |
| candidates.append( |
| _base_candidate(12, ActionType.DECOMPOSE_NEW_DRUG, mode=DecisionMode.REVIEW).model_copy( |
| update={ |
| "candidate_id": "cand_12", |
| "new_drug_name": "novel_combination_x", |
| "candidate_components": ["novel_component_a", "novel_component_b"], |
| "estimated_safety_delta": 0.14, |
| "disease_stability_estimate": 0.8, |
| "uncertainty_score": 0.24, |
| "rationale_tags": ["new_drug_component_analysis"], |
| } |
| ) |
| ) |
|
|
| priority_by_subenv = { |
| "WEB_SEARCH_MISSING_DATA": ActionType.FETCH_EXTERNAL_EVIDENCE, |
| "ALTERNATIVE_SUGGESTION": ActionType.RECOMMEND_ALTERNATIVE, |
| "NEW_DRUG_DECOMPOSITION": ActionType.DECOMPOSE_NEW_DRUG, |
| } |
| priority_action = priority_by_subenv.get(state.sub_environment.value) |
| if priority_action is not None: |
| prioritized = [item for item in candidates if item.action_type == priority_action] |
| non_prioritized = [item for item in candidates if item.action_type != priority_action] |
| candidates = prioritized + non_prioritized |
|
|
| |
| limited = candidates[:10] |
| if len(limited) < 3: |
| limited.extend([_base_candidate(i + 10, ActionType.KEEP_REGIMEN) for i in range(3 - len(limited))]) |
| validated: list[CandidateAction] = [] |
| for candidate in limited: |
| legal = verify_action_legality(state, _to_action(candidate)).legal |
| validated.append(candidate.model_copy(update={"legality_precheck": legal})) |
| return validated |
|
|