| """ |
| Main processing pipeline for BackgroundFX Pro. |
| Orchestrates the complete background removal and replacement workflow. |
| """ |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from typing import Dict, List, Optional, Tuple, Union, Callable, Any |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from pathlib import Path |
| import time |
| import threading |
| from queue import Queue |
| import json |
| import hashlib |
| from concurrent.futures import ThreadPoolExecutor, Future |
|
|
| from ..utils.logger import setup_logger |
| from ..utils.device import DeviceManager |
| from ..utils.config import ConfigManager |
| from ..utils import TimeEstimator, MemoryMonitor |
|
|
| from ..core.models import ModelFactory, ModelType |
| from ..core.temporal import TemporalCoherence |
| from ..core.quality import QualityAnalyzer |
| from ..core.edge import EdgeRefinement |
| from ..core.hair_segmentation import HairSegmentation |
|
|
| from ..processing.matting import AlphaMatting, MattingConfig, CompositingEngine |
| from ..processing.fallback import FallbackStrategy, FallbackLevel |
| from ..processing.effects import BackgroundEffects, CompositeEffects, EffectType |
|
|
| logger = setup_logger(__name__) |
|
|
|
|
| class ProcessingMode(Enum): |
| """Processing mode types.""" |
| PHOTO = "photo" |
| VIDEO = "video" |
| REALTIME = "realtime" |
| BATCH = "batch" |
|
|
|
|
| class PipelineStage(Enum): |
| """Pipeline processing stages.""" |
| INITIALIZATION = "initialization" |
| PREPROCESSING = "preprocessing" |
| SEGMENTATION = "segmentation" |
| MATTING = "matting" |
| REFINEMENT = "refinement" |
| EFFECTS = "effects" |
| COMPOSITING = "compositing" |
| POSTPROCESSING = "postprocessing" |
| COMPLETE = "complete" |
|
|
|
|
| @dataclass |
| class PipelineConfig: |
| """Configuration for the processing pipeline.""" |
| |
| model_type: ModelType = ModelType.RMBG_1_4 |
| use_gpu: bool = True |
| device: Optional[str] = None |
| |
| |
| mode: ProcessingMode = ProcessingMode.PHOTO |
| enable_temporal: bool = True |
| enable_hair_refinement: bool = True |
| enable_edge_refinement: bool = True |
| enable_fallback: bool = True |
| |
| |
| quality_preset: str = "high" |
| target_resolution: Optional[Tuple[int, int]] = None |
| maintain_aspect_ratio: bool = True |
| |
| |
| matting_method: str = "auto" |
| matting_config: MattingConfig = field(default_factory=MattingConfig) |
| |
| |
| background_blur: bool = False |
| blur_strength: float = 15.0 |
| apply_effects: List[EffectType] = field(default_factory=list) |
| |
| |
| batch_size: int = 1 |
| num_workers: int = 4 |
| enable_caching: bool = True |
| cache_size_mb: int = 500 |
| |
| |
| output_format: str = "png" |
| output_quality: int = 95 |
| preserve_metadata: bool = True |
| |
| |
| progress_callback: Optional[Callable[[float, str], None]] = None |
| stage_callback: Optional[Callable[[PipelineStage, Dict], None]] = None |
|
|
|
|
| @dataclass |
| class PipelineResult: |
| """Result from pipeline processing.""" |
| success: bool |
| output_image: Optional[np.ndarray] = None |
| alpha_matte: Optional[np.ndarray] = None |
| foreground: Optional[np.ndarray] = None |
| background: Optional[np.ndarray] = None |
| metadata: Dict[str, Any] = field(default_factory=dict) |
| processing_time: float = 0.0 |
| stages_completed: List[PipelineStage] = field(default_factory=list) |
| errors: List[str] = field(default_factory=list) |
| quality_score: float = 0.0 |
|
|
|
|
| class ProcessingPipeline: |
| """Main processing pipeline orchestrator.""" |
| |
| def __init__(self, config: Optional[PipelineConfig] = None): |
| """ |
| Initialize the processing pipeline. |
| |
| Args: |
| config: Pipeline configuration |
| """ |
| self.config = config or PipelineConfig() |
| self.logger = setup_logger(f"{__name__}.ProcessingPipeline") |
| |
| |
| self._initialize_components() |
| |
| |
| self.current_stage = PipelineStage.INITIALIZATION |
| self.processing_stats = {} |
| self.cache = {} |
| self.is_processing = False |
| |
| |
| self.executor = ThreadPoolExecutor(max_workers=self.config.num_workers) |
| |
| self.logger.info("Pipeline initialized successfully") |
| |
| def _initialize_components(self): |
| """Initialize all pipeline components.""" |
| try: |
| |
| self.device_manager = DeviceManager() |
| if self.config.device: |
| self.device_manager.set_device(self.config.device) |
| elif not self.config.use_gpu: |
| self.device_manager.set_device('cpu') |
| |
| |
| self.model_factory = ModelFactory() |
| self.quality_analyzer = QualityAnalyzer() |
| self.edge_refinement = EdgeRefinement() |
| self.temporal_coherence = TemporalCoherence() if self.config.enable_temporal else None |
| self.hair_segmentation = HairSegmentation() if self.config.enable_hair_refinement else None |
| |
| |
| self.alpha_matting = AlphaMatting(self.config.matting_config) |
| self.compositing_engine = CompositingEngine() |
| self.background_effects = BackgroundEffects() |
| self.composite_effects = CompositeEffects() |
| |
| |
| self.fallback_strategy = FallbackStrategy() if self.config.enable_fallback else None |
| |
| |
| self.memory_monitor = MemoryMonitor() |
| self.time_estimator = TimeEstimator() |
| |
| |
| self._load_model() |
| |
| except Exception as e: |
| self.logger.error(f"Component initialization failed: {e}") |
| raise |
| |
| def _load_model(self): |
| """Load the segmentation model.""" |
| try: |
| self.logger.info(f"Loading model: {self.config.model_type.value}") |
| |
| self.model = self.model_factory.load_model( |
| self.config.model_type, |
| device=self.device_manager.get_device(), |
| optimize=True |
| ) |
| |
| self.logger.info("Model loaded successfully") |
| |
| except Exception as e: |
| self.logger.error(f"Model loading failed: {e}") |
| if self.config.enable_fallback: |
| self.logger.info("Attempting fallback model loading") |
| self.config.model_type = ModelType.U2NET_LITE |
| self.model = self.model_factory.load_model( |
| self.config.model_type, |
| device='cpu' |
| ) |
| |
| def process_image(self, |
| image: Union[np.ndarray, str, Path], |
| background: Optional[Union[np.ndarray, str, Path]] = None, |
| **kwargs) -> PipelineResult: |
| """ |
| Process a single image through the pipeline. |
| |
| Args: |
| image: Input image (array or path) |
| background: Optional background image/path |
| **kwargs: Additional processing parameters |
| |
| Returns: |
| PipelineResult with processed image and metadata |
| """ |
| start_time = time.time() |
| self.is_processing = True |
| result = PipelineResult(success=False) |
| |
| try: |
| |
| self._update_stage(PipelineStage.INITIALIZATION) |
| image_array = self._load_image(image) |
| bg_array = self._load_image(background) if background is not None else None |
| |
| |
| cache_key = self._generate_cache_key(image_array, kwargs) |
| |
| |
| if self.config.enable_caching and cache_key in self.cache: |
| self.logger.info("Using cached result") |
| cached_result = self.cache[cache_key] |
| cached_result.processing_time = time.time() - start_time |
| return cached_result |
| |
| |
| self._update_stage(PipelineStage.PREPROCESSING) |
| preprocessed = self._preprocess_image(image_array) |
| result.metadata['original_size'] = image_array.shape[:2] |
| result.metadata['preprocessed_size'] = preprocessed.shape[:2] |
| |
| |
| quality_metrics = self.quality_analyzer.analyze_frame(preprocessed) |
| result.metadata['quality_metrics'] = quality_metrics |
| |
| |
| self._update_stage(PipelineStage.SEGMENTATION) |
| segmentation_mask = self._segment_image(preprocessed) |
| |
| |
| if self.config.enable_hair_refinement: |
| self.logger.info("Applying hair refinement") |
| hair_mask = self.hair_segmentation.segment_hair(preprocessed) |
| segmentation_mask = self._combine_masks(segmentation_mask, hair_mask) |
| |
| |
| self._update_stage(PipelineStage.MATTING) |
| matting_result = self.alpha_matting.process( |
| preprocessed, |
| segmentation_mask, |
| method=self.config.matting_method |
| ) |
| alpha_matte = matting_result['alpha'] |
| result.metadata['matting_confidence'] = matting_result['confidence'] |
| |
| |
| self._update_stage(PipelineStage.REFINEMENT) |
| if self.config.enable_edge_refinement: |
| alpha_matte = self.edge_refinement.refine_edges( |
| preprocessed, |
| (alpha_matte * 255).astype(np.uint8) |
| ) / 255.0 |
| |
| |
| if preprocessed.shape[:2] != image_array.shape[:2]: |
| alpha_matte = cv2.resize( |
| alpha_matte, |
| (image_array.shape[1], image_array.shape[0]), |
| interpolation=cv2.INTER_LINEAR |
| ) |
| |
| |
| foreground = self._extract_foreground(image_array, alpha_matte) |
| |
| |
| self._update_stage(PipelineStage.EFFECTS) |
| |
| if bg_array is not None: |
| |
| bg_array = self._resize_background(bg_array, image_array.shape[:2]) |
| |
| |
| if self.config.background_blur: |
| bg_array = self.background_effects.apply_blur( |
| bg_array, |
| strength=self.config.blur_strength, |
| mask=1 - alpha_matte |
| ) |
| |
| |
| if self.config.apply_effects: |
| bg_array = self._apply_effects(bg_array, alpha_matte) |
| else: |
| |
| bg_array = np.zeros_like(image_array) |
| |
| |
| self._update_stage(PipelineStage.COMPOSITING) |
| |
| if self.config.apply_effects and EffectType.LIGHT_WRAP in self.config.apply_effects: |
| foreground = self.background_effects.apply_light_wrap( |
| foreground, bg_array, alpha_matte |
| ) |
| |
| composited = self.compositing_engine.composite( |
| foreground, bg_array, alpha_matte |
| ) |
| |
| |
| if self.config.apply_effects: |
| composited = self._apply_post_effects(composited, alpha_matte) |
| |
| |
| self._update_stage(PipelineStage.POSTPROCESSING) |
| final_output = self._postprocess_image(composited, alpha_matte) |
| |
| |
| result.quality_score = self._calculate_quality_score( |
| final_output, alpha_matte, quality_metrics |
| ) |
| |
| |
| result.success = True |
| result.output_image = final_output |
| result.alpha_matte = alpha_matte |
| result.foreground = foreground |
| result.background = bg_array |
| result.stages_completed = list(PipelineStage) |
| result.processing_time = time.time() - start_time |
| |
| |
| if self.config.enable_caching: |
| self._cache_result(cache_key, result) |
| |
| |
| self._update_stage(PipelineStage.COMPLETE) |
| self.logger.info(f"Processing completed in {result.processing_time:.2f}s") |
| |
| |
| self._update_statistics(result) |
| |
| except Exception as e: |
| self.logger.error(f"Pipeline processing failed: {e}") |
| result.errors.append(str(e)) |
| |
| if self.config.enable_fallback and self.fallback_strategy: |
| self.logger.info("Attempting fallback processing") |
| result = self._fallback_processing(image_array, bg_array) |
| |
| finally: |
| self.is_processing = False |
| |
| return result |
| |
| def _preprocess_image(self, image: np.ndarray) -> np.ndarray: |
| """Preprocess image for optimal processing.""" |
| processed = image.copy() |
| |
| |
| if self.config.target_resolution: |
| target_h, target_w = self.config.target_resolution |
| h, w = image.shape[:2] |
| |
| if self.config.maintain_aspect_ratio: |
| scale = min(target_w / w, target_h / h) |
| new_w = int(w * scale) |
| new_h = int(h * scale) |
| else: |
| new_w, new_h = target_w, target_h |
| |
| if (new_w, new_h) != (w, h): |
| processed = cv2.resize(processed, (new_w, new_h), |
| interpolation=cv2.INTER_AREA) |
| |
| |
| if self.config.quality_preset == "low": |
| |
| processed = cv2.fastNlMeansDenoising(processed, None, 10, 7, 21) |
| elif self.config.quality_preset in ["high", "ultra"]: |
| |
| processed = cv2.detailEnhance(processed, sigma_s=10, sigma_r=0.15) |
| |
| return processed |
| |
| def _segment_image(self, image: np.ndarray) -> np.ndarray: |
| """Perform image segmentation.""" |
| try: |
| |
| with torch.no_grad(): |
| |
| input_tensor = self._prepare_input_tensor(image) |
| |
| |
| output = self.model(input_tensor) |
| |
| |
| if isinstance(output, tuple): |
| output = output[0] |
| |
| |
| mask = output.squeeze().cpu().numpy() |
| |
| |
| mask = (mask > 0.5).astype(np.uint8) * 255 |
| |
| |
| if mask.shape[:2] != image.shape[:2]: |
| mask = cv2.resize(mask, (image.shape[1], image.shape[0])) |
| |
| return mask |
| |
| except Exception as e: |
| self.logger.error(f"Segmentation failed: {e}") |
| if self.config.enable_fallback: |
| |
| from ..processing.fallback import ProcessingFallback |
| fallback = ProcessingFallback() |
| return fallback.basic_segmentation(image) |
| raise |
| |
| def _prepare_input_tensor(self, image: np.ndarray) -> torch.Tensor: |
| """Prepare image tensor for model input.""" |
| |
| model_size = 512 |
| resized = cv2.resize(image, (model_size, model_size)) |
| |
| |
| tensor = torch.from_numpy(resized.transpose(2, 0, 1)).float() |
| tensor = tensor.unsqueeze(0) / 255.0 |
| |
| |
| tensor = tensor.to(self.device_manager.get_device()) |
| |
| return tensor |
| |
| def _combine_masks(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: |
| """Combine two masks intelligently.""" |
| |
| m1 = mask1.astype(np.float32) / 255.0 |
| m2 = mask2.astype(np.float32) / 255.0 |
| |
| |
| combined = np.maximum(m1, m2) |
| |
| |
| return (combined * 255).astype(np.uint8) |
| |
| def _extract_foreground(self, image: np.ndarray, |
| alpha: np.ndarray) -> np.ndarray: |
| """Extract foreground using alpha matte.""" |
| if len(alpha.shape) == 2: |
| alpha = np.expand_dims(alpha, axis=2) |
| |
| if alpha.shape[2] == 1: |
| alpha = np.repeat(alpha, 3, axis=2) |
| |
| |
| foreground = image.astype(np.float32) * alpha |
| |
| return foreground.astype(np.uint8) |
| |
| def _resize_background(self, background: np.ndarray, |
| target_shape: Tuple[int, int]) -> np.ndarray: |
| """Resize background to match target shape.""" |
| h, w = target_shape |
| bg_h, bg_w = background.shape[:2] |
| |
| if (bg_h, bg_w) == (h, w): |
| return background |
| |
| |
| scale = max(h / bg_h, w / bg_w) |
| new_h = int(bg_h * scale) |
| new_w = int(bg_w * scale) |
| |
| |
| resized = cv2.resize(background, (new_w, new_h), |
| interpolation=cv2.INTER_LINEAR) |
| |
| |
| start_y = (new_h - h) // 2 |
| start_x = (new_w - w) // 2 |
| cropped = resized[start_y:start_y + h, start_x:start_x + w] |
| |
| return cropped |
| |
| def _apply_effects(self, image: np.ndarray, |
| mask: np.ndarray) -> np.ndarray: |
| """Apply configured effects to image.""" |
| result = image.copy() |
| |
| for effect in self.config.apply_effects: |
| if effect == EffectType.BOKEH: |
| result = self.background_effects.apply_bokeh(result) |
| elif effect == EffectType.VIGNETTE: |
| result = self.background_effects.add_vignette(result) |
| elif effect == EffectType.FILM_GRAIN: |
| result = self.background_effects.add_film_grain(result) |
| |
| return result |
| |
| def _apply_post_effects(self, image: np.ndarray, |
| mask: np.ndarray) -> np.ndarray: |
| """Apply post-composite effects.""" |
| result = image.copy() |
| |
| for effect in self.config.apply_effects: |
| if effect == EffectType.SHADOW: |
| result = self.background_effects.add_shadow(result, mask) |
| elif effect == EffectType.REFLECTION: |
| result = self.background_effects.add_reflection(result, mask) |
| elif effect == EffectType.GLOW: |
| result = self.background_effects.add_glow(result, mask) |
| elif effect == EffectType.CHROMATIC_ABERRATION: |
| result = self.background_effects.chromatic_aberration(result) |
| |
| return result |
| |
| def _postprocess_image(self, image: np.ndarray, |
| alpha: np.ndarray) -> np.ndarray: |
| """Apply final postprocessing.""" |
| result = image.copy() |
| |
| |
| if self.config.quality_preset in ["high", "ultra"]: |
| |
| lab = cv2.cvtColor(result, cv2.COLOR_BGR2LAB) |
| l, a, b = cv2.split(lab) |
| l = cv2.equalizeHist(l) |
| result = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR) |
| |
| |
| if self.config.quality_preset == "ultra": |
| kernel = np.array([[-1,-1,-1], |
| [-1, 9,-1], |
| [-1,-1,-1]]) |
| result = cv2.filter2D(result, -1, kernel) |
| |
| return result |
| |
| def _calculate_quality_score(self, image: np.ndarray, |
| alpha: np.ndarray, |
| metrics: Dict) -> float: |
| """Calculate overall quality score.""" |
| scores = [] |
| |
| |
| edge_score = metrics.get('edge_clarity', 0.5) |
| scores.append(edge_score) |
| |
| |
| alpha_std = np.std(alpha) |
| alpha_score = min(alpha_std * 2, 1.0) |
| scores.append(alpha_score) |
| |
| |
| quality_score = metrics.get('overall_quality', 0.5) |
| scores.append(quality_score) |
| |
| return np.mean(scores) |
| |
| def _load_image(self, source: Union[np.ndarray, str, Path]) -> np.ndarray: |
| """Load image from various sources.""" |
| if isinstance(source, np.ndarray): |
| return source |
| |
| path = Path(source) if not isinstance(source, Path) else source |
| if not path.exists(): |
| raise FileNotFoundError(f"Image not found: {path}") |
| |
| image = cv2.imread(str(path)) |
| if image is None: |
| raise ValueError(f"Failed to load image: {path}") |
| |
| return image |
| |
| def _generate_cache_key(self, image: np.ndarray, |
| params: Dict) -> str: |
| """Generate cache key for result.""" |
| |
| hasher = hashlib.md5() |
| hasher.update(image.tobytes()) |
| hasher.update(json.dumps(params, sort_keys=True).encode()) |
| return hasher.hexdigest() |
| |
| def _cache_result(self, key: str, result: PipelineResult): |
| """Cache processing result.""" |
| self.cache[key] = result |
| |
| |
| cache_memory = sum( |
| r.output_image.nbytes if r.output_image is not None else 0 |
| for r in self.cache.values() |
| ) |
| |
| max_bytes = self.config.cache_size_mb * 1024 * 1024 |
| |
| if cache_memory > max_bytes: |
| |
| for old_key in list(self.cache.keys())[:len(self.cache)//4]: |
| del self.cache[old_key] |
| |
| def _update_stage(self, stage: PipelineStage): |
| """Update current processing stage.""" |
| self.current_stage = stage |
| |
| if self.config.stage_callback: |
| self.config.stage_callback(stage, { |
| 'timestamp': time.time(), |
| 'memory_usage': self.memory_monitor.get_usage() |
| }) |
| |
| if self.config.progress_callback: |
| progress = list(PipelineStage).index(stage) / len(PipelineStage) |
| self.config.progress_callback(progress, stage.value) |
| |
| def _update_statistics(self, result: PipelineResult): |
| """Update processing statistics.""" |
| if 'total_processed' not in self.processing_stats: |
| self.processing_stats['total_processed'] = 0 |
| self.processing_stats['total_time'] = 0 |
| self.processing_stats['avg_quality'] = 0 |
| |
| self.processing_stats['total_processed'] += 1 |
| self.processing_stats['total_time'] += result.processing_time |
| self.processing_stats['avg_time'] = ( |
| self.processing_stats['total_time'] / |
| self.processing_stats['total_processed'] |
| ) |
| |
| |
| n = self.processing_stats['total_processed'] |
| old_avg = self.processing_stats['avg_quality'] |
| self.processing_stats['avg_quality'] = ( |
| (old_avg * (n - 1) + result.quality_score) / n |
| ) |
| |
| def _fallback_processing(self, image: np.ndarray, |
| background: Optional[np.ndarray]) -> PipelineResult: |
| """Fallback processing when main pipeline fails.""" |
| from ..processing.fallback import ProcessingFallback |
| |
| result = PipelineResult(success=False) |
| fallback = ProcessingFallback() |
| |
| try: |
| |
| mask = fallback.basic_segmentation(image) |
| |
| |
| alpha = fallback.basic_matting(image, mask) |
| |
| |
| if background is not None: |
| background = self._resize_background(background, image.shape[:2]) |
| output = self.compositing_engine.composite( |
| image, background, alpha |
| ) |
| else: |
| output = image |
| |
| result.success = True |
| result.output_image = output |
| result.alpha_matte = alpha |
| result.metadata['fallback_used'] = True |
| |
| except Exception as e: |
| self.logger.error(f"Fallback processing also failed: {e}") |
| result.errors.append(str(e)) |
| |
| return result |
| |
| def process_batch(self, images: List[Union[np.ndarray, str, Path]], |
| background: Optional[Union[np.ndarray, str, Path]] = None, |
| **kwargs) -> List[PipelineResult]: |
| """ |
| Process multiple images in batch. |
| |
| Args: |
| images: List of input images |
| background: Optional background for all images |
| **kwargs: Additional processing parameters |
| |
| Returns: |
| List of PipelineResults |
| """ |
| results = [] |
| total = len(images) |
| |
| self.logger.info(f"Processing batch of {total} images") |
| |
| |
| futures = [] |
| for i, image in enumerate(images): |
| future = self.executor.submit( |
| self.process_image, image, background, **kwargs |
| ) |
| futures.append(future) |
| |
| |
| for i, future in enumerate(futures): |
| try: |
| result = future.result(timeout=30) |
| results.append(result) |
| |
| if self.config.progress_callback: |
| progress = (i + 1) / total |
| self.config.progress_callback( |
| progress, |
| f"Processed {i + 1}/{total}" |
| ) |
| |
| except Exception as e: |
| self.logger.error(f"Batch item {i} failed: {e}") |
| results.append(PipelineResult( |
| success=False, |
| errors=[str(e)] |
| )) |
| |
| return results |
| |
| def get_statistics(self) -> Dict[str, Any]: |
| """Get processing statistics.""" |
| return { |
| **self.processing_stats, |
| 'cache_size': len(self.cache), |
| 'current_stage': self.current_stage.value, |
| 'is_processing': self.is_processing, |
| 'device': str(self.device_manager.get_device()), |
| 'model_type': self.config.model_type.value |
| } |
| |
| def clear_cache(self): |
| """Clear the result cache.""" |
| self.cache.clear() |
| self.logger.info("Cache cleared") |
| |
| def shutdown(self): |
| """Shutdown the pipeline and cleanup resources.""" |
| self.executor.shutdown(wait=True) |
| self.clear_cache() |
| |
| |
| if hasattr(self, 'model'): |
| del self.model |
| torch.cuda.empty_cache() |
| |
| self.logger.info("Pipeline shutdown complete") |