File size: 7,201 Bytes
21c7db9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """Environment transition dynamics."""
from __future__ import annotations
from pathlib import Path
from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import PolyGuardAction, PolyGuardState
from app.dataops.parser import extract_components, extract_drug_mentions
from app.dataops.source_manager import SourceManager
from app.dataops.web_fallback import scrape_with_fallback
from app.knowledge.ddi_knowledge import top_risky_pairs
DOSE_BURDEN_WEIGHT = {
DoseBucket.LOW: 0.7,
DoseBucket.MEDIUM: 1.0,
DoseBucket.HIGH: 1.25,
DoseBucket.HOLD: 0.45,
DoseBucket.NA: 1.0,
}
def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None:
if not drug:
return None
for idx, med in enumerate(state.patient.medications):
if med.drug == drug:
return idx
return None
def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]:
delta: dict[str, object] = {"applied": True, "changes": []}
meds = state.patient.medications
target_idx = _find_med_idx(state, action.target_drug)
state.active_mode = action.mode
if action.action_type == ActionType.KEEP_REGIMEN:
delta["changes"].append("no_change")
elif action.action_type == ActionType.STOP_DRUG and target_idx is not None:
removed = meds.pop(target_idx)
delta["changes"].append(f"stopped:{removed.drug}")
elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug:
old = meds[target_idx].drug
meds[target_idx].drug = action.replacement_drug
delta["changes"].append(f"substituted:{old}->{action.replacement_drug}")
elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug:
old = meds[target_idx].drug
meds[target_idx].drug = action.replacement_drug
delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}")
elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None:
bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH]
current = meds[target_idx].dose_bucket
if current in bucket_order:
cur_idx = bucket_order.index(current)
if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0:
meds[target_idx].dose_bucket = bucket_order[cur_idx - 1]
if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1:
meds[target_idx].dose_bucket = bucket_order[cur_idx + 1]
delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}")
elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None:
meds[target_idx].dose_bucket = DoseBucket.HOLD
delta["changes"].append(f"held:{meds[target_idx].drug}")
elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT:
if target_idx is not None:
meds[target_idx].dose_bucket = DoseBucket.HOLD
delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}")
state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")]
delta["changes"].append("monitoring_ordered")
elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None:
meds[target_idx].requires_taper = True
delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d")
elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None:
meds[target_idx].dose_bucket = DoseBucket.LOW
delta["changes"].append(f"taper_continue:{meds[target_idx].drug}")
elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}:
state.active_mode = DecisionMode.REVIEW
state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}")
delta["changes"].append(f"review:{action.action_type.value}")
elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE:
text = ""
allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"]
query = (action.evidence_query or "").strip()
if query.startswith("http"):
manager = SourceManager(root=Path(__file__).resolve().parents[2])
try:
fetched = manager.fetch_with_cache(
url=query,
allow_domains=allow_domains,
namespace="evidence_fetch",
offline_first=True,
)
text = str(fetched.get("text", ""))
delta["changes"].append("evidence_cached_or_fetched")
except Exception:
fallback = scrape_with_fallback(query, allow_domains=allow_domains)
text = str(fallback.get("text", ""))
delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}")
else:
text = query
delta["changes"].append("evidence_query_recorded")
mentions = extract_drug_mentions(text)
components = extract_components(text)
state.risk_summary["external_mentions_count"] = float(len(mentions))
state.risk_summary["external_components_count"] = float(len(components))
state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item]
elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG:
seed_text = (
" ".join(action.candidate_components)
if action.candidate_components
else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}"
)
extracted = extract_components(seed_text)
fallback_components = [token for token in (action.candidate_components or []) if token]
components = extracted or fallback_components
state.risk_summary["new_drug_component_count"] = float(len(components))
state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0
state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item]
delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}")
state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")})
state.step_count += 1
# Dose-aware burden update so dose optimization has a real reward signal.
dose_weighted_burden = sum(DOSE_BURDEN_WEIGHT.get(med.dose_bucket, 1.0) for med in meds)
state.burden_score = max(0.0, min(1.0, dose_weighted_burden / 12.0))
state.risk_summary["polypharmacy_count"] = float(len(meds))
state.risk_summary["burden_score"] = float(state.burden_score)
state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds])))
delta["state"] = {"step_count": state.step_count, "med_count": len(meds)}
return delta
|