File size: 6,889 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Environment transition dynamics."""

from __future__ import annotations

from pathlib import Path

from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import PolyGuardAction, PolyGuardState
from app.dataops.parser import extract_components, extract_drug_mentions
from app.dataops.source_manager import SourceManager
from app.dataops.web_fallback import scrape_with_fallback
from app.knowledge.ddi_knowledge import top_risky_pairs


def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None:
    if not drug:
        return None
    for idx, med in enumerate(state.patient.medications):
        if med.drug == drug:
            return idx
    return None


def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]:
    delta: dict[str, object] = {"applied": True, "changes": []}
    meds = state.patient.medications
    target_idx = _find_med_idx(state, action.target_drug)
    state.active_mode = action.mode

    if action.action_type == ActionType.KEEP_REGIMEN:
        delta["changes"].append("no_change")

    elif action.action_type == ActionType.STOP_DRUG and target_idx is not None:
        removed = meds.pop(target_idx)
        delta["changes"].append(f"stopped:{removed.drug}")

    elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug:
        old = meds[target_idx].drug
        meds[target_idx].drug = action.replacement_drug
        delta["changes"].append(f"substituted:{old}->{action.replacement_drug}")

    elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug:
        old = meds[target_idx].drug
        meds[target_idx].drug = action.replacement_drug
        delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}")

    elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None:
        bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH]
        current = meds[target_idx].dose_bucket
        if current in bucket_order:
            cur_idx = bucket_order.index(current)
            if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0:
                meds[target_idx].dose_bucket = bucket_order[cur_idx - 1]
            if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1:
                meds[target_idx].dose_bucket = bucket_order[cur_idx + 1]
            delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}")

    elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None:
        meds[target_idx].dose_bucket = DoseBucket.HOLD
        delta["changes"].append(f"held:{meds[target_idx].drug}")

    elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT:
        if target_idx is not None:
            meds[target_idx].dose_bucket = DoseBucket.HOLD
            delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}")
        state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")]
        delta["changes"].append("monitoring_ordered")

    elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None:
        meds[target_idx].requires_taper = True
        delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d")

    elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None:
        meds[target_idx].dose_bucket = DoseBucket.LOW
        delta["changes"].append(f"taper_continue:{meds[target_idx].drug}")

    elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}:
        state.active_mode = DecisionMode.REVIEW
        state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}")
        delta["changes"].append(f"review:{action.action_type.value}")

    elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE:
        text = ""
        allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"]
        query = (action.evidence_query or "").strip()
        if query.startswith("http"):
            manager = SourceManager(root=Path(__file__).resolve().parents[2])
            try:
                fetched = manager.fetch_with_cache(
                    url=query,
                    allow_domains=allow_domains,
                    namespace="evidence_fetch",
                    offline_first=True,
                )
                text = str(fetched.get("text", ""))
                delta["changes"].append("evidence_cached_or_fetched")
            except Exception:
                fallback = scrape_with_fallback(query, allow_domains=allow_domains)
                text = str(fallback.get("text", ""))
                delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}")
        else:
            text = query
            delta["changes"].append("evidence_query_recorded")
        mentions = extract_drug_mentions(text)
        components = extract_components(text)
        state.risk_summary["external_mentions_count"] = float(len(mentions))
        state.risk_summary["external_components_count"] = float(len(components))
        state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item]

    elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG:
        seed_text = (
            " ".join(action.candidate_components)
            if action.candidate_components
            else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}"
        )
        extracted = extract_components(seed_text)
        fallback_components = [token for token in (action.candidate_components or []) if token]
        components = extracted or fallback_components
        state.risk_summary["new_drug_component_count"] = float(len(components))
        state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0
        state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item]
        delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}")

    state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")})
    state.step_count += 1

    # Coarse burden update.
    state.burden_score = max(0.0, min(1.0, len(meds) / 12.0))
    state.risk_summary["polypharmacy_count"] = float(len(meds))
    state.risk_summary["burden_score"] = float(state.burden_score)
    state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds])))
    delta["state"] = {"step_count": state.step_count, "med_count": len(meds)}
    return delta