| import os |
| import logging |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import torch |
| import json |
| from datetime import datetime |
| import tensorflow as tf |
| from transformers import pipeline |
| from ultralytics import YOLO |
| from tensorflow.keras.models import load_model |
| from langchain_community.document_loaders import PyPDFLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain_community.vectorstores import FAISS |
| from huggingface_hub import HfApi, HfFolder |
| import spaces |
| import time |
| from typing import Dict, Any, Optional, Tuple |
|
|
| from .config import Config |
|
|
| class EnhancedAIProcessor: |
| """Enhanced AI processor with dashboard integration and analytics tracking""" |
| |
| def __init__(self): |
| self.models_cache = {} |
| self.knowledge_base_cache = {} |
| self.config = Config() |
| self.px_per_cm = self.config.PIXELS_PER_CM |
| self.model_version = "v1.2.0" |
| self._initialize_models() |
|
|
| @spaces.GPU(enable_queue=True, duration=90) |
| def _initialize_models(self): |
| """Initialize all AI models including real-time models""" |
| try: |
| |
| if self.config.HF_TOKEN: |
| HfFolder.save_token(self.config.HF_TOKEN) |
| logging.info("HuggingFace token set successfully") |
|
|
| |
| try: |
| self.models_cache["medgemma_pipe"] = pipeline( |
| "image-text-to-text", |
| model="google/medgemma-4b-it", |
| torch_dtype=torch.bfloat16, |
| offload_folder="offload", |
| device_map="auto", |
| token=self.config.HF_TOKEN |
| ) |
| logging.info("✅ MedGemma pipeline loaded successfully") |
| except Exception as e: |
| logging.warning(f"MedGemma pipeline not available: {e}") |
|
|
| |
| try: |
| self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH) |
| logging.info("✅ YOLO detection model loaded successfully") |
| except Exception as e: |
| logging.warning(f"YOLO model not available: {e}") |
|
|
| |
| try: |
| self.models_cache["seg"] = load_model(self.config.SEG_MODEL_PATH, compile=False) |
| logging.info("✅ Segmentation model loaded successfully") |
| except Exception as e: |
| logging.warning(f"Segmentation model not available: {e}") |
|
|
| |
| try: |
| self.models_cache["cls"] = pipeline( |
| "image-classification", |
| model="Hemg/Wound-classification", |
| token=self.config.HF_TOKEN, |
| device="cpu" |
| ) |
| logging.info("✅ Wound classification model loaded successfully") |
| except Exception as e: |
| logging.warning(f"Wound classification model not available: {e}") |
|
|
| |
| try: |
| self.models_cache["embedding_model"] = HuggingFaceEmbeddings( |
| model_name="sentence-transformers/all-MiniLM-L6-v2", |
| model_kwargs={'device': 'cpu'} |
| ) |
| logging.info("✅ Embedding model loaded successfully") |
| except Exception as e: |
| logging.warning(f"Embedding model not available: {e}") |
|
|
| logging.info("✅ All models loaded.") |
| self._load_knowledge_base() |
|
|
| except Exception as e: |
| logging.error(f"Error initializing AI models: {e}") |
|
|
| def _load_knowledge_base(self): |
| """Load knowledge base from PDF guidelines""" |
| try: |
| documents = [] |
| for pdf_path in self.config.GUIDELINE_PDFS: |
| if os.path.exists(pdf_path): |
| loader = PyPDFLoader(pdf_path) |
| docs = loader.load() |
| documents.extend(docs) |
| logging.info(f"Loaded PDF: {pdf_path}") |
|
|
| if documents and 'embedding_model' in self.models_cache: |
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=100 |
| ) |
| chunks = text_splitter.split_documents(documents) |
| |
| |
| vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model']) |
| self.knowledge_base_cache['vectorstore'] = vectorstore |
| logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks") |
| else: |
| self.knowledge_base_cache['vectorstore'] = None |
| logging.warning("Knowledge base not available - no PDFs found or embedding model unavailable") |
|
|
| except Exception as e: |
| logging.warning(f"Knowledge base loading error: {e}") |
| self.knowledge_base_cache['vectorstore'] = None |
|
|
| def perform_comprehensive_analysis(self, image_pil: Image.Image, patient_info: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Perform comprehensive analysis with enhanced tracking for dashboard integration |
| """ |
| start_time = time.time() |
| |
| try: |
| |
| visual_results = self.perform_visual_analysis(image_pil) |
| |
| |
| guideline_query = f"wound care {visual_results.get('wound_type', 'general')} treatment recommendations" |
| guideline_context = self.query_guidelines(guideline_query) |
| |
| |
| report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) |
| |
| |
| processing_time = round(time.time() - start_time, 2) |
| |
| |
| risk_score = self._calculate_risk_score(visual_results, patient_info) |
| |
| |
| analysis_data = { |
| 'visual_results': visual_results, |
| 'patient_info': patient_info, |
| 'guideline_context': guideline_context, |
| 'report': report, |
| 'processing_time': processing_time, |
| 'risk_score': risk_score, |
| 'model_version': self.model_version, |
| 'analysis_timestamp': datetime.now().isoformat(), |
| 'analysis_metadata': { |
| 'models_used': list(self.models_cache.keys()), |
| 'image_dimensions': image_pil.size, |
| 'guideline_sources': len(guideline_context.split('\n\n')) if guideline_context else 0 |
| } |
| } |
| |
| logging.info(f"✅ Comprehensive analysis completed in {processing_time}s with risk score {risk_score}") |
| return analysis_data |
| |
| except Exception as e: |
| processing_time = round(time.time() - start_time, 2) |
| logging.error(f"❌ Analysis failed after {processing_time}s: {e}") |
| |
| |
| return { |
| 'error': str(e), |
| 'processing_time': processing_time, |
| 'risk_score': 0, |
| 'model_version': self.model_version, |
| 'analysis_timestamp': datetime.now().isoformat() |
| } |
|
|
| def perform_visual_analysis(self, image_pil: Image.Image) -> Dict[str, Any]: |
| """Perform comprehensive visual analysis of wound image with enhanced tracking""" |
| try: |
| |
| image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) |
| |
| |
| if 'det' not in self.models_cache: |
| raise ValueError("YOLO detection model not available.") |
| |
| results = self.models_cache['det'].predict(image_cv, verbose=False, device="cpu") |
| |
| if not results or not results[0].boxes: |
| raise ValueError("No wound detected in the image.") |
| |
| |
| box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) |
| x1, y1, x2, y2 = box |
| region_cv = image_cv[y1:y2, x1:x2] |
| |
| |
| detection_image_cv = image_cv.copy() |
| cv2.rectangle(detection_image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2) |
| os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True) |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
| detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{timestamp}.png") |
| cv2.imwrite(detection_image_path, detection_image_cv) |
| detection_image_pil = Image.fromarray(cv2.cvtColor(detection_image_cv, cv2.COLOR_BGR2RGB)) |
| |
| |
| length = breadth = area = 0 |
| segmentation_image_pil = None |
| segmentation_image_path = None |
| segmentation_confidence = 0.0 |
| |
| |
| if 'seg' in self.models_cache: |
| input_size = self.models_cache['seg'].input_shape[1:3] |
| resized_region = cv2.resize(region_cv, (input_size[1], input_size[0])) |
| |
| seg_input = np.expand_dims(resized_region / 255.0, 0) |
| mask_pred = self.models_cache['seg'].predict(seg_input, verbose=0)[0] |
| mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8) |
| |
| |
| segmentation_confidence = float(np.mean(mask_pred[:, :, 0])) |
| |
| |
| mask_resized = cv2.resize(mask_np, (region_cv.shape[1], region_cv.shape[0]), interpolation=cv2.INTER_NEAREST) |
| |
| |
| overlay = region_cv.copy() |
| overlay[mask_resized == 1] = [0, 0, 255] |
| |
| |
| segmented_visual = cv2.addWeighted(region_cv, 0.7, overlay, 0.3, 0) |
| |
| |
| segmentation_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"segmentation_{timestamp}.png") |
| cv2.imwrite(segmentation_image_path, segmented_visual) |
| segmentation_image_pil = Image.fromarray(cv2.cvtColor(segmented_visual, cv2.COLOR_BGR2RGB)) |
| |
| |
| contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if contours: |
| cnt = max(contours, key=cv2.contourArea) |
| x, y, w, h = cv2.boundingRect(cnt) |
| length = round(h / self.px_per_cm, 2) |
| breadth = round(w / self.px_per_cm, 2) |
| area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2) |
| |
| |
| wound_type = "Unknown" |
| classification_confidence = 0.0 |
| classification_scores = [] |
| |
| if 'cls' in self.models_cache: |
| try: |
| region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB)) |
| cls_result = self.models_cache['cls'](region_pil) |
| |
| if cls_result: |
| best_result = max(cls_result, key=lambda x: x['score']) |
| wound_type = best_result['label'] |
| classification_confidence = float(best_result['score']) |
| classification_scores = [{'label': r['label'], 'score': float(r['score'])} for r in cls_result] |
| |
| except Exception as e: |
| logging.warning(f"Wound classification error: {e}") |
| |
| return { |
| 'wound_type': wound_type, |
| 'length_cm': length, |
| 'breadth_cm': breadth, |
| 'surface_area_cm2': area, |
| 'detection_confidence': float(results[0].boxes[0].conf.cpu().item()), |
| 'segmentation_confidence': segmentation_confidence, |
| 'classification_confidence': classification_confidence, |
| 'classification_scores': classification_scores, |
| 'bounding_box': box.tolist(), |
| 'detection_image_path': detection_image_path, |
| 'detection_image_pil': detection_image_pil, |
| 'segmentation_image_path': segmentation_image_path, |
| 'segmentation_image_pil': segmentation_image_pil, |
| 'analysis_quality': { |
| 'detection_quality': 'high' if float(results[0].boxes[0].conf.cpu().item()) > 0.8 else 'medium', |
| 'segmentation_quality': 'high' if segmentation_confidence > 0.7 else 'medium', |
| 'classification_quality': 'high' if classification_confidence > 0.8 else 'medium' |
| } |
| } |
| |
| except Exception as e: |
| logging.error(f"Visual analysis error: {e}") |
| raise ValueError(f"Visual analysis failed: {str(e)}") |
|
|
| def _calculate_risk_score(self, visual_results: Dict[str, Any], patient_info: Dict[str, Any]) -> int: |
| """ |
| Calculate comprehensive risk score (0-100) based on visual analysis and patient data |
| """ |
| try: |
| risk_score = 0 |
| |
| |
| area = visual_results.get('surface_area_cm2', 0) |
| if area > 10: |
| risk_score += 25 |
| elif area > 5: |
| risk_score += 15 |
| elif area > 2: |
| risk_score += 10 |
| else: |
| risk_score += 5 |
| |
| |
| wound_type = visual_results.get('wound_type', '').lower() |
| high_risk_types = ['ulcer', 'necrotic', 'infected', 'diabetic'] |
| medium_risk_types = ['pressure', 'venous', 'arterial'] |
| |
| if any(risk_type in wound_type for risk_type in high_risk_types): |
| risk_score += 20 |
| elif any(risk_type in wound_type for risk_type in medium_risk_types): |
| risk_score += 15 |
| else: |
| risk_score += 10 |
| |
| |
| age = patient_info.get('patient_age', 0) |
| if age > 70: |
| risk_score += 15 |
| elif age > 50: |
| risk_score += 10 |
| else: |
| risk_score += 5 |
| |
| |
| diabetic_status = patient_info.get('diabetic_status', '').lower() |
| if 'yes' in diabetic_status or 'diabetic' in diabetic_status: |
| risk_score += 15 |
| |
| |
| pain_level = patient_info.get('pain_level', 0) |
| if pain_level > 7: |
| risk_score += 10 |
| elif pain_level > 4: |
| risk_score += 7 |
| else: |
| risk_score += 3 |
| |
| |
| infection_signs = patient_info.get('infection_signs', '').lower() |
| if 'yes' in infection_signs or 'present' in infection_signs: |
| risk_score += 15 |
| elif 'possible' in infection_signs or 'mild' in infection_signs: |
| risk_score += 10 |
| else: |
| risk_score += 5 |
| |
| |
| risk_score = min(max(risk_score, 0), 100) |
| |
| logging.info(f"Calculated risk score: {risk_score}") |
| return risk_score |
| |
| except Exception as e: |
| logging.error(f"Error calculating risk score: {e}") |
| return 50 |
|
|
| def query_guidelines(self, query: str) -> str: |
| """Query the knowledge base for relevant guidelines with enhanced tracking""" |
| try: |
| vector_store = self.knowledge_base_cache.get("vectorstore") |
| if not vector_store: |
| return "Knowledge base unavailable - clinical guidelines not loaded" |
|
|
| |
| retriever = vector_store.as_retriever(search_kwargs={"k": 10}) |
| docs = retriever.invoke(query) |
| |
| if not docs: |
| return "No relevant guidelines found for the query" |
|
|
| |
| formatted_results = [] |
| for i, doc in enumerate(docs): |
| source = doc.metadata.get('source', 'Unknown') |
| page = doc.metadata.get('page', 'N/A') |
| content = doc.page_content.strip() |
| |
| |
| relevance = f"Result {i+1}/10" |
| formatted_results.append(f"[{relevance}] Source: {source}, Page: {page}\nContent: {content}") |
|
|
| guideline_text = "\n\n".join(formatted_results) |
| logging.info(f"Retrieved {len(docs)} guideline documents for query: {query[:50]}...") |
| return guideline_text |
|
|
| except Exception as e: |
| logging.error(f"Guidelines query error: {e}") |
| return f"Error querying guidelines: {str(e)}" |
| |
| @spaces.GPU(enable_queue=True, duration=90) |
| def generate_final_report(self, patient_info: Dict[str, Any], visual_results: Dict[str, Any], |
| guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None) -> str: |
| """Generate comprehensive medical report using MedGemma with enhanced tracking""" |
| try: |
| if 'medgemma_pipe' not in self.models_cache: |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) |
|
|
| max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS |
| |
| |
| detection_image = visual_results.get('detection_image_pil', None) |
| segmentation_image = visual_results.get('segmentation_image_pil', None) |
| |
| |
| analysis_quality = visual_results.get('analysis_quality', {}) |
| prompt = f""" |
| # SmartHeal AI Wound Care Report |
| |
| ## Patient Information |
| {self._format_patient_info(patient_info)} |
| |
| ## Visual Analysis Summary |
| - Wound Type: {visual_results.get('wound_type', 'Unknown')} (Confidence: {visual_results.get('classification_confidence', 0):.2f}) |
| - Dimensions: {visual_results.get('length_cm', 0)} × {visual_results.get('breadth_cm', 0)} cm |
| - Surface Area: {visual_results.get('surface_area_cm2', 0)} cm² |
| - Detection Quality: {analysis_quality.get('detection_quality', 'medium')} |
| - Segmentation Quality: {analysis_quality.get('segmentation_quality', 'medium')} |
| |
| ## Clinical Reference Guidelines |
| {guideline_context[:2000]}... |
| |
| ## Analysis Request |
| You are SmartHeal-AI Agent, a specialized wound care AI with expertise in clinical assessment and evidence-based treatment planning. |
| |
| Based on the comprehensive data provided (patient information, precise wound measurements, clinical guidelines, and visual analysis), generate a structured clinical report with the following sections: |
| |
| ### 1. Clinical Assessment |
| - Detailed wound characterization based on visual analysis |
| - Tissue type assessment (granulation, slough, necrotic, epithelializing) |
| - Peri-wound skin condition evaluation |
| - Infection risk assessment |
| |
| ### 2. Treatment Recommendations |
| - Specific wound care dressing recommendations based on wound characteristics |
| - Topical treatments if indicated |
| - Debridement recommendations if needed |
| - Pressure offloading strategies if applicable |
| |
| ### 3. Risk Stratification |
| - Patient-specific risk factors analysis |
| - Healing prognosis assessment |
| - Complications to monitor |
| |
| ### 4. Follow-up Plan |
| - Recommended assessment frequency |
| - Key monitoring parameters |
| - Escalation criteria for specialist referral |
| |
| Generate a concise, evidence-based report suitable for clinical documentation. |
| """ |
|
|
| |
| content_list = [{"type": "text", "text": prompt}] |
| |
| |
| if image_pil: |
| content_list.insert(0, {"type": "image", "image": image_pil}) |
| |
| if detection_image: |
| content_list.insert(1, {"type": "image", "image": detection_image}) |
| |
| if segmentation_image: |
| content_list.insert(2, {"type": "image", "image": segmentation_image}) |
| |
| messages = [ |
| { |
| "role": "system", |
| "content": [{"type": "text", "text": "You are a specialized medical AI assistant for wound care with expertise in clinical assessment, treatment planning, and evidence-based recommendations. Provide structured, actionable clinical reports."}], |
| }, |
| { |
| "role": "user", |
| "content": content_list |
| } |
| ] |
|
|
| |
| output = self.models_cache['medgemma_pipe']( |
| text=messages, |
| max_new_tokens=max_tokens, |
| do_sample=False, |
| ) |
| |
| generated_content = output[0]['generated_text'] |
| |
| |
| if isinstance(generated_content, list): |
| for message in generated_content: |
| if message.get('role') == 'assistant': |
| report_content = message.get('content', '') |
| if isinstance(report_content, list): |
| report_text = ''.join([item.get('text', '') for item in report_content if item.get('type') == 'text']) |
| else: |
| report_text = str(report_content) |
| break |
| else: |
| report_text = str(generated_content) |
| else: |
| report_text = str(generated_content) |
| |
| |
| report_with_metadata = f""" |
| {report_text} |
| |
| --- |
| **Report Metadata:** |
| - Generated by: SmartHeal AI v{self.model_version} |
| - Analysis Quality: Detection ({analysis_quality.get('detection_quality', 'medium')}), Segmentation ({analysis_quality.get('segmentation_quality', 'medium')}) |
| - Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} |
| """ |
| |
| logging.info("✅ MedGemma report generated successfully") |
| return report_with_metadata |
| |
| except Exception as e: |
| logging.error(f"MedGemma report generation error: {e}") |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) |
|
|
| def _format_patient_info(self, patient_info: Dict[str, Any]) -> str: |
| """Format patient information for report""" |
| formatted = f""" |
| - Name: {patient_info.get('patient_name', 'N/A')} |
| - Age: {patient_info.get('patient_age', 'N/A')} years |
| - Gender: {patient_info.get('patient_gender', 'N/A')} |
| - Wound Location: {patient_info.get('wound_location', 'N/A')} |
| - Wound Duration: {patient_info.get('wound_duration', 'N/A')} |
| - Pain Level: {patient_info.get('pain_level', 'N/A')}/10 |
| - Diabetic Status: {patient_info.get('diabetic_status', 'N/A')} |
| - Infection Signs: {patient_info.get('infection_signs', 'N/A')} |
| - Previous Treatment: {patient_info.get('previous_treatment', 'N/A')} |
| - Medical History: {patient_info.get('medical_history', 'N/A')} |
| - Current Medications: {patient_info.get('medications', 'N/A')} |
| - Known Allergies: {patient_info.get('allergies', 'N/A')} |
| """ |
| return formatted.strip() |
|
|
| def _generate_fallback_report(self, patient_info: Dict[str, Any], visual_results: Dict[str, Any], |
| guideline_context: str) -> str: |
| """Generate fallback report when MedGemma is not available""" |
| |
| wound_type = visual_results.get('wound_type', 'Unknown') |
| length = visual_results.get('length_cm', 0) |
| breadth = visual_results.get('breadth_cm', 0) |
| area = visual_results.get('surface_area_cm2', 0) |
| |
| |
| risk_factors = [] |
| if patient_info.get('patient_age', 0) > 65: |
| risk_factors.append("Advanced age") |
| if 'yes' in str(patient_info.get('diabetic_status', '')).lower(): |
| risk_factors.append("Diabetes mellitus") |
| if patient_info.get('pain_level', 0) > 6: |
| risk_factors.append("High pain level") |
| if area > 5: |
| risk_factors.append("Large wound size") |
| |
| report = f""" |
| # SmartHeal AI Wound Assessment Report |
| |
| ## Clinical Summary |
| **Patient:** {patient_info.get('patient_name', 'N/A')}, {patient_info.get('patient_age', 'N/A')} years old {patient_info.get('patient_gender', '')} |
| |
| **Wound Characteristics:** |
| - Type: {wound_type} |
| - Location: {patient_info.get('wound_location', 'N/A')} |
| - Dimensions: {length} × {breadth} cm (Area: {area} cm²) |
| - Duration: {patient_info.get('wound_duration', 'N/A')} |
| - Pain Level: {patient_info.get('pain_level', 'N/A')}/10 |
| |
| ## Risk Assessment |
| **Identified Risk Factors:** |
| {chr(10).join(f'- {factor}' for factor in risk_factors) if risk_factors else '- No significant risk factors identified'} |
| |
| ## Treatment Recommendations |
| **Wound Care:** |
| - Regular wound assessment and documentation |
| - Appropriate dressing selection based on wound characteristics |
| - Maintain moist wound environment |
| - Monitor for signs of infection |
| |
| **Patient Management:** |
| - Pain management as indicated |
| - Nutritional assessment and optimization |
| - Patient education on wound care |
| |
| ## Follow-up Plan |
| - Reassess wound in 1-2 weeks |
| - Monitor for signs of healing or deterioration |
| - Consider specialist referral if no improvement in 4 weeks |
| |
| --- |
| **Report Generated by:** SmartHeal AI Fallback System v{self.model_version} |
| **Generated at:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} |
| **Note:** This is a basic assessment. For comprehensive analysis, ensure all AI models are properly loaded. |
| """ |
| |
| logging.info("✅ Fallback report generated") |
| return report |
|
|
|
|