"""Download utilities for SDXL Model Merger with Gradio progress integration.""" import re import requests from pathlib import Path from tqdm import tqdm as TqdmBase from .config import download_cancelled def extract_model_id(url: str) -> str | None: """Extract CivitAI model ID from URL.""" match = re.search(r'/models/(\d+)', url) return match.group(1) if match else None def is_huggingface_url(url: str) -> bool: """Check if URL is a HuggingFace model download URL.""" return "huggingface.co" in url.lower() def get_safe_filename_from_url( url: str, default_name: str = "model.safetensors", suffix: str = "", type_prefix: str | None = None ) -> str: """ Generate a safe filename with model ID from URL. For CivitAI URLs like https://civitai.com/api/download/models/12345?type=... Naming patterns: - Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors - VAE (suffix='_vae'): 12345_vae.safetensors (no name extraction to avoid double suffix) - LoRA (suffix='_lora'): 12345_lora.safetensors (no name extraction to avoid double suffix) For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming. Args: url: The download URL default_name: Fallback filename if extraction fails suffix: Optional suffix to append before .safetensors (e.g., '_vae', '_lora') type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors) """ model_id = extract_model_id(url) # If no CivitAI model ID, try to generate a name from HuggingFace path if not model_id and "huggingface.co" in url: # Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> fp16_fix) try: parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else "" if parts: # Get the repo name (second part after org/) path_parts = [p for p in parts.split("/") if p] if len(path_parts) >= 2: repo_name = path_parts[1] # Clean up and create a simple identifier clean_repo = re.sub(r'[^a-zA-Z0-9]', '_', repo_name)[:30].strip('_') if clean_repo: model_id = f"hf_{clean_repo}" except Exception: pass if not model_id: return default_name # Special handling for VAE/LoRA with HuggingFace URLs to avoid double suffix is_special_type = suffix in ("_vae", "_lora") # Strip common suffixes from model_id when adding corresponding suffix # (e.g., "sdxl_vae_fp16_fix" + "_vae" -> "sdxl_fp16_fix" + "_vae") if is_special_type: strip_suffix = suffix.lstrip('_') # "vae" or "lora" model_id_lower = model_id.lower() # Check if model_id contains the type (with underscore boundaries) if f"_{strip_suffix}_" in model_id_lower or model_id_lower.endswith(f"_{strip_suffix}"): # Remove the suffix from model_id if model_id_lower.endswith(f"_{strip_suffix}"): model_id = model_id[:-len(strip_suffix)-1] else: # Find and remove _suffix_ pattern pattern = f"_{strip_suffix}_" idx = model_id_lower.find(pattern) if idx >= 0: model_id = model_id[:idx] + model_id[idx+len(pattern):] # Build the name portion: either clean name from URL or fallback name_part = "" # For VAE/LoRA types, skip Content-Disposition parsing to avoid double naming # (e.g., sdxl_vae_vae instead of just vae) if not is_special_type: try: response = requests.head(url, timeout=10, allow_redirects=True) cd = response.headers.get('Content-Disposition', '') match = re.search(r'filename="([^"]+)"', cd) if match: filename = match.group(1) # Extract base name without extension base_name = Path(filename).stem # Clean up the name (remove special chars) clean_name = re.sub(r'[^\w\s-]', '', base_name)[:50] clean_name = re.sub(r'[-\s]+', '_', clean_name.strip('-_')) if clean_name: name_part = clean_name except Exception: pass # Build filename with model_id, optional type_prefix, optional name_part, and suffix parts = [model_id] if type_prefix: parts.append(type_prefix) if name_part: parts.append(name_part) # Handle suffix - for VAE/LoRA we only add the suffix, not double naming if suffix: if is_special_type: # For _vae and _lora: just use model_id + suffix directly return f"{model_id}{suffix}.safetensors" else: # For other types (checkpoint), append suffix after name_part parts.append(suffix.lstrip('_')) return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors' class TqdmGradio(TqdmBase): """tqdm subclass that sends progress updates to Gradio's gr.Progress()""" def __init__(self, *args, gradio_prog=None, **kwargs): super().__init__(*args, **kwargs) self.gradio_prog = gradio_prog self.last_pct = 0 def update(self, n=1): from .config import download_cancelled if download_cancelled: raise KeyboardInterrupt("Download cancelled by user") super().update(n) if self.gradio_prog and self.total: pct = int(100 * self.n / self.total) # Only update UI every ~5% to avoid spamming if pct != self.last_pct and pct % 5 == 0: self.last_pct = pct self.gradio_prog(pct / 100) def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]: """ Check if file exists in cache and matches expected size. Uses the same filename generation logic as download operations to find cached files by URL. Args: url: The download URL to check for cached file suffix: Optional suffix (e.g., '_vae', '_lora') for special file types type_prefix: Optional prefix after model_id (e.g., 'model') Returns: Tuple of (cached_file_path, file_size) if valid cache exists, or (None, None) if no valid cache found. """ from .config import CACHE_DIR # Generate the expected filename for this URL default_name = "vae.safetensors" if suffix == "_vae" else ( "lora.safetensors" if suffix == "_lora" else "model.safetensors" ) cached_filename = get_safe_filename_from_url( url, default_name=default_name, suffix=suffix, type_prefix=type_prefix ) cached_path = CACHE_DIR / cached_filename if cached_path.exists() and cached_path.is_file(): try: file_size = cached_path.stat().st_size # Only return valid cache if file has content if file_size > 0: return cached_path, file_size except OSError: pass return None, None def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path: """ Download a file with Gradio-synced progress bar + cancel support. Checks for existing cached files before downloading. If a valid cache exists (file exists with matching expected size), skips re-download. Supports both HTTP(S) and HuggingFace Hub URLs. Args: url: File URL to download (http/https/file) output_path: Destination path for downloaded file progress_bar: Optional gr.Progress() object for UI updates Returns: Path to the downloaded (or cached) file Raises: KeyboardInterrupt: If download is cancelled requests.RequestException: If download fails """ from .config import download_cancelled # Handle local file:// URLs if url.startswith("file://"): local_path = Path(url[7:]) # Remove "file://" prefix if local_path.exists(): import shutil output_path.parent.mkdir(parents=True, exist_ok=True) print(f" 📁 Copying from cache: {local_path.name} → {output_path.name}") # Copy the file to cache location shutil.copy2(str(local_path), str(output_path)) # Update progress bar for cached files if progress_bar: progress_bar(1.0) return output_path else: raise FileNotFoundError(f"Local file not found: {local_path}") print(f" 📥 Downloading to cache: {output_path.name}") # Early cache check: if file exists and size matches URL's content-length, skip re-download expected_size = None try: head = requests.head(url, timeout=10) expected_size = int(head.headers.get('content-length', 0)) except Exception: pass # Skip header fetch on errors if output_path.exists() and expected_size is not None: try: cached_size = output_path.stat().st_size if cached_size == expected_size: print(f" ✅ Cache hit: {output_path.name} ({cached_size / (1024**2):.1f} MB)") # Cache hit - file exists with correct size if progress_bar: progress_bar(1.0) return output_path # Skip re-download! except OSError: pass # File access error, proceed with download output_path.parent.mkdir(parents=True, exist_ok=True) session = requests.Session() response = session.get(url, stream=True, timeout=30) response.raise_for_status() total_size = expected_size or int(response.headers.get('content-length', 0)) block_size = 8192 # Use TqdmGradio to sync progress with Gradio tqdm_kwargs = { 'unit': 'B', 'unit_scale': True, 'desc': f"Downloading {output_path.name}", 'gradio_prog': progress_bar, 'disable': False, 'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]', } with open(output_path, "wb") as f: try: for data in TqdmGradio( response.iter_content(block_size), total=total_size // block_size if total_size else 0, **tqdm_kwargs, ): if download_cancelled: raise KeyboardInterrupt("Download cancelled by user") f.write(data) except KeyboardInterrupt: # Clean partial file on cancel output_path.unlink(missing_ok=True) raise # Verify the downloaded file is complete try: actual_size = output_path.stat().st_size # For safetensors files, check header is valid if output_path.suffix == ".safetensors": import struct with open(output_path, "rb") as f: header_size_bytes = f.read(8) if len(header_size_bytes) < 8: raise OSError(f"Safetensors file too small: {output_path.name}") header_size = struct.unpack("