Spaces:
Sleeping
Sleeping
Kyle Pearson commited on
Commit ·
6a07ce1
1
Parent(s): b807b57
code
Browse files- app.py +542 -0
- requirements.txt +20 -0
- src/__init__.py +3 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/__pycache__/config.cpython-311.pyc +0 -0
- src/__pycache__/config.cpython-313.pyc +0 -0
- src/__pycache__/downloader.cpython-311.pyc +0 -0
- src/__pycache__/downloader.cpython-313.pyc +0 -0
- src/__pycache__/exporter.cpython-311.pyc +0 -0
- src/__pycache__/exporter.cpython-313.pyc +0 -0
- src/__pycache__/generator.cpython-311.pyc +0 -0
- src/__pycache__/generator.cpython-313.pyc +0 -0
- src/__pycache__/pipeline.cpython-311.pyc +0 -0
- src/__pycache__/pipeline.cpython-313.pyc +0 -0
- src/config.py +129 -0
- src/downloader.py +248 -0
- src/exporter.py +134 -0
- src/generator.py +62 -0
- src/pipeline.py +174 -0
- src/ui/__init__.py +13 -0
- src/ui/__pycache__/__init__.cpython-311.pyc +0 -0
- src/ui/__pycache__/__init__.cpython-313.pyc +0 -0
- src/ui/__pycache__/exporter_tab.cpython-311.pyc +0 -0
- src/ui/__pycache__/exporter_tab.cpython-313.pyc +0 -0
- src/ui/__pycache__/generator_tab.cpython-311.pyc +0 -0
- src/ui/__pycache__/generator_tab.cpython-313.pyc +0 -0
- src/ui/__pycache__/header.cpython-311.pyc +0 -0
- src/ui/__pycache__/header.cpython-313.pyc +0 -0
- src/ui/__pycache__/loader_tab.cpython-311.pyc +0 -0
- src/ui/__pycache__/loader_tab.cpython-313.pyc +0 -0
- src/ui/exporter_tab.py +106 -0
- src/ui/generator_tab.py +104 -0
- src/ui/header.py +129 -0
- src/ui/loader_tab.py +83 -0
app.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SDXL Model Merger - Modernized with modular architecture and improved UI/UX.
|
| 4 |
+
|
| 5 |
+
This application allows you to:
|
| 6 |
+
- Load SDXL checkpoints with optional VAE and multiple LoRAs
|
| 7 |
+
- Generate images with seamless tiling support
|
| 8 |
+
- Export merged models with quantization options
|
| 9 |
+
|
| 10 |
+
Author: Qwen Code Assistant
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_app():
|
| 17 |
+
"""Create and configure the Gradio app."""
|
| 18 |
+
|
| 19 |
+
header_css = """
|
| 20 |
+
.header-gradient {
|
| 21 |
+
background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%);
|
| 22 |
+
-webkit-background-clip: text;
|
| 23 |
+
-webkit-text-fill-color: transparent;
|
| 24 |
+
background-clip: text;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.feature-card {
|
| 28 |
+
border-radius: 12px;
|
| 29 |
+
padding: 20px;
|
| 30 |
+
margin-bottom: 16px;
|
| 31 |
+
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
|
| 32 |
+
transition: transform 0.2s ease;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
.feature-card:hover {
|
| 36 |
+
transform: translateY(-2px);
|
| 37 |
+
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
.gradio-container .label {
|
| 41 |
+
font-weight: 600;
|
| 42 |
+
color: #374151;
|
| 43 |
+
margin-bottom: 8px;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.status-success { color: #059669 !important; font-weight: 600; }
|
| 47 |
+
.status-error { color: #dc2626 !important; font-weight: 600; }
|
| 48 |
+
.status-warning { color: #d97706 !important; font-weight: 600; }
|
| 49 |
+
|
| 50 |
+
.gradio-container .btn {
|
| 51 |
+
border-radius: 8px;
|
| 52 |
+
padding: 12px 24px;
|
| 53 |
+
font-weight: 600;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.gradio-container textarea,
|
| 57 |
+
.gradio-container input[type="number"],
|
| 58 |
+
.gradio-container input[type="text"] {
|
| 59 |
+
border-radius: 8px;
|
| 60 |
+
border-color: #d1d5db;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
.gradio-container textarea:focus,
|
| 64 |
+
.gradio-container input:focus {
|
| 65 |
+
outline: none;
|
| 66 |
+
border-color: #6366f1;
|
| 67 |
+
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.gradio-container .tabitem {
|
| 71 |
+
background: transparent;
|
| 72 |
+
border-radius: 12px;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.progress-text {
|
| 76 |
+
font-weight: 500;
|
| 77 |
+
color: #6b7280 !important;
|
| 78 |
+
}
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
from src.pipeline import load_pipeline, cancel_download
|
| 82 |
+
from src.generator import generate_image
|
| 83 |
+
from src.exporter import export_merged_model
|
| 84 |
+
from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
|
| 85 |
+
|
| 86 |
+
with gr.Blocks(title="SDXL Model Merger", css=header_css) as demo:
|
| 87 |
+
# Header section
|
| 88 |
+
with gr.Column(elem_classes=["feature-card"]):
|
| 89 |
+
gr.HTML("""
|
| 90 |
+
<div style="text-align: center; margin-bottom: 24px;">
|
| 91 |
+
<h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;">
|
| 92 |
+
<span class="header-gradient">SDXL Model Merger</span>
|
| 93 |
+
</h1>
|
| 94 |
+
<p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;">
|
| 95 |
+
Merge checkpoints, LoRAs, and VAEs - then bake LoRAs into a single exportable
|
| 96 |
+
checkpoint with optional quantization.
|
| 97 |
+
</p>
|
| 98 |
+
</div>
|
| 99 |
+
""")
|
| 100 |
+
|
| 101 |
+
# Feature highlights
|
| 102 |
+
with gr.Row():
|
| 103 |
+
with gr.Column(scale=1):
|
| 104 |
+
gr.HTML("""
|
| 105 |
+
<div style="text-align: center; padding: 16px;">
|
| 106 |
+
<div style="font-size: 2.5em; margin-bottom: 8px;">🚀</div>
|
| 107 |
+
<strong>Fast Loading</strong>
|
| 108 |
+
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p>
|
| 109 |
+
</div>
|
| 110 |
+
""")
|
| 111 |
+
with gr.Column(scale=1):
|
| 112 |
+
gr.HTML("""
|
| 113 |
+
<div style="text-align: center; padding: 16px;">
|
| 114 |
+
<div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div>
|
| 115 |
+
<strong>Panorama Gen</strong>
|
| 116 |
+
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p>
|
| 117 |
+
</div>
|
| 118 |
+
""")
|
| 119 |
+
with gr.Column(scale=1):
|
| 120 |
+
gr.HTML("""
|
| 121 |
+
<div style="text-align: center; padding: 16px;">
|
| 122 |
+
<div style="font-size: 2.5em; margin-bottom: 8px;">📦</div>
|
| 123 |
+
<strong>Export Ready</strong>
|
| 124 |
+
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p>
|
| 125 |
+
</div>
|
| 126 |
+
""")
|
| 127 |
+
|
| 128 |
+
gr.Markdown("---")
|
| 129 |
+
|
| 130 |
+
with gr.Tab("Load Pipeline"):
|
| 131 |
+
gr.Markdown("### Load SDXL Pipeline with Checkpoint, VAE, and LoRAs")
|
| 132 |
+
|
| 133 |
+
# Progress indicator for pipeline loading
|
| 134 |
+
load_progress = gr.Textbox(
|
| 135 |
+
label="Loading Progress",
|
| 136 |
+
placeholder="Ready to start...",
|
| 137 |
+
show_label=True,
|
| 138 |
+
info="Real-time status of model downloads and pipeline setup"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
with gr.Row():
|
| 142 |
+
with gr.Column(scale=2):
|
| 143 |
+
# Checkpoint URL with cached models dropdown
|
| 144 |
+
checkpoint_url = gr.Textbox(
|
| 145 |
+
label="Base Model (.safetensors) URL",
|
| 146 |
+
value="https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
|
| 147 |
+
placeholder="e.g., https://civitai.com/api/download/models/...",
|
| 148 |
+
info="Download link for the base SDXL checkpoint"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Dropdown of cached checkpoints
|
| 152 |
+
cached_checkpoints = gr.Dropdown(
|
| 153 |
+
choices=["(None found)"] + get_cached_checkpoints(),
|
| 154 |
+
label="Cached Checkpoints",
|
| 155 |
+
value="(None found)" if not get_cached_checkpoints() else None,
|
| 156 |
+
info="Models already downloaded to .cache/"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# VAE URL
|
| 160 |
+
vae_url = gr.Textbox(
|
| 161 |
+
label="VAE (.safetensors) URL",
|
| 162 |
+
value="https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
|
| 163 |
+
placeholder="Leave blank to use model's built-in VAE",
|
| 164 |
+
info="Optional custom VAE for improved quality"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Dropdown of cached VAEs
|
| 168 |
+
cached_vaes = gr.Dropdown(
|
| 169 |
+
choices=["(None found)"] + get_cached_vaes(),
|
| 170 |
+
label="Cached VAEs",
|
| 171 |
+
value="(None found)" if not get_cached_vaes() else None,
|
| 172 |
+
info="Select a VAE to load"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
with gr.Column(scale=1):
|
| 176 |
+
# LoRA URLs input
|
| 177 |
+
lora_urls = gr.Textbox(
|
| 178 |
+
label="LoRA URLs (one per line)",
|
| 179 |
+
lines=5,
|
| 180 |
+
value="https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor",
|
| 181 |
+
placeholder="https://civit.ai/...\nhttps://huggingface.co/...",
|
| 182 |
+
info="Multiple LoRAs can be loaded and fused together"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Dropdown of cached LoRAs
|
| 186 |
+
cached_loras = gr.Dropdown(
|
| 187 |
+
choices=["(None found)"] + get_cached_loras(),
|
| 188 |
+
label="Cached LoRAs",
|
| 189 |
+
value="(None found)" if not get_cached_loras() else None,
|
| 190 |
+
info="Select a LoRA to add to the list below"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
lora_strengths = gr.Textbox(
|
| 194 |
+
label="LoRA Strengths",
|
| 195 |
+
value="1.0",
|
| 196 |
+
placeholder="e.g., 0.8,1.0,0.5",
|
| 197 |
+
info="Comma-separated strength values for each LoRA"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
with gr.Row():
|
| 201 |
+
load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg")
|
| 202 |
+
|
| 203 |
+
# Detailed status display
|
| 204 |
+
load_status = gr.HTML(
|
| 205 |
+
label="Status",
|
| 206 |
+
value='<div class="status-success">✅ Ready to load pipeline</div>',
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
with gr.Tab("Generate Image"):
|
| 210 |
+
gr.Markdown("### Generate Panorama Images with Seamless Tiling")
|
| 211 |
+
|
| 212 |
+
# Progress indicator for image generation
|
| 213 |
+
gen_progress = gr.Textbox(
|
| 214 |
+
label="Generation Progress",
|
| 215 |
+
placeholder="Ready to generate...",
|
| 216 |
+
show_label=True,
|
| 217 |
+
info="Real-time status of image generation"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
with gr.Row():
|
| 221 |
+
with gr.Column(scale=1):
|
| 222 |
+
prompt = gr.Textbox(
|
| 223 |
+
label="Positive Prompt",
|
| 224 |
+
value="Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic",
|
| 225 |
+
lines=4,
|
| 226 |
+
placeholder="Describe the image you want to generate..."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
cfg = gr.Slider(
|
| 230 |
+
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
|
| 231 |
+
label="CFG Scale",
|
| 232 |
+
info="Higher values make outputs match prompt more strictly"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
height = gr.Number(
|
| 236 |
+
value=1024, precision=0,
|
| 237 |
+
label="Height (pixels)",
|
| 238 |
+
info="Output image height"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
with gr.Column(scale=1):
|
| 242 |
+
negative_prompt = gr.Textbox(
|
| 243 |
+
label="Negative Prompt",
|
| 244 |
+
value="boring, text, signature, watermark, low quality, bad quality",
|
| 245 |
+
lines=4,
|
| 246 |
+
placeholder="Elements to avoid in generation..."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
steps = gr.Slider(
|
| 250 |
+
minimum=1, maximum=100, value=25, step=1,
|
| 251 |
+
label="Inference Steps",
|
| 252 |
+
info="More steps = better quality but slower"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
width = gr.Number(
|
| 256 |
+
value=2048, precision=0,
|
| 257 |
+
label="Width (pixels)",
|
| 258 |
+
info="Output image width"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
with gr.Row():
|
| 262 |
+
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
|
| 263 |
+
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
|
| 264 |
+
|
| 265 |
+
with gr.Row():
|
| 266 |
+
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
|
| 267 |
+
|
| 268 |
+
with gr.Row():
|
| 269 |
+
image_output = gr.Image(
|
| 270 |
+
label="Result",
|
| 271 |
+
height=400,
|
| 272 |
+
show_label=True
|
| 273 |
+
)
|
| 274 |
+
with gr.Column():
|
| 275 |
+
gen_status = gr.HTML(
|
| 276 |
+
label="Generation Status",
|
| 277 |
+
value='<div class="status-success">✅ Ready to generate</div>',
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
gr.HTML("""
|
| 281 |
+
<div style="margin-top: 16px; padding: 12px; background-color: #e5e7eb !important; border-radius: 8px;">
|
| 282 |
+
<strong style="color: #1f2937 !important;">💡 Tips:</strong>
|
| 283 |
+
<ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em; color: #1f2937 !important;">
|
| 284 |
+
<li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li>
|
| 285 |
+
<li>Enable seamless tiling for texture-like outputs</li>
|
| 286 |
+
<li>Lower CFG (3-5) for more creative results</li>
|
| 287 |
+
</ul>
|
| 288 |
+
</div>
|
| 289 |
+
""")
|
| 290 |
+
|
| 291 |
+
with gr.Tab("Export Model"):
|
| 292 |
+
gr.Markdown("### Export Merged Checkpoint with Quantization Options")
|
| 293 |
+
|
| 294 |
+
# Progress indicator for export
|
| 295 |
+
export_progress = gr.Textbox(
|
| 296 |
+
label="Export Progress",
|
| 297 |
+
placeholder="Ready to export...",
|
| 298 |
+
show_label=True,
|
| 299 |
+
info="Real-time status of model export and quantization"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
with gr.Row():
|
| 303 |
+
include_lora = gr.Checkbox(
|
| 304 |
+
True,
|
| 305 |
+
label="Include Fused LoRAs",
|
| 306 |
+
info="Bake the loaded LoRAs into the exported model"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
quantize_toggle = gr.Checkbox(
|
| 310 |
+
False,
|
| 311 |
+
label="Apply Quantization",
|
| 312 |
+
info="Reduce model size with quantization"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
qtype_row = gr.Row(visible=True)
|
| 316 |
+
with qtype_row:
|
| 317 |
+
qtype_dropdown = gr.Dropdown(
|
| 318 |
+
choices=["none", "int8", "int4", "float8"],
|
| 319 |
+
value="int8",
|
| 320 |
+
label="Quantization Method",
|
| 321 |
+
info="Trade quality for smaller file size"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
with gr.Row():
|
| 325 |
+
format_dropdown = gr.Dropdown(
|
| 326 |
+
choices=["safetensors", "bin"],
|
| 327 |
+
value="safetensors",
|
| 328 |
+
label="Export Format",
|
| 329 |
+
info="safetensors is recommended for safety"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
with gr.Row():
|
| 333 |
+
export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg")
|
| 334 |
+
|
| 335 |
+
with gr.Row():
|
| 336 |
+
download_link = gr.File(
|
| 337 |
+
label="Download Merged File",
|
| 338 |
+
show_label=True,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
with gr.Column():
|
| 342 |
+
export_status = gr.HTML(
|
| 343 |
+
label="Export Status",
|
| 344 |
+
value='<div class="status-success">✅ Ready to export</div>',
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
gr.HTML("""
|
| 348 |
+
<div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
|
| 349 |
+
<strong>ℹ️ About Quantization:</strong>
|
| 350 |
+
<p style="font-size: 0.9em; margin: 8px 0;">
|
| 351 |
+
Reduces model size by lowering precision. Int8 is typically
|
| 352 |
+
lossless for inference while cutting size in half.
|
| 353 |
+
</p>
|
| 354 |
+
</div>
|
| 355 |
+
""")
|
| 356 |
+
|
| 357 |
+
# Event handlers - all inside Blocks context
|
| 358 |
+
|
| 359 |
+
def on_load_pipeline_start():
|
| 360 |
+
"""Called when pipeline loading starts."""
|
| 361 |
+
return (
|
| 362 |
+
'<div class="status-warning">⏳ Loading started...</div>',
|
| 363 |
+
"Starting download...",
|
| 364 |
+
gr.update(interactive=False)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def on_load_pipeline_complete(status_msg, progress_text):
|
| 368 |
+
"""Called when pipeline loading completes."""
|
| 369 |
+
if "✅" in status_msg:
|
| 370 |
+
return (
|
| 371 |
+
'<div class="status-success">✅ Pipeline loaded successfully!</div>',
|
| 372 |
+
progress_text,
|
| 373 |
+
gr.update(interactive=True)
|
| 374 |
+
)
|
| 375 |
+
elif "⚠️" in status_msg:
|
| 376 |
+
return (
|
| 377 |
+
'<div class="status-warning">⚠️ Download cancelled</div>',
|
| 378 |
+
progress_text,
|
| 379 |
+
gr.update(interactive=True)
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
return (
|
| 383 |
+
f'<div class="status-error">{status_msg}</div>',
|
| 384 |
+
progress_text,
|
| 385 |
+
gr.update(interactive=True)
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
load_btn.click(
|
| 389 |
+
fn=on_load_pipeline_start,
|
| 390 |
+
inputs=[],
|
| 391 |
+
outputs=[load_status, load_progress, load_btn],
|
| 392 |
+
).then(
|
| 393 |
+
fn=load_pipeline,
|
| 394 |
+
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 395 |
+
outputs=[load_status, load_progress],
|
| 396 |
+
show_api=False,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def on_cached_checkpoint_change(cached_path):
|
| 400 |
+
"""Update URL when a cached checkpoint is selected."""
|
| 401 |
+
if cached_path and cached_path != "(None found)":
|
| 402 |
+
return gr.update(value=f"file://{cached_path}")
|
| 403 |
+
return gr.update()
|
| 404 |
+
|
| 405 |
+
cached_checkpoints.change(
|
| 406 |
+
fn=lambda x: gr.update(value=f"file://{x}" if x and x != "(None found)" else ""),
|
| 407 |
+
inputs=cached_checkpoints,
|
| 408 |
+
outputs=checkpoint_url,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def on_cached_vae_change(cached_path):
|
| 412 |
+
"""Update VAE URL when a cached VAE is selected."""
|
| 413 |
+
if cached_path and cached_path != "(None found)":
|
| 414 |
+
return gr.update(value=f"file://{cached_path}")
|
| 415 |
+
return gr.update()
|
| 416 |
+
|
| 417 |
+
cached_vaes.change(
|
| 418 |
+
fn=on_cached_vae_change,
|
| 419 |
+
inputs=cached_vaes,
|
| 420 |
+
outputs=vae_url,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
def on_cached_lora_change(cached_path, current_urls):
|
| 424 |
+
"""Add cached LoRA to the list."""
|
| 425 |
+
if cached_path and cached_path != "(None found)":
|
| 426 |
+
# Add new LoRA to existing URLs (avoid duplicate)
|
| 427 |
+
urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
|
| 428 |
+
if cached_path not in urls_list:
|
| 429 |
+
urls_list.append(cached_path)
|
| 430 |
+
return gr.update(value="\n".join(urls_list))
|
| 431 |
+
return gr.update()
|
| 432 |
+
|
| 433 |
+
cached_loras.change(
|
| 434 |
+
fn=on_cached_lora_change,
|
| 435 |
+
inputs=[cached_loras, lora_urls],
|
| 436 |
+
outputs=lora_urls,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def on_generate_start():
|
| 441 |
+
"""Called when image generation starts."""
|
| 442 |
+
return (
|
| 443 |
+
'<div class="status-warning">⏳ Generating image...</div>',
|
| 444 |
+
"Starting generation...",
|
| 445 |
+
gr.update(interactive=False)
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
def on_generate_complete(status_msg, progress_text, image):
|
| 449 |
+
"""Called when image generation completes."""
|
| 450 |
+
if image is None:
|
| 451 |
+
return (
|
| 452 |
+
f'<div class="status-error">{status_msg}</div>',
|
| 453 |
+
"",
|
| 454 |
+
gr.update(interactive=True),
|
| 455 |
+
gr.update()
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
return (
|
| 459 |
+
'<div class="status-success">✅ Generation complete!</div>',
|
| 460 |
+
"Done",
|
| 461 |
+
gr.update(interactive=True),
|
| 462 |
+
gr.update(value=image)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
gen_btn.click(
|
| 466 |
+
fn=on_generate_start,
|
| 467 |
+
inputs=[],
|
| 468 |
+
outputs=[gen_status, gen_progress, gen_btn],
|
| 469 |
+
).then(
|
| 470 |
+
fn=generate_image,
|
| 471 |
+
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 472 |
+
outputs=[image_output, gen_progress],
|
| 473 |
+
show_api=False,
|
| 474 |
+
).then(
|
| 475 |
+
fn=lambda img, msg: on_generate_complete(msg, "Done", img),
|
| 476 |
+
inputs=[image_output, gen_progress],
|
| 477 |
+
outputs=[gen_status, gen_progress, gen_btn, image_output],
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def on_export_start():
|
| 481 |
+
"""Called when export starts."""
|
| 482 |
+
return (
|
| 483 |
+
'<div class="status-warning">⏳ Export started...</div>',
|
| 484 |
+
"Starting export...",
|
| 485 |
+
gr.update(interactive=False)
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def on_export_complete(status_msg, progress_text, file_path):
|
| 489 |
+
"""Called when export completes."""
|
| 490 |
+
if file_path is None:
|
| 491 |
+
return (
|
| 492 |
+
f'<div class="status-error">{status_msg}</div>',
|
| 493 |
+
"",
|
| 494 |
+
gr.update(interactive=True),
|
| 495 |
+
gr.update(value=None)
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
return (
|
| 499 |
+
'<div class="status-success">✅ Export complete!</div>',
|
| 500 |
+
"Exported successfully",
|
| 501 |
+
gr.update(interactive=True),
|
| 502 |
+
gr.update(value=file_path)
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
export_btn.click(
|
| 506 |
+
fn=on_export_start,
|
| 507 |
+
inputs=[],
|
| 508 |
+
outputs=[export_status, export_progress, export_btn],
|
| 509 |
+
).then(
|
| 510 |
+
fn=lambda inc, q, qt, fmt: export_merged_model(
|
| 511 |
+
include_lora=inc,
|
| 512 |
+
quantize=q and (qt != "none"),
|
| 513 |
+
qtype=qt if qt != "none" else None,
|
| 514 |
+
save_format=fmt,
|
| 515 |
+
),
|
| 516 |
+
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
|
| 517 |
+
outputs=[download_link, export_progress],
|
| 518 |
+
show_api=False,
|
| 519 |
+
).then(
|
| 520 |
+
fn=lambda path, msg: on_export_complete(msg, "Exported", path),
|
| 521 |
+
inputs=[download_link, export_progress],
|
| 522 |
+
outputs=[export_status, export_progress, export_btn, download_link],
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
quantize_toggle.change(
|
| 526 |
+
fn=lambda checked: gr.update(visible=checked),
|
| 527 |
+
inputs=[quantize_toggle],
|
| 528 |
+
outputs=qtype_row,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
return demo
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def main():
|
| 535 |
+
"""Create and launch the Gradio app."""
|
| 536 |
+
app = create_app()
|
| 537 |
+
# CSS is embedded in the Blocks, so we pass it to launch for Gradio 6+
|
| 538 |
+
app.launch()
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
if __name__ == "__main__":
|
| 542 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SDXL Model Merger - Dependencies
|
| 2 |
+
|
| 3 |
+
# Core ML frameworks
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
diffusers>=0.24.0
|
| 6 |
+
transformers>=4.35.0
|
| 7 |
+
safetensors>=0.4.0
|
| 8 |
+
|
| 9 |
+
# Image processing
|
| 10 |
+
Pillow>=10.0.0
|
| 11 |
+
|
| 12 |
+
# UI framework
|
| 13 |
+
gradio>=4.0.0
|
| 14 |
+
|
| 15 |
+
# Download utilities
|
| 16 |
+
tqdm>=4.65.0
|
| 17 |
+
requests>=2.31.0
|
| 18 |
+
|
| 19 |
+
# Optional: quantization support
|
| 20 |
+
optimum-quanto>=0.2.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SDXL Model Merger - Modular SDXL pipeline management and generation."""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (233 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (270 Bytes). View file
|
|
|
src/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
src/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
src/__pycache__/downloader.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
src/__pycache__/downloader.cpython-313.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
src/__pycache__/exporter.cpython-311.pyc
ADDED
|
Binary file (6.9 kB). View file
|
|
|
src/__pycache__/exporter.cpython-313.pyc
ADDED
|
Binary file (5.67 kB). View file
|
|
|
src/__pycache__/generator.cpython-311.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
src/__pycache__/generator.cpython-313.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
src/__pycache__/pipeline.cpython-311.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
src/__pycache__/pipeline.cpython-313.pyc
ADDED
|
Binary file (8.11 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration constants and global settings for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# ──────────────────────────────────────────────
|
| 7 |
+
# Paths & Directories
|
| 8 |
+
# ──────────────────────────────────────────────
|
| 9 |
+
SCRIPT_DIR = Path.cwd()
|
| 10 |
+
CACHE_DIR = SCRIPT_DIR / ".cache"
|
| 11 |
+
CACHE_DIR.mkdir(exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# ──────────────────────────────────────────────
|
| 14 |
+
# Default URLs
|
| 15 |
+
# ──────────────────────────────────────────────
|
| 16 |
+
DEFAULT_CHECKPOINT_URL = "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16"
|
| 17 |
+
DEFAULT_VAE_URL = "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true"
|
| 18 |
+
DEFAULT_LORA_URLS = "https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor"
|
| 19 |
+
|
| 20 |
+
# ──────────────────────────────────────────────
|
| 21 |
+
# PyTorch & Device Settings
|
| 22 |
+
# ──────────────────────────────────────────────
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_device_info() -> tuple[str, str]:
|
| 27 |
+
"""
|
| 28 |
+
Detect and return the optimal device for ML inference.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tuple of (device_name, device_description)
|
| 32 |
+
"""
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
device_name = "cuda"
|
| 35 |
+
device_desc = f"CUDA (GPU: {torch.cuda.get_device_name(0)})"
|
| 36 |
+
elif torch.backends.mps.is_available():
|
| 37 |
+
device_name = "mps"
|
| 38 |
+
device_desc = "Apple Silicon MPS"
|
| 39 |
+
else:
|
| 40 |
+
device_name = "cpu"
|
| 41 |
+
device_desc = "CPU (no GPU available)"
|
| 42 |
+
|
| 43 |
+
return device_name, device_desc
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
device, device_description = get_device_info()
|
| 47 |
+
dtype = torch.float16
|
| 48 |
+
|
| 49 |
+
print(f"🚀 Using device: {device_description}")
|
| 50 |
+
|
| 51 |
+
# ──────────────────────────────────────────────
|
| 52 |
+
# Global State
|
| 53 |
+
# ──────────────────────────────────────────────
|
| 54 |
+
pipe = None
|
| 55 |
+
download_cancelled = False
|
| 56 |
+
|
| 57 |
+
# ──────────────────────────────────────────────
|
| 58 |
+
# Generation Defaults
|
| 59 |
+
# ──────────────────────────────────────────────
|
| 60 |
+
DEFAULT_PROMPT = "Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic"
|
| 61 |
+
DEFAULT_NEGATIVE_PROMPT = "boring, text, signature, watermark, low quality, bad quality"
|
| 62 |
+
|
| 63 |
+
# ──────────────────────────────────────────────
|
| 64 |
+
# Model Presets (URLs for common models)
|
| 65 |
+
# ──────────────────────────────────────────────
|
| 66 |
+
MODEL_PRESETS = {
|
| 67 |
+
# Checkpoints
|
| 68 |
+
"DreamShaper XL v2": "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
|
| 69 |
+
"Realism Engine SDXL": "https://civitai.com/api/download/models/328799?type=Model&format=SafeTensor&size=full&fp=fp16",
|
| 70 |
+
"Juggernaut XL v9": "https://civitai.com/api/download/models/350565?type=Model&format=SafeTensor&size=full&fp=fp16",
|
| 71 |
+
|
| 72 |
+
# VAEs
|
| 73 |
+
"VAE-FP16 Fix": "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
|
| 74 |
+
|
| 75 |
+
# LoRAs
|
| 76 |
+
"Rainbow Color LoRA": "https://civitai.com/api/download/models/127983?type=Model&format=SafeTensor",
|
| 77 |
+
"More Details LoRA": "https://civitai.com/api/download/models/280590?type=Model&format=SafeTensor",
|
| 78 |
+
"Epic Realism LoRA": "https://civitai.com/api/download/models/346631?type=Model&format=SafeTensor",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_cached_models():
|
| 83 |
+
"""Get list of cached model files."""
|
| 84 |
+
if not CACHE_DIR.exists():
|
| 85 |
+
return []
|
| 86 |
+
|
| 87 |
+
models = []
|
| 88 |
+
for file in sorted(CACHE_DIR.glob("*.safetensors")):
|
| 89 |
+
models.append(str(file))
|
| 90 |
+
return models
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_cached_model_names():
|
| 94 |
+
"""Get display names for cached models."""
|
| 95 |
+
models = get_cached_models()
|
| 96 |
+
return [str(m.name) for m in models]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_cached_checkpoints():
|
| 100 |
+
"""Get list of cached checkpoint files (model_id_model.safetensors)."""
|
| 101 |
+
if not CACHE_DIR.exists():
|
| 102 |
+
return []
|
| 103 |
+
|
| 104 |
+
models = []
|
| 105 |
+
for file in sorted(CACHE_DIR.glob("*_model.safetensors")):
|
| 106 |
+
models.append(str(file))
|
| 107 |
+
return models
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_cached_vaes():
|
| 111 |
+
"""Get list of cached VAE files (model_id_*_vae.safetensors)."""
|
| 112 |
+
if not CACHE_DIR.exists():
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
models = []
|
| 116 |
+
for file in sorted(CACHE_DIR.glob("*_vae.safetensors")):
|
| 117 |
+
models.append(str(file))
|
| 118 |
+
return models
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_cached_loras():
|
| 122 |
+
"""Get list of cached LoRA files (model_id_*_lora.safetensors)."""
|
| 123 |
+
if not CACHE_DIR.exists():
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
models = []
|
| 127 |
+
for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
|
| 128 |
+
models.append(str(file))
|
| 129 |
+
return models
|
src/downloader.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download utilities for SDXL Model Merger with Gradio progress integration."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import requests
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tqdm import tqdm as TqdmBase
|
| 7 |
+
|
| 8 |
+
from .config import download_cancelled
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def extract_model_id(url: str) -> str | None:
|
| 12 |
+
"""Extract CivitAI model ID from URL."""
|
| 13 |
+
match = re.search(r'/models/(\d+)', url)
|
| 14 |
+
return match.group(1) if match else None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_safe_filename_from_url(
|
| 18 |
+
url: str,
|
| 19 |
+
default_name: str = "model.safetensors",
|
| 20 |
+
suffix: str = "",
|
| 21 |
+
type_prefix: str | None = None
|
| 22 |
+
) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Generate a safe filename with model ID from URL.
|
| 25 |
+
|
| 26 |
+
For CivitAI URLs like https://civitai.com/api/download/models/12345?type=...
|
| 27 |
+
|
| 28 |
+
Naming patterns:
|
| 29 |
+
- Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors
|
| 30 |
+
- VAE (suffix='_vae'): 12345_vae.safetensors or 12345_anime_vae.safetensors
|
| 31 |
+
- LoRA (suffix='_lora'): 12345_lora.safetensors or 12345_name_lora.safetensors
|
| 32 |
+
|
| 33 |
+
For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
url: The download URL
|
| 37 |
+
default_name: Fallback filename if extraction fails
|
| 38 |
+
suffix: Optional suffix to append before .safetensors (e.g., '_vae', '_lora')
|
| 39 |
+
type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors)
|
| 40 |
+
"""
|
| 41 |
+
model_id = extract_model_id(url)
|
| 42 |
+
|
| 43 |
+
# If no CivitAI model ID, try to generate a name from HuggingFace path
|
| 44 |
+
if not model_id and "huggingface.co" in url:
|
| 45 |
+
# Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> vae)
|
| 46 |
+
try:
|
| 47 |
+
parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else ""
|
| 48 |
+
if parts:
|
| 49 |
+
# Get the repo name (second part after org/)
|
| 50 |
+
path_parts = [p for p in parts.split("/") if p]
|
| 51 |
+
if len(path_parts) >= 2:
|
| 52 |
+
repo_name = path_parts[1]
|
| 53 |
+
# Clean up and create a simple identifier
|
| 54 |
+
clean_repo = re.sub(r'[^a-zA-Z0-9]', '_', repo_name)[:30].strip('_')
|
| 55 |
+
if clean_repo:
|
| 56 |
+
model_id = f"hf_{clean_repo}"
|
| 57 |
+
except Exception:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
if not model_id:
|
| 61 |
+
return default_name
|
| 62 |
+
|
| 63 |
+
# Build the name portion: either clean name from URL or fallback
|
| 64 |
+
name_part = ""
|
| 65 |
+
|
| 66 |
+
# For VAE/LoRA types, prefer the suffix-based naming and skip Content-Disposition parsing
|
| 67 |
+
# to avoid double naming (e.g., sdxlvae_vae instead of just vae)
|
| 68 |
+
is_special_type = suffix in ("_vae", "_lora")
|
| 69 |
+
|
| 70 |
+
if not is_special_type:
|
| 71 |
+
try:
|
| 72 |
+
response = requests.head(url, timeout=10, allow_redirects=True)
|
| 73 |
+
cd = response.headers.get('Content-Disposition', '')
|
| 74 |
+
match = re.search(r'filename="([^"]+)"', cd)
|
| 75 |
+
if match:
|
| 76 |
+
filename = match.group(1)
|
| 77 |
+
# Extract base name without extension
|
| 78 |
+
base_name = Path(filename).stem
|
| 79 |
+
# Clean up the name (remove special chars)
|
| 80 |
+
clean_name = re.sub(r'[^\w\s-]', '', base_name)[:50]
|
| 81 |
+
clean_name = re.sub(r'[-\s]+', '_', clean_name.strip('-_'))
|
| 82 |
+
if clean_name:
|
| 83 |
+
name_part = clean_name
|
| 84 |
+
except Exception:
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
# Build filename with model_id, optional type_prefix, optional name_part, and suffix
|
| 88 |
+
parts = [model_id]
|
| 89 |
+
if type_prefix:
|
| 90 |
+
parts.append(type_prefix)
|
| 91 |
+
if name_part:
|
| 92 |
+
parts.append(name_part)
|
| 93 |
+
if suffix:
|
| 94 |
+
# Avoid double underscores: only add separator if needed
|
| 95 |
+
if not suffix.startswith('_'):
|
| 96 |
+
parts.append('_' + suffix.lstrip('_'))
|
| 97 |
+
else:
|
| 98 |
+
parts.append(suffix)
|
| 99 |
+
|
| 100 |
+
return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors'
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TqdmGradio(TqdmBase):
|
| 104 |
+
"""tqdm subclass that sends progress updates to Gradio's gr.Progress()"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, *args, gradio_prog=None, **kwargs):
|
| 107 |
+
super().__init__(*args, **kwargs)
|
| 108 |
+
self.gradio_prog = gradio_prog
|
| 109 |
+
self.last_pct = 0
|
| 110 |
+
|
| 111 |
+
def update(self, n=1):
|
| 112 |
+
global download_cancelled
|
| 113 |
+
if download_cancelled:
|
| 114 |
+
raise KeyboardInterrupt("Download cancelled by user")
|
| 115 |
+
super().update(n)
|
| 116 |
+
if self.gradio_prog and self.total:
|
| 117 |
+
pct = int(100 * self.n / self.total)
|
| 118 |
+
# Only update UI every ~5% to avoid spamming
|
| 119 |
+
if pct != self.last_pct and pct % 5 == 0:
|
| 120 |
+
self.last_pct = pct
|
| 121 |
+
self.gradio_prog(pct / 100)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def set_download_cancelled(value: bool):
|
| 125 |
+
"""Set the global download cancellation flag."""
|
| 126 |
+
global download_cancelled
|
| 127 |
+
download_cancelled = value
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_cached_file_size(url: str) -> tuple[Path | None, int | None]:
|
| 131 |
+
"""
|
| 132 |
+
Check if file exists in cache and matches expected size.
|
| 133 |
+
Returns (path, expected_size) or (None, None) if no valid cache.
|
| 134 |
+
"""
|
| 135 |
+
# Simple implementation - would need URL-to-filename mapping for production
|
| 136 |
+
return None, None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
|
| 140 |
+
"""
|
| 141 |
+
Download a file with Gradio-synced progress bar + cancel support.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
url: File URL to download (http/https/file)
|
| 145 |
+
output_path: Destination path for downloaded file
|
| 146 |
+
progress_bar: Optional gr.Progress() object for UI updates
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Path to the downloaded file
|
| 150 |
+
|
| 151 |
+
Raises:
|
| 152 |
+
KeyboardInterrupt: If download is cancelled
|
| 153 |
+
requests.RequestException: If download fails
|
| 154 |
+
"""
|
| 155 |
+
global download_cancelled
|
| 156 |
+
download_cancelled = False
|
| 157 |
+
|
| 158 |
+
# Handle local file:// URLs
|
| 159 |
+
if url.startswith("file://"):
|
| 160 |
+
local_path = Path(url[7:]) # Remove "file://" prefix
|
| 161 |
+
if local_path.exists():
|
| 162 |
+
import shutil
|
| 163 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
# Copy the file to cache location
|
| 165 |
+
shutil.copy2(str(local_path), str(output_path))
|
| 166 |
+
|
| 167 |
+
# Update progress bar for cached files
|
| 168 |
+
if progress_bar:
|
| 169 |
+
progress_bar(1.0)
|
| 170 |
+
return output_path
|
| 171 |
+
else:
|
| 172 |
+
raise FileNotFoundError(f"Local file not found: {local_path}")
|
| 173 |
+
|
| 174 |
+
# Cache check: if file exists and size matches URL's content-length, skip re-download
|
| 175 |
+
expected_size = None
|
| 176 |
+
try:
|
| 177 |
+
head = requests.head(url, timeout=10)
|
| 178 |
+
expected_size = int(head.headers.get('content-length', 0))
|
| 179 |
+
if output_path.exists() and output_path.stat().st_size == expected_size:
|
| 180 |
+
# Cache hit - still update progress to show completion
|
| 181 |
+
if progress_bar:
|
| 182 |
+
progress_bar(1.0)
|
| 183 |
+
return output_path # Cache hit!
|
| 184 |
+
except Exception:
|
| 185 |
+
pass # Skip cache validation on errors
|
| 186 |
+
|
| 187 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
session = requests.Session()
|
| 190 |
+
response = session.get(url, stream=True, timeout=30)
|
| 191 |
+
response.raise_for_status()
|
| 192 |
+
|
| 193 |
+
total_size = expected_size or int(response.headers.get('content-length', 0))
|
| 194 |
+
block_size = 8192
|
| 195 |
+
|
| 196 |
+
# Use TqdmGradio to sync progress with Gradio
|
| 197 |
+
tqdm_kwargs = {
|
| 198 |
+
'unit': 'B',
|
| 199 |
+
'unit_scale': True,
|
| 200 |
+
'desc': f"Downloading {output_path.name}",
|
| 201 |
+
'gradio_prog': progress_bar,
|
| 202 |
+
'disable': False,
|
| 203 |
+
'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]',
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
with open(output_path, "wb") as f:
|
| 207 |
+
try:
|
| 208 |
+
for data in TqdmGradio(
|
| 209 |
+
response.iter_content(block_size),
|
| 210 |
+
total=total_size // block_size if total_size else 0,
|
| 211 |
+
**tqdm_kwargs,
|
| 212 |
+
):
|
| 213 |
+
if download_cancelled:
|
| 214 |
+
raise KeyboardInterrupt("Download cancelled by user")
|
| 215 |
+
f.write(data)
|
| 216 |
+
except KeyboardInterrupt:
|
| 217 |
+
# Clean partial file on cancel
|
| 218 |
+
output_path.unlink(missing_ok=True)
|
| 219 |
+
raise
|
| 220 |
+
|
| 221 |
+
return output_path
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None):
|
| 225 |
+
"""
|
| 226 |
+
Remove old cache files.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
cache_dir: Cache directory path (defaults to config.CACHE_DIR)
|
| 230 |
+
keep_extensions: File extensions to preserve (default: ['.safetensors'])
|
| 231 |
+
"""
|
| 232 |
+
if cache_dir is None:
|
| 233 |
+
from .config import CACHE_DIR
|
| 234 |
+
cache_dir = CACHE_DIR
|
| 235 |
+
|
| 236 |
+
if keep_extensions is None:
|
| 237 |
+
keep_extensions = ['.safetensors']
|
| 238 |
+
|
| 239 |
+
# Remove temp files
|
| 240 |
+
for file in cache_dir.glob("*.tmp"):
|
| 241 |
+
file.unlink()
|
| 242 |
+
|
| 243 |
+
# Optional: age-based cleanup (7 days)
|
| 244 |
+
# import time
|
| 245 |
+
# cutoff = time.time() - 86400 * 7
|
| 246 |
+
# for f in cache_dir.iterdir():
|
| 247 |
+
# if f.is_file() and f.stat().st_mtime < cutoff:
|
| 248 |
+
# f.unlink()
|
src/exporter.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model export functionality for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from safetensors.torch import save_file
|
| 8 |
+
|
| 9 |
+
from .config import SCRIPT_DIR, pipe as global_pipe
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def export_merged_model(
|
| 13 |
+
include_lora: bool,
|
| 14 |
+
quantize: bool,
|
| 15 |
+
qtype: str,
|
| 16 |
+
save_format: str = "safetensors",
|
| 17 |
+
) -> tuple[str | None, str]:
|
| 18 |
+
"""
|
| 19 |
+
Export the merged pipeline model with optional LoRA baking and quantization.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
include_lora: Whether to include fused LoRAs in export
|
| 23 |
+
quantize: Whether to apply quantization
|
| 24 |
+
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
|
| 25 |
+
save_format: Output format - 'safetensors' or 'bin'
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Tuple of (output_path or None, status message)
|
| 29 |
+
"""
|
| 30 |
+
if not global_pipe:
|
| 31 |
+
return None, "⚠️ Please load a pipeline first."
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# Step 1: Unload LoRAs
|
| 35 |
+
yield "💾 Exporting model...", "Unloading LoRAs..."
|
| 36 |
+
if include_lora:
|
| 37 |
+
global_pipe.unload_lora_weights()
|
| 38 |
+
|
| 39 |
+
merged_state_dict = {}
|
| 40 |
+
|
| 41 |
+
# Step 2: Extract UNet weights
|
| 42 |
+
yield "💾 Exporting model...", "Extracting UNet weights..."
|
| 43 |
+
for k, v in global_pipe.unet.state_dict().items():
|
| 44 |
+
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
|
| 45 |
+
|
| 46 |
+
# Step 3: Extract text encoder weights
|
| 47 |
+
yield "💾 Exporting model...", "Extracting text encoders..."
|
| 48 |
+
for k, v in global_pipe.text_encoder.state_dict().items():
|
| 49 |
+
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
|
| 50 |
+
for k, v in global_pipe.text_encoder_2.state_dict().items():
|
| 51 |
+
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
|
| 52 |
+
|
| 53 |
+
# Step 4: Extract VAE weights
|
| 54 |
+
yield "💾 Exporting model...", "Extracting VAE weights..."
|
| 55 |
+
for k, v in global_pipe.vae.state_dict().items():
|
| 56 |
+
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
|
| 57 |
+
|
| 58 |
+
# Step 5: Quantize if requested and optimum.quanto is available
|
| 59 |
+
try:
|
| 60 |
+
from optimum.quanto import quantize, QTensor
|
| 61 |
+
QUANTO_AVAILABLE = True
|
| 62 |
+
except ImportError:
|
| 63 |
+
QUANTO_AVAILABLE = False
|
| 64 |
+
|
| 65 |
+
if quantize and qtype != "none" and QUANTO_AVAILABLE:
|
| 66 |
+
yield "💾 Exporting model...", f"Applying {qtype} quantization..."
|
| 67 |
+
|
| 68 |
+
class FakeModel(torch.nn.Module):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
fake_model = FakeModel()
|
| 72 |
+
fake_model.__dict__.update(merged_state_dict)
|
| 73 |
+
|
| 74 |
+
# Select quantization method
|
| 75 |
+
if qtype == "int8":
|
| 76 |
+
from optimum.quanto import int8_weight_only
|
| 77 |
+
quantize(fake_model, int8_weight_only())
|
| 78 |
+
elif qtype == "int4":
|
| 79 |
+
from optimum.quanto import int4_weight_only
|
| 80 |
+
quantize(fake_model, int4_weight_only())
|
| 81 |
+
elif qtype == "float8":
|
| 82 |
+
from optimum.quanto import float8_dynamic_activation_float8_weight
|
| 83 |
+
quantize(fake_model, float8_dynamic_activation_float8_weight())
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unsupported qtype: {qtype}")
|
| 86 |
+
|
| 87 |
+
merged_state_dict = {
|
| 88 |
+
k: v.dequantize().half() if isinstance(v, QTensor) else v
|
| 89 |
+
for k, v in fake_model.state_dict().items()
|
| 90 |
+
}
|
| 91 |
+
elif quantize and not QUANTO_AVAILABLE:
|
| 92 |
+
return None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
|
| 93 |
+
|
| 94 |
+
# Step 6: Save model
|
| 95 |
+
yield "💾 Exporting model...", "Saving weights..."
|
| 96 |
+
|
| 97 |
+
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 98 |
+
|
| 99 |
+
# Build filename based on options
|
| 100 |
+
prefix = ""
|
| 101 |
+
if quantize and qtype != "none":
|
| 102 |
+
prefix = f"{qtype}_"
|
| 103 |
+
|
| 104 |
+
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
|
| 105 |
+
|
| 106 |
+
# Save appropriately
|
| 107 |
+
if ext == ".bin":
|
| 108 |
+
torch.save(merged_state_dict, str(out_path))
|
| 109 |
+
else:
|
| 110 |
+
save_file(merged_state_dict, str(out_path))
|
| 111 |
+
|
| 112 |
+
size_gb = out_path.stat().st_size / 1024**3
|
| 113 |
+
|
| 114 |
+
if quantize and qtype != "none":
|
| 115 |
+
msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
|
| 116 |
+
else:
|
| 117 |
+
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
|
| 118 |
+
|
| 119 |
+
yield "💾 Exporting model...", msg
|
| 120 |
+
return str(out_path), msg
|
| 121 |
+
|
| 122 |
+
except ImportError as e:
|
| 123 |
+
return None, f"❌ Missing dependency: {str(e)}"
|
| 124 |
+
except Exception as e:
|
| 125 |
+
return None, f"❌ Export failed: {str(e)}"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_export_status() -> str:
|
| 129 |
+
"""Get current export capability status."""
|
| 130 |
+
try:
|
| 131 |
+
from optimum.quanto import quantize
|
| 132 |
+
return "✅ optimum.quanto available for quantization"
|
| 133 |
+
except ImportError:
|
| 134 |
+
return "ℹ️ Install optimum-quanto for quantization support"
|
src/generator.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image generation functions for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .config import device, dtype, pipe as global_pipe
|
| 6 |
+
from .pipeline import enable_seamless_tiling
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def generate_image(
|
| 10 |
+
prompt: str,
|
| 11 |
+
negative_prompt: str,
|
| 12 |
+
cfg: float,
|
| 13 |
+
steps: int,
|
| 14 |
+
height: int,
|
| 15 |
+
width: int,
|
| 16 |
+
tile_x: bool = True,
|
| 17 |
+
tile_y: bool = False,
|
| 18 |
+
) -> tuple[object | None, str]:
|
| 19 |
+
"""
|
| 20 |
+
Generate an image using the loaded SDXL pipeline.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
prompt: Positive prompt for image generation
|
| 24 |
+
negative_prompt: Negative prompt to avoid certain elements
|
| 25 |
+
cfg: Classifier-Free Guidance scale (1.0-20.0)
|
| 26 |
+
steps: Number of inference steps (1-50)
|
| 27 |
+
height: Output image height in pixels
|
| 28 |
+
width: Output image width in pixels
|
| 29 |
+
tile_x: Enable seamless tiling on x-axis
|
| 30 |
+
tile_y: Enable seamless tiling on y-axis
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Tuple of (PIL Image or None, status message)
|
| 34 |
+
"""
|
| 35 |
+
if not global_pipe:
|
| 36 |
+
return None, "⚠️ Please load a pipeline first."
|
| 37 |
+
|
| 38 |
+
# Enable seamless tiling on UNet & VAE decoder
|
| 39 |
+
enable_seamless_tiling(global_pipe.unet, tile_x=tile_x, tile_y=tile_y)
|
| 40 |
+
enable_seamless_tiling(global_pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)
|
| 41 |
+
|
| 42 |
+
yield "🎨 Generating image...", f"Steps: 0/{steps} | CFG: {cfg}"
|
| 43 |
+
|
| 44 |
+
generator = torch.Generator(device=device).manual_seed(42) # Fixed seed for reproducibility
|
| 45 |
+
result = global_pipe(
|
| 46 |
+
prompt=prompt,
|
| 47 |
+
negative_prompt=negative_prompt,
|
| 48 |
+
width=int(width),
|
| 49 |
+
height=int(height),
|
| 50 |
+
num_inference_steps=int(steps),
|
| 51 |
+
guidance_scale=float(cfg),
|
| 52 |
+
generator=generator,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
image = result.images[0]
|
| 56 |
+
yield "🎨 Generating image...", f"✅ Complete! ({width}x{height})"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def set_pipeline(pipe):
|
| 60 |
+
"""Set the global pipeline instance."""
|
| 61 |
+
global global_pipe
|
| 62 |
+
global_pipe = pipe
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline management for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import (
|
| 7 |
+
StableDiffusionXLPipeline,
|
| 8 |
+
AutoencoderKL,
|
| 9 |
+
DPMSolverSDEScheduler,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from .config import device, dtype, pipe as global_pipe, CACHE_DIR, download_cancelled, device_description
|
| 13 |
+
from .downloader import download_file_with_progress, get_safe_filename_from_url
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
|
| 17 |
+
"""Create patched forward for seamless tiling on Conv2d layers."""
|
| 18 |
+
original_forward = module._conv_forward
|
| 19 |
+
|
| 20 |
+
def patched_conv_forward(input, weight, bias):
|
| 21 |
+
if tile_x and tile_y:
|
| 22 |
+
input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
|
| 23 |
+
elif tile_x:
|
| 24 |
+
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
|
| 25 |
+
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
|
| 26 |
+
elif tile_y:
|
| 27 |
+
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
|
| 28 |
+
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
|
| 29 |
+
else:
|
| 30 |
+
return original_forward(input, weight, bias)
|
| 31 |
+
|
| 32 |
+
return torch.nn.functional.conv2d(
|
| 33 |
+
input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return patched_conv_forward
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
|
| 40 |
+
"""
|
| 41 |
+
Enable seamless tiling on a model's Conv2d layers.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
|
| 45 |
+
tile_x: Enable tiling along x-axis
|
| 46 |
+
tile_y: Enable tiling along y-axis
|
| 47 |
+
"""
|
| 48 |
+
for module in model.modules():
|
| 49 |
+
if isinstance(module, torch.nn.Conv2d):
|
| 50 |
+
pad_h = module.padding[0]
|
| 51 |
+
pad_w = module.padding[1]
|
| 52 |
+
if pad_h == 0 and pad_w == 0:
|
| 53 |
+
continue
|
| 54 |
+
module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_pipeline(
|
| 58 |
+
checkpoint_url: str,
|
| 59 |
+
vae_url: str,
|
| 60 |
+
lora_urls_str: str,
|
| 61 |
+
lora_strengths_str: str,
|
| 62 |
+
progress=None
|
| 63 |
+
) -> tuple[str, str]:
|
| 64 |
+
"""
|
| 65 |
+
Load SDXL pipeline with checkpoint, VAE, and LoRAs.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
checkpoint_url: URL to base model .safetensors file
|
| 69 |
+
vae_url: Optional URL to VAE .safetensors file
|
| 70 |
+
lora_urls_str: Newline-separated URLs for LoRA models
|
| 71 |
+
lora_strengths_str: Comma-separated strength values for each LoRA
|
| 72 |
+
progress: Optional gr.Progress() object for UI updates
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tuple of (final_status_message, progress_text)
|
| 76 |
+
"""
|
| 77 |
+
global global_pipe, download_cancelled
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
|
| 81 |
+
checkpoint_path = CACHE_DIR / checkpoint_filename
|
| 82 |
+
|
| 83 |
+
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
|
| 84 |
+
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
|
| 85 |
+
vae_path = CACHE_DIR / vae_filename
|
| 86 |
+
|
| 87 |
+
# Download checkpoint
|
| 88 |
+
if progress:
|
| 89 |
+
progress(0.1, desc="Downloading base model...")
|
| 90 |
+
yield f"📥 Downloading {checkpoint_path.name}...", "Starting download..."
|
| 91 |
+
download_file_with_progress(checkpoint_url, checkpoint_path)
|
| 92 |
+
|
| 93 |
+
# Download VAE if provided
|
| 94 |
+
if vae_url.strip():
|
| 95 |
+
if progress:
|
| 96 |
+
progress(0.2, desc="Downloading VAE...")
|
| 97 |
+
yield f"📥 Downloading {vae_path.name}...", f"Downloading VAE: {vae_path.name}"
|
| 98 |
+
download_file_with_progress(vae_url, vae_path)
|
| 99 |
+
vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
|
| 100 |
+
else:
|
| 101 |
+
vae = None
|
| 102 |
+
|
| 103 |
+
# Load base pipeline
|
| 104 |
+
if progress:
|
| 105 |
+
progress(0.4, desc="Loading SDXL pipeline...")
|
| 106 |
+
yield f"⚙️ Loading pipeline...", f"Using device: {device_description}"
|
| 107 |
+
global_pipe = StableDiffusionXLPipeline.from_single_file(
|
| 108 |
+
str(checkpoint_path),
|
| 109 |
+
torch_dtype=dtype,
|
| 110 |
+
use_safetensors=True,
|
| 111 |
+
safety_checker=None,
|
| 112 |
+
)
|
| 113 |
+
if vae:
|
| 114 |
+
global_pipe.vae = vae.to(device=device, dtype=dtype)
|
| 115 |
+
|
| 116 |
+
# Parse LoRA URLs & ensure strengths list matches
|
| 117 |
+
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
|
| 118 |
+
strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
|
| 119 |
+
strengths = []
|
| 120 |
+
for i, url in enumerate(lora_urls):
|
| 121 |
+
try:
|
| 122 |
+
val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0
|
| 123 |
+
strengths.append(val)
|
| 124 |
+
except ValueError:
|
| 125 |
+
strengths.append(1.0)
|
| 126 |
+
|
| 127 |
+
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 128 |
+
if lora_urls:
|
| 129 |
+
first_lora_filename = get_safe_filename_from_url(lora_urls[0], "lora_0.safetensors", suffix="_lora")
|
| 130 |
+
first_lora_path = CACHE_DIR / first_lora_filename
|
| 131 |
+
yield f"📥 Downloading LoRA: {first_lora_path.name}...", f"Downloading LoRA 1/... ({first_lora_path.name})..."
|
| 132 |
+
download_file_with_progress(lora_urls[0], first_lora_path)
|
| 133 |
+
|
| 134 |
+
global_pipe.load_lora_weights(str(first_lora_path), adapter_name="main_lora")
|
| 135 |
+
global_pipe.fuse_lora(adapter_names=["main_lora"], lora_scale=strengths[0])
|
| 136 |
+
|
| 137 |
+
for i in range(1, len(lora_urls)):
|
| 138 |
+
lora_filename = get_safe_filename_from_url(lora_urls[i], f"lora_{i}.safetensors", suffix="_lora")
|
| 139 |
+
lora_path = CACHE_DIR / lora_filename
|
| 140 |
+
yield f"📥 Downloading LoRA {i+1}...", f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..."
|
| 141 |
+
download_file_with_progress(lora_urls[i], lora_path)
|
| 142 |
+
|
| 143 |
+
global_pipe.unload_lora_weights()
|
| 144 |
+
global_pipe.load_lora_weights(str(lora_path), adapter_name=f"lora_{i}")
|
| 145 |
+
# Fuse all loaded adapters so far
|
| 146 |
+
global_pipe.fuse_lora(
|
| 147 |
+
adapter_names=["main_lora"] + [f"lora_{j}" for j in range(1, i+1)],
|
| 148 |
+
lora_scale=strengths[i]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Set scheduler and move to device (do this once at the end)
|
| 152 |
+
yield "⚙️ Finalizing...", "Setting up scheduler..."
|
| 153 |
+
# Use existing scheduler, just update algorithm_type for DPM++ SDE
|
| 154 |
+
global_pipe.scheduler.config.algorithm_type = "sde-dpmsolver++"
|
| 155 |
+
global_pipe = global_pipe.to(device=device, dtype=dtype)
|
| 156 |
+
|
| 157 |
+
return ("✅ Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
|
| 158 |
+
|
| 159 |
+
except KeyboardInterrupt:
|
| 160 |
+
download_cancelled = False
|
| 161 |
+
return ("⚠️ Download cancelled by user", "Cancelled")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
return (f"❌ Error loading pipeline: {str(e)}", f"Error: {str(e)}")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def cancel_download():
|
| 167 |
+
"""Set the global cancellation flag to stop any ongoing downloads."""
|
| 168 |
+
global download_cancelled
|
| 169 |
+
download_cancelled = True
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_pipeline() -> StableDiffusionXLPipeline | None:
|
| 173 |
+
"""Get the currently loaded pipeline."""
|
| 174 |
+
return global_pipe
|
src/ui/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UI components for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
from .header import create_header
|
| 4 |
+
from .loader_tab import create_loader_tab
|
| 5 |
+
from .generator_tab import create_generator_tab
|
| 6 |
+
from .exporter_tab import create_exporter_tab
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"create_header",
|
| 10 |
+
"create_loader_tab",
|
| 11 |
+
"create_generator_tab",
|
| 12 |
+
"create_exporter_tab",
|
| 13 |
+
]
|
src/ui/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (510 Bytes). View file
|
|
|
src/ui/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (477 Bytes). View file
|
|
|
src/ui/__pycache__/exporter_tab.cpython-311.pyc
ADDED
|
Binary file (5.23 kB). View file
|
|
|
src/ui/__pycache__/exporter_tab.cpython-313.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
src/ui/__pycache__/generator_tab.cpython-311.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
src/ui/__pycache__/generator_tab.cpython-313.pyc
ADDED
|
Binary file (4.51 kB). View file
|
|
|
src/ui/__pycache__/header.cpython-311.pyc
ADDED
|
Binary file (5.51 kB). View file
|
|
|
src/ui/__pycache__/header.cpython-313.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|
src/ui/__pycache__/loader_tab.cpython-311.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
src/ui/__pycache__/loader_tab.cpython-313.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
src/ui/exporter_tab.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model exporter tab for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from ..exporter import export_merged_model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_exporter_tab():
|
| 9 |
+
"""Create the model export tab with all configuration options."""
|
| 10 |
+
|
| 11 |
+
with gr.Accordion("📦 3. Export Merged Model", open=True, elem_classes=["feature-card"]):
|
| 12 |
+
# Export settings
|
| 13 |
+
with gr.Row():
|
| 14 |
+
include_lora = gr.Checkbox(
|
| 15 |
+
True,
|
| 16 |
+
label="Include Fused LoRAs",
|
| 17 |
+
info="Bake the loaded LoRAs into the exported model"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
quantize_toggle = gr.Checkbox(
|
| 21 |
+
False,
|
| 22 |
+
label="Apply Quantization",
|
| 23 |
+
info="Reduce model size with quantization"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Quantization options
|
| 27 |
+
with gr.Row(visible=True) as qtype_row:
|
| 28 |
+
qtype_dropdown = gr.Dropdown(
|
| 29 |
+
choices=["none", "int8", "int4", "float8"],
|
| 30 |
+
value="int8",
|
| 31 |
+
label="Quantization Method",
|
| 32 |
+
info="Trade quality for smaller file size"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Format options
|
| 36 |
+
with gr.Row():
|
| 37 |
+
format_dropdown = gr.Dropdown(
|
| 38 |
+
choices=["safetensors", "bin"],
|
| 39 |
+
value="safetensors",
|
| 40 |
+
label="Export Format",
|
| 41 |
+
info="safetensors is recommended for safety"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Export button and output
|
| 45 |
+
with gr.Row():
|
| 46 |
+
export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg")
|
| 47 |
+
|
| 48 |
+
with gr.Row():
|
| 49 |
+
download_link = gr.File(
|
| 50 |
+
label="Download Merged File",
|
| 51 |
+
show_label=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
with gr.Column():
|
| 55 |
+
export_status = gr.Textbox(
|
| 56 |
+
label="Export Status",
|
| 57 |
+
placeholder="Ready to export..."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Info about quantization
|
| 61 |
+
gr.HTML("""
|
| 62 |
+
<div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
|
| 63 |
+
<strong>ℹ️ About Quantization:</strong>
|
| 64 |
+
<p style="font-size: 0.9em; margin: 8px 0;">
|
| 65 |
+
Reduces model size by lowering precision. Int8 is typically
|
| 66 |
+
lossless for inference while cutting size in half.
|
| 67 |
+
</p>
|
| 68 |
+
</div>
|
| 69 |
+
""")
|
| 70 |
+
|
| 71 |
+
return (
|
| 72 |
+
include_lora, quantize_toggle, qtype_dropdown, format_dropdown,
|
| 73 |
+
export_btn, download_link, export_status, qtype_row
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def setup_exporter_events(
|
| 78 |
+
include_lora, quantize_toggle, qtype_dropdown, format_dropdown,
|
| 79 |
+
export_btn, download_link, export_status, qtype_row
|
| 80 |
+
):
|
| 81 |
+
"""Setup event handlers for the exporter tab."""
|
| 82 |
+
|
| 83 |
+
# Toggle quantization row visibility
|
| 84 |
+
quantize_toggle.change(
|
| 85 |
+
fn=lambda checked: gr.update(visible=checked),
|
| 86 |
+
inputs=[quantize_toggle],
|
| 87 |
+
outputs=qtype_row,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Clear download link after use
|
| 91 |
+
def clear_download_link():
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
export_btn.click(
|
| 95 |
+
fn=lambda inc, q, qt, fmt: export_merged_model(
|
| 96 |
+
include_lora=inc,
|
| 97 |
+
quantize=q and (qt != "none"),
|
| 98 |
+
qtype=qt if qt != "none" else None,
|
| 99 |
+
save_format=fmt,
|
| 100 |
+
),
|
| 101 |
+
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
|
| 102 |
+
outputs=[download_link, export_status],
|
| 103 |
+
).then(
|
| 104 |
+
fn=clear_download_link,
|
| 105 |
+
outputs=[download_link],
|
| 106 |
+
)
|
src/ui/generator_tab.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image generator tab for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from ..config import DEFAULT_PROMPT, DEFAULT_NEGATIVE_PROMPT
|
| 6 |
+
from ..generator import generate_image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_generator_tab():
|
| 10 |
+
"""Create the image generation tab with all input controls."""
|
| 11 |
+
|
| 12 |
+
with gr.Accordion("🎨 2. Generate Image", open=True, elem_classes=["feature-card"]):
|
| 13 |
+
# Prompts section
|
| 14 |
+
with gr.Row():
|
| 15 |
+
with gr.Column(scale=1):
|
| 16 |
+
prompt = gr.Textbox(
|
| 17 |
+
label="Positive Prompt",
|
| 18 |
+
value=DEFAULT_PROMPT,
|
| 19 |
+
lines=3,
|
| 20 |
+
placeholder="Describe the image you want to generate..."
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
cfg = gr.Slider(
|
| 24 |
+
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
|
| 25 |
+
label="CFG Scale",
|
| 26 |
+
info="Higher values make outputs match prompt more strictly"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
height = gr.Number(
|
| 30 |
+
value=1024, precision=0,
|
| 31 |
+
label="Height (pixels)",
|
| 32 |
+
info="Output image height"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
with gr.Column(scale=1):
|
| 36 |
+
negative_prompt = gr.Textbox(
|
| 37 |
+
label="Negative Prompt",
|
| 38 |
+
value=DEFAULT_NEGATIVE_PROMPT,
|
| 39 |
+
lines=3,
|
| 40 |
+
placeholder="Elements to avoid in generation..."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
steps = gr.Slider(
|
| 44 |
+
minimum=1, maximum=100, value=25, step=1,
|
| 45 |
+
label="Inference Steps",
|
| 46 |
+
info="More steps = better quality but slower"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
width = gr.Number(
|
| 50 |
+
value=2048, precision=0,
|
| 51 |
+
label="Width (pixels)",
|
| 52 |
+
info="Output image width"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Tiling options
|
| 56 |
+
with gr.Row():
|
| 57 |
+
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
|
| 58 |
+
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
|
| 59 |
+
|
| 60 |
+
# Generate button and outputs
|
| 61 |
+
with gr.Row():
|
| 62 |
+
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
|
| 63 |
+
|
| 64 |
+
with gr.Row():
|
| 65 |
+
image_output = gr.Image(
|
| 66 |
+
label="Result",
|
| 67 |
+
height=400,
|
| 68 |
+
show_label=True
|
| 69 |
+
)
|
| 70 |
+
with gr.Column():
|
| 71 |
+
gen_status = gr.Textbox(
|
| 72 |
+
label="Generation Status",
|
| 73 |
+
placeholder="Ready to generate..."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Quick tips
|
| 77 |
+
gr.HTML("""
|
| 78 |
+
<div style="margin-top: 16px; padding: 12px; background: #f3f4f6; border-radius: 8px;">
|
| 79 |
+
<strong>💡 Tips:</strong>
|
| 80 |
+
<ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em;">
|
| 81 |
+
<li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li>
|
| 82 |
+
<li>Enable seamless tiling for texture-like outputs</li>
|
| 83 |
+
<li>Lower CFG (3-5) for more creative results</li>
|
| 84 |
+
</ul>
|
| 85 |
+
</div>
|
| 86 |
+
""")
|
| 87 |
+
|
| 88 |
+
return (
|
| 89 |
+
prompt, negative_prompt, cfg, steps, height, width,
|
| 90 |
+
tile_x, tile_y, gen_btn, image_output, gen_status
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def setup_generator_events(
|
| 95 |
+
prompt, negative_prompt, cfg, steps, height, width,
|
| 96 |
+
tile_x, tile_y, gen_btn, image_output, gen_status
|
| 97 |
+
):
|
| 98 |
+
"""Setup event handlers for the generator tab."""
|
| 99 |
+
|
| 100 |
+
gen_btn.click(
|
| 101 |
+
fn=generate_image,
|
| 102 |
+
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 103 |
+
outputs=[image_output, gen_status],
|
| 104 |
+
)
|
src/ui/header.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Header component with title and styling for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_header():
|
| 7 |
+
"""Create the header section with title, description, and custom styling."""
|
| 8 |
+
|
| 9 |
+
# Custom CSS for modern look
|
| 10 |
+
css = """
|
| 11 |
+
/* Header gradient text */
|
| 12 |
+
.header-gradient {
|
| 13 |
+
background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%);
|
| 14 |
+
-webkit-background-clip: text;
|
| 15 |
+
-webkit-text-fill-color: transparent;
|
| 16 |
+
background-clip: text;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
/* Feature cards */
|
| 20 |
+
.feature-card {
|
| 21 |
+
border-radius: 12px;
|
| 22 |
+
padding: 20px;
|
| 23 |
+
margin-bottom: 16px;
|
| 24 |
+
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
|
| 25 |
+
transition: transform 0.2s ease;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.feature-card:hover {
|
| 29 |
+
transform: translateY(-2px);
|
| 30 |
+
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
/* Label styling */
|
| 34 |
+
.gradio-container .label {
|
| 35 |
+
font-weight: 600;
|
| 36 |
+
color: #374151;
|
| 37 |
+
margin-bottom: 8px;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/* Status message colors */
|
| 41 |
+
.status-success {
|
| 42 |
+
color: #059669 !important;
|
| 43 |
+
font-weight: 600;
|
| 44 |
+
}
|
| 45 |
+
.status-error {
|
| 46 |
+
color: #dc2626 !important;
|
| 47 |
+
font-weight: 600;
|
| 48 |
+
}
|
| 49 |
+
.status-warning {
|
| 50 |
+
color: #d97706 !important;
|
| 51 |
+
font-weight: 600;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/* Button improvements */
|
| 55 |
+
.gradio-container .btn {
|
| 56 |
+
border-radius: 8px;
|
| 57 |
+
padding: 12px 24px;
|
| 58 |
+
font-weight: 600;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
/* Input field styling */
|
| 62 |
+
.gradio-container textarea,
|
| 63 |
+
.gradio-container input[type="number"],
|
| 64 |
+
.gradio-container input[type="text"] {
|
| 65 |
+
border-radius: 8px;
|
| 66 |
+
border-color: #d1d5db;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.gradio-container textarea:focus,
|
| 70 |
+
.gradio-container input:focus {
|
| 71 |
+
outline: none;
|
| 72 |
+
border-color: #6366f1;
|
| 73 |
+
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/* Tab styling */
|
| 77 |
+
.gradio-container .tabitem {
|
| 78 |
+
background: transparent;
|
| 79 |
+
border-radius: 12px;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/* Progress bar improvements */
|
| 83 |
+
.gradio-container .progress-bar {
|
| 84 |
+
border-radius: 8px;
|
| 85 |
+
overflow: hidden;
|
| 86 |
+
}
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
with gr.Column(elem_classes=["feature-card"]):
|
| 90 |
+
gr.HTML("""
|
| 91 |
+
<div style="text-align: center; margin-bottom: 24px;">
|
| 92 |
+
<h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;">
|
| 93 |
+
<span class="header-gradient">SDXL Model Merger</span>
|
| 94 |
+
</h1>
|
| 95 |
+
<p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;">
|
| 96 |
+
Merge checkpoints, LoRAs, and VAEs — then bake LoRAs into a single exportable
|
| 97 |
+
checkpoint with optional quantization.
|
| 98 |
+
</p>
|
| 99 |
+
</div>
|
| 100 |
+
""")
|
| 101 |
+
|
| 102 |
+
# Feature highlights
|
| 103 |
+
with gr.Row():
|
| 104 |
+
with gr.Column(scale=1):
|
| 105 |
+
gr.HTML("""
|
| 106 |
+
<div style="text-align: center; padding: 16px;">
|
| 107 |
+
<div style="font-size: 2.5em; margin-bottom: 8px;">🚀</div>
|
| 108 |
+
<strong>Fast Loading</strong>
|
| 109 |
+
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p>
|
| 110 |
+
</div>
|
| 111 |
+
""")
|
| 112 |
+
with gr.Column(scale=1):
|
| 113 |
+
gr.HTML("""
|
| 114 |
+
<div style="text-align: center; padding: 16px;">
|
| 115 |
+
<div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div>
|
| 116 |
+
<strong>Panorama Gen</strong>
|
| 117 |
+
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p>
|
| 118 |
+
</div>
|
| 119 |
+
""")
|
| 120 |
+
with gr.Column(scale=1):
|
| 121 |
+
gr.HTML("""
|
| 122 |
+
<div style="text-align: center; padding: 16px;">
|
| 123 |
+
<div style="font-size: 2.5em; margin-bottom: 8px;">📦</div>
|
| 124 |
+
<strong>Export Ready</strong>
|
| 125 |
+
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p>
|
| 126 |
+
</div>
|
| 127 |
+
""")
|
| 128 |
+
|
| 129 |
+
return css
|
src/ui/loader_tab.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline loader tab for SDXL Model Merger."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from ..config import (
|
| 6 |
+
DEFAULT_CHECKPOINT_URL,
|
| 7 |
+
DEFAULT_VAE_URL,
|
| 8 |
+
DEFAULT_LORA_URLS,
|
| 9 |
+
)
|
| 10 |
+
from ..pipeline import load_pipeline, cancel_download
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_loader_tab():
|
| 14 |
+
"""Create the pipeline loading tab with all input controls."""
|
| 15 |
+
|
| 16 |
+
with gr.Accordion("⚙️ 1. Load Pipeline", open=True, elem_classes=["feature-card"]):
|
| 17 |
+
with gr.Row():
|
| 18 |
+
with gr.Column(scale=2):
|
| 19 |
+
# Checkpoint URL
|
| 20 |
+
checkpoint_url = gr.Textbox(
|
| 21 |
+
label="Base Model (.safetensors) URL",
|
| 22 |
+
value=DEFAULT_CHECKPOINT_URL,
|
| 23 |
+
placeholder="e.g., https://civitai.com/api/download/models/...",
|
| 24 |
+
info="Download link for the base SDXL checkpoint"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# VAE URL (optional)
|
| 28 |
+
vae_url = gr.Textbox(
|
| 29 |
+
label="VAE (.safetensors) URL",
|
| 30 |
+
value=DEFAULT_VAE_URL,
|
| 31 |
+
placeholder="Leave blank to use model's built-in VAE",
|
| 32 |
+
info="Optional custom VAE for improved quality"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
with gr.Column(scale=1):
|
| 36 |
+
# LoRA URLs
|
| 37 |
+
lora_urls = gr.Textbox(
|
| 38 |
+
label="LoRA URLs (one per line)",
|
| 39 |
+
lines=5,
|
| 40 |
+
value=DEFAULT_LORA_URLS,
|
| 41 |
+
placeholder="https://civit.ai/...\nhttps://huggingface.co/...",
|
| 42 |
+
info="Multiple LoRAs can be loaded and fused together"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# LoRA strengths
|
| 46 |
+
lora_strengths = gr.Textbox(
|
| 47 |
+
label="LoRA Strengths",
|
| 48 |
+
value="1.0",
|
| 49 |
+
placeholder="e.g., 0.8,1.0,0.5",
|
| 50 |
+
info="Comma-separated strength values for each LoRA"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Action buttons
|
| 54 |
+
with gr.Row():
|
| 55 |
+
load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg")
|
| 56 |
+
cancel_btn = gr.Button("🛑 Cancel Download", variant="stop", size="lg")
|
| 57 |
+
|
| 58 |
+
# Status output
|
| 59 |
+
load_status = gr.Textbox(
|
| 60 |
+
label="Status",
|
| 61 |
+
placeholder="Ready to load pipeline...",
|
| 62 |
+
show_label=True,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return (
|
| 66 |
+
checkpoint_url, vae_url, lora_urls, lora_strengths,
|
| 67 |
+
load_btn, cancel_btn, load_status
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def setup_loader_events(
|
| 72 |
+
checkpoint_url, vae_url, lora_urls, lora_strengths,
|
| 73 |
+
load_btn, cancel_btn, load_status
|
| 74 |
+
):
|
| 75 |
+
"""Setup event handlers for the loader tab."""
|
| 76 |
+
|
| 77 |
+
load_btn.click(
|
| 78 |
+
fn=load_pipeline,
|
| 79 |
+
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 80 |
+
outputs=load_status,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
cancel_btn.click(fn=cancel_download)
|