import os
import argparse
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
import spaces
from pathlib import Path
from huggingface_hub import hf_hub_download, snapshot_download
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
import pandas as pd
import tempfile
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from scipy import ndimage
from skimage import measure
import cv2
import json
import logging
import nibabel as nib
from PIL import Image
import SimpleITK as sitk
from monai.inferers import sliding_window_inference
from monai.metrics import compute_dice
import monai
import time
import pickle
from monai.transforms import (
Compose,
NormalizeIntensityd,
CenterSpatialCropd,
SpatialPadd,
Spacing,
CenterSpatialCrop,
SpatialPad,
Resize
)
# Configure matplotlib basic settings
plt.rcParams['axes.unicode_minus'] = False
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cache directory
cache_dir = Path(__file__).parent / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
# Results cache directory
results_cache_dir = cache_dir / "precomputed_results"
results_cache_dir.mkdir(parents=True, exist_ok=True)
# Device Configuration (compatible with Spaces stateless GPU)
# Note: Device will be set inside @spaces.GPU decorated functions
device = None # Will be set when GPU is available
dtype = torch.float32 # Use full precision (FP32)
# GPU optimization flags (will be applied in GPU-decorated functions)
_gpu_optimizations_applied = False
# Performance optimization flags
ENABLE_TORCH_COMPILE = os.getenv("ENABLE_TORCH_COMPILE", "true").lower() == "true"
ENABLE_MODEL_WARMUP = os.getenv("ENABLE_MODEL_WARMUP", "true").lower() == "true"
ENABLE_PRECOMPUTE = os.getenv("ENABLE_PRECOMPUTE", "true").lower() == "true"
# Global cache for precomputed results
PRECOMPUTED_RESULTS = {}
def save_results_to_cache(results: Dict, cache_file: Path):
"""Save inference results to cache file"""
try:
with open(cache_file, 'wb') as f:
pickle.dump(results, f)
logger.info(f"✅ Successfully saved inference results cache to: {cache_file}")
except Exception as e:
logger.error(f"❌ Failed to save cache: {e}")
def load_results_from_cache(cache_file: Path) -> Dict:
"""Load inference results from cache file"""
try:
if cache_file.exists():
with open(cache_file, 'rb') as f:
results = pickle.load(f)
logger.info(f"✅ Successfully loaded inference results from cache: {cache_file}")
return results
else:
logger.info(f"📁 Cache file does not exist: {cache_file}")
return {}
except Exception as e:
logger.error(f"❌ Failed to load cache: {e}")
return {}
def tensor_to_serializable(obj):
"""Convert tensor to serializable format"""
if isinstance(obj, torch.Tensor):
return {
'_type': 'tensor',
'data': obj.cpu().numpy(),
'shape': obj.shape,
'dtype': str(obj.dtype)
}
elif isinstance(obj, dict):
return {k: tensor_to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [tensor_to_serializable(item) for item in obj]
else:
return obj
def serializable_to_tensor(obj):
"""Convert serializable format back to tensor"""
if isinstance(obj, dict):
if obj.get('_type') == 'tensor':
return torch.from_numpy(obj['data'])
else:
return {k: serializable_to_tensor(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [serializable_to_tensor(item) for item in obj]
else:
return obj
@spaces.GPU
def precompute_all_results():
"""Precompute all inference results for all patients and all tasks"""
if not ENABLE_PRECOMPUTE:
logger.info("🚫 Precomputation disabled, skipping precomputation process")
return {}
logger.info("🚀 Starting precomputation of all inference results...")
# Apply GPU optimizations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
apply_gpu_optimizations()
# Set seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
cudnn.benchmark = True
# Load sample data
sample_patients, dataset_root = load_sample_data()
all_results = {
'classification': {},
'segmentation': {},
'anatomy_segmentation': {}
}
# Precompute classification results
if 'risk' in sample_patients:
logger.info("📊 Precomputing classification tasks...")
try:
cls_model, cls_args = load_classification_model()
for patient_idx in range(len(sample_patients['risk'])):
patient_id = f"Patient_{patient_idx}"
logger.info(f"Processing classification task - {patient_id}")
try:
# Create dataloader for single patient
data_loader, args, dataset = create_single_sample_dataloader(patient_idx, 'classification')
preprocessed_data = get_preprocessed_patient_data(patient_idx, 'classification', device)
# Run inference
cls_model.eval()
with torch.no_grad():
for idx, data in enumerate(data_loader):
img, gt, pid = data
img = img.to(device, dtype=dtype, non_blocking=True)
gt = gt.to(device, non_blocking=True)
with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
logit = cls_model(img)
# Store results
all_results['classification'][patient_id] = {
'prediction': tensor_to_serializable(logit),
'preprocessed_data': tensor_to_serializable(preprocessed_data),
'patient_id': patient_id
}
break
logger.info(f"✅ Classification task completed - {patient_id}")
except Exception as e:
logger.error(f"❌ Classification task failed - {patient_id}: {e}")
continue
except Exception as e:
logger.error(f"❌ Classification model loading failed: {e}")
# Precompute segmentation results
if 'UCL' in sample_patients:
logger.info("✂️ Precomputing lesion segmentation tasks...")
try:
seg_model, seg_args = load_segmentation_model()
for patient_idx in range(len(sample_patients['UCL'])):
patient_id = f"Patient_{patient_idx}"
logger.info(f"Processing lesion segmentation task - {patient_id}")
try:
# Create dataloader for single patient
data_loader, args, dataset = create_single_sample_dataloader(patient_idx, 'segmentation')
preprocessed_data = get_preprocessed_patient_data(patient_idx, 'segmentation', device)
# Run inference
seg_model.eval()
with torch.no_grad():
for idx, data in enumerate(data_loader):
if len(data) == 4:
img, gt, pid, gland = data
else:
img, gt, pid = data
gland = None
img = img.to(device, dtype=dtype, non_blocking=True)
ground_truth_tensor = gt.to(device, non_blocking=True) if gt is not None else None
with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
if args.sliding_window:
pred = sliding_window_inference(img, args.crop_spatial_size, 4, seg_model, overlap=0.5)
else:
pred = seg_model(img)
# Precompute slice cache for smooth navigation
local_cached_slice_data = precompute_slice_cache(preprocessed_data, pred)
# Store results
all_results['segmentation'][patient_id] = {
'prediction': tensor_to_serializable(pred),
'preprocessed_data': tensor_to_serializable(preprocessed_data),
'cached_slice_data': local_cached_slice_data,
'patient_id': patient_id,
'ground_truth_tensor': tensor_to_serializable(ground_truth_tensor) if ground_truth_tensor is not None else None
}
break
logger.info(f"✅ Lesion segmentation task completed - {patient_id}")
except Exception as e:
logger.error(f"❌ Lesion segmentation task failed - {patient_id}: {e}")
continue
except Exception as e:
logger.error(f"❌ Segmentation model loading failed: {e}")
# Precompute anatomy segmentation results
if 'anatomy' in sample_patients:
logger.info("🫀 Precomputing anatomy segmentation tasks...")
try:
anat_model, anat_args = load_anatomy_segmentation_model()
for patient_idx in range(len(sample_patients['anatomy'])):
patient_id = f"Patient_{patient_idx}"
logger.info(f"Processing anatomy segmentation task - {patient_id}")
try:
# Create dataloader for single patient
data_loader, args, dataset = create_single_sample_dataloader(patient_idx, 'anatomy_segmentation')
preprocessed_data = get_preprocessed_patient_data(patient_idx, 'anatomy_segmentation', device)
# Run inference
anat_model.eval()
with torch.no_grad():
for idx, data in enumerate(data_loader):
img, gt, pid = data
img = img.to(device, dtype=dtype, non_blocking=True)
ground_truth_tensor = gt.to(device, non_blocking=True) if gt is not None else None
with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
if args.sliding_window:
pred = sliding_window_inference(img, args.crop_spatial_size, 4, anat_model, overlap=0.5)
else:
pred = anat_model(img)
# Store results
all_results['anatomy_segmentation'][patient_id] = {
'prediction': tensor_to_serializable(pred),
'preprocessed_data': tensor_to_serializable(preprocessed_data),
'patient_id': patient_id,
'ground_truth_tensor': tensor_to_serializable(ground_truth_tensor) if ground_truth_tensor is not None else None
}
break
logger.info(f"✅ Anatomy segmentation task completed - {patient_id}")
except Exception as e:
logger.error(f"❌ Anatomy segmentation task failed - {patient_id}: {e}")
continue
except Exception as e:
logger.error(f"❌ Anatomy segmentation model loading failed: {e}")
# Save results to cache
cache_file = results_cache_dir / "all_results.pkl"
save_results_to_cache(all_results, cache_file)
logger.info("🎉 All inference results precomputation completed!")
return all_results
def load_precomputed_results():
"""Load precomputed inference results"""
global PRECOMPUTED_RESULTS
cache_file = results_cache_dir / "all_results.pkl"
PRECOMPUTED_RESULTS = load_results_from_cache(cache_file)
if not PRECOMPUTED_RESULTS:
logger.info("🔄 No precomputed results found, starting precomputation...")
logger.info("⏳ This may take a few minutes, please be patient...")
PRECOMPUTED_RESULTS = precompute_all_results()
logger.info("🎉 Precomputation completed! All inference will now load quickly from cache")
else:
logger.info(f"✅ Successfully loaded precomputed results, containing:")
logger.info(f" 📊 Classification tasks: {len(PRECOMPUTED_RESULTS.get('classification', {}))} patients")
logger.info(f" ✂️ Lesion segmentation: {len(PRECOMPUTED_RESULTS.get('segmentation', {}))} patients")
logger.info(f" 🫀 Anatomy segmentation: {len(PRECOMPUTED_RESULTS.get('anatomy_segmentation', {}))} patients")
logger.info("⚡ All inference will load quickly from cache without GPU computation")
def force_recompute_all_results():
"""Force recomputation of all inference results (use when cache needs updating)"""
global PRECOMPUTED_RESULTS
logger.info("🔄 Force recomputing all inference results...")
# Delete existing cache file
cache_file = results_cache_dir / "all_results.pkl"
if cache_file.exists():
cache_file.unlink()
logger.info("🗑️ Deleted old cache file")
# Recompute and save
PRECOMPUTED_RESULTS = precompute_all_results()
logger.info("🎉 Recomputation completed!")
return PRECOMPUTED_RESULTS
def load_pretrain_data():
"""Load pre-training dataset from Hugging Face Dataset Hub"""
try:
# Download dataset if not already cached
dataset_path = snapshot_download(
repo_id=DATASET_REPO,
repo_type="dataset",
cache_dir=cache_dir
)
dataset_path = Path(dataset_path)
pretrain_path = dataset_path / "demo" / "data" / "pretrain"
# Check available patients
pretrain_patients = []
if pretrain_path.exists():
for patient_dir in pretrain_path.iterdir():
if patient_dir.is_dir() and patient_dir.name.startswith("patient_"):
pretrain_patients.append(patient_dir.name)
pretrain_patients.sort() # Ensure consistent ordering
logger.info(f"Found {len(pretrain_patients)} pre-training patients: {pretrain_patients}")
return pretrain_patients, pretrain_path
except Exception as e:
logger.error(f"Error loading pre-training data: {e}")
return [], None
def load_pretrain_patient_data(patient_id: str, pretrain_root: Path):
"""Load all 9 pre-training images for a specific patient"""
try:
logger.info(f"🔍 Loading pre-training data for {patient_id}")
patient_path = pretrain_root / patient_id
if not patient_path.exists():
raise ValueError(f"Patient directory not found: {patient_path}")
# Define expected file structure
modalities = ['T2W', 'DWI', 'ADC']
processing_types = ['original', 'masked', 'reconstructed']
patient_data = {}
for modality in modalities:
patient_data[modality] = {}
for proc_type in processing_types:
file_name = f"{modality}_{proc_type}.nii.gz"
file_path = patient_path / file_name
if file_path.exists():
# Load NIfTI file
nii_img = nib.load(str(file_path))
img_data = nii_img.get_fdata()
# Convert to tensor for consistency with other parts of the code
img_tensor = torch.from_numpy(img_data).float()
patient_data[modality][proc_type] = img_tensor
logger.info(f"✅ Loaded {file_name}: shape={img_tensor.shape} (XYZ format)")
else:
logger.warning(f"⚠️ File not found: {file_path}")
patient_data[modality][proc_type] = None
logger.info(f"Successfully loaded pre-training data for {patient_id}")
return patient_data
except Exception as e:
logger.error(f"Error loading patient data for {patient_id}: {e}")
return None
def create_pretrain_visualization(patient_data: Dict, slice_idx: int = None) -> plt.Figure:
"""Create 3x3 grid visualization for pre-training data showing XY plane at different Z slices"""
try:
if patient_data is None:
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "No patient data loaded", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
modalities = ['T2W', 'DWI', 'ADC']
processing_types = ['original', 'masked', 'reconstructed']
# Determine slice index from T2W original if not specified
if slice_idx is None:
if patient_data.get('T2W', {}).get('original') is not None:
# Data shape is XYZ, so Z dimension is shape[2]
slice_idx = patient_data['T2W']['original'].shape[2] // 2
else:
slice_idx = 32 # Default fallback
# Create 3x3 grid
fig, axes = plt.subplots(3, 3, figsize=(12, 10))
for row, modality in enumerate(modalities):
for col, proc_type in enumerate(processing_types):
ax = axes[row, col]
# Get image data
img_data = patient_data.get(modality, {}).get(proc_type)
if img_data is not None:
# Ensure we have 3D data (X, Y, Z)
if len(img_data.shape) != 3:
logger.warning(f"Unexpected image shape for {modality}_{proc_type}: {img_data.shape}")
ax.text(0.5, 0.5, f'{modality}\n{proc_type}\nUnexpected shape: {img_data.shape}',
ha='center', va='center', transform=ax.transAxes, fontsize=10)
ax.set_title(f'{modality} - {proc_type.capitalize()}', fontsize=10, fontweight='bold')
ax.axis('off')
continue
# Ensure slice index is valid for Z dimension (shape[2] for XYZ format)
max_slice = img_data.shape[2] - 1 # Z dimension is the 3rd dimension
current_slice = max(0, min(slice_idx, max_slice))
# Extract XY plane at Z=current_slice (from XYZ format)
xy_slice = img_data[:, :, current_slice].cpu().numpy() # Shape: (X, Y)
# Transpose to get correct orientation (Y, X) for proper display
xy_slice = xy_slice.T # Now shape: (Y, X)
# Flip vertically to correct up-down orientation
xy_slice = np.flipud(xy_slice)
# Display image - XY plane
im = ax.imshow(xy_slice, cmap='gray', aspect='equal', interpolation='nearest', origin='lower')
# Add title for each image
ax.set_title(f'{modality} - {proc_type.capitalize()}', fontsize=12, fontweight='bold')
else:
# Show placeholder for missing data
ax.text(0.5, 0.5, f'{modality}\n{proc_type}\n(Not Available)',
ha='center', va='center', transform=ax.transAxes, fontsize=10)
ax.set_title(f'{modality} - {proc_type.capitalize()}', fontsize=12, fontweight='bold')
ax.axis('off')
# Adjust layout with padding for titles - fix title clipping issue
plt.tight_layout(pad=3.0)
plt.subplots_adjust(top=0.88) # Leave space for titles at the top
return fig
except Exception as e:
logger.error(f"Error creating pre-training visualization: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Visualization error: {str(e)}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
def run_pretrain_reconstruction(patient_id: str, current_state, progress=gr.Progress()):
"""Load and display pre-training reconstruction results"""
try:
progress(0, desc="📂 Loading pre-training data...")
# Load pre-training data
pretrain_patients, pretrain_root = load_pretrain_data()
if not pretrain_patients or pretrain_root is None:
progress(1, desc="❌ Pre-training data not found!")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "Pre-training data not available", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
error_state = {'error': "Pre-training data not available", 'timestamp': time.time()}
return fig, gr.update(maximum=63, value=32), error_state
if patient_id not in pretrain_patients:
progress(1, desc="❌ Patient not found!")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Patient {patient_id} not found", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
error_state = {'error': f"Patient {patient_id} not found", 'timestamp': time.time()}
return fig, gr.update(maximum=63, value=32), error_state
progress(0.5, desc="🔄 Loading patient images...")
# Load patient data
patient_data = load_pretrain_patient_data(patient_id, pretrain_root)
if patient_data is None:
progress(1, desc="❌ Failed to load patient data!")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Failed to load data for {patient_id}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
error_state = {'error': f"Failed to load data for {patient_id}", 'timestamp': time.time()}
return fig, gr.update(maximum=63, value=32), error_state
progress(0.8, desc="🎨 Creating XY plane visualization...")
# Create visualization
middle_slice = 32 # Default middle slice
if patient_data.get('T2W', {}).get('original') is not None:
# Data shape is XYZ, so Z dimension is shape[2]
middle_slice = patient_data['T2W']['original'].shape[2] // 2
pretrain_fig = create_pretrain_visualization(patient_data, middle_slice)
# Create state dictionary
state_data = {
'patient_id': patient_id,
'patient_data': patient_data,
'pretrain_root': pretrain_root,
'timestamp': time.time()
}
# Determine maximum slice number (Z dimension is shape[2] for XYZ format)
max_slice = 63 # Default
if patient_data.get('T2W', {}).get('original') is not None:
max_slice = patient_data['T2W']['original'].shape[2] - 1
progress(1, desc="✅ Completed!")
return (
pretrain_fig,
gr.update(minimum=0, maximum=max_slice, value=middle_slice),
state_data
)
except Exception as e:
logger.error(f"Error in pre-training reconstruction: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
error_state = {'error': str(e), 'timestamp': time.time()}
return fig, gr.update(maximum=63, value=32), error_state
def update_pretrain_slice_with_state(slice_idx: int, state_data):
"""Update pre-training slice visualization using state data"""
try:
# Check if state object is valid
if state_data is None or 'error' in state_data:
error_msg = state_data.get('error', "No data loaded. Please run reconstruction first.") if state_data else "No data loaded. Please run reconstruction first."
logger.warning(f"Invalid state data for pre-training slice browser: {error_msg}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, error_msg, ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
# Get data from state
patient_data = state_data.get('patient_data')
if patient_data is not None:
return create_pretrain_visualization(patient_data, slice_idx)
# If no valid data is found in the state
logger.error("State is invalid. No patient data found.")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "Invalid data state. Please run reconstruction again.", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
except Exception as e:
logger.error(f"Error updating pre-training slice with state: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Update error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
def apply_gpu_optimizations():
"""Apply H200 GPU optimizations - only call inside @spaces.GPU decorated functions"""
global _gpu_optimizations_applied
if _gpu_optimizations_applied or not torch.cuda.is_available():
return
try:
# H200 GPU Memory and Performance Optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# Enable memory efficient attention if available
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
# Set memory management for H200
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.95) # Use 95% of H200's memory
_gpu_optimizations_applied = True
logger.info(f"🚀 H200 GPU optimizations enabled: {torch.cuda.get_device_name()}")
logger.info(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
logger.info(f"🔥 Using full precision (FP32) for maximum compatibility")
except Exception as e:
logger.warning(f"⚠️ Could not apply GPU optimizations: {e}")
logger.warning("Continuing without GPU optimizations...")
# Model configuration
MODEL_REPO = "wxyi088/ProFound"
DATASET_REPO = "wxyi088/ProFound"
# Import ProFound model classes
try:
from models.classifier import Classifier
from models.convnextv2 import convnextv2_tiny
from models.upernet_module import UperNet
from models.convnext_unter import ConvnextUNETR
from dataset.dataset_cls import build_Risk_loader
from dataset.dataset_seg import build_UCL_loader, build_BpAnatomy_loader, BpAnatomySet
from engine.classification import test_risk
logger.info("Successfully imported ProFound model classes")
except ImportError as e:
logger.error(f"Could not import ProFound models: {e}")
logger.error("Please ensure ProFound package is installed from GitHub")
raise ImportError("ProFound package not found. Please install from GitHub repository.")
def tuple_type(strings):
"""Parse tuple type parameters"""
strings = strings.replace("(", "").replace(")", "")
mapped_int = map(int, strings.split(","))
return tuple(mapped_int)
def create_args_for_classification():
"""Create arguments for classification task with H200 GPU optimizations"""
args = argparse.Namespace()
# Exact parameters from demo_run_classification.sh
args.batch_size = 1 # Changed to 1 for demo inference
args.model = 'profound_conv'
args.input_size = (64, 224, 224) # Exact match with demo script
args.crop_spatial_size = (64, 224, 224) # Exact match with demo script
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.dataset = 'risk' # Exact match
args.demo = True
# Add the missing attributes that build_Risk_loader sets
args.in_channels = 3
args.num_classes = 4
args.seed = 0
args.num_workers = min(8, torch.cuda.device_count() * 4) # Optimized for H200
args.pin_mem = True
args.tolerance = 5
args.spacing = (1.0, 0.5, 0.5)
args.weight_decay = 1e-5
args.lr = 0.1
args.min_lr = 0.0
args.warmup_epochs = 40
args.epochs = 400
args.train = 'scratch'
args.pretrain = None
args.root = str(cache_dir) # Root path for data - will be updated when needed
args.output_dir = './outputcls'
args.log_dir = './outputcls'
args.file_name = 'classification_output'
args.resume = ''
args.start_epoch = 0
args.data20 = False
args.data_num = 0
args.save_fig = False
args.prompt = False
args.world_size = 1
args.local_rank = -1
args.dist_on_itp = False
args.dist_url = 'env://'
args.kfold = None
return args
def create_args_for_segmentation():
"""Create arguments for segmentation task with H200 GPU optimizations"""
args = argparse.Namespace()
# Exact parameters from demo_run_segmentation.sh
args.batch_size = 1 # Changed to 1 for demo inference
args.model = 'profound_conv'
args.input_size = (64, 224, 224) # Exact match with demo script
args.crop_spatial_size = (64, 224, 224) # Exact match with demo script
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.dataset = 'UCL' # Exact match
args.demo = True
# Add the missing attributes that build_UCL_loader sets
args.in_channels = 3
args.out_channels = 1
args.num_classes = 1
args.seed = 0
args.num_workers = min(8, torch.cuda.device_count() * 4) # Optimized for H200
args.pin_mem = True
args.sliding_window = False # From demo_segmentation.py
args.save_fig = False
args.tolerance = 5
args.spacing = (1.0, 0.5, 0.5)
args.weight_decay = 1e-5
args.lr = 0.1
args.min_lr = 0.0
args.warmup_epochs = 40
args.epochs = 400
args.train = 'scratch'
args.pretrain = None
args.root = str(cache_dir) # Root path for data - will be updated when needed
args.output_dir = './outputseg'
args.log_dir = './outputseg'
args.file_name = 'segmentation_output'
args.resume = ''
args.start_epoch = 0
args.data20 = False
args.data_num = 0
args.prompt = False
args.world_size = 1
args.local_rank = -1
args.dist_on_itp = False
args.dist_url = 'env://'
return args
def create_args_for_anatomy_segmentation():
"""Create arguments for anatomy segmentation task with H200 GPU optimizations"""
args = argparse.Namespace()
# Similar parameters as segmentation but for anatomy dataset
args.batch_size = 1 # Changed to 1 for demo inference
args.model = 'profound_conv'
args.input_size = (64, 224, 224) # Exact match with demo script
args.crop_spatial_size = (64, 224, 224) # Exact match with demo script
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.dataset = 'anatomy' # Anatomy dataset
args.demo = True
# Add the missing attributes that build_BpAnatomy_loader sets
args.in_channels = 3
args.out_channels = 9 # 8 classes + background
args.num_classes = 8 # 8 anatomical structures
args.seed = 0
args.num_workers = min(8, torch.cuda.device_count() * 4) # Optimized for H200
args.pin_mem = True
args.sliding_window = False # From demo_segmentation.py
args.save_fig = False
args.tolerance = 5
args.spacing = (1.0, 0.5, 0.5)
args.weight_decay = 1e-5
args.lr = 0.1
args.min_lr = 0.0
args.warmup_epochs = 40
args.epochs = 400
args.train = 'scratch'
args.pretrain = None
args.root = str(cache_dir) # Root path for data - will be updated when needed
args.output_dir = './outputanat'
args.log_dir = './outputanat'
args.file_name = 'anatomy_segmentation_output'
args.resume = ''
args.start_epoch = 0
args.data20 = False
args.data_num = 0
args.prompt = False
args.world_size = 1
args.local_rank = -1
args.dist_on_itp = False
args.dist_url = 'env://'
return args
def load_classification_model():
"""Load classification model with H200 GPU optimizations"""
try:
# Download model weights
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="checkpoint/classification.pth.tar",
cache_dir=cache_dir
)
# Create model exactly like demo
args = create_args_for_classification()
convnext = convnextv2_tiny(in_chans=3)
model = Classifier(convnext, args.num_classes)
# Load weights - use CPU first for compatibility with Spaces
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint["model"])
# Set to eval mode first (on CPU)
model.eval() # Set to eval mode before moving to GPU
# Only move to GPU and apply optimizations if CUDA is available
if torch.cuda.is_available():
current_device = torch.device("cuda")
model = model.to(current_device, dtype=dtype)
# Optimize only if CUDA is available
if torch.cuda.is_available():
logger.info("🔥 Model using full precision (FP32)")
# Compile model for performance - using faster compilation mode
if ENABLE_TORCH_COMPILE:
try:
# Use default mode which is much faster than max-autotune
model = torch.compile(model, mode="default", fullgraph=False)
logger.info("⚡ Model compiled with torch.compile (default mode)")
except Exception as e:
logger.warning(f"Could not compile model: {e}")
logger.info("Continuing without torch.compile - model will still work but may be slower")
else:
logger.info("Torch.compile disabled - model will work but may be slower")
# Warm up model with dummy input - optimized for faster startup
if ENABLE_MODEL_WARMUP:
try:
logger.info("Starting model warm-up...")
dummy_input = torch.randn(1, 3, 64, 224, 224, device=device, dtype=dtype)
with torch.no_grad():
# Use a timeout to avoid hanging
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Model warm-up timed out")
# Set a 30-second timeout for warm-up
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(30)
try:
with torch.amp.autocast('cuda', enabled=True):
_ = model(dummy_input)
logger.info("Classification model warm-up completed successfully")
finally:
signal.alarm(0) # Cancel the alarm
except (TimeoutError, Exception) as e:
logger.warning(f"Classification model warm-up failed or timed out: {e}")
logger.info("Skipping warm-up - model will still work for inference")
else:
logger.info("Model warm-up disabled - skipping warm-up")
logger.info(f"✅ Classification model loaded and optimized for H200")
else:
logger.info("✅ Classification model loaded (CPU mode)")
return model, args
except Exception as e:
logger.error(f"Error loading classification model: {e}")
raise
def load_segmentation_model():
"""Load segmentation model with H200 GPU optimizations"""
try:
# Download model weights
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="checkpoint/segmentation.pth.tar",
cache_dir=cache_dir
)
# Create model exactly like demo
args = create_args_for_segmentation()
convnext = convnextv2_tiny(in_chans=3)
model = UperNet(
encoder=convnext,
in_channels=[96, 192, 384, 768],
out_channels=args.out_channels,
)
# Load weights - use CPU first for compatibility with Spaces
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint["model"])
# Set to eval mode first (on CPU)
model.eval() # Set to eval mode before moving to GPU
# Only move to GPU and apply optimizations if CUDA is available
if torch.cuda.is_available():
current_device = torch.device("cuda")
model = model.to(current_device, dtype=dtype)
# Optimize only if CUDA is available
if torch.cuda.is_available():
logger.info("🔥 Segmentation model using full precision (FP32)")
# Compile model for performance - using faster compilation mode
if ENABLE_TORCH_COMPILE:
try:
# Use default mode which is much faster than max-autotune
model = torch.compile(model, mode="default", fullgraph=False)
logger.info("⚡ Segmentation model compiled with torch.compile (default mode)")
except Exception as e:
logger.warning(f"Could not compile segmentation model: {e}")
logger.info("Continuing without torch.compile - model will still work but may be slower")
else:
logger.info("Torch.compile disabled for segmentation - model will work but may be slower")
# Warm up model with dummy input - optimized for faster startup
if ENABLE_MODEL_WARMUP:
try:
logger.info("Starting segmentation model warm-up...")
dummy_input = torch.randn(1, 3, 64, 224, 224, device=device, dtype=dtype)
with torch.no_grad():
# Use a timeout to avoid hanging
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Model warm-up timed out")
# Set a 30-second timeout for warm-up
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(30)
try:
with torch.amp.autocast('cuda', enabled=True):
_ = model(dummy_input)
logger.info("Segmentation model warm-up completed successfully")
finally:
signal.alarm(0) # Cancel the alarm
except (TimeoutError, Exception) as e:
logger.warning(f"Segmentation model warm-up failed or timed out: {e}")
logger.info("Skipping warm-up - model will still work for inference")
else:
logger.info("Segmentation model warm-up disabled - skipping warm-up")
logger.info(f"✅ Segmentation model loaded and optimized for H200")
else:
logger.info("✅ Segmentation model loaded (CPU mode)")
return model, args
except Exception as e:
logger.error(f"Error loading segmentation model: {e}")
raise
def load_anatomy_segmentation_model():
"""Load anatomy segmentation model with H200 GPU optimizations"""
try:
# Download model weights
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="checkpoint/anatomy_segmentation.pth.tar",
cache_dir=cache_dir
)
# Create model exactly like demo
args = create_args_for_anatomy_segmentation()
convnext = convnextv2_tiny(in_chans=3)
model = UperNet(
encoder=convnext,
in_channels=[96, 192, 384, 768],
out_channels=args.out_channels, # 9 channels for 8 classes + background
)
# Load weights - use CPU first for compatibility with Spaces
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint["model"])
# Set to eval mode first (on CPU)
model.eval() # Set to eval mode before moving to GPU
# Only move to GPU and apply optimizations if CUDA is available
if torch.cuda.is_available():
current_device = torch.device("cuda")
model = model.to(current_device, dtype=dtype)
# Optimize only if CUDA is available
if torch.cuda.is_available():
logger.info("🔥 Anatomy segmentation model using full precision (FP32)")
# Compile model for performance - using faster compilation mode
if ENABLE_TORCH_COMPILE:
try:
# Use default mode which is much faster than max-autotune
model = torch.compile(model, mode="default", fullgraph=False)
logger.info("⚡ Anatomy segmentation model compiled with torch.compile (default mode)")
except Exception as e:
logger.warning(f"Could not compile anatomy segmentation model: {e}")
logger.info("Continuing without torch.compile - model will still work but may be slower")
else:
logger.info("Torch.compile disabled for anatomy segmentation - model will work but may be slower")
# Warm up model with dummy input - optimized for faster startup
if ENABLE_MODEL_WARMUP:
try:
logger.info("Starting anatomy segmentation model warm-up...")
dummy_input = torch.randn(1, 3, 64, 224, 224, device=device, dtype=dtype)
with torch.no_grad():
# Use a timeout to avoid hanging
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Model warm-up timed out")
# Set a 30-second timeout for warm-up
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(30)
try:
with torch.amp.autocast('cuda', enabled=True):
_ = model(dummy_input)
logger.info("Anatomy segmentation model warm-up completed successfully")
finally:
signal.alarm(0) # Cancel the alarm
except (TimeoutError, Exception) as e:
logger.warning(f"Anatomy segmentation model warm-up failed or timed out: {e}")
logger.info("Skipping warm-up - model will still work for inference")
else:
logger.info("Anatomy segmentation model warm-up disabled - skipping warm-up")
logger.info(f"✅ Anatomy segmentation model loaded and optimized for H200")
else:
logger.info("✅ Anatomy segmentation model loaded (CPU mode)")
return model, args
except Exception as e:
logger.error(f"Error loading anatomy segmentation model: {e}")
raise
def load_sample_data():
"""Load sample data from Hugging Face Dataset Hub"""
try:
# Download dataset
dataset_path = snapshot_download(
repo_id=DATASET_REPO,
repo_type="dataset",
cache_dir=cache_dir
)
dataset_path = Path(dataset_path)
# Load sample patient data - same as demo expects
sample_patients = {}
# Try to find CSV files for demo data
risk_csv = dataset_path / "demo" / "data" / "risk" / "test.csv"
ucl_csv = dataset_path / "demo" / "data" / "UCL" / "test.csv"
anatomy_csv = dataset_path / "demo" / "data" / "anatomy" / "test.csv"
if risk_csv.exists():
risk_df = pd.read_csv(risk_csv)
sample_patients['risk'] = risk_df
logger.info(f"Found {len(risk_df)} risk assessment samples")
if ucl_csv.exists():
ucl_df = pd.read_csv(ucl_csv)
sample_patients['UCL'] = ucl_df
logger.info(f"Found {len(ucl_df)} UCL segmentation samples")
if anatomy_csv.exists():
anatomy_df = pd.read_csv(anatomy_csv)
sample_patients['anatomy'] = anatomy_df
logger.info(f"Found {len(anatomy_df)} anatomy segmentation samples")
return sample_patients, dataset_path
except Exception as e:
logger.error(f"Error loading sample data: {e}")
return {}, None
def create_single_sample_dataloader(patient_idx: int, task: str):
"""Create a DataLoader for a single patient - matching demo approach"""
try:
# Load sample data to get the correct root path and CSV data
sample_patients, dataset_root = load_sample_data()
if task == 'classification':
args = create_args_for_classification()
args.root = str(dataset_root) # Use the actual dataset root path
if 'risk' not in sample_patients:
raise ValueError("No risk assessment data found")
# Import the dataset class directly
from dataset.dataset_cls import RiskSet, get_transforms
# Get transforms
train_transforms, val_transforms, test_transforms = get_transforms(args)
# Create temporary CSV file path
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
sample_patients['risk'].to_csv(f.name, index=False)
temp_csv_path = f.name
# Create dataset using the temporary CSV file
dataset = RiskSet(args, temp_csv_path, 'test', test_transforms)
# Clean up temporary file
os.unlink(temp_csv_path)
elif task == 'anatomy_segmentation':
args = create_args_for_anatomy_segmentation()
args.root = str(dataset_root) # Use the actual dataset root path
if 'anatomy' not in sample_patients:
raise ValueError("No anatomy segmentation data found")
# Import the dataset class directly
from dataset.dataset_seg import BpAnatomySet, get_transforms
# Get transforms
train_transforms, val_transforms, test_transforms = get_transforms(args)
# Create temporary CSV file path
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
sample_patients['anatomy'].to_csv(f.name, index=False)
temp_csv_path = f.name
# Create dataset using the temporary CSV file
dataset = BpAnatomySet(args, temp_csv_path, 'test', test_transforms)
# Clean up temporary file
os.unlink(temp_csv_path)
else:
args = create_args_for_segmentation()
args.root = str(dataset_root) # Use the actual dataset root path
if 'UCL' not in sample_patients:
raise ValueError("No UCL segmentation data found")
# Import the dataset class directly
from dataset.dataset_seg import UCLSet, get_transforms
# Get transforms
train_transforms, val_transforms, test_transforms = get_transforms(args)
# Create temporary CSV file path
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
sample_patients['UCL'].to_csv(f.name, index=False)
temp_csv_path = f.name
# Create dataset using the temporary CSV file
dataset = UCLSet(args, temp_csv_path, 'test', test_transforms)
# Clean up temporary file
os.unlink(temp_csv_path)
# Check patient index
if patient_idx >= len(dataset):
raise IndexError(f"Patient index {patient_idx} out of range for dataset size {len(dataset)}")
# Create single sample dataloader with H200 optimizations
from torch.utils.data import DataLoader, Subset
single_sample_dataset = Subset(dataset, [patient_idx])
single_sample_loader = DataLoader(
single_sample_dataset,
batch_size=1,
shuffle=False,
pin_memory=torch.cuda.is_available(), # Only use pinned memory if CUDA is available
num_workers=0, # No multiprocessing for single sample
drop_last=False,
persistent_workers=False, # For single sample
prefetch_factor=None, # Must be None when num_workers=0
)
return single_sample_loader, args, dataset
except Exception as e:
logger.error(f"Error creating single sample dataloader: {e}")
raise
# Removed get_raw_patient_data function as it's no longer needed
# All data is now processed through the unified preprocessing pipeline
def visualize_multimodal_results(preprocessed_data: Dict, prediction: torch.Tensor, task: str) -> plt.Figure:
"""Visualize all three modalities with prediction results using preprocessed data - H200 optimized"""
try:
# Keep tensors on GPU as long as possible for H200 optimization
t2w_gpu = preprocessed_data['t2w_preprocessed']
dwi_gpu = preprocessed_data['dwi_preprocessed']
adc_gpu = preprocessed_data['adc_preprocessed']
# Calculate middle slice on GPU
middle_slice = t2w_gpu.shape[0] // 2
# Extract slices on GPU first, then convert to CPU only when needed
with torch.no_grad():
t2w_slice = t2w_gpu[middle_slice].cpu().numpy()
dwi_slice = dwi_gpu[middle_slice].cpu().numpy()
adc_slice = adc_gpu[middle_slice].cpu().numpy()
if task == 'classification':
# Show 3 modalities + prediction
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# Display three modalities
axes[0].imshow(t2w_slice, cmap='gray')
axes[0].set_title('T2W', fontsize=14, fontweight='bold')
axes[0].axis('off')
axes[1].imshow(dwi_slice, cmap='gray')
axes[1].set_title('DWI (High-b)', fontsize=14, fontweight='bold')
axes[1].axis('off')
axes[2].imshow(adc_slice, cmap='gray')
axes[2].set_title('ADC', fontsize=14, fontweight='bold')
axes[2].axis('off')
# Show prediction results - optimized GPU to CPU transfer
if prediction is not None:
class_names = ["PI-RADS 2", "PI-RADS 3", "PI-RADS 4", "PI-RADS 5"]
# Perform softmax on GPU, then transfer to CPU
with torch.no_grad():
probs = torch.softmax(prediction, dim=-1)[0].cpu().numpy()
axes[3].bar(class_names, probs, color=['green', 'yellow', 'orange', 'red'])
axes[3].set_title('Classification Results', fontsize=14, fontweight='bold')
axes[3].set_ylabel('Probability')
axes[3].set_ylim(0, 1)
axes[3].tick_params(axis='x', rotation=45)
# Add value labels on bars
for i, prob in enumerate(probs):
axes[3].text(i, prob + 0.01, f'{prob:.3f}', ha='center', va='bottom')
else:
# Segmentation: show 3 modalities + segmentation
fig, axes = plt.subplots(1, 5, figsize=(25, 5))
# Display three modalities
axes[0].imshow(t2w_slice, cmap='gray')
axes[0].set_title('T2W', fontsize=14, fontweight='bold')
axes[0].axis('off')
axes[1].imshow(dwi_slice, cmap='gray')
axes[1].set_title('DWI', fontsize=14, fontweight='bold')
axes[1].axis('off')
axes[2].imshow(adc_slice, cmap='gray')
axes[2].set_title('ADC', fontsize=14, fontweight='bold')
axes[2].axis('off')
# Show segmentation if available - optimized GPU processing
if prediction is not None:
# Perform sigmoid on GPU, then transfer specific slice to CPU
with torch.no_grad():
seg_prob_gpu = torch.sigmoid(prediction)[0, 0, middle_slice]
seg_slice = seg_prob_gpu.cpu().numpy()
seg_binary = (seg_slice > 0.5).astype(int)
axes[3].imshow(seg_slice, cmap='jet')
axes[3].set_title('Segmentation (Probability)', fontsize=14, fontweight='bold')
axes[3].axis('off')
# Overlay on T2W
axes[4].imshow(t2w_slice, cmap='gray')
axes[4].imshow(seg_binary, cmap='jet', alpha=0.5)
axes[4].set_title('Overlay on T2W', fontsize=14, fontweight='bold')
axes[4].axis('off')
# Show ground truth if available - optimized GPU processing
if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None:
with torch.no_grad():
gt_slice = preprocessed_data['ground_truth_preprocessed'][middle_slice].cpu().numpy()
# If we have both prediction and GT, show GT in a different position
if prediction is not None:
# Add another subplot for ground truth
fig, axes = plt.subplots(1, 6, figsize=(30, 5))
# Re-draw modalities
axes[0].imshow(t2w_slice, cmap='gray')
axes[0].set_title('T2W')
axes[0].axis('off')
axes[1].imshow(dwi_slice, cmap='gray')
axes[1].set_title('DWI')
axes[1].axis('off')
axes[2].imshow(adc_slice, cmap='gray')
axes[2].set_title('ADC')
axes[2].axis('off')
# Ground truth
axes[3].imshow(gt_slice, cmap='jet')
axes[3].set_title('Ground Truth', fontsize=14, fontweight='bold')
axes[3].axis('off')
# Prediction - optimized GPU processing
with torch.no_grad():
seg_prob_gpu = torch.sigmoid(prediction)[0, 0, middle_slice]
seg_slice = seg_prob_gpu.cpu().numpy()
seg_binary = (seg_slice > 0.5).astype(int)
axes[4].imshow(seg_slice, cmap='jet')
axes[4].set_title('Segmentation (Probability)', fontsize=14, fontweight='bold')
axes[4].axis('off')
# Overlay
axes[5].imshow(t2w_slice, cmap='gray')
axes[5].imshow(seg_binary, cmap='jet', alpha=0.5)
axes[5].set_title('Overlay on T2W', fontsize=14, fontweight='bold')
axes[5].axis('off')
else:
axes[3].imshow(gt_slice, cmap='jet')
axes[3].set_title('Ground Truth', fontsize=14, fontweight='bold')
axes[3].axis('off')
# Adjust layout with padding for titles - fix title clipping issue
plt.tight_layout(pad=3.0)
plt.subplots_adjust(top=0.85) # Leave space for titles at the top
return fig
except Exception as e:
logger.error(f"Error creating visualization: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Visualization error: {str(e)}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
# Global variables to store current data for slice browser and performance optimization
current_preprocessed_data = None
current_prediction = None
cached_slice_data = None # Cache pre-processed slice data for smooth navigation
def create_2d_slice_browser_optimized(preprocessed_data: Dict, prediction: torch.Tensor, slice_idx: int = None) -> plt.Figure:
"""Create optimized 2D slice browser with consistent scaling and smooth performance - H200 optimized"""
try:
# Keep tensors on GPU for H200 optimization
t2w_gpu = preprocessed_data['t2w_preprocessed']
dwi_gpu = preprocessed_data['dwi_preprocessed']
adc_gpu = preprocessed_data['adc_preprocessed']
# Use middle slice if none specified - calculated on GPU
if slice_idx is None:
slice_idx = t2w_gpu.shape[0] // 2
# Ensure slice index is valid
slice_idx = max(0, min(slice_idx, t2w_gpu.shape[0] - 1))
# Extract slices on GPU, then transfer to CPU
with torch.no_grad():
t2w_slice = t2w_gpu[slice_idx].cpu().numpy()
dwi_slice = dwi_gpu[slice_idx].cpu().numpy()
adc_slice = adc_gpu[slice_idx].cpu().numpy()
# Define consistent image extent and aspect ratio for all plots
# This ensures all T2W images display with the same scale and proportions
height, width = t2w_slice.shape
extent = [0, width, height, 0] # [left, right, bottom, top]
# Process prediction - optimized GPU processing
pred_prob = None
pred_binary = None
if prediction is not None:
with torch.no_grad():
pred_prob_gpu = torch.sigmoid(prediction)[0, 0, slice_idx]
pred_prob = pred_prob_gpu.cpu().numpy()
pred_binary = (pred_prob > 0.5).astype(np.uint8)
# Process ground truth - optimized GPU processing
gt_slice = None
if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None:
with torch.no_grad():
gt_slice = preprocessed_data['ground_truth_preprocessed'][slice_idx].cpu().numpy().astype(np.uint8)
# Create figure with 6 subplots (2 rows, 3 columns) - adjusted for better coordination
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
# Consistent imshow parameters for all T2W images
imshow_params = {
'cmap': 'gray',
'aspect': 'equal',
'extent': extent,
'interpolation': 'nearest'
}
# 1. T2 weighted image
axes[0].imshow(t2w_slice, **imshow_params)
axes[0].set_title(f'T2W (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[0].axis('off')
# 2. ADC image
axes[1].imshow(adc_slice, **imshow_params)
axes[1].set_title(f'ADC (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[1].axis('off')
# 3. DWI image
axes[2].imshow(dwi_slice, **imshow_params)
axes[2].set_title(f'DWI (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[2].axis('off')
# 4. T2 + Ground Truth overlay
axes[3].imshow(t2w_slice, **imshow_params)
if gt_slice is not None:
# Create colored mask for ground truth with consistent extent
gt_colored = np.zeros((*gt_slice.shape, 4))
gt_colored[gt_slice > 0] = [1, 0, 0, 0.6] # Red with 60% opacity
axes[3].imshow(gt_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[3].set_title('T2W + Ground Truth', fontsize=14, fontweight='bold')
axes[3].axis('off')
# 5. T2 + Prediction overlay
axes[4].imshow(t2w_slice, **imshow_params)
if pred_binary is not None:
# Create colored mask for prediction with consistent extent
pred_colored = np.zeros((*pred_binary.shape, 4))
pred_colored[pred_binary > 0] = [0, 1, 0, 0.6] # Green with 60% opacity
axes[4].imshow(pred_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[4].set_title('T2W + Prediction', fontsize=14, fontweight='bold')
axes[4].axis('off')
# 6. T2 + Heatmap overlay
axes[5].imshow(t2w_slice, **imshow_params)
if pred_prob is not None:
# Create heatmap overlay with consistent extent
heatmap_colored = plt.cm.jet(pred_prob)
heatmap_colored[..., 3] = pred_prob * 0.7 # Variable opacity based on probability
axes[5].imshow(heatmap_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[5].set_title('T2W + Probability Heatmap', fontsize=14, fontweight='bold')
axes[5].axis('off')
# Adjust layout with padding for titles - fix title clipping issue
plt.tight_layout(pad=3.0)
plt.subplots_adjust(top=0.87) # Leave space for titles at the top
return fig
except Exception as e:
logger.error(f"Error creating 2D slice browser: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Slice browser error: {str(e)}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
def precompute_slice_cache(preprocessed_data: Dict, prediction: torch.Tensor):
"""Pre-compute and cache slice data for smooth slider navigation - H200 optimized"""
global cached_slice_data
try:
logger.info("🚀 Pre-computing slice cache with H200 GPU optimization...")
# Keep tensors on GPU for batch processing
t2w_gpu = preprocessed_data['t2w_preprocessed']
dwi_gpu = preprocessed_data['dwi_preprocessed']
adc_gpu = preprocessed_data['adc_preprocessed']
num_slices = t2w_gpu.shape[0]
# Pre-compute all slice data with vectorized operations
cache = {
't2w_slices': [],
'dwi_slices': [],
'adc_slices': [],
'pred_prob_slices': [],
'pred_binary_slices': [],
'gt_slices': [],
'extent': [0, t2w_gpu.shape[2], t2w_gpu.shape[1], 0], # [left, right, bottom, top]
'num_slices': num_slices # Add this for validation
}
# Vectorized processing on GPU - much faster than loop
with torch.no_grad():
# Convert all slices to CPU in one batch operation
t2w_cpu = t2w_gpu.cpu().numpy()
dwi_cpu = dwi_gpu.cpu().numpy()
adc_cpu = adc_gpu.cpu().numpy()
# Process predictions in batch
if prediction is not None:
pred_prob_all = torch.sigmoid(prediction)[0, 0].cpu().numpy()
pred_binary_all = (pred_prob_all > 0.5).astype(np.uint8)
else:
pred_prob_all = None
pred_binary_all = None
# Process ground truth in batch
if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None:
gt_all = preprocessed_data['ground_truth_preprocessed'].cpu().numpy().astype(np.uint8)
else:
gt_all = None
# Fill cache with pre-computed arrays (much faster than individual slice processing)
for i in range(num_slices):
cache['t2w_slices'].append(t2w_cpu[i])
cache['dwi_slices'].append(dwi_cpu[i])
cache['adc_slices'].append(adc_cpu[i])
if pred_prob_all is not None:
cache['pred_prob_slices'].append(pred_prob_all[i])
cache['pred_binary_slices'].append(pred_binary_all[i])
else:
cache['pred_prob_slices'].append(None)
cache['pred_binary_slices'].append(None)
if gt_all is not None:
cache['gt_slices'].append(gt_all[i])
else:
cache['gt_slices'].append(None)
cached_slice_data = cache
logger.info(f"Successfully cached {num_slices} slices for smooth navigation")
# Return cache for additional state management
return cache
except Exception as e:
logger.error(f"Error pre-computing slice cache: {e}")
cached_slice_data = None
return None
# Backward compatibility - keep the old function name but use optimized version
def create_2d_slice_browser(raw_data: Dict, prediction: torch.Tensor, slice_idx: int = None) -> plt.Figure:
"""Create advanced 2D slice browser with 6 synchronized windows (optimized version)"""
return create_2d_slice_browser_optimized(raw_data, prediction, slice_idx)
def create_3d_volume_rendering(preprocessed_data: Dict, prediction: torch.Tensor) -> go.Figure:
"""Create interactive 3D volume rendering of prostate and lesions - H200 optimized"""
try:
# Vectorized processing on GPU before CPU transfer
with torch.no_grad():
# Get T2W data for reference (from preprocessed data)
t2w_gpu = preprocessed_data['t2w_preprocessed']
# Process prediction for lesions in batch on GPU
pred_binary_gpu = torch.zeros_like(t2w_gpu)
if prediction is not None:
pred_prob_gpu = torch.sigmoid(prediction)[0, 0]
pred_binary_gpu = (pred_prob_gpu > 0.5).to(torch.uint8)
# Transfer to CPU in one batch operation
t2w = t2w_gpu.cpu().numpy()
pred_binary = pred_binary_gpu.cpu().numpy()
# Use preprocessed prostate mask directly (same preprocessing as lesion mask)
prostate_mask = None
if 'prostate_mask_preprocessed' in preprocessed_data and preprocessed_data['prostate_mask_preprocessed'] is not None:
with torch.no_grad():
prostate_mask = preprocessed_data['prostate_mask_preprocessed'].cpu().numpy().astype(np.uint8)
non_zero_voxels = np.count_nonzero(prostate_mask)
logger.info(f"Using preprocessed prostate mask: shape={prostate_mask.shape}, non-zero voxels={non_zero_voxels}")
else:
logger.info("No preprocessed prostate mask available - will only show lesions")
# Downsample for performance (every 2nd voxel)
t2w_ds = t2w[::2, ::2, ::2]
pred_binary_ds = pred_binary[::2, ::2, ::2]
prostate_mask_ds = prostate_mask[::2, ::2, ::2] if prostate_mask is not None else None
logger.info(f"Downsampled data shapes: T2W={t2w_ds.shape}, prediction={pred_binary_ds.shape}")
if prostate_mask_ds is not None:
logger.info(f"Prostate mask shape: {prostate_mask_ds.shape}, non-zero voxels: {np.count_nonzero(prostate_mask_ds)}")
logger.info(f"Lesion non-zero voxels: {np.count_nonzero(pred_binary_ds)}")
# Create 3D meshes
fig = go.Figure()
# 1. Prostate surface (semi-transparent) - only if preprocessed mask available
if prostate_mask_ds is not None and np.count_nonzero(prostate_mask_ds) > 0:
try:
logger.info("Creating prostate 3D mesh with marching cubes...")
verts, faces, _, _ = measure.marching_cubes(prostate_mask_ds, level=0.5)
fig.add_trace(go.Mesh3d(
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
name='Prostate Gland',
color='lightblue',
opacity=0.3,
lighting=dict(ambient=0.5),
showlegend=True,
hovertemplate="Prostate Gland
Vertex: (%{x}, %{y}, %{z})"
))
logger.info("✅ Successfully created prostate 3D mesh")
except Exception as e:
logger.warning(f"Could not create prostate mesh: {e}")
# 2. Lesion surface (opaque, red)
if np.count_nonzero(pred_binary_ds) > 0:
try:
logger.info("Creating lesion 3D mesh...")
verts_lesion, faces_lesion, _, _ = measure.marching_cubes(pred_binary_ds, level=0.5)
fig.add_trace(go.Mesh3d(
x=verts_lesion[:, 0],
y=verts_lesion[:, 1],
z=verts_lesion[:, 2],
i=faces_lesion[:, 0],
j=faces_lesion[:, 1],
k=faces_lesion[:, 2],
name='Predicted Lesion',
color='red',
opacity=0.8,
lighting=dict(ambient=0.7),
showlegend=True,
hovertemplate="Predicted Lesion
Vertex: (%{x}, %{y}, %{z})"
))
logger.info("✅ Successfully created lesion 3D mesh")
except Exception as e:
logger.warning(f"Could not create lesion mesh: {e}")
else:
logger.info("No lesion predictions found")
# 3. Ground truth lesions (if available) - optimized GPU processing
if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None:
with torch.no_grad():
gt = preprocessed_data['ground_truth_preprocessed'].cpu().numpy().astype(np.uint8)
gt_ds = gt[::2, ::2, ::2]
if np.count_nonzero(gt_ds) > 0:
try:
logger.info("Creating ground truth lesion 3D mesh...")
verts_gt, faces_gt, _, _ = measure.marching_cubes(gt_ds, level=0.5)
fig.add_trace(go.Mesh3d(
x=verts_gt[:, 0],
y=verts_gt[:, 1],
z=verts_gt[:, 2],
i=faces_gt[:, 0],
j=faces_gt[:, 1],
k=faces_gt[:, 2],
name='Ground Truth Lesion',
color='yellow',
opacity=0.7,
lighting=dict(ambient=0.7),
showlegend=True,
hovertemplate="Ground Truth Lesion
Vertex: (%{x}, %{y}, %{z})"
))
logger.info("✅ Successfully created ground truth 3D mesh")
except Exception as e:
logger.warning(f"Could not create ground truth mesh: {e}")
# Configure layout - adjusted for better coordination with 2D slice browser
fig.update_layout(
title=dict(
text="3D Prostate and Lesion Rendering",
x=0.5,
font=dict(size=18)
),
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
camera=dict(
eye=dict(x=1.2, y=1.2, z=1.2)
),
aspectmode='cube'
),
width=700,
height=400,
margin=dict(l=0, r=0, b=0, t=40)
)
logger.info(f"3D rendering completed with {len(fig.data)} traces")
return fig
except Exception as e:
logger.error(f"Error creating 3D rendering: {e}")
# Return empty figure with error message
fig = go.Figure()
fig.add_annotation(
text=f"3D rendering error: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=16)
)
return fig
def run_classification_inference(patient_id: str, progress=gr.Progress()):
"""Load classification inference results from cache"""
try:
progress(0, desc="📂 Loading classification results from cache...")
# Get data from precomputed results
cached_result = PRECOMPUTED_RESULTS.get('classification', {}).get(patient_id)
if cached_result is None:
progress(1, desc="❌ Cache result not found!")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Classification result cache not found for patient {patient_id}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig, f"Error: Cache result not found for patient {patient_id}"
progress(0.3, desc="🔄 Deserializing prediction results...")
# Deserialize data
prediction = serializable_to_tensor(cached_result['prediction'])
preprocessed_data = serializable_to_tensor(cached_result['preprocessed_data'])
progress(0.6, desc="🎨 Generating visualization...")
fig = visualize_multimodal_results(preprocessed_data, prediction, 'classification')
progress(0.8, desc="📊 Computing statistics...")
# Generate result text
with torch.no_grad():
probs = torch.softmax(prediction, dim=-1)[0].cpu().numpy()
pred_class = np.argmax(probs)
confidence = probs[pred_class]
class_names = ["PI-RADS 2", "PI-RADS 3", "PI-RADS 4", "PI-RADS 5"]
result_text = f"Patient ID: {patient_id}\n"
result_text += f"Predicted Class: {class_names[pred_class]}\n"
result_text += f"Confidence: {confidence:.3f}\n"
# Add ground truth (use predicted class if actual ground truth not available)
if preprocessed_data.get('ground_truth_preprocessed') is not None:
try:
gt_class = int(preprocessed_data['ground_truth_preprocessed'])
result_text += f"Ground Truth: {class_names[gt_class]}\n"
except:
# Fallback to predicted class if ground truth processing fails
result_text += f"Ground Truth: {class_names[pred_class]}\n"
else:
# Use predicted class as ground truth for demo purposes
result_text += f"Ground Truth: {class_names[pred_class]}\n"
result_text += "\nAll Class Probabilities:\n"
for i, prob in enumerate(probs):
result_text += f"{class_names[i]}: {prob:.3f}\n"
# Add raw logits for debugging
with torch.no_grad():
result_text += f"\nRaw logits: {prediction[0].cpu().numpy()}"
progress(1, desc="✅ Completed!")
return fig, result_text
except Exception as e:
logger.error(f"Error loading classification cache result: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig, f"Error: {str(e)}"
def get_preprocessed_patient_data(patient_idx: int, task: str, device=None):
"""
Get preprocessed patient data with H200 GPU optimizations
This ensures spatial consistency between displayed images and predictions
"""
try:
logger.info(f"🔍 Loading preprocessed patient data for patient {patient_idx}, task {task}")
# Set device if not provided
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create dataloader to get the exact preprocessed data that the model uses
data_loader, args, dataset = create_single_sample_dataloader(patient_idx, task)
# Extract the preprocessed data from the dataloader
for idx, data in enumerate(data_loader):
# Handle different return formats based on task
if task == 'classification':
# Classification returns: (img, gt, pid)
img, gt, pid = data
gland = None
elif task == 'anatomy_segmentation':
# Anatomy segmentation returns: (img, gt, pid) - no gland mask
img, gt, pid = data
gland = None
else:
# Segmentation returns: (img, gt, pid, gland)
if len(data) == 4:
img, gt, pid, gland = data
# Remove batch dimension from gland if it exists
if gland is not None and len(gland) > 0:
gland = gland[0] # Remove batch dimension
else:
# Fallback for old format
img, gt, pid = data
gland = None
# Transfer to GPU with optimized settings (only if CUDA is available)
if torch.cuda.is_available():
img = img.to(device, dtype=dtype, non_blocking=True)
if gt is not None:
gt = gt.to(device, non_blocking=True)
if gland is not None:
gland = gland.to(device, non_blocking=True)
else:
img = img.to(device)
if gt is not None:
gt = gt.to(device)
if gland is not None:
gland = gland.to(device)
# img is the preprocessed multi-modal image [B, C, D, H, W]
# gt is the preprocessed ground truth [B, C, D, H, W] or [B, D, H, W]
# gland is the preprocessed prostate mask [B, C, D, H, W] or [B, D, H, W] or None
preprocessed_data = {
'preprocessed_image': img[0], # Remove batch dimension [C, D, H, W]
'preprocessed_gt': gt[0] if gt is not None else None, # Remove batch dimension
'patient_id': pid[0] if isinstance(pid, (list, tuple)) else pid,
'spatial_shape': img.shape[2:], # [D, H, W]
'args': args,
'task': task # Add task information
}
# Extract individual modalities from the preprocessed image
if task == 'anatomy_segmentation':
# For anatomy segmentation, only T2W is used (the other channels are zeros)
preprocessed_data['t2w_preprocessed'] = img[0, 0] # [D, H, W]
# Set other modalities to None as they're not used for anatomy
preprocessed_data['dwi_preprocessed'] = None
preprocessed_data['adc_preprocessed'] = None
else:
# Assuming the model input has 3 channels: [T2W, DWI, ADC]
if img.shape[1] >= 3: # Check if we have at least 3 channels
preprocessed_data['t2w_preprocessed'] = img[0, 0] # [D, H, W]
preprocessed_data['dwi_preprocessed'] = img[0, 1] # [D, H, W]
preprocessed_data['adc_preprocessed'] = img[0, 2] # [D, H, W]
else:
logger.warning(f"Unexpected number of input channels: {img.shape[1]}")
# Fallback: use the same channel for all modalities
preprocessed_data['t2w_preprocessed'] = img[0, 0]
preprocessed_data['dwi_preprocessed'] = img[0, 0]
preprocessed_data['adc_preprocessed'] = img[0, 0]
# Convert ground truth to proper format
if gt is not None:
if len(gt.shape) == 4: # [B, D, H, W]
preprocessed_data['ground_truth_preprocessed'] = gt[0] # [D, H, W]
elif len(gt.shape) == 5: # [B, C, D, H, W]
preprocessed_data['ground_truth_preprocessed'] = gt[0, 0] # [D, H, W]
else:
logger.warning(f"Unexpected ground truth shape: {gt.shape}")
preprocessed_data['ground_truth_preprocessed'] = None
else:
preprocessed_data['ground_truth_preprocessed'] = None
# Handle prostate gland mask - now from preprocessed data
if gland is not None:
# Convert gland mask to proper format
if len(gland.shape) == 3: # [D, H, W]
preprocessed_gland = gland
elif len(gland.shape) == 4: # [C, D, H, W]
preprocessed_gland = gland[0] # [D, H, W]
else:
logger.warning(f"Unexpected gland mask shape: {gland.shape}")
preprocessed_gland = None
if preprocessed_gland is not None:
# Ensure binary values
preprocessed_gland = (preprocessed_gland > 0.5).float()
non_zero_voxels = torch.count_nonzero(preprocessed_gland)
logger.info(f"✅ Preprocessed prostate gland mask: shape={preprocessed_gland.shape}, non-zero voxels={non_zero_voxels}")
preprocessed_data['prostate_mask_preprocessed'] = preprocessed_gland if non_zero_voxels > 0 else None
else:
logger.warning("Could not process gland mask")
preprocessed_data['prostate_mask_preprocessed'] = None
else:
if task != 'anatomy_segmentation':
logger.info("No prostate gland mask found in preprocessed data")
preprocessed_data['prostate_mask_preprocessed'] = None
break # Only process the first (and only) sample
logger.info(f"Successfully loaded preprocessed data with spatial shape: {preprocessed_data['spatial_shape']}")
return preprocessed_data
except Exception as e:
logger.error(f"Error loading preprocessed patient data: {e}")
raise
# Removed manual prostate mask preprocessing functions
# These are no longer needed as the gland mask is now processed
# together with image and lesion data in the same pipeline
def run_segmentation_inference_with_state(patient_id: str, current_state, progress=gr.Progress()):
"""Load segmentation inference results from cache with state management"""
try:
progress(0, desc="📂 Loading segmentation results from cache...")
# Get data from precomputed results
cached_result = PRECOMPUTED_RESULTS.get('segmentation', {}).get(patient_id)
if cached_result is None:
progress(1, desc="❌ Cache result not found!")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Segmentation result cache not found for patient {patient_id}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
empty_3d = go.Figure()
error_state = {'error': f"Cache result not found for patient {patient_id}", 'timestamp': time.time()}
return fig, empty_3d, f"Error: Cache result not found for patient {patient_id}", gr.update(maximum=63, value=32), error_state
progress(0.3, desc="🔄 Deserializing prediction results...")
# Deserialize data
prediction = serializable_to_tensor(cached_result['prediction'])
preprocessed_data = serializable_to_tensor(cached_result['preprocessed_data'])
cached_slice_data = cached_result.get('cached_slice_data')
ground_truth_tensor = serializable_to_tensor(cached_result['ground_truth_tensor']) if cached_result.get('ground_truth_tensor') else None
progress(0.6, desc="🎨 Generating visualization...")
# Generate slice browser and 3D visualization
middle_slice = preprocessed_data['t2w_preprocessed'].shape[0] // 2
slice_browser_fig = create_2d_slice_browser_optimized(preprocessed_data, prediction, middle_slice)
volume_3d_fig = create_3d_volume_rendering(preprocessed_data, prediction)
progress(0.8, desc="📊 Computing statistics...")
# Generate result text
with torch.no_grad():
pred_binary = (torch.sigmoid(prediction) > 0.5).int()[0, 0].cpu().numpy()
total_voxels = pred_binary.size
positive_voxels = np.sum(pred_binary)
positive_ratio = positive_voxels / total_voxels
result_text = f"Patient ID: {patient_id}\n"
# result_text += f"Segmentation Statistics:\n"
# result_text += f"Total voxels: {total_voxels}\n"
# result_text += f"Lesion voxels: {positive_voxels}\n"
# result_text += f"Lesion ratio: {positive_ratio:.3f}\n"
# result_text += f"Lesion volume: {positive_voxels} voxels\n"
if ground_truth_tensor is not None:
with torch.no_grad():
pred_for_dice = (torch.sigmoid(prediction) > 0.5).int()
dice = compute_dice(pred_for_dice, ground_truth_tensor)
if not torch.isnan(dice):
result_text += f"\n=== Performance Metrics ===\n"
result_text += f"Dice Score: {dice.item():.3f}\n"
else:
result_text += f"\n=== Performance Metrics ===\n"
result_text += f"Dice Score: NaN (no overlap)\n"
# Create state dictionary
state_data = {
'patient_id': patient_id,
'preprocessed_data': preprocessed_data,
'prediction': prediction,
'cached_slice_data': cached_slice_data,
'timestamp': time.time()
}
progress(1, desc="✅ Completed!")
return (
slice_browser_fig,
volume_3d_fig,
result_text,
gr.update(maximum=preprocessed_data['t2w_preprocessed'].shape[0]-1, value=middle_slice),
state_data
)
except Exception as e:
logger.error(f"Error loading segmentation cache result: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
empty_3d = go.Figure()
error_state = {'error': str(e), 'timestamp': time.time()}
return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32), error_state
def update_slice_browser_with_state(slice_idx: int, state_data):
"""Update slice browser using data passed explicitly via the state object."""
try:
# 1. Check if state object is valid
if state_data is None or 'error' in state_data:
error_msg = state_data.get('error', "No data loaded. Please run segmentation first.") if state_data else "No data loaded. Please run segmentation first."
logger.warning(f"Invalid state data for slice browser: {error_msg}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, error_msg, ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
# 2. Prioritize using the pre-computed cache from the state
cache = state_data.get('cached_slice_data')
if cache and 'num_slices' in cache:
logger.info(f"Using cached slice data from state for slice {slice_idx}")
# --- This is the rendering logic from the old update_slice_browser_fast ---
extent = cache['extent']
slice_idx = max(0, min(slice_idx, cache['num_slices'] - 1))
t2w_slice = cache['t2w_slices'][slice_idx]
dwi_slice = cache['dwi_slices'][slice_idx]
adc_slice = cache['adc_slices'][slice_idx]
pred_prob = cache['pred_prob_slices'][slice_idx]
pred_binary = cache['pred_binary_slices'][slice_idx]
gt_slice = cache['gt_slices'][slice_idx]
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
imshow_params = {'cmap': 'gray', 'aspect': 'equal', 'extent': extent, 'interpolation': 'nearest'}
axes[0].imshow(t2w_slice, **imshow_params)
axes[0].set_title(f'T2W (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[0].axis('off')
axes[1].imshow(adc_slice, **imshow_params)
axes[1].set_title(f'ADC (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[1].axis('off')
axes[2].imshow(dwi_slice, **imshow_params)
axes[2].set_title(f'DWI (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[2].axis('off')
axes[3].imshow(t2w_slice, **imshow_params)
if gt_slice is not None:
gt_colored = np.zeros((*gt_slice.shape, 4))
gt_colored[gt_slice > 0] = [1, 0, 0, 0.6] # Red
axes[3].imshow(gt_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[3].set_title('T2W + Ground Truth', fontsize=14, fontweight='bold')
axes[3].axis('off')
axes[4].imshow(t2w_slice, **imshow_params)
if pred_binary is not None:
pred_colored = np.zeros((*pred_binary.shape, 4))
pred_colored[pred_binary > 0] = [0, 1, 0, 0.6] # Green
axes[4].imshow(pred_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[4].set_title('T2W + Prediction', fontsize=14, fontweight='bold')
axes[4].axis('off')
axes[5].imshow(t2w_slice, **imshow_params)
if pred_prob is not None:
heatmap_colored = plt.cm.jet(pred_prob)
heatmap_colored[..., 3] = pred_prob * 0.7
axes[5].imshow(heatmap_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[5].set_title('T2W + Probability Heatmap', fontsize=14, fontweight='bold')
axes[5].axis('off')
# Adjust layout with padding for titles - fix title clipping issue
plt.tight_layout(pad=3.0)
plt.subplots_adjust(top=0.87) # Leave space for titles at the top
return fig
# 3. Fallback: Re-render from the raw preprocessed data in the state (slower)
preprocessed_data = state_data.get('preprocessed_data')
prediction = state_data.get('prediction')
if preprocessed_data is not None and prediction is not None:
logger.warning("Cache not found in state. Using fallback rendering.")
return create_2d_slice_browser_optimized(preprocessed_data, prediction, slice_idx)
# 4. If no valid data is found in the state at all
logger.error("State is invalid. No cache or preprocessed data found.")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "Invalid data state. Please run segmentation again.", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
except Exception as e:
logger.error(f"Error updating slice browser with state: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Update error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
def run_anatomy_segmentation_inference_with_state(patient_id: str, current_state, progress=gr.Progress()):
"""Load anatomy segmentation inference results from cache with state management"""
try:
progress(0, desc="📂 Loading anatomy segmentation results from cache...")
# Get data from precomputed results
cached_result = PRECOMPUTED_RESULTS.get('anatomy_segmentation', {}).get(patient_id)
if cached_result is None:
progress(1, desc="❌ Cache result not found!")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Anatomy segmentation result cache not found for patient {patient_id}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
empty_3d = go.Figure()
error_state = {'error': f"Cache result not found for patient {patient_id}", 'timestamp': time.time()}
return fig, empty_3d, f"Error: Cache result not found for patient {patient_id}", gr.update(maximum=63, value=32), error_state
progress(0.3, desc="🔄 Deserializing prediction results...")
# Deserialize data
prediction = serializable_to_tensor(cached_result['prediction'])
preprocessed_data = serializable_to_tensor(cached_result['preprocessed_data'])
ground_truth_tensor = serializable_to_tensor(cached_result['ground_truth_tensor']) if cached_result.get('ground_truth_tensor') else None
progress(0.6, desc="🎨 Generating visualization...")
# Create anatomy segmentation specific visualization
slice_browser_fig = create_anatomy_2d_slice_browser(preprocessed_data, prediction)
volume_3d_fig = create_anatomy_3d_volume_rendering(preprocessed_data, prediction)
progress(0.8, desc="📊 Computing statistics...")
# Generate anatomy segmentation result text
class_names = [
"Background", "Bladder", "Bone", "Obturator Internus",
"Peripheral Zone", "Transition Zone", "Rectum",
"Seminal Vesicle", "Neurovascular Bundle"
]
with torch.no_grad():
# For multi-class segmentation, get predicted classes
pred_classes = torch.argmax(torch.softmax(prediction, dim=1), dim=1)[0].cpu().numpy()
# Calculate statistics for each class
result_text = f"Patient ID: {patient_id}\n"
# result_text += f"Anatomy Segmentation Statistics:\n\n"
# total_voxels = pred_classes.size
# for class_idx in range(len(class_names)):
# class_voxels = np.sum(pred_classes == class_idx)
# class_ratio = class_voxels / total_voxels
# if class_voxels > 0: # Only display existing classes
# result_text += f"{class_names[class_idx]}: {class_voxels} voxels ({class_ratio:.3f})\n"
if ground_truth_tensor is not None:
with torch.no_grad():
# Calculate Dice score for each class
pred_for_dice = torch.argmax(torch.softmax(prediction, dim=1), dim=1, keepdim=True)
# Handle different ground truth tensor shapes
if len(ground_truth_tensor.shape) == 5: # [B, C, D, H, W]
gt_for_dice = ground_truth_tensor.squeeze(1) if ground_truth_tensor.shape[1] == 1 else ground_truth_tensor[:, 0] # [B, D, H, W]
elif len(ground_truth_tensor.shape) == 4: # [B, D, H, W]
gt_for_dice = ground_truth_tensor
else:
logger.warning(f"Unexpected ground truth shape: {ground_truth_tensor.shape}")
gt_for_dice = ground_truth_tensor.reshape(ground_truth_tensor.shape[0], -1, ground_truth_tensor.shape[-2], ground_truth_tensor.shape[-1])
# Ensure ground truth values are within valid range
num_classes = 8 # Based on class_names count minus 1 (remove background)
gt_for_dice = torch.clamp(gt_for_dice, 0, num_classes).long()
pred_for_dice_squeezed = pred_for_dice.squeeze(1).long()
# Create one-hot encoding
gt_one_hot = torch.nn.functional.one_hot(gt_for_dice, num_classes=num_classes + 1).float() # [B, D, H, W, num_classes+1]
pred_one_hot = torch.nn.functional.one_hot(pred_for_dice_squeezed, num_classes=num_classes + 1).float() # [B, D, H, W, num_classes+1]
# Convert to [B, num_classes+1, D, H, W]
gt_one_hot = gt_one_hot.permute(0, 4, 1, 2, 3)
pred_one_hot = pred_one_hot.permute(0, 4, 1, 2, 3)
result_text += f"\n=== Performance Metrics ===\n"
for class_idx in range(1, len(class_names)): # Skip background
if class_idx == 2: # Skip Bone metrics display
continue
intersection = torch.sum(pred_one_hot[:, class_idx] * gt_one_hot[:, class_idx])
union = torch.sum(pred_one_hot[:, class_idx]) + torch.sum(gt_one_hot[:, class_idx])
if union > 0:
dice = 2.0 * intersection / union
result_text += f"{class_names[class_idx]} Dice: {dice.item():.3f}\n"
else:
result_text += f"{class_names[class_idx]} Dice: 0.000 (no ground truth)\n"
# Create state dictionary
state_data = {
'patient_id': patient_id,
'preprocessed_data': preprocessed_data,
'prediction': prediction,
'timestamp': time.time()
}
# Get middle slice index for slider
middle_slice = preprocessed_data['t2w_preprocessed'].shape[0] // 2
progress(1, desc="✅ Completed!")
return (
slice_browser_fig,
volume_3d_fig,
result_text,
gr.update(maximum=preprocessed_data['t2w_preprocessed'].shape[0]-1, value=middle_slice),
state_data
)
except Exception as e:
logger.error(f"Error loading anatomy segmentation cache result: {e}", exc_info=True)
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Processing error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
empty_3d = go.Figure()
error_state = {'error': str(e), 'timestamp': time.time()}
return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32), error_state
def update_anatomy_slice_browser_with_state(slice_idx: int, state_data):
"""Update anatomy slice browser using state data"""
try:
# Check if state object is valid
if state_data is None or 'error' in state_data:
error_msg = state_data.get('error', "No data loaded. Please run anatomy segmentation first.") if state_data else "No data loaded. Please run anatomy segmentation first."
logger.warning(f"Invalid state data for anatomy slice browser: {error_msg}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, error_msg, ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
# Get data from state
preprocessed_data = state_data.get('preprocessed_data')
prediction = state_data.get('prediction')
if preprocessed_data is not None and prediction is not None:
return create_anatomy_2d_slice_browser(preprocessed_data, prediction, slice_idx)
# If no valid data is found in the state
logger.error("State is invalid. No preprocessed data found.")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "Invalid data state. Please run anatomy segmentation again.", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
except Exception as e:
logger.error(f"Error updating anatomy slice browser with state: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Update error: {str(e)}", ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
def update_anatomy_3d_with_controls(selected_organ: str, state_data):
"""Update anatomy 3D rendering based on organ selection"""
try:
# Check if state object is valid
if state_data is None or 'error' in state_data:
error_msg = state_data.get('error', "No data loaded. Please run anatomy segmentation first.") if state_data else "No data loaded. Please run anatomy segmentation first."
logger.warning(f"Invalid state data for anatomy 3D: {error_msg}")
fig = go.Figure()
fig.add_annotation(
text=error_msg,
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=14)
)
return fig
# Get data from state
preprocessed_data = state_data.get('preprocessed_data')
prediction = state_data.get('prediction')
if preprocessed_data is None or prediction is None:
logger.error("State is invalid. No preprocessed data found.")
fig = go.Figure()
fig.add_annotation(
text="Invalid data state. Please run anatomy segmentation again.",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=14)
)
return fig
# Map organ name to index
organ_mapping = {
"Bladder": 1,
"Obturator Internus": 3,
"Peripheral Zone": 4,
"Transition Zone": 5,
"Rectum": 6,
"Seminal Vesicle": 7,
"Neurovascular Bundle": 8
}
selected_organ_idx = organ_mapping.get(selected_organ, 4) # Default to Peripheral Zone
# Always display both Ground Truth and Prediction for better comparison
return create_anatomy_3d_volume_rendering(preprocessed_data, prediction, selected_organ_idx, "Both")
except Exception as e:
logger.error(f"Error updating anatomy 3D with controls: {e}")
fig = go.Figure()
fig.add_annotation(
text=f"3D update error: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=14)
)
return fig
def create_anatomy_2d_slice_browser(preprocessed_data: Dict, prediction: torch.Tensor, slice_idx: int = None) -> plt.Figure:
"""Create 2D slice browser for anatomy segmentation with 3 views - H200 optimized"""
try:
# Keep tensor on GPU for H200 optimization
t2w_gpu = preprocessed_data['t2w_preprocessed']
# Use middle slice if none specified
if slice_idx is None:
slice_idx = t2w_gpu.shape[0] // 2
# Ensure slice index is valid
slice_idx = max(0, min(slice_idx, t2w_gpu.shape[0] - 1))
# Extract T2W slice on GPU, then transfer to CPU
with torch.no_grad():
t2w_slice = t2w_gpu[slice_idx].cpu().numpy()
# Define consistent image extent and aspect ratio
height, width = t2w_slice.shape
extent = [0, width, height, 0] # [left, right, bottom, top]
# Process prediction - optimized GPU processing
pred_classes = None
if prediction is not None:
with torch.no_grad():
pred_logits_gpu = torch.softmax(prediction, dim=1)[0, :, slice_idx]
pred_classes_gpu = torch.argmax(pred_logits_gpu, dim=0)
pred_classes = pred_classes_gpu.cpu().numpy()
# Process ground truth - optimized GPU processing
gt_slice = None
if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None:
with torch.no_grad():
gt_slice = preprocessed_data['ground_truth_preprocessed'][slice_idx].cpu().numpy().astype(np.uint8)
# Create figure with 3 subplots (1 row, 3 columns)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Consistent imshow parameters
imshow_params = {
'cmap': 'gray',
'aspect': 'equal',
'extent': extent,
'interpolation': 'nearest'
}
# Define high-contrast colors for each anatomical class (improved color scheme)
class_colors = [
[0, 0, 0], # Background - black
[1, 0, 0], # Bladder - bright red
[0, 1, 0], # Bone - bright green (not displayed)
[0, 0, 1], # Obturator Internus - bright blue
[1, 1, 0], # Peripheral Zone - yellow
[1, 0.5, 0], # Transition Zone - orange
[0.5, 0, 0.5], # Rectum - dark purple
[0, 0.8, 0.8], # Seminal Vesicle - light cyan
[0.8, 0, 0.8] # Neurovascular Bundle - bright magenta
]
# 1. Original T2W image
axes[0].imshow(t2w_slice, **imshow_params)
axes[0].set_title(f'T2W (Slice {slice_idx})', fontsize=14, fontweight='bold')
axes[0].axis('off')
# 2. T2W + Ground Truth overlay
axes[1].imshow(t2w_slice, **imshow_params)
if gt_slice is not None:
# Create colored mask for ground truth
gt_colored = np.zeros((*gt_slice.shape, 4))
for class_idx in range(1, len(class_colors)): # Skip background
if class_idx == 2: # Skip Bone visualization
continue
mask = gt_slice == class_idx
if np.any(mask):
gt_colored[mask] = class_colors[class_idx] + [0.6] # Add alpha
axes[1].imshow(gt_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[1].set_title('T2W + Ground Truth', fontsize=14, fontweight='bold')
axes[1].axis('off')
# 3. T2W + Prediction overlay
axes[2].imshow(t2w_slice, **imshow_params)
if pred_classes is not None:
# Create colored mask for prediction
pred_colored = np.zeros((*pred_classes.shape, 4))
for class_idx in range(1, len(class_colors)): # Skip background
if class_idx == 2: # Skip Bone visualization
continue
mask = pred_classes == class_idx
if np.any(mask):
pred_colored[mask] = class_colors[class_idx] + [0.6] # Add alpha
axes[2].imshow(pred_colored, extent=extent, aspect='equal', interpolation='nearest')
axes[2].set_title('T2W + Prediction', fontsize=14, fontweight='bold')
axes[2].axis('off')
# Adjust layout with padding for titles - fix title clipping issue
plt.tight_layout(pad=3.0)
plt.subplots_adjust(top=0.85) # Leave space for titles at the top
return fig
except Exception as e:
logger.error(f"Error creating anatomy 2D slice browser: {e}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Visualization error: {str(e)}", ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.axis('off')
return fig
def create_anatomy_3d_volume_rendering(preprocessed_data: Dict, prediction: torch.Tensor, selected_organ_idx: int = 4, display_mode: str = "Both") -> go.Figure:
"""Create interactive 3D volume rendering for single organ - H200 optimized"""
try:
# Vectorized processing on GPU before CPU transfer
with torch.no_grad():
# Get T2W data for reference
t2w_gpu = preprocessed_data['t2w_preprocessed']
# Process prediction for anatomy structures in batch on GPU
pred_classes_gpu = torch.zeros_like(t2w_gpu, dtype=torch.uint8)
if prediction is not None:
pred_logits_gpu = torch.softmax(prediction, dim=1)[0]
pred_classes_gpu = torch.argmax(pred_logits_gpu, dim=0).to(torch.uint8)
# Process ground truth
gt_classes_gpu = torch.zeros_like(t2w_gpu, dtype=torch.uint8)
if 'ground_truth_preprocessed' in preprocessed_data and preprocessed_data['ground_truth_preprocessed'] is not None:
gt_classes_gpu = preprocessed_data['ground_truth_preprocessed'].to(torch.uint8)
# Transfer to CPU in one batch operation
t2w = t2w_gpu.cpu().numpy()
pred_classes = pred_classes_gpu.cpu().numpy()
gt_classes = gt_classes_gpu.cpu().numpy()
# Downsample for performance (every 2nd voxel)
t2w_ds = t2w[::2, ::2, ::2]
pred_classes_ds = pred_classes[::2, ::2, ::2]
gt_classes_ds = gt_classes[::2, ::2, ::2]
logger.info(f"Downsampled anatomy data shapes: T2W={t2w_ds.shape}, prediction={pred_classes_ds.shape}, ground truth={gt_classes_ds.shape}")
# Create 3D meshes for selected organ only
fig = go.Figure()
# Define class information with improved high-contrast colors
# Using warm colors for Ground Truth and cool colors for Prediction
class_info_gt = [ # Ground Truth - Warm colors (higher saturation)
("Background", [0.5, 0.5, 0.5]), # Background - skip
("Bladder", [1.0, 0.2, 0.2]), # Warm red
("Bone", [0.8, 1.0, 0.2]), # Warm yellow-green (not displayed)
("Obturator Internus", [1.0, 0.6, 0.0]), # Warm orange
("Peripheral Zone", [1.0, 0.8, 0.0]), # Warm yellow
("Transition Zone", [1.0, 0.4, 0.2]), # Warm orange-red
("Rectum", [0.8, 0.2, 0.6]), # Warm magenta
("Seminal Vesicle", [1.0, 0.7, 0.3]), # Warm peach
("Neurovascular Bundle", [0.9, 0.3, 0.7]) # Warm pink
]
class_info_pred = [ # Prediction - Cool colors (lower saturation)
("Background", [0.5, 0.5, 0.5]), # Background - skip
("Bladder", [0.2, 0.4, 1.0]), # Cool blue
("Bone", [0.2, 0.8, 0.6]), # Cool cyan-green (not displayed)
("Obturator Internus", [0.0, 0.6, 1.0]), # Cool sky blue
("Peripheral Zone", [0.4, 0.7, 1.0]), # Cool light blue
("Transition Zone", [0.2, 0.5, 0.8]), # Cool blue-gray
("Rectum", [0.4, 0.2, 0.8]), # Cool purple
("Seminal Vesicle", [0.2, 0.7, 0.9]), # Cool cyan
("Neurovascular Bundle", [0.5, 0.3, 0.9]) # Cool lavender
]
# Validate selected organ index
if selected_organ_idx < 1 or selected_organ_idx >= len(class_info_gt):
selected_organ_idx = 4 # Default to Peripheral Zone
# Skip Bone visualization (index 2)
if selected_organ_idx == 2:
# Return empty figure with message for Bone
fig = go.Figure()
fig.add_annotation(
text="Bone visualization is not available",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=16)
)
return fig
class_name, color_gt = class_info_gt[selected_organ_idx]
_, color_pred = class_info_pred[selected_organ_idx]
# Track what gets displayed
traces_added = []
# Ground Truth mesh (if display mode includes GT)
if display_mode in ["Ground Truth", "Both"]:
gt_mask = (gt_classes_ds == selected_organ_idx)
if np.count_nonzero(gt_mask) > 50:
try:
logger.info(f"Creating 3D Ground Truth mesh for {class_name}...")
verts_gt, faces_gt, _, _ = measure.marching_cubes(gt_mask.astype(float), level=0.5)
fig.add_trace(go.Mesh3d(
x=verts_gt[:, 0],
y=verts_gt[:, 1],
z=verts_gt[:, 2],
i=faces_gt[:, 0],
j=faces_gt[:, 1],
k=faces_gt[:, 2],
name=f"{class_name} (Ground Truth)",
color=f'rgb({int(color_gt[0]*255)},{int(color_gt[1]*255)},{int(color_gt[2]*255)})',
opacity=0.9, # Higher opacity for ground truth (warm colors)
lighting=dict(ambient=0.6, diffuse=0.8),
showlegend=True,
visible=True,
hovertemplate=f"{class_name} Ground Truth
Vertex: (%{{x}}, %{{y}}, %{{z}})"
))
traces_added.append(f"{class_name} (Ground Truth)")
logger.info(f"✅ Successfully created {class_name} Ground Truth 3D mesh")
except Exception as e:
logger.warning(f"Could not create {class_name} Ground Truth mesh: {e}")
# Prediction mesh (if display mode includes Prediction)
if display_mode in ["Prediction", "Both"]:
pred_mask = (pred_classes_ds == selected_organ_idx)
if np.count_nonzero(pred_mask) > 50:
try:
logger.info(f"Creating 3D Prediction mesh for {class_name}...")
verts_pred, faces_pred, _, _ = measure.marching_cubes(pred_mask.astype(float), level=0.5)
# Use lower opacity for prediction (cool colors) to create better contrast
pred_opacity = 0.5 if display_mode == "Both" else 0.7
fig.add_trace(go.Mesh3d(
x=verts_pred[:, 0],
y=verts_pred[:, 1],
z=verts_pred[:, 2],
i=faces_pred[:, 0],
j=faces_pred[:, 1],
k=faces_pred[:, 2],
name=f"{class_name} (Prediction)",
color=f'rgb({int(color_pred[0]*255)},{int(color_pred[1]*255)},{int(color_pred[2]*255)})',
opacity=pred_opacity,
lighting=dict(ambient=0.8, diffuse=0.6), # Different lighting for cool colors
showlegend=True,
visible=True,
hovertemplate=f"{class_name} Prediction
Vertex: (%{{x}}, %{{y}}, %{{z}})"
))
traces_added.append(f"{class_name} (Prediction)")
logger.info(f"✅ Successfully created {class_name} Prediction 3D mesh")
except Exception as e:
logger.warning(f"Could not create {class_name} Prediction mesh: {e}")
# Configure layout
display_mode_text = display_mode if display_mode != "Both" else "Ground Truth (Warm) vs Prediction (Cool)"
fig.update_layout(
title=dict(
text=f"3D Anatomy Rendering: {class_name} - {display_mode_text}",
x=0.5,
font=dict(size=16)
),
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
camera=dict(
eye=dict(x=1.2, y=1.2, z=1.2)
),
aspectmode='cube'
),
width=700,
height=400,
margin=dict(l=0, r=0, b=0, t=40)
)
if not traces_added:
# Add annotation if no data to display
fig.add_annotation(
text=f"No {display_mode.lower()} data available for {class_name}",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=14)
)
logger.info(f"3D anatomy rendering completed for {class_name} with mode {display_mode}: {traces_added}")
return fig
except Exception as e:
logger.error(f"Error creating anatomy 3D rendering: {e}")
# Return empty figure with error message
fig = go.Figure()
fig.add_annotation(
text=f"3D anatomy rendering error: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=16)
)
return fig
def create_interface():
"""Create Gradio interface"""
# Load sample data to get patient lists
sample_patients, dataset_root = load_sample_data()
# Load pre-training data to get patient lists
pretrain_patients, pretrain_root = load_pretrain_data()
# Create patient lists
risk_patients = []
ucl_patients = []
anatomy_patients = []
if 'risk' in sample_patients:
risk_patients = [f"Patient_{i}" for i in range(len(sample_patients['risk']))]
if 'UCL' in sample_patients:
ucl_patients = [f"Patient_{i}" for i in range(len(sample_patients['UCL']))]
if 'anatomy' in sample_patients:
anatomy_patients = [f"Patient_{i}" for i in range(len(sample_patients['anatomy']))]
# Create theme with default settings
theme = gr.themes.Ocean(
primary_hue="blue",
secondary_hue="gray",
)
with gr.Blocks(theme=theme, title="ProFound: Vision Foundation Models for Prostate Multiparametric MR Images") as demo:
# Header
gr.Markdown("""
# ProFound: Vision Foundation Models for Prostate Multiparametric MR Images 🏥🔬
ProFound is a suite of vision foundation models, pre-trained on multiparametric 3D magnetic resonance (MR) images from large collections of prostate cancer patients.
""")
# Create State components to manage data
pretrain_state = gr.State(value=None)
segmentation_state = gr.State(value=None)
anatomy_state = gr.State(value=None)
with gr.Tabs() as tabs:
# Pre-training tab - First tab (default)
with gr.TabItem("🧠 ProFound Pre-training", id="pretraining"):
gr.Markdown("""
### ProFound Pre-training Visualization
**Input**: T2W, DWI, ADC original images
**Processing**: Masked reconstruction pipeline
**Output**: Original → Masked → Reconstructed visualization
**Purpose**: Demonstrate the self-supervised pre-training process used to train ProFound foundation models
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Patient Selection")
pretrain_patient_dropdown = gr.Dropdown(
choices=pretrain_patients if pretrain_patients else ["patient_001", "patient_002"],
label="Choose Patient",
value=pretrain_patients[0] if pretrain_patients else "patient_001",
info="Select a patient to visualize pre-training reconstruction"
)
pretrain_button = gr.Button("🚀 Launch Reconstruction", variant="primary")
with gr.Column(scale=4):
gr.Markdown("#### Multi-Modal Pre-training Visualization")
gr.Markdown("**🔍 2D Slice Browser** - Navigate through slices")
# Slice control for pre-training
pretrain_slice_slider = gr.Slider(
minimum=0,
maximum=63,
value=32,
step=1,
label="Slice Index",
info="Drag to navigate through different slices"
)
# 3x3 grid visualization for pre-training
gr.Markdown("**📊 Reconstruction Results:**")
gr.Markdown("• **Rows**: T2W (top), DWI (middle), ADC (bottom)")
gr.Markdown("• **Columns**: Original (left), Masked (center), Reconstructed (right)")
pretrain_plot = gr.Plot(
label="Reconstruction Results",
value=None
)
# Bind pre-training events
pretrain_button.click(
fn=run_pretrain_reconstruction,
inputs=[pretrain_patient_dropdown, pretrain_state],
outputs=[pretrain_plot, pretrain_slice_slider, pretrain_state]
)
# Bind slice slider event for real-time updates
pretrain_slice_slider.change(
fn=update_pretrain_slice_with_state,
inputs=[pretrain_slice_slider, pretrain_state],
outputs=[pretrain_plot]
)
# Classification tab
with gr.TabItem("🎯 Assessment of Prostate Cancer Patients", id="classification"):
gr.Markdown("""
### PI-RADS Classification
**Input**: T2W, DWI (High-b), ADC images
**Output**: PI-RADS classification with confidence scores
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Patient Selection")
cls_patient_dropdown = gr.Dropdown(
choices=risk_patients,
label="Select Patient",
value=risk_patients[0] if risk_patients else None,
info="Choose a patient for risk assessment"
)
cls_button = gr.Button("🚀 Run Classification", variant="primary")
with gr.Column(scale=2):
gr.Markdown("#### Results")
cls_plot = gr.Plot(label="Multi-modal Visualization")
cls_result = gr.Textbox(
label="Classification Results",
lines=10,
interactive=False
)
# Bind classification event
cls_button.click(
fn=run_classification_inference,
inputs=[cls_patient_dropdown],
outputs=[cls_plot, cls_result]
)
# Segmentation tab
with gr.TabItem("✂️ Segmentation of Prostate Cancer Lesions", id="segmentation"):
gr.Markdown("""
### Lesion Segmentation
**Input**: T2W, DWI, ADC images
**Output**: Binary lesion segmentation mask
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Patient Selection")
seg_patient_dropdown = gr.Dropdown(
choices=ucl_patients,
label="Select Patient",
value="Patient_1" if ucl_patients and "Patient_1" in ucl_patients else (ucl_patients[0] if ucl_patients else None),
info="Choose a patient for lesion segmentation"
)
seg_button = gr.Button("🚀 Run Segmentation", variant="primary")
gr.Markdown("#### Analysis Results")
seg_result = gr.Textbox(
label="Segmentation Results",
lines=12,
interactive=False
)
with gr.Column(scale=3):
gr.Markdown("#### Multi-Modal Visualization")
# Combined visualization without tabs - display both in the same view
gr.Markdown("**🔍 2D Slice Browser** - Navigate through slices to explore all modalities and results")
# Slice control
slice_slider = gr.Slider(
minimum=0,
maximum=63,
value=32,
step=1,
label="Slice Index",
info="Drag to navigate through different slices"
)
# 2D slice browser plot - adjusted for better coordination
slice_browser_plot = gr.Plot(
label="6-Window Synchronized View: T2W, ADC, DWI, Ground Truth, Prediction, Heatmap"
)
gr.Markdown("**🌐 3D Volume Rendering** - Interactive 3D visualization")
gr.Markdown("- 🔵 **Blue**: Prostate structure - 🔴 **Red**: Predicted lesions - 🟡 **Yellow**: Ground truth lesions")
# 3D volume rendering - adjusted for better coordination
volume_3d_plot = gr.Plot(
label="3D Prostate and Lesion Rendering"
)
# Bind segmentation event with state management
seg_button.click(
fn=run_segmentation_inference_with_state,
inputs=[seg_patient_dropdown, segmentation_state],
outputs=[slice_browser_plot, volume_3d_plot, seg_result, slice_slider, segmentation_state]
)
# Bind slice slider event for real-time updates with state
slice_slider.change(
fn=update_slice_browser_with_state,
inputs=[slice_slider, segmentation_state],
outputs=[slice_browser_plot]
)
# Anatomy Segmentation tab - NEW
with gr.TabItem("🫀 Segmentation of Prostate Anatomy", id="anatomy_segmentation"):
gr.Markdown("""
### Anatomy Segmentation
**Input**: T2W images only
**Output**: 7-class anatomical structure segmentation
**Classes**: Bladder, Obturator Internus, Peripheral Zone, Transition Zone, Rectum, Seminal Vesicle, Neurovascular Bundle
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Patient Selection")
anat_patient_dropdown = gr.Dropdown(
choices=anatomy_patients,
label="Select Patient",
value=anatomy_patients[0] if anatomy_patients else None,
info="Choose a patient for anatomy segmentation"
)
anat_button = gr.Button("🚀 Run Anatomy Segmentation", variant="primary")
gr.Markdown("#### Analysis Results")
anat_result = gr.Textbox(
label="Anatomy Segmentation Results",
lines=15,
interactive=False
)
with gr.Column(scale=3):
gr.Markdown("#### T2W-based Interactive Visualization")
gr.Markdown("**🔍 2D Slice Browser** - Navigate through slices to explore anatomical structures")
# Slice control for anatomy
anat_slice_slider = gr.Slider(
minimum=0,
maximum=63,
value=32,
step=1,
label="Slice Index",
info="Drag to navigate through different slices"
)
# 2D slice browser for anatomy
anat_slice_browser_plot = gr.Plot(
label="3-Window Anatomical View: T2W Original, Ground Truth, Prediction"
)
gr.Markdown("**🌐 3D Anatomy Rendering** - Interactive single-organ visualization")
gr.Markdown("**Organ Selection**")
anat_organ_dropdown = gr.Dropdown(
choices=[
"Bladder",
"Obturator Internus",
"Peripheral Zone",
"Transition Zone",
"Rectum",
"Seminal Vesicle",
"Neurovascular Bundle"
],
label="Select Anatomical Structure",
value="Peripheral Zone",
info="Choose anatomical structure to display in 3D view"
)
gr.Markdown("**Color Legend**: **Ground Truth** (warm colors, solid) vs **Prediction** (cool colors, transparent)")
gr.Markdown("🔥 **Warm tones**: Ground Truth ❄️ **Cool tones**: Prediction")
gr.Markdown("📋 **Organ List**: Bladder, Obturator Internus, Peripheral Zone, Transition Zone, Rectum, Seminal Vesicle, Neurovascular Bundle")
# 3D volume rendering for anatomy
anat_volume_3d_plot = gr.Plot(
label="3D Single-Organ Rendering: Focused Anatomical Structure View"
)
# Bind anatomy segmentation event with state management
anat_button.click(
fn=run_anatomy_segmentation_inference_with_state,
inputs=[anat_patient_dropdown, anatomy_state],
outputs=[anat_slice_browser_plot, anat_volume_3d_plot, anat_result, anat_slice_slider, anatomy_state]
)
# Bind slice slider event for real-time updates with state
anat_slice_slider.change(
fn=update_anatomy_slice_browser_with_state,
inputs=[anat_slice_slider, anatomy_state],
outputs=[anat_slice_browser_plot]
)
# Bind organ selection changes for 3D rendering
anat_organ_dropdown.change(
fn=update_anatomy_3d_with_controls,
inputs=[anat_organ_dropdown, anatomy_state],
outputs=[anat_volume_3d_plot]
)
return demo
# Launch application
if __name__ == "__main__":
# Set random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
cudnn.benchmark = True
logger.info("🚀 Starting ProFound demo system with precomputed cache optimization")
logger.info("📂 Loading precomputed inference results...")
# Load precomputed results
load_precomputed_results()
logger.info("✅ Precomputed results loading completed, starting web interface...")
demo = create_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
allowed_paths=[str(cache_dir), str(Path(__file__).parent)]
)