File size: 6,472 Bytes
9a75c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# mediagent/agents/vision.py
"""
Vision Agent for MediAgent.
Multimodal medical image analysis engine that processes base64-encoded
diagnostic images using local Qwen vision capabilities. Extracts
anatomical structures, detects pathologies, and structures findings
into standardized radiological observations with confidence scoring.
"""

import logging
from typing import Any, Dict, List, Optional

from core.llm import LLMClient
from core.models import (
    ConfidenceLevel,
    ImageModality,
    SeverityLevel,
    VisionFinding,
    VisionOutput,
)

logger = logging.getLogger(__name__)


class VisionAgent:
    """
    Specialized radiological analysis agent. Interprets medical imagery
    (X-ray, MRI, CT) and outputs structured findings with severity and
    confidence classifications. Operates with deterministic fallbacks
    to maintain pipeline continuity on LLM failures.
    """

    SYSTEM_PROMPT = """You are a board-certified radiologist. Analyze the medical image and return ONLY valid JSON:
{
  "modality_detected": "X-RAY"|"MRI"|"CT"|"UNKNOWN",
  "technical_quality": "brief quality/artifact/positioning note",
  "findings": [
    {
      "anatomical_region": "string",
      "description": "concise radiological observation using standard terminology",
      "severity": "NORMAL"|"INCIDENTAL"|"SIGNIFICANT"|"CRITICAL",
      "confidence": "LOW"|"MEDIUM"|"HIGH",
      "confidence_score": 0.0-100.0,
      "is_anomaly": boolean
    }
  ],
  "overall_assessment": "concise clinical summary"
}
Rules: precise radiological terms; correct anatomical orientation; distinguish normal variants from pathology; realistic confidence scores; no treatment plans; no markdown; no extra text."""

    def __init__(self, llm_client: Optional[LLMClient] = None):
        self.llm = llm_client or LLMClient()

    def process(self, image_base64: str, clinical_context: str = "") -> VisionOutput:
        """
        Execute multimodal image analysis.
        
        Args:
            image_base64: Base64 encoded medical image
            clinical_context: Standardized symptoms/demographics from Intake Agent
            
        Returns:
            VisionOutput: Structured radiological findings and metadata
        """
        logger.info("👁️ Vision Agent initiated multimodal analysis")

        user_prompt = "Analyze this medical image carefully."
        if clinical_context:
            user_prompt += f"\n\nClinical Context: {clinical_context}"
        user_prompt += "\n\nProvide a structured radiological assessment following the specified JSON schema."

        result = self.llm.generate_vision(
            base64_image=image_base64,
            prompt=user_prompt,
            system_prompt=self.SYSTEM_PROMPT,
            temperature=0.0,
            max_tokens=2000
        )

        if not result.get("success"):
            logger.error(f"❌ Vision LLM call failed: {result.get('error')}")
            return self._get_fallback_output()

        raw_content = result.get("content", "")
        parsed = LLMClient.extract_json_from_response(raw_content)
        if not parsed:
            logger.warning("⚠️ Failed to parse vision LLM JSON response. Using fallback.")
            return self._get_fallback_output()

        try:
            return self._parse_vision_response(parsed, result.get("usage", {}))
        except Exception as e:
            logger.error(f"💥 Vision response mapping failed: {e}")
            return self._get_fallback_output()

    def _parse_vision_response(self, data: Dict[str, Any], usage: Dict[str, int]) -> VisionOutput:
        """Map raw LLM JSON to validated Pydantic models with safe enum conversion."""
        findings = []
        raw_findings = data.get("findings", [])

        for item in raw_findings:
            try:
                finding = VisionFinding(
                    anatomical_region=item.get("anatomical_region", "Unspecified Region"),
                    description=item.get("description", "Unable to generate detailed description."),
                    severity=self._safe_enum(SeverityLevel, item.get("severity"), SeverityLevel.NORMAL),
                    confidence=self._safe_enum(ConfidenceLevel, item.get("confidence"), ConfidenceLevel.MEDIUM),
                    confidence_score=float(item.get("confidence_score", 50.0)),
                    is_anomaly=bool(item.get("is_anomaly", False))
                )
                findings.append(finding)
            except Exception as e:
                logger.warning(f"⚠️ Skipping malformed finding due to validation error: {e}")
                continue

        modality_str = data.get("modality_detected", "UNKNOWN").upper()
        try:
            modality = ImageModality(modality_str)
        except ValueError:
            modality = ImageModality.UNKNOWN

        return VisionOutput(
            modality_detected=modality,
            technical_quality=data.get(
                "technical_quality", 
                "Image quality acceptable for preliminary assessment."
            ),
            findings=findings,
            overall_assessment=data.get(
                "overall_assessment", 
                "Unable to generate overall assessment from provided data."
            ),
            metadata={
                "llm_usage": usage,
                "findings_count": len(findings),
                "anomalies_detected": sum(1 for f in findings if f.is_anomaly)
            }
        )

    @staticmethod
    def _safe_enum(enum_cls, value, default):
        """Safely convert string to enum, falling back gracefully."""
        try:
            return enum_cls(str(value).strip().upper())
        except (ValueError, AttributeError, TypeError):
            return default

    def _get_fallback_output(self) -> VisionOutput:
        """Return a safe, non-failure-breaking VisionOutput when processing fails."""
        logger.warning("⚠️ Returning fallback VisionOutput due to processing failure.")
        return VisionOutput(
            modality_detected=ImageModality.UNKNOWN,
            technical_quality="Analysis unavailable due to system error. Manual review required.",
            findings=[],
            overall_assessment="Vision analysis could not be completed. Please verify image quality and system connectivity.",
            metadata={"error": "VISION_AGENT_FALLBACK", "llm_usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}}
        )