File size: 7,201 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Environment transition dynamics."""

from __future__ import annotations

from pathlib import Path

from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import PolyGuardAction, PolyGuardState
from app.dataops.parser import extract_components, extract_drug_mentions
from app.dataops.source_manager import SourceManager
from app.dataops.web_fallback import scrape_with_fallback
from app.knowledge.ddi_knowledge import top_risky_pairs


DOSE_BURDEN_WEIGHT = {
    DoseBucket.LOW: 0.7,
    DoseBucket.MEDIUM: 1.0,
    DoseBucket.HIGH: 1.25,
    DoseBucket.HOLD: 0.45,
    DoseBucket.NA: 1.0,
}


def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None:
    if not drug:
        return None
    for idx, med in enumerate(state.patient.medications):
        if med.drug == drug:
            return idx
    return None


def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]:
    delta: dict[str, object] = {"applied": True, "changes": []}
    meds = state.patient.medications
    target_idx = _find_med_idx(state, action.target_drug)
    state.active_mode = action.mode

    if action.action_type == ActionType.KEEP_REGIMEN:
        delta["changes"].append("no_change")

    elif action.action_type == ActionType.STOP_DRUG and target_idx is not None:
        removed = meds.pop(target_idx)
        delta["changes"].append(f"stopped:{removed.drug}")

    elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug:
        old = meds[target_idx].drug
        meds[target_idx].drug = action.replacement_drug
        delta["changes"].append(f"substituted:{old}->{action.replacement_drug}")

    elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug:
        old = meds[target_idx].drug
        meds[target_idx].drug = action.replacement_drug
        delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}")

    elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None:
        bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH]
        current = meds[target_idx].dose_bucket
        if current in bucket_order:
            cur_idx = bucket_order.index(current)
            if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0:
                meds[target_idx].dose_bucket = bucket_order[cur_idx - 1]
            if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1:
                meds[target_idx].dose_bucket = bucket_order[cur_idx + 1]
            delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}")

    elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None:
        meds[target_idx].dose_bucket = DoseBucket.HOLD
        delta["changes"].append(f"held:{meds[target_idx].drug}")

    elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT:
        if target_idx is not None:
            meds[target_idx].dose_bucket = DoseBucket.HOLD
            delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}")
        state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")]
        delta["changes"].append("monitoring_ordered")

    elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None:
        meds[target_idx].requires_taper = True
        delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d")

    elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None:
        meds[target_idx].dose_bucket = DoseBucket.LOW
        delta["changes"].append(f"taper_continue:{meds[target_idx].drug}")

    elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}:
        state.active_mode = DecisionMode.REVIEW
        state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}")
        delta["changes"].append(f"review:{action.action_type.value}")

    elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE:
        text = ""
        allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"]
        query = (action.evidence_query or "").strip()
        if query.startswith("http"):
            manager = SourceManager(root=Path(__file__).resolve().parents[2])
            try:
                fetched = manager.fetch_with_cache(
                    url=query,
                    allow_domains=allow_domains,
                    namespace="evidence_fetch",
                    offline_first=True,
                )
                text = str(fetched.get("text", ""))
                delta["changes"].append("evidence_cached_or_fetched")
            except Exception:
                fallback = scrape_with_fallback(query, allow_domains=allow_domains)
                text = str(fallback.get("text", ""))
                delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}")
        else:
            text = query
            delta["changes"].append("evidence_query_recorded")
        mentions = extract_drug_mentions(text)
        components = extract_components(text)
        state.risk_summary["external_mentions_count"] = float(len(mentions))
        state.risk_summary["external_components_count"] = float(len(components))
        state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item]

    elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG:
        seed_text = (
            " ".join(action.candidate_components)
            if action.candidate_components
            else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}"
        )
        extracted = extract_components(seed_text)
        fallback_components = [token for token in (action.candidate_components or []) if token]
        components = extracted or fallback_components
        state.risk_summary["new_drug_component_count"] = float(len(components))
        state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0
        state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item]
        delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}")

    state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")})
    state.step_count += 1

    # Dose-aware burden update so dose optimization has a real reward signal.
    dose_weighted_burden = sum(DOSE_BURDEN_WEIGHT.get(med.dose_bucket, 1.0) for med in meds)
    state.burden_score = max(0.0, min(1.0, dose_weighted_burden / 12.0))
    state.risk_summary["polypharmacy_count"] = float(len(meds))
    state.risk_summary["burden_score"] = float(state.burden_score)
    state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds])))
    delta["state"] = {"step_count": state.step_count, "med_count": len(meds)}
    return delta