| """ |
| Advanced hair segmentation pipeline for BackgroundFX Pro. |
| Specialized module for accurate hair detection and segmentation. |
| """ |
|
|
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional, Tuple, Any |
| from dataclasses import dataclass |
| import logging |
| from scipy import ndimage |
| |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class HairConfig: |
| """Configuration for hair segmentation.""" |
| min_hair_confidence: float = 0.6 |
| edge_sensitivity: float = 0.8 |
| strand_detection: bool = True |
| strand_thickness: int = 2 |
| asymmetry_correction: bool = True |
| max_asymmetry_ratio: float = 1.5 |
| use_deep_features: bool = False |
| refinement_iterations: int = 3 |
| alpha_matting: bool = True |
| preserve_details: bool = True |
| smoothing_sigma: float = 1.0 |
|
|
|
|
| class HairSegmentationPipeline: |
| """Complete hair segmentation pipeline.""" |
| |
| def __init__(self, config: Optional[HairConfig] = None): |
| self.config = config or HairConfig() |
| self.mask_refiner = HairMaskRefiner(config) |
| self.asymmetry_detector = AsymmetryDetector(config) |
| self.edge_enhancer = HairEdgeEnhancer(config) |
| |
| |
| self.deep_model = None |
| if self.config.use_deep_features: |
| self.deep_model = HairNet() |
| |
| def segment(self, image: np.ndarray, |
| initial_mask: Optional[np.ndarray] = None, |
| prompts: Optional[Dict] = None) -> Dict[str, np.ndarray]: |
| """ |
| Perform complete hair segmentation. |
| |
| Returns: |
| Dictionary containing: |
| - 'mask': Final hair mask |
| - 'confidence': Confidence map |
| - 'strands': Fine hair strands mask |
| - 'edges': Hair edge map |
| """ |
| h, w = image.shape[:2] |
| |
| |
| hair_regions = self._detect_hair_regions(image, initial_mask) |
| |
| |
| if self.deep_model and self.config.use_deep_features: |
| deep_features = self._extract_deep_features(image) |
| hair_regions = self._enhance_with_deep_features(hair_regions, deep_features) |
| |
| |
| if self.config.asymmetry_correction: |
| asymmetry_info = self.asymmetry_detector.detect(hair_regions, image) |
| if asymmetry_info['is_asymmetric']: |
| logger.info(f"Correcting hair asymmetry: {asymmetry_info['score']:.3f}") |
| hair_regions = self.asymmetry_detector.correct( |
| hair_regions, asymmetry_info |
| ) |
| |
| |
| strands_mask = None |
| if self.config.strand_detection: |
| strands_mask = self._detect_hair_strands(image, hair_regions) |
| |
| hair_regions = self._integrate_strands(hair_regions, strands_mask) |
| |
| |
| refined_mask = self.mask_refiner.refine(image, hair_regions) |
| |
| |
| edges = self.edge_enhancer.enhance(refined_mask, image) |
| refined_mask = self._apply_edge_enhancement(refined_mask, edges) |
| |
| |
| if self.config.alpha_matting: |
| refined_mask = self._apply_alpha_matting(image, refined_mask) |
| |
| |
| final_mask = self._final_smoothing(refined_mask) |
| |
| |
| confidence = self._compute_confidence(final_mask, initial_mask) |
| |
| return { |
| 'mask': final_mask, |
| 'confidence': confidence, |
| 'strands': strands_mask, |
| 'edges': edges |
| } |
| |
| def _detect_hair_regions(self, image: np.ndarray, |
| initial_mask: Optional[np.ndarray]) -> np.ndarray: |
| """Detect hair regions using multiple cues.""" |
| |
| color_mask = self._detect_by_color(image) |
| |
| |
| texture_mask = self._detect_by_texture(image) |
| |
| |
| hair_probability = 0.6 * color_mask + 0.4 * texture_mask |
| |
| |
| if initial_mask is not None: |
| |
| kernel = np.ones((15, 15), np.uint8) |
| dilated_initial = cv2.dilate(initial_mask, kernel, iterations=2) |
| hair_probability *= dilated_initial |
| |
| |
| hair_mask = (hair_probability > self.config.min_hair_confidence).astype(np.float32) |
| |
| |
| hair_mask = self._remove_small_regions(hair_mask) |
| |
| return hair_mask |
| |
| def _detect_by_color(self, image: np.ndarray) -> np.ndarray: |
| """Detect hair by color characteristics.""" |
| |
| hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) |
| ycrcb = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb) |
| |
| masks = [] |
| |
| |
| black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 30)) |
| masks.append(black_mask) |
| |
| |
| brown_mask = cv2.inRange(hsv, (10, 20, 20), (20, 255, 100)) |
| masks.append(brown_mask) |
| |
| |
| blonde_mask = cv2.inRange(hsv, (15, 30, 50), (25, 255, 200)) |
| masks.append(blonde_mask) |
| |
| |
| red_mask = cv2.inRange(hsv, (0, 50, 50), (10, 255, 150)) |
| auburn_mask = cv2.inRange(hsv, (160, 50, 50), (180, 255, 150)) |
| masks.append(cv2.bitwise_or(red_mask, auburn_mask)) |
| |
| |
| gray_mask = cv2.inRange(hsv, (0, 0, 50), (180, 30, 200)) |
| masks.append(gray_mask) |
| |
| |
| combined = np.zeros_like(masks[0], dtype=np.float32) |
| for mask in masks: |
| combined = np.maximum(combined, mask.astype(np.float32) / 255.0) |
| |
| |
| combined = cv2.GaussianBlur(combined, (7, 7), 2.0) |
| |
| return combined |
| |
| def _detect_by_texture(self, image: np.ndarray) -> np.ndarray: |
| """Detect hair by texture characteristics.""" |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| |
| |
| texture_responses = [] |
| |
| |
| for scale in [3, 5, 7]: |
| for angle in [0, 30, 60, 90, 120, 150]: |
| theta = np.deg2rad(angle) |
| kernel = cv2.getGaborKernel( |
| (21, 21), scale, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F |
| ) |
| response = cv2.filter2D(gray, cv2.CV_32F, kernel) |
| texture_responses.append(np.abs(response)) |
| |
| |
| texture_map = np.mean(texture_responses, axis=0) |
| |
| |
| texture_map = (texture_map - np.min(texture_map)) / (np.max(texture_map) - np.min(texture_map) + 1e-6) |
| |
| |
| |
| coherence = self._compute_texture_coherence(texture_responses) |
| |
| |
| hair_texture = texture_map * coherence |
| |
| return hair_texture |
| |
| def _compute_texture_coherence(self, responses: List[np.ndarray]) -> np.ndarray: |
| """Compute texture coherence (consistency of orientation).""" |
| if len(responses) < 2: |
| return np.ones_like(responses[0]) |
| |
| |
| response_stack = np.stack(responses, axis=0) |
| variance = np.var(response_stack, axis=0) |
| mean = np.mean(response_stack, axis=0) + 1e-6 |
| |
| |
| coherence = 1.0 - np.minimum(variance / mean, 1.0) |
| |
| return coherence |
| |
| def _detect_hair_strands(self, image: np.ndarray, |
| hair_mask: np.ndarray) -> np.ndarray: |
| """Detect fine hair strands.""" |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| |
| |
| edges = cv2.Canny(gray, 10, 30) |
| |
| |
| lines = cv2.HoughLinesP( |
| edges, 1, np.pi/180, threshold=20, |
| minLineLength=10, maxLineGap=5 |
| ) |
| |
| |
| strand_mask = np.zeros_like(gray, dtype=np.float32) |
| |
| if lines is not None: |
| for line in lines: |
| x1, y1, x2, y2 = line[0] |
| |
| |
| mid_x, mid_y = (x1 + x2) // 2, (y1 + y2) // 2 |
| |
| |
| kernel = np.ones((15, 15), np.uint8) |
| dilated_hair = cv2.dilate(hair_mask, kernel, iterations=1) |
| |
| if dilated_hair[mid_y, mid_x] > 0: |
| |
| cv2.line(strand_mask, (x1, y1), (x2, y2), 1.0, self.config.strand_thickness) |
| |
| |
| ridges = filters.frangi(gray, sigmas=range(1, 4)) |
| ridges = (ridges - np.min(ridges)) / (np.max(ridges) - np.min(ridges) + 1e-6) |
| |
| |
| strand_mask = np.maximum(strand_mask, ridges * dilated_hair) |
| |
| |
| strand_mask = (strand_mask > 0.3).astype(np.float32) |
| strand_mask = cv2.morphologyEx(strand_mask, cv2.MORPH_CLOSE, np.ones((3, 3))) |
| |
| return strand_mask |
| |
| def _integrate_strands(self, hair_mask: np.ndarray, |
| strands_mask: np.ndarray) -> np.ndarray: |
| """Integrate detected strands into main hair mask.""" |
| if strands_mask is None: |
| return hair_mask |
| |
| |
| integrated = np.maximum(hair_mask, strands_mask * 0.8) |
| |
| |
| integrated = cv2.GaussianBlur(integrated, (5, 5), 1.0) |
| |
| return np.clip(integrated, 0, 1) |
| |
| def _extract_deep_features(self, image: np.ndarray) -> torch.Tensor: |
| """Extract deep features using neural network.""" |
| if not self.deep_model: |
| return None |
| |
| |
| input_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 |
| |
| |
| with torch.no_grad(): |
| features = self.deep_model.extract_features(input_tensor) |
| |
| return features |
| |
| def _enhance_with_deep_features(self, mask: np.ndarray, |
| features: torch.Tensor) -> np.ndarray: |
| """Enhance mask using deep features.""" |
| if features is None: |
| return mask |
| |
| |
| hair_prob = self.deep_model.process_features(features) |
| hair_prob = hair_prob.squeeze().cpu().numpy() |
| |
| |
| hair_prob = cv2.resize(hair_prob, (mask.shape[1], mask.shape[0])) |
| |
| |
| enhanced = 0.7 * mask + 0.3 * hair_prob |
| |
| return np.clip(enhanced, 0, 1) |
| |
| def _apply_alpha_matting(self, image: np.ndarray, |
| mask: np.ndarray) -> np.ndarray: |
| """Apply alpha matting for refined transparency.""" |
| |
| |
| |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| gray = gray.astype(np.float32) / 255.0 |
| |
| |
| radius = 20 |
| epsilon = 0.01 |
| |
| alpha = self._guided_filter(mask, gray, radius, epsilon) |
| |
| return np.clip(alpha, 0, 1) |
| |
| def _guided_filter(self, p: np.ndarray, I: np.ndarray, |
| radius: int, epsilon: float) -> np.ndarray: |
| """Guided filter implementation.""" |
| mean_I = cv2.boxFilter(I, cv2.CV_32F, (radius, radius)) |
| mean_p = cv2.boxFilter(p, cv2.CV_32F, (radius, radius)) |
| mean_Ip = cv2.boxFilter(I * p, cv2.CV_32F, (radius, radius)) |
| cov_Ip = mean_Ip - mean_I * mean_p |
| |
| mean_II = cv2.boxFilter(I * I, cv2.CV_32F, (radius, radius)) |
| var_I = mean_II - mean_I * mean_I |
| |
| a = cov_Ip / (var_I + epsilon) |
| b = mean_p - a * mean_I |
| |
| mean_a = cv2.boxFilter(a, cv2.CV_32F, (radius, radius)) |
| mean_b = cv2.boxFilter(b, cv2.CV_32F, (radius, radius)) |
| |
| q = mean_a * I + mean_b |
| |
| return q |
| |
| def _apply_edge_enhancement(self, mask: np.ndarray, |
| edges: np.ndarray) -> np.ndarray: |
| """Apply edge enhancement to mask.""" |
| |
| edge_weight = 0.3 |
| enhanced = mask + edge_weight * edges |
| |
| return np.clip(enhanced, 0, 1) |
| |
| def _final_smoothing(self, mask: np.ndarray) -> np.ndarray: |
| """Apply final smoothing while preserving details.""" |
| if self.config.preserve_details: |
| |
| smoothed = cv2.bilateralFilter( |
| (mask * 255).astype(np.uint8), 9, 75, 75 |
| ) / 255.0 |
| else: |
| |
| smoothed = cv2.GaussianBlur( |
| mask, (5, 5), self.config.smoothing_sigma |
| ) |
| |
| return smoothed |
| |
| def _compute_confidence(self, mask: np.ndarray, |
| initial_mask: Optional[np.ndarray]) -> np.ndarray: |
| """Compute confidence map for the segmentation.""" |
| |
| |
| distance_from_middle = np.abs(mask - 0.5) * 2 |
| confidence = distance_from_middle |
| |
| |
| if initial_mask is not None: |
| agreement = 1 - np.abs(mask - initial_mask) |
| confidence = 0.7 * confidence + 0.3 * agreement |
| |
| return np.clip(confidence, 0, 1) |
| |
| def _remove_small_regions(self, mask: np.ndarray, |
| min_size: int = 100) -> np.ndarray: |
| """Remove small disconnected regions.""" |
| |
| binary = (mask > 0.5).astype(np.uint8) |
| |
| |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary) |
| |
| |
| cleaned = np.zeros_like(mask) |
| for i in range(1, num_labels): |
| if stats[i, cv2.CC_STAT_AREA] >= min_size: |
| cleaned[labels == i] = mask[labels == i] |
| |
| return cleaned |
|
|
|
|
| class HairMaskRefiner: |
| """Refines hair masks for better quality.""" |
| |
| def __init__(self, config: HairConfig): |
| self.config = config |
| |
| def refine(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| """Refine hair mask through multiple iterations.""" |
| refined = mask.copy() |
| |
| for iteration in range(self.config.refinement_iterations): |
| |
| refined = self._refine_iteration(image, refined, iteration) |
| |
| return refined |
| |
| def _refine_iteration(self, image: np.ndarray, mask: np.ndarray, |
| iteration: int) -> np.ndarray: |
| """Single refinement iteration.""" |
| |
| kernel_size = 5 - iteration |
| kernel = cv2.getStructuringElement( |
| cv2.MORPH_ELLIPSE, (kernel_size, kernel_size) |
| ) |
| |
| |
| refined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| |
| |
| refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel) |
| |
| |
| refined = cv2.GaussianBlur(refined, (3, 3), 0.5) |
| |
| return refined |
|
|
|
|
| class AsymmetryDetector: |
| """Detects and corrects asymmetry in hair masks.""" |
| |
| def __init__(self, config: HairConfig): |
| self.config = config |
| |
| def detect(self, mask: np.ndarray, image: np.ndarray) -> Dict[str, Any]: |
| """Detect asymmetry in hair mask.""" |
| h, w = mask.shape[:2] |
| |
| |
| center_x = self._find_center_line(mask) |
| |
| |
| left_mask = mask[:, :center_x] |
| right_mask = mask[:, center_x:] |
| |
| |
| min_width = min(left_mask.shape[1], right_mask.shape[1]) |
| left_mask = left_mask[:, -min_width:] if left_mask.shape[1] > min_width else left_mask |
| right_mask = right_mask[:, :min_width] if right_mask.shape[1] > min_width else right_mask |
| |
| |
| right_flipped = np.fliplr(right_mask) |
| |
| |
| pixel_diff = np.mean(np.abs(left_mask - right_flipped)) |
| |
| |
| left_area = np.sum(left_mask > 0.5) |
| right_area = np.sum(right_mask > 0.5) |
| area_ratio = max(left_area, right_area) / (min(left_area, right_area) + 1e-6) |
| |
| |
| left_edges = cv2.Canny((left_mask * 255).astype(np.uint8), 50, 150) |
| right_edges = cv2.Canny((right_mask * 255).astype(np.uint8), 50, 150) |
| right_edges_flipped = np.fliplr(right_edges) |
| edge_diff = np.mean(np.abs(left_edges - right_edges_flipped)) / 255.0 |
| |
| |
| asymmetry_score = 0.4 * pixel_diff + 0.3 * (area_ratio - 1.0) / 2.0 + 0.3 * edge_diff |
| |
| is_asymmetric = (asymmetry_score > self.config.symmetry_threshold or |
| area_ratio > self.config.max_asymmetry_ratio) |
| |
| return { |
| 'is_asymmetric': is_asymmetric, |
| 'score': asymmetry_score, |
| 'center_x': center_x, |
| 'area_ratio': area_ratio, |
| 'pixel_diff': pixel_diff, |
| 'edge_diff': edge_diff |
| } |
| |
| def correct(self, mask: np.ndarray, asymmetry_info: Dict[str, Any]) -> np.ndarray: |
| """Correct detected asymmetry.""" |
| center_x = asymmetry_info['center_x'] |
| h, w = mask.shape[:2] |
| |
| |
| left_mask = mask[:, :center_x] |
| right_mask = mask[:, center_x:] |
| |
| |
| left_density = np.mean(left_mask > 0.5) |
| right_density = np.mean(right_mask > 0.5) |
| |
| |
| if left_density > right_density: |
| |
| reference = left_mask |
| mirrored = np.fliplr(reference) |
| |
| |
| corrected_right = 0.7 * mirrored[:, :right_mask.shape[1]] + 0.3 * right_mask |
| |
| |
| corrected = np.zeros_like(mask) |
| corrected[:, :center_x] = left_mask |
| corrected[:, center_x:center_x + corrected_right.shape[1]] = corrected_right |
| else: |
| |
| reference = right_mask |
| mirrored = np.fliplr(reference) |
| |
| |
| corrected_left = 0.7 * mirrored[:, -left_mask.shape[1]:] + 0.3 * left_mask |
| |
| |
| corrected = np.zeros_like(mask) |
| corrected[:, :center_x] = corrected_left |
| corrected[:, center_x:] = right_mask |
| |
| |
| seam_width = 10 |
| seam_start = max(0, center_x - seam_width) |
| seam_end = min(w, center_x + seam_width) |
| corrected[:, seam_start:seam_end] = cv2.GaussianBlur( |
| corrected[:, seam_start:seam_end], (7, 1), 2.0 |
| ) |
| |
| return corrected |
| |
| def _find_center_line(self, mask: np.ndarray) -> int: |
| """Find the vertical center line of the object.""" |
| |
| mask_binary = (mask > 0.5).astype(np.uint8) |
| moments = cv2.moments(mask_binary) |
| |
| if moments['m00'] > 0: |
| cx = int(moments['m10'] / moments['m00']) |
| else: |
| |
| cx = mask.shape[1] // 2 |
| |
| return cx |
|
|
|
|
| class HairEdgeEnhancer: |
| """Enhances edges in hair masks.""" |
| |
| def __init__(self, config: HairConfig): |
| self.config = config |
| |
| def enhance(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
| """Enhance hair edges for better quality.""" |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| |
| |
| edges = self._multi_scale_edges(gray) |
| |
| |
| mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 30, 100) / 255.0 |
| |
| |
| hair_edges = self._detect_hair_edges(gray, mask) |
| |
| |
| combined_edges = np.maximum(edges * 0.3, np.maximum(mask_edges * 0.3, hair_edges * 0.4)) |
| |
| |
| combined_edges = self._non_max_suppression(combined_edges) |
| |
| return combined_edges |
| |
| def _multi_scale_edges(self, gray: np.ndarray) -> np.ndarray: |
| """Detect edges at multiple scales.""" |
| edges_list = [] |
| |
| for scale in [1, 2, 3]: |
| |
| if scale > 1: |
| scaled = cv2.resize(gray, None, fx=1/scale, fy=1/scale) |
| else: |
| scaled = gray |
| |
| |
| edges = cv2.Canny(scaled, 30 * scale, 80 * scale) |
| |
| |
| if scale > 1: |
| edges = cv2.resize(edges, (gray.shape[1], gray.shape[0])) |
| |
| edges_list.append(edges / 255.0) |
| |
| |
| combined = np.mean(edges_list, axis=0) |
| |
| return combined |
| |
| def _detect_hair_edges(self, gray: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| """Detect edges specific to hair texture.""" |
| |
| hair_edges = np.zeros_like(gray, dtype=np.float32) |
| |
| |
| for angle in range(0, 180, 30): |
| theta = np.deg2rad(angle) |
| kernel = cv2.getGaborKernel( |
| (11, 11), 3.0, theta, 8.0, 0.5, 0, ktype=cv2.CV_32F |
| ) |
| |
| filtered = cv2.filter2D(gray, cv2.CV_32F, kernel) |
| hair_edges = np.maximum(hair_edges, np.abs(filtered)) |
| |
| |
| hair_edges = hair_edges / (np.max(hair_edges) + 1e-6) |
| |
| |
| hair_edges *= mask |
| |
| |
| hair_edges = (hair_edges > self.config.edge_sensitivity * 0.5).astype(np.float32) |
| |
| return hair_edges |
| |
| def _non_max_suppression(self, edges: np.ndarray) -> np.ndarray: |
| """Apply non-maximum suppression to edges.""" |
| |
| dx = cv2.Sobel(edges, cv2.CV_32F, 1, 0, ksize=3) |
| dy = cv2.Sobel(edges, cv2.CV_32F, 0, 1, ksize=3) |
| |
| |
| magnitude = np.sqrt(dx**2 + dy**2) |
| direction = np.arctan2(dy, dx) |
| |
| |
| direction = np.rad2deg(direction) |
| direction[direction < 0] += 180 |
| |
| |
| suppressed = np.zeros_like(magnitude) |
| |
| for i in range(1, magnitude.shape[0] - 1): |
| for j in range(1, magnitude.shape[1] - 1): |
| angle = direction[i, j] |
| mag = magnitude[i, j] |
| |
| |
| if (0 <= angle < 22.5) or (157.5 <= angle <= 180): |
| |
| neighbors = [magnitude[i, j-1], magnitude[i, j+1]] |
| elif 22.5 <= angle < 67.5: |
| |
| neighbors = [magnitude[i-1, j+1], magnitude[i+1, j-1]] |
| elif 67.5 <= angle < 112.5: |
| |
| neighbors = [magnitude[i-1, j], magnitude[i+1, j]] |
| else: |
| |
| neighbors = [magnitude[i-1, j-1], magnitude[i+1, j+1]] |
| |
| |
| if mag >= max(neighbors): |
| suppressed[i, j] = mag |
| |
| |
| suppressed = suppressed / (np.max(suppressed) + 1e-6) |
| |
| return suppressed |
|
|
|
|
| class HairNet(nn.Module): |
| """Simple neural network for hair feature extraction (placeholder).""" |
| |
| def __init__(self): |
| super().__init__() |
| |
| self.encoder = nn.Sequential( |
| nn.Conv2d(3, 32, 3, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(32, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.ReLU(), |
| ) |
| |
| self.decoder = nn.Sequential( |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(64, 32, 3, padding=1), |
| nn.ReLU(), |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(32, 1, 3, padding=1), |
| nn.Sigmoid() |
| ) |
| |
| def extract_features(self, x: torch.Tensor) -> torch.Tensor: |
| """Extract features from input image.""" |
| return self.encoder(x) |
| |
| def process_features(self, features: torch.Tensor) -> torch.Tensor: |
| """Process features to get hair probability.""" |
| return self.decoder(features) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass.""" |
| features = self.extract_features(x) |
| output = self.process_features(features) |
| return output |
|
|
|
|
| |
| def visualize_hair_segmentation(image: np.ndarray, |
| results: Dict[str, np.ndarray], |
| save_path: Optional[str] = None) -> np.ndarray: |
| """Visualize hair segmentation results.""" |
| h, w = image.shape[:2] |
| |
| |
| viz = np.zeros((h * 2, w * 2, 3), dtype=np.uint8) |
| |
| |
| viz[:h, :w] = image |
| |
| |
| mask_colored = np.zeros_like(image) |
| mask_colored[:, :, 1] = (results['mask'] * 255).astype(np.uint8) |
| overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0) |
| viz[:h, w:] = overlay |
| |
| |
| if 'confidence' in results: |
| confidence_colored = cv2.applyColorMap( |
| (results['confidence'] * 255).astype(np.uint8), |
| cv2.COLORMAP_JET |
| ) |
| viz[h:, :w] = confidence_colored |
| |
| |
| if 'edges' in results and 'strands' in results: |
| edges_viz = np.zeros_like(image) |
| edges_viz[:, :, 2] = (results['edges'] * 255).astype(np.uint8) |
| |
| if results['strands'] is not None: |
| edges_viz[:, :, 0] = (results['strands'] * 255).astype(np.uint8) |
| |
| viz[h:, w:] = edges_viz |
| |
| |
| cv2.putText(viz, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
| cv2.putText(viz, "Hair Mask", (w + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
| cv2.putText(viz, "Confidence", (10, h + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
| cv2.putText(viz, "Edges/Strands", (w + 10, h + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
| |
| if save_path: |
| cv2.imwrite(save_path, viz) |
| |
| return viz |
|
|
|
|
| |
| __all__ = [ |
| 'HairSegmentationPipeline', |
| 'HairConfig', |
| 'HairMaskRefiner', |
| 'AsymmetryDetector', |
| 'HairEdgeEnhancer', |
| 'HairNet', |
| 'visualize_hair_segmentation' |
| ] |