| """ |
| MedGemma Agent - LLM agent with tool calling and staged thinking feedback |
| |
| Pipeline: MedGemma independent exam → Tools (MONET/ConvNeXt/GradCAM) → MedGemma reconciliation → Management |
| """ |
|
|
| import sys |
| import time |
| import random |
| import json |
| import os |
| import subprocess |
| import threading |
| from typing import Optional, Generator, Dict, Any |
| from PIL import Image |
|
|
|
|
| class MCPClient: |
| """ |
| Minimal MCP client that communicates with a FastMCP subprocess over stdio. |
| |
| Uses raw newline-delimited JSON-RPC 2.0 so the main process (Python 3.9) |
| does not need the mcp library. The subprocess is launched with python3.11 |
| which has mcp installed. |
| """ |
|
|
| def __init__(self): |
| self._process = None |
| self._lock = threading.Lock() |
| self._id_counter = 0 |
|
|
| def _next_id(self) -> int: |
| self._id_counter += 1 |
| return self._id_counter |
|
|
| def _send(self, obj: dict): |
| line = json.dumps(obj) + "\n" |
| self._process.stdin.write(line) |
| self._process.stdin.flush() |
|
|
| def _recv(self, timeout: int = 300) -> dict: |
| import select |
| deadline = time.time() + timeout |
| while True: |
| remaining = deadline - time.time() |
| if remaining <= 0: |
| raise RuntimeError( |
| f"MCP server did not respond within {timeout}s" |
| ) |
| ready, _, _ = select.select( |
| [self._process.stdout], [], [], min(remaining, 5) |
| ) |
| if not ready: |
| |
| if self._process.poll() is not None: |
| raise RuntimeError( |
| f"MCP server exited with code {self._process.returncode}" |
| ) |
| continue |
| line = self._process.stdout.readline() |
| if not line: |
| raise RuntimeError("MCP server closed connection unexpectedly") |
| line = line.strip() |
| if not line: |
| continue |
| msg = json.loads(line) |
| |
| if "id" in msg: |
| return msg |
|
|
| def _initialize(self): |
| """Send MCP initialize handshake.""" |
| req_id = self._next_id() |
| self._send({ |
| "jsonrpc": "2.0", |
| "id": req_id, |
| "method": "initialize", |
| "params": { |
| "protocolVersion": "2024-11-05", |
| "capabilities": {}, |
| "clientInfo": {"name": "SkinProAI", "version": "1.0.0"}, |
| }, |
| }) |
| self._recv() |
| |
| self._send({ |
| "jsonrpc": "2.0", |
| "method": "notifications/initialized", |
| "params": {}, |
| }) |
|
|
| def start(self): |
| """Spawn the MCP server subprocess and complete the handshake.""" |
| root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| server_script = os.path.join(root, "mcp_server", "server.py") |
| |
| |
| env = os.environ.copy() |
| env["SKINPRO_TOOL_DEVICE"] = "cpu" |
| |
| |
| |
| |
| self._process = subprocess.Popen( |
| [sys.executable, server_script], |
| stdin=subprocess.PIPE, |
| stdout=subprocess.PIPE, |
| stderr=None, |
| text=True, |
| bufsize=1, |
| env=env, |
| ) |
| self._initialize() |
|
|
| def call_tool_sync(self, tool_name: str, arguments: dict) -> dict: |
| """Call a tool synchronously and return the parsed result dict.""" |
| with self._lock: |
| req_id = self._next_id() |
| self._send({ |
| "jsonrpc": "2.0", |
| "id": req_id, |
| "method": "tools/call", |
| "params": {"name": tool_name, "arguments": arguments}, |
| }) |
| response = self._recv() |
|
|
| |
| if "error" in response: |
| raise RuntimeError( |
| f"MCP tool '{tool_name}' failed: {response['error']}" |
| ) |
|
|
| result = response["result"] |
| content_text = result["content"][0]["text"] |
|
|
| |
| if result.get("isError"): |
| raise RuntimeError(f"MCP tool '{tool_name}' error: {content_text}") |
|
|
| return json.loads(content_text) |
|
|
| def stop(self): |
| """Terminate the MCP server subprocess.""" |
| if self._process: |
| try: |
| self._process.stdin.close() |
| self._process.terminate() |
| self._process.wait(timeout=5) |
| except Exception: |
| pass |
| self._process = None |
|
|
|
|
| |
| ANALYSIS_VERBS = [ |
| "Analyzing", "Examining", "Processing", "Inspecting", "Evaluating", |
| "Scanning", "Assessing", "Reviewing", "Studying", "Interpreting" |
| ] |
|
|
| |
| COMPREHENSIVE_EXAM_PROMPT = """Perform a systematic dermoscopic examination of this skin lesion. Assess ALL of the following in a SINGLE concise analysis: |
| |
| 1. PATTERN: Overall architecture, symmetry (symmetric/asymmetric), organization |
| 2. COLORS: List all colors present (brown, black, blue, white, red, pink) and distribution |
| 3. BORDER: Sharp vs gradual, regular vs irregular, any disruptions |
| 4. STRUCTURES: Pigment network, dots/globules, streaks, blue-white veil, regression, vessels |
| |
| Then provide: |
| - Top 3 differential diagnoses with brief reasoning |
| - Concern level (1-5, where 5=urgent) |
| - Single most important feature driving your assessment |
| |
| Be CONCISE - focus on clinically relevant findings only.""" |
|
|
|
|
| def get_verb(): |
| """Get a random analysis verb for spinner effect""" |
| return random.choice(ANALYSIS_VERBS) |
|
|
|
|
| class MedGemmaAgent: |
| """ |
| Medical image analysis agent with: |
| - Staged thinking display (no emojis) |
| - Tool calling (MONET, ConvNeXt, Grad-CAM) |
| - Streaming responses |
| """ |
|
|
| def __init__(self, verbose: bool = True): |
| self.verbose = verbose |
| self.pipe = None |
| self.model_id = "google/medgemma-4b-it" |
| self.loaded = False |
|
|
| |
| self.monet_tool = None |
| self.convnext_tool = None |
| self.gradcam_tool = None |
| self.rag_tool = None |
| self.tools_loaded = False |
|
|
| |
| self.mcp_client = None |
|
|
| |
| self.last_diagnosis = None |
| self.last_monet_result = None |
| self.last_image = None |
| self.last_medgemma_exam = None |
| self.last_reconciliation = None |
|
|
| def reset_state(self): |
| """Reset analysis state for new analysis (keeps models loaded)""" |
| self.last_diagnosis = None |
| self.last_monet_result = None |
| self.last_image = None |
| self.last_medgemma_exam = None |
| self.last_reconciliation = None |
|
|
| def _print(self, message: str): |
| """Print if verbose""" |
| if self.verbose: |
| print(message, flush=True) |
|
|
| def load_model(self): |
| """Load MedGemma model""" |
| if self.loaded: |
| return |
|
|
| self._print("Initializing MedGemma agent...") |
|
|
| import os |
| import torch |
| from transformers import pipeline, AutoProcessor |
|
|
| |
| hf_token = os.environ.get("HF_TOKEN") |
| if hf_token: |
| from huggingface_hub import login |
| login(token=hf_token, add_to_git_credential=False) |
| self._print("Authenticated with HF Hub") |
| else: |
| self._print("Warning: HF_TOKEN not set — gated models will fail") |
|
|
| self._print(f"Loading model: {self.model_id}") |
|
|
| if torch.cuda.is_available(): |
| device = "cuda" |
| self._print(f"Using GPU: {torch.cuda.get_device_name(0)}") |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| self._print("Using Apple Silicon (MPS)") |
| else: |
| device = "cpu" |
| self._print("Using CPU") |
|
|
| model_kwargs = dict( |
| dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
|
|
| start = time.time() |
| processor = AutoProcessor.from_pretrained( |
| self.model_id, use_fast=True, token=hf_token, |
| ) |
| self.pipe = pipeline( |
| "image-text-to-text", |
| model=self.model_id, |
| model_kwargs=model_kwargs, |
| token=hf_token, |
| image_processor=processor.image_processor, |
| tokenizer=processor.tokenizer, |
| ) |
|
|
| |
| |
| if hasattr(self.pipe.model, "generation_config"): |
| self.pipe.model.generation_config.max_length = None |
|
|
| self._print(f"Model loaded in {time.time() - start:.1f}s") |
| self.loaded = True |
|
|
| def load_tools(self): |
| """Load tool models (MONET + ConvNeXt + Grad-CAM + RAG)""" |
| if self.tools_loaded: |
| return |
|
|
| from models.monet_tool import MonetTool |
| self.monet_tool = MonetTool() |
| self.monet_tool.load() |
|
|
| from models.convnext_classifier import ConvNeXtClassifier |
| self.convnext_tool = ConvNeXtClassifier() |
| self.convnext_tool.load() |
|
|
| from models.gradcam_tool import GradCAMTool |
| self.gradcam_tool = GradCAMTool(classifier=self.convnext_tool) |
| self.gradcam_tool.load() |
|
|
| from models.guidelines_rag import get_guidelines_rag |
| self.rag_tool = get_guidelines_rag() |
| if not self.rag_tool.loaded: |
| self.rag_tool.load_index() |
|
|
| self.tools_loaded = True |
|
|
| def load_tools_via_mcp(self): |
| """Start the MCP server subprocess and mark tools as loaded.""" |
| if self.tools_loaded: |
| return |
| self.mcp_client = MCPClient() |
| self.mcp_client.start() |
| self._print("MCP server started successfully") |
| self.tools_loaded = True |
|
|
| def _multi_pass_visual_exam(self, image, question: Optional[str] = None) -> Generator[str, None, Dict[str, str]]: |
| """ |
| MedGemma performs comprehensive visual examination BEFORE tools run. |
| Single prompt covers pattern, colors, borders, structures, and differentials. |
| Returns findings dict after yielding all output. |
| """ |
| findings = {} |
|
|
| yield f"\n[STAGE:medgemma_exam]MedGemma Visual Examination[/STAGE]\n" |
| yield f"[THINKING]Performing systematic dermoscopic assessment...[/THINKING]\n" |
|
|
| |
| exam_prompt = COMPREHENSIVE_EXAM_PROMPT |
| if question: |
| exam_prompt += f"\n\nCLINICAL QUESTION: {question}" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": exam_prompt} |
| ] |
| } |
| ] |
|
|
| try: |
| time.sleep(0.2) |
| output = self.pipe(messages, max_new_tokens=200) |
| result = output[0]["generated_text"][-1]["content"] |
| findings['synthesis'] = result |
|
|
| yield f"[RESPONSE]\n" |
| words = result.split() |
| for i, word in enumerate(words): |
| time.sleep(0.015) |
| yield word + (" " if i < len(words) - 1 else "") |
| yield f"\n[/RESPONSE]\n" |
|
|
| except Exception as e: |
| findings['synthesis'] = f"Analysis failed: {e}" |
| yield f"[ERROR]Visual examination failed: {e}[/ERROR]\n" |
|
|
| self.last_medgemma_exam = findings |
| return findings |
|
|
| def _reconcile_findings( |
| self, |
| image, |
| medgemma_exam: Dict[str, str], |
| monet_result: Dict[str, Any], |
| convnext_result: Dict[str, Any], |
| question: Optional[str] = None |
| ) -> Generator[str, None, None]: |
| """ |
| MedGemma reconciles its independent findings with tool outputs. |
| Identifies agreements, disagreements, and provides integrated assessment. |
| """ |
| yield f"\n[STAGE:reconciliation]Reconciling MedGemma Findings with Tool Results[/STAGE]\n" |
| yield f"[THINKING]Comparing independent visual assessment against AI classification tools...[/THINKING]\n" |
|
|
| top = convnext_result['predictions'][0] |
| runner_up = convnext_result['predictions'][1] if len(convnext_result['predictions']) > 1 else None |
|
|
| |
| monet_top = sorted(monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5] |
| monet_str = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in monet_top]) |
|
|
| reconciliation_prompt = f"""You performed an independent visual examination of this lesion and concluded: |
| |
| YOUR ASSESSMENT: |
| {medgemma_exam.get('synthesis', 'Not available')[:600]} |
| |
| The AI classification tools produced these results: |
| - ConvNeXt classifier: {top['full_name']} ({top['probability']:.1%} confidence) |
| {f"- Runner-up: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""} |
| - Key MONET features: {monet_str} |
| |
| {f'CLINICAL QUESTION: {question}' if question else ''} |
| |
| Reconcile your visual findings with the AI classification: |
| 1. AGREEMENT/DISAGREEMENT: Do your findings support the AI diagnosis? Any conflicts? |
| 2. INTEGRATED ASSESSMENT: Final diagnosis considering all evidence |
| 3. CONFIDENCE (1-10): How certain? What would change your assessment? |
| |
| Be concise and specific.""" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": reconciliation_prompt} |
| ] |
| } |
| ] |
|
|
| try: |
| output = self.pipe(messages, max_new_tokens=150) |
| reconciliation = output[0]["generated_text"][-1]["content"] |
| self.last_reconciliation = reconciliation |
|
|
| yield f"[RESPONSE]\n" |
| words = reconciliation.split() |
| for i, word in enumerate(words): |
| time.sleep(0.015) |
| yield word + (" " if i < len(words) - 1 else "") |
| yield f"\n[/RESPONSE]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]Reconciliation failed: {e}[/ERROR]\n" |
|
|
| def analyze_image_stream( |
| self, |
| image_path: str, |
| question: Optional[str] = None, |
| max_tokens: int = 512, |
| use_tools: bool = True |
| ) -> Generator[str, None, None]: |
| """ |
| Stream analysis with new pipeline: |
| 1. MedGemma independent multi-pass exam |
| 2. MONET + ConvNeXt + GradCAM tools |
| 3. MedGemma reconciliation |
| 4. Confirmation request |
| """ |
| if not self.loaded: |
| yield "[STAGE:loading]Initializing MedGemma...[/STAGE]\n" |
| self.load_model() |
|
|
| yield f"[STAGE:image]{get_verb()} image...[/STAGE]\n" |
|
|
| try: |
| image = Image.open(image_path).convert("RGB") |
| self.last_image = image |
| except Exception as e: |
| yield f"[ERROR]Failed to load image: {e}[/ERROR]\n" |
| return |
|
|
| |
| if use_tools and not self.tools_loaded: |
| yield f"[STAGE:tools]Loading analysis tools...[/STAGE]\n" |
| self.load_tools_via_mcp() |
|
|
| |
| medgemma_exam = {} |
| for chunk in self._multi_pass_visual_exam(image, question): |
| yield chunk |
| if isinstance(chunk, dict): |
| medgemma_exam = chunk |
| medgemma_exam = self.last_medgemma_exam or {} |
|
|
| monet_result = None |
| convnext_result = None |
|
|
| if use_tools: |
| |
| yield f"\n[STAGE:tools_run]Running AI Classification Tools[/STAGE]\n" |
| yield f"[THINKING]Now running MONET and ConvNeXt to compare against visual examination...[/THINKING]\n" |
|
|
| |
| time.sleep(0.2) |
| yield f"\n[STAGE:monet]MONET Feature Extraction[/STAGE]\n" |
|
|
| try: |
| monet_result = self.mcp_client.call_tool_sync( |
| "monet_analyze", {"image_path": image_path} |
| ) |
| self.last_monet_result = monet_result |
|
|
| yield f"[TOOL_OUTPUT:MONET Features]\n" |
| for name, score in monet_result["features"].items(): |
| short_name = name.replace("MONET_", "").replace("_", " ").title() |
| bar_filled = int(score * 10) |
| bar = "|" + "=" * bar_filled + "-" * (10 - bar_filled) + "|" |
| yield f" {short_name}: {bar} {score:.0%}\n" |
| yield f"[/TOOL_OUTPUT]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]MONET failed: {e}[/ERROR]\n" |
|
|
| |
| time.sleep(0.2) |
| yield f"\n[STAGE:convnext]ConvNeXt Classification[/STAGE]\n" |
|
|
| try: |
| monet_scores = monet_result["vector"] if monet_result else None |
| convnext_result = self.mcp_client.call_tool_sync( |
| "classify_lesion", |
| { |
| "image_path": image_path, |
| "monet_scores": monet_scores, |
| }, |
| ) |
| self.last_diagnosis = convnext_result |
|
|
| yield f"[TOOL_OUTPUT:Classification Results]\n" |
| for pred in convnext_result["predictions"][:5]: |
| prob = pred['probability'] |
| bar_filled = int(prob * 20) |
| bar = "|" + "=" * bar_filled + "-" * (20 - bar_filled) + "|" |
| yield f" {pred['class']}: {bar} {prob:.1%}\n" |
| yield f" {pred['full_name']}\n" |
| yield f"[/TOOL_OUTPUT]\n" |
|
|
| top = convnext_result['predictions'][0] |
| yield f"[RESULT]ConvNeXt Primary: {top['full_name']} ({top['probability']:.1%})[/RESULT]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]ConvNeXt failed: {e}[/ERROR]\n" |
|
|
| |
| time.sleep(0.2) |
| yield f"\n[STAGE:gradcam]Grad-CAM Attention Map[/STAGE]\n" |
|
|
| try: |
| gradcam_result = self.mcp_client.call_tool_sync( |
| "generate_gradcam", {"image_path": image_path} |
| ) |
| gradcam_path = gradcam_result["gradcam_path"] |
| yield f"[GRADCAM_IMAGE:{gradcam_path}]\n" |
| except Exception as e: |
| yield f"[ERROR]Grad-CAM failed: {e}[/ERROR]\n" |
|
|
| |
| if convnext_result and monet_result and medgemma_exam: |
| for chunk in self._reconcile_findings( |
| image, medgemma_exam, monet_result, convnext_result, question |
| ): |
| yield chunk |
|
|
| |
| if convnext_result and self.mcp_client: |
| for chunk in self.generate_management_guidance(): |
| yield chunk |
|
|
| def generate_management_guidance( |
| self, |
| user_confirmed: bool = True, |
| user_feedback: Optional[str] = None |
| ) -> Generator[str, None, None]: |
| """ |
| Generate LESION-SPECIFIC management guidance using RAG + MedGemma reasoning. |
| References specific findings from this analysis, not generic textbook management. |
| """ |
| if not self.last_diagnosis: |
| yield "[ERROR]No diagnosis available. Please analyze an image first.[/ERROR]\n" |
| return |
|
|
| top = self.last_diagnosis['predictions'][0] |
| runner_up = self.last_diagnosis['predictions'][1] if len(self.last_diagnosis['predictions']) > 1 else None |
| diagnosis = top['full_name'] |
|
|
| if not user_confirmed and user_feedback: |
| yield f"[THINKING]Clinician provided alternative assessment: {user_feedback}[/THINKING]\n" |
| diagnosis = user_feedback |
|
|
| |
| time.sleep(0.3) |
| yield f"\n[STAGE:guidelines]Searching clinical guidelines for {diagnosis}...[/STAGE]\n" |
|
|
| |
| features_desc = self.last_monet_result.get('description', '') if self.last_monet_result else '' |
| rag_data = self.mcp_client.call_tool_sync( |
| "search_guidelines", |
| {"query": features_desc, "diagnosis": diagnosis}, |
| ) |
| context = rag_data["context"] |
| references = rag_data["references"] |
|
|
| |
| has_relevant_guidelines = False |
| if references: |
| diagnosis_lower = diagnosis.lower() |
| for ref in references: |
| source_lower = ref['source'].lower() |
| if any(term in diagnosis_lower for term in ['melanoma']) and 'melanoma' in source_lower: |
| has_relevant_guidelines = True |
| break |
| elif 'actinic' in diagnosis_lower and 'actinic' in source_lower: |
| has_relevant_guidelines = True |
| break |
| elif ref.get('score', 0) > 0.7: |
| has_relevant_guidelines = True |
| break |
|
|
| if not references or not has_relevant_guidelines: |
| yield f"[THINKING]No specific published guidelines for {diagnosis}. Using clinical knowledge.[/THINKING]\n" |
| context = "No specific clinical guidelines available." |
| references = [] |
|
|
| |
| monet_features = "" |
| if self.last_monet_result: |
| top_features = sorted(self.last_monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5] |
| monet_features = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in top_features]) |
|
|
| |
| time.sleep(0.3) |
| yield f"\n[STAGE:management]Generating Lesion-Specific Management Plan[/STAGE]\n" |
| yield f"[THINKING]Creating management plan tailored to THIS lesion's specific characteristics...[/THINKING]\n" |
|
|
| management_prompt = f"""Generate a CONCISE management plan for this lesion: |
| |
| DIAGNOSIS: {diagnosis} ({top['probability']:.1%}) |
| {f"Alternative: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""} |
| KEY FEATURES: {monet_features} |
| |
| {f"GUIDELINES: {context[:800]}" if context else ""} |
| |
| Provide: |
| 1. RECOMMENDED ACTION: Biopsy, excision, monitoring, or discharge - with specific reasoning |
| 2. URGENCY: Routine vs urgent vs same-day referral |
| 3. KEY CONCERNS: What features drive this recommendation |
| |
| Be specific to THIS lesion. 3-5 sentences maximum.""" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": self.last_image}, |
| {"type": "text", "text": management_prompt} |
| ] |
| } |
| ] |
|
|
| |
| start = time.time() |
| try: |
| output = self.pipe(messages, max_new_tokens=150) |
| response = output[0]["generated_text"][-1]["content"] |
|
|
| yield f"[RESPONSE]\n" |
| words = response.split() |
| for i, word in enumerate(words): |
| time.sleep(0.015) |
| yield word + (" " if i < len(words) - 1 else "") |
| yield f"\n[/RESPONSE]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]Management generation failed: {e}[/ERROR]\n" |
|
|
| |
| if references: |
| yield rag_data["references_display"] |
|
|
| yield f"\n[COMPLETE]Lesion-specific management plan generated in {time.time() - start:.1f}s[/COMPLETE]\n" |
|
|
| |
| self.last_management_response = response |
|
|
| def extract_recommendation(self) -> Generator[str, None, Dict[str, Any]]: |
| """ |
| Extract structured recommendation from management guidance. |
| Determines: BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE |
| For BIOPSY/EXCISION, gets coordinates from MedGemma. |
| """ |
| if not self.last_management_response or not self.last_image: |
| yield "[ERROR]No management guidance available[/ERROR]\n" |
| return {"action": "UNKNOWN"} |
|
|
| yield f"\n[STAGE:recommendation]Extracting Clinical Recommendation[/STAGE]\n" |
|
|
| |
| classification_prompt = f"""Based on the management plan you just provided: |
| |
| {self.last_management_response[:1000]} |
| |
| Classify the PRIMARY recommended action into exactly ONE of these categories: |
| - BIOPSY: If punch biopsy, shave biopsy, or incisional biopsy is recommended |
| - EXCISION: If complete surgical excision is recommended |
| - FOLLOWUP: If monitoring with repeat photography/dermoscopy is recommended |
| - DISCHARGE: If the lesion is clearly benign and no follow-up needed |
| |
| Respond with ONLY the category name (BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE) on the first line. |
| Then on the second line, provide a brief (1 sentence) justification.""" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": self.last_image}, |
| {"type": "text", "text": classification_prompt} |
| ] |
| } |
| ] |
|
|
| try: |
| output = self.pipe(messages, max_new_tokens=100) |
| response = output[0]["generated_text"][-1]["content"].strip() |
| lines = response.split('\n') |
| action = lines[0].strip().upper() |
| justification = lines[1].strip() if len(lines) > 1 else "" |
|
|
| |
| valid_actions = ["BIOPSY", "EXCISION", "FOLLOWUP", "DISCHARGE"] |
| if action not in valid_actions: |
| |
| for valid in valid_actions: |
| if valid in response.upper(): |
| action = valid |
| break |
| else: |
| action = "FOLLOWUP" |
|
|
| yield f"[RESULT]Recommended Action: {action}[/RESULT]\n" |
| yield f"[OBSERVATION]{justification}[/OBSERVATION]\n" |
|
|
| result = { |
| "action": action, |
| "justification": justification |
| } |
|
|
| return result |
|
|
| except Exception as e: |
| yield f"[ERROR]Failed to extract recommendation: {e}[/ERROR]\n" |
| return {"action": "UNKNOWN", "error": str(e)} |
|
|
| def compare_followup_images( |
| self, |
| previous_image_path: str, |
| current_image_path: str |
| ) -> Generator[str, None, None]: |
| """ |
| Compare a follow-up image with the previous one. |
| Runs full analysis pipeline on current image, then compares findings. |
| """ |
| yield f"\n[STAGE:comparison]Follow-up Comparison Analysis[/STAGE]\n" |
|
|
| try: |
| current_image = Image.open(current_image_path).convert("RGB") |
| except Exception as e: |
| yield f"[ERROR]Failed to load images: {e}[/ERROR]\n" |
| return |
|
|
| |
| prev_exam = self.last_medgemma_exam |
|
|
| |
| yield f"\n[STAGE:current_analysis]Analyzing Current Image[/STAGE]\n" |
|
|
| if self.tools_loaded: |
| try: |
| compare_data = self.mcp_client.call_tool_sync( |
| "compare_images", |
| { |
| "image1_path": previous_image_path, |
| "image2_path": current_image_path, |
| }, |
| ) |
| yield f"[COMPARISON_IMAGE:{compare_data['comparison_path']}]\n" |
|
|
| |
| prev_gc = compare_data.get("prev_gradcam_path") |
| curr_gc = compare_data.get("curr_gradcam_path") |
| if prev_gc and curr_gc: |
| yield f"[GRADCAM_COMPARE:{prev_gc}:{curr_gc}]\n" |
|
|
| |
| if compare_data["monet_deltas"]: |
| yield f"[TOOL_OUTPUT:Feature Comparison]\n" |
| for name, delta_info in compare_data["monet_deltas"].items(): |
| prev_val = delta_info["previous"] |
| curr_val = delta_info["current"] |
| diff = delta_info["delta"] |
| short_name = name.replace("MONET_", "").replace("_", " ").title() |
| direction = "↑" if diff > 0 else "↓" |
| yield f" {short_name}: {prev_val:.0%} → {curr_val:.0%} ({direction}{abs(diff):.0%})\n" |
| yield f"[/TOOL_OUTPUT]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]MCP comparison failed: {e}[/ERROR]\n" |
|
|
| |
| comparison_prompt = f"""You are comparing TWO images of the same skin lesion taken at different times. |
| |
| PREVIOUS ANALYSIS: |
| {prev_exam.get('synthesis', 'Not available')[:500] if prev_exam else 'Not available'} |
| |
| Now examine the CURRENT image and compare to your memory of the previous findings. |
| |
| Assess for changes in: |
| 1. SIZE: Has the lesion grown, shrunk, or stayed the same? |
| 2. COLOR: Any new colors appeared? Any colors faded? |
| 3. SHAPE/SYMMETRY: Has the shape changed? More or less symmetric? |
| 4. BORDERS: Sharper, more irregular, or unchanged? |
| 5. STRUCTURES: New dermoscopic structures? Lost structures? |
| |
| Provide your assessment: |
| - CHANGE_LEVEL: SIGNIFICANT_CHANGE / MINOR_CHANGE / STABLE / IMPROVED |
| - Specific changes observed |
| - Clinical recommendation based on changes""" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": current_image}, |
| {"type": "text", "text": comparison_prompt} |
| ] |
| } |
| ] |
|
|
| try: |
| yield f"[THINKING]Comparing current image to previous findings...[/THINKING]\n" |
| output = self.pipe(messages, max_new_tokens=200) |
| comparison_result = output[0]["generated_text"][-1]["content"] |
|
|
| yield f"[RESPONSE]\n" |
| words = comparison_result.split() |
| for i, word in enumerate(words): |
| time.sleep(0.02) |
| yield word + (" " if i < len(words) - 1 else "") |
| yield f"\n[/RESPONSE]\n" |
|
|
| |
| change_level = "UNKNOWN" |
| for level in ["SIGNIFICANT_CHANGE", "MINOR_CHANGE", "STABLE", "IMPROVED"]: |
| if level in comparison_result.upper(): |
| change_level = level |
| break |
|
|
| if change_level == "SIGNIFICANT_CHANGE": |
| yield f"[RESULT]⚠️ SIGNIFICANT CHANGES DETECTED - Further evaluation recommended[/RESULT]\n" |
| elif change_level == "IMPROVED": |
| yield f"[RESULT]✓ LESION IMPROVED - Continue monitoring[/RESULT]\n" |
| elif change_level == "STABLE": |
| yield f"[RESULT]✓ LESION STABLE - Continue scheduled follow-up[/RESULT]\n" |
| else: |
| yield f"[RESULT]Minor changes noted - Clinical correlation recommended[/RESULT]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]Comparison analysis failed: {e}[/ERROR]\n" |
|
|
| yield f"\n[COMPLETE]Follow-up comparison complete[/COMPLETE]\n" |
|
|
| def chat(self, message: str, image_path: Optional[str] = None) -> str: |
| """Simple chat interface""" |
| if not self.loaded: |
| self.load_model() |
|
|
| content = [] |
| if image_path: |
| image = Image.open(image_path).convert("RGB") |
| content.append({"type": "image", "image": image}) |
| content.append({"type": "text", "text": message}) |
|
|
| messages = [{"role": "user", "content": content}] |
| output = self.pipe(messages, max_new_tokens=200) |
| return output[0]["generated_text"][-1]["content"] |
|
|
| def chat_followup(self, message: str) -> Generator[str, None, None]: |
| """ |
| Handle follow-up questions using the stored analysis context. |
| Uses the last analyzed image and diagnosis to provide contextual responses. |
| """ |
| if not self.loaded: |
| yield "[ERROR]Model not loaded[/ERROR]\n" |
| return |
|
|
| if not self.last_diagnosis or not self.last_image: |
| yield "[ERROR]No previous analysis context. Please analyze an image first.[/ERROR]\n" |
| return |
|
|
| |
| top_diagnosis = self.last_diagnosis['predictions'][0] |
| differentials = ", ".join([ |
| f"{p['class']} ({p['probability']:.0%})" |
| for p in self.last_diagnosis['predictions'][:3] |
| ]) |
|
|
| monet_desc = "" |
| if self.last_monet_result: |
| monet_desc = self.last_monet_result.get('description', '') |
|
|
| context_prompt = f"""You are a dermatology assistant helping with skin lesion analysis. |
| |
| PREVIOUS ANALYSIS CONTEXT: |
| - Primary diagnosis: {top_diagnosis['full_name']} ({top_diagnosis['probability']:.1%} confidence) |
| - Differential diagnoses: {differentials} |
| - Visual features: {monet_desc} |
| |
| The user has a follow-up question about this lesion. Please provide a helpful, medically accurate response. |
| |
| USER QUESTION: {message} |
| |
| Provide a concise, informative response. If the question is outside your expertise or requires in-person examination, say so.""" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": self.last_image}, |
| {"type": "text", "text": context_prompt} |
| ] |
| } |
| ] |
|
|
| try: |
| yield f"[THINKING]Considering your question in context of the previous analysis...[/THINKING]\n" |
| time.sleep(0.2) |
|
|
| output = self.pipe(messages, max_new_tokens=200) |
| response = output[0]["generated_text"][-1]["content"] |
|
|
| yield f"[RESPONSE]\n" |
| |
| words = response.split() |
| for i, word in enumerate(words): |
| time.sleep(0.02) |
| yield word + (" " if i < len(words) - 1 else "") |
| yield f"\n[/RESPONSE]\n" |
|
|
| except Exception as e: |
| yield f"[ERROR]Failed to generate response: {e}[/ERROR]\n" |
|
|
|
|
| def main(): |
| """Interactive terminal interface""" |
| print("=" * 60) |
| print(" MedGemma Agent - Medical Image Analysis") |
| print("=" * 60) |
|
|
| agent = MedGemmaAgent(verbose=True) |
| agent.load_model() |
|
|
| print("\nCommands: analyze <path>, chat <message>, quit") |
|
|
| while True: |
| try: |
| user_input = input("\n> ").strip() |
| if not user_input: |
| continue |
|
|
| if user_input.lower() in ["quit", "exit", "q"]: |
| break |
|
|
| parts = user_input.split(maxsplit=1) |
| cmd = parts[0].lower() |
|
|
| if cmd == "analyze" and len(parts) > 1: |
| for chunk in agent.analyze_image_stream(parts[1].strip()): |
| print(chunk, end="", flush=True) |
|
|
| elif cmd == "chat" and len(parts) > 1: |
| print(agent.chat(parts[1])) |
|
|
| else: |
| print("Unknown command") |
|
|
| except KeyboardInterrupt: |
| break |
| except Exception as e: |
| print(f"Error: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|