Spaces:
Running
Running
| """Constrained candidate action generation.""" | |
| from __future__ import annotations | |
| from app.common.enums import ActionType, DecisionMode, DoseBucket | |
| from app.common.types import CandidateAction, PolyGuardAction, PolyGuardState | |
| from app.env.verifier import verify_action_legality | |
| from app.knowledge.ddi_knowledge import top_risky_pairs | |
| from app.knowledge.hepatic_rules import is_hepatic_unsafe | |
| from app.knowledge.renal_rules import is_renal_unsafe | |
| from app.knowledge.substitution_rules import get_substitutions | |
| def _base_candidate( | |
| idx: int, | |
| action_type: ActionType, | |
| target_drug: str | None = None, | |
| replacement_drug: str | None = None, | |
| mode: DecisionMode = DecisionMode.REGIMEN_OPT, | |
| ) -> CandidateAction: | |
| return CandidateAction( | |
| candidate_id=f"cand_{idx:02d}", | |
| mode=mode, | |
| action_type=action_type, | |
| target_drug=target_drug, | |
| replacement_drug=replacement_drug, | |
| dose_bucket=DoseBucket.NA, | |
| taper_days=14 if action_type == ActionType.TAPER_INITIATE else None, | |
| monitoring_plan="repeat_labs_7d" if action_type == ActionType.ORDER_MONITORING_AND_WAIT else None, | |
| estimated_safety_delta=0.02, | |
| burden_delta=0.0, | |
| disease_stability_estimate=0.85, | |
| uncertainty_score=0.45, | |
| rationale_tags=["rule_based_seed"], | |
| required_monitoring=[], | |
| legality_precheck=True, | |
| ) | |
| def _to_action(candidate: CandidateAction) -> PolyGuardAction: | |
| return PolyGuardAction( | |
| mode=candidate.mode, | |
| action_type=candidate.action_type, | |
| target_drug=candidate.target_drug, | |
| replacement_drug=candidate.replacement_drug, | |
| dose_bucket=candidate.dose_bucket, | |
| taper_days=candidate.taper_days, | |
| monitoring_plan=candidate.monitoring_plan, | |
| evidence_query=candidate.evidence_query, | |
| new_drug_name=candidate.new_drug_name, | |
| candidate_components=candidate.candidate_components, | |
| candidate_id=candidate.candidate_id, | |
| confidence=max(0.45, 1.0 - candidate.uncertainty_score), | |
| rationale_brief="candidate_precheck", | |
| ) | |
| def build_candidates(state: PolyGuardState) -> list[CandidateAction]: | |
| meds = state.patient.medications | |
| candidates: list[CandidateAction] = [] | |
| risky_pairs = top_risky_pairs([m.drug for m in meds]) | |
| target_risky_drug = risky_pairs[0][0] if risky_pairs else (meds[0].drug if meds else None) | |
| keep = _base_candidate(1, ActionType.KEEP_REGIMEN) | |
| keep = keep.model_copy(update={"estimated_safety_delta": -0.02, "uncertainty_score": 0.5}) | |
| candidates.append(keep) | |
| if meds: | |
| first = target_risky_drug or meds[0].drug | |
| stop = _base_candidate(2, ActionType.STOP_DRUG, target_drug=first) | |
| stop = stop.model_copy( | |
| update={ | |
| "estimated_safety_delta": 0.26, | |
| "burden_delta": 0.12, | |
| "disease_stability_estimate": 0.68 if first == "warfarin_like" else 0.81, | |
| "uncertainty_score": 0.42, | |
| "rationale_tags": ["ddi_reduction", "deprescribing"], | |
| } | |
| ) | |
| candidates.append(stop) | |
| dose_candidate = _base_candidate(3, ActionType.REDUCE_DOSE_BUCKET, target_drug=first) | |
| candidates.append( | |
| dose_candidate.model_copy( | |
| update={ | |
| "mode": DecisionMode.DOSE_OPT, | |
| "dose_bucket": DoseBucket.LOW, | |
| "estimated_safety_delta": 0.16, | |
| "burden_delta": 0.03, | |
| "uncertainty_score": 0.33, | |
| "rationale_tags": ["dose_deintensification"], | |
| } | |
| ) | |
| ) | |
| subs = get_substitutions(first) | |
| if subs: | |
| preferred = subs[0] | |
| candidates.append( | |
| _base_candidate( | |
| 4, | |
| ActionType.SUBSTITUTE_WITHIN_CLASS, | |
| target_drug=first, | |
| replacement_drug=preferred, | |
| ).model_copy( | |
| update={ | |
| "estimated_safety_delta": 0.22, | |
| "burden_delta": 0.05, | |
| "uncertainty_score": 0.36, | |
| "rationale_tags": ["therapeutic_substitution"], | |
| } | |
| ) | |
| ) | |
| for med in meds: | |
| if is_renal_unsafe(med.drug, state.patient.labs.egfr) or is_hepatic_unsafe(med.drug, state.patient.labs.ast, state.patient.labs.alt): | |
| hold = _base_candidate(5, ActionType.DOSE_HOLD, target_drug=med.drug, mode=DecisionMode.DOSE_OPT).model_copy( | |
| update={ | |
| "monitoring_plan": "repeat_labs_72h", | |
| "estimated_safety_delta": 0.2, | |
| "disease_stability_estimate": 0.74, | |
| "uncertainty_score": 0.28, | |
| "required_monitoring": ["renal_or_hepatic_panel"], | |
| "rationale_tags": ["organ_function_guardrail"], | |
| } | |
| ) | |
| candidates.append(hold) | |
| break | |
| monitoring = _base_candidate(8, ActionType.ORDER_MONITORING_AND_WAIT, mode=DecisionMode.DOSE_OPT).model_copy( | |
| update={ | |
| "monitoring_plan": "vitals_labs_7d", | |
| "estimated_safety_delta": 0.1, | |
| "disease_stability_estimate": 0.88, | |
| "uncertainty_score": 0.26, | |
| "rationale_tags": ["monitor_before_change"], | |
| "required_monitoring": ["cbc", "cmp"], | |
| } | |
| ) | |
| candidates.append(monitoring) | |
| pharm = _base_candidate(9, ActionType.REQUEST_PHARMACIST_REVIEW, mode=DecisionMode.REVIEW).model_copy( | |
| update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.18, "rationale_tags": ["abstain_for_review"]} | |
| ) | |
| spec = _base_candidate(10, ActionType.REQUEST_SPECIALIST_REVIEW, mode=DecisionMode.REVIEW).model_copy( | |
| update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.2, "rationale_tags": ["abstain_for_review"]} | |
| ) | |
| candidates.extend([pharm, spec]) | |
| if state.sub_environment.value == "BANDIT_MINING" and meds: | |
| bandit = _base_candidate(6, ActionType.KEEP_REGIMEN).model_copy( | |
| update={ | |
| "candidate_id": "cand_06", | |
| "mode": DecisionMode.REGIMEN_OPT, | |
| "estimated_safety_delta": 0.08, | |
| "burden_delta": 0.01, | |
| "uncertainty_score": 0.31, | |
| "rationale_tags": ["contextual_bandit_exploration"], | |
| } | |
| ) | |
| candidates.append(bandit) | |
| if state.sub_environment.value == "WEB_SEARCH_MISSING_DATA": | |
| candidates.append( | |
| _base_candidate(7, ActionType.FETCH_EXTERNAL_EVIDENCE, mode=DecisionMode.REVIEW).model_copy( | |
| update={ | |
| "candidate_id": "cand_07", | |
| "evidence_query": "https://www.nih.gov", | |
| "estimated_safety_delta": 0.11, | |
| "disease_stability_estimate": 0.84, | |
| "uncertainty_score": 0.22, | |
| "rationale_tags": ["missing_data_recovery", "external_evidence_fetch"], | |
| } | |
| ) | |
| ) | |
| if state.sub_environment.value == "ALTERNATIVE_SUGGESTION" and meds: | |
| alt_target = meds[0].drug | |
| alt_replacements = get_substitutions(alt_target) | |
| if alt_replacements: | |
| candidates.append( | |
| _base_candidate( | |
| 11, | |
| ActionType.RECOMMEND_ALTERNATIVE, | |
| target_drug=alt_target, | |
| replacement_drug=alt_replacements[0], | |
| mode=DecisionMode.REGIMEN_OPT, | |
| ).model_copy( | |
| update={ | |
| "candidate_id": "cand_11", | |
| "estimated_safety_delta": 0.24, | |
| "burden_delta": 0.04, | |
| "uncertainty_score": 0.29, | |
| "rationale_tags": ["alternative_suggestion", "safer_addition_or_swap"], | |
| } | |
| ) | |
| ) | |
| if state.sub_environment.value == "NEW_DRUG_DECOMPOSITION": | |
| candidates.append( | |
| _base_candidate(12, ActionType.DECOMPOSE_NEW_DRUG, mode=DecisionMode.REVIEW).model_copy( | |
| update={ | |
| "candidate_id": "cand_12", | |
| "new_drug_name": "novel_combination_x", | |
| "candidate_components": ["novel_component_a", "novel_component_b"], | |
| "estimated_safety_delta": 0.14, | |
| "disease_stability_estimate": 0.8, | |
| "uncertainty_score": 0.24, | |
| "rationale_tags": ["new_drug_component_analysis"], | |
| } | |
| ) | |
| ) | |
| priority_by_subenv = { | |
| "WEB_SEARCH_MISSING_DATA": ActionType.FETCH_EXTERNAL_EVIDENCE, | |
| "ALTERNATIVE_SUGGESTION": ActionType.RECOMMEND_ALTERNATIVE, | |
| "NEW_DRUG_DECOMPOSITION": ActionType.DECOMPOSE_NEW_DRUG, | |
| } | |
| priority_action = priority_by_subenv.get(state.sub_environment.value) | |
| if priority_action is not None: | |
| prioritized = [item for item in candidates if item.action_type == priority_action] | |
| non_prioritized = [item for item in candidates if item.action_type != priority_action] | |
| candidates = prioritized + non_prioritized | |
| # Strict 3..10. | |
| limited = candidates[:10] | |
| if len(limited) < 3: | |
| limited.extend([_base_candidate(i + 10, ActionType.KEEP_REGIMEN) for i in range(3 - len(limited))]) | |
| validated: list[CandidateAction] = [] | |
| for candidate in limited: | |
| legal = verify_action_legality(state, _to_action(candidate)).legal | |
| validated.append(candidate.model_copy(update={"legality_precheck": legal})) | |
| return validated | |