""" Visualization Utilities Produces publication-quality overlays, highlights, and detection visualizations for the SpotSeg pipeline. """ import numpy as np from PIL import Image, ImageDraw, ImageFont, ImageFilter import colorsys def hex_to_rgb(hex_color: str) -> tuple: """Convert hex color string to (R, G, B) tuple.""" hex_color = hex_color.lstrip("#") return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) def get_label_colors(n: int, base_hex: str = "#4fd1c5") -> list: """ Generate n visually distinct colors starting from a base hue. Returns list of (R, G, B) tuples. """ base_rgb = hex_to_rgb(base_hex) base_hsv = colorsys.rgb_to_hsv( base_rgb[0] / 255, base_rgb[1] / 255, base_rgb[2] / 255 ) colors = [] for i in range(n): hue = (base_hsv[0] + i * 0.618033988749895) % 1.0 # golden ratio sat = 0.7 + (i % 3) * 0.1 val = 0.85 r, g, b = colorsys.hsv_to_rgb(hue, sat, val) colors.append((int(r * 255), int(g * 255), int(b * 255))) return colors def create_highlight_overlay( image: Image.Image, mask: np.ndarray, highlight_color: str = "#4fd1c5", alpha: float = 0.45, ) -> Image.Image: """ Overlay a colored, semi-transparent highlight on the detected object. Areas outside the mask are slightly dimmed for emphasis. Args: image: Original PIL Image (RGB) mask: Float32 array [0, 1] same size as image highlight_color: Hex color for the highlight alpha: Opacity of the highlight overlay Returns: PIL Image with highlight applied """ img_array = np.array(image).astype(np.float32) color = hex_to_rgb(highlight_color) # Expand mask to 3 channels mask_3d = np.stack([mask] * 3, axis=-1) # Create highlight layer highlight = np.full_like(img_array, color, dtype=np.float32) # Dim background slightly dimmed = img_array * 0.4 # Blend: highlighted object + dimmed background result = np.where( mask_3d > 0, img_array * (1 - alpha) + highlight * alpha, dimmed, ) # Add subtle glow at mask edges from PIL import ImageFilter as IF mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) edge_glow = mask_pil.filter(IF.GaussianBlur(radius=8)) edge_glow = np.array(edge_glow).astype(np.float32) / 255.0 edge_only = np.clip(edge_glow - mask, 0, 1) edge_3d = np.stack([edge_only] * 3, axis=-1) glow = np.full_like(img_array, color, dtype=np.float32) result = result + edge_3d * glow * 0.3 result = np.clip(result, 0, 255).astype(np.uint8) return Image.fromarray(result) def create_blur_background( image: Image.Image, mask: np.ndarray, blur_radius: int = 25, ) -> Image.Image: """ Keep the detected object sharp and blur the background. Creates a portrait-mode / bokeh effect. Args: image: Original PIL Image mask: Float32 array [0, 1] blur_radius: Gaussian blur strength for background Returns: PIL Image with blurred background """ blurred = image.filter(ImageFilter.GaussianBlur(radius=blur_radius)) img_array = np.array(image).astype(np.float32) blur_array = np.array(blurred).astype(np.float32) # Smooth the mask edges for a natural transition mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) mask_smooth = mask_pil.filter(ImageFilter.GaussianBlur(radius=3)) mask_smooth = np.array(mask_smooth).astype(np.float32) / 255.0 mask_3d = np.stack([mask_smooth] * 3, axis=-1) result = img_array * mask_3d + blur_array * (1 - mask_3d) result = np.clip(result, 0, 255).astype(np.uint8) return Image.fromarray(result) def create_contour_outline( image: Image.Image, mask: np.ndarray, outline_color: str = "#4fd1c5", thickness: int = 3, ) -> Image.Image: """ Draw a glowing contour outline around the detected object. Args: image: Original PIL Image mask: Float32 array [0, 1] outline_color: Hex color for the contour thickness: Line thickness in pixels Returns: PIL Image with contour overlay """ import cv2 color = hex_to_rgb(outline_color) binary = (mask > 0).astype(np.uint8) * 255 # Find contours contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) result = np.array(image).copy() # Draw filled semi-transparent overlay first overlay = result.copy() cv2.drawContours(overlay, contours, -1, color, -1) alpha = 0.15 result = cv2.addWeighted(overlay, alpha, result, 1 - alpha, 0) # Draw contour outlines cv2.drawContours(result, contours, -1, color, thickness) # Add outer glow effect glow_mask = np.zeros(binary.shape, dtype=np.uint8) cv2.drawContours(glow_mask, contours, -1, 255, thickness + 4) glow_mask = cv2.GaussianBlur(glow_mask, (15, 15), 0) glow_3d = np.stack([glow_mask] * 3, axis=-1).astype(np.float32) / 255.0 glow_color = np.full_like(result, color, dtype=np.float32) result = result.astype(np.float32) + glow_3d * glow_color * 0.3 result = np.clip(result, 0, 255).astype(np.uint8) return Image.fromarray(result) def create_detection_visualization( image: Image.Image, detections: list, highlight_color: str = "#4fd1c5", ) -> Image.Image: """ Draw bounding boxes and labels for all detected objects. Args: image: Original PIL Image detections: List of dicts with 'label', 'confidence', 'bbox' highlight_color: Base color for generating palette Returns: PIL Image with detection boxes drawn """ result = image.copy() draw = ImageDraw.Draw(result) # Get distinct colors per unique class unique_labels = list(set(d["label"] for d in detections)) colors = get_label_colors(len(unique_labels), highlight_color) label_color_map = dict(zip(unique_labels, colors)) # Try to load a clean font, fall back to default try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14) font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 11) except (IOError, OSError): font = ImageFont.load_default() font_small = font for det in detections: label = det["label"] conf = det["confidence"] x1, y1, x2, y2 = det["bbox"] color = label_color_map[label] # Draw box with slight transparency for offset in range(2): draw.rectangle( [x1 - offset, y1 - offset, x2 + offset, y2 + offset], outline=color, width=1, ) # Label background text = f"{label} {conf:.0%}" bbox = draw.textbbox((0, 0), text, font=font) tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] label_y = max(y1 - th - 8, 0) draw.rectangle( [x1, label_y, x1 + tw + 10, label_y + th + 6], fill=color, ) draw.text( (x1 + 5, label_y + 2), text, fill=(0, 0, 0), font=font, ) return result