"""Environment transition dynamics.""" from __future__ import annotations from pathlib import Path from app.common.enums import ActionType, DecisionMode, DoseBucket from app.common.types import PolyGuardAction, PolyGuardState from app.dataops.parser import extract_components, extract_drug_mentions from app.dataops.source_manager import SourceManager from app.dataops.web_fallback import scrape_with_fallback from app.knowledge.ddi_knowledge import top_risky_pairs DOSE_BURDEN_WEIGHT = { DoseBucket.LOW: 0.7, DoseBucket.MEDIUM: 1.0, DoseBucket.HIGH: 1.25, DoseBucket.HOLD: 0.45, DoseBucket.NA: 1.0, } def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None: if not drug: return None for idx, med in enumerate(state.patient.medications): if med.drug == drug: return idx return None def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]: delta: dict[str, object] = {"applied": True, "changes": []} meds = state.patient.medications target_idx = _find_med_idx(state, action.target_drug) state.active_mode = action.mode if action.action_type == ActionType.KEEP_REGIMEN: delta["changes"].append("no_change") elif action.action_type == ActionType.STOP_DRUG and target_idx is not None: removed = meds.pop(target_idx) delta["changes"].append(f"stopped:{removed.drug}") elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug: old = meds[target_idx].drug meds[target_idx].drug = action.replacement_drug delta["changes"].append(f"substituted:{old}->{action.replacement_drug}") elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug: old = meds[target_idx].drug meds[target_idx].drug = action.replacement_drug delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}") elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None: bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH] current = meds[target_idx].dose_bucket if current in bucket_order: cur_idx = bucket_order.index(current) if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0: meds[target_idx].dose_bucket = bucket_order[cur_idx - 1] if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1: meds[target_idx].dose_bucket = bucket_order[cur_idx + 1] delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}") elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None: meds[target_idx].dose_bucket = DoseBucket.HOLD delta["changes"].append(f"held:{meds[target_idx].drug}") elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT: if target_idx is not None: meds[target_idx].dose_bucket = DoseBucket.HOLD delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}") state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")] delta["changes"].append("monitoring_ordered") elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None: meds[target_idx].requires_taper = True delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d") elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None: meds[target_idx].dose_bucket = DoseBucket.LOW delta["changes"].append(f"taper_continue:{meds[target_idx].drug}") elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}: state.active_mode = DecisionMode.REVIEW state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}") delta["changes"].append(f"review:{action.action_type.value}") elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE: text = "" allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"] query = (action.evidence_query or "").strip() if query.startswith("http"): manager = SourceManager(root=Path(__file__).resolve().parents[2]) try: fetched = manager.fetch_with_cache( url=query, allow_domains=allow_domains, namespace="evidence_fetch", offline_first=True, ) text = str(fetched.get("text", "")) delta["changes"].append("evidence_cached_or_fetched") except Exception: fallback = scrape_with_fallback(query, allow_domains=allow_domains) text = str(fallback.get("text", "")) delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}") else: text = query delta["changes"].append("evidence_query_recorded") mentions = extract_drug_mentions(text) components = extract_components(text) state.risk_summary["external_mentions_count"] = float(len(mentions)) state.risk_summary["external_components_count"] = float(len(components)) state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item] elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG: seed_text = ( " ".join(action.candidate_components) if action.candidate_components else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}" ) extracted = extract_components(seed_text) fallback_components = [token for token in (action.candidate_components or []) if token] components = extracted or fallback_components state.risk_summary["new_drug_component_count"] = float(len(components)) state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0 state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item] delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}") state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")}) state.step_count += 1 # Dose-aware burden update so dose optimization has a real reward signal. dose_weighted_burden = sum(DOSE_BURDEN_WEIGHT.get(med.dose_bucket, 1.0) for med in meds) state.burden_score = max(0.0, min(1.0, dose_weighted_burden / 12.0)) state.risk_summary["polypharmacy_count"] = float(len(meds)) state.risk_summary["burden_score"] = float(state.burden_score) state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds]))) delta["state"] = {"step_count": state.step_count, "med_count": len(meds)} return delta