| """ |
| Professional Edge Detection & Refinement Module |
| ============================================== |
| |
| This module provides advanced edge detection, refinement, and processing |
| specifically optimized for hair segmentation in video processing pipelines. |
| |
| Features: |
| - Multi-scale edge detection |
| - Hair-specific edge refinement |
| - Temporal edge consistency |
| - Sub-pixel edge accuracy |
| - GPU-accelerated processing |
| |
| Author: BackgroundFX Pro |
| License: MIT |
| """ |
|
|
| import os |
| import cv2 |
| import numpy as np |
| import logging |
| from typing import Dict, List, Tuple, Optional, Union |
| from dataclasses import dataclass |
| from enum import Enum |
| import time |
|
|
| try: |
| import torch |
| import torch.nn.functional as F |
| TORCH_AVAILABLE = True |
| except ImportError: |
| TORCH_AVAILABLE = False |
| logging.warning("PyTorch not available - using CPU-only edge detection") |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class EdgeDetectionMethod(Enum): |
| """Available edge detection methods""" |
| CANNY = "canny" |
| SOBEL = "sobel" |
| LAPLACIAN = "laplacian" |
| SCHARR = "scharr" |
| PREWITT = "prewitt" |
| ROBERTS = "roberts" |
| MULTISCALE = "multiscale" |
| HAIR_OPTIMIZED = "hair_optimized" |
|
|
| @dataclass |
| class EdgeDetectionResult: |
| """Result container for edge detection""" |
| edges: np.ndarray |
| confidence_map: np.ndarray |
| edge_strength: float |
| processing_time: float |
| method_used: str |
| quality_score: float |
|
|
| class EdgeQualityMetrics: |
| """Calculate edge quality metrics""" |
| |
| @staticmethod |
| def calculate_edge_strength(edges: np.ndarray) -> float: |
| """Calculate overall edge strength""" |
| return np.mean(edges[edges > 0]) if np.any(edges > 0) else 0.0 |
| |
| @staticmethod |
| def calculate_edge_density(edges: np.ndarray) -> float: |
| """Calculate edge density (ratio of edge pixels)""" |
| return np.sum(edges > 0) / edges.size |
| |
| @staticmethod |
| def calculate_edge_continuity(edges: np.ndarray) -> float: |
| """Calculate edge continuity score""" |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) |
| dilated = cv2.dilate(edges, kernel, iterations=1) |
| eroded = cv2.erode(dilated, kernel, iterations=1) |
| |
| |
| original_pixels = np.sum(edges > 0) |
| preserved_pixels = np.sum(eroded > 0) |
| |
| return preserved_pixels / max(original_pixels, 1) |
| |
| @staticmethod |
| def calculate_edge_thickness_variation(edges: np.ndarray) -> float: |
| """Calculate variation in edge thickness""" |
| |
| dist_transform = cv2.distanceTransform( |
| (edges > 0).astype(np.uint8), |
| cv2.DIST_L2, |
| 5 |
| ) |
| |
| edge_pixels = edges > 0 |
| if not np.any(edge_pixels): |
| return 0.0 |
| |
| thicknesses = dist_transform[edge_pixels] |
| return np.std(thicknesses) / (np.mean(thicknesses) + 1e-6) |
| |
| @staticmethod |
| def calculate_overall_quality(edges: np.ndarray) -> float: |
| """Calculate overall edge quality score""" |
| strength = EdgeQualityMetrics.calculate_edge_strength(edges) |
| density = EdgeQualityMetrics.calculate_edge_density(edges) |
| continuity = EdgeQualityMetrics.calculate_edge_continuity(edges) |
| thickness_var = EdgeQualityMetrics.calculate_edge_thickness_variation(edges) |
| |
| |
| quality = ( |
| strength * 0.3 + |
| density * 0.2 + |
| continuity * 0.4 + |
| (1.0 - min(thickness_var, 1.0)) * 0.1 |
| ) |
| |
| return min(quality, 1.0) |
|
|
| class BaseEdgeDetector: |
| """Base class for edge detectors""" |
| |
| def __init__(self, name: str): |
| self.name = name |
| |
| def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: |
| """Detect edges in image""" |
| raise NotImplementedError |
| |
| def get_default_params(self) -> Dict: |
| """Get default parameters""" |
| return {} |
|
|
| class CannyEdgeDetector(BaseEdgeDetector): |
| """Canny edge detector with adaptive thresholds""" |
| |
| def __init__(self): |
| super().__init__("Canny") |
| |
| def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: |
| """Detect edges using Canny""" |
| |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| |
| |
| low_threshold = kwargs.get('low_threshold', None) |
| high_threshold = kwargs.get('high_threshold', None) |
| |
| if low_threshold is None or high_threshold is None: |
| |
| _, otsu_thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| low_threshold = 0.5 * otsu_thresh |
| high_threshold = otsu_thresh |
| |
| |
| blur_kernel = kwargs.get('blur_kernel', 5) |
| if blur_kernel > 0: |
| gray = cv2.GaussianBlur(gray, (blur_kernel, blur_kernel), 0) |
| |
| |
| edges = cv2.Canny( |
| gray, |
| int(low_threshold), |
| int(high_threshold), |
| apertureSize=kwargs.get('aperture_size', 3), |
| L2gradient=kwargs.get('l2_gradient', False) |
| ) |
| |
| return edges.astype(np.float32) / 255.0 |
| |
| def get_default_params(self) -> Dict: |
| return { |
| 'low_threshold': None, |
| 'high_threshold': None, |
| 'blur_kernel': 5, |
| 'aperture_size': 3, |
| 'l2_gradient': False |
| } |
|
|
| class HairOptimizedEdgeDetector(BaseEdgeDetector): |
| """Hair-specific edge detection optimized for fine details""" |
| |
| def __init__(self): |
| super().__init__("HairOptimized") |
| |
| def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: |
| """Detect hair edges using multi-scale approach""" |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| |
| |
| scales = kwargs.get('scales', [1.0, 0.7, 1.4]) |
| edge_maps = [] |
| |
| for scale in scales: |
| |
| if scale != 1.0: |
| h, w = gray.shape |
| new_h, new_w = int(h * scale), int(w * scale) |
| scaled_gray = cv2.resize(gray, (new_w, new_h)) |
| else: |
| scaled_gray = gray |
| |
| |
| scale_edges = self._detect_single_scale(scaled_gray, **kwargs) |
| |
| |
| if scale != 1.0: |
| scale_edges = cv2.resize(scale_edges, (gray.shape[1], gray.shape[0])) |
| |
| edge_maps.append(scale_edges) |
| |
| |
| combined_edges = self._combine_edge_maps(edge_maps) |
| |
| |
| refined_edges = self._hair_specific_refinement(combined_edges, gray) |
| |
| return refined_edges |
| |
| def _detect_single_scale(self, gray: np.ndarray, **kwargs) -> np.ndarray: |
| """Detect edges at single scale""" |
| |
| sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) |
| sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) |
| sobel_magnitude = np.sqrt(sobel_x**2 + sobel_y**2) |
| |
| |
| scharr_x = cv2.Scharr(gray, cv2.CV_64F, 1, 0) |
| scharr_y = cv2.Scharr(gray, cv2.CV_64F, 0, 1) |
| scharr_magnitude = np.sqrt(scharr_x**2 + scharr_y**2) |
| |
| |
| combined = 0.6 * sobel_magnitude + 0.4 * scharr_magnitude |
| |
| |
| combined = combined / (np.max(combined) + 1e-6) |
| |
| return combined.astype(np.float32) |
| |
| def _combine_edge_maps(self, edge_maps: List[np.ndarray]) -> np.ndarray: |
| """Combine multiple edge maps""" |
| |
| weights = [0.5, 0.25, 0.25] |
| |
| combined = np.zeros_like(edge_maps[0]) |
| for edge_map, weight in zip(edge_maps, weights): |
| combined += edge_map * weight |
| |
| return combined |
| |
| def _hair_specific_refinement(self, edges: np.ndarray, original: np.ndarray) -> np.ndarray: |
| """Apply hair-specific refinements""" |
| |
| kernel_thin = np.array([[-1, -1, -1], |
| [ 2, 2, 2], |
| [-1, -1, -1]]) / 3.0 |
| |
| thin_enhanced = cv2.filter2D(edges, -1, kernel_thin) |
| |
| |
| refined = 0.7 * edges + 0.3 * np.abs(thin_enhanced) |
| |
| |
| refined = self._thin_edge_nms(refined) |
| |
| return refined |
| |
| def _thin_edge_nms(self, edges: np.ndarray) -> np.ndarray: |
| """Non-maximum suppression optimized for thin edges""" |
| |
| kernel = np.ones((3, 3), np.uint8) |
| dilated = cv2.dilate(edges, kernel, iterations=1) |
| |
| |
| nms_edges = np.where(edges == dilated, edges, 0) |
| |
| return nms_edges |
|
|
| class MultiScaleEdgeDetector(BaseEdgeDetector): |
| """Multi-scale edge detection with scale fusion""" |
| |
| def __init__(self): |
| super().__init__("MultiScale") |
| |
| def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: |
| """Multi-scale edge detection""" |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| |
| scales = kwargs.get('scales', [0.5, 1.0, 1.5, 2.0]) |
| sigma_base = kwargs.get('sigma_base', 1.0) |
| |
| edge_pyramid = [] |
| |
| for scale in scales: |
| |
| sigma = sigma_base * scale |
| |
| |
| blurred = cv2.GaussianBlur(gray, (0, 0), sigma) |
| |
| |
| edges = cv2.Canny( |
| blurred, |
| int(50 / scale), |
| int(150 / scale), |
| apertureSize=3 |
| ) |
| |
| edge_pyramid.append(edges.astype(np.float32) / 255.0) |
| |
| |
| weights = np.array([0.1, 0.4, 0.3, 0.2]) |
| combined_edges = np.zeros_like(edge_pyramid[0]) |
| |
| for edges, weight in zip(edge_pyramid, weights): |
| combined_edges += edges * weight |
| |
| return combined_edges |
|
|
| class GPUEdgeDetector(BaseEdgeDetector): |
| """GPU-accelerated edge detection using PyTorch""" |
| |
| def __init__(self): |
| super().__init__("GPU") |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| if not TORCH_AVAILABLE: |
| logger.warning("PyTorch not available - GPU edge detection disabled") |
| |
| def detect(self, image: np.ndarray, **kwargs) -> np.ndarray: |
| """GPU-accelerated edge detection""" |
| if not TORCH_AVAILABLE: |
| |
| detector = CannyEdgeDetector() |
| return detector.detect(image, **kwargs) |
| |
| |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| |
| tensor = torch.from_numpy(gray).float().unsqueeze(0).unsqueeze(0).to(self.device) |
| tensor = tensor / 255.0 |
| |
| |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(self.device) |
| sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(self.device) |
| |
| |
| grad_x = F.conv2d(tensor, sobel_x, padding=1) |
| grad_y = F.conv2d(tensor, sobel_y, padding=1) |
| |
| |
| magnitude = torch.sqrt(grad_x**2 + grad_y**2) |
| |
| |
| threshold = kwargs.get('threshold', 0.1) |
| edges = (magnitude > threshold).float() |
| |
| |
| result = edges.squeeze().cpu().numpy() |
| |
| return result |
|
|
| class TemporalEdgeConsistency: |
| """Ensure temporal consistency in edge detection across frames""" |
| |
| def __init__(self, memory_frames: int = 3, consistency_threshold: float = 0.1): |
| self.memory_frames = memory_frames |
| self.consistency_threshold = consistency_threshold |
| self.frame_buffer = [] |
| |
| def apply_temporal_consistency(self, current_edges: np.ndarray) -> np.ndarray: |
| """Apply temporal consistency to current frame edges""" |
| if len(self.frame_buffer) == 0: |
| |
| self.frame_buffer.append(current_edges.copy()) |
| return current_edges |
| |
| |
| consistent_edges = self._calculate_consistent_edges(current_edges) |
| |
| |
| self.frame_buffer.append(current_edges.copy()) |
| if len(self.frame_buffer) > self.memory_frames: |
| self.frame_buffer.pop(0) |
| |
| return consistent_edges |
| |
| def _calculate_consistent_edges(self, current_edges: np.ndarray) -> np.ndarray: |
| """Calculate temporally consistent edges""" |
| |
| weights = np.linspace(0.1, 0.9, len(self.frame_buffer)) |
| weights = weights / np.sum(weights) |
| |
| |
| avg_previous = np.zeros_like(current_edges) |
| for frame, weight in zip(self.frame_buffer, weights): |
| avg_previous += frame * weight |
| |
| |
| consistency_factor = 0.3 |
| blended_edges = (1 - consistency_factor) * current_edges + consistency_factor * avg_previous |
| |
| return blended_edges |
|
|
| class EdgeRefinementProcessor: |
| """Post-process edges for better quality""" |
| |
| @staticmethod |
| def remove_noise(edges: np.ndarray, min_area: int = 10) -> np.ndarray: |
| """Remove small noise components""" |
| |
| edges_uint8 = (edges * 255).astype(np.uint8) |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(edges_uint8, connectivity=8) |
| |
| |
| filtered_edges = np.zeros_like(edges) |
| for i in range(1, num_labels): |
| area = stats[i, cv2.CC_STAT_AREA] |
| if area >= min_area: |
| filtered_edges[labels == i] = edges[labels == i] |
| |
| return filtered_edges |
| |
| @staticmethod |
| def smooth_edges(edges: np.ndarray, iterations: int = 1) -> np.ndarray: |
| """Smooth edges while preserving structure""" |
| smoothed = edges.copy() |
| |
| for _ in range(iterations): |
| |
| smoothed = cv2.GaussianBlur(smoothed, (3, 3), 0.5) |
| |
| return smoothed |
| |
| @staticmethod |
| def enhance_hair_edges(edges: np.ndarray, original_image: np.ndarray) -> np.ndarray: |
| """Enhance edges specifically for hair""" |
| |
| if len(original_image.shape) == 3: |
| gray = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = original_image |
| |
| |
| |
| grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) |
| grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) |
| |
| |
| J11 = cv2.GaussianBlur(grad_x * grad_x, (5, 5), 1.0) |
| J22 = cv2.GaussianBlur(grad_y * grad_y, (5, 5), 1.0) |
| J12 = cv2.GaussianBlur(grad_x * grad_y, (5, 5), 1.0) |
| |
| |
| trace = J11 + J22 |
| det = J11 * J22 - J12 * J12 |
| |
| |
| coherence = np.divide( |
| (trace - 2 * np.sqrt(det + 1e-6))**2, |
| (trace + 1e-6)**2, |
| out=np.zeros_like(trace), |
| where=(trace + 1e-6) != 0 |
| ) |
| |
| |
| coherence = coherence / (np.max(coherence) + 1e-6) |
| |
| |
| enhanced_edges = edges * (1.0 + coherence * 0.5) |
| |
| return np.clip(enhanced_edges, 0, 1) |
|
|
| class EdgeDetectionPipeline: |
| """Main edge detection pipeline with multiple methods and post-processing""" |
| |
| def __init__(self, config: Optional[Dict] = None): |
| self.config = config or {} |
| self.detectors = {} |
| self.temporal_processor = TemporalEdgeConsistency( |
| memory_frames=self.config.get('temporal_memory', 3), |
| consistency_threshold=self.config.get('consistency_threshold', 0.1) |
| ) |
| self.refinement_processor = EdgeRefinementProcessor() |
| |
| |
| self._initialize_detectors() |
| |
| def _initialize_detectors(self): |
| """Initialize available edge detectors""" |
| self.detectors[EdgeDetectionMethod.CANNY] = CannyEdgeDetector() |
| self.detectors[EdgeDetectionMethod.HAIR_OPTIMIZED] = HairOptimizedEdgeDetector() |
| self.detectors[EdgeDetectionMethod.MULTISCALE] = MultiScaleEdgeDetector() |
| |
| if TORCH_AVAILABLE: |
| self.detectors[EdgeDetectionMethod.GPU] = GPUEdgeDetector() |
| |
| def detect_edges(self, |
| image: np.ndarray, |
| method: EdgeDetectionMethod = EdgeDetectionMethod.HAIR_OPTIMIZED, |
| apply_temporal_consistency: bool = True, |
| apply_refinement: bool = True, |
| **kwargs) -> EdgeDetectionResult: |
| """Detect edges with specified method and post-processing""" |
| |
| start_time = time.time() |
| |
| |
| if method not in self.detectors: |
| logger.warning(f"Method {method} not available, using Canny") |
| method = EdgeDetectionMethod.CANNY |
| |
| detector = self.detectors[method] |
| |
| |
| try: |
| edges = detector.detect(image, **kwargs) |
| except Exception as e: |
| logger.error(f"Edge detection failed with {method.value}: {e}") |
| |
| edges = self.detectors[EdgeDetectionMethod.CANNY].detect(image, **kwargs) |
| method = EdgeDetectionMethod.CANNY |
| |
| |
| if apply_temporal_consistency: |
| edges = self.temporal_processor.apply_temporal_consistency(edges) |
| |
| |
| if apply_refinement: |
| |
| edges = self.refinement_processor.remove_noise( |
| edges, |
| min_area=self.config.get('min_edge_area', 10) |
| ) |
| |
| |
| edges = self.refinement_processor.smooth_edges( |
| edges, |
| iterations=self.config.get('smoothing_iterations', 1) |
| ) |
| |
| |
| edges = self.refinement_processor.enhance_hair_edges(edges, image) |
| |
| |
| processing_time = time.time() - start_time |
| quality_score = EdgeQualityMetrics.calculate_overall_quality(edges) |
| edge_strength = EdgeQualityMetrics.calculate_edge_strength(edges) |
| |
| |
| confidence_map = edges.copy() |
| |
| return EdgeDetectionResult( |
| edges=edges, |
| confidence_map=confidence_map, |
| edge_strength=edge_strength, |
| processing_time=processing_time, |
| method_used=method.value, |
| quality_score=quality_score |
| ) |
| |
| def get_best_method_for_image(self, image: np.ndarray) -> EdgeDetectionMethod: |
| """Automatically select best edge detection method for image""" |
| |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| |
| |
| contrast = np.std(gray) |
| brightness = np.mean(gray) |
| |
| |
| if contrast > 50: |
| return EdgeDetectionMethod.CANNY |
| |
| |
| if contrast < 20 or brightness < 50: |
| return EdgeDetectionMethod.HAIR_OPTIMIZED |
| |
| |
| return EdgeDetectionMethod.MULTISCALE |
|
|
| |
| def detect_hair_edges(image: np.ndarray, config: Optional[Dict] = None) -> EdgeDetectionResult: |
| """Convenience function to detect hair edges with optimal settings""" |
| pipeline = EdgeDetectionPipeline(config) |
| return pipeline.detect_edges( |
| image, |
| method=EdgeDetectionMethod.HAIR_OPTIMIZED, |
| apply_temporal_consistency=False, |
| apply_refinement=True |
| ) |
|
|
| def detect_video_edges(frames: List[np.ndarray], config: Optional[Dict] = None) -> List[EdgeDetectionResult]: |
| """Detect edges in video frames with temporal consistency""" |
| pipeline = EdgeDetectionPipeline(config) |
| results = [] |
| |
| for frame in frames: |
| result = pipeline.detect_edges( |
| frame, |
| method=EdgeDetectionMethod.HAIR_OPTIMIZED, |
| apply_temporal_consistency=True, |
| apply_refinement=True |
| ) |
| results.append(result) |
| |
| return results |
|
|
| |
| if __name__ == "__main__": |
| |
| test_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) |
| |
| |
| config = { |
| 'temporal_memory': 3, |
| 'consistency_threshold': 0.1, |
| 'min_edge_area': 10, |
| 'smoothing_iterations': 1 |
| } |
| |
| pipeline = EdgeDetectionPipeline(config) |
| |
| |
| methods = [ |
| EdgeDetectionMethod.CANNY, |
| EdgeDetectionMethod.HAIR_OPTIMIZED, |
| EdgeDetectionMethod.MULTISCALE |
| ] |
| |
| for method in methods: |
| if method in pipeline.detectors: |
| result = pipeline.detect_edges(test_image, method=method) |
| |
| print(f"\n{method.value} Results:") |
| print(f" Edge strength: {result.edge_strength:.3f}") |
| print(f" Quality score: {result.quality_score:.3f}") |
| print(f" Processing time: {result.processing_time:.3f}s") |
| |
| |
| best_method = pipeline.get_best_method_for_image(test_image) |
| print(f"\nBest method for this image: {best_method.value}") |