TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Constrained candidate action generation."""
from __future__ import annotations
from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import CandidateAction, PolyGuardAction, PolyGuardState
from app.env.verifier import verify_action_legality
from app.knowledge.ddi_knowledge import top_risky_pairs
from app.knowledge.hepatic_rules import is_hepatic_unsafe
from app.knowledge.renal_rules import is_renal_unsafe
from app.knowledge.substitution_rules import get_substitutions
def _base_candidate(
idx: int,
action_type: ActionType,
target_drug: str | None = None,
replacement_drug: str | None = None,
mode: DecisionMode = DecisionMode.REGIMEN_OPT,
) -> CandidateAction:
return CandidateAction(
candidate_id=f"cand_{idx:02d}",
mode=mode,
action_type=action_type,
target_drug=target_drug,
replacement_drug=replacement_drug,
dose_bucket=DoseBucket.NA,
taper_days=14 if action_type == ActionType.TAPER_INITIATE else None,
monitoring_plan="repeat_labs_7d" if action_type == ActionType.ORDER_MONITORING_AND_WAIT else None,
estimated_safety_delta=0.02,
burden_delta=0.0,
disease_stability_estimate=0.85,
uncertainty_score=0.45,
rationale_tags=["rule_based_seed"],
required_monitoring=[],
legality_precheck=True,
)
def _to_action(candidate: CandidateAction) -> PolyGuardAction:
return PolyGuardAction(
mode=candidate.mode,
action_type=candidate.action_type,
target_drug=candidate.target_drug,
replacement_drug=candidate.replacement_drug,
dose_bucket=candidate.dose_bucket,
taper_days=candidate.taper_days,
monitoring_plan=candidate.monitoring_plan,
evidence_query=candidate.evidence_query,
new_drug_name=candidate.new_drug_name,
candidate_components=candidate.candidate_components,
candidate_id=candidate.candidate_id,
confidence=max(0.45, 1.0 - candidate.uncertainty_score),
rationale_brief="candidate_precheck",
)
def build_candidates(state: PolyGuardState) -> list[CandidateAction]:
meds = state.patient.medications
candidates: list[CandidateAction] = []
risky_pairs = top_risky_pairs([m.drug for m in meds])
target_risky_drug = risky_pairs[0][0] if risky_pairs else (meds[0].drug if meds else None)
keep = _base_candidate(1, ActionType.KEEP_REGIMEN)
keep = keep.model_copy(update={"estimated_safety_delta": -0.02, "uncertainty_score": 0.5})
candidates.append(keep)
if meds:
first = target_risky_drug or meds[0].drug
stop = _base_candidate(2, ActionType.STOP_DRUG, target_drug=first)
stop = stop.model_copy(
update={
"estimated_safety_delta": 0.26,
"burden_delta": 0.12,
"disease_stability_estimate": 0.68 if first == "warfarin_like" else 0.81,
"uncertainty_score": 0.42,
"rationale_tags": ["ddi_reduction", "deprescribing"],
}
)
candidates.append(stop)
dose_candidate = _base_candidate(3, ActionType.REDUCE_DOSE_BUCKET, target_drug=first)
candidates.append(
dose_candidate.model_copy(
update={
"mode": DecisionMode.DOSE_OPT,
"dose_bucket": DoseBucket.LOW,
"estimated_safety_delta": 0.16,
"burden_delta": 0.03,
"uncertainty_score": 0.33,
"rationale_tags": ["dose_deintensification"],
}
)
)
subs = get_substitutions(first)
if subs:
preferred = subs[0]
candidates.append(
_base_candidate(
4,
ActionType.SUBSTITUTE_WITHIN_CLASS,
target_drug=first,
replacement_drug=preferred,
).model_copy(
update={
"estimated_safety_delta": 0.22,
"burden_delta": 0.05,
"uncertainty_score": 0.36,
"rationale_tags": ["therapeutic_substitution"],
}
)
)
for med in meds:
if is_renal_unsafe(med.drug, state.patient.labs.egfr) or is_hepatic_unsafe(med.drug, state.patient.labs.ast, state.patient.labs.alt):
hold = _base_candidate(5, ActionType.DOSE_HOLD, target_drug=med.drug, mode=DecisionMode.DOSE_OPT).model_copy(
update={
"monitoring_plan": "repeat_labs_72h",
"estimated_safety_delta": 0.2,
"disease_stability_estimate": 0.74,
"uncertainty_score": 0.28,
"required_monitoring": ["renal_or_hepatic_panel"],
"rationale_tags": ["organ_function_guardrail"],
}
)
candidates.append(hold)
break
monitoring = _base_candidate(8, ActionType.ORDER_MONITORING_AND_WAIT, mode=DecisionMode.DOSE_OPT).model_copy(
update={
"monitoring_plan": "vitals_labs_7d",
"estimated_safety_delta": 0.1,
"disease_stability_estimate": 0.88,
"uncertainty_score": 0.26,
"rationale_tags": ["monitor_before_change"],
"required_monitoring": ["cbc", "cmp"],
}
)
candidates.append(monitoring)
pharm = _base_candidate(9, ActionType.REQUEST_PHARMACIST_REVIEW, mode=DecisionMode.REVIEW).model_copy(
update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.18, "rationale_tags": ["abstain_for_review"]}
)
spec = _base_candidate(10, ActionType.REQUEST_SPECIALIST_REVIEW, mode=DecisionMode.REVIEW).model_copy(
update={"estimated_safety_delta": 0.04, "uncertainty_score": 0.2, "rationale_tags": ["abstain_for_review"]}
)
candidates.extend([pharm, spec])
if state.sub_environment.value == "BANDIT_MINING" and meds:
bandit = _base_candidate(6, ActionType.KEEP_REGIMEN).model_copy(
update={
"candidate_id": "cand_06",
"mode": DecisionMode.REGIMEN_OPT,
"estimated_safety_delta": 0.08,
"burden_delta": 0.01,
"uncertainty_score": 0.31,
"rationale_tags": ["contextual_bandit_exploration"],
}
)
candidates.append(bandit)
if state.sub_environment.value == "WEB_SEARCH_MISSING_DATA":
candidates.append(
_base_candidate(7, ActionType.FETCH_EXTERNAL_EVIDENCE, mode=DecisionMode.REVIEW).model_copy(
update={
"candidate_id": "cand_07",
"evidence_query": "https://www.nih.gov",
"estimated_safety_delta": 0.11,
"disease_stability_estimate": 0.84,
"uncertainty_score": 0.22,
"rationale_tags": ["missing_data_recovery", "external_evidence_fetch"],
}
)
)
if state.sub_environment.value == "ALTERNATIVE_SUGGESTION" and meds:
alt_target = meds[0].drug
alt_replacements = get_substitutions(alt_target)
if alt_replacements:
candidates.append(
_base_candidate(
11,
ActionType.RECOMMEND_ALTERNATIVE,
target_drug=alt_target,
replacement_drug=alt_replacements[0],
mode=DecisionMode.REGIMEN_OPT,
).model_copy(
update={
"candidate_id": "cand_11",
"estimated_safety_delta": 0.24,
"burden_delta": 0.04,
"uncertainty_score": 0.29,
"rationale_tags": ["alternative_suggestion", "safer_addition_or_swap"],
}
)
)
if state.sub_environment.value == "NEW_DRUG_DECOMPOSITION":
candidates.append(
_base_candidate(12, ActionType.DECOMPOSE_NEW_DRUG, mode=DecisionMode.REVIEW).model_copy(
update={
"candidate_id": "cand_12",
"new_drug_name": "novel_combination_x",
"candidate_components": ["novel_component_a", "novel_component_b"],
"estimated_safety_delta": 0.14,
"disease_stability_estimate": 0.8,
"uncertainty_score": 0.24,
"rationale_tags": ["new_drug_component_analysis"],
}
)
)
priority_by_subenv = {
"WEB_SEARCH_MISSING_DATA": ActionType.FETCH_EXTERNAL_EVIDENCE,
"ALTERNATIVE_SUGGESTION": ActionType.RECOMMEND_ALTERNATIVE,
"NEW_DRUG_DECOMPOSITION": ActionType.DECOMPOSE_NEW_DRUG,
}
priority_action = priority_by_subenv.get(state.sub_environment.value)
if priority_action is not None:
prioritized = [item for item in candidates if item.action_type == priority_action]
non_prioritized = [item for item in candidates if item.action_type != priority_action]
candidates = prioritized + non_prioritized
# Strict 3..10.
limited = candidates[:10]
if len(limited) < 3:
limited.extend([_base_candidate(i + 10, ActionType.KEEP_REGIMEN) for i in range(3 - len(limited))])
validated: list[CandidateAction] = []
for candidate in limited:
legal = verify_action_legality(state, _to_action(candidate)).legal
validated.append(candidate.model_copy(update={"legality_precheck": legal}))
return validated