MediAgent / core /models.py
medi422's picture
Upload 21 files
9a75c73 verified
# mediagent/core/models.py
"""
Pydantic data models for MediAgent multi-agent medical imaging pipeline.
Defines input, agent outputs, report structure, and pipeline state tracking.
"""
import enum
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
# ─────────────────────────────────────────────────────────────────────────────
# ENUMERATIONS
# ─────────────────────────────────────────────────────────────────────────────
class SeverityLevel(str, enum.Enum):
"""Clinical severity classification for findings."""
NORMAL = "NORMAL"
INCIDENTAL = "INCIDENTAL"
SIGNIFICANT = "SIGNIFICANT"
CRITICAL = "CRITICAL"
class ConfidenceLevel(str, enum.Enum):
"""AI confidence classification for model outputs."""
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
class AgentStatus(str, enum.Enum):
"""Real-time pipeline agent execution states."""
WAITING = "WAITING"
RUNNING = "RUNNING"
DONE = "DONE"
ERROR = "ERROR"
class ImageModality(str, enum.Enum):
"""Supported medical imaging modalities."""
XRAY = "X-RAY"
MRI = "MRI"
CT = "CT"
UNKNOWN = "UNKNOWN"
# ─────────────────────────────────────────────────────────────────────────────
# INPUT MODELS
# ─────────────────────────────────────────────────────────────────────────────
class PatientInput(BaseModel):
"""Initial client submission containing image and clinical context."""
image_base64: str = Field(
...,
description="Base64 encoded medical image (PNG/JPG format)"
)
symptoms: str = Field(
default="",
description="Patient reported symptoms or chief complaint"
)
age: Optional[int] = Field(
default=None, ge=0, le=120, description="Patient age in years"
)
sex: Optional[str] = Field(
default=None, pattern="^(M|F|O)$", description="Patient biological sex"
)
clinical_context: Optional[str] = Field(
default=None, description="Relevant medical history or referral details"
)
@field_validator("image_base64")
@classmethod
def validate_image_data(cls, v: str) -> str:
if not v or len(v) < 10:
raise ValueError("Invalid or empty base64 image data provided.")
return v
# ─────────────────────────────────────────────────────────────────────────────
# AGENT OUTPUT MODELS
# ─────────────────────────────────────────────────────────────────────────────
class IntakeOutput(BaseModel):
"""Structured data produced by the Intake Agent."""
validated: bool = Field(default=True, description="Whether input passed validation checks")
standardized_symptoms: str = Field(default="", description="Clinically normalized symptom description")
extracted_demographics: Dict[str, Any] = Field(default_factory=dict)
safety_flags: List[str] = Field(default_factory=list, description="Pre-analysis safety/alert flags")
recommended_modality: ImageModality = Field(default=ImageModality.UNKNOWN)
processing_notes: str = Field(default="")
class VisionFinding(BaseModel):
"""Individual anatomical observation from the Vision Agent."""
anatomical_region: str = Field(..., description="e.g., Left Lung Field, Medial Patella")
description: str = Field(..., description="Detailed radiological description")
severity: SeverityLevel = Field(default=SeverityLevel.NORMAL)
confidence: ConfidenceLevel = Field(default=ConfidenceLevel.LOW)
confidence_score: float = Field(default=0.0, ge=0.0, le=100.0)
is_anomaly: bool = Field(default=False)
class VisionOutput(BaseModel):
"""Complete visual analysis result from the Vision Agent."""
modality_detected: ImageModality = Field(default=ImageModality.UNKNOWN)
technical_quality: str = Field(default="Acceptable", description="Image quality/artifact assessment")
findings: List[VisionFinding] = Field(default_factory=list)
overall_assessment: str = Field(default="No obvious abnormalities detected.")
metadata: Dict[str, Any] = Field(default_factory=dict)
class KnowledgeMatch(BaseModel):
"""Differential diagnosis entry from the Research Agent."""
condition_name: str = Field(..., description="Medical condition or diagnosis")
match_probability: float = Field(..., ge=0.0, le=100.0, description="Confidence percentage")
supporting_evidence: str = Field(..., description="Pathophysiological/clinical correlation")
differential_rank: int = Field(default=0, ge=1)
icd10_code: Optional[str] = Field(default=None)
class ResearchOutput(BaseModel):
"""Knowledge base search and differential diagnosis result."""
differential_diagnoses: List[KnowledgeMatch] = Field(default_factory=list)
matched_conditions: List[str] = Field(default_factory=list)
relevant_guidelines: List[str] = Field(default_factory=list)
research_notes: str = Field(default="")
sources_used: List[str] = Field(default=["internal_knowledge_base"])
# ─────────────────────────────────────────────────────────────────────────────
# REPORT MODELS
# ─────────────────────────────────────────────────────────────────────────────
class ReportSection(BaseModel):
"""Standard radiology report structure."""
clinical_history: str = Field(default="Not provided.")
technique: str = Field(default="Digital advanced imaging acquisition.")
findings: str = Field(default="No abnormalities detected.")
impression: str = Field(default="Within normal limits.")
recommendations: str = Field(default="Routine follow-up as clinically indicated.")
disclaimer: str = Field(
default="This analysis is AI-generated and must be reviewed by a licensed radiologist before any clinical decisions are made."
)
class FinalReport(BaseModel):
"""Complete synthesized clinical report ready for delivery."""
report_id: str = Field(default_factory=lambda: f"REP-{uuid.uuid4().hex[:12].upper()}")
patient_metadata: Dict[str, Any] = Field(default_factory=dict)
sections: ReportSection = Field(default_factory=ReportSection)
vision_summary: str = Field(default="")
research_summary: str = Field(default="")
overall_severity: SeverityLevel = Field(default=SeverityLevel.NORMAL)
generation_timestamp: datetime = Field(default_factory=datetime.now)
agent_pipeline_status: Dict[str, AgentStatus] = Field(default_factory=dict)
# ─────────────────────────────────────────────────────────────────────────────
# CHAT / ADVISOR MODELS
# ─────────────────────────────────────────────────────────────────────────────
class ChatMessage(BaseModel):
"""Single turn in the post-report clinical advisor chat."""
role: str = Field(..., description="'user' or 'assistant'")
content: str = Field(..., description="Message text")
timestamp: datetime = Field(default_factory=datetime.now)
class ChatRequest(BaseModel):
"""Incoming question for the ClinicalAdvisorAgent."""
question: str = Field(..., min_length=3, max_length=1000)
class ChatResponse(BaseModel):
"""Response from the ClinicalAdvisorAgent."""
answer: str
report_id: str
timestamp: datetime = Field(default_factory=datetime.now)
# ─────────────────────────────────────────────────────────────────────────────
# PIPELINE STATE MODEL
# ─────────────────────────────────────────────────────────────────────────────
class PipelineState(BaseModel):
"""Tracks real-time execution state across all agents."""
current_step: str = Field(default="INIT")
agent_statuses: Dict[str, AgentStatus] = Field(
default_factory=lambda: {
"INTAKE": AgentStatus.WAITING,
"VISION": AgentStatus.WAITING,
"RESEARCH": AgentStatus.WAITING,
"REPORT": AgentStatus.WAITING,
"CRITIC": AgentStatus.WAITING
}
)
intake_output: Optional[IntakeOutput] = None
vision_output: Optional[VisionOutput] = None
research_output: Optional[ResearchOutput] = None
report_draft: Optional[ReportSection] = None
final_report: Optional[FinalReport] = None
error_log: List[str] = Field(default_factory=list)
def mark_running(self, agent_name: str) -> None:
self.agent_statuses[agent_name] = AgentStatus.RUNNING
def mark_done(self, agent_name: str) -> None:
self.agent_statuses[agent_name] = AgentStatus.DONE
def mark_error(self, agent_name: str, error_msg: str) -> None:
self.agent_statuses[agent_name] = AgentStatus.ERROR
self.error_log.append(f"[{agent_name}] {error_msg}")