| """ |
| BackgroundFX Pro Models Module. |
| Comprehensive model management, optimization, and deployment. |
| """ |
|
|
| from .registry import ( |
| ModelRegistry, |
| ModelInfo, |
| ModelStatus, |
| ModelTask, |
| ModelFramework |
| ) |
|
|
| from .downloader import ( |
| ModelDownloader, |
| DownloadStatus, |
| DownloadProgress |
| ) |
|
|
| from .loaders.model_loader import ( |
| ModelLoader, |
| LoadedModel |
| ) |
|
|
| from .optimizer import ( |
| ModelOptimizer, |
| OptimizationResult |
| ) |
|
|
| __all__ = [ |
| |
| 'ModelRegistry', |
| 'ModelInfo', |
| 'ModelStatus', |
| 'ModelTask', |
| 'ModelFramework', |
| |
| |
| 'ModelDownloader', |
| 'DownloadStatus', |
| 'DownloadProgress', |
| |
| |
| 'ModelLoader', |
| 'LoadedModel', |
| |
| |
| 'ModelOptimizer', |
| 'OptimizationResult', |
| |
| |
| 'create_model_manager', |
| 'download_all_models', |
| 'optimize_for_deployment', |
| 'benchmark_models' |
| ] |
|
|
| |
| __version__ = '1.0.0' |
|
|
|
|
| class ModelManager: |
| """ |
| High-level model management interface. |
| Combines registry, downloading, loading, and optimization. |
| """ |
| |
| def __init__(self, models_dir: str = None, device: str = 'auto'): |
| """ |
| Initialize model manager. |
| |
| Args: |
| models_dir: Directory for model storage |
| device: Device for model loading |
| """ |
| from pathlib import Path |
| |
| self.models_dir = Path(models_dir) if models_dir else Path.home() / ".backgroundfx" / "models" |
| self.device = device |
| |
| |
| self.registry = ModelRegistry(self.models_dir) |
| self.downloader = ModelDownloader(self.registry) |
| self.loader = ModelLoader(self.registry, device=device) |
| self.optimizer = ModelOptimizer(self.loader) |
| |
| def setup(self, task: str = None, download: bool = True) -> bool: |
| """ |
| Setup models for a specific task. |
| |
| Args: |
| task: Task type (segmentation, matting, etc.) |
| download: Download missing models |
| |
| Returns: |
| True if setup successful |
| """ |
| if download: |
| return self.downloader.download_required_models(task) |
| return True |
| |
| def get_model(self, model_id: str = None, task: str = None) -> LoadedModel: |
| """ |
| Get a loaded model by ID or task. |
| |
| Args: |
| model_id: Specific model ID |
| task: Task type to find best model |
| |
| Returns: |
| Loaded model |
| """ |
| if model_id: |
| return self.loader.load_model(model_id) |
| elif task: |
| from .registry import ModelTask |
| task_enum = ModelTask(task) |
| best_model = self.registry.get_best_model(task_enum) |
| if best_model: |
| return self.loader.load_model(best_model.model_id) |
| return None |
| |
| def predict(self, input_data, model_id: str = None, task: str = None, **kwargs): |
| """ |
| Run prediction with a model. |
| |
| Args: |
| input_data: Input data |
| model_id: Model ID |
| task: Task type |
| **kwargs: Additional arguments |
| |
| Returns: |
| Prediction result |
| """ |
| if not model_id and task: |
| from .registry import ModelTask |
| task_enum = ModelTask(task) |
| best_model = self.registry.get_best_model(task_enum) |
| if best_model: |
| model_id = best_model.model_id |
| |
| if model_id: |
| return self.loader.predict(model_id, input_data, **kwargs) |
| return None |
| |
| def optimize(self, model_id: str, optimization_type: str = 'quantization', **kwargs): |
| """ |
| Optimize a model. |
| |
| Args: |
| model_id: Model to optimize |
| optimization_type: Type of optimization |
| **kwargs: Optimization parameters |
| |
| Returns: |
| Optimization result |
| """ |
| return self.optimizer.optimize_model(model_id, optimization_type, **kwargs) |
| |
| def benchmark(self, task: str = None) -> dict: |
| """ |
| Benchmark available models. |
| |
| Args: |
| task: Optional task filter |
| |
| Returns: |
| Benchmark results |
| """ |
| results = {} |
| |
| models = self.registry.list_models() |
| if task: |
| from .registry import ModelTask |
| task_enum = ModelTask(task) |
| models = [m for m in models if m.task == task_enum] |
| |
| for model_info in models: |
| if model_info.status == ModelStatus.AVAILABLE: |
| loaded = self.loader.load_model(model_info.model_id) |
| if loaded: |
| results[model_info.model_id] = { |
| 'name': model_info.name, |
| 'framework': model_info.framework.value, |
| 'size_mb': model_info.file_size / (1024 * 1024), |
| 'speed_fps': model_info.speed_fps, |
| 'accuracy': model_info.accuracy, |
| 'memory_mb': model_info.memory_mb, |
| 'load_time': loaded.load_time |
| } |
| |
| return results |
| |
| def cleanup(self, days: int = 30): |
| """ |
| Clean up unused models. |
| |
| Args: |
| days: Days threshold for unused models |
| |
| Returns: |
| List of removed models |
| """ |
| return self.registry.cleanup_unused_models(days) |
| |
| def get_stats(self) -> dict: |
| """Get model management statistics.""" |
| return { |
| 'registry': self.registry.get_statistics(), |
| 'loader': self.loader.get_memory_usage(), |
| 'downloads': { |
| model_id: progress.progress |
| for model_id, progress in self.downloader.get_all_progress().items() |
| } |
| } |
|
|
|
|
| |
|
|
| def create_model_manager(models_dir: str = None, device: str = 'auto') -> ModelManager: |
| """ |
| Create a model manager instance. |
| |
| Args: |
| models_dir: Directory for models |
| device: Device for loading |
| |
| Returns: |
| Model manager |
| """ |
| return ModelManager(models_dir, device) |
|
|
|
|
| def download_all_models(manager: ModelManager = None, force: bool = False) -> bool: |
| """ |
| Download all available models. |
| |
| Args: |
| manager: Model manager instance |
| force: Force re-download |
| |
| Returns: |
| True if all downloads successful |
| """ |
| if not manager: |
| manager = create_model_manager() |
| |
| models = manager.registry.list_models() |
| model_ids = [m.model_id for m in models] |
| |
| futures = manager.downloader.download_models_async(model_ids, force=force) |
| |
| success = True |
| for model_id, future in futures.items(): |
| try: |
| if not future.result(): |
| success = False |
| except: |
| success = False |
| |
| return success |
|
|
|
|
| def optimize_for_deployment(manager: ModelManager = None, |
| target: str = 'edge', |
| models: list = None) -> dict: |
| """ |
| Optimize models for deployment. |
| |
| Args: |
| manager: Model manager |
| target: Deployment target (edge, cloud, mobile) |
| models: Specific models to optimize |
| |
| Returns: |
| Optimization results |
| """ |
| if not manager: |
| manager = create_model_manager() |
| |
| results = {} |
| |
| |
| if target == 'edge': |
| optimization = 'quantization' |
| kwargs = {'quantization_type': 'dynamic'} |
| elif target == 'mobile': |
| optimization = 'coreml' if manager.device == 'mps' else 'tflite' |
| kwargs = {} |
| elif target == 'cloud': |
| optimization = 'tensorrt' if manager.device == 'cuda' else 'onnx' |
| kwargs = {'fp16': True} |
| else: |
| optimization = 'onnx' |
| kwargs = {} |
| |
| |
| if not models: |
| available = manager.registry.list_models(status=ModelStatus.AVAILABLE) |
| models = [m.model_id for m in available] |
| |
| |
| for model_id in models: |
| result = manager.optimize(model_id, optimization, **kwargs) |
| if result: |
| results[model_id] = result |
| |
| return results |
|
|
|
|
| def benchmark_models(manager: ModelManager = None, task: str = None) -> dict: |
| """ |
| Benchmark model performance. |
| |
| Args: |
| manager: Model manager |
| task: Optional task filter |
| |
| Returns: |
| Benchmark results |
| """ |
| if not manager: |
| manager = create_model_manager() |
| |
| return manager.benchmark(task) |