TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Environment transition dynamics."""
from __future__ import annotations
from pathlib import Path
from app.common.enums import ActionType, DecisionMode, DoseBucket
from app.common.types import PolyGuardAction, PolyGuardState
from app.dataops.parser import extract_components, extract_drug_mentions
from app.dataops.source_manager import SourceManager
from app.dataops.web_fallback import scrape_with_fallback
from app.knowledge.ddi_knowledge import top_risky_pairs
DOSE_BURDEN_WEIGHT = {
DoseBucket.LOW: 0.7,
DoseBucket.MEDIUM: 1.0,
DoseBucket.HIGH: 1.25,
DoseBucket.HOLD: 0.45,
DoseBucket.NA: 1.0,
}
def _find_med_idx(state: PolyGuardState, drug: str | None) -> int | None:
if not drug:
return None
for idx, med in enumerate(state.patient.medications):
if med.drug == drug:
return idx
return None
def apply_transition(state: PolyGuardState, action: PolyGuardAction) -> dict[str, object]:
delta: dict[str, object] = {"applied": True, "changes": []}
meds = state.patient.medications
target_idx = _find_med_idx(state, action.target_drug)
state.active_mode = action.mode
if action.action_type == ActionType.KEEP_REGIMEN:
delta["changes"].append("no_change")
elif action.action_type == ActionType.STOP_DRUG and target_idx is not None:
removed = meds.pop(target_idx)
delta["changes"].append(f"stopped:{removed.drug}")
elif action.action_type == ActionType.SUBSTITUTE_WITHIN_CLASS and target_idx is not None and action.replacement_drug:
old = meds[target_idx].drug
meds[target_idx].drug = action.replacement_drug
delta["changes"].append(f"substituted:{old}->{action.replacement_drug}")
elif action.action_type == ActionType.RECOMMEND_ALTERNATIVE and target_idx is not None and action.replacement_drug:
old = meds[target_idx].drug
meds[target_idx].drug = action.replacement_drug
delta["changes"].append(f"alternative_recommended:{old}->{action.replacement_drug}")
elif action.action_type in {ActionType.REDUCE_DOSE_BUCKET, ActionType.INCREASE_DOSE_BUCKET} and target_idx is not None:
bucket_order = [DoseBucket.LOW, DoseBucket.MEDIUM, DoseBucket.HIGH]
current = meds[target_idx].dose_bucket
if current in bucket_order:
cur_idx = bucket_order.index(current)
if action.action_type == ActionType.REDUCE_DOSE_BUCKET and cur_idx > 0:
meds[target_idx].dose_bucket = bucket_order[cur_idx - 1]
if action.action_type == ActionType.INCREASE_DOSE_BUCKET and cur_idx < len(bucket_order) - 1:
meds[target_idx].dose_bucket = bucket_order[cur_idx + 1]
delta["changes"].append(f"dose_change:{meds[target_idx].drug}:{current}->{meds[target_idx].dose_bucket}")
elif action.action_type == ActionType.DOSE_HOLD and target_idx is not None:
meds[target_idx].dose_bucket = DoseBucket.HOLD
delta["changes"].append(f"held:{meds[target_idx].drug}")
elif action.action_type == ActionType.ORDER_MONITORING_AND_WAIT:
if target_idx is not None:
meds[target_idx].dose_bucket = DoseBucket.HOLD
delta["changes"].append(f"held_for_monitoring:{meds[target_idx].drug}")
state.unresolved_conflicts = [c for c in state.unresolved_conflicts if not c.startswith("review_requested")]
delta["changes"].append("monitoring_ordered")
elif action.action_type == ActionType.TAPER_INITIATE and target_idx is not None:
meds[target_idx].requires_taper = True
delta["changes"].append(f"taper_start:{meds[target_idx].drug}:{action.taper_days or 7}d")
elif action.action_type == ActionType.TAPER_CONTINUE and target_idx is not None:
meds[target_idx].dose_bucket = DoseBucket.LOW
delta["changes"].append(f"taper_continue:{meds[target_idx].drug}")
elif action.action_type in {ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW}:
state.active_mode = DecisionMode.REVIEW
state.unresolved_conflicts.append(f"review_requested:{action.action_type.value}")
delta["changes"].append(f"review:{action.action_type.value}")
elif action.action_type == ActionType.FETCH_EXTERNAL_EVIDENCE:
text = ""
allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"]
query = (action.evidence_query or "").strip()
if query.startswith("http"):
manager = SourceManager(root=Path(__file__).resolve().parents[2])
try:
fetched = manager.fetch_with_cache(
url=query,
allow_domains=allow_domains,
namespace="evidence_fetch",
offline_first=True,
)
text = str(fetched.get("text", ""))
delta["changes"].append("evidence_cached_or_fetched")
except Exception:
fallback = scrape_with_fallback(query, allow_domains=allow_domains)
text = str(fallback.get("text", ""))
delta["changes"].append(f"evidence_fallback:{fallback.get('backend', 'none')}")
else:
text = query
delta["changes"].append("evidence_query_recorded")
mentions = extract_drug_mentions(text)
components = extract_components(text)
state.risk_summary["external_mentions_count"] = float(len(mentions))
state.risk_summary["external_components_count"] = float(len(components))
state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "missing_data" not in item]
elif action.action_type == ActionType.DECOMPOSE_NEW_DRUG:
seed_text = (
" ".join(action.candidate_components)
if action.candidate_components
else f"active ingredients: {(action.new_drug_name or '').replace('_', ' ')}"
)
extracted = extract_components(seed_text)
fallback_components = [token for token in (action.candidate_components or []) if token]
components = extracted or fallback_components
state.risk_summary["new_drug_component_count"] = float(len(components))
state.risk_summary["new_drug_unknown_risk"] = 0.0 if components else 1.0
state.unresolved_conflicts = [item for item in state.unresolved_conflicts if "new_drug_unknown" not in item]
delta["changes"].append(f"new_drug_components:{','.join(components) if components else 'none'}")
state.action_history.append({"step": state.step_count, "action": action.model_dump(mode="json")})
state.step_count += 1
# Dose-aware burden update so dose optimization has a real reward signal.
dose_weighted_burden = sum(DOSE_BURDEN_WEIGHT.get(med.dose_bucket, 1.0) for med in meds)
state.burden_score = max(0.0, min(1.0, dose_weighted_burden / 12.0))
state.risk_summary["polypharmacy_count"] = float(len(meds))
state.risk_summary["burden_score"] = float(state.burden_score)
state.risk_summary["severe_pair_count"] = float(len(top_risky_pairs([m.drug for m in meds])))
delta["state"] = {"step_count": state.step_count, "med_count": len(meds)}
return delta