Kyle Pearson commited on
Commit
570384a
·
1 Parent(s): 459ac47

cleaned up code

Browse files
Files changed (10) hide show
  1. .env.example +12 -0
  2. app.py +18 -33
  3. manifest.json +1 -0
  4. src/config.py +33 -1
  5. src/downloader.py +2 -8
  6. src/exporter.py +39 -26
  7. src/generator.py +53 -33
  8. src/pipeline.py +42 -82
  9. src/tiling.py +52 -0
  10. src/ui/generator_tab.py +55 -18
.env.example ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SDXL Model Merger - Environment Configuration Example
2
+ # Copy this to .env and customize for your deployment
3
+
4
+ # Deployment Environment (auto-detected: local, spaces)
5
+ DEPLOYMENT_ENV=local
6
+
7
+ # Default Model URLs - Use HF models for Spaces compatibility
8
+ DEFAULT_CHECKPOINT_URL=https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true
9
+ DEFAULT_VAE_URL=https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true
10
+
11
+ # Default LoRA - using HF instead of CivitAI (may not be accessible on Spaces)
12
+ DEFAULT_LORA_URLS=https://huggingface.co/nerijs/pixel-art-xl/resolve/main/pixel-art-xl.safetensors?download=true
app.py CHANGED
@@ -78,7 +78,7 @@ def create_app():
78
  }
79
  """
80
 
81
- from src.pipeline import load_pipeline, cancel_download
82
  from src.generator import generate_image
83
  from src.exporter import export_merged_model
84
  from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
@@ -227,7 +227,7 @@ def create_app():
227
  )
228
 
229
  cfg = gr.Slider(
230
- minimum=1.0, maximum=20.0, value=7.5, step=0.5,
231
  label="CFG Scale",
232
  info="Higher values make outputs match prompt more strictly"
233
  )
@@ -247,7 +247,7 @@ def create_app():
247
  )
248
 
249
  steps = gr.Slider(
250
- minimum=1, maximum=100, value=25, step=1,
251
  label="Inference Steps",
252
  info="More steps = better quality but slower"
253
  )
@@ -262,6 +262,13 @@ def create_app():
262
  tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
263
  tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
264
 
 
 
 
 
 
 
 
265
  with gr.Row():
266
  gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
267
 
@@ -361,8 +368,7 @@ def create_app():
361
  return (
362
  '<div class="status-warning">⏳ Loading started...</div>',
363
  "Starting download...",
364
- gr.update(interactive=False),
365
- gr.update(visible=True)
366
  )
367
 
368
  def on_load_pipeline_complete(status_msg, progress_text):
@@ -371,44 +377,34 @@ def create_app():
371
  return (
372
  '<div class="status-success">✅ Pipeline loaded successfully!</div>',
373
  progress_text,
374
- gr.update(interactive=True),
375
- gr.update(visible=False)
376
  )
377
  elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
378
  return (
379
  '<div class="status-warning">⚠️ Download cancelled</div>',
380
  progress_text,
381
- gr.update(interactive=True),
382
- gr.update(visible=False)
383
  )
384
  else:
385
  return (
386
  f'<div class="status-error">{status_msg}</div>',
387
  progress_text,
388
- gr.update(interactive=True),
389
- gr.update(visible=False)
390
  )
391
 
392
- # Cancel button for pipeline loading
393
- cancel_load_btn = gr.Button("🛑 Cancel Loading", variant="secondary", size="sm", visible=False)
394
-
395
  load_btn.click(
396
  fn=on_load_pipeline_start,
397
  inputs=[],
398
- outputs=[load_status, load_progress, load_btn, cancel_load_btn],
399
  ).then(
400
  fn=load_pipeline,
401
  inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
402
  outputs=[load_status, load_progress],
403
  show_progress="full",
404
- ).then(
405
- fn=lambda: (gr.update(interactive=True), gr.update(visible=False)),
406
- inputs=[],
407
- outputs=[cancel_load_btn, cancel_load_btn], # Just to hide it
408
  ).then(
409
  fn=on_load_pipeline_complete,
410
  inputs=[load_status, load_progress],
411
- outputs=[load_status, load_progress, load_btn, cancel_load_btn],
412
  ).then(
413
  fn=lambda: (
414
  gr.update(choices=["(None found)"] + get_cached_checkpoints()),
@@ -419,17 +415,6 @@ def create_app():
419
  outputs=[cached_checkpoints, cached_vaes, cached_loras],
420
  )
421
 
422
- # Cancel button handler
423
- cancel_load_btn.click(
424
- fn=lambda: (cancel_download(),
425
- '<div class="status-warning">⏳ Cancelling...</div>',
426
- "Cancelling download...",
427
- gr.update(visible=False),
428
- gr.update(interactive=True)),
429
- inputs=[],
430
- outputs=[load_status, load_progress, cancel_load_btn, load_btn],
431
- )
432
-
433
  def on_cached_checkpoint_change(cached_path):
434
  """Update URL when a cached checkpoint is selected."""
435
  if cached_path and cached_path != "(None found)":
@@ -502,7 +487,7 @@ def create_app():
502
  outputs=[gen_status, gen_progress, gen_btn],
503
  ).then(
504
  fn=generate_image,
505
- inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
506
  outputs=[image_output, gen_progress],
507
  ).then(
508
  fn=lambda img, msg: on_generate_complete(msg, "Done", img),
@@ -543,7 +528,7 @@ def create_app():
543
  fn=lambda inc, q, qt, fmt: export_merged_model(
544
  include_lora=inc,
545
  quantize=q and (qt != "none"),
546
- qtype=qt if qt != "none" else None,
547
  save_format=fmt,
548
  ),
549
  inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
 
78
  }
79
  """
80
 
81
+ from src.pipeline import load_pipeline
82
  from src.generator import generate_image
83
  from src.exporter import export_merged_model
84
  from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
 
227
  )
228
 
229
  cfg = gr.Slider(
230
+ minimum=1.0, maximum=20.0, value=3.0, step=0.5,
231
  label="CFG Scale",
232
  info="Higher values make outputs match prompt more strictly"
233
  )
 
247
  )
248
 
249
  steps = gr.Slider(
250
+ minimum=1, maximum=100, value=8, step=1,
251
  label="Inference Steps",
252
  info="More steps = better quality but slower"
253
  )
 
262
  tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
263
  tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
264
 
265
+ seed = gr.Number(
266
+ value=80484030936239,
267
+ precision=0,
268
+ label="Seed",
269
+ info="Random seed for reproducible generation"
270
+ )
271
+
272
  with gr.Row():
273
  gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
274
 
 
368
  return (
369
  '<div class="status-warning">⏳ Loading started...</div>',
370
  "Starting download...",
371
+ gr.update(interactive=False)
 
372
  )
373
 
374
  def on_load_pipeline_complete(status_msg, progress_text):
 
377
  return (
378
  '<div class="status-success">✅ Pipeline loaded successfully!</div>',
379
  progress_text,
380
+ gr.update(interactive=True)
 
381
  )
382
  elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
383
  return (
384
  '<div class="status-warning">⚠️ Download cancelled</div>',
385
  progress_text,
386
+ gr.update(interactive=True)
 
387
  )
388
  else:
389
  return (
390
  f'<div class="status-error">{status_msg}</div>',
391
  progress_text,
392
+ gr.update(interactive=True)
 
393
  )
394
 
 
 
 
395
  load_btn.click(
396
  fn=on_load_pipeline_start,
397
  inputs=[],
398
+ outputs=[load_status, load_progress, load_btn],
399
  ).then(
400
  fn=load_pipeline,
401
  inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
402
  outputs=[load_status, load_progress],
403
  show_progress="full",
 
 
 
 
404
  ).then(
405
  fn=on_load_pipeline_complete,
406
  inputs=[load_status, load_progress],
407
+ outputs=[load_status, load_progress, load_btn],
408
  ).then(
409
  fn=lambda: (
410
  gr.update(choices=["(None found)"] + get_cached_checkpoints()),
 
415
  outputs=[cached_checkpoints, cached_vaes, cached_loras],
416
  )
417
 
 
 
 
 
 
 
 
 
 
 
 
418
  def on_cached_checkpoint_change(cached_path):
419
  """Update URL when a cached checkpoint is selected."""
420
  if cached_path and cached_path != "(None found)":
 
487
  outputs=[gen_status, gen_progress, gen_btn],
488
  ).then(
489
  fn=generate_image,
490
+ inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y, seed],
491
  outputs=[image_output, gen_progress],
492
  ).then(
493
  fn=lambda img, msg: on_generate_complete(msg, "Done", img),
 
528
  fn=lambda inc, q, qt, fmt: export_merged_model(
529
  include_lora=inc,
530
  quantize=q and (qt != "none"),
531
+ qtype=qt, # always pass the string value; exporter handles "none" correctly
532
  save_format=fmt,
533
  ),
534
  inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
manifest.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "SDXL Model Merger", "short_name": "SDXL Merger", "description": "Merge SDXL checkpoints, LoRAs, and VAEs with optional quantization", "start_url": "/", "display": "standalone", "background_color": "#ffffff", "theme_color": "#10b981"}
src/config.py CHANGED
@@ -114,11 +114,43 @@ print(f"🚀 Using device: {device_description}")
114
  check_memory_requirements()
115
 
116
  # ──────────────────────────────────────────────
117
- # Global State
 
 
 
 
 
118
  # ──────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  pipe = None
 
 
 
 
 
120
  download_cancelled = False
121
 
 
 
 
 
 
 
 
122
  # ──────────────────────────────────────────────
123
  # Generation Defaults
124
  # ──────────────────────────────────────────────
 
114
  check_memory_requirements()
115
 
116
  # ──────────────────────────────────────────────
117
+ # Global Pipeline State
118
+ #
119
+ # IMPORTANT: Use get_pipe() / set_pipe() instead of importing `pipe` directly.
120
+ # Python's `from .config import pipe` binds the value (None) at import time.
121
+ # Subsequent set_pipe() calls update the mutable dict, so all modules that
122
+ # call get_pipe() will always see the current pipeline instance.
123
  # ──────────────────────────────────────────────
124
+ _pipeline_state: dict = {"pipe": None}
125
+
126
+
127
+ def get_pipe():
128
+ """Get the currently loaded pipeline instance (always up-to-date)."""
129
+ return _pipeline_state["pipe"]
130
+
131
+
132
+ def set_pipe(pipeline) -> None:
133
+ """Set the globally loaded pipeline instance."""
134
+ _pipeline_state["pipe"] = pipeline
135
+
136
+
137
+ # Legacy alias kept for any code that references config.pipe directly.
138
+ # Do NOT use this for checking whether the pipeline is loaded — use get_pipe().
139
  pipe = None
140
+
141
+
142
+ # ──────────────────────────────────────────────
143
+ # Download Cancellation Flag
144
+ # ──────────────────────────────────────────────
145
  download_cancelled = False
146
 
147
+
148
+ def set_download_cancelled(value: bool) -> None:
149
+ """Set the global download cancellation flag."""
150
+ global download_cancelled
151
+ download_cancelled = value
152
+
153
+
154
  # ──────────────────────────────────────────────
155
  # Generation Defaults
156
  # ──────────────────────────────────────────────
src/downloader.py CHANGED
@@ -135,7 +135,7 @@ class TqdmGradio(TqdmBase):
135
  self.last_pct = 0
136
 
137
  def update(self, n=1):
138
- global download_cancelled
139
  if download_cancelled:
140
  raise KeyboardInterrupt("Download cancelled by user")
141
  super().update(n)
@@ -147,12 +147,6 @@ class TqdmGradio(TqdmBase):
147
  self.gradio_prog(pct / 100)
148
 
149
 
150
- def set_download_cancelled(value: bool):
151
- """Set the global download cancellation flag."""
152
- global download_cancelled
153
- download_cancelled = value
154
-
155
-
156
  def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
157
  """
158
  Check if file exists in cache and matches expected size.
@@ -218,7 +212,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
218
  KeyboardInterrupt: If download is cancelled
219
  requests.RequestException: If download fails
220
  """
221
- global download_cancelled
222
 
223
  # Handle local file:// URLs
224
  if url.startswith("file://"):
 
135
  self.last_pct = 0
136
 
137
  def update(self, n=1):
138
+ from .config import download_cancelled
139
  if download_cancelled:
140
  raise KeyboardInterrupt("Download cancelled by user")
141
  super().update(n)
 
147
  self.gradio_prog(pct / 100)
148
 
149
 
 
 
 
 
 
 
150
  def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
151
  """
152
  Check if file exists in cache and matches expected size.
 
212
  KeyboardInterrupt: If download is cancelled
213
  requests.RequestException: If download fails
214
  """
215
+ from .config import download_cancelled
216
 
217
  # Handle local file:// URLs
218
  if url.startswith("file://"):
src/exporter.py CHANGED
@@ -6,7 +6,8 @@ from pathlib import Path
6
  import torch
7
  from safetensors.torch import save_file
8
 
9
- from .config import SCRIPT_DIR, pipe as global_pipe
 
10
 
11
 
12
  def export_merged_model(
@@ -14,7 +15,7 @@ def export_merged_model(
14
  quantize: bool,
15
  qtype: str,
16
  save_format: str = "safetensors",
17
- ) -> tuple[str | None, str]:
18
  """
19
  Export the merged pipeline model with optional LoRA baking and quantization.
20
 
@@ -24,56 +25,66 @@ def export_merged_model(
24
  qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
25
  save_format: Output format - 'safetensors' or 'bin'
26
 
 
 
 
27
  Returns:
28
- Tuple of (output_path or None, status message)
29
  """
30
- if not global_pipe:
31
- return None, "⚠️ Please load a pipeline first."
 
 
 
 
32
 
33
  try:
34
  # Validate quantization type
35
  valid_qtypes = ("none", "int8", "int4", "float8")
36
  if qtype not in valid_qtypes:
37
- return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
 
38
 
39
  # Step 1: Unload LoRAs
40
  yield "💾 Exporting model...", "Unloading LoRAs..."
41
  if include_lora:
42
  try:
43
- global_pipe.unload_lora_weights()
44
  except Exception as e:
45
  print(f" ℹ️ Could not unload LoRAs: {e}")
46
- pass
47
 
48
  merged_state_dict = {}
49
 
50
  # Step 2: Extract UNet weights
51
  yield "💾 Exporting model...", "Extracting UNet weights..."
52
- for k, v in global_pipe.unet.state_dict().items():
53
  merged_state_dict[f"unet.{k}"] = v.contiguous().half()
54
 
55
  # Step 3: Extract text encoder weights
56
  yield "💾 Exporting model...", "Extracting text encoders..."
57
- for k, v in global_pipe.text_encoder.state_dict().items():
58
- merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
59
- for k, v in global_pipe.text_encoder_2.state_dict().items():
60
- merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
 
 
61
 
62
  # Step 4: Extract VAE weights
63
  yield "💾 Exporting model...", "Extracting VAE weights..."
64
- for k, v in global_pipe.vae.state_dict().items():
65
- merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
 
66
 
67
  # Step 5: Quantize if requested and optimum.quanto is available
68
  try:
69
- from optimum.quanto import quantize, QTensor
70
  QUANTO_AVAILABLE = True
71
  except ImportError:
72
  QUANTO_AVAILABLE = False
73
 
74
  if quantize and qtype != "none" and QUANTO_AVAILABLE:
75
  yield "💾 Exporting model...", f"Applying {qtype} quantization..."
76
-
77
  class FakeModel(torch.nn.Module):
78
  pass
79
 
@@ -83,13 +94,13 @@ def export_merged_model(
83
  # Select quantization method
84
  if qtype == "int8":
85
  from optimum.quanto import int8_weight_only
86
- quantize(fake_model, int8_weight_only())
87
  elif qtype == "int4":
88
  from optimum.quanto import int4_weight_only
89
- quantize(fake_model, int4_weight_only())
90
  elif qtype == "float8":
91
  from optimum.quanto import float8_dynamic_activation_float8_weight
92
- quantize(fake_model, float8_dynamic_activation_float8_weight())
93
  else:
94
  raise ValueError(f"Unsupported qtype: {qtype}")
95
 
@@ -98,11 +109,12 @@ def export_merged_model(
98
  for k, v in fake_model.state_dict().items()
99
  }
100
  elif quantize and not QUANTO_AVAILABLE:
101
- return None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
 
102
 
103
  # Step 6: Save model
104
  yield "💾 Exporting model...", "Saving weights..."
105
-
106
  ext = ".bin" if save_format == "bin" else ".safetensors"
107
 
108
  # Build filename based on options
@@ -125,13 +137,14 @@ def export_merged_model(
125
  else:
126
  msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
127
 
128
- yield "💾 Exporting model...", msg
129
- return str(out_path), msg
130
 
131
  except ImportError as e:
132
- return None, f"❌ Missing dependency: {str(e)}"
133
  except Exception as e:
134
- return None, f"❌ Export failed: {str(e)}"
 
 
135
 
136
 
137
  def get_export_status() -> str:
 
6
  import torch
7
  from safetensors.torch import save_file
8
 
9
+ from . import config
10
+ from .config import SCRIPT_DIR
11
 
12
 
13
  def export_merged_model(
 
15
  quantize: bool,
16
  qtype: str,
17
  save_format: str = "safetensors",
18
+ ):
19
  """
20
  Export the merged pipeline model with optional LoRA baking and quantization.
21
 
 
25
  qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
26
  save_format: Output format - 'safetensors' or 'bin'
27
 
28
+ Yields:
29
+ Tuple of (status_message, progress_text) at each export stage.
30
+
31
  Returns:
32
+ Final yielded tuple of (output_path or None, status message)
33
  """
34
+ # Fetch the pipeline at call time — avoids the stale import-by-value problem.
35
+ pipe = config.get_pipe()
36
+
37
+ if not pipe:
38
+ yield None, "⚠️ Please load a pipeline first."
39
+ return
40
 
41
  try:
42
  # Validate quantization type
43
  valid_qtypes = ("none", "int8", "int4", "float8")
44
  if qtype not in valid_qtypes:
45
+ yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
46
+ return
47
 
48
  # Step 1: Unload LoRAs
49
  yield "💾 Exporting model...", "Unloading LoRAs..."
50
  if include_lora:
51
  try:
52
+ pipe.unload_lora_weights()
53
  except Exception as e:
54
  print(f" ℹ️ Could not unload LoRAs: {e}")
 
55
 
56
  merged_state_dict = {}
57
 
58
  # Step 2: Extract UNet weights
59
  yield "💾 Exporting model...", "Extracting UNet weights..."
60
+ for k, v in pipe.unet.state_dict().items():
61
  merged_state_dict[f"unet.{k}"] = v.contiguous().half()
62
 
63
  # Step 3: Extract text encoder weights
64
  yield "💾 Exporting model...", "Extracting text encoders..."
65
+ if pipe.text_encoder is not None:
66
+ for k, v in pipe.text_encoder.state_dict().items():
67
+ merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
68
+ if pipe.text_encoder_2 is not None:
69
+ for k, v in pipe.text_encoder_2.state_dict().items():
70
+ merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
71
 
72
  # Step 4: Extract VAE weights
73
  yield "💾 Exporting model...", "Extracting VAE weights..."
74
+ if pipe.vae is not None:
75
+ for k, v in pipe.vae.state_dict().items():
76
+ merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
77
 
78
  # Step 5: Quantize if requested and optimum.quanto is available
79
  try:
80
+ from optimum.quanto import quantize as quanto_quantize, QTensor
81
  QUANTO_AVAILABLE = True
82
  except ImportError:
83
  QUANTO_AVAILABLE = False
84
 
85
  if quantize and qtype != "none" and QUANTO_AVAILABLE:
86
  yield "💾 Exporting model...", f"Applying {qtype} quantization..."
87
+
88
  class FakeModel(torch.nn.Module):
89
  pass
90
 
 
94
  # Select quantization method
95
  if qtype == "int8":
96
  from optimum.quanto import int8_weight_only
97
+ quanto_quantize(fake_model, int8_weight_only())
98
  elif qtype == "int4":
99
  from optimum.quanto import int4_weight_only
100
+ quanto_quantize(fake_model, int4_weight_only())
101
  elif qtype == "float8":
102
  from optimum.quanto import float8_dynamic_activation_float8_weight
103
+ quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
104
  else:
105
  raise ValueError(f"Unsupported qtype: {qtype}")
106
 
 
109
  for k, v in fake_model.state_dict().items()
110
  }
111
  elif quantize and not QUANTO_AVAILABLE:
112
+ yield None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
113
+ return
114
 
115
  # Step 6: Save model
116
  yield "💾 Exporting model...", "Saving weights..."
117
+
118
  ext = ".bin" if save_format == "bin" else ".safetensors"
119
 
120
  # Build filename based on options
 
137
  else:
138
  msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
139
 
140
+ yield str(out_path), msg
 
141
 
142
  except ImportError as e:
143
+ yield None, f"❌ Missing dependency: {str(e)}"
144
  except Exception as e:
145
+ import traceback
146
+ print(traceback.format_exc())
147
+ yield None, f"❌ Export failed: {str(e)}"
148
 
149
 
150
  def get_export_status() -> str:
src/generator.py CHANGED
@@ -2,8 +2,9 @@
2
 
3
  import torch
4
 
5
- from .config import device, dtype, pipe as global_pipe, is_running_on_spaces
6
- from .pipeline import enable_seamless_tiling
 
7
 
8
 
9
  def generate_image(
@@ -15,7 +16,8 @@ def generate_image(
15
  width: int,
16
  tile_x: bool = True,
17
  tile_y: bool = False,
18
- ) -> tuple[object | None, str]:
 
19
  """
20
  Generate an image using the loaded SDXL pipeline.
21
 
@@ -28,49 +30,67 @@ def generate_image(
28
  width: Output image width in pixels
29
  tile_x: Enable seamless tiling on x-axis
30
  tile_y: Enable seamless tiling on y-axis
 
31
 
32
- Returns:
33
- Tuple of (PIL Image or None, status message)
 
 
 
34
  """
35
- if not global_pipe:
36
- return None, "⚠️ Please load a pipeline first."
 
 
 
 
37
 
38
  # For CPU mode, use float32 and warn about slow generation
39
  effective_dtype = dtype
40
  effective_device = device
41
-
42
  if is_running_on_spaces() and device == "cpu":
43
  print(" ℹ️ CPU mode: using float32 for stability (generation will be slower)")
 
 
44
  effective_dtype = torch.float32
45
  # Update pipeline to use float32
46
- global_pipe.unet.to(dtype=torch.float32)
47
- global_pipe.text_encoder.to(dtype=torch.float32)
48
- global_pipe.text_encoder_2.to(dtype=torch.float32)
49
- if global_pipe.vae:
50
- global_pipe.vae.to(dtype=torch.float32)
 
 
 
 
51
 
52
  # Enable seamless tiling on UNet & VAE decoder
53
- enable_seamless_tiling(global_pipe.unet, tile_x=tile_x, tile_y=tile_y)
54
- enable_seamless_tiling(global_pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)
55
-
56
- yield "🎨 Generating image...", f"Steps: 0/{steps} | CFG: {cfg}"
57
 
58
- generator = torch.Generator(device=effective_device).manual_seed(42) # Fixed seed for reproducibility
59
- result = global_pipe(
60
- prompt=prompt,
61
- negative_prompt=negative_prompt,
62
- width=int(width),
63
- height=int(height),
64
- num_inference_steps=int(steps),
65
- guidance_scale=float(cfg),
66
- generator=generator,
67
- )
68
 
69
- image = result.images[0]
70
- yield "🎨 Generating image...", f"✅ Complete! ({width}x{height})"
 
 
 
 
 
 
 
 
 
 
 
71
 
 
 
72
 
73
- def set_pipeline(pipe):
74
- """Set the global pipeline instance."""
75
- global global_pipe
76
- global_pipe = pipe
 
 
 
2
 
3
  import torch
4
 
5
+ from . import config
6
+ from .config import device, dtype, is_running_on_spaces
7
+ from .tiling import enable_seamless_tiling
8
 
9
 
10
  def generate_image(
 
16
  width: int,
17
  tile_x: bool = True,
18
  tile_y: bool = False,
19
+ seed: int | None = None,
20
+ ):
21
  """
22
  Generate an image using the loaded SDXL pipeline.
23
 
 
30
  width: Output image width in pixels
31
  tile_x: Enable seamless tiling on x-axis
32
  tile_y: Enable seamless tiling on y-axis
33
+ seed: Random seed for reproducibility (default: uses timestamp-based seed)
34
 
35
+ Yields:
36
+ Tuple of (intermediate_image_or_none, status_message)
37
+ - First yield: (None, progress_text) for initial progress update
38
+ - Final yield: (PIL Image, final_status) with the generated image,
39
+ or (None, error_message) if generation failed.
40
  """
41
+ # Fetch the pipeline at call time — avoids the stale import-by-value problem.
42
+ pipe = config.get_pipe()
43
+
44
+ if not pipe:
45
+ yield None, "⚠️ Please load a pipeline first."
46
+ return
47
 
48
  # For CPU mode, use float32 and warn about slow generation
49
  effective_dtype = dtype
50
  effective_device = device
51
+
52
  if is_running_on_spaces() and device == "cpu":
53
  print(" ℹ️ CPU mode: using float32 for stability (generation will be slower)")
54
+ # Store original dtype and temporarily use float32
55
+ original_dtype = effective_dtype
56
  effective_dtype = torch.float32
57
  # Update pipeline to use float32
58
+ pipe.unet.to(dtype=torch.float32)
59
+ if pipe.text_encoder is not None:
60
+ pipe.text_encoder.to(dtype=torch.float32)
61
+ if pipe.text_encoder_2 is not None:
62
+ pipe.text_encoder_2.to(dtype=torch.float32)
63
+ if pipe.vae is not None:
64
+ pipe.vae.to(dtype=torch.float32)
65
+ else:
66
+ original_dtype = None
67
 
68
  # Enable seamless tiling on UNet & VAE decoder
69
+ enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
70
+ enable_seamless_tiling(pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)
 
 
71
 
72
+ yield None, "🎨 Generating image..."
 
 
 
 
 
 
 
 
 
73
 
74
+ try:
75
+ # Use provided seed or generate a random one if None
76
+ actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
77
+ generator = torch.Generator(device=effective_device).manual_seed(actual_seed)
78
+ result = pipe(
79
+ prompt=prompt,
80
+ negative_prompt=negative_prompt,
81
+ width=int(width),
82
+ height=int(height),
83
+ num_inference_steps=int(steps),
84
+ guidance_scale=float(cfg),
85
+ generator=generator,
86
+ )
87
 
88
+ image = result.images[0]
89
+ yield image, f"✅ Complete! ({int(width)}x{int(height)})"
90
 
91
+ except Exception as e:
92
+ import traceback
93
+ error_msg = f"❌ Generation failed: {str(e)}"
94
+ print(error_msg)
95
+ print(traceback.format_exc())
96
+ yield None, error_msg
src/pipeline.py CHANGED
@@ -2,55 +2,16 @@
2
 
3
  from pathlib import Path
4
 
5
- import torch
6
  from diffusers import (
7
  StableDiffusionXLPipeline,
8
  AutoencoderKL,
9
  DPMSolverSDEScheduler,
10
  )
11
 
12
- from .config import device, dtype, pipe as global_pipe, CACHE_DIR, device_description, is_running_on_spaces
13
-
14
-
15
- def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
16
- """Create patched forward for seamless tiling on Conv2d layers."""
17
- original_forward = module._conv_forward
18
-
19
- def patched_conv_forward(input, weight, bias):
20
- if tile_x and tile_y:
21
- input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
22
- elif tile_x:
23
- input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
24
- input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
25
- elif tile_y:
26
- input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
27
- input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
28
- else:
29
- return original_forward(input, weight, bias)
30
-
31
- return torch.nn.functional.conv2d(
32
- input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
33
- )
34
-
35
- return patched_conv_forward
36
-
37
-
38
- def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
39
- """
40
- Enable seamless tiling on a model's Conv2d layers.
41
-
42
- Args:
43
- model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
44
- tile_x: Enable tiling along x-axis
45
- tile_y: Enable tiling along y-axis
46
- """
47
- for module in model.modules():
48
- if isinstance(module, torch.nn.Conv2d):
49
- pad_h = module.padding[0]
50
- pad_w = module.padding[1]
51
- if pad_h == 0 and pad_w == 0:
52
- continue
53
- module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)
54
 
55
 
56
  def load_pipeline(
@@ -70,17 +31,18 @@ def load_pipeline(
70
  lora_strengths_str: Comma-separated strength values for each LoRA
71
  progress: Optional gr.Progress() object for UI updates
72
 
 
 
 
73
  Returns:
74
- Tuple of (final_status_message, progress_text)
75
  """
76
- global global_pipe
 
77
 
78
  try:
79
  set_download_cancelled(False)
80
 
81
- # Import gr here to update button state if needed
82
- import gradio as gr
83
-
84
  print("=" * 60)
85
  print("🔄 Loading SDXL Pipeline...")
86
  print("=" * 60)
@@ -93,8 +55,7 @@ def load_pipeline(
93
 
94
  # Validate cache file before using it
95
  if checkpoint_cached:
96
- from src.config import validate_cache_file
97
- is_valid, msg = validate_cache_file(checkpoint_path)
98
  if not is_valid:
99
  print(f" ⚠️ Cache invalid: {msg}")
100
  checkpoint_path.unlink(missing_ok=True)
@@ -107,8 +68,7 @@ def load_pipeline(
107
 
108
  # Validate VAE cache file before using it
109
  if vae_cached:
110
- from src.config import validate_cache_file
111
- is_valid, msg = validate_cache_file(vae_path)
112
  if not is_valid:
113
  print(f" ⚠️ VAE Cache invalid: {msg}")
114
  vae_path.unlink(missing_ok=True)
@@ -124,7 +84,7 @@ def load_pipeline(
124
  else:
125
  status_msg = f"✅ Using cached {checkpoint_path.name}"
126
  print(f" ✅ Using cached: {checkpoint_path.name}")
127
-
128
  yield status_msg, "Starting download..."
129
 
130
  if not checkpoint_cached:
@@ -136,10 +96,10 @@ def load_pipeline(
136
  if vae_path:
137
  status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
138
  print(f" 📥 VAE: {vae_path.name}" if not vae_cached else f" ✅ VAE (cached): {vae_path.name}")
139
-
140
  if progress:
141
  progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
142
-
143
  yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
144
 
145
  if not vae_cached:
@@ -158,7 +118,7 @@ def load_pipeline(
158
  # Load base pipeline (yield progress during this heavy operation)
159
  print(" ⚙️ Loading SDXL pipeline from single file...")
160
  yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
161
-
162
  if progress:
163
  progress(0.3, desc="Loading text encoders...")
164
 
@@ -166,38 +126,35 @@ def load_pipeline(
166
  load_kwargs = {
167
  "torch_dtype": dtype,
168
  "use_safetensors": True,
169
- "safety_checker": None,
170
  }
171
 
172
  if is_running_on_spaces() and device == "cpu":
173
  print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
174
  load_kwargs["device_map"] = "auto"
175
- else:
176
- # For GPU, we'll move to device after loading
177
- load_kwargs["variant"] = "fp16" if dtype == torch.float16 else None
178
 
179
- global_pipe = StableDiffusionXLPipeline.from_single_file(
 
180
  str(checkpoint_path),
181
  **load_kwargs,
182
  )
183
  print(" ✅ Text encoders loaded")
184
-
185
  if progress:
186
  progress(0.5, desc="Loading UNet...")
187
 
188
  print(" ✅ UNet loaded")
189
-
190
  # Move to device (unless using device_map='auto' which handles this automatically)
191
  if not is_running_on_spaces() or device != "cpu":
192
  print(f" ⚙️ Moving pipeline to device: {device_description}...")
193
- global_pipe = global_pipe.to(device=device, dtype=dtype)
194
-
195
  yield "⚙️ Pipeline loaded, setting up components...", f"Using device: {device_description}"
196
 
197
  # Load VAE into pipeline if provided
198
  if vae is not None:
199
  print(" ⚙️ Setting custom VAE...")
200
- global_pipe.vae = vae.to(device=device, dtype=dtype)
201
  yield "⚙️ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
202
 
203
  # Parse LoRA URLs & ensure strengths list matches
@@ -214,35 +171,34 @@ def load_pipeline(
214
  # Load and fuse each LoRA sequentially (only if URLs exist)
215
  if lora_urls:
216
  print(f" ⚙️ Moving pipeline to device: {device_description}...")
217
- global_pipe = global_pipe.to(device=device, dtype=dtype)
218
 
219
  for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
220
  lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
221
  lora_path = CACHE_DIR / lora_filename
222
  lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
223
-
224
  # Validate LoRA cache file before using it
225
  if lora_cached:
226
- from src.config import validate_cache_file
227
- is_valid, msg = validate_cache_file(lora_path)
228
  if not is_valid:
229
  print(f" ⚠️ LoRA Cache invalid: {msg}")
230
  lora_path.unlink(missing_ok=True)
231
  lora_cached = False
232
-
233
  if not lora_cached:
234
  print(f" 📥 LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
235
  status_msg = f"📥 Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
236
  else:
237
  print(f" ✅ LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
238
  status_msg = f"✅ Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
239
-
240
  yield (
241
  status_msg,
242
  f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
243
  else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
244
  )
245
-
246
  if not lora_cached:
247
  download_file_with_progress(lora_url, lora_path)
248
 
@@ -250,16 +206,16 @@ def load_pipeline(
250
  yield f"⚙️ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
251
  if progress:
252
  progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
253
-
254
  adapter_name = f"lora_{i}"
255
- global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
256
  print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
257
- global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
258
- global_pipe.unload_lora_weights()
259
  else:
260
  # Move pipeline to device even without LoRAs
261
  print(f" ⚙️ Moving pipeline to device: {device_description}...")
262
- global_pipe = global_pipe.to(device=device, dtype=dtype)
263
 
264
  # Set scheduler and finalize (do this once at the end)
265
  print(" ⚙️ Configuring scheduler...")
@@ -268,22 +224,26 @@ def load_pipeline(
268
  if progress:
269
  progress(0.95, desc="Finalizing...")
270
 
271
- global_pipe.scheduler = DPMSolverSDEScheduler.from_config(
272
- global_pipe.scheduler.config,
273
  algorithm_type="sde-dpmsolver++",
274
  use_karras_sigmas=False,
275
  )
276
 
 
 
 
277
  print(" ✅ Pipeline ready!")
278
  yield "✅ Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
279
- return ("✅ Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
280
 
281
  except KeyboardInterrupt:
282
  set_download_cancelled(False)
 
283
  print("\n⚠️ Download cancelled by user")
284
  return ("⚠️ Download cancelled by user", "Cancelled")
285
  except Exception as e:
286
  import traceback
 
287
  error_msg = f"❌ Error loading pipeline: {str(e)}"
288
  print(f"\n{error_msg}")
289
  print(traceback.format_exc())
@@ -297,4 +257,4 @@ def cancel_download():
297
 
298
  def get_pipeline() -> StableDiffusionXLPipeline | None:
299
  """Get the currently loaded pipeline."""
300
- return global_pipe
 
2
 
3
  from pathlib import Path
4
 
 
5
  from diffusers import (
6
  StableDiffusionXLPipeline,
7
  AutoencoderKL,
8
  DPMSolverSDEScheduler,
9
  )
10
 
11
+ from . import config
12
+ from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
13
+ from .downloader import get_safe_filename_from_url, download_file_with_progress
14
+ from .tiling import enable_seamless_tiling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def load_pipeline(
 
31
  lora_strengths_str: Comma-separated strength values for each LoRA
32
  progress: Optional gr.Progress() object for UI updates
33
 
34
+ Yields:
35
+ Tuple of (status_message, progress_text) at each loading stage.
36
+
37
  Returns:
38
+ Final yielded tuple of (final_status_message, progress_text)
39
  """
40
+ # Clear any previously loaded pipeline so the UI reflects loading state
41
+ config.set_pipe(None)
42
 
43
  try:
44
  set_download_cancelled(False)
45
 
 
 
 
46
  print("=" * 60)
47
  print("🔄 Loading SDXL Pipeline...")
48
  print("=" * 60)
 
55
 
56
  # Validate cache file before using it
57
  if checkpoint_cached:
58
+ is_valid, msg = config.validate_cache_file(checkpoint_path)
 
59
  if not is_valid:
60
  print(f" ⚠️ Cache invalid: {msg}")
61
  checkpoint_path.unlink(missing_ok=True)
 
68
 
69
  # Validate VAE cache file before using it
70
  if vae_cached:
71
+ is_valid, msg = config.validate_cache_file(vae_path)
 
72
  if not is_valid:
73
  print(f" ⚠️ VAE Cache invalid: {msg}")
74
  vae_path.unlink(missing_ok=True)
 
84
  else:
85
  status_msg = f"✅ Using cached {checkpoint_path.name}"
86
  print(f" ✅ Using cached: {checkpoint_path.name}")
87
+
88
  yield status_msg, "Starting download..."
89
 
90
  if not checkpoint_cached:
 
96
  if vae_path:
97
  status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
98
  print(f" 📥 VAE: {vae_path.name}" if not vae_cached else f" ✅ VAE (cached): {vae_path.name}")
99
+
100
  if progress:
101
  progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
102
+
103
  yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
104
 
105
  if not vae_cached:
 
118
  # Load base pipeline (yield progress during this heavy operation)
119
  print(" ⚙️ Loading SDXL pipeline from single file...")
120
  yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
121
+
122
  if progress:
123
  progress(0.3, desc="Loading text encoders...")
124
 
 
126
  load_kwargs = {
127
  "torch_dtype": dtype,
128
  "use_safetensors": True,
 
129
  }
130
 
131
  if is_running_on_spaces() and device == "cpu":
132
  print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
133
  load_kwargs["device_map"] = "auto"
 
 
 
134
 
135
+ # Use a local variable for the pipeline being built — only stored globally on success.
136
+ _pipe = StableDiffusionXLPipeline.from_single_file(
137
  str(checkpoint_path),
138
  **load_kwargs,
139
  )
140
  print(" ✅ Text encoders loaded")
141
+
142
  if progress:
143
  progress(0.5, desc="Loading UNet...")
144
 
145
  print(" ✅ UNet loaded")
146
+
147
  # Move to device (unless using device_map='auto' which handles this automatically)
148
  if not is_running_on_spaces() or device != "cpu":
149
  print(f" ⚙️ Moving pipeline to device: {device_description}...")
150
+ _pipe = _pipe.to(device=device, dtype=dtype)
151
+
152
  yield "⚙️ Pipeline loaded, setting up components...", f"Using device: {device_description}"
153
 
154
  # Load VAE into pipeline if provided
155
  if vae is not None:
156
  print(" ⚙️ Setting custom VAE...")
157
+ _pipe.vae = vae.to(device=device, dtype=dtype)
158
  yield "⚙️ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
159
 
160
  # Parse LoRA URLs & ensure strengths list matches
 
171
  # Load and fuse each LoRA sequentially (only if URLs exist)
172
  if lora_urls:
173
  print(f" ⚙️ Moving pipeline to device: {device_description}...")
174
+ _pipe = _pipe.to(device=device, dtype=dtype)
175
 
176
  for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
177
  lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
178
  lora_path = CACHE_DIR / lora_filename
179
  lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
180
+
181
  # Validate LoRA cache file before using it
182
  if lora_cached:
183
+ is_valid, msg = config.validate_cache_file(lora_path)
 
184
  if not is_valid:
185
  print(f" ⚠️ LoRA Cache invalid: {msg}")
186
  lora_path.unlink(missing_ok=True)
187
  lora_cached = False
188
+
189
  if not lora_cached:
190
  print(f" 📥 LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
191
  status_msg = f"📥 Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
192
  else:
193
  print(f" ✅ LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
194
  status_msg = f"✅ Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
195
+
196
  yield (
197
  status_msg,
198
  f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
199
  else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
200
  )
201
+
202
  if not lora_cached:
203
  download_file_with_progress(lora_url, lora_path)
204
 
 
206
  yield f"⚙️ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
207
  if progress:
208
  progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
209
+
210
  adapter_name = f"lora_{i}"
211
+ _pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
212
  print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
213
+ _pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
214
+ _pipe.unload_lora_weights()
215
  else:
216
  # Move pipeline to device even without LoRAs
217
  print(f" ⚙️ Moving pipeline to device: {device_description}...")
218
+ _pipe = _pipe.to(device=device, dtype=dtype)
219
 
220
  # Set scheduler and finalize (do this once at the end)
221
  print(" ⚙️ Configuring scheduler...")
 
224
  if progress:
225
  progress(0.95, desc="Finalizing...")
226
 
227
+ _pipe.scheduler = DPMSolverSDEScheduler.from_config(
228
+ _pipe.scheduler.config,
229
  algorithm_type="sde-dpmsolver++",
230
  use_karras_sigmas=False,
231
  )
232
 
233
+ # ✅ Only publish the pipeline globally AFTER all steps succeed
234
+ config.set_pipe(_pipe)
235
+
236
  print(" ✅ Pipeline ready!")
237
  yield "✅ Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
 
238
 
239
  except KeyboardInterrupt:
240
  set_download_cancelled(False)
241
+ config.set_pipe(None)
242
  print("\n⚠️ Download cancelled by user")
243
  return ("⚠️ Download cancelled by user", "Cancelled")
244
  except Exception as e:
245
  import traceback
246
+ config.set_pipe(None)
247
  error_msg = f"❌ Error loading pipeline: {str(e)}"
248
  print(f"\n{error_msg}")
249
  print(traceback.format_exc())
 
257
 
258
  def get_pipeline() -> StableDiffusionXLPipeline | None:
259
  """Get the currently loaded pipeline."""
260
+ return config.get_pipe()
src/tiling.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Seamless tiling utilities for SDXL Model Merger."""
2
+
3
+ import torch
4
+
5
+
6
+ def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
7
+ """Create patched forward for seamless tiling on Conv2d layers."""
8
+ original_forward = module._conv_forward
9
+
10
+ def patched_conv_forward(input, weight, bias):
11
+ if tile_x and tile_y:
12
+ # Circular padding on both axes
13
+ input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
14
+ elif tile_x:
15
+ # Circular padding only on left/right edges, constant (zero) on top/bottom
16
+ # Asymmetric padding for 360° panorama tiling
17
+ input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
18
+ input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
19
+ elif tile_y:
20
+ # Circular padding only on top/bottom edges, constant (zero) on left/right
21
+ input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
22
+ input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
23
+ else:
24
+ return original_forward(input, weight, bias)
25
+
26
+ return torch.nn.functional.conv2d(
27
+ input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
28
+ )
29
+
30
+ return patched_conv_forward
31
+
32
+
33
+ def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
34
+ """
35
+ Enable seamless tiling on a model's Conv2d layers.
36
+
37
+ Args:
38
+ model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
39
+ tile_x: Enable tiling along x-axis
40
+ tile_y: Enable tiling along y-axis
41
+ """
42
+ for module in model.modules():
43
+ if isinstance(module, torch.nn.Conv2d):
44
+ pad_h = module.padding[0]
45
+ pad_w = module.padding[1]
46
+ if pad_h == 0 and pad_w == 0:
47
+ continue
48
+ current = getattr(module, "_tiling_config", None)
49
+ if current == (tile_x, tile_y):
50
+ continue # already patched with same config
51
+ module._tiling_config = (tile_x, tile_y)
52
+ module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)
src/ui/generator_tab.py CHANGED
@@ -8,7 +8,7 @@ from ..generator import generate_image
8
 
9
  def create_generator_tab():
10
  """Create the image generation tab with all input controls."""
11
-
12
  with gr.Accordion("🎨 2. Generate Image", open=True, elem_classes=["feature-card"]):
13
  # Prompts section
14
  with gr.Row():
@@ -19,19 +19,19 @@ def create_generator_tab():
19
  lines=3,
20
  placeholder="Describe the image you want to generate..."
21
  )
22
-
23
  cfg = gr.Slider(
24
  minimum=1.0, maximum=20.0, value=7.5, step=0.5,
25
  label="CFG Scale",
26
  info="Higher values make outputs match prompt more strictly"
27
  )
28
-
29
  height = gr.Number(
30
  value=1024, precision=0,
31
  label="Height (pixels)",
32
  info="Output image height"
33
  )
34
-
35
  with gr.Column(scale=1):
36
  negative_prompt = gr.Textbox(
37
  label="Negative Prompt",
@@ -39,28 +39,28 @@ def create_generator_tab():
39
  lines=3,
40
  placeholder="Elements to avoid in generation..."
41
  )
42
-
43
  steps = gr.Slider(
44
  minimum=1, maximum=100, value=25, step=1,
45
  label="Inference Steps",
46
  info="More steps = better quality but slower"
47
  )
48
-
49
  width = gr.Number(
50
  value=2048, precision=0,
51
  label="Width (pixels)",
52
  info="Output image width"
53
  )
54
-
55
  # Tiling options
56
  with gr.Row():
57
  tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
58
  tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
59
-
60
  # Generate button and outputs
61
  with gr.Row():
62
  gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
63
-
64
  with gr.Row():
65
  image_output = gr.Image(
66
  label="Result",
@@ -68,12 +68,18 @@ def create_generator_tab():
68
  show_label=True
69
  )
70
  with gr.Column():
71
- gen_status = gr.Textbox(
72
  label="Generation Status",
73
- placeholder="Ready to generate..."
74
  )
75
-
76
- # Quick tips
 
 
 
 
 
 
77
  gr.HTML("""
78
  <div style="margin-top: 16px; padding: 12px; background: #f3f4f6; border-radius: 8px;">
79
  <strong>💡 Tips:</strong>
@@ -84,21 +90,52 @@ def create_generator_tab():
84
  </ul>
85
  </div>
86
  """)
87
-
88
  return (
89
  prompt, negative_prompt, cfg, steps, height, width,
90
- tile_x, tile_y, gen_btn, image_output, gen_status
91
  )
92
 
93
 
94
  def setup_generator_events(
95
  prompt, negative_prompt, cfg, steps, height, width,
96
- tile_x, tile_y, gen_btn, image_output, gen_status
97
  ):
98
  """Setup event handlers for the generator tab."""
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  gen_btn.click(
 
 
 
 
101
  fn=generate_image,
102
  inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
103
- outputs=[image_output, gen_status],
 
 
 
 
104
  )
 
8
 
9
  def create_generator_tab():
10
  """Create the image generation tab with all input controls."""
11
+
12
  with gr.Accordion("🎨 2. Generate Image", open=True, elem_classes=["feature-card"]):
13
  # Prompts section
14
  with gr.Row():
 
19
  lines=3,
20
  placeholder="Describe the image you want to generate..."
21
  )
22
+
23
  cfg = gr.Slider(
24
  minimum=1.0, maximum=20.0, value=7.5, step=0.5,
25
  label="CFG Scale",
26
  info="Higher values make outputs match prompt more strictly"
27
  )
28
+
29
  height = gr.Number(
30
  value=1024, precision=0,
31
  label="Height (pixels)",
32
  info="Output image height"
33
  )
34
+
35
  with gr.Column(scale=1):
36
  negative_prompt = gr.Textbox(
37
  label="Negative Prompt",
 
39
  lines=3,
40
  placeholder="Elements to avoid in generation..."
41
  )
42
+
43
  steps = gr.Slider(
44
  minimum=1, maximum=100, value=25, step=1,
45
  label="Inference Steps",
46
  info="More steps = better quality but slower"
47
  )
48
+
49
  width = gr.Number(
50
  value=2048, precision=0,
51
  label="Width (pixels)",
52
  info="Output image width"
53
  )
54
+
55
  # Tiling options
56
  with gr.Row():
57
  tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
58
  tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
59
+
60
  # Generate button and outputs
61
  with gr.Row():
62
  gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
63
+
64
  with gr.Row():
65
  image_output = gr.Image(
66
  label="Result",
 
68
  show_label=True
69
  )
70
  with gr.Column():
71
+ gen_status = gr.HTML(
72
  label="Generation Status",
73
+ value='<div class="status-success">✅ Ready to generate</div>',
74
  )
75
+
76
+ gen_progress = gr.Textbox(
77
+ label="Generation Progress",
78
+ placeholder="Ready to generate...",
79
+ show_label=True,
80
+ visible=False
81
+ )
82
+
83
  gr.HTML("""
84
  <div style="margin-top: 16px; padding: 12px; background: #f3f4f6; border-radius: 8px;">
85
  <strong>💡 Tips:</strong>
 
90
  </ul>
91
  </div>
92
  """)
93
+
94
  return (
95
  prompt, negative_prompt, cfg, steps, height, width,
96
+ tile_x, tile_y, gen_btn, image_output, gen_status, gen_progress
97
  )
98
 
99
 
100
  def setup_generator_events(
101
  prompt, negative_prompt, cfg, steps, height, width,
102
+ tile_x, tile_y, gen_btn, image_output, gen_status, gen_progress
103
  ):
104
  """Setup event handlers for the generator tab."""
105
+
106
+ def on_generate_start():
107
+ return (
108
+ '<div class="status-warning">⏳ Generating image...</div>',
109
+ "Starting generation...",
110
+ gr.update(interactive=False),
111
+ )
112
+
113
+ def on_generate_complete(img, msg):
114
+ """Read both image and last progress message to determine success/failure."""
115
+ if img is None:
116
+ return (
117
+ f'<div class="status-error">{msg}</div>',
118
+ "",
119
+ gr.update(interactive=True),
120
+ gr.update(),
121
+ )
122
+ return (
123
+ '<div class="status-success">✅ Generation complete!</div>',
124
+ "Done",
125
+ gr.update(interactive=True),
126
+ gr.update(value=img),
127
+ )
128
+
129
  gen_btn.click(
130
+ fn=on_generate_start,
131
+ inputs=[],
132
+ outputs=[gen_status, gen_progress, gen_btn],
133
+ ).then(
134
  fn=generate_image,
135
  inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
136
+ outputs=[image_output, gen_progress],
137
+ ).then(
138
+ fn=on_generate_complete,
139
+ inputs=[image_output, gen_progress],
140
+ outputs=[gen_status, gen_progress, gen_btn, image_output],
141
  )