| """ |
| Edge processing and symmetry correction for BackgroundFX Pro. |
| Fixes hair segmentation asymmetry and improves edge quality. |
| """ |
|
|
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional, Tuple, Any |
| from dataclasses import dataclass |
| from scipy import ndimage, signal |
| from scipy.spatial import distance |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class EdgeConfig: |
| """Configuration for edge processing.""" |
| edge_thickness: int = 3 |
| smoothing_iterations: int = 2 |
| symmetry_threshold: float = 0.3 |
| hair_detection_sensitivity: float = 0.7 |
| refinement_radius: int = 5 |
| use_guided_filter: bool = True |
| bilateral_d: int = 9 |
| bilateral_sigma_color: float = 75 |
| bilateral_sigma_space: float = 75 |
| morphology_kernel_size: int = 5 |
| edge_preservation_weight: float = 0.8 |
|
|
|
|
| class EdgeProcessor: |
| """Main edge processing and refinement system.""" |
| |
| def __init__(self, config: Optional[EdgeConfig] = None): |
| self.config = config or EdgeConfig() |
| self.hair_segmentation = HairSegmentation(config) |
| self.edge_refinement = EdgeRefinement(config) |
| self.symmetry_corrector = SymmetryCorrector(config) |
| |
| def process(self, image: np.ndarray, mask: np.ndarray, |
| detect_hair: bool = True) -> np.ndarray: |
| """Process edges with full pipeline.""" |
| |
| edges = self._detect_edges(mask) |
| |
| |
| if detect_hair: |
| hair_mask = self.hair_segmentation.segment(image, mask) |
| mask = self._blend_hair_mask(mask, hair_mask) |
| |
| |
| mask = self.symmetry_corrector.correct(mask, image) |
| |
| |
| mask = self.edge_refinement.refine(image, mask, edges) |
| |
| |
| mask = self._final_smoothing(mask) |
| |
| return mask |
| |
| def _detect_edges(self, mask: np.ndarray) -> np.ndarray: |
| """Detect edges in mask.""" |
| |
| mask_uint8 = (mask * 255).astype(np.uint8) |
| |
| |
| edges1 = cv2.Canny(mask_uint8, 50, 150) |
| edges2 = cv2.Canny(mask_uint8, 30, 100) |
| edges3 = cv2.Canny(mask_uint8, 70, 200) |
| |
| |
| edges = np.maximum(edges1, np.maximum(edges2, edges3)) |
| |
| return edges / 255.0 |
| |
| def _blend_hair_mask(self, original_mask: np.ndarray, |
| hair_mask: np.ndarray) -> np.ndarray: |
| """Blend hair mask with original mask.""" |
| |
| hair_regions = hair_mask > 0.5 |
| |
| |
| alpha = 0.7 |
| blended = original_mask.copy() |
| blended[hair_regions] = ( |
| alpha * hair_mask[hair_regions] + |
| (1 - alpha) * original_mask[hair_regions] |
| ) |
| |
| return blended |
| |
| def _final_smoothing(self, mask: np.ndarray) -> np.ndarray: |
| """Apply final smoothing pass.""" |
| |
| if self.config.use_guided_filter: |
| mask = self._guided_filter(mask, mask) |
| |
| |
| kernel = cv2.getStructuringElement( |
| cv2.MORPH_ELLIPSE, |
| (self.config.morphology_kernel_size, self.config.morphology_kernel_size) |
| ) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| |
| return mask |
| |
| def _guided_filter(self, input_img: np.ndarray, |
| guidance: np.ndarray, |
| radius: int = 4, |
| epsilon: float = 0.2**2) -> np.ndarray: |
| """Apply guided filter for edge-preserving smoothing.""" |
| |
| mean_I = cv2.boxFilter(guidance, cv2.CV_64F, (radius, radius)) |
| mean_p = cv2.boxFilter(input_img, cv2.CV_64F, (radius, radius)) |
| mean_Ip = cv2.boxFilter(guidance * input_img, cv2.CV_64F, (radius, radius)) |
| cov_Ip = mean_Ip - mean_I * mean_p |
| |
| mean_II = cv2.boxFilter(guidance * guidance, cv2.CV_64F, (radius, radius)) |
| var_I = mean_II - mean_I * mean_I |
| |
| a = cov_Ip / (var_I + epsilon) |
| b = mean_p - a * mean_I |
| |
| mean_a = cv2.boxFilter(a, cv2.CV_64F, (radius, radius)) |
| mean_b = cv2.boxFilter(b, cv2.CV_64F, (radius, radius)) |
| |
| q = mean_a * guidance + mean_b |
| |
| return q |
|
|
|
|
| class HairSegmentation: |
| """Specialized hair segmentation module.""" |
| |
| def __init__(self, config: EdgeConfig): |
| self.config = config |
| self.hair_detector = HairDetector() |
| |
| def segment(self, image: np.ndarray, initial_mask: np.ndarray) -> np.ndarray: |
| """Segment hair regions with improved accuracy.""" |
| |
| hair_probability = self.hair_detector.detect(image) |
| |
| |
| hair_mask = self._refine_with_mask(hair_probability, initial_mask) |
| |
| |
| hair_mask = self._fix_hair_asymmetry(hair_mask, image) |
| |
| |
| hair_mask = self._enhance_hair_strands(hair_mask, image) |
| |
| return hair_mask |
| |
| def _refine_with_mask(self, hair_prob: np.ndarray, |
| initial_mask: np.ndarray) -> np.ndarray: |
| """Refine hair probability with initial mask.""" |
| |
| kernel = np.ones((15, 15), np.uint8) |
| dilated_mask = cv2.dilate(initial_mask, kernel, iterations=2) |
| |
| |
| refined = hair_prob * dilated_mask |
| |
| |
| threshold = self.config.hair_detection_sensitivity |
| hair_mask = (refined > threshold).astype(np.float32) |
| |
| |
| hair_mask = cv2.GaussianBlur(hair_mask, (5, 5), 1.0) |
| |
| return hair_mask |
| |
| def _fix_hair_asymmetry(self, mask: np.ndarray, |
| image: np.ndarray) -> np.ndarray: |
| """Fix asymmetry in hair segmentation.""" |
| h, w = mask.shape[:2] |
| center_x = w // 2 |
| |
| |
| left_mask = mask[:, :center_x] |
| right_mask = mask[:, center_x:] |
| |
| |
| right_flipped = np.fliplr(right_mask) |
| |
| |
| if left_mask.shape[1] == right_flipped.shape[1]: |
| diff = np.abs(left_mask - right_flipped) |
| asymmetry_score = np.mean(diff) |
| |
| if asymmetry_score > self.config.symmetry_threshold: |
| logger.info(f"Detected hair asymmetry: {asymmetry_score:.3f}") |
| |
| |
| balanced_left = 0.5 * left_mask + 0.5 * right_flipped |
| balanced_right = np.fliplr(0.5 * right_mask + 0.5 * np.fliplr(left_mask)) |
| |
| |
| mask[:, :center_x] = balanced_left |
| mask[:, center_x:center_x + balanced_right.shape[1]] = balanced_right |
| |
| return mask |
| |
| def _enhance_hair_strands(self, mask: np.ndarray, |
| image: np.ndarray) -> np.ndarray: |
| """Enhance fine hair strands.""" |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| |
| |
| enhanced_mask = mask.copy() |
| |
| |
| orientations = [0, 45, 90, 135] |
| gabor_responses = [] |
| |
| for angle in orientations: |
| theta = np.deg2rad(angle) |
| kernel = cv2.getGaborKernel( |
| (21, 21), 4.0, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F |
| ) |
| filtered = cv2.filter2D(gray, cv2.CV_32F, kernel) |
| gabor_responses.append(np.abs(filtered)) |
| |
| |
| gabor_max = np.max(gabor_responses, axis=0) |
| gabor_normalized = gabor_max / (np.max(gabor_max) + 1e-6) |
| |
| |
| hair_enhancement = gabor_normalized * (1 - mask) |
| enhanced_mask = np.clip(mask + 0.3 * hair_enhancement, 0, 1) |
| |
| return enhanced_mask |
|
|
|
|
| class HairDetector: |
| """Detects hair regions in images.""" |
| |
| def detect(self, image: np.ndarray) -> np.ndarray: |
| """Detect hair probability map.""" |
| |
| hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) |
| |
| |
| hair_colors = [ |
| |
| ((0, 0, 0), (180, 255, 30)), |
| |
| ((10, 20, 20), (20, 255, 100)), |
| |
| ((15, 30, 50), (25, 255, 200)), |
| |
| ((0, 50, 50), (10, 255, 150)), |
| ] |
| |
| hair_masks = [] |
| for (lower, upper) in hair_colors: |
| mask = cv2.inRange(hsv, np.array(lower), np.array(upper)) |
| hair_masks.append(mask) |
| |
| |
| color_mask = np.max(hair_masks, axis=0) / 255.0 |
| |
| |
| texture_mask = self._detect_hair_texture(image) |
| |
| |
| hair_probability = 0.6 * color_mask + 0.4 * texture_mask |
| |
| |
| hair_probability = cv2.GaussianBlur(hair_probability, (7, 7), 2.0) |
| |
| return hair_probability |
| |
| def _detect_hair_texture(self, image: np.ndarray) -> np.ndarray: |
| """Detect hair-like texture patterns.""" |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| |
| |
| texture_score = np.zeros_like(gray, dtype=np.float32) |
| |
| |
| dx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) |
| dy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) |
| |
| |
| magnitude = np.sqrt(dx**2 + dy**2) |
| orientation = np.arctan2(dy, dx) |
| |
| |
| |
| window_size = 9 |
| kernel = np.ones((window_size, window_size)) / (window_size**2) |
| |
| |
| orient_mean = cv2.filter2D(orientation, -1, kernel) |
| orient_sq_mean = cv2.filter2D(orientation**2, -1, kernel) |
| orient_var = orient_sq_mean - orient_mean**2 |
| |
| |
| texture_score = magnitude * np.exp(-orient_var) |
| |
| |
| texture_score = texture_score / (np.max(texture_score) + 1e-6) |
| |
| return texture_score |
|
|
|
|
| class EdgeRefinement: |
| """Refines edges for better quality.""" |
| |
| def __init__(self, config: EdgeConfig): |
| self.config = config |
| |
| def refine(self, image: np.ndarray, mask: np.ndarray, |
| edges: np.ndarray) -> np.ndarray: |
| """Refine mask edges.""" |
| |
| refined = self._bilateral_smooth(mask, image) |
| |
| |
| refined = self._snap_to_edges(refined, image, edges) |
| |
| |
| refined = self._subpixel_refinement(refined, image) |
| |
| |
| refined = self._apply_feathering(refined) |
| |
| return refined |
| |
| def _bilateral_smooth(self, mask: np.ndarray, |
| image: np.ndarray) -> np.ndarray: |
| """Apply bilateral filtering for edge-aware smoothing.""" |
| |
| mask_uint8 = (mask * 255).astype(np.uint8) |
| |
| |
| smoothed = cv2.bilateralFilter( |
| mask_uint8, |
| self.config.bilateral_d, |
| self.config.bilateral_sigma_color, |
| self.config.bilateral_sigma_space |
| ) |
| |
| return smoothed / 255.0 |
| |
| def _snap_to_edges(self, mask: np.ndarray, image: np.ndarray, |
| detected_edges: np.ndarray) -> np.ndarray: |
| """Snap mask boundaries to image edges.""" |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| image_edges = cv2.Canny(gray, 50, 150) / 255.0 |
| |
| |
| mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) / 255.0 |
| |
| |
| dist_transform = cv2.distanceTransform( |
| (1 - image_edges).astype(np.uint8), |
| cv2.DIST_L2, 5 |
| ) |
| |
| |
| snap_radius = self.config.refinement_radius |
| refined = mask.copy() |
| |
| |
| edge_region = cv2.dilate(mask_edges, np.ones((5, 5))) > 0 |
| |
| |
| close_to_image_edge = (dist_transform < snap_radius) & edge_region |
| refined[close_to_image_edge] = np.where( |
| mask[close_to_image_edge] > 0.5, 1.0, 0.0 |
| ) |
| |
| return refined |
| |
| def _subpixel_refinement(self, mask: np.ndarray, |
| image: np.ndarray) -> np.ndarray: |
| """Apply subpixel refinement to edges.""" |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| |
| |
| grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) |
| grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) |
| grad_mag = np.sqrt(grad_x**2 + grad_y**2) |
| |
| |
| grad_mag = grad_mag / (np.max(grad_mag) + 1e-6) |
| |
| |
| |
| refined = mask.copy() |
| strong_gradient = grad_mag > 0.3 |
| |
| refined[strong_gradient] = np.where( |
| mask[strong_gradient] > 0.5, |
| np.minimum(mask[strong_gradient] + 0.1, 1.0), |
| np.maximum(mask[strong_gradient] - 0.1, 0.0) |
| ) |
| |
| return refined |
| |
| def _apply_feathering(self, mask: np.ndarray, |
| radius: int = 3) -> np.ndarray: |
| """Apply feathering to edges.""" |
| |
| mask_binary = (mask > 0.5).astype(np.uint8) |
| |
| |
| dist_outside = cv2.distanceTransform( |
| mask_binary, cv2.DIST_L2, 5 |
| ) |
| |
| |
| dist_inside = cv2.distanceTransform( |
| 1 - mask_binary, cv2.DIST_L2, 5 |
| ) |
| |
| |
| feather_region = (dist_outside <= radius) | (dist_inside <= radius) |
| |
| if np.any(feather_region): |
| |
| alpha = np.zeros_like(mask) |
| alpha[dist_outside > radius] = 1.0 |
| alpha[feather_region] = dist_outside[feather_region] / radius |
| |
| |
| mask = mask * (1 - feather_region) + alpha * feather_region |
| |
| return mask |
|
|
|
|
| class SymmetryCorrector: |
| """Corrects asymmetry in masks.""" |
| |
| def __init__(self, config: EdgeConfig): |
| self.config = config |
| |
| def correct(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
| """Correct asymmetry in mask.""" |
| |
| center = self._find_center(mask) |
| |
| |
| asymmetry_score = self._compute_asymmetry(mask, center) |
| |
| if asymmetry_score > self.config.symmetry_threshold: |
| logger.info(f"Correcting asymmetry: {asymmetry_score:.3f}") |
| mask = self._balance_mask(mask, center) |
| |
| return mask |
| |
| def _find_center(self, mask: np.ndarray) -> int: |
| """Find vertical center of object.""" |
| |
| mask_binary = (mask > 0.5).astype(np.uint8) |
| |
| moments = cv2.moments(mask_binary) |
| if moments['m00'] > 0: |
| cx = int(moments['m10'] / moments['m00']) |
| return cx |
| else: |
| return mask.shape[1] // 2 |
| |
| def _compute_asymmetry(self, mask: np.ndarray, center: int) -> float: |
| """Compute asymmetry score.""" |
| h, w = mask.shape[:2] |
| |
| |
| left_width = center |
| right_width = w - center |
| min_width = min(left_width, right_width) |
| |
| if min_width <= 0: |
| return 0.0 |
| |
| |
| left = mask[:, center-min_width:center] |
| right = mask[:, center:center+min_width] |
| |
| |
| right_flipped = np.fliplr(right) |
| |
| |
| diff = np.abs(left - right_flipped) |
| asymmetry = np.mean(diff) |
| |
| return asymmetry |
| |
| def _balance_mask(self, mask: np.ndarray, center: int) -> np.ndarray: |
| """Balance mask to reduce asymmetry.""" |
| h, w = mask.shape[:2] |
| balanced = mask.copy() |
| |
| |
| left_width = center |
| right_width = w - center |
| min_width = min(left_width, right_width) |
| |
| if min_width <= 0: |
| return mask |
| |
| |
| left = mask[:, center-min_width:center] |
| right = mask[:, center:center+min_width] |
| |
| |
| left_confidence = np.mean(np.abs(left - 0.5)) |
| right_confidence = np.mean(np.abs(right - 0.5)) |
| |
| |
| total_conf = left_confidence + right_confidence + 1e-6 |
| left_weight = left_confidence / total_conf |
| right_weight = right_confidence / total_conf |
| |
| |
| balanced_left = left_weight * left + right_weight * np.fliplr(right) |
| balanced_right = right_weight * right + left_weight * np.fliplr(left) |
| |
| |
| balanced[:, center-min_width:center] = balanced_left |
| balanced[:, center:center+min_width] = balanced_right |
| |
| |
| seam_width = 5 |
| seam_start = max(0, center - seam_width) |
| seam_end = min(w, center + seam_width) |
| balanced[:, seam_start:seam_end] = cv2.GaussianBlur( |
| balanced[:, seam_start:seam_end], (5, 1), 1.0 |
| ) |
| |
| return balanced |
|
|
|
|
| |
| __all__ = [ |
| 'EdgeProcessor', |
| 'EdgeConfig', |
| 'HairSegmentation', |
| 'EdgeRefinement', |
| 'SymmetryCorrector', |
| 'HairDetector' |
| ] |