Kyle Pearson commited on
Commit
3631a8e
·
1 Parent(s): 570384a

Add zero-gpu support, enhance model export with quantization/gpu acceleration helpers, optimize inference pipeline with vae fixes, modernize pipeline loading with unified decorators, implement gpu decorator infrastructure.

Browse files
Files changed (5) hide show
  1. requirements.txt +3 -0
  2. src/exporter.py +77 -86
  3. src/generator.py +20 -31
  4. src/gpu_decorator.py +10 -0
  5. src/pipeline.py +68 -73
requirements.txt CHANGED
@@ -26,3 +26,6 @@ psutil>=5.9.0
26
 
27
  # Optional: quantization support
28
  optimum-quanto>=0.2.0
 
 
 
 
26
 
27
  # Optional: quantization support
28
  optimum-quanto>=0.2.0
29
+
30
+ # ZeroGPU support for HuggingFace Spaces
31
+ spaces
src/exporter.py CHANGED
@@ -1,13 +1,85 @@
1
  """Model export functionality for SDXL Model Merger."""
2
 
3
- import os
4
- from pathlib import Path
5
-
6
  import torch
7
  from safetensors.torch import save_file
8
 
9
  from . import config
10
  from .config import SCRIPT_DIR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def export_merged_model(
@@ -45,90 +117,9 @@ def export_merged_model(
45
  yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
46
  return
47
 
48
- # Step 1: Unload LoRAs
49
- yield "💾 Exporting model...", "Unloading LoRAs..."
50
- if include_lora:
51
- try:
52
- pipe.unload_lora_weights()
53
- except Exception as e:
54
- print(f" ℹ️ Could not unload LoRAs: {e}")
55
-
56
- merged_state_dict = {}
57
-
58
- # Step 2: Extract UNet weights
59
- yield "💾 Exporting model...", "Extracting UNet weights..."
60
- for k, v in pipe.unet.state_dict().items():
61
- merged_state_dict[f"unet.{k}"] = v.contiguous().half()
62
-
63
- # Step 3: Extract text encoder weights
64
- yield "💾 Exporting model...", "Extracting text encoders..."
65
- if pipe.text_encoder is not None:
66
- for k, v in pipe.text_encoder.state_dict().items():
67
- merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
68
- if pipe.text_encoder_2 is not None:
69
- for k, v in pipe.text_encoder_2.state_dict().items():
70
- merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
71
-
72
- # Step 4: Extract VAE weights
73
- yield "💾 Exporting model...", "Extracting VAE weights..."
74
- if pipe.vae is not None:
75
- for k, v in pipe.vae.state_dict().items():
76
- merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
77
-
78
- # Step 5: Quantize if requested and optimum.quanto is available
79
- try:
80
- from optimum.quanto import quantize as quanto_quantize, QTensor
81
- QUANTO_AVAILABLE = True
82
- except ImportError:
83
- QUANTO_AVAILABLE = False
84
-
85
- if quantize and qtype != "none" and QUANTO_AVAILABLE:
86
- yield "💾 Exporting model...", f"Applying {qtype} quantization..."
87
-
88
- class FakeModel(torch.nn.Module):
89
- pass
90
-
91
- fake_model = FakeModel()
92
- fake_model.__dict__.update(merged_state_dict)
93
-
94
- # Select quantization method
95
- if qtype == "int8":
96
- from optimum.quanto import int8_weight_only
97
- quanto_quantize(fake_model, int8_weight_only())
98
- elif qtype == "int4":
99
- from optimum.quanto import int4_weight_only
100
- quanto_quantize(fake_model, int4_weight_only())
101
- elif qtype == "float8":
102
- from optimum.quanto import float8_dynamic_activation_float8_weight
103
- quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
104
- else:
105
- raise ValueError(f"Unsupported qtype: {qtype}")
106
-
107
- merged_state_dict = {
108
- k: v.dequantize().half() if isinstance(v, QTensor) else v
109
- for k, v in fake_model.state_dict().items()
110
- }
111
- elif quantize and not QUANTO_AVAILABLE:
112
- yield None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
113
- return
114
-
115
- # Step 6: Save model
116
- yield "💾 Exporting model...", "Saving weights..."
117
-
118
- ext = ".bin" if save_format == "bin" else ".safetensors"
119
 
120
- # Build filename based on options
121
- prefix = ""
122
- if quantize and qtype != "none":
123
- prefix = f"{qtype}_"
124
-
125
- out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
126
-
127
- # Save appropriately
128
- if ext == ".bin":
129
- torch.save(merged_state_dict, str(out_path))
130
- else:
131
- save_file(merged_state_dict, str(out_path))
132
 
133
  size_gb = out_path.stat().st_size / 1024**3
134
 
 
1
  """Model export functionality for SDXL Model Merger."""
2
 
 
 
 
3
  import torch
4
  from safetensors.torch import save_file
5
 
6
  from . import config
7
  from .config import SCRIPT_DIR
8
+ from .gpu_decorator import GPU
9
+
10
+
11
+ @GPU(duration=180)
12
+ def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
13
+ """GPU-decorated helper that extracts weights and saves the model."""
14
+ if include_lora:
15
+ try:
16
+ pipe.unload_lora_weights()
17
+ except Exception as e:
18
+ print(f" ℹ️ Could not unload LoRAs: {e}")
19
+
20
+ merged_state_dict = {}
21
+
22
+ # Extract UNet weights
23
+ for k, v in pipe.unet.state_dict().items():
24
+ merged_state_dict[f"unet.{k}"] = v.contiguous().half()
25
+
26
+ # Extract text encoder weights
27
+ if pipe.text_encoder is not None:
28
+ for k, v in pipe.text_encoder.state_dict().items():
29
+ merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
30
+ if pipe.text_encoder_2 is not None:
31
+ for k, v in pipe.text_encoder_2.state_dict().items():
32
+ merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
33
+
34
+ # Extract VAE weights
35
+ if pipe.vae is not None:
36
+ for k, v in pipe.vae.state_dict().items():
37
+ merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
38
+
39
+ # Quantize if requested
40
+ try:
41
+ from optimum.quanto import quantize as quanto_quantize, QTensor
42
+ QUANTO_AVAILABLE = True
43
+ except ImportError:
44
+ QUANTO_AVAILABLE = False
45
+
46
+ if quantize and qtype != "none" and QUANTO_AVAILABLE:
47
+ class FakeModel(torch.nn.Module):
48
+ pass
49
+
50
+ fake_model = FakeModel()
51
+ fake_model.__dict__.update(merged_state_dict)
52
+
53
+ if qtype == "int8":
54
+ from optimum.quanto import int8_weight_only
55
+ quanto_quantize(fake_model, int8_weight_only())
56
+ elif qtype == "int4":
57
+ from optimum.quanto import int4_weight_only
58
+ quanto_quantize(fake_model, int4_weight_only())
59
+ elif qtype == "float8":
60
+ from optimum.quanto import float8_dynamic_activation_float8_weight
61
+ quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
62
+ else:
63
+ raise ValueError(f"Unsupported qtype: {qtype}")
64
+
65
+ merged_state_dict = {
66
+ k: v.dequantize().half() if isinstance(v, QTensor) else v
67
+ for k, v in fake_model.state_dict().items()
68
+ }
69
+ elif quantize and not QUANTO_AVAILABLE:
70
+ raise ImportError("optimum.quanto not installed. Install with: pip install optimum-quanto")
71
+
72
+ # Save model
73
+ ext = ".bin" if save_format == "bin" else ".safetensors"
74
+ prefix = f"{qtype}_" if quantize and qtype != "none" else ""
75
+ out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
76
+
77
+ if ext == ".bin":
78
+ torch.save(merged_state_dict, str(out_path))
79
+ else:
80
+ save_file(merged_state_dict, str(out_path))
81
+
82
+ return out_path
83
 
84
 
85
  def export_merged_model(
 
117
  yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
118
  return
119
 
120
+ yield "💾 Exporting model...", "Extracting and saving weights..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  size_gb = out_path.stat().st_size / 1024**3
125
 
src/generator.py CHANGED
@@ -3,10 +3,25 @@
3
  import torch
4
 
5
  from . import config
6
- from .config import device, dtype, is_running_on_spaces
 
7
  from .tiling import enable_seamless_tiling
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def generate_image(
11
  prompt: str,
12
  negative_prompt: str,
@@ -45,25 +60,8 @@ def generate_image(
45
  yield None, "⚠️ Please load a pipeline first."
46
  return
47
 
48
- # For CPU mode, use float32 and warn about slow generation
49
- effective_dtype = dtype
50
- effective_device = device
51
-
52
- if is_running_on_spaces() and device == "cpu":
53
- print(" ℹ️ CPU mode: using float32 for stability (generation will be slower)")
54
- # Store original dtype and temporarily use float32
55
- original_dtype = effective_dtype
56
- effective_dtype = torch.float32
57
- # Update pipeline to use float32
58
- pipe.unet.to(dtype=torch.float32)
59
- if pipe.text_encoder is not None:
60
- pipe.text_encoder.to(dtype=torch.float32)
61
- if pipe.text_encoder_2 is not None:
62
- pipe.text_encoder_2.to(dtype=torch.float32)
63
- if pipe.vae is not None:
64
- pipe.vae.to(dtype=torch.float32)
65
- else:
66
- original_dtype = None
67
 
68
  # Enable seamless tiling on UNet & VAE decoder
69
  enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
@@ -72,18 +70,9 @@ def generate_image(
72
  yield None, "🎨 Generating image..."
73
 
74
  try:
75
- # Use provided seed or generate a random one if None
76
  actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
77
- generator = torch.Generator(device=effective_device).manual_seed(actual_seed)
78
- result = pipe(
79
- prompt=prompt,
80
- negative_prompt=negative_prompt,
81
- width=int(width),
82
- height=int(height),
83
- num_inference_steps=int(steps),
84
- guidance_scale=float(cfg),
85
- generator=generator,
86
- )
87
 
88
  image = result.images[0]
89
  yield image, f"✅ Complete! ({int(width)}x{int(height)})"
 
3
  import torch
4
 
5
  from . import config
6
+ from .config import device, dtype
7
+ from .gpu_decorator import GPU
8
  from .tiling import enable_seamless_tiling
9
 
10
 
11
+ @GPU(duration=120)
12
+ def _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator):
13
+ """GPU-decorated helper that runs the actual inference."""
14
+ return pipe(
15
+ prompt=prompt,
16
+ negative_prompt=negative_prompt,
17
+ width=int(width),
18
+ height=int(height),
19
+ num_inference_steps=int(steps),
20
+ guidance_scale=float(cfg),
21
+ generator=generator,
22
+ )
23
+
24
+
25
  def generate_image(
26
  prompt: str,
27
  negative_prompt: str,
 
60
  yield None, "⚠️ Please load a pipeline first."
61
  return
62
 
63
+ # Ensure VAE stays in float32 to prevent colorful static output
64
+ pipe.vae.to(dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Enable seamless tiling on UNet & VAE decoder
67
  enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
 
70
  yield None, "🎨 Generating image..."
71
 
72
  try:
 
73
  actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
74
+ generator = torch.Generator(device=device).manual_seed(actual_seed)
75
+ result = _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator)
 
 
 
 
 
 
 
 
76
 
77
  image = result.images[0]
78
  yield image, f"✅ Complete! ({int(width)}x{int(height)})"
src/gpu_decorator.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """ZeroGPU compatibility decorator for HuggingFace Spaces."""
2
+
3
+ try:
4
+ import spaces
5
+ GPU = spaces.GPU
6
+ except ImportError:
7
+ def GPU(func=None, duration=None):
8
+ if func is None:
9
+ return lambda f: f
10
+ return func
src/pipeline.py CHANGED
@@ -1,7 +1,6 @@
1
  """Pipeline management for SDXL Model Merger."""
2
 
3
- from pathlib import Path
4
-
5
  from diffusers import (
6
  StableDiffusionXLPipeline,
7
  AutoencoderKL,
@@ -11,7 +10,61 @@ from diffusers import (
11
  from . import config
12
  from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
13
  from .downloader import get_safe_filename_from_url, download_file_with_progress
14
- from .tiling import enable_seamless_tiling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def load_pipeline(
@@ -90,8 +143,7 @@ def load_pipeline(
90
  if not checkpoint_cached:
91
  download_file_with_progress(checkpoint_url, checkpoint_path)
92
 
93
- # Download VAE if provided
94
- vae = None
95
  if vae_url and vae_url.strip():
96
  if vae_path:
97
  status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
@@ -105,23 +157,6 @@ def load_pipeline(
105
  if not vae_cached:
106
  download_file_with_progress(vae_url, vae_path)
107
 
108
- # Load VAE from file
109
- print(" ⚙️ Loading VAE weights...")
110
- yield "⚙️ Loading VAE...", f"Loading VAE: {vae_path.name}"
111
- vae = AutoencoderKL.from_single_file(
112
- str(vae_path),
113
- torch_dtype=dtype,
114
- )
115
- if progress:
116
- progress(0.25, desc="VAE loaded")
117
-
118
- # Load base pipeline (yield progress during this heavy operation)
119
- print(" ⚙️ Loading SDXL pipeline from single file...")
120
- yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
121
-
122
- if progress:
123
- progress(0.3, desc="Loading text encoders...")
124
-
125
  # For CPU/low-memory environments on Spaces, use device_map for better RAM management
126
  load_kwargs = {
127
  "torch_dtype": dtype,
@@ -132,31 +167,6 @@ def load_pipeline(
132
  print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
133
  load_kwargs["device_map"] = "auto"
134
 
135
- # Use a local variable for the pipeline being built — only stored globally on success.
136
- _pipe = StableDiffusionXLPipeline.from_single_file(
137
- str(checkpoint_path),
138
- **load_kwargs,
139
- )
140
- print(" ✅ Text encoders loaded")
141
-
142
- if progress:
143
- progress(0.5, desc="Loading UNet...")
144
-
145
- print(" ✅ UNet loaded")
146
-
147
- # Move to device (unless using device_map='auto' which handles this automatically)
148
- if not is_running_on_spaces() or device != "cpu":
149
- print(f" ⚙️ Moving pipeline to device: {device_description}...")
150
- _pipe = _pipe.to(device=device, dtype=dtype)
151
-
152
- yield "⚙️ Pipeline loaded, setting up components...", f"Using device: {device_description}"
153
-
154
- # Load VAE into pipeline if provided
155
- if vae is not None:
156
- print(" ⚙️ Setting custom VAE...")
157
- _pipe.vae = vae.to(device=device, dtype=dtype)
158
- yield "⚙️ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
159
-
160
  # Parse LoRA URLs & ensure strengths list matches
161
  lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
162
  strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
@@ -168,11 +178,9 @@ def load_pipeline(
168
  except ValueError:
169
  strengths.append(1.0)
170
 
171
- # Load and fuse each LoRA sequentially (only if URLs exist)
 
172
  if lora_urls:
173
- print(f" ⚙️ Moving pipeline to device: {device_description}...")
174
- _pipe = _pipe.to(device=device, dtype=dtype)
175
-
176
  for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
177
  lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
178
  lora_path = CACHE_DIR / lora_filename
@@ -202,34 +210,21 @@ def load_pipeline(
202
  if not lora_cached:
203
  download_file_with_progress(lora_url, lora_path)
204
 
205
- print(f" ⚙️ Loading LoRA {i+1}/{len(lora_urls)}...")
206
- yield f"⚙️ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
207
- if progress:
208
- progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
209
 
210
- adapter_name = f"lora_{i}"
211
- _pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
212
- print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
213
- _pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
214
- _pipe.unload_lora_weights()
215
- else:
216
- # Move pipeline to device even without LoRAs
217
- print(f" ⚙️ Moving pipeline to device: {device_description}...")
218
- _pipe = _pipe.to(device=device, dtype=dtype)
219
-
220
- # Set scheduler and finalize (do this once at the end)
221
- print(" ⚙️ Configuring scheduler...")
222
- yield "⚙️ Finalizing pipeline...", "Setting up scheduler..."
223
 
224
  if progress:
225
- progress(0.95, desc="Finalizing...")
226
 
227
- _pipe.scheduler = DPMSolverSDEScheduler.from_config(
228
- _pipe.scheduler.config,
229
- algorithm_type="sde-dpmsolver++",
230
- use_karras_sigmas=False,
231
  )
232
 
 
 
 
233
  # ✅ Only publish the pipeline globally AFTER all steps succeed
234
  config.set_pipe(_pipe)
235
 
 
1
  """Pipeline management for SDXL Model Merger."""
2
 
3
+ import torch
 
4
  from diffusers import (
5
  StableDiffusionXLPipeline,
6
  AutoencoderKL,
 
10
  from . import config
11
  from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
12
  from .downloader import get_safe_filename_from_url, download_file_with_progress
13
+ from .gpu_decorator import GPU
14
+
15
+
16
+ @GPU(duration=300)
17
+ def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs):
18
+ """GPU-decorated helper that performs all GPU-intensive pipeline setup."""
19
+ _pipe = StableDiffusionXLPipeline.from_single_file(
20
+ str(checkpoint_path),
21
+ **load_kwargs,
22
+ )
23
+ print(" ✅ Text encoders loaded")
24
+
25
+ # Move to device (unless using device_map='auto' which handles this automatically)
26
+ if not is_running_on_spaces() or device != "cpu":
27
+ print(f" ⚙️ Moving pipeline to device: {device_description}...")
28
+ _pipe = _pipe.to(device=device, dtype=dtype)
29
+
30
+ # Load custom VAE if provided
31
+ if vae_path is not None:
32
+ print(" ⚙️ Loading VAE weights...")
33
+ vae = AutoencoderKL.from_single_file(
34
+ str(vae_path),
35
+ torch_dtype=dtype,
36
+ )
37
+ print(" ⚙️ Setting custom VAE...")
38
+ _pipe.vae = vae.to(device=device, dtype=torch.float32)
39
+
40
+ # Load and fuse each LoRA
41
+ if lora_paths_and_strengths:
42
+ # Ensure pipeline is on device for LoRA fusion
43
+ _pipe = _pipe.to(device=device, dtype=dtype)
44
+
45
+ for i, (lora_path, strength) in enumerate(lora_paths_and_strengths):
46
+ adapter_name = f"lora_{i}"
47
+ print(f" ⚙️ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...")
48
+ _pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
49
+ print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
50
+ _pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
51
+ _pipe.unload_lora_weights()
52
+ else:
53
+ # Move pipeline to device even without LoRAs
54
+ _pipe = _pipe.to(device=device, dtype=dtype)
55
+
56
+ # Set scheduler
57
+ print(" ⚙️ Configuring scheduler...")
58
+ _pipe.scheduler = DPMSolverSDEScheduler.from_config(
59
+ _pipe.scheduler.config,
60
+ algorithm_type="sde-dpmsolver++",
61
+ use_karras_sigmas=False,
62
+ )
63
+
64
+ # Keep VAE in float32 to prevent colorful static output
65
+ _pipe.vae.to(dtype=torch.float32)
66
+
67
+ return _pipe
68
 
69
 
70
  def load_pipeline(
 
143
  if not checkpoint_cached:
144
  download_file_with_progress(checkpoint_url, checkpoint_path)
145
 
146
+ # Download VAE if provided (loading happens in _load_and_setup_pipeline)
 
147
  if vae_url and vae_url.strip():
148
  if vae_path:
149
  status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
 
157
  if not vae_cached:
158
  download_file_with_progress(vae_url, vae_path)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # For CPU/low-memory environments on Spaces, use device_map for better RAM management
161
  load_kwargs = {
162
  "torch_dtype": dtype,
 
167
  print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
168
  load_kwargs["device_map"] = "auto"
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # Parse LoRA URLs & ensure strengths list matches
171
  lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
172
  strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
 
178
  except ValueError:
179
  strengths.append(1.0)
180
 
181
+ # Download LoRAs (CPU-bound downloads, before GPU work)
182
+ lora_paths_and_strengths = []
183
  if lora_urls:
 
 
 
184
  for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
185
  lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
186
  lora_path = CACHE_DIR / lora_filename
 
210
  if not lora_cached:
211
  download_file_with_progress(lora_url, lora_path)
212
 
213
+ lora_paths_and_strengths.append((lora_path, strength))
 
 
 
214
 
215
+ # All downloads complete — now do GPU-intensive setup in one decorated call
216
+ yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  if progress:
219
+ progress(0.5, desc="Loading pipeline...")
220
 
221
+ _pipe = _load_and_setup_pipeline(
222
+ checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs
 
 
223
  )
224
 
225
+ if progress:
226
+ progress(0.95, desc="Finalizing...")
227
+
228
  # ✅ Only publish the pipeline globally AFTER all steps succeed
229
  config.set_pipe(_pipe)
230