| |
| """ |
| Intake Agent for MediAgent. |
| Validates patient submissions, normalizes clinical terminology, |
| extracts demographics, detects imaging modality hints, and flags |
| urgent safety concerns before routing to downstream agents. |
| """ |
|
|
| import logging |
| import re |
| from typing import Any, Dict, List, Optional |
|
|
| from core.llm import LLMClient |
| from core.models import IntakeOutput, ImageModality, PatientInput |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class IntakeAgent: |
| """ |
| First-stage pipeline agent responsible for input validation, |
| clinical text normalization, demographic extraction, and safety triage. |
| Ensures downstream agents receive structured, standardized input. |
| """ |
|
|
| |
| SAFETY_KEYWORDS = [ |
| "acute trauma", "chest pain", "shortness of breath", "dyspnea", |
| "stroke symptoms", "neurological deficit", "hemoptysis", "massive bleed", |
| "pediatric emergency", "pregnant", "anaphylaxis", "sepsis", "fever", |
| "head injury", "spinal trauma", "acute abdomen", "suspected fracture" |
| ] |
|
|
| MODALITY_KEYWORDS = { |
| "x-ray": ImageModality.XRAY, |
| "xr": ImageModality.XRAY, |
| "radiograph": ImageModality.XRAY, |
| "ct scan": ImageModality.CT, |
| "ct": ImageModality.CT, |
| "computed tomography": ImageModality.CT, |
| "mri": ImageModality.MRI, |
| "magnetic resonance": ImageModality.MRI, |
| "mammogram": ImageModality.XRAY, |
| } |
|
|
| def __init__(self, llm_client: Optional[LLMClient] = None): |
| self.llm = llm_client or LLMClient() |
|
|
| def process(self, patient_input: PatientInput) -> IntakeOutput: |
| """ |
| Main intake processing method. |
| |
| Args: |
| patient_input: Raw validated patient submission |
| |
| Returns: |
| IntakeOutput: Structured, normalized, safety-checked data |
| """ |
| logger.info("📋 Intake Agent processing initiated") |
| |
| |
| if not self._validate_image_payload(patient_input.image_base64): |
| logger.warning("⚠️ Image payload validation failed. Proceeding with warnings.") |
| |
| |
| safety_flags = self._check_deterministic_safety(patient_input) |
| |
| |
| structured_data = self._normalize_with_llm(patient_input, safety_flags) |
| |
| |
| modality = self._infer_modality(patient_input, structured_data) |
| |
| |
| try: |
| output = IntakeOutput( |
| validated=True, |
| standardized_symptoms=structured_data.get("standardized_symptoms", patient_input.symptoms or ""), |
| extracted_demographics=structured_data.get("extracted_demographics", {}), |
| safety_flags=list(set(safety_flags + structured_data.get("safety_flags", []))), |
| recommended_modality=modality, |
| processing_notes=structured_data.get("processing_notes", "") |
| ) |
| logger.info("✅ Intake Agent completed successfully") |
| return output |
| except Exception as e: |
| logger.error(f"💥 IntakeOutput validation failed: {e}") |
| return self._get_fallback_output(patient_input, safety_flags) |
|
|
| def _validate_image_payload(self, base64_data: str) -> bool: |
| """Validate base64 image integrity and size constraints.""" |
| if not base64_data or len(base64_data) < 500: |
| return False |
| |
| clean = re.sub(r"^data:image/[a-z]+;base64,", "", base64_data) |
| try: |
| import base64 |
| base64.b64decode(clean) |
| return len(clean) < 20 * 1024 * 1024 |
| except Exception: |
| return False |
|
|
| def _check_deterministic_safety(self, inp: PatientInput) -> List[str]: |
| """Scan raw input for high-priority clinical safety terms.""" |
| text = f"{inp.symptoms} {inp.clinical_context}".lower() |
| flags = [] |
| for kw in self.SAFETY_KEYWORDS: |
| if kw.lower() in text: |
| flags.append(f"URGENT_TERM_DETECTED: {kw}") |
| if inp.age is not None and inp.age < 18: |
| flags.append("PATIENT_AGE: PEDIATRIC_REQUIRES_EXPERT_REVIEW") |
| if inp.age is not None and inp.age > 75: |
| flags.append("PATIENT_AGE: GERIATRIC_CONSIDERATIONS_RECOMMENDED") |
| return flags |
|
|
| |
| LAYMAN_TERMS = { |
| "can't breathe": "dyspnea", "hard to breathe": "dyspnea", "difficulty breathing": "dyspnea", |
| "stomach pain": "abdominal pain", "belly pain": "abdominal pain", "tummy pain": "abdominal pain", |
| "chest tightness": "chest pain/pressure", "heart racing": "palpitations", |
| "blurry vision": "visual disturbance", "can't see clearly": "visual disturbance", |
| "dizzy": "dizziness/vertigo", "feel faint": "presyncope", "passed out": "syncope", |
| "throwing up": "vomiting", "nausea and vomiting": "nausea/emesis", |
| "back pain": "dorsal pain", "leg pain": "lower extremity pain", |
| "arm pain": "upper extremity pain", "neck pain": "cervicalgia", |
| "headache": "cephalgia", "head pain": "cephalgia", |
| "swollen": "edema", "swelling": "edema", "bruise": "ecchymosis", |
| "lump": "mass/nodule", "bump": "mass/nodule" |
| } |
|
|
| def _normalize_with_llm(self, inp: PatientInput, existing_flags: List[str]) -> Dict[str, Any]: |
| """ |
| Normalize clinical text. Uses fast deterministic mapping for simple inputs; |
| falls back to LLM only for complex or lengthy clinical context. |
| """ |
| combined_text = f"{inp.symptoms or ''} {inp.clinical_context or ''}".strip() |
|
|
| |
| if len(combined_text) <= 120 and not any( |
| indicator in combined_text.lower() |
| for indicator in ["history of", "diagnosed with", "chronic", "prior", "previous", "medication", "allerg"] |
| ): |
| logger.debug("⚡ Short input detected — using fast deterministic normalization (skipping LLM)") |
| return self._fast_normalize(inp, existing_flags) |
|
|
| prompt = f"""You are a clinical data standardization expert. |
| Convert raw patient input to standardized clinical terminology. Respond ONLY with JSON: |
| {{"standardized_symptoms":"string","extracted_demographics":{{"age":int|null,"sex":"M|F|O"|null,"comorbidities":["string"]}},"safety_flags":["string"],"processing_notes":"string"}} |
| |
| Input: |
| - Symptoms: "{inp.symptoms or 'Not provided'}" |
| - Age: {inp.age} |
| - Sex: {inp.sex} |
| - Clinical Context: "{inp.clinical_context or 'Not provided'}" |
| - Existing Flags: {existing_flags} |
| |
| Rules: convert layman terms to medical terminology; extract comorbidities; add safety flags; no markdown.""" |
|
|
| result = self.llm.generate_text(prompt=prompt, force_json=True) |
| if result.get("success") and result.get("content"): |
| parsed = LLMClient.extract_json_from_response(result["content"]) |
| if parsed: |
| return parsed |
|
|
| logger.warning("⚠️ LLM normalization failed. Using deterministic fallback.") |
| return self._build_fallback_dict(inp, existing_flags) |
|
|
| def _fast_normalize(self, inp: PatientInput, flags: List[str]) -> Dict[str, Any]: |
| """Deterministic normalization using term mapping — zero LLM calls.""" |
| text = f"{inp.symptoms or ''} {inp.clinical_context or ''}".lower() |
| normalized = inp.symptoms or "No symptoms provided" |
| for layman, medical in self.LAYMAN_TERMS.items(): |
| if layman in text: |
| normalized = normalized.lower().replace(layman, medical) |
| return { |
| "standardized_symptoms": normalized.strip(), |
| "extracted_demographics": { |
| "age": inp.age, |
| "sex": inp.sex, |
| "comorbidities": [] |
| }, |
| "safety_flags": flags, |
| "processing_notes": "Fast deterministic normalization applied." |
| } |
|
|
| def _infer_modality(self, inp: PatientInput, llm_data: Dict[str, Any]) -> ImageModality: |
| """Infer imaging modality from text hints or default to UNKNOWN.""" |
| text = f"{inp.symptoms} {inp.clinical_context}".lower() |
| for kw, mod in self.MODALITY_KEYWORDS.items(): |
| if kw in text: |
| return mod |
| return ImageModality.UNKNOWN |
|
|
| def _build_fallback_dict(self, inp: PatientInput, flags: List[str]) -> Dict[str, Any]: |
| """Deterministic fallback when LLM is unavailable.""" |
| return { |
| "standardized_symptoms": inp.symptoms or "No symptoms provided", |
| "extracted_demographics": { |
| "age": inp.age, |
| "sex": inp.sex, |
| "comorbidities": [] |
| }, |
| "safety_flags": flags, |
| "processing_notes": "LLM normalization unavailable. Raw input preserved." |
| } |
|
|
| def _get_fallback_output(self, inp: PatientInput, flags: List[str]) -> IntakeOutput: |
| """Return a safe, minimally structured IntakeOutput on critical failure.""" |
| return IntakeOutput( |
| validated=False, |
| standardized_symptoms=inp.symptoms or "", |
| extracted_demographics={"age": inp.age, "sex": inp.sex, "comorbidities": []}, |
| safety_flags=flags + ["INTAKE_AGENT_FALLBACK_MODE"], |
| recommended_modality=ImageModality.UNKNOWN, |
| processing_notes="Intake agent encountered critical validation failure. Pipeline continues with degraded state." |
| ) |
|
|