| """ |
| Model downloader for BackgroundFX Pro. |
| Handles downloading, caching, and verification of models. |
| """ |
|
|
| import os |
| import shutil |
| import tempfile |
| import hashlib |
| import requests |
| from pathlib import Path |
| from typing import Optional, Callable, Dict, Any, List |
| from dataclasses import dataclass |
| from enum import Enum |
| import time |
| import threading |
| from urllib.parse import urlparse |
| from concurrent.futures import ThreadPoolExecutor, Future |
| import logging |
|
|
| from .registry import ModelInfo, ModelStatus, ModelRegistry |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DownloadStatus(Enum): |
| """Download status.""" |
| PENDING = "pending" |
| DOWNLOADING = "downloading" |
| VERIFYING = "verifying" |
| EXTRACTING = "extracting" |
| COMPLETED = "completed" |
| FAILED = "failed" |
| CANCELLED = "cancelled" |
|
|
|
|
| @dataclass |
| class DownloadProgress: |
| """Download progress information.""" |
| model_id: str |
| status: DownloadStatus |
| current_bytes: int = 0 |
| total_bytes: int = 0 |
| speed_mbps: float = 0.0 |
| eta_seconds: float = 0.0 |
| error: Optional[str] = None |
| |
| @property |
| def progress(self) -> float: |
| """Get progress percentage.""" |
| if self.total_bytes > 0: |
| return (self.current_bytes / self.total_bytes) * 100 |
| return 0.0 |
|
|
|
|
| class ModelDownloader: |
| """Handle model downloading with progress tracking and resume support.""" |
| |
| def __init__(self, |
| registry: ModelRegistry, |
| max_workers: int = 3, |
| chunk_size: int = 8192, |
| timeout: int = 30, |
| max_retries: int = 3): |
| """ |
| Initialize model downloader. |
| |
| Args: |
| registry: Model registry instance |
| max_workers: Maximum concurrent downloads |
| chunk_size: Download chunk size in bytes |
| timeout: Request timeout in seconds |
| max_retries: Maximum retry attempts |
| """ |
| self.registry = registry |
| self.max_workers = max_workers |
| self.chunk_size = chunk_size |
| self.timeout = timeout |
| self.max_retries = max_retries |
| |
| |
| self.downloads: Dict[str, DownloadProgress] = {} |
| self.executor = ThreadPoolExecutor(max_workers=max_workers) |
| self.futures: Dict[str, Future] = {} |
| self._stop_events: Dict[str, threading.Event] = {} |
| |
| |
| self.cache_dir = registry.models_dir / ".cache" |
| self.cache_dir.mkdir(exist_ok=True) |
| |
| def download_model(self, |
| model_id: str, |
| progress_callback: Optional[Callable[[DownloadProgress], None]] = None, |
| force: bool = False) -> bool: |
| """ |
| Download a model. |
| |
| Args: |
| model_id: Model ID to download |
| progress_callback: Optional progress callback |
| force: Force re-download even if exists |
| |
| Returns: |
| True if download successful |
| """ |
| |
| model = self.registry.get_model(model_id) |
| if not model: |
| logger.error(f"Model not found: {model_id}") |
| return False |
| |
| |
| if not force and model.status == ModelStatus.AVAILABLE: |
| logger.info(f"Model already available: {model_id}") |
| return True |
| |
| |
| progress = DownloadProgress( |
| model_id=model_id, |
| status=DownloadStatus.PENDING, |
| total_bytes=model.file_size |
| ) |
| self.downloads[model_id] = progress |
| |
| |
| self._stop_events[model_id] = threading.Event() |
| |
| |
| future = self.executor.submit( |
| self._download_model_task, |
| model, |
| progress, |
| progress_callback, |
| force |
| ) |
| self.futures[model_id] = future |
| |
| |
| try: |
| return future.result() |
| except Exception as e: |
| logger.error(f"Download failed for {model_id}: {e}") |
| return False |
| |
| def download_models_async(self, |
| model_ids: List[str], |
| progress_callback: Optional[Callable[[str, DownloadProgress], None]] = None, |
| force: bool = False) -> Dict[str, Future]: |
| """ |
| Download multiple models asynchronously. |
| |
| Args: |
| model_ids: List of model IDs |
| progress_callback: Optional progress callback with model_id |
| force: Force re-download |
| |
| Returns: |
| Dictionary of futures |
| """ |
| futures = {} |
| |
| for model_id in model_ids: |
| model = self.registry.get_model(model_id) |
| if not model: |
| logger.warning(f"Model not found: {model_id}") |
| continue |
| |
| |
| if not force and model.status == ModelStatus.AVAILABLE: |
| continue |
| |
| |
| progress = DownloadProgress( |
| model_id=model_id, |
| status=DownloadStatus.PENDING, |
| total_bytes=model.file_size |
| ) |
| self.downloads[model_id] = progress |
| |
| |
| self._stop_events[model_id] = threading.Event() |
| |
| |
| def progress_wrapper(p): |
| if progress_callback: |
| progress_callback(model_id, p) |
| |
| |
| future = self.executor.submit( |
| self._download_model_task, |
| model, |
| progress, |
| progress_wrapper, |
| force |
| ) |
| futures[model_id] = future |
| self.futures[model_id] = future |
| |
| return futures |
| |
| def _download_model_task(self, |
| model: ModelInfo, |
| progress: DownloadProgress, |
| progress_callback: Optional[Callable], |
| force: bool) -> bool: |
| """ |
| Download model task. |
| |
| Args: |
| model: Model information |
| progress: Progress tracker |
| progress_callback: Progress callback |
| force: Force re-download |
| |
| Returns: |
| True if successful |
| """ |
| try: |
| |
| progress.status = DownloadStatus.DOWNLOADING |
| self._notify_progress(progress, progress_callback) |
| |
| |
| urls = [model.url] + model.mirror_urls |
| |
| for url in urls: |
| if self._stop_events[model.model_id].is_set(): |
| progress.status = DownloadStatus.CANCELLED |
| self._notify_progress(progress, progress_callback) |
| return False |
| |
| try: |
| |
| output_path = self.registry.models_dir / model.filename |
| success = self._download_file( |
| url, |
| output_path, |
| progress, |
| progress_callback, |
| model.model_id |
| ) |
| |
| if success: |
| |
| progress.status = DownloadStatus.VERIFYING |
| self._notify_progress(progress, progress_callback) |
| |
| if self._verify_download(output_path, model): |
| |
| model.status = ModelStatus.AVAILABLE |
| model.local_path = str(output_path) |
| model.download_date = time.time() |
| self.registry._save_registry() |
| |
| progress.status = DownloadStatus.COMPLETED |
| self._notify_progress(progress, progress_callback) |
| |
| logger.info(f"Successfully downloaded: {model.model_id}") |
| return True |
| else: |
| |
| output_path.unlink(missing_ok=True) |
| logger.warning(f"Verification failed for {model.model_id}") |
| |
| except Exception as e: |
| logger.warning(f"Download failed from {url}: {e}") |
| continue |
| |
| |
| progress.status = DownloadStatus.FAILED |
| progress.error = "All download attempts failed" |
| self._notify_progress(progress, progress_callback) |
| return False |
| |
| except Exception as e: |
| progress.status = DownloadStatus.FAILED |
| progress.error = str(e) |
| self._notify_progress(progress, progress_callback) |
| logger.error(f"Download task failed: {e}") |
| return False |
| |
| def _download_file(self, |
| url: str, |
| output_path: Path, |
| progress: DownloadProgress, |
| progress_callback: Optional[Callable], |
| model_id: str) -> bool: |
| """ |
| Download file with resume support. |
| |
| Args: |
| url: Download URL |
| output_path: Output file path |
| progress: Progress tracker |
| progress_callback: Progress callback |
| model_id: Model ID for stop event |
| |
| Returns: |
| True if successful |
| """ |
| |
| temp_path = output_path.with_suffix('.part') |
| resume_pos = 0 |
| |
| if temp_path.exists(): |
| resume_pos = temp_path.stat().st_size |
| logger.info(f"Resuming download from {resume_pos} bytes") |
| |
| |
| headers = {} |
| if resume_pos > 0: |
| headers['Range'] = f'bytes={resume_pos}-' |
| |
| |
| start_time = time.time() |
| bytes_downloaded = resume_pos |
| |
| try: |
| response = requests.get( |
| url, |
| headers=headers, |
| stream=True, |
| timeout=self.timeout |
| ) |
| response.raise_for_status() |
| |
| |
| if 'content-length' in response.headers: |
| total_size = int(response.headers['content-length']) + resume_pos |
| progress.total_bytes = total_size |
| else: |
| total_size = None |
| |
| |
| mode = 'ab' if resume_pos > 0 else 'wb' |
| with open(temp_path, mode) as f: |
| for chunk in response.iter_content(chunk_size=self.chunk_size): |
| |
| if self._stop_events[model_id].is_set(): |
| logger.info(f"Download cancelled: {model_id}") |
| return False |
| |
| if chunk: |
| f.write(chunk) |
| bytes_downloaded += len(chunk) |
| |
| |
| progress.current_bytes = bytes_downloaded |
| |
| |
| elapsed = time.time() - start_time |
| if elapsed > 0: |
| speed_bps = (bytes_downloaded - resume_pos) / elapsed |
| progress.speed_mbps = (speed_bps * 8) / 1_000_000 |
| |
| if total_size and speed_bps > 0: |
| remaining = total_size - bytes_downloaded |
| progress.eta_seconds = remaining / speed_bps |
| |
| self._notify_progress(progress, progress_callback) |
| |
| |
| shutil.move(str(temp_path), str(output_path)) |
| return True |
| |
| except requests.exceptions.RequestException as e: |
| logger.error(f"Download error: {e}") |
| return False |
| except Exception as e: |
| logger.error(f"File write error: {e}") |
| return False |
| |
| def _verify_download(self, file_path: Path, model: ModelInfo) -> bool: |
| """ |
| Verify downloaded file. |
| |
| Args: |
| file_path: Downloaded file path |
| model: Model information |
| |
| Returns: |
| True if verification passed |
| """ |
| |
| if not file_path.exists(): |
| return False |
| |
| |
| actual_size = file_path.stat().st_size |
| if model.file_size > 0: |
| size_diff = abs(actual_size - model.file_size) |
| if size_diff > 1000: |
| logger.warning(f"Size mismatch: expected {model.file_size}, got {actual_size}") |
| return False |
| |
| |
| if model.sha256: |
| try: |
| sha256 = self._calculate_sha256(file_path) |
| if sha256 != model.sha256: |
| logger.warning(f"SHA256 mismatch for {model.model_id}") |
| return False |
| except Exception as e: |
| logger.error(f"SHA256 calculation failed: {e}") |
| return False |
| |
| return True |
| |
| def _calculate_sha256(self, file_path: Path) -> str: |
| """Calculate SHA256 hash of file.""" |
| sha256_hash = hashlib.sha256() |
| with open(file_path, "rb") as f: |
| for byte_block in iter(lambda: f.read(self.chunk_size), b""): |
| sha256_hash.update(byte_block) |
| return sha256_hash.hexdigest() |
| |
| def _notify_progress(self, progress: DownloadProgress, callback: Optional[Callable]): |
| """Notify progress callback.""" |
| if callback: |
| try: |
| callback(progress) |
| except Exception as e: |
| logger.error(f"Progress callback error: {e}") |
| |
| def cancel_download(self, model_id: str) -> bool: |
| """ |
| Cancel ongoing download. |
| |
| Args: |
| model_id: Model ID to cancel |
| |
| Returns: |
| True if cancelled |
| """ |
| if model_id in self._stop_events: |
| self._stop_events[model_id].set() |
| |
| |
| if model_id in self.futures: |
| try: |
| self.futures[model_id].result(timeout=5) |
| except: |
| pass |
| del self.futures[model_id] |
| |
| |
| if model_id in self.downloads: |
| self.downloads[model_id].status = DownloadStatus.CANCELLED |
| |
| logger.info(f"Download cancelled: {model_id}") |
| return True |
| |
| return False |
| |
| def get_progress(self, model_id: str) -> Optional[DownloadProgress]: |
| """Get download progress for model.""" |
| return self.downloads.get(model_id) |
| |
| def get_all_progress(self) -> Dict[str, DownloadProgress]: |
| """Get all download progress.""" |
| return self.downloads.copy() |
| |
| def cleanup_partial_downloads(self): |
| """Clean up partial download files.""" |
| for file in self.registry.models_dir.glob("*.part"): |
| try: |
| file.unlink() |
| logger.info(f"Removed partial download: {file.name}") |
| except Exception as e: |
| logger.error(f"Failed to remove {file}: {e}") |
| |
| def download_required_models(self, |
| task: str = None, |
| gpu_available: bool = True) -> bool: |
| """ |
| Download all required models for a task. |
| |
| Args: |
| task: Optional task filter |
| gpu_available: GPU availability |
| |
| Returns: |
| True if all downloads successful |
| """ |
| |
| required = [] |
| |
| if task: |
| |
| from .registry import ModelTask |
| task_enum = ModelTask(task) |
| model = self.registry.get_best_model( |
| task_enum, |
| require_gpu=gpu_available if gpu_available else False |
| ) |
| if model: |
| required.append(model.model_id) |
| else: |
| |
| essential = ['rmbg-1.4', 'u2netp', 'modnet'] |
| for model_id in essential: |
| if self.registry.get_model(model_id): |
| required.append(model_id) |
| |
| |
| if required: |
| logger.info(f"Downloading required models: {required}") |
| futures = self.download_models_async(required) |
| |
| |
| success = True |
| for model_id, future in futures.items(): |
| try: |
| if not future.result(): |
| success = False |
| except Exception: |
| success = False |
| |
| return success |
| |
| return True |