"""Pipeline management for SDXL Model Merger.""" import torch from diffusers import ( StableDiffusionXLPipeline, AutoencoderKL, DPMSolverSDEScheduler, ) from . import config from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled from .downloader import get_safe_filename_from_url, download_file_with_progress from .gpu_decorator import GPU @GPU(duration=300) def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs): """GPU-decorated helper that performs all GPU-intensive pipeline setup.""" _pipe = StableDiffusionXLPipeline.from_single_file( str(checkpoint_path), **load_kwargs, ) print(" ✅ Text encoders loaded") # Move to device (unless using device_map='auto' which handles this automatically) if not is_running_on_spaces() or device != "cpu": print(f" ⚙️ Moving pipeline to device: {device_description}...") _pipe = _pipe.to(device=device, dtype=dtype) # Load custom VAE if provided if vae_path is not None: print(" ⚙️ Loading VAE weights...") vae = AutoencoderKL.from_single_file( str(vae_path), torch_dtype=dtype, ) print(" ⚙️ Setting custom VAE...") _pipe.vae = vae.to(device=device, dtype=torch.float32) # Load and fuse each LoRA if lora_paths_and_strengths: # Ensure pipeline is on device for LoRA fusion _pipe = _pipe.to(device=device, dtype=dtype) for i, (lora_path, strength) in enumerate(lora_paths_and_strengths): adapter_name = f"lora_{i}" print(f" ⚙️ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...") _pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name) print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...") _pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength) _pipe.unload_lora_weights() else: # Move pipeline to device even without LoRAs _pipe = _pipe.to(device=device, dtype=dtype) # Set scheduler print(" ⚙️ Configuring scheduler...") _pipe.scheduler = DPMSolverSDEScheduler.from_config( _pipe.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False, ) # Keep VAE in float32 to prevent colorful static output _pipe.vae.to(dtype=torch.float32) return _pipe def load_pipeline( checkpoint_url: str, vae_url: str, lora_urls_str: str, lora_strengths_str: str, progress=None ) -> tuple[str, str]: """ Load SDXL pipeline with checkpoint, VAE, and LoRAs. Args: checkpoint_url: URL to base model .safetensors file vae_url: Optional URL to VAE .safetensors file lora_urls_str: Newline-separated URLs for LoRA models lora_strengths_str: Comma-separated strength values for each LoRA progress: Optional gr.Progress() object for UI updates Yields: Tuple of (status_message, progress_text) at each loading stage. Returns: Final yielded tuple of (final_status_message, progress_text) """ # Clear any previously loaded pipeline so the UI reflects loading state config.set_pipe(None) try: set_download_cancelled(False) print("=" * 60) print("🔄 Loading SDXL Pipeline...") print("=" * 60) checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model") checkpoint_path = CACHE_DIR / checkpoint_filename # Check if checkpoint is already cached checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0 # Validate cache file before using it if checkpoint_cached: is_valid, msg = config.validate_cache_file(checkpoint_path) if not is_valid: print(f" ⚠️ Cache invalid: {msg}") checkpoint_path.unlink(missing_ok=True) checkpoint_cached = False # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else None vae_path = CACHE_DIR / vae_filename if vae_filename else None vae_cached = vae_url.strip() and vae_path and vae_path.exists() and vae_path.stat().st_size > 0 # Validate VAE cache file before using it if vae_cached: is_valid, msg = config.validate_cache_file(vae_path) if not is_valid: print(f" ⚠️ VAE Cache invalid: {msg}") vae_path.unlink(missing_ok=True) vae_cached = False # Download checkpoint (skips if already cached) if progress: progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...") if not checkpoint_cached: status_msg = f"📥 Downloading {checkpoint_path.name}..." print(f" 📥 Downloading: {checkpoint_path.name}") else: status_msg = f"✅ Using cached {checkpoint_path.name}" print(f" ✅ Using cached: {checkpoint_path.name}") yield status_msg, "Starting download..." if not checkpoint_cached: download_file_with_progress(checkpoint_url, checkpoint_path) # Download VAE if provided (loading happens in _load_and_setup_pipeline) if vae_url and vae_url.strip(): if vae_path: status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}" print(f" 📥 VAE: {vae_path.name}" if not vae_cached else f" ✅ VAE (cached): {vae_path.name}") if progress: progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...") yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}" if not vae_cached: download_file_with_progress(vae_url, vae_path) # For CPU/low-memory environments on Spaces, use device_map for better RAM management load_kwargs = { "torch_dtype": dtype, "use_safetensors": True, } if is_running_on_spaces() and device == "cpu": print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management") load_kwargs["device_map"] = "auto" # Parse LoRA URLs & ensure strengths list matches lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()] strengths_raw = [s.strip() for s in lora_strengths_str.split(",")] strengths = [] for i, url in enumerate(lora_urls): try: val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0 strengths.append(val) except ValueError: strengths.append(1.0) # Download LoRAs (CPU-bound downloads, before GPU work) lora_paths_and_strengths = [] if lora_urls: for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)): lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora") lora_path = CACHE_DIR / lora_filename lora_cached = lora_path.exists() and lora_path.stat().st_size > 0 # Validate LoRA cache file before using it if lora_cached: is_valid, msg = config.validate_cache_file(lora_path) if not is_valid: print(f" ⚠️ LoRA Cache invalid: {msg}") lora_path.unlink(missing_ok=True) lora_cached = False if not lora_cached: print(f" 📥 LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...") status_msg = f"📥 Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..." else: print(f" ✅ LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}") status_msg = f"✅ Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}" yield ( status_msg, f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})" ) if not lora_cached: download_file_with_progress(lora_url, lora_path) lora_paths_and_strengths.append((lora_path, strength)) # All downloads complete — now do GPU-intensive setup in one decorated call yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..." if progress: progress(0.5, desc="Loading pipeline...") _pipe = _load_and_setup_pipeline( checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs ) if progress: progress(0.95, desc="Finalizing...") # ✅ Only publish the pipeline globally AFTER all steps succeed config.set_pipe(_pipe) print(" ✅ Pipeline ready!") yield "✅ Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)" except KeyboardInterrupt: set_download_cancelled(False) config.set_pipe(None) print("\n⚠️ Download cancelled by user") return ("⚠️ Download cancelled by user", "Cancelled") except Exception as e: import traceback config.set_pipe(None) error_msg = f"❌ Error loading pipeline: {str(e)}" print(f"\n{error_msg}") print(traceback.format_exc()) return (error_msg, f"Error: {str(e)}") def cancel_download(): """Set the global cancellation flag to stop any ongoing downloads.""" set_download_cancelled(True) def get_pipeline() -> StableDiffusionXLPipeline | None: """Get the currently loaded pipeline.""" return config.get_pipe()