| |
| """ |
| Fallback strategies for BackgroundFX Pro. |
| Implements robust fallback mechanisms when primary processing fails. |
| """ |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from typing import Dict, List, Optional, Tuple, Any |
| from dataclasses import dataclass |
| from enum import Enum |
| import logging |
| import traceback |
|
|
| |
| from utils.logger import setup_logger |
| from utils.device import DeviceManager |
| from utils.config import ConfigManager |
| from core.quality import QualityAnalyzer |
|
|
| logger = setup_logger(__name__) |
|
|
| class FallbackLevel(Enum): |
| NONE = 0 |
| QUALITY_REDUCTION = 1 |
| METHOD_SWITCH = 2 |
| BASIC_PROCESSING = 3 |
| MINIMAL_PROCESSING = 4 |
| PASSTHROUGH = 5 |
|
|
| @dataclass |
| class FallbackConfig: |
| max_retries: int = 3 |
| quality_reduction_factor: float = 0.75 |
| min_quality: float = 0.3 |
| enable_caching: bool = True |
| cache_size: int = 10 |
| timeout_seconds: float = 30.0 |
| gpu_fallback_to_cpu: bool = True |
| progressive_downscale: bool = True |
| min_resolution: Tuple[int, int] = (320, 240) |
|
|
| class FallbackStrategy: |
| def __init__(self, config: Optional[FallbackConfig] = None): |
| self.config = config or FallbackConfig() |
| self.device_manager = DeviceManager() |
| self.quality_analyzer = QualityAnalyzer() |
| self.cache = {} |
| self.fallback_history = [] |
| self.current_level = FallbackLevel.NONE |
|
|
| def execute_with_fallback(self, func, *args, **kwargs) -> Dict[str, Any]: |
| attempt = 0 |
| last_error = None |
| original_args = args |
| original_kwargs = kwargs.copy() |
|
|
| while attempt < self.config.max_retries: |
| try: |
| logger.info(f"Attempt {attempt + 1}/{self.config.max_retries} for {func.__name__}") |
| result = func(*args, **kwargs) |
| self.current_level = FallbackLevel.NONE |
| return { |
| 'success': True, |
| 'result': result, |
| 'attempts': attempt + 1, |
| 'fallback_level': self.current_level |
| } |
| except Exception as e: |
| last_error = e |
| logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
| fallback_result = self._apply_fallback(func, e, attempt, original_args, original_kwargs) |
| if fallback_result['handled']: |
| args = fallback_result.get('new_args', args) |
| kwargs = fallback_result.get('new_kwargs', kwargs) |
| else: |
| break |
| attempt += 1 |
|
|
| logger.error(f"All attempts failed for {func.__name__}") |
| return self._final_fallback(func, last_error, original_args) |
|
|
| def _apply_fallback(self, func, error: Exception, attempt: int, original_args: tuple, original_kwargs: dict) -> Dict[str, Any]: |
| error_type = type(error).__name__ |
| self.fallback_history.append({ |
| 'function': func.__name__, |
| 'error': error_type, |
| 'attempt': attempt |
| }) |
|
|
| if 'CUDA' in str(error) or 'GPU' in str(error): |
| return self._handle_gpu_error(original_kwargs) |
| elif 'memory' in str(error).lower(): |
| return self._handle_memory_error(original_args, original_kwargs) |
| elif 'timeout' in str(error).lower(): |
| return self._handle_timeout_error(original_kwargs) |
| elif 'model' in str(error).lower(): |
| return self._handle_model_error(original_kwargs) |
| else: |
| return self._handle_generic_error(attempt, original_kwargs) |
|
|
| def _handle_gpu_error(self, kwargs: dict) -> Dict[str, Any]: |
| logger.info("GPU error detected, falling back to CPU") |
| if self.config.gpu_fallback_to_cpu: |
| self.device_manager.device = torch.device('cpu') |
| kwargs['device'] = 'cpu' |
| if 'batch_size' in kwargs: |
| kwargs['batch_size'] = max(1, kwargs['batch_size'] // 2) |
| self.current_level = FallbackLevel.METHOD_SWITCH |
| return { |
| 'handled': True, |
| 'new_kwargs': kwargs |
| } |
| return {'handled': False} |
|
|
| def _handle_memory_error(self, args: tuple, kwargs: dict) -> Dict[str, Any]: |
| logger.info("Memory error detected, reducing quality") |
| image = None |
| image_idx = -1 |
| for i, arg in enumerate(args): |
| if isinstance(arg, np.ndarray) and len(arg.shape) == 3: |
| image = arg |
| image_idx = i |
| break |
| if image is not None and self.config.progressive_downscale: |
| h, w = image.shape[:2] |
| new_h = int(h * self.config.quality_reduction_factor) |
| new_w = int(w * self.config.quality_reduction_factor) |
| new_h = max(new_h, self.config.min_resolution[1]) |
| new_w = max(new_w, self.config.min_resolution[0]) |
| if new_h < h or new_w < w: |
| resized = cv2.resize(image, (new_w, new_h)) |
| args = list(args) |
| args[image_idx] = resized |
| self.current_level = FallbackLevel.QUALITY_REDUCTION |
| return { |
| 'handled': True, |
| 'new_args': tuple(args), |
| 'new_kwargs': kwargs |
| } |
| if 'quality' in kwargs: |
| kwargs['quality'] = max( |
| self.config.min_quality, |
| kwargs['quality'] * self.config.quality_reduction_factor |
| ) |
| return { |
| 'handled': True, |
| 'new_kwargs': kwargs |
| } |
|
|
| def _handle_timeout_error(self, kwargs: dict) -> Dict[str, Any]: |
| logger.info("Timeout detected, simplifying processing") |
| simplifications = { |
| 'use_refinement': False, |
| 'use_temporal': False, |
| 'use_guided_filter': False, |
| 'iterations': 1, |
| 'num_samples': 1 |
| } |
| for key, value in simplifications.items(): |
| if key in kwargs: |
| kwargs[key] = value |
| self.current_level = FallbackLevel.BASIC_PROCESSING |
| return { |
| 'handled': True, |
| 'new_kwargs': kwargs |
| } |
|
|
| def _handle_model_error(self, kwargs: dict) -> Dict[str, Any]: |
| logger.info("Model error detected, using simpler model") |
| if 'model_type' in kwargs: |
| model_hierarchy = ['large', 'base', 'small', 'tiny'] |
| current = kwargs.get('model_type', 'base') |
| if current in model_hierarchy: |
| idx = model_hierarchy.index(current) |
| if idx < len(model_hierarchy) - 1: |
| kwargs['model_type'] = model_hierarchy[idx + 1] |
| self.current_level = FallbackLevel.METHOD_SWITCH |
| return { |
| 'handled': True, |
| 'new_kwargs': kwargs |
| } |
| kwargs['use_model'] = False |
| self.current_level = FallbackLevel.BASIC_PROCESSING |
| return { |
| 'handled': True, |
| 'new_kwargs': kwargs |
| } |
|
|
| def _handle_generic_error(self, attempt: int, kwargs: dict) -> Dict[str, Any]: |
| logger.info(f"Generic error, applying degradation level {attempt + 1}") |
| if attempt == 0: |
| self.current_level = FallbackLevel.QUALITY_REDUCTION |
| if 'quality' in kwargs: |
| kwargs['quality'] *= 0.8 |
| elif attempt == 1: |
| self.current_level = FallbackLevel.METHOD_SWITCH |
| kwargs['method'] = 'basic' |
| else: |
| self.current_level = FallbackLevel.MINIMAL_PROCESSING |
| kwargs['skip_refinement'] = True |
| kwargs['fast_mode'] = True |
| return { |
| 'handled': True, |
| 'new_kwargs': kwargs |
| } |
|
|
| def _final_fallback(self, func, error: Exception, original_args: tuple) -> Dict[str, Any]: |
| logger.error(f"Final fallback for {func.__name__}: {str(error)}") |
| self.current_level = FallbackLevel.PASSTHROUGH |
| for arg in original_args: |
| if isinstance(arg, np.ndarray): |
| return { |
| 'success': False, |
| 'result': arg, |
| 'fallback_level': self.current_level, |
| 'error': str(error) |
| } |
| return { |
| 'success': False, |
| 'result': None, |
| 'fallback_level': self.current_level, |
| 'error': str(error) |
| } |
|
|
| class ProcessingFallback: |
| def __init__(self): |
| self.logger = setup_logger(f"{__name__}.ProcessingFallback") |
| self.quality_analyzer = QualityAnalyzer() |
|
|
| def basic_segmentation(self, image: np.ndarray) -> np.ndarray: |
| try: |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| mask = np.zeros(gray.shape[:2], np.uint8) |
| bgd_model = np.zeros((1, 65), np.float64) |
| fgd_model = np.zeros((1, 65), np.float64) |
| h, w = gray.shape[:2] |
| rect = (int(w * 0.1), int(h * 0.1), int(w * 0.8), int(h * 0.8)) |
| cv2.grabCut(image, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT) |
| mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8') |
| return mask2 |
| except Exception as e: |
| self.logger.error(f"Basic segmentation failed: {e}") |
| return self._center_blob_mask(image.shape[:2]) |
|
|
| def _center_blob_mask(self, shape: Tuple[int, int]) -> np.ndarray: |
| h, w = shape |
| mask = np.zeros((h, w), dtype=np.uint8) |
| center = (w // 2, h // 2) |
| axes = (w // 3, h // 3) |
| cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1) |
| mask = cv2.GaussianBlur(mask, (21, 21), 10) |
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| return mask |
|
|
| def basic_matting(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| try: |
| if mask.dtype != np.uint8: |
| mask = (mask * 255).astype(np.uint8) |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| mask = cv2.GaussianBlur(mask, (5, 5), 2) |
| alpha = mask.astype(np.float32) / 255.0 |
| return alpha |
| except Exception as e: |
| self.logger.error(f"Basic matting failed: {e}") |
| return mask.astype(np.float32) / 255.0 |
|
|
| def color_difference_keying(self, image: np.ndarray, key_color: Optional[np.ndarray] = None, threshold: float = 30) -> np.ndarray: |
| try: |
| if key_color is None: |
| h, w = image.shape[:2] |
| corners = [ |
| image[0:10, 0:10], |
| image[0:10, w-10:w], |
| image[h-10:h, 0:10], |
| image[h-10:h, w-10:w] |
| ] |
| key_color = np.mean([np.mean(c, axis=(0, 1)) for c in corners], axis=0) |
| diff = np.sqrt(np.sum((image - key_color) ** 2, axis=2)) |
| mask = (diff > threshold).astype(np.float32) |
| mask = cv2.GaussianBlur(mask, (5, 5), 2) |
| return mask |
| except Exception as e: |
| self.logger.error(f"Color keying failed: {e}") |
| return np.ones(image.shape[:2], dtype=np.float32) |
|
|
| def edge_based_segmentation(self, image: np.ndarray) -> np.ndarray: |
| try: |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image |
| edges = cv2.Canny(gray, 50, 150) |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) |
| closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=2) |
| contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| mask = np.zeros(gray.shape, dtype=np.uint8) |
| if contours: |
| largest = max(contours, key=cv2.contourArea) |
| cv2.drawContours(mask, [largest], -1, 255, -1) |
| return mask |
| except Exception as e: |
| self.logger.error(f"Edge segmentation failed: {e}") |
| return self._center_blob_mask(image.shape[:2]) |
|
|
| def cached_result(self, cache_key: str, fallback_func, *args, **kwargs) -> Any: |
| if not hasattr(self, '_cache'): |
| self._cache = {} |
| if cache_key in self._cache: |
| self.logger.info(f"Using cached result for {cache_key}") |
| return self._cache[cache_key] |
| try: |
| result = fallback_func(*args, **kwargs) |
| self._cache[cache_key] = result |
| if len(self._cache) > 100: |
| keys = list(self._cache.keys()) |
| for key in keys[:20]: |
| del self._cache[key] |
| return result |
| except Exception as e: |
| self.logger.error(f"Cached computation failed: {e}") |
| return None |
|
|