File size: 10,841 Bytes
9a75c73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | # mediagent/core/pipeline.py
"""
Pipeline Orchestrator for MediAgent.
Manages execution of all medical imaging analysis agents with parallel
scheduling where possible, tracks real-time state, handles graceful error
recovery, and coordinates data flow between specialized AI components.
Parallelism strategy:
- INTAKE + VISION run concurrently (vision only needs the raw image)
- RESEARCH runs after VISION completes (needs findings)
- REPORT runs after all three complete
- CRITIC runs after REPORT
"""
import logging
import time
from concurrent.futures import ThreadPoolExecutor, wait as futures_wait
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from core.models import (
AgentStatus,
FinalReport,
IntakeOutput,
PipelineState,
PatientInput,
ResearchOutput,
ReportSection,
VisionOutput,
)
logger = logging.getLogger(__name__)
class PipelineOrchestrator:
"""
Central execution engine that routes patient data through the 5-agent
medical analysis pipeline. Runs INTAKE and VISION in parallel to cut
wall-clock latency, then sequences RESEARCH β REPORT β CRITIC.
"""
AGENT_ORDER = ["INTAKE", "VISION", "RESEARCH", "REPORT", "CRITIC"]
def __init__(
self,
intake_agent: Any,
vision_agent: Any,
research_agent: Any,
report_agent: Any,
critic_agent: Any,
on_status_update: Optional[Callable[[PipelineState], None]] = None,
):
self.agents = {
"INTAKE": intake_agent,
"VISION": vision_agent,
"RESEARCH": research_agent,
"REPORT": report_agent,
"CRITIC": critic_agent,
}
self.on_status_update = on_status_update
def run(self, patient_input: PatientInput) -> PipelineState:
"""
Execute the full diagnostic pipeline with parallel INTAKE+VISION stage.
Returns:
PipelineState: Complete execution state containing all outputs,
agent statuses, and final report.
"""
logger.info("π Pipeline execution started | Input ID: %s", id(patient_input))
state = PipelineState()
try:
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 1: INTAKE + VISION in parallel
# Vision uses raw symptoms as clinical context so it doesn't
# have to wait for intake normalization to complete.
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
raw_context = patient_input.symptoms or ""
with ThreadPoolExecutor(max_workers=2) as pool:
intake_fut = pool.submit(
self._execute_step, state, "INTAKE",
patient_input=patient_input
)
vision_fut = pool.submit(
self._execute_step, state, "VISION",
image_base64=patient_input.image_base64,
clinical_context=raw_context
)
futures_wait([intake_fut, vision_fut])
if state.agent_statuses["INTAKE"] == AgentStatus.ERROR:
logger.warning("β οΈ Intake failed. Halting pipeline for safety.")
return state
if state.agent_statuses["VISION"] == AgentStatus.ERROR:
logger.warning("β οΈ Vision analysis failed. Continuing with degraded pipeline.")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 2: RESEARCH (needs vision findings + intake demographics)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
self._execute_step(
state, "RESEARCH",
vision_findings=state.vision_output.findings if state.vision_output else [],
demographics=state.intake_output.extracted_demographics,
detected_modality=state.vision_output.modality_detected.value if state.vision_output else "UNKNOWN"
)
if state.agent_statuses["RESEARCH"] == AgentStatus.ERROR:
logger.warning("β οΈ Research failed. Generating report without differential augmentation.")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 3: REPORT (synthesizes all three upstream outputs)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
self._execute_step(
state, "REPORT",
intake=state.intake_output,
vision=state.vision_output,
research=state.research_output,
)
if state.agent_statuses["REPORT"] == AgentStatus.ERROR:
logger.error("β Report generation failed. Pipeline cannot complete safely.")
return state
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# STAGE 4: CRITIC (QA review of final draft)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
self._execute_step(
state, "CRITIC",
draft_report=state.report_draft,
pipeline_state=state
)
state.final_report = self._assemble_final_report(state)
state.current_step = "COMPLETE"
logger.info("β
Pipeline execution completed successfully.")
except Exception as e:
logger.exception("π₯ Unhandled pipeline failure: %s", str(e))
state.current_step = "FAILED"
state.error_log.append(f"SYSTEM_FAILURE: {str(e)}")
return state
def _execute_step(
self,
state: PipelineState,
agent_name: str,
**kwargs
) -> None:
"""
Generic step executor with state management, timing, and error isolation.
Thread-safe: each agent writes to its own dedicated state field.
"""
logger.info(f"βΆοΈ Executing agent: {agent_name}")
state.current_step = agent_name
state.mark_running(agent_name)
self._notify(state)
start_time = time.perf_counter()
try:
agent = self.agents[agent_name]
output = agent.process(**kwargs)
elapsed = time.perf_counter() - start_time
if output is not None:
if agent_name == "INTAKE":
state.intake_output = output
elif agent_name == "VISION":
state.vision_output = output
elif agent_name == "RESEARCH":
state.research_output = output
elif agent_name == "REPORT":
state.report_draft = output
state.mark_done(agent_name)
logger.info(f"β
{agent_name} completed in {elapsed:.3f}s")
except Exception as e:
elapsed = time.perf_counter() - start_time
error_msg = f"{agent_name} execution failed after {elapsed:.3f}s: {str(e)}"
logger.error("β %s", error_msg, exc_info=True)
state.mark_error(agent_name, str(e))
finally:
self._notify(state)
def _assemble_final_report(self, state: PipelineState) -> FinalReport:
"""
Synthesize all agent outputs into the final deliverable report structure.
Applies critic modifications and standardizes formatting.
"""
report_draft = state.report_draft or ReportSection()
# Determine overall severity from vision findings
overall_severity = "NORMAL"
if state.vision_output and state.vision_output.findings:
severity_hierarchy = {"CRITICAL": 4, "SIGNIFICANT": 3, "INCIDENTAL": 2, "NORMAL": 1}
highest = max(
(severity_hierarchy.get(f.severity.value, 1) for f in state.vision_output.findings),
default=1
)
severity_map = {4: "CRITICAL", 3: "SIGNIFICANT", 2: "INCIDENTAL", 1: "NORMAL"}
overall_severity = severity_map.get(highest, "NORMAL")
# Build vision summary
vision_summary = "No imaging analysis performed."
if state.vision_output:
anomalies = [f.description for f in state.vision_output.findings if f.is_anomaly]
vision_summary = (
f"Modality: {state.vision_output.modality_detected.value} | "
f"Quality: {state.vision_output.technical_quality} | "
f"Anomalies Detected: {len(anomalies)} | "
f"Overall: {state.vision_output.overall_assessment}"
)
# Build research summary
research_summary = "No differential diagnosis generated."
if state.research_output and state.research_output.differential_diagnoses:
top_dx = state.research_output.differential_diagnoses[:3]
dx_list = " | ".join([d.condition_name for d in top_dx])
research_summary = f"Top Differentials: {dx_list} | Confidence: {'/'.join([f'{d.match_probability:.0f}%' for d in top_dx])}"
return FinalReport(
patient_metadata=state.intake_output.extracted_demographics if state.intake_output else {},
sections=report_draft,
vision_summary=vision_summary,
research_summary=research_summary,
overall_severity=overall_severity,
agent_pipeline_status=state.agent_statuses,
generation_timestamp=datetime.now()
)
def _notify(self, state: PipelineState) -> None:
"""Invoke status callback if registered (used for SSE/UI polling)."""
if self.on_status_update:
try:
self.on_status_update(state)
except Exception as e:
logger.warning("β οΈ Status callback failed: %s", str(e))
|