MediAgent / core /pipeline.py
medi422's picture
Upload 21 files
9a75c73 verified
# mediagent/core/pipeline.py
"""
Pipeline Orchestrator for MediAgent.
Manages execution of all medical imaging analysis agents with parallel
scheduling where possible, tracks real-time state, handles graceful error
recovery, and coordinates data flow between specialized AI components.
Parallelism strategy:
- INTAKE + VISION run concurrently (vision only needs the raw image)
- RESEARCH runs after VISION completes (needs findings)
- REPORT runs after all three complete
- CRITIC runs after REPORT
"""
import logging
import time
from concurrent.futures import ThreadPoolExecutor, wait as futures_wait
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from core.models import (
AgentStatus,
FinalReport,
IntakeOutput,
PipelineState,
PatientInput,
ResearchOutput,
ReportSection,
VisionOutput,
)
logger = logging.getLogger(__name__)
class PipelineOrchestrator:
"""
Central execution engine that routes patient data through the 5-agent
medical analysis pipeline. Runs INTAKE and VISION in parallel to cut
wall-clock latency, then sequences RESEARCH β†’ REPORT β†’ CRITIC.
"""
AGENT_ORDER = ["INTAKE", "VISION", "RESEARCH", "REPORT", "CRITIC"]
def __init__(
self,
intake_agent: Any,
vision_agent: Any,
research_agent: Any,
report_agent: Any,
critic_agent: Any,
on_status_update: Optional[Callable[[PipelineState], None]] = None,
):
self.agents = {
"INTAKE": intake_agent,
"VISION": vision_agent,
"RESEARCH": research_agent,
"REPORT": report_agent,
"CRITIC": critic_agent,
}
self.on_status_update = on_status_update
def run(self, patient_input: PatientInput) -> PipelineState:
"""
Execute the full diagnostic pipeline with parallel INTAKE+VISION stage.
Returns:
PipelineState: Complete execution state containing all outputs,
agent statuses, and final report.
"""
logger.info("πŸš€ Pipeline execution started | Input ID: %s", id(patient_input))
state = PipelineState()
try:
# ─────────────────────────────────────────────────────────────
# STAGE 1: INTAKE + VISION in parallel
# Vision uses raw symptoms as clinical context so it doesn't
# have to wait for intake normalization to complete.
# ─────────────────────────────────────────────────────────────
raw_context = patient_input.symptoms or ""
with ThreadPoolExecutor(max_workers=2) as pool:
intake_fut = pool.submit(
self._execute_step, state, "INTAKE",
patient_input=patient_input
)
vision_fut = pool.submit(
self._execute_step, state, "VISION",
image_base64=patient_input.image_base64,
clinical_context=raw_context
)
futures_wait([intake_fut, vision_fut])
if state.agent_statuses["INTAKE"] == AgentStatus.ERROR:
logger.warning("⚠️ Intake failed. Halting pipeline for safety.")
return state
if state.agent_statuses["VISION"] == AgentStatus.ERROR:
logger.warning("⚠️ Vision analysis failed. Continuing with degraded pipeline.")
# ─────────────────────────────────────────────────────────────
# STAGE 2: RESEARCH (needs vision findings + intake demographics)
# ─────────────────────────────────────────────────────────────
self._execute_step(
state, "RESEARCH",
vision_findings=state.vision_output.findings if state.vision_output else [],
demographics=state.intake_output.extracted_demographics,
detected_modality=state.vision_output.modality_detected.value if state.vision_output else "UNKNOWN"
)
if state.agent_statuses["RESEARCH"] == AgentStatus.ERROR:
logger.warning("⚠️ Research failed. Generating report without differential augmentation.")
# ─────────────────────────────────────────────────────────────
# STAGE 3: REPORT (synthesizes all three upstream outputs)
# ─────────────────────────────────────────────────────────────
self._execute_step(
state, "REPORT",
intake=state.intake_output,
vision=state.vision_output,
research=state.research_output,
)
if state.agent_statuses["REPORT"] == AgentStatus.ERROR:
logger.error("❌ Report generation failed. Pipeline cannot complete safely.")
return state
# ─────────────────────────────────────────────────────────────
# STAGE 4: CRITIC (QA review of final draft)
# ─────────────────────────────────────────────────────────────
self._execute_step(
state, "CRITIC",
draft_report=state.report_draft,
pipeline_state=state
)
state.final_report = self._assemble_final_report(state)
state.current_step = "COMPLETE"
logger.info("βœ… Pipeline execution completed successfully.")
except Exception as e:
logger.exception("πŸ’₯ Unhandled pipeline failure: %s", str(e))
state.current_step = "FAILED"
state.error_log.append(f"SYSTEM_FAILURE: {str(e)}")
return state
def _execute_step(
self,
state: PipelineState,
agent_name: str,
**kwargs
) -> None:
"""
Generic step executor with state management, timing, and error isolation.
Thread-safe: each agent writes to its own dedicated state field.
"""
logger.info(f"▢️ Executing agent: {agent_name}")
state.current_step = agent_name
state.mark_running(agent_name)
self._notify(state)
start_time = time.perf_counter()
try:
agent = self.agents[agent_name]
output = agent.process(**kwargs)
elapsed = time.perf_counter() - start_time
if output is not None:
if agent_name == "INTAKE":
state.intake_output = output
elif agent_name == "VISION":
state.vision_output = output
elif agent_name == "RESEARCH":
state.research_output = output
elif agent_name == "REPORT":
state.report_draft = output
state.mark_done(agent_name)
logger.info(f"βœ… {agent_name} completed in {elapsed:.3f}s")
except Exception as e:
elapsed = time.perf_counter() - start_time
error_msg = f"{agent_name} execution failed after {elapsed:.3f}s: {str(e)}"
logger.error("❌ %s", error_msg, exc_info=True)
state.mark_error(agent_name, str(e))
finally:
self._notify(state)
def _assemble_final_report(self, state: PipelineState) -> FinalReport:
"""
Synthesize all agent outputs into the final deliverable report structure.
Applies critic modifications and standardizes formatting.
"""
report_draft = state.report_draft or ReportSection()
# Determine overall severity from vision findings
overall_severity = "NORMAL"
if state.vision_output and state.vision_output.findings:
severity_hierarchy = {"CRITICAL": 4, "SIGNIFICANT": 3, "INCIDENTAL": 2, "NORMAL": 1}
highest = max(
(severity_hierarchy.get(f.severity.value, 1) for f in state.vision_output.findings),
default=1
)
severity_map = {4: "CRITICAL", 3: "SIGNIFICANT", 2: "INCIDENTAL", 1: "NORMAL"}
overall_severity = severity_map.get(highest, "NORMAL")
# Build vision summary
vision_summary = "No imaging analysis performed."
if state.vision_output:
anomalies = [f.description for f in state.vision_output.findings if f.is_anomaly]
vision_summary = (
f"Modality: {state.vision_output.modality_detected.value} | "
f"Quality: {state.vision_output.technical_quality} | "
f"Anomalies Detected: {len(anomalies)} | "
f"Overall: {state.vision_output.overall_assessment}"
)
# Build research summary
research_summary = "No differential diagnosis generated."
if state.research_output and state.research_output.differential_diagnoses:
top_dx = state.research_output.differential_diagnoses[:3]
dx_list = " | ".join([d.condition_name for d in top_dx])
research_summary = f"Top Differentials: {dx_list} | Confidence: {'/'.join([f'{d.match_probability:.0f}%' for d in top_dx])}"
return FinalReport(
patient_metadata=state.intake_output.extracted_demographics if state.intake_output else {},
sections=report_draft,
vision_summary=vision_summary,
research_summary=research_summary,
overall_severity=overall_severity,
agent_pipeline_status=state.agent_statuses,
generation_timestamp=datetime.now()
)
def _notify(self, state: PipelineState) -> None:
"""Invoke status callback if registered (used for SSE/UI polling)."""
if self.on_status_update:
try:
self.on_status_update(state)
except Exception as e:
logger.warning("⚠️ Status callback failed: %s", str(e))