Spaces:
Sleeping
Sleeping
Kyle Pearson commited on
Commit ·
570384a
1
Parent(s): 459ac47
cleaned up code
Browse files- .env.example +12 -0
- app.py +18 -33
- manifest.json +1 -0
- src/config.py +33 -1
- src/downloader.py +2 -8
- src/exporter.py +39 -26
- src/generator.py +53 -33
- src/pipeline.py +42 -82
- src/tiling.py +52 -0
- src/ui/generator_tab.py +55 -18
.env.example
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SDXL Model Merger - Environment Configuration Example
|
| 2 |
+
# Copy this to .env and customize for your deployment
|
| 3 |
+
|
| 4 |
+
# Deployment Environment (auto-detected: local, spaces)
|
| 5 |
+
DEPLOYMENT_ENV=local
|
| 6 |
+
|
| 7 |
+
# Default Model URLs - Use HF models for Spaces compatibility
|
| 8 |
+
DEFAULT_CHECKPOINT_URL=https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true
|
| 9 |
+
DEFAULT_VAE_URL=https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true
|
| 10 |
+
|
| 11 |
+
# Default LoRA - using HF instead of CivitAI (may not be accessible on Spaces)
|
| 12 |
+
DEFAULT_LORA_URLS=https://huggingface.co/nerijs/pixel-art-xl/resolve/main/pixel-art-xl.safetensors?download=true
|
app.py
CHANGED
|
@@ -78,7 +78,7 @@ def create_app():
|
|
| 78 |
}
|
| 79 |
"""
|
| 80 |
|
| 81 |
-
from src.pipeline import load_pipeline
|
| 82 |
from src.generator import generate_image
|
| 83 |
from src.exporter import export_merged_model
|
| 84 |
from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
|
|
@@ -227,7 +227,7 @@ def create_app():
|
|
| 227 |
)
|
| 228 |
|
| 229 |
cfg = gr.Slider(
|
| 230 |
-
minimum=1.0, maximum=20.0, value=
|
| 231 |
label="CFG Scale",
|
| 232 |
info="Higher values make outputs match prompt more strictly"
|
| 233 |
)
|
|
@@ -247,7 +247,7 @@ def create_app():
|
|
| 247 |
)
|
| 248 |
|
| 249 |
steps = gr.Slider(
|
| 250 |
-
minimum=1, maximum=100, value=
|
| 251 |
label="Inference Steps",
|
| 252 |
info="More steps = better quality but slower"
|
| 253 |
)
|
|
@@ -262,6 +262,13 @@ def create_app():
|
|
| 262 |
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
|
| 263 |
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
with gr.Row():
|
| 266 |
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
|
| 267 |
|
|
@@ -361,8 +368,7 @@ def create_app():
|
|
| 361 |
return (
|
| 362 |
'<div class="status-warning">⏳ Loading started...</div>',
|
| 363 |
"Starting download...",
|
| 364 |
-
gr.update(interactive=False)
|
| 365 |
-
gr.update(visible=True)
|
| 366 |
)
|
| 367 |
|
| 368 |
def on_load_pipeline_complete(status_msg, progress_text):
|
|
@@ -371,44 +377,34 @@ def create_app():
|
|
| 371 |
return (
|
| 372 |
'<div class="status-success">✅ Pipeline loaded successfully!</div>',
|
| 373 |
progress_text,
|
| 374 |
-
gr.update(interactive=True)
|
| 375 |
-
gr.update(visible=False)
|
| 376 |
)
|
| 377 |
elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
|
| 378 |
return (
|
| 379 |
'<div class="status-warning">⚠️ Download cancelled</div>',
|
| 380 |
progress_text,
|
| 381 |
-
gr.update(interactive=True)
|
| 382 |
-
gr.update(visible=False)
|
| 383 |
)
|
| 384 |
else:
|
| 385 |
return (
|
| 386 |
f'<div class="status-error">{status_msg}</div>',
|
| 387 |
progress_text,
|
| 388 |
-
gr.update(interactive=True)
|
| 389 |
-
gr.update(visible=False)
|
| 390 |
)
|
| 391 |
|
| 392 |
-
# Cancel button for pipeline loading
|
| 393 |
-
cancel_load_btn = gr.Button("🛑 Cancel Loading", variant="secondary", size="sm", visible=False)
|
| 394 |
-
|
| 395 |
load_btn.click(
|
| 396 |
fn=on_load_pipeline_start,
|
| 397 |
inputs=[],
|
| 398 |
-
outputs=[load_status, load_progress, load_btn
|
| 399 |
).then(
|
| 400 |
fn=load_pipeline,
|
| 401 |
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 402 |
outputs=[load_status, load_progress],
|
| 403 |
show_progress="full",
|
| 404 |
-
).then(
|
| 405 |
-
fn=lambda: (gr.update(interactive=True), gr.update(visible=False)),
|
| 406 |
-
inputs=[],
|
| 407 |
-
outputs=[cancel_load_btn, cancel_load_btn], # Just to hide it
|
| 408 |
).then(
|
| 409 |
fn=on_load_pipeline_complete,
|
| 410 |
inputs=[load_status, load_progress],
|
| 411 |
-
outputs=[load_status, load_progress, load_btn
|
| 412 |
).then(
|
| 413 |
fn=lambda: (
|
| 414 |
gr.update(choices=["(None found)"] + get_cached_checkpoints()),
|
|
@@ -419,17 +415,6 @@ def create_app():
|
|
| 419 |
outputs=[cached_checkpoints, cached_vaes, cached_loras],
|
| 420 |
)
|
| 421 |
|
| 422 |
-
# Cancel button handler
|
| 423 |
-
cancel_load_btn.click(
|
| 424 |
-
fn=lambda: (cancel_download(),
|
| 425 |
-
'<div class="status-warning">⏳ Cancelling...</div>',
|
| 426 |
-
"Cancelling download...",
|
| 427 |
-
gr.update(visible=False),
|
| 428 |
-
gr.update(interactive=True)),
|
| 429 |
-
inputs=[],
|
| 430 |
-
outputs=[load_status, load_progress, cancel_load_btn, load_btn],
|
| 431 |
-
)
|
| 432 |
-
|
| 433 |
def on_cached_checkpoint_change(cached_path):
|
| 434 |
"""Update URL when a cached checkpoint is selected."""
|
| 435 |
if cached_path and cached_path != "(None found)":
|
|
@@ -502,7 +487,7 @@ def create_app():
|
|
| 502 |
outputs=[gen_status, gen_progress, gen_btn],
|
| 503 |
).then(
|
| 504 |
fn=generate_image,
|
| 505 |
-
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 506 |
outputs=[image_output, gen_progress],
|
| 507 |
).then(
|
| 508 |
fn=lambda img, msg: on_generate_complete(msg, "Done", img),
|
|
@@ -543,7 +528,7 @@ def create_app():
|
|
| 543 |
fn=lambda inc, q, qt, fmt: export_merged_model(
|
| 544 |
include_lora=inc,
|
| 545 |
quantize=q and (qt != "none"),
|
| 546 |
-
qtype=qt
|
| 547 |
save_format=fmt,
|
| 548 |
),
|
| 549 |
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
|
|
|
|
| 78 |
}
|
| 79 |
"""
|
| 80 |
|
| 81 |
+
from src.pipeline import load_pipeline
|
| 82 |
from src.generator import generate_image
|
| 83 |
from src.exporter import export_merged_model
|
| 84 |
from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
|
|
|
|
| 227 |
)
|
| 228 |
|
| 229 |
cfg = gr.Slider(
|
| 230 |
+
minimum=1.0, maximum=20.0, value=3.0, step=0.5,
|
| 231 |
label="CFG Scale",
|
| 232 |
info="Higher values make outputs match prompt more strictly"
|
| 233 |
)
|
|
|
|
| 247 |
)
|
| 248 |
|
| 249 |
steps = gr.Slider(
|
| 250 |
+
minimum=1, maximum=100, value=8, step=1,
|
| 251 |
label="Inference Steps",
|
| 252 |
info="More steps = better quality but slower"
|
| 253 |
)
|
|
|
|
| 262 |
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
|
| 263 |
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
|
| 264 |
|
| 265 |
+
seed = gr.Number(
|
| 266 |
+
value=80484030936239,
|
| 267 |
+
precision=0,
|
| 268 |
+
label="Seed",
|
| 269 |
+
info="Random seed for reproducible generation"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
with gr.Row():
|
| 273 |
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
|
| 274 |
|
|
|
|
| 368 |
return (
|
| 369 |
'<div class="status-warning">⏳ Loading started...</div>',
|
| 370 |
"Starting download...",
|
| 371 |
+
gr.update(interactive=False)
|
|
|
|
| 372 |
)
|
| 373 |
|
| 374 |
def on_load_pipeline_complete(status_msg, progress_text):
|
|
|
|
| 377 |
return (
|
| 378 |
'<div class="status-success">✅ Pipeline loaded successfully!</div>',
|
| 379 |
progress_text,
|
| 380 |
+
gr.update(interactive=True)
|
|
|
|
| 381 |
)
|
| 382 |
elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
|
| 383 |
return (
|
| 384 |
'<div class="status-warning">⚠️ Download cancelled</div>',
|
| 385 |
progress_text,
|
| 386 |
+
gr.update(interactive=True)
|
|
|
|
| 387 |
)
|
| 388 |
else:
|
| 389 |
return (
|
| 390 |
f'<div class="status-error">{status_msg}</div>',
|
| 391 |
progress_text,
|
| 392 |
+
gr.update(interactive=True)
|
|
|
|
| 393 |
)
|
| 394 |
|
|
|
|
|
|
|
|
|
|
| 395 |
load_btn.click(
|
| 396 |
fn=on_load_pipeline_start,
|
| 397 |
inputs=[],
|
| 398 |
+
outputs=[load_status, load_progress, load_btn],
|
| 399 |
).then(
|
| 400 |
fn=load_pipeline,
|
| 401 |
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 402 |
outputs=[load_status, load_progress],
|
| 403 |
show_progress="full",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
).then(
|
| 405 |
fn=on_load_pipeline_complete,
|
| 406 |
inputs=[load_status, load_progress],
|
| 407 |
+
outputs=[load_status, load_progress, load_btn],
|
| 408 |
).then(
|
| 409 |
fn=lambda: (
|
| 410 |
gr.update(choices=["(None found)"] + get_cached_checkpoints()),
|
|
|
|
| 415 |
outputs=[cached_checkpoints, cached_vaes, cached_loras],
|
| 416 |
)
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
def on_cached_checkpoint_change(cached_path):
|
| 419 |
"""Update URL when a cached checkpoint is selected."""
|
| 420 |
if cached_path and cached_path != "(None found)":
|
|
|
|
| 487 |
outputs=[gen_status, gen_progress, gen_btn],
|
| 488 |
).then(
|
| 489 |
fn=generate_image,
|
| 490 |
+
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y, seed],
|
| 491 |
outputs=[image_output, gen_progress],
|
| 492 |
).then(
|
| 493 |
fn=lambda img, msg: on_generate_complete(msg, "Done", img),
|
|
|
|
| 528 |
fn=lambda inc, q, qt, fmt: export_merged_model(
|
| 529 |
include_lora=inc,
|
| 530 |
quantize=q and (qt != "none"),
|
| 531 |
+
qtype=qt, # always pass the string value; exporter handles "none" correctly
|
| 532 |
save_format=fmt,
|
| 533 |
),
|
| 534 |
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
|
manifest.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"name": "SDXL Model Merger", "short_name": "SDXL Merger", "description": "Merge SDXL checkpoints, LoRAs, and VAEs with optional quantization", "start_url": "/", "display": "standalone", "background_color": "#ffffff", "theme_color": "#10b981"}
|
src/config.py
CHANGED
|
@@ -114,11 +114,43 @@ print(f"🚀 Using device: {device_description}")
|
|
| 114 |
check_memory_requirements()
|
| 115 |
|
| 116 |
# ──────────────────────────────────────────────
|
| 117 |
-
# Global State
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# ──────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
download_cancelled = False
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
# ──────────────────────────────────────────────
|
| 123 |
# Generation Defaults
|
| 124 |
# ──────────────────────────────────────────────
|
|
|
|
| 114 |
check_memory_requirements()
|
| 115 |
|
| 116 |
# ──────────────────────────────────────────────
|
| 117 |
+
# Global Pipeline State
|
| 118 |
+
#
|
| 119 |
+
# IMPORTANT: Use get_pipe() / set_pipe() instead of importing `pipe` directly.
|
| 120 |
+
# Python's `from .config import pipe` binds the value (None) at import time.
|
| 121 |
+
# Subsequent set_pipe() calls update the mutable dict, so all modules that
|
| 122 |
+
# call get_pipe() will always see the current pipeline instance.
|
| 123 |
# ──────────────────────────────────────────────
|
| 124 |
+
_pipeline_state: dict = {"pipe": None}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_pipe():
|
| 128 |
+
"""Get the currently loaded pipeline instance (always up-to-date)."""
|
| 129 |
+
return _pipeline_state["pipe"]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def set_pipe(pipeline) -> None:
|
| 133 |
+
"""Set the globally loaded pipeline instance."""
|
| 134 |
+
_pipeline_state["pipe"] = pipeline
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Legacy alias kept for any code that references config.pipe directly.
|
| 138 |
+
# Do NOT use this for checking whether the pipeline is loaded — use get_pipe().
|
| 139 |
pipe = None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ──────────────────────────────────────────────
|
| 143 |
+
# Download Cancellation Flag
|
| 144 |
+
# ──────────────────────────────────────────────
|
| 145 |
download_cancelled = False
|
| 146 |
|
| 147 |
+
|
| 148 |
+
def set_download_cancelled(value: bool) -> None:
|
| 149 |
+
"""Set the global download cancellation flag."""
|
| 150 |
+
global download_cancelled
|
| 151 |
+
download_cancelled = value
|
| 152 |
+
|
| 153 |
+
|
| 154 |
# ──────────────────────────────────────────────
|
| 155 |
# Generation Defaults
|
| 156 |
# ──────────────────────────────────────────────
|
src/downloader.py
CHANGED
|
@@ -135,7 +135,7 @@ class TqdmGradio(TqdmBase):
|
|
| 135 |
self.last_pct = 0
|
| 136 |
|
| 137 |
def update(self, n=1):
|
| 138 |
-
|
| 139 |
if download_cancelled:
|
| 140 |
raise KeyboardInterrupt("Download cancelled by user")
|
| 141 |
super().update(n)
|
|
@@ -147,12 +147,6 @@ class TqdmGradio(TqdmBase):
|
|
| 147 |
self.gradio_prog(pct / 100)
|
| 148 |
|
| 149 |
|
| 150 |
-
def set_download_cancelled(value: bool):
|
| 151 |
-
"""Set the global download cancellation flag."""
|
| 152 |
-
global download_cancelled
|
| 153 |
-
download_cancelled = value
|
| 154 |
-
|
| 155 |
-
|
| 156 |
def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
|
| 157 |
"""
|
| 158 |
Check if file exists in cache and matches expected size.
|
|
@@ -218,7 +212,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 218 |
KeyboardInterrupt: If download is cancelled
|
| 219 |
requests.RequestException: If download fails
|
| 220 |
"""
|
| 221 |
-
|
| 222 |
|
| 223 |
# Handle local file:// URLs
|
| 224 |
if url.startswith("file://"):
|
|
|
|
| 135 |
self.last_pct = 0
|
| 136 |
|
| 137 |
def update(self, n=1):
|
| 138 |
+
from .config import download_cancelled
|
| 139 |
if download_cancelled:
|
| 140 |
raise KeyboardInterrupt("Download cancelled by user")
|
| 141 |
super().update(n)
|
|
|
|
| 147 |
self.gradio_prog(pct / 100)
|
| 148 |
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
|
| 151 |
"""
|
| 152 |
Check if file exists in cache and matches expected size.
|
|
|
|
| 212 |
KeyboardInterrupt: If download is cancelled
|
| 213 |
requests.RequestException: If download fails
|
| 214 |
"""
|
| 215 |
+
from .config import download_cancelled
|
| 216 |
|
| 217 |
# Handle local file:// URLs
|
| 218 |
if url.startswith("file://"):
|
src/exporter.py
CHANGED
|
@@ -6,7 +6,8 @@ from pathlib import Path
|
|
| 6 |
import torch
|
| 7 |
from safetensors.torch import save_file
|
| 8 |
|
| 9 |
-
from .
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def export_merged_model(
|
|
@@ -14,7 +15,7 @@ def export_merged_model(
|
|
| 14 |
quantize: bool,
|
| 15 |
qtype: str,
|
| 16 |
save_format: str = "safetensors",
|
| 17 |
-
)
|
| 18 |
"""
|
| 19 |
Export the merged pipeline model with optional LoRA baking and quantization.
|
| 20 |
|
|
@@ -24,56 +25,66 @@ def export_merged_model(
|
|
| 24 |
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
|
| 25 |
save_format: Output format - 'safetensors' or 'bin'
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
Returns:
|
| 28 |
-
|
| 29 |
"""
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
try:
|
| 34 |
# Validate quantization type
|
| 35 |
valid_qtypes = ("none", "int8", "int4", "float8")
|
| 36 |
if qtype not in valid_qtypes:
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
# Step 1: Unload LoRAs
|
| 40 |
yield "💾 Exporting model...", "Unloading LoRAs..."
|
| 41 |
if include_lora:
|
| 42 |
try:
|
| 43 |
-
|
| 44 |
except Exception as e:
|
| 45 |
print(f" ℹ️ Could not unload LoRAs: {e}")
|
| 46 |
-
pass
|
| 47 |
|
| 48 |
merged_state_dict = {}
|
| 49 |
|
| 50 |
# Step 2: Extract UNet weights
|
| 51 |
yield "💾 Exporting model...", "Extracting UNet weights..."
|
| 52 |
-
for k, v in
|
| 53 |
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
|
| 54 |
|
| 55 |
# Step 3: Extract text encoder weights
|
| 56 |
yield "💾 Exporting model...", "Extracting text encoders..."
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Step 4: Extract VAE weights
|
| 63 |
yield "💾 Exporting model...", "Extracting VAE weights..."
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
# Step 5: Quantize if requested and optimum.quanto is available
|
| 68 |
try:
|
| 69 |
-
from optimum.quanto import quantize, QTensor
|
| 70 |
QUANTO_AVAILABLE = True
|
| 71 |
except ImportError:
|
| 72 |
QUANTO_AVAILABLE = False
|
| 73 |
|
| 74 |
if quantize and qtype != "none" and QUANTO_AVAILABLE:
|
| 75 |
yield "💾 Exporting model...", f"Applying {qtype} quantization..."
|
| 76 |
-
|
| 77 |
class FakeModel(torch.nn.Module):
|
| 78 |
pass
|
| 79 |
|
|
@@ -83,13 +94,13 @@ def export_merged_model(
|
|
| 83 |
# Select quantization method
|
| 84 |
if qtype == "int8":
|
| 85 |
from optimum.quanto import int8_weight_only
|
| 86 |
-
|
| 87 |
elif qtype == "int4":
|
| 88 |
from optimum.quanto import int4_weight_only
|
| 89 |
-
|
| 90 |
elif qtype == "float8":
|
| 91 |
from optimum.quanto import float8_dynamic_activation_float8_weight
|
| 92 |
-
|
| 93 |
else:
|
| 94 |
raise ValueError(f"Unsupported qtype: {qtype}")
|
| 95 |
|
|
@@ -98,11 +109,12 @@ def export_merged_model(
|
|
| 98 |
for k, v in fake_model.state_dict().items()
|
| 99 |
}
|
| 100 |
elif quantize and not QUANTO_AVAILABLE:
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
# Step 6: Save model
|
| 104 |
yield "💾 Exporting model...", "Saving weights..."
|
| 105 |
-
|
| 106 |
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 107 |
|
| 108 |
# Build filename based on options
|
|
@@ -125,13 +137,14 @@ def export_merged_model(
|
|
| 125 |
else:
|
| 126 |
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
|
| 127 |
|
| 128 |
-
yield
|
| 129 |
-
return str(out_path), msg
|
| 130 |
|
| 131 |
except ImportError as e:
|
| 132 |
-
|
| 133 |
except Exception as e:
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
def get_export_status() -> str:
|
|
|
|
| 6 |
import torch
|
| 7 |
from safetensors.torch import save_file
|
| 8 |
|
| 9 |
+
from . import config
|
| 10 |
+
from .config import SCRIPT_DIR
|
| 11 |
|
| 12 |
|
| 13 |
def export_merged_model(
|
|
|
|
| 15 |
quantize: bool,
|
| 16 |
qtype: str,
|
| 17 |
save_format: str = "safetensors",
|
| 18 |
+
):
|
| 19 |
"""
|
| 20 |
Export the merged pipeline model with optional LoRA baking and quantization.
|
| 21 |
|
|
|
|
| 25 |
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
|
| 26 |
save_format: Output format - 'safetensors' or 'bin'
|
| 27 |
|
| 28 |
+
Yields:
|
| 29 |
+
Tuple of (status_message, progress_text) at each export stage.
|
| 30 |
+
|
| 31 |
Returns:
|
| 32 |
+
Final yielded tuple of (output_path or None, status message)
|
| 33 |
"""
|
| 34 |
+
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
|
| 35 |
+
pipe = config.get_pipe()
|
| 36 |
+
|
| 37 |
+
if not pipe:
|
| 38 |
+
yield None, "⚠️ Please load a pipeline first."
|
| 39 |
+
return
|
| 40 |
|
| 41 |
try:
|
| 42 |
# Validate quantization type
|
| 43 |
valid_qtypes = ("none", "int8", "int4", "float8")
|
| 44 |
if qtype not in valid_qtypes:
|
| 45 |
+
yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
|
| 46 |
+
return
|
| 47 |
|
| 48 |
# Step 1: Unload LoRAs
|
| 49 |
yield "💾 Exporting model...", "Unloading LoRAs..."
|
| 50 |
if include_lora:
|
| 51 |
try:
|
| 52 |
+
pipe.unload_lora_weights()
|
| 53 |
except Exception as e:
|
| 54 |
print(f" ℹ️ Could not unload LoRAs: {e}")
|
|
|
|
| 55 |
|
| 56 |
merged_state_dict = {}
|
| 57 |
|
| 58 |
# Step 2: Extract UNet weights
|
| 59 |
yield "💾 Exporting model...", "Extracting UNet weights..."
|
| 60 |
+
for k, v in pipe.unet.state_dict().items():
|
| 61 |
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
|
| 62 |
|
| 63 |
# Step 3: Extract text encoder weights
|
| 64 |
yield "💾 Exporting model...", "Extracting text encoders..."
|
| 65 |
+
if pipe.text_encoder is not None:
|
| 66 |
+
for k, v in pipe.text_encoder.state_dict().items():
|
| 67 |
+
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
|
| 68 |
+
if pipe.text_encoder_2 is not None:
|
| 69 |
+
for k, v in pipe.text_encoder_2.state_dict().items():
|
| 70 |
+
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
|
| 71 |
|
| 72 |
# Step 4: Extract VAE weights
|
| 73 |
yield "💾 Exporting model...", "Extracting VAE weights..."
|
| 74 |
+
if pipe.vae is not None:
|
| 75 |
+
for k, v in pipe.vae.state_dict().items():
|
| 76 |
+
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
|
| 77 |
|
| 78 |
# Step 5: Quantize if requested and optimum.quanto is available
|
| 79 |
try:
|
| 80 |
+
from optimum.quanto import quantize as quanto_quantize, QTensor
|
| 81 |
QUANTO_AVAILABLE = True
|
| 82 |
except ImportError:
|
| 83 |
QUANTO_AVAILABLE = False
|
| 84 |
|
| 85 |
if quantize and qtype != "none" and QUANTO_AVAILABLE:
|
| 86 |
yield "💾 Exporting model...", f"Applying {qtype} quantization..."
|
| 87 |
+
|
| 88 |
class FakeModel(torch.nn.Module):
|
| 89 |
pass
|
| 90 |
|
|
|
|
| 94 |
# Select quantization method
|
| 95 |
if qtype == "int8":
|
| 96 |
from optimum.quanto import int8_weight_only
|
| 97 |
+
quanto_quantize(fake_model, int8_weight_only())
|
| 98 |
elif qtype == "int4":
|
| 99 |
from optimum.quanto import int4_weight_only
|
| 100 |
+
quanto_quantize(fake_model, int4_weight_only())
|
| 101 |
elif qtype == "float8":
|
| 102 |
from optimum.quanto import float8_dynamic_activation_float8_weight
|
| 103 |
+
quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
|
| 104 |
else:
|
| 105 |
raise ValueError(f"Unsupported qtype: {qtype}")
|
| 106 |
|
|
|
|
| 109 |
for k, v in fake_model.state_dict().items()
|
| 110 |
}
|
| 111 |
elif quantize and not QUANTO_AVAILABLE:
|
| 112 |
+
yield None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
|
| 113 |
+
return
|
| 114 |
|
| 115 |
# Step 6: Save model
|
| 116 |
yield "💾 Exporting model...", "Saving weights..."
|
| 117 |
+
|
| 118 |
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 119 |
|
| 120 |
# Build filename based on options
|
|
|
|
| 137 |
else:
|
| 138 |
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
|
| 139 |
|
| 140 |
+
yield str(out_path), msg
|
|
|
|
| 141 |
|
| 142 |
except ImportError as e:
|
| 143 |
+
yield None, f"❌ Missing dependency: {str(e)}"
|
| 144 |
except Exception as e:
|
| 145 |
+
import traceback
|
| 146 |
+
print(traceback.format_exc())
|
| 147 |
+
yield None, f"❌ Export failed: {str(e)}"
|
| 148 |
|
| 149 |
|
| 150 |
def get_export_status() -> str:
|
src/generator.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
from .
|
| 6 |
-
from .
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def generate_image(
|
|
@@ -15,7 +16,8 @@ def generate_image(
|
|
| 15 |
width: int,
|
| 16 |
tile_x: bool = True,
|
| 17 |
tile_y: bool = False,
|
| 18 |
-
|
|
|
|
| 19 |
"""
|
| 20 |
Generate an image using the loaded SDXL pipeline.
|
| 21 |
|
|
@@ -28,49 +30,67 @@ def generate_image(
|
|
| 28 |
width: Output image width in pixels
|
| 29 |
tile_x: Enable seamless tiling on x-axis
|
| 30 |
tile_y: Enable seamless tiling on y-axis
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
Tuple of (
|
|
|
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# For CPU mode, use float32 and warn about slow generation
|
| 39 |
effective_dtype = dtype
|
| 40 |
effective_device = device
|
| 41 |
-
|
| 42 |
if is_running_on_spaces() and device == "cpu":
|
| 43 |
print(" ℹ️ CPU mode: using float32 for stability (generation will be slower)")
|
|
|
|
|
|
|
| 44 |
effective_dtype = torch.float32
|
| 45 |
# Update pipeline to use float32
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
if
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Enable seamless tiling on UNet & VAE decoder
|
| 53 |
-
enable_seamless_tiling(
|
| 54 |
-
enable_seamless_tiling(
|
| 55 |
-
|
| 56 |
-
yield "🎨 Generating image...", f"Steps: 0/{steps} | CFG: {cfg}"
|
| 57 |
|
| 58 |
-
|
| 59 |
-
result = global_pipe(
|
| 60 |
-
prompt=prompt,
|
| 61 |
-
negative_prompt=negative_prompt,
|
| 62 |
-
width=int(width),
|
| 63 |
-
height=int(height),
|
| 64 |
-
num_inference_steps=int(steps),
|
| 65 |
-
guidance_scale=float(cfg),
|
| 66 |
-
generator=generator,
|
| 67 |
-
)
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
from . import config
|
| 6 |
+
from .config import device, dtype, is_running_on_spaces
|
| 7 |
+
from .tiling import enable_seamless_tiling
|
| 8 |
|
| 9 |
|
| 10 |
def generate_image(
|
|
|
|
| 16 |
width: int,
|
| 17 |
tile_x: bool = True,
|
| 18 |
tile_y: bool = False,
|
| 19 |
+
seed: int | None = None,
|
| 20 |
+
):
|
| 21 |
"""
|
| 22 |
Generate an image using the loaded SDXL pipeline.
|
| 23 |
|
|
|
|
| 30 |
width: Output image width in pixels
|
| 31 |
tile_x: Enable seamless tiling on x-axis
|
| 32 |
tile_y: Enable seamless tiling on y-axis
|
| 33 |
+
seed: Random seed for reproducibility (default: uses timestamp-based seed)
|
| 34 |
|
| 35 |
+
Yields:
|
| 36 |
+
Tuple of (intermediate_image_or_none, status_message)
|
| 37 |
+
- First yield: (None, progress_text) for initial progress update
|
| 38 |
+
- Final yield: (PIL Image, final_status) with the generated image,
|
| 39 |
+
or (None, error_message) if generation failed.
|
| 40 |
"""
|
| 41 |
+
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
|
| 42 |
+
pipe = config.get_pipe()
|
| 43 |
+
|
| 44 |
+
if not pipe:
|
| 45 |
+
yield None, "⚠️ Please load a pipeline first."
|
| 46 |
+
return
|
| 47 |
|
| 48 |
# For CPU mode, use float32 and warn about slow generation
|
| 49 |
effective_dtype = dtype
|
| 50 |
effective_device = device
|
| 51 |
+
|
| 52 |
if is_running_on_spaces() and device == "cpu":
|
| 53 |
print(" ℹ️ CPU mode: using float32 for stability (generation will be slower)")
|
| 54 |
+
# Store original dtype and temporarily use float32
|
| 55 |
+
original_dtype = effective_dtype
|
| 56 |
effective_dtype = torch.float32
|
| 57 |
# Update pipeline to use float32
|
| 58 |
+
pipe.unet.to(dtype=torch.float32)
|
| 59 |
+
if pipe.text_encoder is not None:
|
| 60 |
+
pipe.text_encoder.to(dtype=torch.float32)
|
| 61 |
+
if pipe.text_encoder_2 is not None:
|
| 62 |
+
pipe.text_encoder_2.to(dtype=torch.float32)
|
| 63 |
+
if pipe.vae is not None:
|
| 64 |
+
pipe.vae.to(dtype=torch.float32)
|
| 65 |
+
else:
|
| 66 |
+
original_dtype = None
|
| 67 |
|
| 68 |
# Enable seamless tiling on UNet & VAE decoder
|
| 69 |
+
enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
|
| 70 |
+
enable_seamless_tiling(pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
yield None, "🎨 Generating image..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
try:
|
| 75 |
+
# Use provided seed or generate a random one if None
|
| 76 |
+
actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
|
| 77 |
+
generator = torch.Generator(device=effective_device).manual_seed(actual_seed)
|
| 78 |
+
result = pipe(
|
| 79 |
+
prompt=prompt,
|
| 80 |
+
negative_prompt=negative_prompt,
|
| 81 |
+
width=int(width),
|
| 82 |
+
height=int(height),
|
| 83 |
+
num_inference_steps=int(steps),
|
| 84 |
+
guidance_scale=float(cfg),
|
| 85 |
+
generator=generator,
|
| 86 |
+
)
|
| 87 |
|
| 88 |
+
image = result.images[0]
|
| 89 |
+
yield image, f"✅ Complete! ({int(width)}x{int(height)})"
|
| 90 |
|
| 91 |
+
except Exception as e:
|
| 92 |
+
import traceback
|
| 93 |
+
error_msg = f"❌ Generation failed: {str(e)}"
|
| 94 |
+
print(error_msg)
|
| 95 |
+
print(traceback.format_exc())
|
| 96 |
+
yield None, error_msg
|
src/pipeline.py
CHANGED
|
@@ -2,55 +2,16 @@
|
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
-
import torch
|
| 6 |
from diffusers import (
|
| 7 |
StableDiffusionXLPipeline,
|
| 8 |
AutoencoderKL,
|
| 9 |
DPMSolverSDEScheduler,
|
| 10 |
)
|
| 11 |
|
| 12 |
-
from .
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
"""Create patched forward for seamless tiling on Conv2d layers."""
|
| 17 |
-
original_forward = module._conv_forward
|
| 18 |
-
|
| 19 |
-
def patched_conv_forward(input, weight, bias):
|
| 20 |
-
if tile_x and tile_y:
|
| 21 |
-
input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
|
| 22 |
-
elif tile_x:
|
| 23 |
-
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
|
| 24 |
-
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
|
| 25 |
-
elif tile_y:
|
| 26 |
-
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
|
| 27 |
-
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
|
| 28 |
-
else:
|
| 29 |
-
return original_forward(input, weight, bias)
|
| 30 |
-
|
| 31 |
-
return torch.nn.functional.conv2d(
|
| 32 |
-
input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
return patched_conv_forward
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
|
| 39 |
-
"""
|
| 40 |
-
Enable seamless tiling on a model's Conv2d layers.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
|
| 44 |
-
tile_x: Enable tiling along x-axis
|
| 45 |
-
tile_y: Enable tiling along y-axis
|
| 46 |
-
"""
|
| 47 |
-
for module in model.modules():
|
| 48 |
-
if isinstance(module, torch.nn.Conv2d):
|
| 49 |
-
pad_h = module.padding[0]
|
| 50 |
-
pad_w = module.padding[1]
|
| 51 |
-
if pad_h == 0 and pad_w == 0:
|
| 52 |
-
continue
|
| 53 |
-
module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)
|
| 54 |
|
| 55 |
|
| 56 |
def load_pipeline(
|
|
@@ -70,17 +31,18 @@ def load_pipeline(
|
|
| 70 |
lora_strengths_str: Comma-separated strength values for each LoRA
|
| 71 |
progress: Optional gr.Progress() object for UI updates
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
Returns:
|
| 74 |
-
|
| 75 |
"""
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
try:
|
| 79 |
set_download_cancelled(False)
|
| 80 |
|
| 81 |
-
# Import gr here to update button state if needed
|
| 82 |
-
import gradio as gr
|
| 83 |
-
|
| 84 |
print("=" * 60)
|
| 85 |
print("🔄 Loading SDXL Pipeline...")
|
| 86 |
print("=" * 60)
|
|
@@ -93,8 +55,7 @@ def load_pipeline(
|
|
| 93 |
|
| 94 |
# Validate cache file before using it
|
| 95 |
if checkpoint_cached:
|
| 96 |
-
|
| 97 |
-
is_valid, msg = validate_cache_file(checkpoint_path)
|
| 98 |
if not is_valid:
|
| 99 |
print(f" ⚠️ Cache invalid: {msg}")
|
| 100 |
checkpoint_path.unlink(missing_ok=True)
|
|
@@ -107,8 +68,7 @@ def load_pipeline(
|
|
| 107 |
|
| 108 |
# Validate VAE cache file before using it
|
| 109 |
if vae_cached:
|
| 110 |
-
|
| 111 |
-
is_valid, msg = validate_cache_file(vae_path)
|
| 112 |
if not is_valid:
|
| 113 |
print(f" ⚠️ VAE Cache invalid: {msg}")
|
| 114 |
vae_path.unlink(missing_ok=True)
|
|
@@ -124,7 +84,7 @@ def load_pipeline(
|
|
| 124 |
else:
|
| 125 |
status_msg = f"✅ Using cached {checkpoint_path.name}"
|
| 126 |
print(f" ✅ Using cached: {checkpoint_path.name}")
|
| 127 |
-
|
| 128 |
yield status_msg, "Starting download..."
|
| 129 |
|
| 130 |
if not checkpoint_cached:
|
|
@@ -136,10 +96,10 @@ def load_pipeline(
|
|
| 136 |
if vae_path:
|
| 137 |
status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
|
| 138 |
print(f" 📥 VAE: {vae_path.name}" if not vae_cached else f" ✅ VAE (cached): {vae_path.name}")
|
| 139 |
-
|
| 140 |
if progress:
|
| 141 |
progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
|
| 142 |
-
|
| 143 |
yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
|
| 144 |
|
| 145 |
if not vae_cached:
|
|
@@ -158,7 +118,7 @@ def load_pipeline(
|
|
| 158 |
# Load base pipeline (yield progress during this heavy operation)
|
| 159 |
print(" ⚙️ Loading SDXL pipeline from single file...")
|
| 160 |
yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
|
| 161 |
-
|
| 162 |
if progress:
|
| 163 |
progress(0.3, desc="Loading text encoders...")
|
| 164 |
|
|
@@ -166,38 +126,35 @@ def load_pipeline(
|
|
| 166 |
load_kwargs = {
|
| 167 |
"torch_dtype": dtype,
|
| 168 |
"use_safetensors": True,
|
| 169 |
-
"safety_checker": None,
|
| 170 |
}
|
| 171 |
|
| 172 |
if is_running_on_spaces() and device == "cpu":
|
| 173 |
print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
|
| 174 |
load_kwargs["device_map"] = "auto"
|
| 175 |
-
else:
|
| 176 |
-
# For GPU, we'll move to device after loading
|
| 177 |
-
load_kwargs["variant"] = "fp16" if dtype == torch.float16 else None
|
| 178 |
|
| 179 |
-
|
|
|
|
| 180 |
str(checkpoint_path),
|
| 181 |
**load_kwargs,
|
| 182 |
)
|
| 183 |
print(" ✅ Text encoders loaded")
|
| 184 |
-
|
| 185 |
if progress:
|
| 186 |
progress(0.5, desc="Loading UNet...")
|
| 187 |
|
| 188 |
print(" ✅ UNet loaded")
|
| 189 |
-
|
| 190 |
# Move to device (unless using device_map='auto' which handles this automatically)
|
| 191 |
if not is_running_on_spaces() or device != "cpu":
|
| 192 |
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 193 |
-
|
| 194 |
-
|
| 195 |
yield "⚙️ Pipeline loaded, setting up components...", f"Using device: {device_description}"
|
| 196 |
|
| 197 |
# Load VAE into pipeline if provided
|
| 198 |
if vae is not None:
|
| 199 |
print(" ⚙️ Setting custom VAE...")
|
| 200 |
-
|
| 201 |
yield "⚙️ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
|
| 202 |
|
| 203 |
# Parse LoRA URLs & ensure strengths list matches
|
|
@@ -214,35 +171,34 @@ def load_pipeline(
|
|
| 214 |
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 215 |
if lora_urls:
|
| 216 |
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 217 |
-
|
| 218 |
|
| 219 |
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 220 |
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
|
| 221 |
lora_path = CACHE_DIR / lora_filename
|
| 222 |
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
|
| 223 |
-
|
| 224 |
# Validate LoRA cache file before using it
|
| 225 |
if lora_cached:
|
| 226 |
-
|
| 227 |
-
is_valid, msg = validate_cache_file(lora_path)
|
| 228 |
if not is_valid:
|
| 229 |
print(f" ⚠️ LoRA Cache invalid: {msg}")
|
| 230 |
lora_path.unlink(missing_ok=True)
|
| 231 |
lora_cached = False
|
| 232 |
-
|
| 233 |
if not lora_cached:
|
| 234 |
print(f" 📥 LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
|
| 235 |
status_msg = f"📥 Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
|
| 236 |
else:
|
| 237 |
print(f" ✅ LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
|
| 238 |
status_msg = f"✅ Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
|
| 239 |
-
|
| 240 |
yield (
|
| 241 |
status_msg,
|
| 242 |
f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
|
| 243 |
else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
|
| 244 |
)
|
| 245 |
-
|
| 246 |
if not lora_cached:
|
| 247 |
download_file_with_progress(lora_url, lora_path)
|
| 248 |
|
|
@@ -250,16 +206,16 @@ def load_pipeline(
|
|
| 250 |
yield f"⚙️ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
|
| 251 |
if progress:
|
| 252 |
progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
|
| 253 |
-
|
| 254 |
adapter_name = f"lora_{i}"
|
| 255 |
-
|
| 256 |
print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
|
| 257 |
-
|
| 258 |
-
|
| 259 |
else:
|
| 260 |
# Move pipeline to device even without LoRAs
|
| 261 |
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 262 |
-
|
| 263 |
|
| 264 |
# Set scheduler and finalize (do this once at the end)
|
| 265 |
print(" ⚙️ Configuring scheduler...")
|
|
@@ -268,22 +224,26 @@ def load_pipeline(
|
|
| 268 |
if progress:
|
| 269 |
progress(0.95, desc="Finalizing...")
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
algorithm_type="sde-dpmsolver++",
|
| 274 |
use_karras_sigmas=False,
|
| 275 |
)
|
| 276 |
|
|
|
|
|
|
|
|
|
|
| 277 |
print(" ✅ Pipeline ready!")
|
| 278 |
yield "✅ Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
|
| 279 |
-
return ("✅ Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
|
| 280 |
|
| 281 |
except KeyboardInterrupt:
|
| 282 |
set_download_cancelled(False)
|
|
|
|
| 283 |
print("\n⚠️ Download cancelled by user")
|
| 284 |
return ("⚠️ Download cancelled by user", "Cancelled")
|
| 285 |
except Exception as e:
|
| 286 |
import traceback
|
|
|
|
| 287 |
error_msg = f"❌ Error loading pipeline: {str(e)}"
|
| 288 |
print(f"\n{error_msg}")
|
| 289 |
print(traceback.format_exc())
|
|
@@ -297,4 +257,4 @@ def cancel_download():
|
|
| 297 |
|
| 298 |
def get_pipeline() -> StableDiffusionXLPipeline | None:
|
| 299 |
"""Get the currently loaded pipeline."""
|
| 300 |
-
return
|
|
|
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
|
|
|
|
| 5 |
from diffusers import (
|
| 6 |
StableDiffusionXLPipeline,
|
| 7 |
AutoencoderKL,
|
| 8 |
DPMSolverSDEScheduler,
|
| 9 |
)
|
| 10 |
|
| 11 |
+
from . import config
|
| 12 |
+
from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
|
| 13 |
+
from .downloader import get_safe_filename_from_url, download_file_with_progress
|
| 14 |
+
from .tiling import enable_seamless_tiling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def load_pipeline(
|
|
|
|
| 31 |
lora_strengths_str: Comma-separated strength values for each LoRA
|
| 32 |
progress: Optional gr.Progress() object for UI updates
|
| 33 |
|
| 34 |
+
Yields:
|
| 35 |
+
Tuple of (status_message, progress_text) at each loading stage.
|
| 36 |
+
|
| 37 |
Returns:
|
| 38 |
+
Final yielded tuple of (final_status_message, progress_text)
|
| 39 |
"""
|
| 40 |
+
# Clear any previously loaded pipeline so the UI reflects loading state
|
| 41 |
+
config.set_pipe(None)
|
| 42 |
|
| 43 |
try:
|
| 44 |
set_download_cancelled(False)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
print("=" * 60)
|
| 47 |
print("🔄 Loading SDXL Pipeline...")
|
| 48 |
print("=" * 60)
|
|
|
|
| 55 |
|
| 56 |
# Validate cache file before using it
|
| 57 |
if checkpoint_cached:
|
| 58 |
+
is_valid, msg = config.validate_cache_file(checkpoint_path)
|
|
|
|
| 59 |
if not is_valid:
|
| 60 |
print(f" ⚠️ Cache invalid: {msg}")
|
| 61 |
checkpoint_path.unlink(missing_ok=True)
|
|
|
|
| 68 |
|
| 69 |
# Validate VAE cache file before using it
|
| 70 |
if vae_cached:
|
| 71 |
+
is_valid, msg = config.validate_cache_file(vae_path)
|
|
|
|
| 72 |
if not is_valid:
|
| 73 |
print(f" ⚠️ VAE Cache invalid: {msg}")
|
| 74 |
vae_path.unlink(missing_ok=True)
|
|
|
|
| 84 |
else:
|
| 85 |
status_msg = f"✅ Using cached {checkpoint_path.name}"
|
| 86 |
print(f" ✅ Using cached: {checkpoint_path.name}")
|
| 87 |
+
|
| 88 |
yield status_msg, "Starting download..."
|
| 89 |
|
| 90 |
if not checkpoint_cached:
|
|
|
|
| 96 |
if vae_path:
|
| 97 |
status_msg = f"📥 Downloading {vae_path.name}..." if not vae_cached else f"✅ Using cached {vae_path.name}"
|
| 98 |
print(f" 📥 VAE: {vae_path.name}" if not vae_cached else f" ✅ VAE (cached): {vae_path.name}")
|
| 99 |
+
|
| 100 |
if progress:
|
| 101 |
progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
|
| 102 |
+
|
| 103 |
yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
|
| 104 |
|
| 105 |
if not vae_cached:
|
|
|
|
| 118 |
# Load base pipeline (yield progress during this heavy operation)
|
| 119 |
print(" ⚙️ Loading SDXL pipeline from single file...")
|
| 120 |
yield "⚙️ Loading SDXL pipeline...", "Loading model weights into memory..."
|
| 121 |
+
|
| 122 |
if progress:
|
| 123 |
progress(0.3, desc="Loading text encoders...")
|
| 124 |
|
|
|
|
| 126 |
load_kwargs = {
|
| 127 |
"torch_dtype": dtype,
|
| 128 |
"use_safetensors": True,
|
|
|
|
| 129 |
}
|
| 130 |
|
| 131 |
if is_running_on_spaces() and device == "cpu":
|
| 132 |
print(" ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
|
| 133 |
load_kwargs["device_map"] = "auto"
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# Use a local variable for the pipeline being built — only stored globally on success.
|
| 136 |
+
_pipe = StableDiffusionXLPipeline.from_single_file(
|
| 137 |
str(checkpoint_path),
|
| 138 |
**load_kwargs,
|
| 139 |
)
|
| 140 |
print(" ✅ Text encoders loaded")
|
| 141 |
+
|
| 142 |
if progress:
|
| 143 |
progress(0.5, desc="Loading UNet...")
|
| 144 |
|
| 145 |
print(" ✅ UNet loaded")
|
| 146 |
+
|
| 147 |
# Move to device (unless using device_map='auto' which handles this automatically)
|
| 148 |
if not is_running_on_spaces() or device != "cpu":
|
| 149 |
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 150 |
+
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 151 |
+
|
| 152 |
yield "⚙️ Pipeline loaded, setting up components...", f"Using device: {device_description}"
|
| 153 |
|
| 154 |
# Load VAE into pipeline if provided
|
| 155 |
if vae is not None:
|
| 156 |
print(" ⚙️ Setting custom VAE...")
|
| 157 |
+
_pipe.vae = vae.to(device=device, dtype=dtype)
|
| 158 |
yield "⚙️ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
|
| 159 |
|
| 160 |
# Parse LoRA URLs & ensure strengths list matches
|
|
|
|
| 171 |
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 172 |
if lora_urls:
|
| 173 |
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 174 |
+
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 175 |
|
| 176 |
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 177 |
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
|
| 178 |
lora_path = CACHE_DIR / lora_filename
|
| 179 |
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
|
| 180 |
+
|
| 181 |
# Validate LoRA cache file before using it
|
| 182 |
if lora_cached:
|
| 183 |
+
is_valid, msg = config.validate_cache_file(lora_path)
|
|
|
|
| 184 |
if not is_valid:
|
| 185 |
print(f" ⚠️ LoRA Cache invalid: {msg}")
|
| 186 |
lora_path.unlink(missing_ok=True)
|
| 187 |
lora_cached = False
|
| 188 |
+
|
| 189 |
if not lora_cached:
|
| 190 |
print(f" 📥 LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
|
| 191 |
status_msg = f"📥 Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
|
| 192 |
else:
|
| 193 |
print(f" ✅ LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
|
| 194 |
status_msg = f"✅ Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
|
| 195 |
+
|
| 196 |
yield (
|
| 197 |
status_msg,
|
| 198 |
f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
|
| 199 |
else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
|
| 200 |
)
|
| 201 |
+
|
| 202 |
if not lora_cached:
|
| 203 |
download_file_with_progress(lora_url, lora_path)
|
| 204 |
|
|
|
|
| 206 |
yield f"⚙️ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
|
| 207 |
if progress:
|
| 208 |
progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
|
| 209 |
+
|
| 210 |
adapter_name = f"lora_{i}"
|
| 211 |
+
_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
|
| 212 |
print(f" ⚙️ Fusing LoRA {i+1} with strength={strength}...")
|
| 213 |
+
_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
|
| 214 |
+
_pipe.unload_lora_weights()
|
| 215 |
else:
|
| 216 |
# Move pipeline to device even without LoRAs
|
| 217 |
print(f" ⚙️ Moving pipeline to device: {device_description}...")
|
| 218 |
+
_pipe = _pipe.to(device=device, dtype=dtype)
|
| 219 |
|
| 220 |
# Set scheduler and finalize (do this once at the end)
|
| 221 |
print(" ⚙️ Configuring scheduler...")
|
|
|
|
| 224 |
if progress:
|
| 225 |
progress(0.95, desc="Finalizing...")
|
| 226 |
|
| 227 |
+
_pipe.scheduler = DPMSolverSDEScheduler.from_config(
|
| 228 |
+
_pipe.scheduler.config,
|
| 229 |
algorithm_type="sde-dpmsolver++",
|
| 230 |
use_karras_sigmas=False,
|
| 231 |
)
|
| 232 |
|
| 233 |
+
# ✅ Only publish the pipeline globally AFTER all steps succeed
|
| 234 |
+
config.set_pipe(_pipe)
|
| 235 |
+
|
| 236 |
print(" ✅ Pipeline ready!")
|
| 237 |
yield "✅ Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
|
|
|
|
| 238 |
|
| 239 |
except KeyboardInterrupt:
|
| 240 |
set_download_cancelled(False)
|
| 241 |
+
config.set_pipe(None)
|
| 242 |
print("\n⚠️ Download cancelled by user")
|
| 243 |
return ("⚠️ Download cancelled by user", "Cancelled")
|
| 244 |
except Exception as e:
|
| 245 |
import traceback
|
| 246 |
+
config.set_pipe(None)
|
| 247 |
error_msg = f"❌ Error loading pipeline: {str(e)}"
|
| 248 |
print(f"\n{error_msg}")
|
| 249 |
print(traceback.format_exc())
|
|
|
|
| 257 |
|
| 258 |
def get_pipeline() -> StableDiffusionXLPipeline | None:
|
| 259 |
"""Get the currently loaded pipeline."""
|
| 260 |
+
return config.get_pipe()
|
src/tiling.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Seamless tiling utilities for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
|
| 7 |
+
"""Create patched forward for seamless tiling on Conv2d layers."""
|
| 8 |
+
original_forward = module._conv_forward
|
| 9 |
+
|
| 10 |
+
def patched_conv_forward(input, weight, bias):
|
| 11 |
+
if tile_x and tile_y:
|
| 12 |
+
# Circular padding on both axes
|
| 13 |
+
input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
|
| 14 |
+
elif tile_x:
|
| 15 |
+
# Circular padding only on left/right edges, constant (zero) on top/bottom
|
| 16 |
+
# Asymmetric padding for 360° panorama tiling
|
| 17 |
+
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
|
| 18 |
+
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
|
| 19 |
+
elif tile_y:
|
| 20 |
+
# Circular padding only on top/bottom edges, constant (zero) on left/right
|
| 21 |
+
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
|
| 22 |
+
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
|
| 23 |
+
else:
|
| 24 |
+
return original_forward(input, weight, bias)
|
| 25 |
+
|
| 26 |
+
return torch.nn.functional.conv2d(
|
| 27 |
+
input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
return patched_conv_forward
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
|
| 34 |
+
"""
|
| 35 |
+
Enable seamless tiling on a model's Conv2d layers.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
|
| 39 |
+
tile_x: Enable tiling along x-axis
|
| 40 |
+
tile_y: Enable tiling along y-axis
|
| 41 |
+
"""
|
| 42 |
+
for module in model.modules():
|
| 43 |
+
if isinstance(module, torch.nn.Conv2d):
|
| 44 |
+
pad_h = module.padding[0]
|
| 45 |
+
pad_w = module.padding[1]
|
| 46 |
+
if pad_h == 0 and pad_w == 0:
|
| 47 |
+
continue
|
| 48 |
+
current = getattr(module, "_tiling_config", None)
|
| 49 |
+
if current == (tile_x, tile_y):
|
| 50 |
+
continue # already patched with same config
|
| 51 |
+
module._tiling_config = (tile_x, tile_y)
|
| 52 |
+
module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)
|
src/ui/generator_tab.py
CHANGED
|
@@ -8,7 +8,7 @@ from ..generator import generate_image
|
|
| 8 |
|
| 9 |
def create_generator_tab():
|
| 10 |
"""Create the image generation tab with all input controls."""
|
| 11 |
-
|
| 12 |
with gr.Accordion("🎨 2. Generate Image", open=True, elem_classes=["feature-card"]):
|
| 13 |
# Prompts section
|
| 14 |
with gr.Row():
|
|
@@ -19,19 +19,19 @@ def create_generator_tab():
|
|
| 19 |
lines=3,
|
| 20 |
placeholder="Describe the image you want to generate..."
|
| 21 |
)
|
| 22 |
-
|
| 23 |
cfg = gr.Slider(
|
| 24 |
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
|
| 25 |
label="CFG Scale",
|
| 26 |
info="Higher values make outputs match prompt more strictly"
|
| 27 |
)
|
| 28 |
-
|
| 29 |
height = gr.Number(
|
| 30 |
value=1024, precision=0,
|
| 31 |
label="Height (pixels)",
|
| 32 |
info="Output image height"
|
| 33 |
)
|
| 34 |
-
|
| 35 |
with gr.Column(scale=1):
|
| 36 |
negative_prompt = gr.Textbox(
|
| 37 |
label="Negative Prompt",
|
|
@@ -39,28 +39,28 @@ def create_generator_tab():
|
|
| 39 |
lines=3,
|
| 40 |
placeholder="Elements to avoid in generation..."
|
| 41 |
)
|
| 42 |
-
|
| 43 |
steps = gr.Slider(
|
| 44 |
minimum=1, maximum=100, value=25, step=1,
|
| 45 |
label="Inference Steps",
|
| 46 |
info="More steps = better quality but slower"
|
| 47 |
)
|
| 48 |
-
|
| 49 |
width = gr.Number(
|
| 50 |
value=2048, precision=0,
|
| 51 |
label="Width (pixels)",
|
| 52 |
info="Output image width"
|
| 53 |
)
|
| 54 |
-
|
| 55 |
# Tiling options
|
| 56 |
with gr.Row():
|
| 57 |
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
|
| 58 |
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
|
| 59 |
-
|
| 60 |
# Generate button and outputs
|
| 61 |
with gr.Row():
|
| 62 |
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
|
| 63 |
-
|
| 64 |
with gr.Row():
|
| 65 |
image_output = gr.Image(
|
| 66 |
label="Result",
|
|
@@ -68,12 +68,18 @@ def create_generator_tab():
|
|
| 68 |
show_label=True
|
| 69 |
)
|
| 70 |
with gr.Column():
|
| 71 |
-
gen_status = gr.
|
| 72 |
label="Generation Status",
|
| 73 |
-
|
| 74 |
)
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
gr.HTML("""
|
| 78 |
<div style="margin-top: 16px; padding: 12px; background: #f3f4f6; border-radius: 8px;">
|
| 79 |
<strong>💡 Tips:</strong>
|
|
@@ -84,21 +90,52 @@ def create_generator_tab():
|
|
| 84 |
</ul>
|
| 85 |
</div>
|
| 86 |
""")
|
| 87 |
-
|
| 88 |
return (
|
| 89 |
prompt, negative_prompt, cfg, steps, height, width,
|
| 90 |
-
tile_x, tile_y, gen_btn, image_output, gen_status
|
| 91 |
)
|
| 92 |
|
| 93 |
|
| 94 |
def setup_generator_events(
|
| 95 |
prompt, negative_prompt, cfg, steps, height, width,
|
| 96 |
-
tile_x, tile_y, gen_btn, image_output, gen_status
|
| 97 |
):
|
| 98 |
"""Setup event handlers for the generator tab."""
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
gen_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
fn=generate_image,
|
| 102 |
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 103 |
-
outputs=[image_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
)
|
|
|
|
| 8 |
|
| 9 |
def create_generator_tab():
|
| 10 |
"""Create the image generation tab with all input controls."""
|
| 11 |
+
|
| 12 |
with gr.Accordion("🎨 2. Generate Image", open=True, elem_classes=["feature-card"]):
|
| 13 |
# Prompts section
|
| 14 |
with gr.Row():
|
|
|
|
| 19 |
lines=3,
|
| 20 |
placeholder="Describe the image you want to generate..."
|
| 21 |
)
|
| 22 |
+
|
| 23 |
cfg = gr.Slider(
|
| 24 |
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
|
| 25 |
label="CFG Scale",
|
| 26 |
info="Higher values make outputs match prompt more strictly"
|
| 27 |
)
|
| 28 |
+
|
| 29 |
height = gr.Number(
|
| 30 |
value=1024, precision=0,
|
| 31 |
label="Height (pixels)",
|
| 32 |
info="Output image height"
|
| 33 |
)
|
| 34 |
+
|
| 35 |
with gr.Column(scale=1):
|
| 36 |
negative_prompt = gr.Textbox(
|
| 37 |
label="Negative Prompt",
|
|
|
|
| 39 |
lines=3,
|
| 40 |
placeholder="Elements to avoid in generation..."
|
| 41 |
)
|
| 42 |
+
|
| 43 |
steps = gr.Slider(
|
| 44 |
minimum=1, maximum=100, value=25, step=1,
|
| 45 |
label="Inference Steps",
|
| 46 |
info="More steps = better quality but slower"
|
| 47 |
)
|
| 48 |
+
|
| 49 |
width = gr.Number(
|
| 50 |
value=2048, precision=0,
|
| 51 |
label="Width (pixels)",
|
| 52 |
info="Output image width"
|
| 53 |
)
|
| 54 |
+
|
| 55 |
# Tiling options
|
| 56 |
with gr.Row():
|
| 57 |
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
|
| 58 |
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
|
| 59 |
+
|
| 60 |
# Generate button and outputs
|
| 61 |
with gr.Row():
|
| 62 |
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
|
| 63 |
+
|
| 64 |
with gr.Row():
|
| 65 |
image_output = gr.Image(
|
| 66 |
label="Result",
|
|
|
|
| 68 |
show_label=True
|
| 69 |
)
|
| 70 |
with gr.Column():
|
| 71 |
+
gen_status = gr.HTML(
|
| 72 |
label="Generation Status",
|
| 73 |
+
value='<div class="status-success">✅ Ready to generate</div>',
|
| 74 |
)
|
| 75 |
+
|
| 76 |
+
gen_progress = gr.Textbox(
|
| 77 |
+
label="Generation Progress",
|
| 78 |
+
placeholder="Ready to generate...",
|
| 79 |
+
show_label=True,
|
| 80 |
+
visible=False
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
gr.HTML("""
|
| 84 |
<div style="margin-top: 16px; padding: 12px; background: #f3f4f6; border-radius: 8px;">
|
| 85 |
<strong>💡 Tips:</strong>
|
|
|
|
| 90 |
</ul>
|
| 91 |
</div>
|
| 92 |
""")
|
| 93 |
+
|
| 94 |
return (
|
| 95 |
prompt, negative_prompt, cfg, steps, height, width,
|
| 96 |
+
tile_x, tile_y, gen_btn, image_output, gen_status, gen_progress
|
| 97 |
)
|
| 98 |
|
| 99 |
|
| 100 |
def setup_generator_events(
|
| 101 |
prompt, negative_prompt, cfg, steps, height, width,
|
| 102 |
+
tile_x, tile_y, gen_btn, image_output, gen_status, gen_progress
|
| 103 |
):
|
| 104 |
"""Setup event handlers for the generator tab."""
|
| 105 |
+
|
| 106 |
+
def on_generate_start():
|
| 107 |
+
return (
|
| 108 |
+
'<div class="status-warning">⏳ Generating image...</div>',
|
| 109 |
+
"Starting generation...",
|
| 110 |
+
gr.update(interactive=False),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def on_generate_complete(img, msg):
|
| 114 |
+
"""Read both image and last progress message to determine success/failure."""
|
| 115 |
+
if img is None:
|
| 116 |
+
return (
|
| 117 |
+
f'<div class="status-error">{msg}</div>',
|
| 118 |
+
"",
|
| 119 |
+
gr.update(interactive=True),
|
| 120 |
+
gr.update(),
|
| 121 |
+
)
|
| 122 |
+
return (
|
| 123 |
+
'<div class="status-success">✅ Generation complete!</div>',
|
| 124 |
+
"Done",
|
| 125 |
+
gr.update(interactive=True),
|
| 126 |
+
gr.update(value=img),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
gen_btn.click(
|
| 130 |
+
fn=on_generate_start,
|
| 131 |
+
inputs=[],
|
| 132 |
+
outputs=[gen_status, gen_progress, gen_btn],
|
| 133 |
+
).then(
|
| 134 |
fn=generate_image,
|
| 135 |
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 136 |
+
outputs=[image_output, gen_progress],
|
| 137 |
+
).then(
|
| 138 |
+
fn=on_generate_complete,
|
| 139 |
+
inputs=[image_output, gen_progress],
|
| 140 |
+
outputs=[gen_status, gen_progress, gen_btn, image_output],
|
| 141 |
)
|