Spaces:
Sleeping
Sleeping
Kyle Pearson commited on
Commit ·
3631a8e
1
Parent(s): 570384a
Add zero-gpu support, enhance model export with quantization/gpu acceleration helpers, optimize inference pipeline with vae fixes, modernize pipeline loading with unified decorators, implement gpu decorator infrastructure.
Browse files- requirements.txt +3 -0
- src/exporter.py +77 -86
- src/generator.py +20 -31
- src/gpu_decorator.py +10 -0
- src/pipeline.py +68 -73
requirements.txt
CHANGED
|
@@ -26,3 +26,6 @@ psutil>=5.9.0
|
|
| 26 |
|
| 27 |
# Optional: quantization support
|
| 28 |
optimum-quanto>=0.2.0
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Optional: quantization support
|
| 28 |
optimum-quanto>=0.2.0
|
| 29 |
+
|
| 30 |
+
# ZeroGPU support for HuggingFace Spaces
|
| 31 |
+
spaces
|
src/exporter.py
CHANGED
|
@@ -1,13 +1,85 @@
|
|
| 1 |
"""Model export functionality for SDXL Model Merger."""
|
| 2 |
|
| 3 |
-
import os
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
|
| 6 |
import torch
|
| 7 |
from safetensors.torch import save_file
|
| 8 |
|
| 9 |
from . import config
|
| 10 |
from .config import SCRIPT_DIR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def export_merged_model(
|
|
@@ -45,90 +117,9 @@ def export_merged_model(
|
|
| 45 |
yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
|
| 46 |
return
|
| 47 |
|
| 48 |
-
|
| 49 |
-
yield "💾 Exporting model...", "Unloading LoRAs..."
|
| 50 |
-
if include_lora:
|
| 51 |
-
try:
|
| 52 |
-
pipe.unload_lora_weights()
|
| 53 |
-
except Exception as e:
|
| 54 |
-
print(f" ℹ️ Could not unload LoRAs: {e}")
|
| 55 |
-
|
| 56 |
-
merged_state_dict = {}
|
| 57 |
-
|
| 58 |
-
# Step 2: Extract UNet weights
|
| 59 |
-
yield "💾 Exporting model...", "Extracting UNet weights..."
|
| 60 |
-
for k, v in pipe.unet.state_dict().items():
|
| 61 |
-
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
|
| 62 |
-
|
| 63 |
-
# Step 3: Extract text encoder weights
|
| 64 |
-
yield "💾 Exporting model...", "Extracting text encoders..."
|
| 65 |
-
if pipe.text_encoder is not None:
|
| 66 |
-
for k, v in pipe.text_encoder.state_dict().items():
|
| 67 |
-
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
|
| 68 |
-
if pipe.text_encoder_2 is not None:
|
| 69 |
-
for k, v in pipe.text_encoder_2.state_dict().items():
|
| 70 |
-
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
|
| 71 |
-
|
| 72 |
-
# Step 4: Extract VAE weights
|
| 73 |
-
yield "💾 Exporting model...", "Extracting VAE weights..."
|
| 74 |
-
if pipe.vae is not None:
|
| 75 |
-
for k, v in pipe.vae.state_dict().items():
|
| 76 |
-
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
|
| 77 |
-
|
| 78 |
-
# Step 5: Quantize if requested and optimum.quanto is available
|
| 79 |
-
try:
|
| 80 |
-
from optimum.quanto import quantize as quanto_quantize, QTensor
|
| 81 |
-
QUANTO_AVAILABLE = True
|
| 82 |
-
except ImportError:
|
| 83 |
-
QUANTO_AVAILABLE = False
|
| 84 |
-
|
| 85 |
-
if quantize and qtype != "none" and QUANTO_AVAILABLE:
|
| 86 |
-
yield "💾 Exporting model...", f"Applying {qtype} quantization..."
|
| 87 |
-
|
| 88 |
-
class FakeModel(torch.nn.Module):
|
| 89 |
-
pass
|
| 90 |
-
|
| 91 |
-
fake_model = FakeModel()
|
| 92 |
-
fake_model.__dict__.update(merged_state_dict)
|
| 93 |
-
|
| 94 |
-
# Select quantization method
|
| 95 |
-
if qtype == "int8":
|
| 96 |
-
from optimum.quanto import int8_weight_only
|
| 97 |
-
quanto_quantize(fake_model, int8_weight_only())
|
| 98 |
-
elif qtype == "int4":
|
| 99 |
-
from optimum.quanto import int4_weight_only
|
| 100 |
-
quanto_quantize(fake_model, int4_weight_only())
|
| 101 |
-
elif qtype == "float8":
|
| 102 |
-
from optimum.quanto import float8_dynamic_activation_float8_weight
|
| 103 |
-
quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
|
| 104 |
-
else:
|
| 105 |
-
raise ValueError(f"Unsupported qtype: {qtype}")
|
| 106 |
-
|
| 107 |
-
merged_state_dict = {
|
| 108 |
-
k: v.dequantize().half() if isinstance(v, QTensor) else v
|
| 109 |
-
for k, v in fake_model.state_dict().items()
|
| 110 |
-
}
|
| 111 |
-
elif quantize and not QUANTO_AVAILABLE:
|
| 112 |
-
yield None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
|
| 113 |
-
return
|
| 114 |
-
|
| 115 |
-
# Step 6: Save model
|
| 116 |
-
yield "💾 Exporting model...", "Saving weights..."
|
| 117 |
-
|
| 118 |
-
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 119 |
|
| 120 |
-
|
| 121 |
-
prefix = ""
|
| 122 |
-
if quantize and qtype != "none":
|
| 123 |
-
prefix = f"{qtype}_"
|
| 124 |
-
|
| 125 |
-
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
|
| 126 |
-
|
| 127 |
-
# Save appropriately
|
| 128 |
-
if ext == ".bin":
|
| 129 |
-
torch.save(merged_state_dict, str(out_path))
|
| 130 |
-
else:
|
| 131 |
-
save_file(merged_state_dict, str(out_path))
|
| 132 |
|
| 133 |
size_gb = out_path.stat().st_size / 1024**3
|
| 134 |
|
|
|
|
| 1 |
"""Model export functionality for SDXL Model Merger."""
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
from safetensors.torch import save_file
|
| 5 |
|
| 6 |
from . import config
|
| 7 |
from .config import SCRIPT_DIR
|
| 8 |
+
from .gpu_decorator import GPU
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@GPU(duration=180)
|
| 12 |
+
def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
|
| 13 |
+
"""GPU-decorated helper that extracts weights and saves the model."""
|
| 14 |
+
if include_lora:
|
| 15 |
+
try:
|
| 16 |
+
pipe.unload_lora_weights()
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f" ℹ️ Could not unload LoRAs: {e}")
|
| 19 |
+
|
| 20 |
+
merged_state_dict = {}
|
| 21 |
+
|
| 22 |
+
# Extract UNet weights
|
| 23 |
+
for k, v in pipe.unet.state_dict().items():
|
| 24 |
+
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
|
| 25 |
+
|
| 26 |
+
# Extract text encoder weights
|
| 27 |
+
if pipe.text_encoder is not None:
|
| 28 |
+
for k, v in pipe.text_encoder.state_dict().items():
|
| 29 |
+
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
|
| 30 |
+
if pipe.text_encoder_2 is not None:
|
| 31 |
+
for k, v in pipe.text_encoder_2.state_dict().items():
|
| 32 |
+
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
|
| 33 |
+
|
| 34 |
+
# Extract VAE weights
|
| 35 |
+
if pipe.vae is not None:
|
| 36 |
+
for k, v in pipe.vae.state_dict().items():
|
| 37 |
+
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
|
| 38 |
+
|
| 39 |
+
# Quantize if requested
|
| 40 |
+
try:
|
| 41 |
+
from optimum.quanto import quantize as quanto_quantize, QTensor
|
| 42 |
+
QUANTO_AVAILABLE = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
QUANTO_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
if quantize and qtype != "none" and QUANTO_AVAILABLE:
|
| 47 |
+
class FakeModel(torch.nn.Module):
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
fake_model = FakeModel()
|
| 51 |
+
fake_model.__dict__.update(merged_state_dict)
|
| 52 |
+
|
| 53 |
+
if qtype == "int8":
|
| 54 |
+
from optimum.quanto import int8_weight_only
|
| 55 |
+
quanto_quantize(fake_model, int8_weight_only())
|
| 56 |
+
elif qtype == "int4":
|
| 57 |
+
from optimum.quanto import int4_weight_only
|
| 58 |
+
quanto_quantize(fake_model, int4_weight_only())
|
| 59 |
+
elif qtype == "float8":
|
| 60 |
+
from optimum.quanto import float8_dynamic_activation_float8_weight
|
| 61 |
+
quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unsupported qtype: {qtype}")
|
| 64 |
+
|
| 65 |
+
merged_state_dict = {
|
| 66 |
+
k: v.dequantize().half() if isinstance(v, QTensor) else v
|
| 67 |
+
for k, v in fake_model.state_dict().items()
|
| 68 |
+
}
|
| 69 |
+
elif quantize and not QUANTO_AVAILABLE:
|
| 70 |
+
raise ImportError("optimum.quanto not installed. Install with: pip install optimum-quanto")
|
| 71 |
+
|
| 72 |
+
# Save model
|
| 73 |
+
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 74 |
+
prefix = f"{qtype}_" if quantize and qtype != "none" else ""
|
| 75 |
+
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
|
| 76 |
+
|
| 77 |
+
if ext == ".bin":
|
| 78 |
+
torch.save(merged_state_dict, str(out_path))
|
| 79 |
+
else:
|
| 80 |
+
save_file(merged_state_dict, str(out_path))
|
| 81 |
+
|
| 82 |
+
return out_path
|
| 83 |
|
| 84 |
|
| 85 |
def export_merged_model(
|
|
|
|
| 117 |
yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
|
| 118 |
return
|
| 119 |
|
| 120 |
+
yield "💾 Exporting model...", "Extracting and saving weights..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
size_gb = out_path.stat().st_size / 1024**3
|
| 125 |
|
src/generator.py
CHANGED
|
@@ -3,10 +3,25 @@
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
from . import config
|
| 6 |
-
from .config import device, dtype
|
|
|
|
| 7 |
from .tiling import enable_seamless_tiling
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def generate_image(
|
| 11 |
prompt: str,
|
| 12 |
negative_prompt: str,
|
|
@@ -45,25 +60,8 @@ def generate_image(
|
|
| 45 |
yield None, "⚠️ Please load a pipeline first."
|
| 46 |
return
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
effective_device = device
|
| 51 |
-
|
| 52 |
-
if is_running_on_spaces() and device == "cpu":
|
| 53 |
-
print(" ℹ️ CPU mode: using float32 for stability (generation will be slower)")
|
| 54 |
-
# Store original dtype and temporarily use float32
|
| 55 |
-
original_dtype = effective_dtype
|
| 56 |
-
effective_dtype = torch.float32
|
| 57 |
-
# Update pipeline to use float32
|
| 58 |
-
pipe.unet.to(dtype=torch.float32)
|
| 59 |
-
if pipe.text_encoder is not None:
|
| 60 |
-
pipe.text_encoder.to(dtype=torch.float32)
|
| 61 |
-
if pipe.text_encoder_2 is not None:
|
| 62 |
-
pipe.text_encoder_2.to(dtype=torch.float32)
|
| 63 |
-
if pipe.vae is not None:
|
| 64 |
-
pipe.vae.to(dtype=torch.float32)
|
| 65 |
-
else:
|
| 66 |
-
original_dtype = None
|
| 67 |
|
| 68 |
# Enable seamless tiling on UNet & VAE decoder
|
| 69 |
enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
|
|
@@ -72,18 +70,9 @@ def generate_image(
|
|
| 72 |
yield None, "🎨 Generating image..."
|
| 73 |
|
| 74 |
try:
|
| 75 |
-
# Use provided seed or generate a random one if None
|
| 76 |
actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
|
| 77 |
-
generator = torch.Generator(device=
|
| 78 |
-
result =
|
| 79 |
-
prompt=prompt,
|
| 80 |
-
negative_prompt=negative_prompt,
|
| 81 |
-
width=int(width),
|
| 82 |
-
height=int(height),
|
| 83 |
-
num_inference_steps=int(steps),
|
| 84 |
-
guidance_scale=float(cfg),
|
| 85 |
-
generator=generator,
|
| 86 |
-
)
|
| 87 |
|
| 88 |
image = result.images[0]
|
| 89 |
yield image, f"✅ Complete! ({int(width)}x{int(height)})"
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
from . import config
|
| 6 |
+
from .config import device, dtype
|
| 7 |
+
from .gpu_decorator import GPU
|
| 8 |
from .tiling import enable_seamless_tiling
|
| 9 |
|
| 10 |
|
| 11 |
+
@GPU(duration=120)
|
| 12 |
+
def _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator):
|
| 13 |
+
"""GPU-decorated helper that runs the actual inference."""
|
| 14 |
+
return pipe(
|
| 15 |
+
prompt=prompt,
|
| 16 |
+
negative_prompt=negative_prompt,
|
| 17 |
+
width=int(width),
|
| 18 |
+
height=int(height),
|
| 19 |
+
num_inference_steps=int(steps),
|
| 20 |
+
guidance_scale=float(cfg),
|
| 21 |
+
generator=generator,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
def generate_image(
|
| 26 |
prompt: str,
|
| 27 |
negative_prompt: str,
|
|
|
|
| 60 |
yield None, "⚠️ Please load a pipeline first."
|
| 61 |
return
|
| 62 |
|
| 63 |
+
# Ensure VAE stays in float32 to prevent colorful static output
|
| 64 |
+
pipe.vae.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Enable seamless tiling on UNet & VAE decoder
|
| 67 |
enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
|
|
|
|
| 70 |
yield None, "🎨 Generating image..."
|
| 71 |
|
| 72 |
try:
|
|
|
|
| 73 |
actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
|
| 74 |
+
generator = torch.Generator(device=device).manual_seed(actual_seed)
|
| 75 |
+
result = _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
image = result.images[0]
|
| 78 |
yield image, f"✅ Complete! ({int(width)}x{int(height)})"
|
src/gpu_decorator.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ZeroGPU compatibility decorator for HuggingFace Spaces."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
import spaces
|
| 5 |
+
GPU = spaces.GPU
|
| 6 |
+
except ImportError:
|
| 7 |
+
def GPU(func=None, duration=None):
|
| 8 |
+
if func is None:
|
| 9 |
+
return lambda f: f
|
| 10 |
+
return func
|
src/pipeline.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""Pipeline management for SDXL Model Merger."""
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
from diffusers import (
|
| 6 |
StableDiffusionXLPipeline,
|
| 7 |
AutoencoderKL,
|
|
@@ -11,7 +10,61 @@ from diffusers import (
|
|
| 11 |
from . import config
|
| 12 |
from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
|
| 13 |
from .downloader import get_safe_filename_from_url, download_file_with_progress
|
| 14 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def load_pipeline(
|
|
@@ -90,8 +143,7 @@ def load_pipeline(
|
|
| 90 |
if not checkpoint_cached:
|
| 91 |
download_file_with_progress(checkpoint_url, checkpoint_path)
|
| 92 |
|
| 93 |
-
# Download VAE if provided
|
| 94 |
-
vae = None
|
| 95 |
if vae_url and vae_url.strip():
|
| 96 |
if vae_path:
|
| 97 |
status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
|
|
@@ -105,23 +157,6 @@ def load_pipeline(
|
|
| 105 |
if not vae_cached:
|
| 106 |
download_file_with_progress(vae_url, vae_path)
|
| 107 |
|
| 108 |
-
# Load VAE from file
|
| 109 |
-
print(" ⚙️ Loading VAE weights...")
|
| 110 |
-
yield "⚙️ Loading VAE...", f"Loading VAE: {vae_path.name}"
|
| 111 |
-
vae = AutoencoderKL.from_single_file(
|
| 112 |
-
str(vae_path),
|
| 113 |
-
torch_dtype=dtype,
|
| 114 |
-
)
|
| 115 |
-
if progress:
|
| 116 |
-
progress(0.25, desc="VAE loaded")
|
| 117 |
-
|
| 118 |
-
# Load base pipeline (yield progress during this heavy operation)
|
| 119 |
-
print(" ⚙️ Loading SDXL pipeline from single file...")
|
| 120 |
-
yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
|
| 121 |
-
|
| 122 |
-
if progress:
|
| 123 |
-
progress(0.3, desc="Loading text encoders...")
|
| 124 |
-
|
| 125 |
# For CPU/low-memory environments on Spaces, use device_map for better RAM management
|
| 126 |
load_kwargs = {
|
| 127 |
"torch_dtype": dtype,
|
|
@@ -132,31 +167,6 @@ def load_pipeline(
|
|
| 132 |
print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
|
| 133 |
load_kwargs["device_map"] = "auto"
|
| 134 |
|
| 135 |
-
# Use a local variable for the pipeline being built — only stored globally on success.
|
| 136 |
-
_pipe = StableDiffusionXLPipeline.from_single_file(
|
| 137 |
-
str(checkpoint_path),
|
| 138 |
-
**load_kwargs,
|
| 139 |
-
)
|
| 140 |
-
print(" ✅ Text encoders loaded")
|
| 141 |
-
|
| 142 |
-
if progress:
|
| 143 |
-
progress(0.5, desc="Loading UNet...")
|
| 144 |
-
|
| 145 |
-
print(" ✅ UNet loaded")
|
| 146 |
-
|
| 147 |
-
# Move to device (unless using device_map='auto' which handles this automatically)
|
| 148 |
-
if not is_running_on_spaces() or device != "cpu":
|
| 149 |
-
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 150 |
-
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 151 |
-
|
| 152 |
-
yield "⚙️ Pipeline loaded, setting up components...", f"Using device: {device_description}"
|
| 153 |
-
|
| 154 |
-
# Load VAE into pipeline if provided
|
| 155 |
-
if vae is not None:
|
| 156 |
-
print(" ⚙️ Setting custom VAE...")
|
| 157 |
-
_pipe.vae = vae.to(device=device, dtype=dtype)
|
| 158 |
-
yield "⚙️ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
|
| 159 |
-
|
| 160 |
# Parse LoRA URLs & ensure strengths list matches
|
| 161 |
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
|
| 162 |
strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
|
|
@@ -168,11 +178,9 @@ def load_pipeline(
|
|
| 168 |
except ValueError:
|
| 169 |
strengths.append(1.0)
|
| 170 |
|
| 171 |
-
#
|
|
|
|
| 172 |
if lora_urls:
|
| 173 |
-
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 174 |
-
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 175 |
-
|
| 176 |
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 177 |
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
|
| 178 |
lora_path = CACHE_DIR / lora_filename
|
|
@@ -202,34 +210,21 @@ def load_pipeline(
|
|
| 202 |
if not lora_cached:
|
| 203 |
download_file_with_progress(lora_url, lora_path)
|
| 204 |
|
| 205 |
-
|
| 206 |
-
yield f"⚙️ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
|
| 207 |
-
if progress:
|
| 208 |
-
progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
|
| 213 |
-
_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
|
| 214 |
-
_pipe.unload_lora_weights()
|
| 215 |
-
else:
|
| 216 |
-
# Move pipeline to device even without LoRAs
|
| 217 |
-
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 218 |
-
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 219 |
-
|
| 220 |
-
# Set scheduler and finalize (do this once at the end)
|
| 221 |
-
print(" ⚙️ Configuring scheduler...")
|
| 222 |
-
yield "⚙️ Finalizing pipeline...", "Setting up scheduler..."
|
| 223 |
|
| 224 |
if progress:
|
| 225 |
-
progress(0.
|
| 226 |
|
| 227 |
-
_pipe
|
| 228 |
-
|
| 229 |
-
algorithm_type="sde-dpmsolver++",
|
| 230 |
-
use_karras_sigmas=False,
|
| 231 |
)
|
| 232 |
|
|
|
|
|
|
|
|
|
|
| 233 |
# ✅ Only publish the pipeline globally AFTER all steps succeed
|
| 234 |
config.set_pipe(_pipe)
|
| 235 |
|
|
|
|
| 1 |
"""Pipeline management for SDXL Model Merger."""
|
| 2 |
|
| 3 |
+
import torch
|
|
|
|
| 4 |
from diffusers import (
|
| 5 |
StableDiffusionXLPipeline,
|
| 6 |
AutoencoderKL,
|
|
|
|
| 10 |
from . import config
|
| 11 |
from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
|
| 12 |
from .downloader import get_safe_filename_from_url, download_file_with_progress
|
| 13 |
+
from .gpu_decorator import GPU
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@GPU(duration=300)
|
| 17 |
+
def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs):
|
| 18 |
+
"""GPU-decorated helper that performs all GPU-intensive pipeline setup."""
|
| 19 |
+
_pipe = StableDiffusionXLPipeline.from_single_file(
|
| 20 |
+
str(checkpoint_path),
|
| 21 |
+
**load_kwargs,
|
| 22 |
+
)
|
| 23 |
+
print(" ✅ Text encoders loaded")
|
| 24 |
+
|
| 25 |
+
# Move to device (unless using device_map='auto' which handles this automatically)
|
| 26 |
+
if not is_running_on_spaces() or device != "cpu":
|
| 27 |
+
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 28 |
+
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 29 |
+
|
| 30 |
+
# Load custom VAE if provided
|
| 31 |
+
if vae_path is not None:
|
| 32 |
+
print(" ⚙️ Loading VAE weights...")
|
| 33 |
+
vae = AutoencoderKL.from_single_file(
|
| 34 |
+
str(vae_path),
|
| 35 |
+
torch_dtype=dtype,
|
| 36 |
+
)
|
| 37 |
+
print(" ⚙️ Setting custom VAE...")
|
| 38 |
+
_pipe.vae = vae.to(device=device, dtype=torch.float32)
|
| 39 |
+
|
| 40 |
+
# Load and fuse each LoRA
|
| 41 |
+
if lora_paths_and_strengths:
|
| 42 |
+
# Ensure pipeline is on device for LoRA fusion
|
| 43 |
+
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 44 |
+
|
| 45 |
+
for i, (lora_path, strength) in enumerate(lora_paths_and_strengths):
|
| 46 |
+
adapter_name = f"lora_{i}"
|
| 47 |
+
print(f" ⚙️ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...")
|
| 48 |
+
_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
|
| 49 |
+
print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
|
| 50 |
+
_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
|
| 51 |
+
_pipe.unload_lora_weights()
|
| 52 |
+
else:
|
| 53 |
+
# Move pipeline to device even without LoRAs
|
| 54 |
+
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 55 |
+
|
| 56 |
+
# Set scheduler
|
| 57 |
+
print(" ⚙️ Configuring scheduler...")
|
| 58 |
+
_pipe.scheduler = DPMSolverSDEScheduler.from_config(
|
| 59 |
+
_pipe.scheduler.config,
|
| 60 |
+
algorithm_type="sde-dpmsolver++",
|
| 61 |
+
use_karras_sigmas=False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Keep VAE in float32 to prevent colorful static output
|
| 65 |
+
_pipe.vae.to(dtype=torch.float32)
|
| 66 |
+
|
| 67 |
+
return _pipe
|
| 68 |
|
| 69 |
|
| 70 |
def load_pipeline(
|
|
|
|
| 143 |
if not checkpoint_cached:
|
| 144 |
download_file_with_progress(checkpoint_url, checkpoint_path)
|
| 145 |
|
| 146 |
+
# Download VAE if provided (loading happens in _load_and_setup_pipeline)
|
|
|
|
| 147 |
if vae_url and vae_url.strip():
|
| 148 |
if vae_path:
|
| 149 |
status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
|
|
|
|
| 157 |
if not vae_cached:
|
| 158 |
download_file_with_progress(vae_url, vae_path)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
# For CPU/low-memory environments on Spaces, use device_map for better RAM management
|
| 161 |
load_kwargs = {
|
| 162 |
"torch_dtype": dtype,
|
|
|
|
| 167 |
print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
|
| 168 |
load_kwargs["device_map"] = "auto"
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# Parse LoRA URLs & ensure strengths list matches
|
| 171 |
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
|
| 172 |
strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
|
|
|
|
| 178 |
except ValueError:
|
| 179 |
strengths.append(1.0)
|
| 180 |
|
| 181 |
+
# Download LoRAs (CPU-bound downloads, before GPU work)
|
| 182 |
+
lora_paths_and_strengths = []
|
| 183 |
if lora_urls:
|
|
|
|
|
|
|
|
|
|
| 184 |
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 185 |
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
|
| 186 |
lora_path = CACHE_DIR / lora_filename
|
|
|
|
| 210 |
if not lora_cached:
|
| 211 |
download_file_with_progress(lora_url, lora_path)
|
| 212 |
|
| 213 |
+
lora_paths_and_strengths.append((lora_path, strength))
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
# All downloads complete — now do GPU-intensive setup in one decorated call
|
| 216 |
+
yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
if progress:
|
| 219 |
+
progress(0.5, desc="Loading pipeline...")
|
| 220 |
|
| 221 |
+
_pipe = _load_and_setup_pipeline(
|
| 222 |
+
checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
+
if progress:
|
| 226 |
+
progress(0.95, desc="Finalizing...")
|
| 227 |
+
|
| 228 |
# ✅ Only publish the pipeline globally AFTER all steps succeed
|
| 229 |
config.set_pipe(_pipe)
|
| 230 |
|