File size: 9,884 Bytes
9a75c73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | # mediagent/agents/intake.py
"""
Intake Agent for MediAgent.
Validates patient submissions, normalizes clinical terminology,
extracts demographics, detects imaging modality hints, and flags
urgent safety concerns before routing to downstream agents.
"""
import logging
import re
from typing import Any, Dict, List, Optional
from core.llm import LLMClient
from core.models import IntakeOutput, ImageModality, PatientInput
logger = logging.getLogger(__name__)
class IntakeAgent:
"""
First-stage pipeline agent responsible for input validation,
clinical text normalization, demographic extraction, and safety triage.
Ensures downstream agents receive structured, standardized input.
"""
# Deterministic safety keywords for immediate flagging
SAFETY_KEYWORDS = [
"acute trauma", "chest pain", "shortness of breath", "dyspnea",
"stroke symptoms", "neurological deficit", "hemoptysis", "massive bleed",
"pediatric emergency", "pregnant", "anaphylaxis", "sepsis", "fever",
"head injury", "spinal trauma", "acute abdomen", "suspected fracture"
]
MODALITY_KEYWORDS = {
"x-ray": ImageModality.XRAY,
"xr": ImageModality.XRAY,
"radiograph": ImageModality.XRAY,
"ct scan": ImageModality.CT,
"ct": ImageModality.CT,
"computed tomography": ImageModality.CT,
"mri": ImageModality.MRI,
"magnetic resonance": ImageModality.MRI,
"mammogram": ImageModality.XRAY, # Technically X-ray based
}
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm = llm_client or LLMClient()
def process(self, patient_input: PatientInput) -> IntakeOutput:
"""
Main intake processing method.
Args:
patient_input: Raw validated patient submission
Returns:
IntakeOutput: Structured, normalized, safety-checked data
"""
logger.info("📋 Intake Agent processing initiated")
# 1. Validate image payload
if not self._validate_image_payload(patient_input.image_base64):
logger.warning("⚠️ Image payload validation failed. Proceeding with warnings.")
# 2. Apply deterministic safety triage
safety_flags = self._check_deterministic_safety(patient_input)
# 3. Clinical normalization & demographic extraction via LLM
structured_data = self._normalize_with_llm(patient_input, safety_flags)
# 4. Enrich modality detection
modality = self._infer_modality(patient_input, structured_data)
# 5. Assemble & validate output
try:
output = IntakeOutput(
validated=True,
standardized_symptoms=structured_data.get("standardized_symptoms", patient_input.symptoms or ""),
extracted_demographics=structured_data.get("extracted_demographics", {}),
safety_flags=list(set(safety_flags + structured_data.get("safety_flags", []))),
recommended_modality=modality,
processing_notes=structured_data.get("processing_notes", "")
)
logger.info("✅ Intake Agent completed successfully")
return output
except Exception as e:
logger.error(f"💥 IntakeOutput validation failed: {e}")
return self._get_fallback_output(patient_input, safety_flags)
def _validate_image_payload(self, base64_data: str) -> bool:
"""Validate base64 image integrity and size constraints."""
if not base64_data or len(base64_data) < 500:
return False
# Check for valid base64 pattern (ignoring data URI prefix)
clean = re.sub(r"^data:image/[a-z]+;base64,", "", base64_data)
try:
import base64
base64.b64decode(clean)
return len(clean) < 20 * 1024 * 1024 # < 20MB limit
except Exception:
return False
def _check_deterministic_safety(self, inp: PatientInput) -> List[str]:
"""Scan raw input for high-priority clinical safety terms."""
text = f"{inp.symptoms} {inp.clinical_context}".lower()
flags = []
for kw in self.SAFETY_KEYWORDS:
if kw.lower() in text:
flags.append(f"URGENT_TERM_DETECTED: {kw}")
if inp.age is not None and inp.age < 18:
flags.append("PATIENT_AGE: PEDIATRIC_REQUIRES_EXPERT_REVIEW")
if inp.age is not None and inp.age > 75:
flags.append("PATIENT_AGE: GERIATRIC_CONSIDERATIONS_RECOMMENDED")
return flags
# Layman-to-medical term map for fast deterministic normalization
LAYMAN_TERMS = {
"can't breathe": "dyspnea", "hard to breathe": "dyspnea", "difficulty breathing": "dyspnea",
"stomach pain": "abdominal pain", "belly pain": "abdominal pain", "tummy pain": "abdominal pain",
"chest tightness": "chest pain/pressure", "heart racing": "palpitations",
"blurry vision": "visual disturbance", "can't see clearly": "visual disturbance",
"dizzy": "dizziness/vertigo", "feel faint": "presyncope", "passed out": "syncope",
"throwing up": "vomiting", "nausea and vomiting": "nausea/emesis",
"back pain": "dorsal pain", "leg pain": "lower extremity pain",
"arm pain": "upper extremity pain", "neck pain": "cervicalgia",
"headache": "cephalgia", "head pain": "cephalgia",
"swollen": "edema", "swelling": "edema", "bruise": "ecchymosis",
"lump": "mass/nodule", "bump": "mass/nodule"
}
def _normalize_with_llm(self, inp: PatientInput, existing_flags: List[str]) -> Dict[str, Any]:
"""
Normalize clinical text. Uses fast deterministic mapping for simple inputs;
falls back to LLM only for complex or lengthy clinical context.
"""
combined_text = f"{inp.symptoms or ''} {inp.clinical_context or ''}".strip()
# Skip LLM for short/simple inputs — deterministic normalization is sufficient
if len(combined_text) <= 120 and not any(
indicator in combined_text.lower()
for indicator in ["history of", "diagnosed with", "chronic", "prior", "previous", "medication", "allerg"]
):
logger.debug("⚡ Short input detected — using fast deterministic normalization (skipping LLM)")
return self._fast_normalize(inp, existing_flags)
prompt = f"""You are a clinical data standardization expert.
Convert raw patient input to standardized clinical terminology. Respond ONLY with JSON:
{{"standardized_symptoms":"string","extracted_demographics":{{"age":int|null,"sex":"M|F|O"|null,"comorbidities":["string"]}},"safety_flags":["string"],"processing_notes":"string"}}
Input:
- Symptoms: "{inp.symptoms or 'Not provided'}"
- Age: {inp.age}
- Sex: {inp.sex}
- Clinical Context: "{inp.clinical_context or 'Not provided'}"
- Existing Flags: {existing_flags}
Rules: convert layman terms to medical terminology; extract comorbidities; add safety flags; no markdown."""
result = self.llm.generate_text(prompt=prompt, force_json=True)
if result.get("success") and result.get("content"):
parsed = LLMClient.extract_json_from_response(result["content"])
if parsed:
return parsed
logger.warning("⚠️ LLM normalization failed. Using deterministic fallback.")
return self._build_fallback_dict(inp, existing_flags)
def _fast_normalize(self, inp: PatientInput, flags: List[str]) -> Dict[str, Any]:
"""Deterministic normalization using term mapping — zero LLM calls."""
text = f"{inp.symptoms or ''} {inp.clinical_context or ''}".lower()
normalized = inp.symptoms or "No symptoms provided"
for layman, medical in self.LAYMAN_TERMS.items():
if layman in text:
normalized = normalized.lower().replace(layman, medical)
return {
"standardized_symptoms": normalized.strip(),
"extracted_demographics": {
"age": inp.age,
"sex": inp.sex,
"comorbidities": []
},
"safety_flags": flags,
"processing_notes": "Fast deterministic normalization applied."
}
def _infer_modality(self, inp: PatientInput, llm_data: Dict[str, Any]) -> ImageModality:
"""Infer imaging modality from text hints or default to UNKNOWN."""
text = f"{inp.symptoms} {inp.clinical_context}".lower()
for kw, mod in self.MODALITY_KEYWORDS.items():
if kw in text:
return mod
return ImageModality.UNKNOWN
def _build_fallback_dict(self, inp: PatientInput, flags: List[str]) -> Dict[str, Any]:
"""Deterministic fallback when LLM is unavailable."""
return {
"standardized_symptoms": inp.symptoms or "No symptoms provided",
"extracted_demographics": {
"age": inp.age,
"sex": inp.sex,
"comorbidities": []
},
"safety_flags": flags,
"processing_notes": "LLM normalization unavailable. Raw input preserved."
}
def _get_fallback_output(self, inp: PatientInput, flags: List[str]) -> IntakeOutput:
"""Return a safe, minimally structured IntakeOutput on critical failure."""
return IntakeOutput(
validated=False,
standardized_symptoms=inp.symptoms or "",
extracted_demographics={"age": inp.age, "sex": inp.sex, "comorbidities": []},
safety_flags=flags + ["INTAKE_AGENT_FALLBACK_MODE"],
recommended_modality=ImageModality.UNKNOWN,
processing_notes="Intake agent encountered critical validation failure. Pipeline continues with degraded state."
)
|