SDXL-Model-Merger / src /pipeline.py
Kyle Pearson
Add zero-gpu support, enhance model export with quantization/gpu acceleration helpers, optimize inference pipeline with vae fixes, modernize pipeline loading with unified decorators, implement gpu decorator infrastructure.
3631a8e
"""Pipeline management for SDXL Model Merger."""
import torch
from diffusers import (
StableDiffusionXLPipeline,
AutoencoderKL,
DPMSolverSDEScheduler,
)
from . import config
from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
from .downloader import get_safe_filename_from_url, download_file_with_progress
from .gpu_decorator import GPU
@GPU(duration=300)
def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs):
"""GPU-decorated helper that performs all GPU-intensive pipeline setup."""
_pipe = StableDiffusionXLPipeline.from_single_file(
str(checkpoint_path),
**load_kwargs,
)
print(" βœ… Text encoders loaded")
# Move to device (unless using device_map='auto' which handles this automatically)
if not is_running_on_spaces() or device != "cpu":
print(f" βš™οΈ Moving pipeline to device: {device_description}...")
_pipe = _pipe.to(device=device, dtype=dtype)
# Load custom VAE if provided
if vae_path is not None:
print(" βš™οΈ Loading VAE weights...")
vae = AutoencoderKL.from_single_file(
str(vae_path),
torch_dtype=dtype,
)
print(" βš™οΈ Setting custom VAE...")
_pipe.vae = vae.to(device=device, dtype=torch.float32)
# Load and fuse each LoRA
if lora_paths_and_strengths:
# Ensure pipeline is on device for LoRA fusion
_pipe = _pipe.to(device=device, dtype=dtype)
for i, (lora_path, strength) in enumerate(lora_paths_and_strengths):
adapter_name = f"lora_{i}"
print(f" βš™οΈ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...")
_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
print(f" βš™οΈ Fusing LoRA {i+1} with strength={strength}...")
_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
_pipe.unload_lora_weights()
else:
# Move pipeline to device even without LoRAs
_pipe = _pipe.to(device=device, dtype=dtype)
# Set scheduler
print(" βš™οΈ Configuring scheduler...")
_pipe.scheduler = DPMSolverSDEScheduler.from_config(
_pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=False,
)
# Keep VAE in float32 to prevent colorful static output
_pipe.vae.to(dtype=torch.float32)
return _pipe
def load_pipeline(
checkpoint_url: str,
vae_url: str,
lora_urls_str: str,
lora_strengths_str: str,
progress=None
) -> tuple[str, str]:
"""
Load SDXL pipeline with checkpoint, VAE, and LoRAs.
Args:
checkpoint_url: URL to base model .safetensors file
vae_url: Optional URL to VAE .safetensors file
lora_urls_str: Newline-separated URLs for LoRA models
lora_strengths_str: Comma-separated strength values for each LoRA
progress: Optional gr.Progress() object for UI updates
Yields:
Tuple of (status_message, progress_text) at each loading stage.
Returns:
Final yielded tuple of (final_status_message, progress_text)
"""
# Clear any previously loaded pipeline so the UI reflects loading state
config.set_pipe(None)
try:
set_download_cancelled(False)
print("=" * 60)
print("πŸ”„ Loading SDXL Pipeline...")
print("=" * 60)
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
checkpoint_path = CACHE_DIR / checkpoint_filename
# Check if checkpoint is already cached
checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
# Validate cache file before using it
if checkpoint_cached:
is_valid, msg = config.validate_cache_file(checkpoint_path)
if not is_valid:
print(f" ⚠️ Cache invalid: {msg}")
checkpoint_path.unlink(missing_ok=True)
checkpoint_cached = False
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else None
vae_path = CACHE_DIR / vae_filename if vae_filename else None
vae_cached = vae_url.strip() and vae_path and vae_path.exists() and vae_path.stat().st_size > 0
# Validate VAE cache file before using it
if vae_cached:
is_valid, msg = config.validate_cache_file(vae_path)
if not is_valid:
print(f" ⚠️ VAE Cache invalid: {msg}")
vae_path.unlink(missing_ok=True)
vae_cached = False
# Download checkpoint (skips if already cached)
if progress:
progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
if not checkpoint_cached:
status_msg = f"πŸ“₯ Downloading {checkpoint_path.name}..."
print(f" πŸ“₯ Downloading: {checkpoint_path.name}")
else:
status_msg = f"βœ… Using cached {checkpoint_path.name}"
print(f" βœ… Using cached: {checkpoint_path.name}")
yield status_msg, "Starting download..."
if not checkpoint_cached:
download_file_with_progress(checkpoint_url, checkpoint_path)
# Download VAE if provided (loading happens in _load_and_setup_pipeline)
if vae_url and vae_url.strip():
if vae_path:
status_msg = f"πŸ“₯ Downloading {vae_path.name}..." if not vae_cached else f"βœ… Using cached {vae_path.name}"
print(f" πŸ“₯ VAE: {vae_path.name}" if not vae_cached else f" βœ… VAE (cached): {vae_path.name}")
if progress:
progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
if not vae_cached:
download_file_with_progress(vae_url, vae_path)
# For CPU/low-memory environments on Spaces, use device_map for better RAM management
load_kwargs = {
"torch_dtype": dtype,
"use_safetensors": True,
}
if is_running_on_spaces() and device == "cpu":
print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
load_kwargs["device_map"] = "auto"
# Parse LoRA URLs & ensure strengths list matches
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
strengths = []
for i, url in enumerate(lora_urls):
try:
val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0
strengths.append(val)
except ValueError:
strengths.append(1.0)
# Download LoRAs (CPU-bound downloads, before GPU work)
lora_paths_and_strengths = []
if lora_urls:
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
lora_path = CACHE_DIR / lora_filename
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
# Validate LoRA cache file before using it
if lora_cached:
is_valid, msg = config.validate_cache_file(lora_path)
if not is_valid:
print(f" ⚠️ LoRA Cache invalid: {msg}")
lora_path.unlink(missing_ok=True)
lora_cached = False
if not lora_cached:
print(f" πŸ“₯ LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
status_msg = f"πŸ“₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
else:
print(f" βœ… LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
status_msg = f"βœ… Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
yield (
status_msg,
f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
)
if not lora_cached:
download_file_with_progress(lora_url, lora_path)
lora_paths_and_strengths.append((lora_path, strength))
# All downloads complete β€” now do GPU-intensive setup in one decorated call
yield "βš™οΈ Loading SDXL pipeline...", "Loading model weights into memory..."
if progress:
progress(0.5, desc="Loading pipeline...")
_pipe = _load_and_setup_pipeline(
checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs
)
if progress:
progress(0.95, desc="Finalizing...")
# βœ… Only publish the pipeline globally AFTER all steps succeed
config.set_pipe(_pipe)
print(" βœ… Pipeline ready!")
yield "βœ… Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
except KeyboardInterrupt:
set_download_cancelled(False)
config.set_pipe(None)
print("\n⚠️ Download cancelled by user")
return ("⚠️ Download cancelled by user", "Cancelled")
except Exception as e:
import traceback
config.set_pipe(None)
error_msg = f"❌ Error loading pipeline: {str(e)}"
print(f"\n{error_msg}")
print(traceback.format_exc())
return (error_msg, f"Error: {str(e)}")
def cancel_download():
"""Set the global cancellation flag to stop any ongoing downloads."""
set_download_cancelled(True)
def get_pipeline() -> StableDiffusionXLPipeline | None:
"""Get the currently loaded pipeline."""
return config.get_pipe()