image-processing-studio / segment_neuroimaging.py
mmrech's picture
Upload segment_neuroimaging.py with huggingface_hub
1c561bc verified
"""
NPH Neuroimaging Segmentation Module
Segmentation and quantitative biomarker analysis for Normal Pressure Hydrocephalus
Supports CT Head, MRI T1, T2, FLAIR.
Computes: Evans' index, callosal angle, temporal horn width, z-Evans index,
DESH pattern assessment, periventricular hyperintensity scoring.
Author: Matheus Rech, MD
Version: 2.0.0 (NPH-focused)
"""
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
import warnings
# ---------------------------------------------------------------------------
# Enums and data classes
# ---------------------------------------------------------------------------
class Modality(Enum):
CT_HEAD = "ct_head"
T1 = "t1_weighted"
T1_GD = "t1_gadolinium"
T2 = "t2_weighted"
FLAIR = "flair"
class CSFAppearance(Enum):
"""How CSF appears on each modality (needed for correct thresholding)."""
DARK = "dark" # CT, T1, FLAIR
BRIGHT = "bright" # T2
@dataclass
class SegmentationResult:
"""Container for NPH segmentation results."""
masks: Dict[str, np.ndarray]
overlay: np.ndarray
contours: Dict[str, List] = field(default_factory=dict)
metadata: Dict = field(default_factory=dict)
# ---------------------------------------------------------------------------
# NPH color palette
# ---------------------------------------------------------------------------
COLORS = {
"lateral_ventricles": (0, 150, 255),
"ventricles": (0, 150, 255),
"third_ventricle": (0, 100, 200),
"temporal_horns": (0, 200, 255),
"csf": (0, 150, 255),
"parenchyma": (100, 200, 100),
"pvh": (255, 200, 0),
"periventricular_hyperintensity": (255, 200, 0),
"sylvian_fissures": (200, 100, 255),
"high_convexity_sulci": (255, 150, 100),
"skull": (255, 255, 200),
"bone": (255, 255, 200),
"aqueductal_flow_void": (255, 80, 80),
"transependymal_flow": (255, 180, 0),
"subdural_collection": (180, 60, 60),
"hemorrhage": (200, 50, 100),
"white_matter": (180, 180, 140),
"gray_matter": (140, 160, 140),
}
# ---------------------------------------------------------------------------
# Modality-specific CSF behavior
# ---------------------------------------------------------------------------
CSF_MODE = {
Modality.CT_HEAD: CSFAppearance.DARK,
Modality.T1: CSFAppearance.DARK,
Modality.T1_GD: CSFAppearance.DARK,
Modality.T2: CSFAppearance.BRIGHT,
Modality.FLAIR: CSFAppearance.DARK,
}
# Default 8-bit thresholds for ventricle segmentation per modality
VENTRICLE_THRESHOLDS = {
Modality.CT_HEAD: {"csf_low": 0, "csf_high": 55}, # Dark on brain window
Modality.T1: {"csf_low": 0, "csf_high": 45}, # Hypointense
Modality.T1_GD: {"csf_low": 0, "csf_high": 45},
Modality.T2: {"csf_low": 170, "csf_high": 255}, # Hyperintense
Modality.FLAIR: {"csf_low": 0, "csf_high": 50}, # Suppressed
}
# Periventricular hyperintensity thresholds (FLAIR only)
PVH_THRESHOLD = 145 # Pixel intensity above which = hyperintense on FLAIR
# CT Hounsfield windows
CT_WINDOWS = {
"brain": (40, 80),
"subdural": (75, 215),
"bone": (400, 1800),
}
# ===========================================================================
# IMAGE LOADING AND PREPROCESSING
# ===========================================================================
def preprocess_image(image_path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Load and preprocess image.
Returns:
(original_rgb, grayscale, blurred)
"""
img = cv2.imread(image_path)
if img is None:
raise FileNotFoundError(f"Could not read image: {image_path}")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
return img_rgb, gray, blurred
def apply_ct_window(hu_image: np.ndarray, center: float, width: float) -> np.ndarray:
"""Apply CT windowing: HU values to 8-bit grayscale."""
low = center - width / 2.0
high = center + width / 2.0
windowed = np.clip(hu_image, low, high)
return ((windowed - low) / (high - low) * 255.0).astype(np.uint8)
def load_dicom(filepath: str) -> Tuple[np.ndarray, dict]:
"""
Load DICOM and return HU image + metadata.
Requires pydicom.
"""
try:
import pydicom
except ImportError:
raise ImportError("pydicom required for DICOM. Install: pip install pydicom")
ds = pydicom.dcmread(filepath)
pixel_array = ds.pixel_array.astype(np.float64)
slope = float(getattr(ds, "RescaleSlope", 1))
intercept = float(getattr(ds, "RescaleIntercept", 0))
hu = (pixel_array * slope + intercept).astype(np.int16)
spacing = list(getattr(ds, "PixelSpacing", [1.0, 1.0]))
meta = {
"patient_id": str(getattr(ds, "PatientID", "")),
"modality": str(getattr(ds, "Modality", "")),
"series_description": str(getattr(ds, "SeriesDescription", "")),
"slice_thickness": float(getattr(ds, "SliceThickness", 0)),
"pixel_spacing_mm": [float(s) for s in spacing],
"rows": int(ds.Rows),
"columns": int(ds.Columns),
}
return hu, meta
# ===========================================================================
# CORE SEGMENTATION PRIMITIVES
# ===========================================================================
def create_roi_mask(blurred: np.ndarray, threshold: int = 15) -> np.ndarray:
"""Create ROI mask excluding background."""
_, roi = cv2.threshold(blurred, threshold, 255, cv2.THRESH_BINARY)
kernel = np.ones((10, 10), np.uint8)
roi = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel, iterations=3)
return roi
def morphological_cleanup(
mask: np.ndarray,
kernel_size: int = 5,
close_iter: int = 2,
open_iter: int = 2,
) -> np.ndarray:
"""Morphological close then open."""
kernel = np.ones((kernel_size, kernel_size), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=close_iter)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=open_iter)
return mask
def filter_by_area(mask: np.ndarray, min_area: int = 300) -> np.ndarray:
"""Remove small connected components."""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filtered = np.zeros_like(mask)
for cnt in contours:
if cv2.contourArea(cnt) > min_area:
cv2.drawContours(filtered, [cnt], -1, 255, -1)
return filtered
def segment_adaptive(
gray: np.ndarray,
block_size: int = 51,
C: int = 10,
roi_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Adaptive thresholding for field inhomogeneity."""
if block_size % 2 == 0:
block_size += 1
mask = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, C
)
if roi_mask is not None:
mask = cv2.bitwise_and(mask, roi_mask)
return mask
def segment_otsu(
gray: np.ndarray,
roi_mask: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, float]:
"""Otsu automatic thresholding. Returns (mask, threshold_value)."""
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
val, mask = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if roi_mask is not None:
mask = cv2.bitwise_and(mask, roi_mask)
return mask, float(val)
def region_growing(
gray: np.ndarray,
seed: Tuple[int, int],
tolerance: int = 15,
roi_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Region growing from seed point within intensity tolerance."""
h, w = gray.shape[:2]
sx, sy = seed
seed_val = int(gray[sy, sx])
low, high = max(0, seed_val - tolerance), min(255, seed_val + tolerance)
visited = np.zeros((h, w), dtype=np.uint8)
mask = np.zeros((h, w), dtype=np.uint8)
neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1),
(0, 1), (1, -1), (1, 0), (1, 1)]
stack = [(sx, sy)]
visited[sy, sx] = 1
while stack:
cx, cy = stack.pop()
val = int(gray[cy, cx])
if low <= val <= high:
if roi_mask is not None and roi_mask[cy, cx] == 0:
continue
mask[cy, cx] = 255
for dx, dy in neighbors:
nx, ny = cx + dx, cy + dy
if 0 <= nx < w and 0 <= ny < h and visited[ny, nx] == 0:
visited[ny, nx] = 1
stack.append((nx, ny))
return mask
def watershed_segment(
gray: np.ndarray,
roi_mask: Optional[np.ndarray] = None,
min_distance: int = 20,
threshold_ratio: float = 0.5,
) -> Tuple[np.ndarray, int]:
"""Watershed segmentation. Returns (label_image, num_labels)."""
if roi_mask is None:
_, roi_mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
dist = cv2.distanceTransform(roi_mask, cv2.DIST_L2, 5)
_, sure_fg = cv2.threshold(dist, threshold_ratio * dist.max(), 255, 0)
sure_fg = sure_fg.astype(np.uint8)
kernel = np.ones((3, 3), np.uint8)
sure_bg = cv2.dilate(roi_mask, kernel, iterations=3)
unknown = cv2.subtract(sure_bg, sure_fg)
num_labels, markers = cv2.connectedComponents(sure_fg)
markers = markers + 1
markers[unknown == 255] = 0
img_color = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
markers = cv2.watershed(img_color, markers)
return markers, num_labels
def detect_edges_canny(
gray: np.ndarray, low: int = 50, high: int = 150,
roi_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Canny edge detection."""
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
edges = cv2.Canny(blurred, low, high)
if roi_mask is not None:
edges = cv2.bitwise_and(edges, roi_mask)
return edges
def detect_edges_sobel(
gray: np.ndarray, ksize: int = 3,
roi_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Sobel gradient magnitude (normalized to uint8)."""
blur = cv2.GaussianBlur(gray, (5, 5), 0)
gx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=ksize)
gy = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=ksize)
mag = np.sqrt(gx ** 2 + gy ** 2)
mag = (mag / mag.max() * 255).astype(np.uint8) if mag.max() > 0 else mag.astype(np.uint8)
if roi_mask is not None:
mag = cv2.bitwise_and(mag, mag, mask=roi_mask)
return mag
# ===========================================================================
# VENTRICLE SEGMENTATION
# ===========================================================================
def segment_ventricles(
gray: np.ndarray,
modality: Modality,
roi_mask: Optional[np.ndarray] = None,
custom_thresholds: Optional[Dict] = None,
) -> np.ndarray:
"""
Segment ventricular CSF on any supported modality.
Args:
gray: Grayscale image (uint8)
modality: Imaging modality
roi_mask: Optional brain ROI mask
custom_thresholds: Optional dict with 'csf_low' and 'csf_high'
Returns:
Binary ventricle mask (uint8, 0 or 255)
"""
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
if roi_mask is None:
roi_mask = create_roi_mask(blurred, threshold=15)
thresh = custom_thresholds or VENTRICLE_THRESHOLDS[modality]
csf_low = thresh["csf_low"]
csf_high = thresh["csf_high"]
csf_mode = CSF_MODE[modality]
if csf_mode == CSFAppearance.DARK:
# CSF is dark: threshold for low intensities
mask = cv2.inRange(blurred, csf_low, csf_high)
else:
# CSF is bright (T2): threshold for high intensities
mask = cv2.inRange(blurred, csf_low, csf_high)
mask = cv2.bitwise_and(mask, roi_mask)
mask = morphological_cleanup(mask, kernel_size=5, close_iter=3, open_iter=2)
mask = filter_by_area(mask, min_area=300)
return mask
def segment_skull(
gray: np.ndarray,
threshold: int = 200,
) -> np.ndarray:
"""
Segment inner skull boundary for diameter measurement.
For CT (bright bone) or any modality with visible skull.
"""
_, bone_mask = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
bone_mask = morphological_cleanup(bone_mask, kernel_size=7, close_iter=3, open_iter=1)
bone_mask = filter_by_area(bone_mask, min_area=1000)
return bone_mask
# ===========================================================================
# NPH BIOMARKER COMPUTATIONS
# ===========================================================================
def compute_evans_index(
ventricle_mask: np.ndarray,
skull_mask: Optional[np.ndarray] = None,
image_width: Optional[int] = None,
pixel_spacing_mm: Optional[float] = None,
) -> Dict:
"""
Compute Evans' Index from ventricle and skull masks.
If skull_mask is not available, uses image_width as proxy for
maximum skull diameter (less accurate).
Args:
ventricle_mask: Binary ventricle mask (uint8)
skull_mask: Optional binary skull mask
image_width: Fallback for skull diameter (image pixel width)
pixel_spacing_mm: Optional pixel spacing for mm conversion
Returns:
Dict with 'evans_index', 'frontal_horn_width_px',
'skull_diameter_px', and optionally '_mm' variants
"""
h, w = ventricle_mask.shape[:2]
# Find the row with maximum horizontal ventricle extent
# (approximates the axial level of max frontal horn width)
max_frontal_width = 0
max_row = 0
for row in range(h):
cols = np.where(ventricle_mask[row, :] > 0)[0]
if len(cols) > 0:
width = cols[-1] - cols[0]
if width > max_frontal_width:
max_frontal_width = width
max_row = row
# Skull diameter
if skull_mask is not None:
# Find max horizontal extent of non-skull (inner diameter) at the same row range
# Use the skull mask to find inner table boundaries
skull_row = skull_mask[max_row, :]
skull_cols = np.where(skull_row > 0)[0]
if len(skull_cols) > 1:
skull_diameter = skull_cols[-1] - skull_cols[0]
else:
skull_diameter = image_width or w
else:
# Fallback: use the brain ROI width or image width
skull_diameter = image_width or w
if skull_diameter == 0:
skull_diameter = w
evans_index = max_frontal_width / skull_diameter
result = {
"evans_index": round(evans_index, 4),
"frontal_horn_width_px": max_frontal_width,
"skull_diameter_px": skull_diameter,
"measurement_row": max_row,
}
if pixel_spacing_mm is not None:
result["frontal_horn_width_mm"] = round(max_frontal_width * pixel_spacing_mm, 2)
result["skull_diameter_mm"] = round(skull_diameter * pixel_spacing_mm, 2)
return result
def compute_callosal_angle(
ventricle_mask: np.ndarray,
) -> Dict:
"""
Estimate callosal angle from a coronal ventricle mask.
Finds the two lateral ventricle peaks and the midline apex,
then computes the angle between the two roof lines.
Args:
ventricle_mask: Binary ventricle mask on a coronal slice (uint8)
Returns:
Dict with 'callosal_angle_deg', 'apex_point', 'left_point', 'right_point'
"""
h, w = ventricle_mask.shape[:2]
midline_x = w // 2
# Find the topmost ventricle row (highest point of ventricles)
rows_with_csf = np.where(ventricle_mask.any(axis=1))[0]
if len(rows_with_csf) == 0:
return {"callosal_angle_deg": None, "error": "No ventricles detected"}
top_row = rows_with_csf[0]
# Find apex: topmost point near midline
midline_band = ventricle_mask[:, max(0, midline_x - 20):min(w, midline_x + 20)]
midline_rows = np.where(midline_band.any(axis=1))[0]
if len(midline_rows) == 0:
# No midline CSF -- use topmost row
apex_y = top_row
apex_x = midline_x
else:
apex_y = midline_rows[0]
apex_col_in_band = np.where(midline_band[apex_y, :] > 0)[0]
apex_x = midline_x - 20 + int(np.mean(apex_col_in_band))
# Find the topmost point of left ventricle (left of midline)
left_mask = ventricle_mask[:, :midline_x]
left_rows = np.where(left_mask.any(axis=1))[0]
if len(left_rows) == 0:
return {"callosal_angle_deg": None, "error": "Left ventricle not detected"}
left_top_row = left_rows[0]
left_cols = np.where(left_mask[left_top_row, :] > 0)[0]
left_x = int(np.mean(left_cols))
left_y = left_top_row
# Find the topmost point of right ventricle (right of midline)
right_mask = ventricle_mask[:, midline_x:]
right_rows = np.where(right_mask.any(axis=1))[0]
if len(right_rows) == 0:
return {"callosal_angle_deg": None, "error": "Right ventricle not detected"}
right_top_row = right_rows[0]
right_cols = np.where(right_mask[right_top_row, :] > 0)[0]
right_x = midline_x + int(np.mean(right_cols))
right_y = right_top_row
# Compute angle at apex between the two lines
vec_left = np.array([left_x - apex_x, left_y - apex_y], dtype=float)
vec_right = np.array([right_x - apex_x, right_y - apex_y], dtype=float)
norm_left = np.linalg.norm(vec_left)
norm_right = np.linalg.norm(vec_right)
if norm_left == 0 or norm_right == 0:
return {"callosal_angle_deg": None, "error": "Degenerate geometry"}
cos_angle = np.dot(vec_left, vec_right) / (norm_left * norm_right)
cos_angle = np.clip(cos_angle, -1.0, 1.0)
angle_deg = np.degrees(np.arccos(cos_angle))
return {
"callosal_angle_deg": round(angle_deg, 1),
"apex_point": (apex_x, apex_y),
"left_point": (left_x, left_y),
"right_point": (right_x, right_y),
}
def compute_temporal_horn_width(
ventricle_mask: np.ndarray,
pixel_spacing_mm: Optional[float] = None,
) -> Dict:
"""
Estimate temporal horn width from an axial ventricle mask.
Looks for ventricle regions in the lower third of the image
(approximate temporal horn location).
Returns:
Dict with 'temporal_horn_width_px' (and '_mm' if spacing given)
"""
h, w = ventricle_mask.shape[:2]
# Temporal horns are in the lower portion of an axial slice
lower_third = ventricle_mask[int(h * 0.55):int(h * 0.85), :]
max_width = 0
for row in range(lower_third.shape[0]):
cols = np.where(lower_third[row, :] > 0)[0]
if len(cols) > 0:
# Look for small isolated clusters (temporal horns are smaller)
# Split into left/right halves
left_cols = cols[cols < w // 2]
right_cols = cols[cols >= w // 2]
for cluster in [left_cols, right_cols]:
if len(cluster) > 0:
cluster_width = cluster[-1] - cluster[0]
if 3 < cluster_width < w // 4: # reasonable temporal horn size
max_width = max(max_width, cluster_width)
result = {"temporal_horn_width_px": max_width}
if pixel_spacing_mm is not None:
result["temporal_horn_width_mm"] = round(max_width * pixel_spacing_mm, 2)
return result
def compute_third_ventricle_width(
ventricle_mask: np.ndarray,
pixel_spacing_mm: Optional[float] = None,
) -> Dict:
"""
Estimate third ventricle width from an axial ventricle mask.
Looks for a narrow midline CSF structure.
Returns:
Dict with 'third_ventricle_width_px' (and '_mm' if spacing given)
"""
h, w = ventricle_mask.shape[:2]
midline_band = ventricle_mask[:, max(0, w // 2 - 30):min(w, w // 2 + 30)]
max_width = 0
for row in range(midline_band.shape[0]):
cols = np.where(midline_band[row, :] > 0)[0]
if len(cols) > 0:
width = cols[-1] - cols[0]
if 2 < width < 60: # reasonable third ventricle size
max_width = max(max_width, width)
result = {"third_ventricle_width_px": max_width}
if pixel_spacing_mm is not None:
result["third_ventricle_width_mm"] = round(max_width * pixel_spacing_mm, 2)
return result
def score_pvh(
flair_gray: np.ndarray,
ventricle_mask: np.ndarray,
dilation_px: int = 15,
pvh_threshold: int = PVH_THRESHOLD,
) -> Dict:
"""
Score periventricular hyperintensity on FLAIR.
Creates a periventricular zone by dilating the ventricle mask,
then measures hyperintense signal within that zone.
Args:
flair_gray: FLAIR grayscale image (uint8)
ventricle_mask: Binary ventricle mask
dilation_px: Size of periventricular zone in pixels
pvh_threshold: Intensity threshold for hyperintensity
Returns:
Dict with 'pvh_grade' (0-3), 'pvh_ratio', 'pvh_area_px'
"""
kernel = np.ones((dilation_px, dilation_px), np.uint8)
periventricular_zone = cv2.dilate(ventricle_mask, kernel, iterations=1)
periventricular_zone = cv2.subtract(periventricular_zone, ventricle_mask)
# Measure hyperintensity within periventricular zone
pvh_mask = cv2.inRange(flair_gray, pvh_threshold, 255)
pvh_in_zone = cv2.bitwise_and(pvh_mask, periventricular_zone)
zone_area = (periventricular_zone > 0).sum()
pvh_area = (pvh_in_zone > 0).sum()
pvh_ratio = pvh_area / zone_area if zone_area > 0 else 0.0
# Grade (Fazekas-like for NPH context)
if pvh_ratio < 0.05:
grade = 0 # No significant PVH
elif pvh_ratio < 0.15:
grade = 1 # Pencil-thin rim
elif pvh_ratio < 0.35:
grade = 2 # Smooth halo
else:
grade = 3 # Irregular, extending into deep white matter
return {
"pvh_grade": grade,
"pvh_ratio": round(pvh_ratio, 4),
"pvh_area_px": int(pvh_area),
"periventricular_zone_area_px": int(zone_area),
"pvh_mask": pvh_in_zone,
}
def assess_desh(
ventricle_mask: np.ndarray,
gray: np.ndarray,
roi_mask: np.ndarray,
modality: Modality,
pixel_spacing_mm: Optional[float] = None,
) -> Dict:
"""
Assess DESH (Disproportionately Enlarged Subarachnoid-space Hydrocephalus) pattern.
Compares sylvian fissure CSF to high-convexity sulcal CSF.
DESH-positive = enlarged sylvian fissures + tight high convexity + ventriculomegaly.
Args:
ventricle_mask: Binary ventricle mask
gray: Grayscale image
roi_mask: Brain ROI mask
modality: Imaging modality
pixel_spacing_mm: Optional pixel spacing
Returns:
Dict with DESH component scores and overall assessment
"""
h, w = gray.shape[:2]
# Evans' index component
ei_data = compute_evans_index(ventricle_mask, image_width=w, pixel_spacing_mm=pixel_spacing_mm)
ei = ei_data["evans_index"]
# Ventriculomegaly score
if ei < 0.3:
vm_score = 0
elif ei <= 0.33:
vm_score = 1
else:
vm_score = 2
# Segment all CSF (sulcal + ventricular)
thresh = VENTRICLE_THRESHOLDS[modality]
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
all_csf = cv2.inRange(blurred, thresh["csf_low"], thresh["csf_high"])
all_csf = cv2.bitwise_and(all_csf, roi_mask)
# Non-ventricular CSF = sulcal CSF
sulcal_csf = cv2.subtract(all_csf, ventricle_mask)
sulcal_csf = morphological_cleanup(sulcal_csf, kernel_size=3, close_iter=1, open_iter=1)
# Sylvian fissure region (middle third vertically, lateral portions)
sylvian_region = np.zeros_like(gray, dtype=np.uint8)
y_start, y_end = int(h * 0.35), int(h * 0.65)
x_left_end, x_right_start = int(w * 0.15), int(w * 0.85)
sylvian_region[y_start:y_end, :x_left_end] = 255
sylvian_region[y_start:y_end, x_right_start:] = 255
# Also include lateral middle zones
sylvian_region[y_start:y_end, :int(w * 0.3)] = 255
sylvian_region[y_start:y_end, int(w * 0.7):] = 255
sylvian_csf = cv2.bitwise_and(sulcal_csf, sylvian_region)
sylvian_csf_area = (sylvian_csf > 0).sum()
# High convexity region (top 25% of image)
convexity_region = np.zeros_like(gray, dtype=np.uint8)
convexity_region[:int(h * 0.25), :] = 255
convexity_csf = cv2.bitwise_and(sulcal_csf, convexity_region)
convexity_csf_area = (convexity_csf > 0).sum()
# Sylvian/convexity ratio
if convexity_csf_area > 0:
ratio = sylvian_csf_area / convexity_csf_area
else:
ratio = float("inf") if sylvian_csf_area > 0 else 0.0
# Sylvian fissure score
if ratio < 1.5:
sylvian_score = 0 # Proportionate
elif ratio < 3.0:
sylvian_score = 1 # Mildly disproportionate
else:
sylvian_score = 2 # Markedly disproportionate (DESH pattern)
# Convexity tightness score
brain_top_area = (roi_mask[:int(h * 0.25), :] > 0).sum()
convexity_ratio = convexity_csf_area / brain_top_area if brain_top_area > 0 else 0
if convexity_ratio > 0.1:
convexity_score = 0 # Normal sulci
elif convexity_ratio > 0.04:
convexity_score = 1 # Mildly tight
else:
convexity_score = 2 # Effaced (tight high convexity)
# DESH positive if all components >= 2 (or total >= 5 as softer criterion)
total_score = vm_score + sylvian_score + convexity_score
is_desh_positive = (vm_score >= 1 and sylvian_score >= 2 and convexity_score >= 2)
return {
"is_desh_positive": is_desh_positive,
"total_score": total_score,
"ventriculomegaly_score": vm_score,
"sylvian_dilation_score": sylvian_score,
"convexity_tightness_score": convexity_score,
"evans_index": ei,
"sylvian_convexity_ratio": round(ratio, 2) if ratio != float("inf") else "inf",
"sylvian_csf_area_px": int(sylvian_csf_area),
"convexity_csf_area_px": int(convexity_csf_area),
"sylvian_mask": sylvian_csf,
"convexity_mask": convexity_csf,
}
# ===========================================================================
# FOUNDATION MODEL WRAPPERS
# ===========================================================================
def sam_segment(
image: np.ndarray,
points: List[Tuple[int, int]],
labels: List[int],
checkpoint: str = "sam_vit_h.pth",
model_type: str = "vit_h",
) -> np.ndarray:
"""SAM point-prompt segmentation. Requires segment-anything."""
try:
from segment_anything import SamPredictor, sam_model_registry
except ImportError:
raise ImportError("segment-anything required. See: https://github.com/facebookresearch/segment-anything")
sam = sam_model_registry[model_type](checkpoint=checkpoint)
predictor = SamPredictor(sam)
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array(points), point_labels=np.array(labels), multimask_output=True
)
return (masks[np.argmax(scores)].astype(np.uint8) * 255)
def medsam_segment(
image: np.ndarray,
bbox: Tuple[int, int, int, int],
checkpoint: Optional[str] = None,
) -> np.ndarray:
"""MedSAM bbox segmentation. Requires medsam."""
try:
from medsam import MedSAMPredictor
except ImportError:
raise ImportError("MedSAM required. See: https://github.com/bowang-lab/MedSAM")
predictor = MedSAMPredictor(checkpoint=checkpoint)
return (predictor.predict(image, bbox).astype(np.uint8) * 255)
# ===========================================================================
# VISUALIZATION
# ===========================================================================
def create_overlay(
img_rgb: np.ndarray,
masks: Dict[str, np.ndarray],
alpha: float = 0.45,
draw_contours: bool = True,
) -> np.ndarray:
"""Create color overlay visualization."""
overlay = img_rgb.copy()
for name, mask in masks.items():
color = COLORS.get(name, (200, 200, 200))
overlay[mask > 0] = color
any_mask = np.zeros(img_rgb.shape[:2], dtype=np.uint8)
for mask in masks.values():
any_mask = cv2.bitwise_or(any_mask, mask)
result = img_rgb.copy()
for c in range(3):
result[:, :, c] = np.where(
any_mask > 0,
(alpha * overlay[:, :, c] + (1 - alpha) * img_rgb[:, :, c]).astype(np.uint8),
img_rgb[:, :, c],
)
if draw_contours:
for name, mask in masks.items():
color = COLORS.get(name, (200, 200, 200))
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
if cv2.contourArea(cnt) > 100:
cv2.drawContours(result, [cnt], -1, color, 2)
return result
def add_annotations(
image: np.ndarray,
masks: Dict[str, np.ndarray],
title: str = "NPH Segmentation",
biomarkers: Optional[Dict] = None,
) -> np.ndarray:
"""Add title, legend, and optionally biomarker values to image."""
pil_img = Image.fromarray(image)
draw = ImageDraw.Draw(pil_img)
height, width = image.shape[:2]
try:
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 13)
font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
except Exception:
font_title = font_label = font_small = ImageFont.load_default()
# Title
draw.text((width // 2 - 80, 8), title, fill=(255, 255, 255), font=font_title)
# Legend (bottom-left)
num_items = len(masks)
legend_x, legend_y = 12, height - 28 * num_items - 20
box_height = 28 * num_items + 12
draw.rectangle(
[(legend_x, legend_y), (legend_x + 180, legend_y + box_height)],
fill=(0, 0, 0, 180), outline=(150, 150, 150),
)
y = legend_y + 8
for name in masks.keys():
color = COLORS.get(name, (200, 200, 200))
draw.rectangle([(legend_x + 8, y), (legend_x + 24, y + 14)], fill=color, outline=(200, 200, 200))
draw.text((legend_x + 32, y - 1), name.replace("_", " ").title(), fill=(255, 255, 255), font=font_label)
y += 24
# Biomarker panel (top-right)
if biomarkers:
bm_x = width - 220
bm_y = 35
bm_items = []
if "evans_index" in biomarkers:
ei = biomarkers["evans_index"]
status = "ABNORMAL" if ei > 0.3 else "normal"
bm_items.append(f"Evans' Index: {ei:.3f} ({status})")
if "callosal_angle_deg" in biomarkers and biomarkers["callosal_angle_deg"] is not None:
ca = biomarkers["callosal_angle_deg"]
bm_items.append(f"Callosal Angle: {ca:.1f} deg")
if "temporal_horn_width_px" in biomarkers:
thw = biomarkers["temporal_horn_width_px"]
bm_items.append(f"Temporal Horn: {thw} px")
if "pvh_grade" in biomarkers:
bm_items.append(f"PVH Grade: {biomarkers['pvh_grade']}/3")
if "is_desh_positive" in biomarkers:
desh = "POSITIVE" if biomarkers["is_desh_positive"] else "negative"
bm_items.append(f"DESH: {desh}")
if bm_items:
box_h = 20 * len(bm_items) + 12
draw.rectangle(
[(bm_x - 5, bm_y - 5), (bm_x + 210, bm_y + box_h)],
fill=(0, 0, 0, 200), outline=(100, 200, 255),
)
for i, text in enumerate(bm_items):
draw.text((bm_x + 3, bm_y + 3 + i * 20), text, fill=(220, 240, 255), font=font_small)
return np.array(pil_img)
def create_comparison(
original: np.ndarray,
segmented: np.ndarray,
title: str = "Original vs NPH Segmentation",
) -> np.ndarray:
"""Side-by-side comparison."""
height, width = original.shape[:2]
gap = 20
comp_width = width * 2 + gap
comp_height = height + 60
comp = np.zeros((comp_height, comp_width, 3), dtype=np.uint8)
comp[:] = (25, 25, 30)
comp[55:55 + height, :width] = original
comp[55:55 + height, width + gap:] = segmented
pil = Image.fromarray(comp)
draw = ImageDraw.Draw(pil)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
except Exception:
font = ImageFont.load_default()
draw.text((width // 2 - 30, 20), "Original", fill=(200, 200, 200), font=font)
draw.text((width + gap + width // 2 - 60, 20), "NPH Analysis", fill=(200, 200, 200), font=font)
draw.text((comp_width // 2 - 100, 2), title, fill=(255, 255, 255), font=font)
return np.array(pil)
# ===========================================================================
# QUALITY METRICS
# ===========================================================================
def dice_coefficient(pred: np.ndarray, gt: np.ndarray) -> float:
"""Dice coefficient between prediction and ground truth."""
p, g = (pred > 0).astype(bool), (gt > 0).astype(bool)
inter = np.sum(p & g)
total = np.sum(p) + np.sum(g)
return 2.0 * inter / total if total > 0 else 1.0
def iou_score(pred: np.ndarray, gt: np.ndarray) -> float:
"""Intersection over Union (Jaccard index)."""
p, g = (pred > 0).astype(bool), (gt > 0).astype(bool)
inter = np.sum(p & g)
union = np.sum(p | g)
return float(inter) / float(union) if union > 0 else 1.0
# ===========================================================================
# INTERNAL HELPERS
# ===========================================================================
def _extract_contours(masks: Dict[str, np.ndarray], min_area: int = 200) -> Dict[str, List]:
out = {}
for name, mask in masks.items():
cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
out[name] = [cnt.tolist() for cnt in cnts if cv2.contourArea(cnt) > min_area]
return out
# ===========================================================================
# MAIN NPH PIPELINE
# ===========================================================================
def segment_nph(
image_path: str,
modality: str = "CT_HEAD",
structures: Optional[List[str]] = None,
output_path: Optional[str] = None,
pixel_spacing_mm: Optional[float] = None,
compute_biomarkers: bool = True,
) -> SegmentationResult:
"""
Main entry point for NPH neuroimaging segmentation and analysis.
Args:
image_path: Path to input image (DICOM, PNG, JPG)
modality: 'CT_HEAD', 'T1', 'T1_GD', 'T2', 'FLAIR'
structures: Optional list of structures to segment
output_path: Optional path to save comparison image
pixel_spacing_mm: Pixel spacing for real-world measurements
compute_biomarkers: Whether to compute Evans' index, etc.
Returns:
SegmentationResult with masks, overlay, contours, and metadata
including NPH biomarkers
"""
mod = Modality[modality.upper()]
# Load image
is_dicom = image_path.lower().endswith((".dcm", ".dicom"))
dicom_meta = {}
if is_dicom:
hu, dicom_meta = load_dicom(image_path)
pixel_spacing_mm = pixel_spacing_mm or dicom_meta.get("pixel_spacing_mm", [1.0])[0]
center, width_hu = CT_WINDOWS["brain"]
gray = apply_ct_window(hu, center, width_hu)
img_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
else:
img_rgb, gray, _ = preprocess_image(image_path)
roi_mask = create_roi_mask(cv2.GaussianBlur(gray, (5, 5), 0), threshold=15)
# Default structures
if structures is None:
if mod == Modality.FLAIR:
structures = ["ventricles", "pvh", "parenchyma"]
else:
structures = ["ventricles", "parenchyma"]
masks: Dict[str, np.ndarray] = {}
# Always segment ventricles
vent_mask = segment_ventricles(gray, mod, roi_mask)
masks["ventricles"] = vent_mask
# Parenchyma (brain tissue excluding ventricles)
if "parenchyma" in structures:
parenchyma = cv2.bitwise_and(roi_mask, cv2.bitwise_not(vent_mask))
masks["parenchyma"] = parenchyma
# PVH on FLAIR
pvh_data = None
if "pvh" in structures and mod == Modality.FLAIR:
pvh_data = score_pvh(gray, vent_mask)
masks["pvh"] = pvh_data["pvh_mask"]
# Skull (for Evans' index)
skull_mask = None
if is_dicom and mod == Modality.CT_HEAD:
bone_gray = apply_ct_window(hu, *CT_WINDOWS["bone"])
skull_mask = segment_skull(bone_gray, threshold=180)
masks["skull"] = skull_mask
# Compute biomarkers
biomarkers = {}
if compute_biomarkers:
ei_data = compute_evans_index(vent_mask, skull_mask, gray.shape[1], pixel_spacing_mm)
biomarkers.update(ei_data)
th_data = compute_temporal_horn_width(vent_mask, pixel_spacing_mm)
biomarkers.update(th_data)
tv_data = compute_third_ventricle_width(vent_mask, pixel_spacing_mm)
biomarkers.update(tv_data)
if pvh_data:
biomarkers["pvh_grade"] = pvh_data["pvh_grade"]
biomarkers["pvh_ratio"] = pvh_data["pvh_ratio"]
# DESH assessment
desh_data = assess_desh(vent_mask, gray, roi_mask, mod, pixel_spacing_mm)
biomarkers["is_desh_positive"] = desh_data["is_desh_positive"]
biomarkers["desh_total_score"] = desh_data["total_score"]
biomarkers["desh_details"] = {
k: v for k, v in desh_data.items()
if k not in ("sylvian_mask", "convexity_mask")
}
# Remove non-display masks
display_masks = {k: v for k, v in masks.items() if k != "skull"}
# Visualization
overlay = create_overlay(img_rgb, display_masks)
annotated = add_annotations(overlay, display_masks, f"{modality} - NPH Analysis", biomarkers)
contours = _extract_contours(display_masks)
metadata = {
"modality": modality,
"structures_found": list(display_masks.keys()),
"image_shape": img_rgb.shape,
"pixel_spacing_mm": pixel_spacing_mm,
"dicom_meta": dicom_meta,
}
metadata.update(biomarkers)
result = SegmentationResult(
masks=display_masks, overlay=annotated, contours=contours, metadata=metadata
)
if output_path:
comparison = create_comparison(img_rgb, annotated, f"{modality} - NPH Analysis")
Image.fromarray(comparison).save(output_path)
print(f"Saved: {output_path}")
return result
# Alias for backward compatibility
segment_brain_image = segment_nph
# ===========================================================================
# CLI
# ===========================================================================
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python segment_neuroimaging.py <image_path> [modality] [output_path]")
print(" modality: CT_HEAD, T1, T1_GD, T2, FLAIR (default: CT_HEAD)")
sys.exit(1)
image_path = sys.argv[1]
modality = sys.argv[2] if len(sys.argv) > 2 else "CT_HEAD"
output_path = sys.argv[3] if len(sys.argv) > 3 else image_path.replace(".", "_nph_analysis.")
result = segment_nph(image_path, modality, output_path=output_path)
print(f"\n--- NPH Analysis Results ---")
print(f"Structures: {result.metadata['structures_found']}")
ei = result.metadata.get("evans_index")
if ei is not None:
print(f"Evans' Index: {ei:.3f} ({'ABNORMAL (>0.3)' if ei > 0.3 else 'Normal'})")
thw = result.metadata.get("temporal_horn_width_px")
if thw is not None:
print(f"Temporal Horn Width: {thw} px")
pvh = result.metadata.get("pvh_grade")
if pvh is not None:
print(f"PVH Grade: {pvh}/3")
desh = result.metadata.get("is_desh_positive")
if desh is not None:
print(f"DESH Pattern: {'POSITIVE' if desh else 'negative'}")