import os import argparse import numpy as np import torch import torch.backends.cudnn as cudnn import gradio as gr import spaces from pathlib import Path from huggingface_hub import hf_hub_download, snapshot_download import matplotlib.pyplot as plt from typing import Tuple, Optional, List, Dict import pandas as pd import tempfile import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots from scipy import ndimage from skimage import measure import cv2 import json import logging import nibabel as nib from PIL import Image import SimpleITK as sitk from monai.inferers import sliding_window_inference from monai.metrics import compute_dice import monai import time import pickle from monai.transforms import ( Compose, NormalizeIntensityd, CenterSpatialCropd, SpatialPadd, Spacing, CenterSpatialCrop, SpatialPad, Resize ) # Configure matplotlib basic settings plt.rcParams['axes.unicode_minus'] = False # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Cache directory cache_dir = Path(__file__).parent / "cache" cache_dir.mkdir(parents=True, exist_ok=True) # Results cache directory results_cache_dir = cache_dir / "precomputed_results" results_cache_dir.mkdir(parents=True, exist_ok=True) # Device Configuration (compatible with Spaces stateless GPU) # Note: Device will be set inside @spaces.GPU decorated functions device = None # Will be set when GPU is available dtype = torch.float32 # Use full precision (FP32) # GPU optimization flags (will be applied in GPU-decorated functions) _gpu_optimizations_applied = False # Performance optimization flags ENABLE_TORCH_COMPILE = os.getenv("ENABLE_TORCH_COMPILE", "true").lower() == "true" ENABLE_MODEL_WARMUP = os.getenv("ENABLE_MODEL_WARMUP", "true").lower() == "true" ENABLE_PRECOMPUTE = os.getenv("ENABLE_PRECOMPUTE", "true").lower() == "true" # Global cache for precomputed results PRECOMPUTED_RESULTS = {} def save_results_to_cache(results: Dict, cache_file: Path): """Save inference results to cache file""" try: with open(cache_file, 'wb') as f: pickle.dump(results, f) logger.info(f"✅ Successfully saved inference results cache to: {cache_file}") except Exception as e: logger.error(f"❌ Failed to save cache: {e}") def load_results_from_cache(cache_file: Path) -> Dict: """Load inference results from cache file""" try: if cache_file.exists(): with open(cache_file, 'rb') as f: results = pickle.load(f) logger.info(f"✅ Successfully loaded inference results from cache: {cache_file}") return results else: logger.info(f"📁 Cache file does not exist: {cache_file}") return {} except Exception as e: logger.error(f"❌ Failed to load cache: {e}") return {} def tensor_to_serializable(obj): """Convert tensor to serializable format""" if isinstance(obj, torch.Tensor): return { '_type': 'tensor', 'data': obj.cpu().numpy(), 'shape': obj.shape, 'dtype': str(obj.dtype) } elif isinstance(obj, dict): return {k: tensor_to_serializable(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [tensor_to_serializable(item) for item in obj] else: return obj def serializable_to_tensor(obj): """Convert serializable format back to tensor""" if isinstance(obj, dict): if obj.get('_type') == 'tensor': return torch.from_numpy(obj['data']) else: return {k: serializable_to_tensor(v) for k, v in obj.items()} elif isinstance(obj, list): return [serializable_to_tensor(item) for item in obj] else: return obj @spaces.GPU def precompute_all_results(): """Precompute all inference results for all patients and all tasks""" if not ENABLE_PRECOMPUTE: logger.info("🚫 Precomputation disabled, skipping precomputation process") return {} logger.info("🚀 Starting precomputation of all inference results...") # Apply GPU optimizations device = torch.device("cuda" if torch.cuda.is_available() else "cpu") apply_gpu_optimizations() # Set seed for reproducibility torch.manual_seed(0) np.random.seed(0) cudnn.benchmark = True # Load sample data sample_patients, dataset_root = load_sample_data() all_results = { 'classification': {}, 'segmentation': {}, 'anatomy_segmentation': {} } # Precompute classification results if 'risk' in sample_patients: logger.info("📊 Precomputing classification tasks...") try: cls_model, cls_args = load_classification_model() for patient_idx in range(len(sample_patients['risk'])): patient_id = f"Patient_{patient_idx}" logger.info(f"Processing classification task - {patient_id}") try: # Create dataloader for single patient data_loader, args, dataset = create_single_sample_dataloader(patient_idx, 'classification') preprocessed_data = get_preprocessed_patient_data(patient_idx, 'classification', device) # Run inference cls_model.eval() with torch.no_grad(): for idx, data in enumerate(data_loader): img, gt, pid = data img = img.to(device, dtype=dtype, non_blocking=True) gt = gt.to(device, non_blocking=True) with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()): logit = cls_model(img) # Store results all_results['classification'][patient_id] = { 'prediction': tensor_to_serializable(logit), 'preprocessed_data': tensor_to_serializable(preprocessed_data), 'patient_id': patient_id } break logger.info(f"✅ Classification task completed - {patient_id}") except Exception as e: logger.error(f"❌ Classification task failed - {patient_id}: {e}") continue except Exception as e: logger.error(f"❌ Classification model loading failed: {e}") # Precompute segmentation results if 'UCL' in sample_patients: logger.info("✂️ Precomputing lesion segmentation tasks...") try: seg_model, seg_args = load_segmentation_model() for patient_idx in range(len(sample_patients['UCL'])): patient_id = f"Patient_{patient_idx}" logger.info(f"Processing lesion segmentation task - {patient_id}") try: # Create dataloader for single patient data_loader, args, dataset = create_single_sample_dataloader(patient_idx, 'segmentation') preprocessed_data = get_preprocessed_patient_data(patient_idx, 'segmentation', device) # Run inference seg_model.eval() with torch.no_grad(): for idx, data in enumerate(data_loader): if len(data) == 4: img, gt, pid, gland = data else: img, gt, pid = data gland = None img = img.to(device, dtype=dtype, non_blocking=True) ground_truth_tensor = gt.to(device, non_blocking=True) if gt is not None else None with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()): if args.sliding_window: pred = sliding_window_inference(img, args.crop_spatial_size, 4, seg_model, overlap=0.5) else: pred = seg_model(img) # Precompute slice cache for smooth navigation local_cached_slice_data = precompute_slice_cache(preprocessed_data, pred) # Store results all_results['segmentation'][patient_id] = { 'prediction': tensor_to_serializable(pred), 'preprocessed_data': tensor_to_serializable(preprocessed_data), 'cached_slice_data': local_cached_slice_data, 'patient_id': patient_id, 'ground_truth_tensor': tensor_to_serializable(ground_truth_tensor) if ground_truth_tensor is not None else None } break logger.info(f"✅ Lesion segmentation task completed - {patient_id}") except Exception as e: logger.error(f"❌ Lesion segmentation task failed - {patient_id}: {e}") continue except Exception as e: logger.error(f"❌ Segmentation model loading failed: {e}") # Precompute anatomy segmentation results if 'anatomy' in sample_patients: logger.info("🫀 Precomputing anatomy segmentation tasks...") try: anat_model, anat_args = load_anatomy_segmentation_model() for patient_idx in range(len(sample_patients['anatomy'])): patient_id = f"Patient_{patient_idx}" logger.info(f"Processing anatomy segmentation task - {patient_id}") try: # Create dataloader for single patient data_loader, args, dataset = create_single_sample_dataloader(patient_idx, 'anatomy_segmentation') preprocessed_data = get_preprocessed_patient_data(patient_idx, 'anatomy_segmentation', device) # Run inference anat_model.eval() with torch.no_grad(): for idx, data in enumerate(data_loader): img, gt, pid = data img = img.to(device, dtype=dtype, non_blocking=True) ground_truth_tensor = gt.to(device, non_blocking=True) if gt is not None else None with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()): if args.sliding_window: pred = sliding_window_inference(img, args.crop_spatial_size, 4, anat_model, overlap=0.5) else: pred = anat_model(img) # Store results all_results['anatomy_segmentation'][patient_id] = { 'prediction': tensor_to_serializable(pred), 'preprocessed_data': tensor_to_serializable(preprocessed_data), 'patient_id': patient_id, 'ground_truth_tensor': tensor_to_serializable(ground_truth_tensor) if ground_truth_tensor is not None else None } break logger.info(f"✅ Anatomy segmentation task completed - {patient_id}") except Exception as e: logger.error(f"❌ Anatomy segmentation task failed - {patient_id}: {e}") continue except Exception as e: logger.error(f"❌ Anatomy segmentation model loading failed: {e}") # Save results to cache cache_file = results_cache_dir / "all_results.pkl" save_results_to_cache(all_results, cache_file) logger.info("🎉 All inference results precomputation completed!") return all_results def load_precomputed_results(): """Load precomputed inference results""" global PRECOMPUTED_RESULTS cache_file = results_cache_dir / "all_results.pkl" PRECOMPUTED_RESULTS = load_results_from_cache(cache_file) if not PRECOMPUTED_RESULTS: logger.info("🔄 No precomputed results found, starting precomputation...") logger.info("⏳ This may take a few minutes, please be patient...") PRECOMPUTED_RESULTS = precompute_all_results() logger.info("🎉 Precomputation completed! All inference will now load quickly from cache") else: logger.info(f"✅ Successfully loaded precomputed results, containing:") logger.info(f" 📊 Classification tasks: {len(PRECOMPUTED_RESULTS.get('classification', {}))} patients") logger.info(f" ✂️ Lesion segmentation: {len(PRECOMPUTED_RESULTS.get('segmentation', {}))} patients") logger.info(f" 🫀 Anatomy segmentation: {len(PRECOMPUTED_RESULTS.get('anatomy_segmentation', {}))} patients") logger.info("⚡ All inference will load quickly from cache without GPU computation") def force_recompute_all_results(): """Force recomputation of all inference results (use when cache needs updating)""" global PRECOMPUTED_RESULTS logger.info("🔄 Force recomputing all inference results...") # Delete existing cache file cache_file = results_cache_dir / "all_results.pkl" if cache_file.exists(): cache_file.unlink() logger.info("🗑️ Deleted old cache file") # Recompute and save PRECOMPUTED_RESULTS = precompute_all_results() logger.info("🎉 Recomputation completed!") return PRECOMPUTED_RESULTS def load_pretrain_data(): """Load pre-training dataset from Hugging Face Dataset Hub""" try: # Download dataset if not already cached dataset_path = snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", cache_dir=cache_dir ) dataset_path = Path(dataset_path) pretrain_path = dataset_path / "demo" / "data" / "pretrain" # Check available patients pretrain_patients = [] if pretrain_path.exists(): for patient_dir in pretrain_path.iterdir(): if patient_dir.is_dir() and patient_dir.name.startswith("patient_"): pretrain_patients.append(patient_dir.name) pretrain_patients.sort() # Ensure consistent ordering logger.info(f"Found {len(pretrain_patients)} pre-training patients: {pretrain_patients}") return pretrain_patients, pretrain_path except Exception as e: logger.error(f"Error loading pre-training data: {e}") return [], None def load_pretrain_patient_data(patient_id: str, pretrain_root: Path): """Load all 9 pre-training images for a specific patient""" try: logger.info(f"🔍 Loading pre-training data for {patient_id}") patient_path = pretrain_root / patient_id if not patient_path.exists(): raise ValueError(f"Patient directory not found: {patient_path}") # Define expected file structure modalities = ['T2W', 'DWI', 'ADC'] processing_types = ['original', 'masked', 'reconstructed'] patient_data = {} for modality in modalities: patient_data[modality] = {} for proc_type in processing_types: file_name = f"{modality}_{proc_type}.nii.gz" file_path = patient_path / file_name if file_path.exists(): # Load NIfTI file nii_img = nib.load(str(file_path)) img_data = nii_img.get_fdata() # Convert to tensor for consistency with other parts of the code img_tensor = torch.from_numpy(img_data).float() patient_data[modality][proc_type] = img_tensor logger.info(f"✅ Loaded {file_name}: shape={img_tensor.shape} (XYZ format)") else: logger.warning(f"⚠️ File not found: {file_path}") patient_data[modality][proc_type] = None logger.info(f"Successfully loaded pre-training data for {patient_id}") return patient_data except Exception as e: logger.error(f"Error loading patient data for {patient_id}: {e}") return None def create_pretrain_visualization(patient_data: Dict, slice_idx: int = None) -> plt.Figure: """Create 3x3 grid visualization for pre-training data showing XY plane at different Z slices""" try: if patient_data is None: fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, "No patient data loaded", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig modalities = ['T2W', 'DWI', 'ADC'] processing_types = ['original', 'masked', 'reconstructed'] # Determine slice index from T2W original if not specified if slice_idx is None: if patient_data.get('T2W', {}).get('original') is not None: # Data shape is XYZ, so Z dimension is shape[2] slice_idx = patient_data['T2W']['original'].shape[2] // 2 else: slice_idx = 32 # Default fallback # Create 3x3 grid fig, axes = plt.subplots(3, 3, figsize=(12, 10)) for row, modality in enumerate(modalities): for col, proc_type in enumerate(processing_types): ax = axes[row, col] # Get image data img_data = patient_data.get(modality, {}).get(proc_type) if img_data is not None: # Ensure we have 3D data (X, Y, Z) if len(img_data.shape) != 3: logger.warning(f"Unexpected image shape for {modality}_{proc_type}: {img_data.shape}") ax.text(0.5, 0.5, f'{modality}\n{proc_type}\nUnexpected shape: {img_data.shape}', ha='center', va='center', transform=ax.transAxes, fontsize=10) ax.set_title(f'{modality} - {proc_type.capitalize()}', fontsize=10, fontweight='bold') ax.axis('off') continue # Ensure slice index is valid for Z dimension (shape[2] for XYZ format) max_slice = img_data.shape[2] - 1 # Z dimension is the 3rd dimension current_slice = max(0, min(slice_idx, max_slice)) # Extract XY plane at Z=current_slice (from XYZ format) xy_slice = img_data[:, :, current_slice].cpu().numpy() # Shape: (X, Y) # Transpose to get correct orientation (Y, X) for proper display xy_slice = xy_slice.T # Now shape: (Y, X) # Flip vertically to correct up-down orientation xy_slice = np.flipud(xy_slice) # Display image - XY plane im = ax.imshow(xy_slice, cmap='gray', aspect='equal', interpolation='nearest', origin='lower') # Add title for each image ax.set_title(f'{modality} - {proc_type.capitalize()}', fontsize=12, fontweight='bold') else: # Show placeholder for missing data ax.text(0.5, 0.5, f'{modality}\n{proc_type}\n(Not Available)', ha='center', va='center', transform=ax.transAxes, fontsize=10) ax.set_title(f'{modality} - {proc_type.capitalize()}', fontsize=12, fontweight='bold') ax.axis('off') # Adjust layout with padding for titles - fix title clipping issue plt.tight_layout(pad=3.0) plt.subplots_adjust(top=0.88) # Leave space for titles at the top return fig except Exception as e: logger.error(f"Error creating pre-training visualization: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Visualization error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig def run_pretrain_reconstruction(patient_id: str, current_state, progress=gr.Progress()): """Load and display pre-training reconstruction results""" try: progress(0, desc="📂 Loading pre-training data...") # Load pre-training data pretrain_patients, pretrain_root = load_pretrain_data() if not pretrain_patients or pretrain_root is None: progress(1, desc="❌ Pre-training data not found!") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, "Pre-training data not available", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') error_state = {'error': "Pre-training data not available", 'timestamp': time.time()} return fig, gr.update(maximum=63, value=32), error_state if patient_id not in pretrain_patients: progress(1, desc="❌ Patient not found!") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Patient {patient_id} not found", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') error_state = {'error': f"Patient {patient_id} not found", 'timestamp': time.time()} return fig, gr.update(maximum=63, value=32), error_state progress(0.5, desc="🔄 Loading patient images...") # Load patient data patient_data = load_pretrain_patient_data(patient_id, pretrain_root) if patient_data is None: progress(1, desc="❌ Failed to load patient data!") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Failed to load data for {patient_id}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') error_state = {'error': f"Failed to load data for {patient_id}", 'timestamp': time.time()} return fig, gr.update(maximum=63, value=32), error_state progress(0.8, desc="🎨 Creating XY plane visualization...") # Create visualization middle_slice = 32 # Default middle slice if patient_data.get('T2W', {}).get('original') is not None: # Data shape is XYZ, so Z dimension is shape[2] middle_slice = patient_data['T2W']['original'].shape[2] // 2 pretrain_fig = create_pretrain_visualization(patient_data, middle_slice) # Create state dictionary state_data = { 'patient_id': patient_id, 'patient_data': patient_data, 'pretrain_root': pretrain_root, 'timestamp': time.time() } # Determine maximum slice number (Z dimension is shape[2] for XYZ format) max_slice = 63 # Default if patient_data.get('T2W', {}).get('original') is not None: max_slice = patient_data['T2W']['original'].shape[2] - 1 progress(1, desc="✅ Completed!") return ( pretrain_fig, gr.update(minimum=0, maximum=max_slice, value=middle_slice), state_data ) except Exception as e: logger.error(f"Error in pre-training reconstruction: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') error_state = {'error': str(e), 'timestamp': time.time()} return fig, gr.update(maximum=63, value=32), error_state def update_pretrain_slice_with_state(slice_idx: int, state_data): """Update pre-training slice visualization using state data""" try: # Check if state object is valid if state_data is None or 'error' in state_data: error_msg = state_data.get('error', "No data loaded. Please run reconstruction first.") if state_data else "No data loaded. Please run reconstruction first." logger.warning(f"Invalid state data for pre-training slice browser: {error_msg}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, error_msg, ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig # Get data from state patient_data = state_data.get('patient_data') if patient_data is not None: return create_pretrain_visualization(patient_data, slice_idx) # If no valid data is found in the state logger.error("State is invalid. No patient data found.") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, "Invalid data state. Please run reconstruction again.", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig except Exception as e: logger.error(f"Error updating pre-training slice with state: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Update error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig def apply_gpu_optimizations(): """Apply H200 GPU optimizations - only call inside @spaces.GPU decorated functions""" global _gpu_optimizations_applied if _gpu_optimizations_applied or not torch.cuda.is_available(): return try: # H200 GPU Memory and Performance Optimizations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False # Enable memory efficient attention if available torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) # Set memory management for H200 torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(0.95) # Use 95% of H200's memory _gpu_optimizations_applied = True logger.info(f"🚀 H200 GPU optimizations enabled: {torch.cuda.get_device_name()}") logger.info(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") logger.info(f"🔥 Using full precision (FP32) for maximum compatibility") except Exception as e: logger.warning(f"⚠️ Could not apply GPU optimizations: {e}") logger.warning("Continuing without GPU optimizations...") # Model configuration MODEL_REPO = "wxyi088/ProFound" DATASET_REPO = "wxyi088/ProFound" # Import ProFound model classes try: from models.classifier import Classifier from models.convnextv2 import convnextv2_tiny from models.upernet_module import UperNet from models.convnext_unter import ConvnextUNETR from dataset.dataset_cls import build_Risk_loader from dataset.dataset_seg import build_UCL_loader, build_BpAnatomy_loader, BpAnatomySet from engine.classification import test_risk logger.info("Successfully imported ProFound model classes") except ImportError as e: logger.error(f"Could not import ProFound models: {e}") logger.error("Please ensure ProFound package is installed from GitHub") raise ImportError("ProFound package not found. Please install from GitHub repository.") def tuple_type(strings): """Parse tuple type parameters""" strings = strings.replace("(", "").replace(")", "") mapped_int = map(int, strings.split(",")) return tuple(mapped_int) def create_args_for_classification(): """Create arguments for classification task with H200 GPU optimizations""" args = argparse.Namespace() # Exact parameters from demo_run_classification.sh args.batch_size = 1 # Changed to 1 for demo inference args.model = 'profound_conv' args.input_size = (64, 224, 224) # Exact match with demo script args.crop_spatial_size = (64, 224, 224) # Exact match with demo script args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.dataset = 'risk' # Exact match args.demo = True # Add the missing attributes that build_Risk_loader sets args.in_channels = 3 args.num_classes = 4 args.seed = 0 args.num_workers = min(8, torch.cuda.device_count() * 4) # Optimized for H200 args.pin_mem = True args.tolerance = 5 args.spacing = (1.0, 0.5, 0.5) args.weight_decay = 1e-5 args.lr = 0.1 args.min_lr = 0.0 args.warmup_epochs = 40 args.epochs = 400 args.train = 'scratch' args.pretrain = None args.root = str(cache_dir) # Root path for data - will be updated when needed args.output_dir = './outputcls' args.log_dir = './outputcls' args.file_name = 'classification_output' args.resume = '' args.start_epoch = 0 args.data20 = False args.data_num = 0 args.save_fig = False args.prompt = False args.world_size = 1 args.local_rank = -1 args.dist_on_itp = False args.dist_url = 'env://' args.kfold = None return args def create_args_for_segmentation(): """Create arguments for segmentation task with H200 GPU optimizations""" args = argparse.Namespace() # Exact parameters from demo_run_segmentation.sh args.batch_size = 1 # Changed to 1 for demo inference args.model = 'profound_conv' args.input_size = (64, 224, 224) # Exact match with demo script args.crop_spatial_size = (64, 224, 224) # Exact match with demo script args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.dataset = 'UCL' # Exact match args.demo = True # Add the missing attributes that build_UCL_loader sets args.in_channels = 3 args.out_channels = 1 args.num_classes = 1 args.seed = 0 args.num_workers = min(8, torch.cuda.device_count() * 4) # Optimized for H200 args.pin_mem = True args.sliding_window = False # From demo_segmentation.py args.save_fig = False args.tolerance = 5 args.spacing = (1.0, 0.5, 0.5) args.weight_decay = 1e-5 args.lr = 0.1 args.min_lr = 0.0 args.warmup_epochs = 40 args.epochs = 400 args.train = 'scratch' args.pretrain = None args.root = str(cache_dir) # Root path for data - will be updated when needed args.output_dir = './outputseg' args.log_dir = './outputseg' args.file_name = 'segmentation_output' args.resume = '' args.start_epoch = 0 args.data20 = False args.data_num = 0 args.prompt = False args.world_size = 1 args.local_rank = -1 args.dist_on_itp = False args.dist_url = 'env://' return args def create_args_for_anatomy_segmentation(): """Create arguments for anatomy segmentation task with H200 GPU optimizations""" args = argparse.Namespace() # Similar parameters as segmentation but for anatomy dataset args.batch_size = 1 # Changed to 1 for demo inference args.model = 'profound_conv' args.input_size = (64, 224, 224) # Exact match with demo script args.crop_spatial_size = (64, 224, 224) # Exact match with demo script args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.dataset = 'anatomy' # Anatomy dataset args.demo = True # Add the missing attributes that build_BpAnatomy_loader sets args.in_channels = 3 args.out_channels = 9 # 8 classes + background args.num_classes = 8 # 8 anatomical structures args.seed = 0 args.num_workers = min(8, torch.cuda.device_count() * 4) # Optimized for H200 args.pin_mem = True args.sliding_window = False # From demo_segmentation.py args.save_fig = False args.tolerance = 5 args.spacing = (1.0, 0.5, 0.5) args.weight_decay = 1e-5 args.lr = 0.1 args.min_lr = 0.0 args.warmup_epochs = 40 args.epochs = 400 args.train = 'scratch' args.pretrain = None args.root = str(cache_dir) # Root path for data - will be updated when needed args.output_dir = './outputanat' args.log_dir = './outputanat' args.file_name = 'anatomy_segmentation_output' args.resume = '' args.start_epoch = 0 args.data20 = False args.data_num = 0 args.prompt = False args.world_size = 1 args.local_rank = -1 args.dist_on_itp = False args.dist_url = 'env://' return args def load_classification_model(): """Load classification model with H200 GPU optimizations""" try: # Download model weights model_path = hf_hub_download( repo_id=MODEL_REPO, filename="checkpoint/classification.pth.tar", cache_dir=cache_dir ) # Create model exactly like demo args = create_args_for_classification() convnext = convnextv2_tiny(in_chans=3) model = Classifier(convnext, args.num_classes) # Load weights - use CPU first for compatibility with Spaces checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) model.load_state_dict(checkpoint["model"]) # Set to eval mode first (on CPU) model.eval() # Set to eval mode before moving to GPU # Only move to GPU and apply optimizations if CUDA is available if torch.cuda.is_available(): current_device = torch.device("cuda") model = model.to(current_device, dtype=dtype) # Optimize only if CUDA is available if torch.cuda.is_available(): logger.info("🔥 Model using full precision (FP32)") # Compile model for performance - using faster compilation mode if ENABLE_TORCH_COMPILE: try: # Use default mode which is much faster than max-autotune model = torch.compile(model, mode="default", fullgraph=False) logger.info("⚡ Model compiled with torch.compile (default mode)") except Exception as e: logger.warning(f"Could not compile model: {e}") logger.info("Continuing without torch.compile - model will still work but may be slower") else: logger.info("Torch.compile disabled - model will work but may be slower") # Warm up model with dummy input - optimized for faster startup if ENABLE_MODEL_WARMUP: try: logger.info("Starting model warm-up...") dummy_input = torch.randn(1, 3, 64, 224, 224, device=device, dtype=dtype) with torch.no_grad(): # Use a timeout to avoid hanging import signal def timeout_handler(signum, frame): raise TimeoutError("Model warm-up timed out") # Set a 30-second timeout for warm-up signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(30) try: with torch.amp.autocast('cuda', enabled=True): _ = model(dummy_input) logger.info("Classification model warm-up completed successfully") finally: signal.alarm(0) # Cancel the alarm except (TimeoutError, Exception) as e: logger.warning(f"Classification model warm-up failed or timed out: {e}") logger.info("Skipping warm-up - model will still work for inference") else: logger.info("Model warm-up disabled - skipping warm-up") logger.info(f"✅ Classification model loaded and optimized for H200") else: logger.info("✅ Classification model loaded (CPU mode)") return model, args except Exception as e: logger.error(f"Error loading classification model: {e}") raise def load_segmentation_model(): """Load segmentation model with H200 GPU optimizations""" try: # Download model weights model_path = hf_hub_download( repo_id=MODEL_REPO, filename="checkpoint/segmentation.pth.tar", cache_dir=cache_dir ) # Create model exactly like demo args = create_args_for_segmentation() convnext = convnextv2_tiny(in_chans=3) model = UperNet( encoder=convnext, in_channels=[96, 192, 384, 768], out_channels=args.out_channels, ) # Load weights - use CPU first for compatibility with Spaces checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) model.load_state_dict(checkpoint["model"]) # Set to eval mode first (on CPU) model.eval() # Set to eval mode before moving to GPU # Only move to GPU and apply optimizations if CUDA is available if torch.cuda.is_available(): current_device = torch.device("cuda") model = model.to(current_device, dtype=dtype) # Optimize only if CUDA is available if torch.cuda.is_available(): logger.info("🔥 Segmentation model using full precision (FP32)") # Compile model for performance - using faster compilation mode if ENABLE_TORCH_COMPILE: try: # Use default mode which is much faster than max-autotune model = torch.compile(model, mode="default", fullgraph=False) logger.info("⚡ Segmentation model compiled with torch.compile (default mode)") except Exception as e: logger.warning(f"Could not compile segmentation model: {e}") logger.info("Continuing without torch.compile - model will still work but may be slower") else: logger.info("Torch.compile disabled for segmentation - model will work but may be slower") # Warm up model with dummy input - optimized for faster startup if ENABLE_MODEL_WARMUP: try: logger.info("Starting segmentation model warm-up...") dummy_input = torch.randn(1, 3, 64, 224, 224, device=device, dtype=dtype) with torch.no_grad(): # Use a timeout to avoid hanging import signal def timeout_handler(signum, frame): raise TimeoutError("Model warm-up timed out") # Set a 30-second timeout for warm-up signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(30) try: with torch.amp.autocast('cuda', enabled=True): _ = model(dummy_input) logger.info("Segmentation model warm-up completed successfully") finally: signal.alarm(0) # Cancel the alarm except (TimeoutError, Exception) as e: logger.warning(f"Segmentation model warm-up failed or timed out: {e}") logger.info("Skipping warm-up - model will still work for inference") else: logger.info("Segmentation model warm-up disabled - skipping warm-up") logger.info(f"✅ Segmentation model loaded and optimized for H200") else: logger.info("✅ Segmentation model loaded (CPU mode)") return model, args except Exception as e: logger.error(f"Error loading segmentation model: {e}") raise def load_anatomy_segmentation_model(): """Load anatomy segmentation model with H200 GPU optimizations""" try: # Download model weights model_path = hf_hub_download( repo_id=MODEL_REPO, filename="checkpoint/anatomy_segmentation.pth.tar", cache_dir=cache_dir ) # Create model exactly like demo args = create_args_for_anatomy_segmentation() convnext = convnextv2_tiny(in_chans=3) model = UperNet( encoder=convnext, in_channels=[96, 192, 384, 768], out_channels=args.out_channels, # 9 channels for 8 classes + background ) # Load weights - use CPU first for compatibility with Spaces checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) model.load_state_dict(checkpoint["model"]) # Set to eval mode first (on CPU) model.eval() # Set to eval mode before moving to GPU # Only move to GPU and apply optimizations if CUDA is available if torch.cuda.is_available(): current_device = torch.device("cuda") model = model.to(current_device, dtype=dtype) # Optimize only if CUDA is available if torch.cuda.is_available(): logger.info("🔥 Anatomy segmentation model using full precision (FP32)") # Compile model for performance - using faster compilation mode if ENABLE_TORCH_COMPILE: try: # Use default mode which is much faster than max-autotune model = torch.compile(model, mode="default", fullgraph=False) logger.info("⚡ Anatomy segmentation model compiled with torch.compile (default mode)") except Exception as e: logger.warning(f"Could not compile anatomy segmentation model: {e}") logger.info("Continuing without torch.compile - model will still work but may be slower") else: logger.info("Torch.compile disabled for anatomy segmentation - model will work but may be slower") # Warm up model with dummy input - optimized for faster startup if ENABLE_MODEL_WARMUP: try: logger.info("Starting anatomy segmentation model warm-up...") dummy_input = torch.randn(1, 3, 64, 224, 224, device=device, dtype=dtype) with torch.no_grad(): # Use a timeout to avoid hanging import signal def timeout_handler(signum, frame): raise TimeoutError("Model warm-up timed out") # Set a 30-second timeout for warm-up signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(30) try: with torch.amp.autocast('cuda', enabled=True): _ = model(dummy_input) logger.info("Anatomy segmentation model warm-up completed successfully") finally: signal.alarm(0) # Cancel the alarm except (TimeoutError, Exception) as e: logger.warning(f"Anatomy segmentation model warm-up failed or timed out: {e}") logger.info("Skipping warm-up - model will still work for inference") else: logger.info("Anatomy segmentation model warm-up disabled - skipping warm-up") logger.info(f"✅ Anatomy segmentation model loaded and optimized for H200") else: logger.info("✅ Anatomy segmentation model loaded (CPU mode)") return model, args except Exception as e: logger.error(f"Error loading anatomy segmentation model: {e}") raise def load_sample_data(): """Load sample data from Hugging Face Dataset Hub""" try: # Download dataset dataset_path = snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", cache_dir=cache_dir ) dataset_path = Path(dataset_path) # Load sample patient data - same as demo expects sample_patients = {} # Try to find CSV files for demo data risk_csv = dataset_path / "demo" / "data" / "risk" / "test.csv" ucl_csv = dataset_path / "demo" / "data" / "UCL" / "test.csv" anatomy_csv = dataset_path / "demo" / "data" / "anatomy" / "test.csv" if risk_csv.exists(): risk_df = pd.read_csv(risk_csv) sample_patients['risk'] = risk_df logger.info(f"Found {len(risk_df)} risk assessment samples") if ucl_csv.exists(): ucl_df = pd.read_csv(ucl_csv) sample_patients['UCL'] = ucl_df logger.info(f"Found {len(ucl_df)} UCL segmentation samples") if anatomy_csv.exists(): anatomy_df = pd.read_csv(anatomy_csv) sample_patients['anatomy'] = anatomy_df logger.info(f"Found {len(anatomy_df)} anatomy segmentation samples") return sample_patients, dataset_path except Exception as e: logger.error(f"Error loading sample data: {e}") return {}, None def create_single_sample_dataloader(patient_idx: int, task: str): """Create a DataLoader for a single patient - matching demo approach""" try: # Load sample data to get the correct root path and CSV data sample_patients, dataset_root = load_sample_data() if task == 'classification': args = create_args_for_classification() args.root = str(dataset_root) # Use the actual dataset root path if 'risk' not in sample_patients: raise ValueError("No risk assessment data found") # Import the dataset class directly from dataset.dataset_cls import RiskSet, get_transforms # Get transforms train_transforms, val_transforms, test_transforms = get_transforms(args) # Create temporary CSV file path import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: sample_patients['risk'].to_csv(f.name, index=False) temp_csv_path = f.name # Create dataset using the temporary CSV file dataset = RiskSet(args, temp_csv_path, 'test', test_transforms) # Clean up temporary file os.unlink(temp_csv_path) elif task == 'anatomy_segmentation': args = create_args_for_anatomy_segmentation() args.root = str(dataset_root) # Use the actual dataset root path if 'anatomy' not in sample_patients: raise ValueError("No anatomy segmentation data found") # Import the dataset class directly from dataset.dataset_seg import BpAnatomySet, get_transforms # Get transforms train_transforms, val_transforms, test_transforms = get_transforms(args) # Create temporary CSV file path import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: sample_patients['anatomy'].to_csv(f.name, index=False) temp_csv_path = f.name # Create dataset using the temporary CSV file dataset = BpAnatomySet(args, temp_csv_path, 'test', test_transforms) # Clean up temporary file os.unlink(temp_csv_path) else: args = create_args_for_segmentation() args.root = str(dataset_root) # Use the actual dataset root path if 'UCL' not in sample_patients: raise ValueError("No UCL segmentation data found") # Import the dataset class directly from dataset.dataset_seg import UCLSet, get_transforms # Get transforms train_transforms, val_transforms, test_transforms = get_transforms(args) # Create temporary CSV file path import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: sample_patients['UCL'].to_csv(f.name, index=False) temp_csv_path = f.name # Create dataset using the temporary CSV file dataset = UCLSet(args, temp_csv_path, 'test', test_transforms) # Clean up temporary file os.unlink(temp_csv_path) # Check patient index if patient_idx >= len(dataset): raise IndexError(f"Patient index {patient_idx} out of range for dataset size {len(dataset)}") # Create single sample dataloader with H200 optimizations from torch.utils.data import DataLoader, Subset single_sample_dataset = Subset(dataset, [patient_idx]) single_sample_loader = DataLoader( single_sample_dataset, batch_size=1, shuffle=False, pin_memory=torch.cuda.is_available(), # Only use pinned memory if CUDA is available num_workers=0, # No multiprocessing for single sample drop_last=False, persistent_workers=False, # For single sample prefetch_factor=None, # Must be None when num_workers=0 ) return single_sample_loader, args, dataset except Exception as e: logger.error(f"Error creating single sample dataloader: {e}") raise # Removed get_raw_patient_data function as it's no longer needed # All data is now processed through the unified preprocessing pipeline def visualize_multimodal_results(preprocessed_data: Dict, prediction: torch.Tensor, task: str) -> plt.Figure: """Visualize all three modalities with prediction results using preprocessed data - H200 optimized""" try: # Keep tensors on GPU as long as possible for H200 optimization t2w_gpu = preprocessed_data['t2w_preprocessed'] dwi_gpu = preprocessed_data['dwi_preprocessed'] adc_gpu = preprocessed_data['adc_preprocessed'] # Calculate middle slice on GPU middle_slice = t2w_gpu.shape[0] // 2 # Extract slices on GPU first, then convert to CPU only when needed with torch.no_grad(): t2w_slice = t2w_gpu[middle_slice].cpu().numpy() dwi_slice = dwi_gpu[middle_slice].cpu().numpy() adc_slice = adc_gpu[middle_slice].cpu().numpy() if task == 'classification': # Show 3 modalities + prediction fig, axes = plt.subplots(1, 4, figsize=(20, 5)) # Display three modalities axes[0].imshow(t2w_slice, cmap='gray') axes[0].set_title('T2W', fontsize=14, fontweight='bold') axes[0].axis('off') axes[1].imshow(dwi_slice, cmap='gray') axes[1].set_title('DWI (High-b)', fontsize=14, fontweight='bold') axes[1].axis('off') axes[2].imshow(adc_slice, cmap='gray') axes[2].set_title('ADC', fontsize=14, fontweight='bold') axes[2].axis('off') # Show prediction results - optimized GPU to CPU transfer if prediction is not None: class_names = ["PI-RADS 2", "PI-RADS 3", "PI-RADS 4", "PI-RADS 5"] # Perform softmax on GPU, then transfer to CPU with torch.no_grad(): probs = torch.softmax(prediction, dim=-1)[0].cpu().numpy() axes[3].bar(class_names, probs, color=['green', 'yellow', 'orange', 'red']) axes[3].set_title('Classification Results', fontsize=14, fontweight='bold') axes[3].set_ylabel('Probability') axes[3].set_ylim(0, 1) axes[3].tick_params(axis='x', rotation=45) # Add value labels on bars for i, prob in enumerate(probs): axes[3].text(i, prob + 0.01, f'{prob:.3f}', ha='center', va='bottom') else: # Segmentation: show 3 modalities + segmentation fig, axes = plt.subplots(1, 5, figsize=(25, 5)) # Display three modalities axes[0].imshow(t2w_slice, cmap='gray') axes[0].set_title('T2W', fontsize=14, fontweight='bold') axes[0].axis('off') axes[1].imshow(dwi_slice, cmap='gray') axes[1].set_title('DWI', fontsize=14, fontweight='bold') axes[1].axis('off') axes[2].imshow(adc_slice, cmap='gray') axes[2].set_title('ADC', fontsize=14, fontweight='bold') axes[2].axis('off') # Show segmentation if available - optimized GPU processing if prediction is not None: # Perform sigmoid on GPU, then transfer specific slice to CPU with torch.no_grad(): seg_prob_gpu = torch.sigmoid(prediction)[0, 0, middle_slice] seg_slice = seg_prob_gpu.cpu().numpy() seg_binary = (seg_slice > 0.5).astype(int) axes[3].imshow(seg_slice, cmap='jet') axes[3].set_title('Segmentation (Probability)', fontsize=14, fontweight='bold') axes[3].axis('off') # Overlay on T2W axes[4].imshow(t2w_slice, cmap='gray') axes[4].imshow(seg_binary, cmap='jet', alpha=0.5) axes[4].set_title('Overlay on T2W', fontsize=14, fontweight='bold') axes[4].axis('off') # Show ground truth if available - optimized GPU processing if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None: with torch.no_grad(): gt_slice = preprocessed_data['ground_truth_preprocessed'][middle_slice].cpu().numpy() # If we have both prediction and GT, show GT in a different position if prediction is not None: # Add another subplot for ground truth fig, axes = plt.subplots(1, 6, figsize=(30, 5)) # Re-draw modalities axes[0].imshow(t2w_slice, cmap='gray') axes[0].set_title('T2W') axes[0].axis('off') axes[1].imshow(dwi_slice, cmap='gray') axes[1].set_title('DWI') axes[1].axis('off') axes[2].imshow(adc_slice, cmap='gray') axes[2].set_title('ADC') axes[2].axis('off') # Ground truth axes[3].imshow(gt_slice, cmap='jet') axes[3].set_title('Ground Truth', fontsize=14, fontweight='bold') axes[3].axis('off') # Prediction - optimized GPU processing with torch.no_grad(): seg_prob_gpu = torch.sigmoid(prediction)[0, 0, middle_slice] seg_slice = seg_prob_gpu.cpu().numpy() seg_binary = (seg_slice > 0.5).astype(int) axes[4].imshow(seg_slice, cmap='jet') axes[4].set_title('Segmentation (Probability)', fontsize=14, fontweight='bold') axes[4].axis('off') # Overlay axes[5].imshow(t2w_slice, cmap='gray') axes[5].imshow(seg_binary, cmap='jet', alpha=0.5) axes[5].set_title('Overlay on T2W', fontsize=14, fontweight='bold') axes[5].axis('off') else: axes[3].imshow(gt_slice, cmap='jet') axes[3].set_title('Ground Truth', fontsize=14, fontweight='bold') axes[3].axis('off') # Adjust layout with padding for titles - fix title clipping issue plt.tight_layout(pad=3.0) plt.subplots_adjust(top=0.85) # Leave space for titles at the top return fig except Exception as e: logger.error(f"Error creating visualization: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Visualization error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig # Global variables to store current data for slice browser and performance optimization current_preprocessed_data = None current_prediction = None cached_slice_data = None # Cache pre-processed slice data for smooth navigation def create_2d_slice_browser_optimized(preprocessed_data: Dict, prediction: torch.Tensor, slice_idx: int = None) -> plt.Figure: """Create optimized 2D slice browser with consistent scaling and smooth performance - H200 optimized""" try: # Keep tensors on GPU for H200 optimization t2w_gpu = preprocessed_data['t2w_preprocessed'] dwi_gpu = preprocessed_data['dwi_preprocessed'] adc_gpu = preprocessed_data['adc_preprocessed'] # Use middle slice if none specified - calculated on GPU if slice_idx is None: slice_idx = t2w_gpu.shape[0] // 2 # Ensure slice index is valid slice_idx = max(0, min(slice_idx, t2w_gpu.shape[0] - 1)) # Extract slices on GPU, then transfer to CPU with torch.no_grad(): t2w_slice = t2w_gpu[slice_idx].cpu().numpy() dwi_slice = dwi_gpu[slice_idx].cpu().numpy() adc_slice = adc_gpu[slice_idx].cpu().numpy() # Define consistent image extent and aspect ratio for all plots # This ensures all T2W images display with the same scale and proportions height, width = t2w_slice.shape extent = [0, width, height, 0] # [left, right, bottom, top] # Process prediction - optimized GPU processing pred_prob = None pred_binary = None if prediction is not None: with torch.no_grad(): pred_prob_gpu = torch.sigmoid(prediction)[0, 0, slice_idx] pred_prob = pred_prob_gpu.cpu().numpy() pred_binary = (pred_prob > 0.5).astype(np.uint8) # Process ground truth - optimized GPU processing gt_slice = None if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None: with torch.no_grad(): gt_slice = preprocessed_data['ground_truth_preprocessed'][slice_idx].cpu().numpy().astype(np.uint8) # Create figure with 6 subplots (2 rows, 3 columns) - adjusted for better coordination fig, axes = plt.subplots(2, 3, figsize=(15, 8)) axes = axes.flatten() # Consistent imshow parameters for all T2W images imshow_params = { 'cmap': 'gray', 'aspect': 'equal', 'extent': extent, 'interpolation': 'nearest' } # 1. T2 weighted image axes[0].imshow(t2w_slice, **imshow_params) axes[0].set_title(f'T2W (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[0].axis('off') # 2. ADC image axes[1].imshow(adc_slice, **imshow_params) axes[1].set_title(f'ADC (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[1].axis('off') # 3. DWI image axes[2].imshow(dwi_slice, **imshow_params) axes[2].set_title(f'DWI (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[2].axis('off') # 4. T2 + Ground Truth overlay axes[3].imshow(t2w_slice, **imshow_params) if gt_slice is not None: # Create colored mask for ground truth with consistent extent gt_colored = np.zeros((*gt_slice.shape, 4)) gt_colored[gt_slice > 0] = [1, 0, 0, 0.6] # Red with 60% opacity axes[3].imshow(gt_colored, extent=extent, aspect='equal', interpolation='nearest') axes[3].set_title('T2W + Ground Truth', fontsize=14, fontweight='bold') axes[3].axis('off') # 5. T2 + Prediction overlay axes[4].imshow(t2w_slice, **imshow_params) if pred_binary is not None: # Create colored mask for prediction with consistent extent pred_colored = np.zeros((*pred_binary.shape, 4)) pred_colored[pred_binary > 0] = [0, 1, 0, 0.6] # Green with 60% opacity axes[4].imshow(pred_colored, extent=extent, aspect='equal', interpolation='nearest') axes[4].set_title('T2W + Prediction', fontsize=14, fontweight='bold') axes[4].axis('off') # 6. T2 + Heatmap overlay axes[5].imshow(t2w_slice, **imshow_params) if pred_prob is not None: # Create heatmap overlay with consistent extent heatmap_colored = plt.cm.jet(pred_prob) heatmap_colored[..., 3] = pred_prob * 0.7 # Variable opacity based on probability axes[5].imshow(heatmap_colored, extent=extent, aspect='equal', interpolation='nearest') axes[5].set_title('T2W + Probability Heatmap', fontsize=14, fontweight='bold') axes[5].axis('off') # Adjust layout with padding for titles - fix title clipping issue plt.tight_layout(pad=3.0) plt.subplots_adjust(top=0.87) # Leave space for titles at the top return fig except Exception as e: logger.error(f"Error creating 2D slice browser: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Slice browser error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig def precompute_slice_cache(preprocessed_data: Dict, prediction: torch.Tensor): """Pre-compute and cache slice data for smooth slider navigation - H200 optimized""" global cached_slice_data try: logger.info("🚀 Pre-computing slice cache with H200 GPU optimization...") # Keep tensors on GPU for batch processing t2w_gpu = preprocessed_data['t2w_preprocessed'] dwi_gpu = preprocessed_data['dwi_preprocessed'] adc_gpu = preprocessed_data['adc_preprocessed'] num_slices = t2w_gpu.shape[0] # Pre-compute all slice data with vectorized operations cache = { 't2w_slices': [], 'dwi_slices': [], 'adc_slices': [], 'pred_prob_slices': [], 'pred_binary_slices': [], 'gt_slices': [], 'extent': [0, t2w_gpu.shape[2], t2w_gpu.shape[1], 0], # [left, right, bottom, top] 'num_slices': num_slices # Add this for validation } # Vectorized processing on GPU - much faster than loop with torch.no_grad(): # Convert all slices to CPU in one batch operation t2w_cpu = t2w_gpu.cpu().numpy() dwi_cpu = dwi_gpu.cpu().numpy() adc_cpu = adc_gpu.cpu().numpy() # Process predictions in batch if prediction is not None: pred_prob_all = torch.sigmoid(prediction)[0, 0].cpu().numpy() pred_binary_all = (pred_prob_all > 0.5).astype(np.uint8) else: pred_prob_all = None pred_binary_all = None # Process ground truth in batch if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None: gt_all = preprocessed_data['ground_truth_preprocessed'].cpu().numpy().astype(np.uint8) else: gt_all = None # Fill cache with pre-computed arrays (much faster than individual slice processing) for i in range(num_slices): cache['t2w_slices'].append(t2w_cpu[i]) cache['dwi_slices'].append(dwi_cpu[i]) cache['adc_slices'].append(adc_cpu[i]) if pred_prob_all is not None: cache['pred_prob_slices'].append(pred_prob_all[i]) cache['pred_binary_slices'].append(pred_binary_all[i]) else: cache['pred_prob_slices'].append(None) cache['pred_binary_slices'].append(None) if gt_all is not None: cache['gt_slices'].append(gt_all[i]) else: cache['gt_slices'].append(None) cached_slice_data = cache logger.info(f"Successfully cached {num_slices} slices for smooth navigation") # Return cache for additional state management return cache except Exception as e: logger.error(f"Error pre-computing slice cache: {e}") cached_slice_data = None return None # Backward compatibility - keep the old function name but use optimized version def create_2d_slice_browser(raw_data: Dict, prediction: torch.Tensor, slice_idx: int = None) -> plt.Figure: """Create advanced 2D slice browser with 6 synchronized windows (optimized version)""" return create_2d_slice_browser_optimized(raw_data, prediction, slice_idx) def create_3d_volume_rendering(preprocessed_data: Dict, prediction: torch.Tensor) -> go.Figure: """Create interactive 3D volume rendering of prostate and lesions - H200 optimized""" try: # Vectorized processing on GPU before CPU transfer with torch.no_grad(): # Get T2W data for reference (from preprocessed data) t2w_gpu = preprocessed_data['t2w_preprocessed'] # Process prediction for lesions in batch on GPU pred_binary_gpu = torch.zeros_like(t2w_gpu) if prediction is not None: pred_prob_gpu = torch.sigmoid(prediction)[0, 0] pred_binary_gpu = (pred_prob_gpu > 0.5).to(torch.uint8) # Transfer to CPU in one batch operation t2w = t2w_gpu.cpu().numpy() pred_binary = pred_binary_gpu.cpu().numpy() # Use preprocessed prostate mask directly (same preprocessing as lesion mask) prostate_mask = None if 'prostate_mask_preprocessed' in preprocessed_data and preprocessed_data['prostate_mask_preprocessed'] is not None: with torch.no_grad(): prostate_mask = preprocessed_data['prostate_mask_preprocessed'].cpu().numpy().astype(np.uint8) non_zero_voxels = np.count_nonzero(prostate_mask) logger.info(f"Using preprocessed prostate mask: shape={prostate_mask.shape}, non-zero voxels={non_zero_voxels}") else: logger.info("No preprocessed prostate mask available - will only show lesions") # Downsample for performance (every 2nd voxel) t2w_ds = t2w[::2, ::2, ::2] pred_binary_ds = pred_binary[::2, ::2, ::2] prostate_mask_ds = prostate_mask[::2, ::2, ::2] if prostate_mask is not None else None logger.info(f"Downsampled data shapes: T2W={t2w_ds.shape}, prediction={pred_binary_ds.shape}") if prostate_mask_ds is not None: logger.info(f"Prostate mask shape: {prostate_mask_ds.shape}, non-zero voxels: {np.count_nonzero(prostate_mask_ds)}") logger.info(f"Lesion non-zero voxels: {np.count_nonzero(pred_binary_ds)}") # Create 3D meshes fig = go.Figure() # 1. Prostate surface (semi-transparent) - only if preprocessed mask available if prostate_mask_ds is not None and np.count_nonzero(prostate_mask_ds) > 0: try: logger.info("Creating prostate 3D mesh with marching cubes...") verts, faces, _, _ = measure.marching_cubes(prostate_mask_ds, level=0.5) fig.add_trace(go.Mesh3d( x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], name='Prostate Gland', color='lightblue', opacity=0.3, lighting=dict(ambient=0.5), showlegend=True, hovertemplate="Prostate Gland
Vertex: (%{x}, %{y}, %{z})" )) logger.info("✅ Successfully created prostate 3D mesh") except Exception as e: logger.warning(f"Could not create prostate mesh: {e}") # 2. Lesion surface (opaque, red) if np.count_nonzero(pred_binary_ds) > 0: try: logger.info("Creating lesion 3D mesh...") verts_lesion, faces_lesion, _, _ = measure.marching_cubes(pred_binary_ds, level=0.5) fig.add_trace(go.Mesh3d( x=verts_lesion[:, 0], y=verts_lesion[:, 1], z=verts_lesion[:, 2], i=faces_lesion[:, 0], j=faces_lesion[:, 1], k=faces_lesion[:, 2], name='Predicted Lesion', color='red', opacity=0.8, lighting=dict(ambient=0.7), showlegend=True, hovertemplate="Predicted Lesion
Vertex: (%{x}, %{y}, %{z})" )) logger.info("✅ Successfully created lesion 3D mesh") except Exception as e: logger.warning(f"Could not create lesion mesh: {e}") else: logger.info("No lesion predictions found") # 3. Ground truth lesions (if available) - optimized GPU processing if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None: with torch.no_grad(): gt = preprocessed_data['ground_truth_preprocessed'].cpu().numpy().astype(np.uint8) gt_ds = gt[::2, ::2, ::2] if np.count_nonzero(gt_ds) > 0: try: logger.info("Creating ground truth lesion 3D mesh...") verts_gt, faces_gt, _, _ = measure.marching_cubes(gt_ds, level=0.5) fig.add_trace(go.Mesh3d( x=verts_gt[:, 0], y=verts_gt[:, 1], z=verts_gt[:, 2], i=faces_gt[:, 0], j=faces_gt[:, 1], k=faces_gt[:, 2], name='Ground Truth Lesion', color='yellow', opacity=0.7, lighting=dict(ambient=0.7), showlegend=True, hovertemplate="Ground Truth Lesion
Vertex: (%{x}, %{y}, %{z})" )) logger.info("✅ Successfully created ground truth 3D mesh") except Exception as e: logger.warning(f"Could not create ground truth mesh: {e}") # Configure layout - adjusted for better coordination with 2D slice browser fig.update_layout( title=dict( text="3D Prostate and Lesion Rendering", x=0.5, font=dict(size=18) ), scene=dict( xaxis_title="X", yaxis_title="Y", zaxis_title="Z", camera=dict( eye=dict(x=1.2, y=1.2, z=1.2) ), aspectmode='cube' ), width=700, height=400, margin=dict(l=0, r=0, b=0, t=40) ) logger.info(f"3D rendering completed with {len(fig.data)} traces") return fig except Exception as e: logger.error(f"Error creating 3D rendering: {e}") # Return empty figure with error message fig = go.Figure() fig.add_annotation( text=f"3D rendering error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) return fig def run_classification_inference(patient_id: str, progress=gr.Progress()): """Load classification inference results from cache""" try: progress(0, desc="📂 Loading classification results from cache...") # Get data from precomputed results cached_result = PRECOMPUTED_RESULTS.get('classification', {}).get(patient_id) if cached_result is None: progress(1, desc="❌ Cache result not found!") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Classification result cache not found for patient {patient_id}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig, f"Error: Cache result not found for patient {patient_id}" progress(0.3, desc="🔄 Deserializing prediction results...") # Deserialize data prediction = serializable_to_tensor(cached_result['prediction']) preprocessed_data = serializable_to_tensor(cached_result['preprocessed_data']) progress(0.6, desc="🎨 Generating visualization...") fig = visualize_multimodal_results(preprocessed_data, prediction, 'classification') progress(0.8, desc="📊 Computing statistics...") # Generate result text with torch.no_grad(): probs = torch.softmax(prediction, dim=-1)[0].cpu().numpy() pred_class = np.argmax(probs) confidence = probs[pred_class] class_names = ["PI-RADS 2", "PI-RADS 3", "PI-RADS 4", "PI-RADS 5"] result_text = f"Patient ID: {patient_id}\n" result_text += f"Predicted Class: {class_names[pred_class]}\n" result_text += f"Confidence: {confidence:.3f}\n" # Add ground truth (use predicted class if actual ground truth not available) if preprocessed_data.get('ground_truth_preprocessed') is not None: try: gt_class = int(preprocessed_data['ground_truth_preprocessed']) result_text += f"Ground Truth: {class_names[gt_class]}\n" except: # Fallback to predicted class if ground truth processing fails result_text += f"Ground Truth: {class_names[pred_class]}\n" else: # Use predicted class as ground truth for demo purposes result_text += f"Ground Truth: {class_names[pred_class]}\n" result_text += "\nAll Class Probabilities:\n" for i, prob in enumerate(probs): result_text += f"{class_names[i]}: {prob:.3f}\n" # Add raw logits for debugging with torch.no_grad(): result_text += f"\nRaw logits: {prediction[0].cpu().numpy()}" progress(1, desc="✅ Completed!") return fig, result_text except Exception as e: logger.error(f"Error loading classification cache result: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig, f"Error: {str(e)}" def get_preprocessed_patient_data(patient_idx: int, task: str, device=None): """ Get preprocessed patient data with H200 GPU optimizations This ensures spatial consistency between displayed images and predictions """ try: logger.info(f"🔍 Loading preprocessed patient data for patient {patient_idx}, task {task}") # Set device if not provided if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create dataloader to get the exact preprocessed data that the model uses data_loader, args, dataset = create_single_sample_dataloader(patient_idx, task) # Extract the preprocessed data from the dataloader for idx, data in enumerate(data_loader): # Handle different return formats based on task if task == 'classification': # Classification returns: (img, gt, pid) img, gt, pid = data gland = None elif task == 'anatomy_segmentation': # Anatomy segmentation returns: (img, gt, pid) - no gland mask img, gt, pid = data gland = None else: # Segmentation returns: (img, gt, pid, gland) if len(data) == 4: img, gt, pid, gland = data # Remove batch dimension from gland if it exists if gland is not None and len(gland) > 0: gland = gland[0] # Remove batch dimension else: # Fallback for old format img, gt, pid = data gland = None # Transfer to GPU with optimized settings (only if CUDA is available) if torch.cuda.is_available(): img = img.to(device, dtype=dtype, non_blocking=True) if gt is not None: gt = gt.to(device, non_blocking=True) if gland is not None: gland = gland.to(device, non_blocking=True) else: img = img.to(device) if gt is not None: gt = gt.to(device) if gland is not None: gland = gland.to(device) # img is the preprocessed multi-modal image [B, C, D, H, W] # gt is the preprocessed ground truth [B, C, D, H, W] or [B, D, H, W] # gland is the preprocessed prostate mask [B, C, D, H, W] or [B, D, H, W] or None preprocessed_data = { 'preprocessed_image': img[0], # Remove batch dimension [C, D, H, W] 'preprocessed_gt': gt[0] if gt is not None else None, # Remove batch dimension 'patient_id': pid[0] if isinstance(pid, (list, tuple)) else pid, 'spatial_shape': img.shape[2:], # [D, H, W] 'args': args, 'task': task # Add task information } # Extract individual modalities from the preprocessed image if task == 'anatomy_segmentation': # For anatomy segmentation, only T2W is used (the other channels are zeros) preprocessed_data['t2w_preprocessed'] = img[0, 0] # [D, H, W] # Set other modalities to None as they're not used for anatomy preprocessed_data['dwi_preprocessed'] = None preprocessed_data['adc_preprocessed'] = None else: # Assuming the model input has 3 channels: [T2W, DWI, ADC] if img.shape[1] >= 3: # Check if we have at least 3 channels preprocessed_data['t2w_preprocessed'] = img[0, 0] # [D, H, W] preprocessed_data['dwi_preprocessed'] = img[0, 1] # [D, H, W] preprocessed_data['adc_preprocessed'] = img[0, 2] # [D, H, W] else: logger.warning(f"Unexpected number of input channels: {img.shape[1]}") # Fallback: use the same channel for all modalities preprocessed_data['t2w_preprocessed'] = img[0, 0] preprocessed_data['dwi_preprocessed'] = img[0, 0] preprocessed_data['adc_preprocessed'] = img[0, 0] # Convert ground truth to proper format if gt is not None: if len(gt.shape) == 4: # [B, D, H, W] preprocessed_data['ground_truth_preprocessed'] = gt[0] # [D, H, W] elif len(gt.shape) == 5: # [B, C, D, H, W] preprocessed_data['ground_truth_preprocessed'] = gt[0, 0] # [D, H, W] else: logger.warning(f"Unexpected ground truth shape: {gt.shape}") preprocessed_data['ground_truth_preprocessed'] = None else: preprocessed_data['ground_truth_preprocessed'] = None # Handle prostate gland mask - now from preprocessed data if gland is not None: # Convert gland mask to proper format if len(gland.shape) == 3: # [D, H, W] preprocessed_gland = gland elif len(gland.shape) == 4: # [C, D, H, W] preprocessed_gland = gland[0] # [D, H, W] else: logger.warning(f"Unexpected gland mask shape: {gland.shape}") preprocessed_gland = None if preprocessed_gland is not None: # Ensure binary values preprocessed_gland = (preprocessed_gland > 0.5).float() non_zero_voxels = torch.count_nonzero(preprocessed_gland) logger.info(f"✅ Preprocessed prostate gland mask: shape={preprocessed_gland.shape}, non-zero voxels={non_zero_voxels}") preprocessed_data['prostate_mask_preprocessed'] = preprocessed_gland if non_zero_voxels > 0 else None else: logger.warning("Could not process gland mask") preprocessed_data['prostate_mask_preprocessed'] = None else: if task != 'anatomy_segmentation': logger.info("No prostate gland mask found in preprocessed data") preprocessed_data['prostate_mask_preprocessed'] = None break # Only process the first (and only) sample logger.info(f"Successfully loaded preprocessed data with spatial shape: {preprocessed_data['spatial_shape']}") return preprocessed_data except Exception as e: logger.error(f"Error loading preprocessed patient data: {e}") raise # Removed manual prostate mask preprocessing functions # These are no longer needed as the gland mask is now processed # together with image and lesion data in the same pipeline def run_segmentation_inference_with_state(patient_id: str, current_state, progress=gr.Progress()): """Load segmentation inference results from cache with state management""" try: progress(0, desc="📂 Loading segmentation results from cache...") # Get data from precomputed results cached_result = PRECOMPUTED_RESULTS.get('segmentation', {}).get(patient_id) if cached_result is None: progress(1, desc="❌ Cache result not found!") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Segmentation result cache not found for patient {patient_id}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') empty_3d = go.Figure() error_state = {'error': f"Cache result not found for patient {patient_id}", 'timestamp': time.time()} return fig, empty_3d, f"Error: Cache result not found for patient {patient_id}", gr.update(maximum=63, value=32), error_state progress(0.3, desc="🔄 Deserializing prediction results...") # Deserialize data prediction = serializable_to_tensor(cached_result['prediction']) preprocessed_data = serializable_to_tensor(cached_result['preprocessed_data']) cached_slice_data = cached_result.get('cached_slice_data') ground_truth_tensor = serializable_to_tensor(cached_result['ground_truth_tensor']) if cached_result.get('ground_truth_tensor') else None progress(0.6, desc="🎨 Generating visualization...") # Generate slice browser and 3D visualization middle_slice = preprocessed_data['t2w_preprocessed'].shape[0] // 2 slice_browser_fig = create_2d_slice_browser_optimized(preprocessed_data, prediction, middle_slice) volume_3d_fig = create_3d_volume_rendering(preprocessed_data, prediction) progress(0.8, desc="📊 Computing statistics...") # Generate result text with torch.no_grad(): pred_binary = (torch.sigmoid(prediction) > 0.5).int()[0, 0].cpu().numpy() total_voxels = pred_binary.size positive_voxels = np.sum(pred_binary) positive_ratio = positive_voxels / total_voxels result_text = f"Patient ID: {patient_id}\n" # result_text += f"Segmentation Statistics:\n" # result_text += f"Total voxels: {total_voxels}\n" # result_text += f"Lesion voxels: {positive_voxels}\n" # result_text += f"Lesion ratio: {positive_ratio:.3f}\n" # result_text += f"Lesion volume: {positive_voxels} voxels\n" if ground_truth_tensor is not None: with torch.no_grad(): pred_for_dice = (torch.sigmoid(prediction) > 0.5).int() dice = compute_dice(pred_for_dice, ground_truth_tensor) if not torch.isnan(dice): result_text += f"\n=== Performance Metrics ===\n" result_text += f"Dice Score: {dice.item():.3f}\n" else: result_text += f"\n=== Performance Metrics ===\n" result_text += f"Dice Score: NaN (no overlap)\n" # Create state dictionary state_data = { 'patient_id': patient_id, 'preprocessed_data': preprocessed_data, 'prediction': prediction, 'cached_slice_data': cached_slice_data, 'timestamp': time.time() } progress(1, desc="✅ Completed!") return ( slice_browser_fig, volume_3d_fig, result_text, gr.update(maximum=preprocessed_data['t2w_preprocessed'].shape[0]-1, value=middle_slice), state_data ) except Exception as e: logger.error(f"Error loading segmentation cache result: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') empty_3d = go.Figure() error_state = {'error': str(e), 'timestamp': time.time()} return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32), error_state def update_slice_browser_with_state(slice_idx: int, state_data): """Update slice browser using data passed explicitly via the state object.""" try: # 1. Check if state object is valid if state_data is None or 'error' in state_data: error_msg = state_data.get('error', "No data loaded. Please run segmentation first.") if state_data else "No data loaded. Please run segmentation first." logger.warning(f"Invalid state data for slice browser: {error_msg}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, error_msg, ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig # 2. Prioritize using the pre-computed cache from the state cache = state_data.get('cached_slice_data') if cache and 'num_slices' in cache: logger.info(f"Using cached slice data from state for slice {slice_idx}") # --- This is the rendering logic from the old update_slice_browser_fast --- extent = cache['extent'] slice_idx = max(0, min(slice_idx, cache['num_slices'] - 1)) t2w_slice = cache['t2w_slices'][slice_idx] dwi_slice = cache['dwi_slices'][slice_idx] adc_slice = cache['adc_slices'][slice_idx] pred_prob = cache['pred_prob_slices'][slice_idx] pred_binary = cache['pred_binary_slices'][slice_idx] gt_slice = cache['gt_slices'][slice_idx] fig, axes = plt.subplots(2, 3, figsize=(15, 8)) axes = axes.flatten() imshow_params = {'cmap': 'gray', 'aspect': 'equal', 'extent': extent, 'interpolation': 'nearest'} axes[0].imshow(t2w_slice, **imshow_params) axes[0].set_title(f'T2W (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[0].axis('off') axes[1].imshow(adc_slice, **imshow_params) axes[1].set_title(f'ADC (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[1].axis('off') axes[2].imshow(dwi_slice, **imshow_params) axes[2].set_title(f'DWI (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[2].axis('off') axes[3].imshow(t2w_slice, **imshow_params) if gt_slice is not None: gt_colored = np.zeros((*gt_slice.shape, 4)) gt_colored[gt_slice > 0] = [1, 0, 0, 0.6] # Red axes[3].imshow(gt_colored, extent=extent, aspect='equal', interpolation='nearest') axes[3].set_title('T2W + Ground Truth', fontsize=14, fontweight='bold') axes[3].axis('off') axes[4].imshow(t2w_slice, **imshow_params) if pred_binary is not None: pred_colored = np.zeros((*pred_binary.shape, 4)) pred_colored[pred_binary > 0] = [0, 1, 0, 0.6] # Green axes[4].imshow(pred_colored, extent=extent, aspect='equal', interpolation='nearest') axes[4].set_title('T2W + Prediction', fontsize=14, fontweight='bold') axes[4].axis('off') axes[5].imshow(t2w_slice, **imshow_params) if pred_prob is not None: heatmap_colored = plt.cm.jet(pred_prob) heatmap_colored[..., 3] = pred_prob * 0.7 axes[5].imshow(heatmap_colored, extent=extent, aspect='equal', interpolation='nearest') axes[5].set_title('T2W + Probability Heatmap', fontsize=14, fontweight='bold') axes[5].axis('off') # Adjust layout with padding for titles - fix title clipping issue plt.tight_layout(pad=3.0) plt.subplots_adjust(top=0.87) # Leave space for titles at the top return fig # 3. Fallback: Re-render from the raw preprocessed data in the state (slower) preprocessed_data = state_data.get('preprocessed_data') prediction = state_data.get('prediction') if preprocessed_data is not None and prediction is not None: logger.warning("Cache not found in state. Using fallback rendering.") return create_2d_slice_browser_optimized(preprocessed_data, prediction, slice_idx) # 4. If no valid data is found in the state at all logger.error("State is invalid. No cache or preprocessed data found.") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, "Invalid data state. Please run segmentation again.", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig except Exception as e: logger.error(f"Error updating slice browser with state: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Update error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig def run_anatomy_segmentation_inference_with_state(patient_id: str, current_state, progress=gr.Progress()): """Load anatomy segmentation inference results from cache with state management""" try: progress(0, desc="📂 Loading anatomy segmentation results from cache...") # Get data from precomputed results cached_result = PRECOMPUTED_RESULTS.get('anatomy_segmentation', {}).get(patient_id) if cached_result is None: progress(1, desc="❌ Cache result not found!") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Anatomy segmentation result cache not found for patient {patient_id}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') empty_3d = go.Figure() error_state = {'error': f"Cache result not found for patient {patient_id}", 'timestamp': time.time()} return fig, empty_3d, f"Error: Cache result not found for patient {patient_id}", gr.update(maximum=63, value=32), error_state progress(0.3, desc="🔄 Deserializing prediction results...") # Deserialize data prediction = serializable_to_tensor(cached_result['prediction']) preprocessed_data = serializable_to_tensor(cached_result['preprocessed_data']) ground_truth_tensor = serializable_to_tensor(cached_result['ground_truth_tensor']) if cached_result.get('ground_truth_tensor') else None progress(0.6, desc="🎨 Generating visualization...") # Create anatomy segmentation specific visualization slice_browser_fig = create_anatomy_2d_slice_browser(preprocessed_data, prediction) volume_3d_fig = create_anatomy_3d_volume_rendering(preprocessed_data, prediction) progress(0.8, desc="📊 Computing statistics...") # Generate anatomy segmentation result text class_names = [ "Background", "Bladder", "Bone", "Obturator Internus", "Peripheral Zone", "Transition Zone", "Rectum", "Seminal Vesicle", "Neurovascular Bundle" ] with torch.no_grad(): # For multi-class segmentation, get predicted classes pred_classes = torch.argmax(torch.softmax(prediction, dim=1), dim=1)[0].cpu().numpy() # Calculate statistics for each class result_text = f"Patient ID: {patient_id}\n" # result_text += f"Anatomy Segmentation Statistics:\n\n" # total_voxels = pred_classes.size # for class_idx in range(len(class_names)): # class_voxels = np.sum(pred_classes == class_idx) # class_ratio = class_voxels / total_voxels # if class_voxels > 0: # Only display existing classes # result_text += f"{class_names[class_idx]}: {class_voxels} voxels ({class_ratio:.3f})\n" if ground_truth_tensor is not None: with torch.no_grad(): # Calculate Dice score for each class pred_for_dice = torch.argmax(torch.softmax(prediction, dim=1), dim=1, keepdim=True) # Handle different ground truth tensor shapes if len(ground_truth_tensor.shape) == 5: # [B, C, D, H, W] gt_for_dice = ground_truth_tensor.squeeze(1) if ground_truth_tensor.shape[1] == 1 else ground_truth_tensor[:, 0] # [B, D, H, W] elif len(ground_truth_tensor.shape) == 4: # [B, D, H, W] gt_for_dice = ground_truth_tensor else: logger.warning(f"Unexpected ground truth shape: {ground_truth_tensor.shape}") gt_for_dice = ground_truth_tensor.reshape(ground_truth_tensor.shape[0], -1, ground_truth_tensor.shape[-2], ground_truth_tensor.shape[-1]) # Ensure ground truth values are within valid range num_classes = 8 # Based on class_names count minus 1 (remove background) gt_for_dice = torch.clamp(gt_for_dice, 0, num_classes).long() pred_for_dice_squeezed = pred_for_dice.squeeze(1).long() # Create one-hot encoding gt_one_hot = torch.nn.functional.one_hot(gt_for_dice, num_classes=num_classes + 1).float() # [B, D, H, W, num_classes+1] pred_one_hot = torch.nn.functional.one_hot(pred_for_dice_squeezed, num_classes=num_classes + 1).float() # [B, D, H, W, num_classes+1] # Convert to [B, num_classes+1, D, H, W] gt_one_hot = gt_one_hot.permute(0, 4, 1, 2, 3) pred_one_hot = pred_one_hot.permute(0, 4, 1, 2, 3) result_text += f"\n=== Performance Metrics ===\n" for class_idx in range(1, len(class_names)): # Skip background if class_idx == 2: # Skip Bone metrics display continue intersection = torch.sum(pred_one_hot[:, class_idx] * gt_one_hot[:, class_idx]) union = torch.sum(pred_one_hot[:, class_idx]) + torch.sum(gt_one_hot[:, class_idx]) if union > 0: dice = 2.0 * intersection / union result_text += f"{class_names[class_idx]} Dice: {dice.item():.3f}\n" else: result_text += f"{class_names[class_idx]} Dice: 0.000 (no ground truth)\n" # Create state dictionary state_data = { 'patient_id': patient_id, 'preprocessed_data': preprocessed_data, 'prediction': prediction, 'timestamp': time.time() } # Get middle slice index for slider middle_slice = preprocessed_data['t2w_preprocessed'].shape[0] // 2 progress(1, desc="✅ Completed!") return ( slice_browser_fig, volume_3d_fig, result_text, gr.update(maximum=preprocessed_data['t2w_preprocessed'].shape[0]-1, value=middle_slice), state_data ) except Exception as e: logger.error(f"Error loading anatomy segmentation cache result: {e}", exc_info=True) fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') empty_3d = go.Figure() error_state = {'error': str(e), 'timestamp': time.time()} return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32), error_state def update_anatomy_slice_browser_with_state(slice_idx: int, state_data): """Update anatomy slice browser using state data""" try: # Check if state object is valid if state_data is None or 'error' in state_data: error_msg = state_data.get('error', "No data loaded. Please run anatomy segmentation first.") if state_data else "No data loaded. Please run anatomy segmentation first." logger.warning(f"Invalid state data for anatomy slice browser: {error_msg}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, error_msg, ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig # Get data from state preprocessed_data = state_data.get('preprocessed_data') prediction = state_data.get('prediction') if preprocessed_data is not None and prediction is not None: return create_anatomy_2d_slice_browser(preprocessed_data, prediction, slice_idx) # If no valid data is found in the state logger.error("State is invalid. No preprocessed data found.") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, "Invalid data state. Please run anatomy segmentation again.", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig except Exception as e: logger.error(f"Error updating anatomy slice browser with state: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Update error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig def update_anatomy_3d_with_controls(selected_organ: str, state_data): """Update anatomy 3D rendering based on organ selection""" try: # Check if state object is valid if state_data is None or 'error' in state_data: error_msg = state_data.get('error', "No data loaded. Please run anatomy segmentation first.") if state_data else "No data loaded. Please run anatomy segmentation first." logger.warning(f"Invalid state data for anatomy 3D: {error_msg}") fig = go.Figure() fig.add_annotation( text=error_msg, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14) ) return fig # Get data from state preprocessed_data = state_data.get('preprocessed_data') prediction = state_data.get('prediction') if preprocessed_data is None or prediction is None: logger.error("State is invalid. No preprocessed data found.") fig = go.Figure() fig.add_annotation( text="Invalid data state. Please run anatomy segmentation again.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14) ) return fig # Map organ name to index organ_mapping = { "Bladder": 1, "Obturator Internus": 3, "Peripheral Zone": 4, "Transition Zone": 5, "Rectum": 6, "Seminal Vesicle": 7, "Neurovascular Bundle": 8 } selected_organ_idx = organ_mapping.get(selected_organ, 4) # Default to Peripheral Zone # Always display both Ground Truth and Prediction for better comparison return create_anatomy_3d_volume_rendering(preprocessed_data, prediction, selected_organ_idx, "Both") except Exception as e: logger.error(f"Error updating anatomy 3D with controls: {e}") fig = go.Figure() fig.add_annotation( text=f"3D update error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14) ) return fig def create_anatomy_2d_slice_browser(preprocessed_data: Dict, prediction: torch.Tensor, slice_idx: int = None) -> plt.Figure: """Create 2D slice browser for anatomy segmentation with 3 views - H200 optimized""" try: # Keep tensor on GPU for H200 optimization t2w_gpu = preprocessed_data['t2w_preprocessed'] # Use middle slice if none specified if slice_idx is None: slice_idx = t2w_gpu.shape[0] // 2 # Ensure slice index is valid slice_idx = max(0, min(slice_idx, t2w_gpu.shape[0] - 1)) # Extract T2W slice on GPU, then transfer to CPU with torch.no_grad(): t2w_slice = t2w_gpu[slice_idx].cpu().numpy() # Define consistent image extent and aspect ratio height, width = t2w_slice.shape extent = [0, width, height, 0] # [left, right, bottom, top] # Process prediction - optimized GPU processing pred_classes = None if prediction is not None: with torch.no_grad(): pred_logits_gpu = torch.softmax(prediction, dim=1)[0, :, slice_idx] pred_classes_gpu = torch.argmax(pred_logits_gpu, dim=0) pred_classes = pred_classes_gpu.cpu().numpy() # Process ground truth - optimized GPU processing gt_slice = None if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None: with torch.no_grad(): gt_slice = preprocessed_data['ground_truth_preprocessed'][slice_idx].cpu().numpy().astype(np.uint8) # Create figure with 3 subplots (1 row, 3 columns) fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Consistent imshow parameters imshow_params = { 'cmap': 'gray', 'aspect': 'equal', 'extent': extent, 'interpolation': 'nearest' } # Define high-contrast colors for each anatomical class (improved color scheme) class_colors = [ [0, 0, 0], # Background - black [1, 0, 0], # Bladder - bright red [0, 1, 0], # Bone - bright green (not displayed) [0, 0, 1], # Obturator Internus - bright blue [1, 1, 0], # Peripheral Zone - yellow [1, 0.5, 0], # Transition Zone - orange [0.5, 0, 0.5], # Rectum - dark purple [0, 0.8, 0.8], # Seminal Vesicle - light cyan [0.8, 0, 0.8] # Neurovascular Bundle - bright magenta ] # 1. Original T2W image axes[0].imshow(t2w_slice, **imshow_params) axes[0].set_title(f'T2W (Slice {slice_idx})', fontsize=14, fontweight='bold') axes[0].axis('off') # 2. T2W + Ground Truth overlay axes[1].imshow(t2w_slice, **imshow_params) if gt_slice is not None: # Create colored mask for ground truth gt_colored = np.zeros((*gt_slice.shape, 4)) for class_idx in range(1, len(class_colors)): # Skip background if class_idx == 2: # Skip Bone visualization continue mask = gt_slice == class_idx if np.any(mask): gt_colored[mask] = class_colors[class_idx] + [0.6] # Add alpha axes[1].imshow(gt_colored, extent=extent, aspect='equal', interpolation='nearest') axes[1].set_title('T2W + Ground Truth', fontsize=14, fontweight='bold') axes[1].axis('off') # 3. T2W + Prediction overlay axes[2].imshow(t2w_slice, **imshow_params) if pred_classes is not None: # Create colored mask for prediction pred_colored = np.zeros((*pred_classes.shape, 4)) for class_idx in range(1, len(class_colors)): # Skip background if class_idx == 2: # Skip Bone visualization continue mask = pred_classes == class_idx if np.any(mask): pred_colored[mask] = class_colors[class_idx] + [0.6] # Add alpha axes[2].imshow(pred_colored, extent=extent, aspect='equal', interpolation='nearest') axes[2].set_title('T2W + Prediction', fontsize=14, fontweight='bold') axes[2].axis('off') # Adjust layout with padding for titles - fix title clipping issue plt.tight_layout(pad=3.0) plt.subplots_adjust(top=0.85) # Leave space for titles at the top return fig except Exception as e: logger.error(f"Error creating anatomy 2D slice browser: {e}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Visualization error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.axis('off') return fig def create_anatomy_3d_volume_rendering(preprocessed_data: Dict, prediction: torch.Tensor, selected_organ_idx: int = 4, display_mode: str = "Both") -> go.Figure: """Create interactive 3D volume rendering for single organ - H200 optimized""" try: # Vectorized processing on GPU before CPU transfer with torch.no_grad(): # Get T2W data for reference t2w_gpu = preprocessed_data['t2w_preprocessed'] # Process prediction for anatomy structures in batch on GPU pred_classes_gpu = torch.zeros_like(t2w_gpu, dtype=torch.uint8) if prediction is not None: pred_logits_gpu = torch.softmax(prediction, dim=1)[0] pred_classes_gpu = torch.argmax(pred_logits_gpu, dim=0).to(torch.uint8) # Process ground truth gt_classes_gpu = torch.zeros_like(t2w_gpu, dtype=torch.uint8) if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None: gt_classes_gpu = preprocessed_data['ground_truth_preprocessed'].to(torch.uint8) # Transfer to CPU in one batch operation t2w = t2w_gpu.cpu().numpy() pred_classes = pred_classes_gpu.cpu().numpy() gt_classes = gt_classes_gpu.cpu().numpy() # Downsample for performance (every 2nd voxel) t2w_ds = t2w[::2, ::2, ::2] pred_classes_ds = pred_classes[::2, ::2, ::2] gt_classes_ds = gt_classes[::2, ::2, ::2] logger.info(f"Downsampled anatomy data shapes: T2W={t2w_ds.shape}, prediction={pred_classes_ds.shape}, ground truth={gt_classes_ds.shape}") # Create 3D meshes for selected organ only fig = go.Figure() # Define class information with improved high-contrast colors # Using warm colors for Ground Truth and cool colors for Prediction class_info_gt = [ # Ground Truth - Warm colors (higher saturation) ("Background", [0.5, 0.5, 0.5]), # Background - skip ("Bladder", [1.0, 0.2, 0.2]), # Warm red ("Bone", [0.8, 1.0, 0.2]), # Warm yellow-green (not displayed) ("Obturator Internus", [1.0, 0.6, 0.0]), # Warm orange ("Peripheral Zone", [1.0, 0.8, 0.0]), # Warm yellow ("Transition Zone", [1.0, 0.4, 0.2]), # Warm orange-red ("Rectum", [0.8, 0.2, 0.6]), # Warm magenta ("Seminal Vesicle", [1.0, 0.7, 0.3]), # Warm peach ("Neurovascular Bundle", [0.9, 0.3, 0.7]) # Warm pink ] class_info_pred = [ # Prediction - Cool colors (lower saturation) ("Background", [0.5, 0.5, 0.5]), # Background - skip ("Bladder", [0.2, 0.4, 1.0]), # Cool blue ("Bone", [0.2, 0.8, 0.6]), # Cool cyan-green (not displayed) ("Obturator Internus", [0.0, 0.6, 1.0]), # Cool sky blue ("Peripheral Zone", [0.4, 0.7, 1.0]), # Cool light blue ("Transition Zone", [0.2, 0.5, 0.8]), # Cool blue-gray ("Rectum", [0.4, 0.2, 0.8]), # Cool purple ("Seminal Vesicle", [0.2, 0.7, 0.9]), # Cool cyan ("Neurovascular Bundle", [0.5, 0.3, 0.9]) # Cool lavender ] # Validate selected organ index if selected_organ_idx < 1 or selected_organ_idx >= len(class_info_gt): selected_organ_idx = 4 # Default to Peripheral Zone # Skip Bone visualization (index 2) if selected_organ_idx == 2: # Return empty figure with message for Bone fig = go.Figure() fig.add_annotation( text="Bone visualization is not available", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) return fig class_name, color_gt = class_info_gt[selected_organ_idx] _, color_pred = class_info_pred[selected_organ_idx] # Track what gets displayed traces_added = [] # Ground Truth mesh (if display mode includes GT) if display_mode in ["Ground Truth", "Both"]: gt_mask = (gt_classes_ds == selected_organ_idx) if np.count_nonzero(gt_mask) > 50: try: logger.info(f"Creating 3D Ground Truth mesh for {class_name}...") verts_gt, faces_gt, _, _ = measure.marching_cubes(gt_mask.astype(float), level=0.5) fig.add_trace(go.Mesh3d( x=verts_gt[:, 0], y=verts_gt[:, 1], z=verts_gt[:, 2], i=faces_gt[:, 0], j=faces_gt[:, 1], k=faces_gt[:, 2], name=f"{class_name} (Ground Truth)", color=f'rgb({int(color_gt[0]*255)},{int(color_gt[1]*255)},{int(color_gt[2]*255)})', opacity=0.9, # Higher opacity for ground truth (warm colors) lighting=dict(ambient=0.6, diffuse=0.8), showlegend=True, visible=True, hovertemplate=f"{class_name} Ground Truth
Vertex: (%{{x}}, %{{y}}, %{{z}})" )) traces_added.append(f"{class_name} (Ground Truth)") logger.info(f"✅ Successfully created {class_name} Ground Truth 3D mesh") except Exception as e: logger.warning(f"Could not create {class_name} Ground Truth mesh: {e}") # Prediction mesh (if display mode includes Prediction) if display_mode in ["Prediction", "Both"]: pred_mask = (pred_classes_ds == selected_organ_idx) if np.count_nonzero(pred_mask) > 50: try: logger.info(f"Creating 3D Prediction mesh for {class_name}...") verts_pred, faces_pred, _, _ = measure.marching_cubes(pred_mask.astype(float), level=0.5) # Use lower opacity for prediction (cool colors) to create better contrast pred_opacity = 0.5 if display_mode == "Both" else 0.7 fig.add_trace(go.Mesh3d( x=verts_pred[:, 0], y=verts_pred[:, 1], z=verts_pred[:, 2], i=faces_pred[:, 0], j=faces_pred[:, 1], k=faces_pred[:, 2], name=f"{class_name} (Prediction)", color=f'rgb({int(color_pred[0]*255)},{int(color_pred[1]*255)},{int(color_pred[2]*255)})', opacity=pred_opacity, lighting=dict(ambient=0.8, diffuse=0.6), # Different lighting for cool colors showlegend=True, visible=True, hovertemplate=f"{class_name} Prediction
Vertex: (%{{x}}, %{{y}}, %{{z}})" )) traces_added.append(f"{class_name} (Prediction)") logger.info(f"✅ Successfully created {class_name} Prediction 3D mesh") except Exception as e: logger.warning(f"Could not create {class_name} Prediction mesh: {e}") # Configure layout display_mode_text = display_mode if display_mode != "Both" else "Ground Truth (Warm) vs Prediction (Cool)" fig.update_layout( title=dict( text=f"3D Anatomy Rendering: {class_name} - {display_mode_text}", x=0.5, font=dict(size=16) ), scene=dict( xaxis_title="X", yaxis_title="Y", zaxis_title="Z", camera=dict( eye=dict(x=1.2, y=1.2, z=1.2) ), aspectmode='cube' ), width=700, height=400, margin=dict(l=0, r=0, b=0, t=40) ) if not traces_added: # Add annotation if no data to display fig.add_annotation( text=f"No {display_mode.lower()} data available for {class_name}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14) ) logger.info(f"3D anatomy rendering completed for {class_name} with mode {display_mode}: {traces_added}") return fig except Exception as e: logger.error(f"Error creating anatomy 3D rendering: {e}") # Return empty figure with error message fig = go.Figure() fig.add_annotation( text=f"3D anatomy rendering error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) return fig def create_interface(): """Create Gradio interface""" # Load sample data to get patient lists sample_patients, dataset_root = load_sample_data() # Load pre-training data to get patient lists pretrain_patients, pretrain_root = load_pretrain_data() # Create patient lists risk_patients = [] ucl_patients = [] anatomy_patients = [] if 'risk' in sample_patients: risk_patients = [f"Patient_{i}" for i in range(len(sample_patients['risk']))] if 'UCL' in sample_patients: ucl_patients = [f"Patient_{i}" for i in range(len(sample_patients['UCL']))] if 'anatomy' in sample_patients: anatomy_patients = [f"Patient_{i}" for i in range(len(sample_patients['anatomy']))] # Create theme with default settings theme = gr.themes.Ocean( primary_hue="blue", secondary_hue="gray", ) with gr.Blocks(theme=theme, title="ProFound: Vision Foundation Models for Prostate Multiparametric MR Images") as demo: # Header gr.Markdown(""" # ProFound: Vision Foundation Models for Prostate Multiparametric MR Images 🏥🔬 ProFound is a suite of vision foundation models, pre-trained on multiparametric 3D magnetic resonance (MR) images from large collections of prostate cancer patients. """) # Create State components to manage data pretrain_state = gr.State(value=None) segmentation_state = gr.State(value=None) anatomy_state = gr.State(value=None) with gr.Tabs() as tabs: # Pre-training tab - First tab (default) with gr.TabItem("🧠 ProFound Pre-training", id="pretraining"): gr.Markdown(""" ### ProFound Pre-training Visualization **Input**: T2W, DWI, ADC original images **Processing**: Masked reconstruction pipeline **Output**: Original → Masked → Reconstructed visualization **Purpose**: Demonstrate the self-supervised pre-training process used to train ProFound foundation models """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### Patient Selection") pretrain_patient_dropdown = gr.Dropdown( choices=pretrain_patients if pretrain_patients else ["patient_001", "patient_002"], label="Choose Patient", value=pretrain_patients[0] if pretrain_patients else "patient_001", info="Select a patient to visualize pre-training reconstruction" ) pretrain_button = gr.Button("🚀 Launch Reconstruction", variant="primary") with gr.Column(scale=4): gr.Markdown("#### Multi-Modal Pre-training Visualization") gr.Markdown("**🔍 2D Slice Browser** - Navigate through slices") # Slice control for pre-training pretrain_slice_slider = gr.Slider( minimum=0, maximum=63, value=32, step=1, label="Slice Index", info="Drag to navigate through different slices" ) # 3x3 grid visualization for pre-training gr.Markdown("**📊 Reconstruction Results:**") gr.Markdown("• **Rows**: T2W (top), DWI (middle), ADC (bottom)") gr.Markdown("• **Columns**: Original (left), Masked (center), Reconstructed (right)") pretrain_plot = gr.Plot( label="Reconstruction Results", value=None ) # Bind pre-training events pretrain_button.click( fn=run_pretrain_reconstruction, inputs=[pretrain_patient_dropdown, pretrain_state], outputs=[pretrain_plot, pretrain_slice_slider, pretrain_state] ) # Bind slice slider event for real-time updates pretrain_slice_slider.change( fn=update_pretrain_slice_with_state, inputs=[pretrain_slice_slider, pretrain_state], outputs=[pretrain_plot] ) # Classification tab with gr.TabItem("🎯 Assessment of Prostate Cancer Patients", id="classification"): gr.Markdown(""" ### PI-RADS Classification **Input**: T2W, DWI (High-b), ADC images **Output**: PI-RADS classification with confidence scores """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### Patient Selection") cls_patient_dropdown = gr.Dropdown( choices=risk_patients, label="Select Patient", value=risk_patients[0] if risk_patients else None, info="Choose a patient for risk assessment" ) cls_button = gr.Button("🚀 Run Classification", variant="primary") with gr.Column(scale=2): gr.Markdown("#### Results") cls_plot = gr.Plot(label="Multi-modal Visualization") cls_result = gr.Textbox( label="Classification Results", lines=10, interactive=False ) # Bind classification event cls_button.click( fn=run_classification_inference, inputs=[cls_patient_dropdown], outputs=[cls_plot, cls_result] ) # Segmentation tab with gr.TabItem("✂️ Segmentation of Prostate Cancer Lesions", id="segmentation"): gr.Markdown(""" ### Lesion Segmentation **Input**: T2W, DWI, ADC images **Output**: Binary lesion segmentation mask """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### Patient Selection") seg_patient_dropdown = gr.Dropdown( choices=ucl_patients, label="Select Patient", value="Patient_1" if ucl_patients and "Patient_1" in ucl_patients else (ucl_patients[0] if ucl_patients else None), info="Choose a patient for lesion segmentation" ) seg_button = gr.Button("🚀 Run Segmentation", variant="primary") gr.Markdown("#### Analysis Results") seg_result = gr.Textbox( label="Segmentation Results", lines=12, interactive=False ) with gr.Column(scale=3): gr.Markdown("#### Multi-Modal Visualization") # Combined visualization without tabs - display both in the same view gr.Markdown("**🔍 2D Slice Browser** - Navigate through slices to explore all modalities and results") # Slice control slice_slider = gr.Slider( minimum=0, maximum=63, value=32, step=1, label="Slice Index", info="Drag to navigate through different slices" ) # 2D slice browser plot - adjusted for better coordination slice_browser_plot = gr.Plot( label="6-Window Synchronized View: T2W, ADC, DWI, Ground Truth, Prediction, Heatmap" ) gr.Markdown("**🌐 3D Volume Rendering** - Interactive 3D visualization") gr.Markdown("- 🔵 **Blue**: Prostate structure - 🔴 **Red**: Predicted lesions - 🟡 **Yellow**: Ground truth lesions") # 3D volume rendering - adjusted for better coordination volume_3d_plot = gr.Plot( label="3D Prostate and Lesion Rendering" ) # Bind segmentation event with state management seg_button.click( fn=run_segmentation_inference_with_state, inputs=[seg_patient_dropdown, segmentation_state], outputs=[slice_browser_plot, volume_3d_plot, seg_result, slice_slider, segmentation_state] ) # Bind slice slider event for real-time updates with state slice_slider.change( fn=update_slice_browser_with_state, inputs=[slice_slider, segmentation_state], outputs=[slice_browser_plot] ) # Anatomy Segmentation tab - NEW with gr.TabItem("🫀 Segmentation of Prostate Anatomy", id="anatomy_segmentation"): gr.Markdown(""" ### Anatomy Segmentation **Input**: T2W images only **Output**: 7-class anatomical structure segmentation **Classes**: Bladder, Obturator Internus, Peripheral Zone, Transition Zone, Rectum, Seminal Vesicle, Neurovascular Bundle """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### Patient Selection") anat_patient_dropdown = gr.Dropdown( choices=anatomy_patients, label="Select Patient", value=anatomy_patients[0] if anatomy_patients else None, info="Choose a patient for anatomy segmentation" ) anat_button = gr.Button("🚀 Run Anatomy Segmentation", variant="primary") gr.Markdown("#### Analysis Results") anat_result = gr.Textbox( label="Anatomy Segmentation Results", lines=15, interactive=False ) with gr.Column(scale=3): gr.Markdown("#### T2W-based Interactive Visualization") gr.Markdown("**🔍 2D Slice Browser** - Navigate through slices to explore anatomical structures") # Slice control for anatomy anat_slice_slider = gr.Slider( minimum=0, maximum=63, value=32, step=1, label="Slice Index", info="Drag to navigate through different slices" ) # 2D slice browser for anatomy anat_slice_browser_plot = gr.Plot( label="3-Window Anatomical View: T2W Original, Ground Truth, Prediction" ) gr.Markdown("**🌐 3D Anatomy Rendering** - Interactive single-organ visualization") gr.Markdown("**Organ Selection**") anat_organ_dropdown = gr.Dropdown( choices=[ "Bladder", "Obturator Internus", "Peripheral Zone", "Transition Zone", "Rectum", "Seminal Vesicle", "Neurovascular Bundle" ], label="Select Anatomical Structure", value="Peripheral Zone", info="Choose anatomical structure to display in 3D view" ) gr.Markdown("**Color Legend**: **Ground Truth** (warm colors, solid) vs **Prediction** (cool colors, transparent)") gr.Markdown("🔥 **Warm tones**: Ground Truth ❄️ **Cool tones**: Prediction") gr.Markdown("📋 **Organ List**: Bladder, Obturator Internus, Peripheral Zone, Transition Zone, Rectum, Seminal Vesicle, Neurovascular Bundle") # 3D volume rendering for anatomy anat_volume_3d_plot = gr.Plot( label="3D Single-Organ Rendering: Focused Anatomical Structure View" ) # Bind anatomy segmentation event with state management anat_button.click( fn=run_anatomy_segmentation_inference_with_state, inputs=[anat_patient_dropdown, anatomy_state], outputs=[anat_slice_browser_plot, anat_volume_3d_plot, anat_result, anat_slice_slider, anatomy_state] ) # Bind slice slider event for real-time updates with state anat_slice_slider.change( fn=update_anatomy_slice_browser_with_state, inputs=[anat_slice_slider, anatomy_state], outputs=[anat_slice_browser_plot] ) # Bind organ selection changes for 3D rendering anat_organ_dropdown.change( fn=update_anatomy_3d_with_controls, inputs=[anat_organ_dropdown, anatomy_state], outputs=[anat_volume_3d_plot] ) return demo # Launch application if __name__ == "__main__": # Set random seed for reproducibility torch.manual_seed(0) np.random.seed(0) cudnn.benchmark = True logger.info("🚀 Starting ProFound demo system with precomputed cache optimization") logger.info("📂 Loading precomputed inference results...") # Load precomputed results load_precomputed_results() logger.info("✅ Precomputed results loading completed, starting web interface...") demo = create_interface() demo.launch( share=True, server_name="0.0.0.0", server_port=7860, allowed_paths=[str(cache_dir), str(Path(__file__).parent)] )