Spaces:
Running on Zero
Running on Zero
Kyle Pearson
Add zero-gpu support, enhance model export with quantization/gpu acceleration helpers, optimize inference pipeline with vae fixes, modernize pipeline loading with unified decorators, implement gpu decorator infrastructure.
3631a8e | """Pipeline management for SDXL Model Merger.""" | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| AutoencoderKL, | |
| DPMSolverSDEScheduler, | |
| ) | |
| from . import config | |
| from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled | |
| from .downloader import get_safe_filename_from_url, download_file_with_progress | |
| from .gpu_decorator import GPU | |
| def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs): | |
| """GPU-decorated helper that performs all GPU-intensive pipeline setup.""" | |
| _pipe = StableDiffusionXLPipeline.from_single_file( | |
| str(checkpoint_path), | |
| **load_kwargs, | |
| ) | |
| print(" β Text encoders loaded") | |
| # Move to device (unless using device_map='auto' which handles this automatically) | |
| if not is_running_on_spaces() or device != "cpu": | |
| print(f" βοΈ Moving pipeline to device: {device_description}...") | |
| _pipe = _pipe.to(device=device, dtype=dtype) | |
| # Load custom VAE if provided | |
| if vae_path is not None: | |
| print(" βοΈ Loading VAE weights...") | |
| vae = AutoencoderKL.from_single_file( | |
| str(vae_path), | |
| torch_dtype=dtype, | |
| ) | |
| print(" βοΈ Setting custom VAE...") | |
| _pipe.vae = vae.to(device=device, dtype=torch.float32) | |
| # Load and fuse each LoRA | |
| if lora_paths_and_strengths: | |
| # Ensure pipeline is on device for LoRA fusion | |
| _pipe = _pipe.to(device=device, dtype=dtype) | |
| for i, (lora_path, strength) in enumerate(lora_paths_and_strengths): | |
| adapter_name = f"lora_{i}" | |
| print(f" βοΈ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...") | |
| _pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name) | |
| print(f" βοΈ Fusing LoRA {i+1} with strength={strength}...") | |
| _pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength) | |
| _pipe.unload_lora_weights() | |
| else: | |
| # Move pipeline to device even without LoRAs | |
| _pipe = _pipe.to(device=device, dtype=dtype) | |
| # Set scheduler | |
| print(" βοΈ Configuring scheduler...") | |
| _pipe.scheduler = DPMSolverSDEScheduler.from_config( | |
| _pipe.scheduler.config, | |
| algorithm_type="sde-dpmsolver++", | |
| use_karras_sigmas=False, | |
| ) | |
| # Keep VAE in float32 to prevent colorful static output | |
| _pipe.vae.to(dtype=torch.float32) | |
| return _pipe | |
| def load_pipeline( | |
| checkpoint_url: str, | |
| vae_url: str, | |
| lora_urls_str: str, | |
| lora_strengths_str: str, | |
| progress=None | |
| ) -> tuple[str, str]: | |
| """ | |
| Load SDXL pipeline with checkpoint, VAE, and LoRAs. | |
| Args: | |
| checkpoint_url: URL to base model .safetensors file | |
| vae_url: Optional URL to VAE .safetensors file | |
| lora_urls_str: Newline-separated URLs for LoRA models | |
| lora_strengths_str: Comma-separated strength values for each LoRA | |
| progress: Optional gr.Progress() object for UI updates | |
| Yields: | |
| Tuple of (status_message, progress_text) at each loading stage. | |
| Returns: | |
| Final yielded tuple of (final_status_message, progress_text) | |
| """ | |
| # Clear any previously loaded pipeline so the UI reflects loading state | |
| config.set_pipe(None) | |
| try: | |
| set_download_cancelled(False) | |
| print("=" * 60) | |
| print("π Loading SDXL Pipeline...") | |
| print("=" * 60) | |
| checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model") | |
| checkpoint_path = CACHE_DIR / checkpoint_filename | |
| # Check if checkpoint is already cached | |
| checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0 | |
| # Validate cache file before using it | |
| if checkpoint_cached: | |
| is_valid, msg = config.validate_cache_file(checkpoint_path) | |
| if not is_valid: | |
| print(f" β οΈ Cache invalid: {msg}") | |
| checkpoint_path.unlink(missing_ok=True) | |
| checkpoint_cached = False | |
| # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching | |
| vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else None | |
| vae_path = CACHE_DIR / vae_filename if vae_filename else None | |
| vae_cached = vae_url.strip() and vae_path and vae_path.exists() and vae_path.stat().st_size > 0 | |
| # Validate VAE cache file before using it | |
| if vae_cached: | |
| is_valid, msg = config.validate_cache_file(vae_path) | |
| if not is_valid: | |
| print(f" β οΈ VAE Cache invalid: {msg}") | |
| vae_path.unlink(missing_ok=True) | |
| vae_cached = False | |
| # Download checkpoint (skips if already cached) | |
| if progress: | |
| progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...") | |
| if not checkpoint_cached: | |
| status_msg = f"π₯ Downloading {checkpoint_path.name}..." | |
| print(f" π₯ Downloading: {checkpoint_path.name}") | |
| else: | |
| status_msg = f"β Using cached {checkpoint_path.name}" | |
| print(f" β Using cached: {checkpoint_path.name}") | |
| yield status_msg, "Starting download..." | |
| if not checkpoint_cached: | |
| download_file_with_progress(checkpoint_url, checkpoint_path) | |
| # Download VAE if provided (loading happens in _load_and_setup_pipeline) | |
| if vae_url and vae_url.strip(): | |
| if vae_path: | |
| status_msg = f"π₯ Downloading {vae_path.name}..." if not vae_cached else f"β Using cached {vae_path.name}" | |
| print(f" π₯ VAE: {vae_path.name}" if not vae_cached else f" β VAE (cached): {vae_path.name}") | |
| if progress: | |
| progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...") | |
| yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}" | |
| if not vae_cached: | |
| download_file_with_progress(vae_url, vae_path) | |
| # For CPU/low-memory environments on Spaces, use device_map for better RAM management | |
| load_kwargs = { | |
| "torch_dtype": dtype, | |
| "use_safetensors": True, | |
| } | |
| if is_running_on_spaces() and device == "cpu": | |
| print(" βΉοΈ CPU mode detected: enabling device_map='auto' for better RAM management") | |
| load_kwargs["device_map"] = "auto" | |
| # Parse LoRA URLs & ensure strengths list matches | |
| lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()] | |
| strengths_raw = [s.strip() for s in lora_strengths_str.split(",")] | |
| strengths = [] | |
| for i, url in enumerate(lora_urls): | |
| try: | |
| val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0 | |
| strengths.append(val) | |
| except ValueError: | |
| strengths.append(1.0) | |
| # Download LoRAs (CPU-bound downloads, before GPU work) | |
| lora_paths_and_strengths = [] | |
| if lora_urls: | |
| for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)): | |
| lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora") | |
| lora_path = CACHE_DIR / lora_filename | |
| lora_cached = lora_path.exists() and lora_path.stat().st_size > 0 | |
| # Validate LoRA cache file before using it | |
| if lora_cached: | |
| is_valid, msg = config.validate_cache_file(lora_path) | |
| if not is_valid: | |
| print(f" β οΈ LoRA Cache invalid: {msg}") | |
| lora_path.unlink(missing_ok=True) | |
| lora_cached = False | |
| if not lora_cached: | |
| print(f" π₯ LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...") | |
| status_msg = f"π₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..." | |
| else: | |
| print(f" β LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}") | |
| status_msg = f"β Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}" | |
| yield ( | |
| status_msg, | |
| f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached | |
| else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})" | |
| ) | |
| if not lora_cached: | |
| download_file_with_progress(lora_url, lora_path) | |
| lora_paths_and_strengths.append((lora_path, strength)) | |
| # All downloads complete β now do GPU-intensive setup in one decorated call | |
| yield "βοΈ Loading SDXL pipeline...", "Loading model weights into memory..." | |
| if progress: | |
| progress(0.5, desc="Loading pipeline...") | |
| _pipe = _load_and_setup_pipeline( | |
| checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs | |
| ) | |
| if progress: | |
| progress(0.95, desc="Finalizing...") | |
| # β Only publish the pipeline globally AFTER all steps succeed | |
| config.set_pipe(_pipe) | |
| print(" β Pipeline ready!") | |
| yield "β Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)" | |
| except KeyboardInterrupt: | |
| set_download_cancelled(False) | |
| config.set_pipe(None) | |
| print("\nβ οΈ Download cancelled by user") | |
| return ("β οΈ Download cancelled by user", "Cancelled") | |
| except Exception as e: | |
| import traceback | |
| config.set_pipe(None) | |
| error_msg = f"β Error loading pipeline: {str(e)}" | |
| print(f"\n{error_msg}") | |
| print(traceback.format_exc()) | |
| return (error_msg, f"Error: {str(e)}") | |
| def cancel_download(): | |
| """Set the global cancellation flag to stop any ongoing downloads.""" | |
| set_download_cancelled(True) | |
| def get_pipeline() -> StableDiffusionXLPipeline | None: | |
| """Get the currently loaded pipeline.""" | |
| return config.get_pipe() | |