"""Constrained candidate action generation.""" from __future__ import annotations from app.common.enums import ActionType, DecisionMode, DoseBucket from app.common.types import CandidateAction, PolyGuardAction, PolyGuardState from app.env.verifier import verify_action_legality from app.knowledge.ddi_knowledge import top_risky_pairs from app.knowledge.hepatic_rules import is_hepatic_unsafe from app.knowledge.renal_rules import is_renal_unsafe from app.knowledge.substitution_rules import get_substitutions def _base_candidate( idx: int, action_type: ActionType, target_drug: str | None = None, replacement_drug: str | None = None, mode: DecisionMode = DecisionMode.REGIMEN_OPT, ) -> CandidateAction: return CandidateAction( candidate_id=f"cand_{idx:02d}", mode=mode, action_type=action_type, target_drug=target_drug, replacement_drug=replacement_drug, dose_bucket=DoseBucket.NA, taper_days=14 if action_type == ActionType.TAPER_INITIATE else None, monitoring_plan="repeat_labs_7d" if action_type == ActionType.ORDER_MONITORING_AND_WAIT else None, estimated_safety_delta=0.02, burden_delta=0.0, disease_stability_estimate=0.85, uncertainty_score=0.45, rationale_tags=["rule_based_seed"], required_monitoring=[], legality_precheck=True, ) def _to_action(candidate: CandidateAction) -> PolyGuardAction: return PolyGuardAction( mode=candidate.mode, action_type=candidate.action_type, target_drug=candidate.target_drug, replacement_drug=candidate.replacement_drug, dose_bucket=candidate.dose_bucket, taper_days=candidate.taper_days, monitoring_plan=candidate.monitoring_plan, evidence_query=candidate.evidence_query, new_drug_name=candidate.new_drug_name, candidate_components=candidate.candidate_components, candidate_id=candidate.candidate_id, confidence=max(0.45, 1.0 - candidate.uncertainty_score), rationale_brief="candidate_precheck", ) def build_candidates(state: PolyGuardState) -> list[CandidateAction]: meds = state.patient.medications candidates: list[CandidateAction] = [] risky_pairs = top_risky_pairs([m.drug for m in meds]) target_risky_drug = risky_pairs[0][0] if risky_pairs else (meds[0].drug if meds else None) keep = _base_candidate(1, ActionType.KEEP_REGIMEN) keep = keep.model_copy(update={"estimated_safety_delta": -0.02, "uncertainty_score": 0.5}) candidates.append(keep) if meds: first = target_risky_drug or meds[0].drug stop = _base_candidate(2, ActionType.STOP_DRUG, target_drug=first) stop = stop.model_copy( update={ "estimated_safety_delta": 0.26, "burden_delta": 0.12, "disease_stability_estimate": 0.68 if first == "warfarin_like" else 0.81, "uncertainty_score": 0.42, "rationale_tags": ["ddi_reduction", "deprescribing"], } ) candidates.append(stop) dose_candidate = _base_candidate(3, ActionType.REDUCE_DOSE_BUCKET, target_drug=first) candidates.append( dose_candidate.model_copy( update={ "mode": DecisionMode.DOSE_OPT, "dose_bucket": DoseBucket.LOW, "estimated_safety_delta": 0.16, "burden_delta": 0.03, "uncertainty_score": 0.33, "rationale_tags": ["dose_deintensification"], } ) ) subs = get_substitutions(first) if subs: preferred = subs[0] candidates.append( _base_candidate( 4, ActionType.SUBSTITUTE_WITHIN_CLASS, target_drug=first, replacement_drug=preferred, ).model_copy( update={ "estimated_safety_delta": 0.22, "burden_delta": 0.05, "uncertainty_score": 0.36, "rationale_tags": ["therapeutic_substitution"], } ) ) for med in meds: if is_renal_unsafe(med.drug, state.patient.labs.egfr) or is_hepatic_unsafe(med.drug, state.patient.labs.ast, state.patient.labs.alt): hold = _base_candidate(5, ActionType.DOSE_HOLD, target_drug=med.drug, mode=DecisionMode.DOSE_OPT).model_copy( update={ "monitoring_plan": "repeat_labs_72h", "estimated_safety_delta": 0.2, "disease_stability_estimate": 0.74, "uncertainty_score": 0.28, "required_monitoring": ["renal_or_hepatic_panel"], "rationale_tags": ["organ_function_guardrail"], } ) candidates.append(hold) break monitoring = _base_candidate(8, ActionType.ORDER_MONITORING_AND_WAIT, mode=DecisionMode.DOSE_OPT).model_copy( update={ "monitoring_plan": "vitals_labs_7d", "estimated_safety_delta": 0.1, "disease_stability_estimate": 0.88, "uncertainty_score": 0.26, "rationale_tags": ["monitor_before_change"], "required_monitoring": ["cbc", "cmp"], } ) candidates.append(monitoring) pharm = _base_candidate(9, ActionType.REQUEST_PHARMACIST_REVIEW, mode=DecisionMode.REVIEW).model_copy( update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.18, "rationale_tags": ["abstain_for_review"]} ) spec = _base_candidate(10, ActionType.REQUEST_SPECIALIST_REVIEW, mode=DecisionMode.REVIEW).model_copy( update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.2, "rationale_tags": ["abstain_for_review"]} ) candidates.extend([pharm, spec]) if state.sub_environment.value == "BANDIT_MINING" and meds: bandit = _base_candidate(6, ActionType.KEEP_REGIMEN).model_copy( update={ "candidate_id": "cand_06", "mode": DecisionMode.REGIMEN_OPT, "estimated_safety_delta": 0.08, "burden_delta": 0.01, "uncertainty_score": 0.31, "rationale_tags": ["contextual_bandit_exploration"], } ) candidates.append(bandit) if state.sub_environment.value == "WEB_SEARCH_MISSING_DATA": candidates.append( _base_candidate(7, ActionType.FETCH_EXTERNAL_EVIDENCE, mode=DecisionMode.REVIEW).model_copy( update={ "candidate_id": "cand_07", "evidence_query": "https://www.nih.gov", "estimated_safety_delta": 0.11, "disease_stability_estimate": 0.84, "uncertainty_score": 0.22, "rationale_tags": ["missing_data_recovery", "external_evidence_fetch"], } ) ) if state.sub_environment.value == "ALTERNATIVE_SUGGESTION" and meds: alt_target = meds[0].drug alt_replacements = get_substitutions(alt_target) if alt_replacements: candidates.append( _base_candidate( 11, ActionType.RECOMMEND_ALTERNATIVE, target_drug=alt_target, replacement_drug=alt_replacements[0], mode=DecisionMode.REGIMEN_OPT, ).model_copy( update={ "candidate_id": "cand_11", "estimated_safety_delta": 0.24, "burden_delta": 0.04, "uncertainty_score": 0.29, "rationale_tags": ["alternative_suggestion", "safer_addition_or_swap"], } ) ) if state.sub_environment.value == "NEW_DRUG_DECOMPOSITION": candidates.append( _base_candidate(12, ActionType.DECOMPOSE_NEW_DRUG, mode=DecisionMode.REVIEW).model_copy( update={ "candidate_id": "cand_12", "new_drug_name": "novel_combination_x", "candidate_components": ["novel_component_a", "novel_component_b"], "estimated_safety_delta": 0.14, "disease_stability_estimate": 0.8, "uncertainty_score": 0.24, "rationale_tags": ["new_drug_component_analysis"], } ) ) priority_by_subenv = { "WEB_SEARCH_MISSING_DATA": ActionType.FETCH_EXTERNAL_EVIDENCE, "ALTERNATIVE_SUGGESTION": ActionType.RECOMMEND_ALTERNATIVE, "NEW_DRUG_DECOMPOSITION": ActionType.DECOMPOSE_NEW_DRUG, } priority_action = priority_by_subenv.get(state.sub_environment.value) if priority_action is not None: prioritized = [item for item in candidates if item.action_type == priority_action] non_prioritized = [item for item in candidates if item.action_type != priority_action] candidates = prioritized + non_prioritized # Strict 3..10. limited = candidates[:10] if len(limited) < 3: limited.extend([_base_candidate(i + 10, ActionType.KEEP_REGIMEN) for i in range(3 - len(limited))]) validated: list[CandidateAction] = [] for candidate in limited: legal = verify_action_legality(state, _to_action(candidate)).legal validated.append(candidate.model_copy(update={"legality_precheck": legal})) return validated