Spaces:
Running
Running
| """Environment transition dynamics.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from app.common.enums import ActionType, DecisionMode, DoseBucket | |
| from app.common.types import PolyGuardAction, PolyGuardState | |
| from app.dataops.parser import extract_components, extract_drug_mentions | |
| from app.dataops.source_manager import SourceManager | |
| from app.dataops.web_fallback import scrape_with_fallback | |
| from app.knowledge.ddi_knowledge import top_risky_pairs | |
| def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None: | |
| if not drug: | |
| return None | |
| for idx, med in enumerate(state.patient.medications): | |
| if med.drug == drug: | |
| return idx | |
| return None | |
| def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]: | |
| delta: dict[str, object] = {"applied": True, "changes": []} | |
| meds = state.patient.medications | |
| target_idx = _find_med_idx(state, action.target_drug) | |
| state.active_mode = action.mode | |
| if action.action_type == ActionType.KEEP_REGIMEN: | |
| delta["changes"].append("no_change") | |
| elif action.action_type == ActionType.STOP_DRUG and target_idx is not None: | |
| removed = meds.pop(target_idx) | |
| delta["changes"].append(f"stopped:{removed.drug}") | |
| elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug: | |
| old = meds[target_idx].drug | |
| meds[target_idx].drug = action.replacement_drug | |
| delta["changes"].append(f"substituted:{old}->{action.replacement_drug}") | |
| elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug: | |
| old = meds[target_idx].drug | |
| meds[target_idx].drug = action.replacement_drug | |
| delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}") | |
| elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None: | |
| bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH] | |
| current = meds[target_idx].dose_bucket | |
| if current in bucket_order: | |
| cur_idx = bucket_order.index(current) | |
| if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0: | |
| meds[target_idx].dose_bucket = bucket_order[cur_idx - 1] | |
| if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1: | |
| meds[target_idx].dose_bucket = bucket_order[cur_idx + 1] | |
| delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}") | |
| elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None: | |
| meds[target_idx].dose_bucket = DoseBucket.HOLD | |
| delta["changes"].append(f"held:{meds[target_idx].drug}") | |
| elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT: | |
| if target_idx is not None: | |
| meds[target_idx].dose_bucket = DoseBucket.HOLD | |
| delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}") | |
| state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")] | |
| delta["changes"].append("monitoring_ordered") | |
| elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None: | |
| meds[target_idx].requires_taper = True | |
| delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d") | |
| elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None: | |
| meds[target_idx].dose_bucket = DoseBucket.LOW | |
| delta["changes"].append(f"taper_continue:{meds[target_idx].drug}") | |
| elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}: | |
| state.active_mode = DecisionMode.REVIEW | |
| state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}") | |
| delta["changes"].append(f"review:{action.action_type.value}") | |
| elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE: | |
| text = "" | |
| allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"] | |
| query = (action.evidence_query or "").strip() | |
| if query.startswith("http"): | |
| manager = SourceManager(root=Path(__file__).resolve().parents[2]) | |
| try: | |
| fetched = manager.fetch_with_cache( | |
| url=query, | |
| allow_domains=allow_domains, | |
| namespace="evidence_fetch", | |
| offline_first=True, | |
| ) | |
| text = str(fetched.get("text", "")) | |
| delta["changes"].append("evidence_cached_or_fetched") | |
| except Exception: | |
| fallback = scrape_with_fallback(query, allow_domains=allow_domains) | |
| text = str(fallback.get("text", "")) | |
| delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}") | |
| else: | |
| text = query | |
| delta["changes"].append("evidence_query_recorded") | |
| mentions = extract_drug_mentions(text) | |
| components = extract_components(text) | |
| state.risk_summary["external_mentions_count"] = float(len(mentions)) | |
| state.risk_summary["external_components_count"] = float(len(components)) | |
| state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item] | |
| elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG: | |
| seed_text = ( | |
| " ".join(action.candidate_components) | |
| if action.candidate_components | |
| else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}" | |
| ) | |
| extracted = extract_components(seed_text) | |
| fallback_components = [token for token in (action.candidate_components or []) if token] | |
| components = extracted or fallback_components | |
| state.risk_summary["new_drug_component_count"] = float(len(components)) | |
| state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0 | |
| state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item] | |
| delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}") | |
| state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")}) | |
| state.step_count += 1 | |
| # Coarse burden update. | |
| state.burden_score = max(0.0, min(1.0, len(meds) / 12.0)) | |
| state.risk_summary["polypharmacy_count"] = float(len(meds)) | |
| state.risk_summary["burden_score"] = float(state.burden_score) | |
| state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds]))) | |
| delta["state"] = {"step_count": state.step_count, "med_count": len(meds)} | |
| return delta | |