| """Environment transition dynamics.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| from app.common.enums import ActionType, DecisionMode, DoseBucket |
| from app.common.types import PolyGuardAction, PolyGuardState |
| from app.dataops.parser import extract_components, extract_drug_mentions |
| from app.dataops.source_manager import SourceManager |
| from app.dataops.web_fallback import scrape_with_fallback |
| from app.knowledge.ddi_knowledge import top_risky_pairs |
|
|
|
|
| DOSE_BURDEN_WEIGHT = { |
| DoseBucket.LOW: 0.7, |
| DoseBucket.MEDIUM: 1.0, |
| DoseBucket.HIGH: 1.25, |
| DoseBucket.HOLD: 0.45, |
| DoseBucket.NA: 1.0, |
| } |
|
|
|
|
| def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None: |
| if not drug: |
| return None |
| for idx, med in enumerate(state.patient.medications): |
| if med.drug == drug: |
| return idx |
| return None |
|
|
|
|
| def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]: |
| delta: dict[str, object] = {"applied": True, "changes": []} |
| meds = state.patient.medications |
| target_idx = _find_med_idx(state, action.target_drug) |
| state.active_mode = action.mode |
|
|
| if action.action_type == ActionType.KEEP_REGIMEN: |
| delta["changes"].append("no_change") |
|
|
| elif action.action_type == ActionType.STOP_DRUG and target_idx is not None: |
| removed = meds.pop(target_idx) |
| delta["changes"].append(f"stopped:{removed.drug}") |
|
|
| elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug: |
| old = meds[target_idx].drug |
| meds[target_idx].drug = action.replacement_drug |
| delta["changes"].append(f"substituted:{old}->{action.replacement_drug}") |
|
|
| elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug: |
| old = meds[target_idx].drug |
| meds[target_idx].drug = action.replacement_drug |
| delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}") |
|
|
| elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None: |
| bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH] |
| current = meds[target_idx].dose_bucket |
| if current in bucket_order: |
| cur_idx = bucket_order.index(current) |
| if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0: |
| meds[target_idx].dose_bucket = bucket_order[cur_idx - 1] |
| if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1: |
| meds[target_idx].dose_bucket = bucket_order[cur_idx + 1] |
| delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}") |
|
|
| elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None: |
| meds[target_idx].dose_bucket = DoseBucket.HOLD |
| delta["changes"].append(f"held:{meds[target_idx].drug}") |
|
|
| elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT: |
| if target_idx is not None: |
| meds[target_idx].dose_bucket = DoseBucket.HOLD |
| delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}") |
| state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")] |
| delta["changes"].append("monitoring_ordered") |
|
|
| elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None: |
| meds[target_idx].requires_taper = True |
| delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d") |
|
|
| elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None: |
| meds[target_idx].dose_bucket = DoseBucket.LOW |
| delta["changes"].append(f"taper_continue:{meds[target_idx].drug}") |
|
|
| elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}: |
| state.active_mode = DecisionMode.REVIEW |
| state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}") |
| delta["changes"].append(f"review:{action.action_type.value}") |
|
|
| elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE: |
| text = "" |
| allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"] |
| query = (action.evidence_query or "").strip() |
| if query.startswith("http"): |
| manager = SourceManager(root=Path(__file__).resolve().parents[2]) |
| try: |
| fetched = manager.fetch_with_cache( |
| url=query, |
| allow_domains=allow_domains, |
| namespace="evidence_fetch", |
| offline_first=True, |
| ) |
| text = str(fetched.get("text", "")) |
| delta["changes"].append("evidence_cached_or_fetched") |
| except Exception: |
| fallback = scrape_with_fallback(query, allow_domains=allow_domains) |
| text = str(fallback.get("text", "")) |
| delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}") |
| else: |
| text = query |
| delta["changes"].append("evidence_query_recorded") |
| mentions = extract_drug_mentions(text) |
| components = extract_components(text) |
| state.risk_summary["external_mentions_count"] = float(len(mentions)) |
| state.risk_summary["external_components_count"] = float(len(components)) |
| state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item] |
|
|
| elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG: |
| seed_text = ( |
| " ".join(action.candidate_components) |
| if action.candidate_components |
| else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}" |
| ) |
| extracted = extract_components(seed_text) |
| fallback_components = [token for token in (action.candidate_components or []) if token] |
| components = extracted or fallback_components |
| state.risk_summary["new_drug_component_count"] = float(len(components)) |
| state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0 |
| state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item] |
| delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}") |
|
|
| state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")}) |
| state.step_count += 1 |
|
|
| |
| dose_weighted_burden = sum(DOSE_BURDEN_WEIGHT.get(med.dose_bucket, 1.0) for med in meds) |
| state.burden_score = max(0.0, min(1.0, dose_weighted_burden / 12.0)) |
| state.risk_summary["polypharmacy_count"] = float(len(meds)) |
| state.risk_summary["burden_score"] = float(state.burden_score) |
| state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds]))) |
| delta["state"] = {"step_count": state.step_count, "med_count": len(meds)} |
| return delta |
|
|