| """ |
| Professional Hair Segmentation Module |
| ===================================== |
| |
| This module provides high-quality hair segmentation for video processing |
| using SAM2 + MatAnyone pipeline with comprehensive error handling and fallbacks. |
| |
| Author: BackgroundFX Pro |
| License: MIT |
| """ |
|
|
| import os |
| import torch |
| import cv2 |
| import numpy as np |
| import logging |
| from typing import Dict, List, Tuple, Optional, Union |
| from pathlib import Path |
| import warnings |
| from dataclasses import dataclass |
| from abc import ABC, abstractmethod |
|
|
| |
| os.environ['OMP_NUM_THREADS'] = '4' |
| os.environ['MKL_NUM_THREADS'] = '4' |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| @dataclass |
| class SegmentationResult: |
| """Result container for hair segmentation""" |
| mask: np.ndarray |
| confidence: float |
| coverage_percent: float |
| asymmetry_score: float |
| processing_time: float |
| fallback_used: bool |
| quality_score: float |
| error_message: Optional[str] = None |
|
|
| class BaseSegmentationModel(ABC): |
| """Abstract base class for segmentation models""" |
| |
| @abstractmethod |
| def initialize(self) -> bool: |
| """Initialize the model""" |
| pass |
| |
| @abstractmethod |
| def segment(self, frame: np.ndarray) -> np.ndarray: |
| """Segment hair in frame""" |
| pass |
| |
| @abstractmethod |
| def get_model_name(self) -> str: |
| """Get model name for logging""" |
| pass |
|
|
| class SAM2Model(BaseSegmentationModel): |
| """SAM2 segmentation model wrapper""" |
| |
| def __init__(self, model_path: Optional[str] = None, device: str = 'auto'): |
| self.model_path = model_path |
| self.device = self._get_best_device(device) |
| self.predictor = None |
| self.initialized = False |
| |
| def _get_best_device(self, device: str) -> str: |
| """Determine best available device""" |
| if device == 'auto': |
| return 'cuda' if torch.cuda.is_available() else 'cpu' |
| return device |
| |
| def initialize(self) -> bool: |
| """Initialize SAM2 model""" |
| try: |
| logger.info("🤖 Initializing SAM2 model...") |
| |
| |
| try: |
| from sam2.build_sam import build_sam2_video_predictor |
| except ImportError: |
| logger.error("SAM2 not found. Please install SAM2.") |
| return False |
| |
| |
| if self.model_path and Path(self.model_path).exists(): |
| self.predictor = build_sam2_video_predictor(self.model_path, device=self.device) |
| else: |
| |
| self.predictor = build_sam2_video_predictor("sam2_hiera_large.pt", device=self.device) |
| |
| self.initialized = True |
| logger.info(f"✅ SAM2 initialized on {self.device}") |
| return True |
| |
| except Exception as e: |
| logger.error(f"❌ SAM2 initialization failed: {e}") |
| return False |
| |
| def segment(self, frame: np.ndarray) -> np.ndarray: |
| """Segment using SAM2""" |
| if not self.initialized: |
| raise RuntimeError("SAM2 model not initialized") |
| |
| try: |
| |
| if len(frame.shape) == 3: |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| else: |
| frame_rgb = frame |
| |
| |
| self.predictor.set_image(frame_rgb) |
| |
| |
| height, width = frame_rgb.shape[:2] |
| center_point = np.array([[width//2, height//2]]) |
| |
| |
| masks, scores, _ = self.predictor.predict( |
| point_coords=center_point, |
| point_labels=np.array([1]) |
| ) |
| |
| |
| if len(masks) > 0: |
| best_mask_idx = np.argmax(scores) |
| return masks[best_mask_idx].astype(np.float32) |
| else: |
| return np.zeros((height, width), dtype=np.float32) |
| |
| except Exception as e: |
| logger.error(f"SAM2 segmentation failed: {e}") |
| raise |
| |
| def get_model_name(self) -> str: |
| return "SAM2" |
|
|
| class MatAnyoneModel(BaseSegmentationModel): |
| """MatAnyone model wrapper with quality checking""" |
| |
| def __init__(self, use_hf_api: bool = True, hf_token: Optional[str] = None): |
| self.use_hf_api = use_hf_api |
| self.hf_token = hf_token |
| self.client = None |
| self.processor = None |
| self.initialized = False |
| self.quality_threshold = 0.3 |
| |
| def initialize(self) -> bool: |
| """Initialize MatAnyone model""" |
| try: |
| logger.info("🎭 Initializing MatAnyone model...") |
| |
| if self.use_hf_api: |
| from gradio_client import Client |
| self.client = Client("PeiqingYang/MatAnyone", hf_token=self.hf_token) |
| logger.info("✅ MatAnyone HF API initialized") |
| else: |
| |
| logger.warning("Local MatAnyone not implemented yet") |
| return False |
| |
| self.initialized = True |
| return True |
| |
| except Exception as e: |
| logger.error(f"❌ MatAnyone initialization failed: {e}") |
| return False |
| |
| def segment(self, frame: np.ndarray) -> np.ndarray: |
| """MatAnyone is primarily for matting, not segmentation""" |
| raise NotImplementedError("MatAnyone is used for matting, not direct segmentation") |
| |
| def matte(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: |
| """Apply matting using MatAnyone""" |
| if not self.initialized: |
| raise RuntimeError("MatAnyone model not initialized") |
| |
| try: |
| |
| import tempfile |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as img_file: |
| cv2.imwrite(img_file.name, image) |
| img_path = img_file.name |
| |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tri_file: |
| cv2.imwrite(tri_file.name, trimap) |
| tri_path = tri_file.name |
| |
| |
| if self.use_hf_api: |
| result = self._process_hf_api(img_path, tri_path) |
| else: |
| result = self._process_local(img_path, tri_path) |
| |
| |
| os.unlink(img_path) |
| os.unlink(tri_path) |
| |
| return result |
| |
| except Exception as e: |
| logger.error(f"MatAnyone matting failed: {e}") |
| raise |
| |
| def _process_hf_api(self, image_path: str, trimap_path: str) -> np.ndarray: |
| """Process using HuggingFace API""" |
| try: |
| result = self.client.predict( |
| image=image_path, |
| trimap=trimap_path, |
| api_name="/predict" |
| ) |
| |
| |
| if isinstance(result, str): |
| result_image = cv2.imread(result) |
| return result_image |
| else: |
| return result |
| |
| except Exception as e: |
| logger.error(f"HF API processing failed: {e}") |
| raise |
| |
| def _process_local(self, image_path: str, trimap_path: str) -> np.ndarray: |
| """Process locally - placeholder for implementation""" |
| raise NotImplementedError("Local MatAnyone processing not implemented") |
| |
| def get_model_name(self) -> str: |
| return "MatAnyone" |
|
|
| class TraditionalCVModel(BaseSegmentationModel): |
| """Traditional computer vision fallback""" |
| |
| def __init__(self): |
| self.initialized = False |
| |
| def initialize(self) -> bool: |
| """Initialize traditional CV methods""" |
| self.initialized = True |
| return True |
| |
| def segment(self, frame: np.ndarray) -> np.ndarray: |
| """Traditional hair segmentation using color and texture""" |
| try: |
| |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| lab = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) |
| |
| |
| hair_mask_hsv = self._detect_hair_hsv(hsv) |
| hair_mask_lab = self._detect_hair_lab(lab) |
| |
| |
| combined_mask = cv2.bitwise_or(hair_mask_hsv, hair_mask_lab) |
| |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) |
| combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) |
| combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel) |
| |
| return combined_mask.astype(np.float32) / 255.0 |
| |
| except Exception as e: |
| logger.error(f"Traditional CV segmentation failed: {e}") |
| raise |
| |
| def _detect_hair_hsv(self, hsv: np.ndarray) -> np.ndarray: |
| """Detect hair in HSV color space""" |
| |
| ranges = [ |
| |
| ([0, 0, 0], [180, 255, 80]), |
| |
| ([8, 50, 20], [25, 255, 200]), |
| |
| ([15, 30, 100], [35, 255, 255]) |
| ] |
| |
| masks = [] |
| for lower, upper in ranges: |
| mask = cv2.inRange(hsv, np.array(lower), np.array(upper)) |
| masks.append(mask) |
| |
| |
| final_mask = masks[0] |
| for mask in masks[1:]: |
| final_mask = cv2.bitwise_or(final_mask, mask) |
| |
| return final_mask |
| |
| def _detect_hair_lab(self, lab: np.ndarray) -> np.ndarray: |
| """Detect hair in LAB color space""" |
| l_channel = lab[:, :, 0] |
| hair_mask = cv2.inRange(l_channel, 0, 120) |
| return hair_mask |
| |
| def get_model_name(self) -> str: |
| return "TraditionalCV" |
|
|
| class TemporalSmoother: |
| """Temporal smoothing for video sequences""" |
| |
| def __init__(self, smoothing_factor: float = 0.7, change_threshold: float = 0.05): |
| self.smoothing_factor = smoothing_factor |
| self.change_threshold = change_threshold |
| self.previous_mask = None |
| self.correction_count = 0 |
| self.total_frames = 0 |
| |
| def smooth(self, current_mask: np.ndarray) -> Tuple[np.ndarray, bool]: |
| """Apply temporal smoothing""" |
| self.total_frames += 1 |
| corrected = False |
| |
| if self.previous_mask is not None: |
| |
| diff = np.mean(np.abs(current_mask - self.previous_mask)) |
| |
| if diff > self.change_threshold: |
| |
| smoothed_mask = (self.smoothing_factor * current_mask + |
| (1 - self.smoothing_factor) * self.previous_mask) |
| self.correction_count += 1 |
| corrected = True |
| else: |
| smoothed_mask = current_mask |
| else: |
| smoothed_mask = current_mask |
| |
| self.previous_mask = smoothed_mask.copy() |
| return smoothed_mask, corrected |
| |
| def get_correction_ratio(self) -> float: |
| """Get ratio of frames that needed correction""" |
| return self.correction_count / max(self.total_frames, 1) |
|
|
| class HairSegmentationPipeline: |
| """Main hair segmentation pipeline with multiple models and fallbacks""" |
| |
| def __init__(self, config: Optional[Dict] = None): |
| self.config = config or {} |
| self.models = {} |
| self.active_model = None |
| self.fallback_models = [] |
| self.temporal_smoother = TemporalSmoother() |
| self.initialized = False |
| |
| |
| self._setup_models() |
| |
| def _setup_models(self): |
| """Setup available models""" |
| try: |
| |
| sam2_model = SAM2Model( |
| model_path=self.config.get('sam2_model_path'), |
| device=self.config.get('device', 'auto') |
| ) |
| self.models['sam2'] = sam2_model |
| |
| |
| matanyone_model = MatAnyoneModel( |
| use_hf_api=self.config.get('use_hf_api', True), |
| hf_token=self.config.get('hf_token') |
| ) |
| self.models['matanyone'] = matanyone_model |
| |
| |
| traditional_model = TraditionalCVModel() |
| self.models['traditional'] = traditional_model |
| |
| except Exception as e: |
| logger.error(f"Model setup failed: {e}") |
| |
| def initialize(self, preferred_model: str = 'sam2') -> bool: |
| """Initialize the pipeline""" |
| logger.info("🚀 Initializing Hair Segmentation Pipeline...") |
| |
| |
| if preferred_model in self.models: |
| if self.models[preferred_model].initialize(): |
| self.active_model = preferred_model |
| logger.info(f"✅ Primary model {preferred_model} initialized") |
| else: |
| logger.warning(f"⚠️ Primary model {preferred_model} failed") |
| |
| |
| for model_name, model in self.models.items(): |
| if model_name != self.active_model: |
| if model.initialize(): |
| self.fallback_models.append(model_name) |
| logger.info(f"✅ Fallback model {model_name} ready") |
| |
| |
| if self.active_model or self.fallback_models: |
| self.initialized = True |
| logger.info(f"🎯 Pipeline ready - Active: {self.active_model}, Fallbacks: {self.fallback_models}") |
| return True |
| else: |
| logger.error("❌ No working models available") |
| return False |
| |
| def segment_frame(self, frame: np.ndarray, |
| apply_temporal_smoothing: bool = True) -> SegmentationResult: |
| """Segment hair in a single frame""" |
| if not self.initialized: |
| raise RuntimeError("Pipeline not initialized") |
| |
| import time |
| start_time = time.time() |
| |
| |
| mask, model_used, error_msg = self._try_segment_with_model(frame, self.active_model) |
| |
| |
| if mask is None: |
| for fallback_model in self.fallback_models: |
| mask, model_used, error_msg = self._try_segment_with_model(frame, fallback_model) |
| if mask is not None: |
| break |
| |
| if mask is None: |
| |
| h, w = frame.shape[:2] |
| mask = np.zeros((h, w), dtype=np.float32) |
| model_used = "none" |
| error_msg = "All models failed" |
| |
| |
| corrected = False |
| if apply_temporal_smoothing: |
| mask, corrected = self.temporal_smoother.smooth(mask) |
| |
| |
| processing_time = time.time() - start_time |
| confidence = self._calculate_confidence(mask) |
| coverage = self._calculate_coverage(mask) |
| asymmetry = self._calculate_asymmetry(mask) |
| quality = self._calculate_quality(mask) |
| |
| return SegmentationResult( |
| mask=mask, |
| confidence=confidence, |
| coverage_percent=coverage, |
| asymmetry_score=asymmetry, |
| processing_time=processing_time, |
| fallback_used=(model_used != self.active_model), |
| quality_score=quality, |
| error_message=error_msg |
| ) |
| |
| def _try_segment_with_model(self, frame: np.ndarray, model_name: str) -> Tuple[Optional[np.ndarray], str, Optional[str]]: |
| """Try to segment with a specific model""" |
| if model_name not in self.models: |
| return None, model_name, f"Model {model_name} not available" |
| |
| try: |
| mask = self.models[model_name].segment(frame) |
| return mask, model_name, None |
| except Exception as e: |
| error_msg = f"Model {model_name} failed: {str(e)}" |
| logger.warning(error_msg) |
| return None, model_name, error_msg |
| |
| def _calculate_confidence(self, mask: np.ndarray) -> float: |
| """Calculate mask confidence using OpenCV instead of skimage""" |
| |
| edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) |
| edge_ratio = np.sum(edges > 0) / mask.size |
| |
| |
| grad_x = cv2.Sobel(mask, cv2.CV_64F, 1, 0, ksize=3) |
| grad_y = cv2.Sobel(mask, cv2.CV_64F, 0, 1, ksize=3) |
| gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2) |
| smoothness = 1.0 / (1.0 + np.std(gradient_magnitude)) |
| |
| return min(edge_ratio * 0.3 + smoothness * 0.7, 1.0) |
| |
| def _calculate_coverage(self, mask: np.ndarray) -> float: |
| """Calculate hair coverage percentage""" |
| return (np.sum(mask > 0.5) / mask.size) * 100 |
| |
| def _calculate_asymmetry(self, mask: np.ndarray) -> float: |
| """Calculate left-right asymmetry score""" |
| h, w = mask.shape[:2] |
| center_x = w // 2 |
| |
| left_half = mask[:, :center_x] |
| right_half = np.fliplr(mask[:, center_x:]) |
| |
| min_width = min(left_half.shape[1], right_half.shape[1]) |
| left_half = left_half[:, :min_width] |
| right_half = right_half[:, :min_width] |
| |
| return np.mean(np.abs(left_half - right_half)) |
| |
| def _calculate_quality(self, mask: np.ndarray) -> float: |
| """Calculate overall mask quality""" |
| |
| confidence = self._calculate_confidence(mask) |
| coverage = self._calculate_coverage(mask) / 100.0 |
| asymmetry_penalty = 1.0 - min(self._calculate_asymmetry(mask), 1.0) |
| |
| return (confidence * 0.5 + coverage * 0.3 + asymmetry_penalty * 0.2) |
| |
| def get_pipeline_stats(self) -> Dict: |
| """Get pipeline performance statistics""" |
| return { |
| 'active_model': self.active_model, |
| 'fallback_models': self.fallback_models, |
| 'temporal_correction_ratio': self.temporal_smoother.get_correction_ratio(), |
| 'total_frames_processed': self.temporal_smoother.total_frames, |
| 'corrections_applied': self.temporal_smoother.correction_count |
| } |
|
|
| |
| def create_pipeline(config: Optional[Dict] = None) -> HairSegmentationPipeline: |
| """Create and initialize hair segmentation pipeline""" |
| pipeline = HairSegmentationPipeline(config) |
| pipeline.initialize() |
| return pipeline |
|
|
| def segment_image(image_path: str, config: Optional[Dict] = None) -> SegmentationResult: |
| """Segment hair in a single image""" |
| pipeline = create_pipeline(config) |
| frame = cv2.imread(image_path) |
| return pipeline.segment_frame(frame) |
|
|
| def segment_video_frames(video_frames: List[np.ndarray], |
| config: Optional[Dict] = None) -> List[SegmentationResult]: |
| """Segment hair in multiple video frames""" |
| pipeline = create_pipeline(config) |
| results = [] |
| |
| for frame in video_frames: |
| result = pipeline.segment_frame(frame) |
| results.append(result) |
| |
| return results |
|
|
| |
| if __name__ == "__main__": |
| |
| config = { |
| 'sam2_model_path': None, |
| 'device': 'auto', |
| 'use_hf_api': True, |
| 'hf_token': None |
| } |
| |
| |
| pipeline = create_pipeline(config) |
| |
| |
| test_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) |
| |
| |
| result = pipeline.segment_frame(test_frame) |
| |
| |
| print(f"Segmentation Results:") |
| print(f" Coverage: {result.coverage_percent:.1f}%") |
| print(f" Confidence: {result.confidence:.3f}") |
| print(f" Quality: {result.quality_score:.3f}") |
| print(f" Processing time: {result.processing_time:.2f}s") |
| print(f" Fallback used: {result.fallback_used}") |
| |
| |
| stats = pipeline.get_pipeline_stats() |
| print(f"\nPipeline Stats:") |
| print(f" Active model: {stats['active_model']}") |
| print(f" Fallbacks: {stats['fallback_models']}") |
| print(f" Correction ratio: {stats['temporal_correction_ratio']:.3f}") |