Kyle Pearson commited on
Commit
6a07ce1
·
1 Parent(s): b807b57
app.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SDXL Model Merger - Modernized with modular architecture and improved UI/UX.
4
+
5
+ This application allows you to:
6
+ - Load SDXL checkpoints with optional VAE and multiple LoRAs
7
+ - Generate images with seamless tiling support
8
+ - Export merged models with quantization options
9
+
10
+ Author: Qwen Code Assistant
11
+ """
12
+
13
+ import gradio as gr
14
+
15
+
16
+ def create_app():
17
+ """Create and configure the Gradio app."""
18
+
19
+ header_css = """
20
+ .header-gradient {
21
+ background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%);
22
+ -webkit-background-clip: text;
23
+ -webkit-text-fill-color: transparent;
24
+ background-clip: text;
25
+ }
26
+
27
+ .feature-card {
28
+ border-radius: 12px;
29
+ padding: 20px;
30
+ margin-bottom: 16px;
31
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
32
+ transition: transform 0.2s ease;
33
+ }
34
+
35
+ .feature-card:hover {
36
+ transform: translateY(-2px);
37
+ box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
38
+ }
39
+
40
+ .gradio-container .label {
41
+ font-weight: 600;
42
+ color: #374151;
43
+ margin-bottom: 8px;
44
+ }
45
+
46
+ .status-success { color: #059669 !important; font-weight: 600; }
47
+ .status-error { color: #dc2626 !important; font-weight: 600; }
48
+ .status-warning { color: #d97706 !important; font-weight: 600; }
49
+
50
+ .gradio-container .btn {
51
+ border-radius: 8px;
52
+ padding: 12px 24px;
53
+ font-weight: 600;
54
+ }
55
+
56
+ .gradio-container textarea,
57
+ .gradio-container input[type="number"],
58
+ .gradio-container input[type="text"] {
59
+ border-radius: 8px;
60
+ border-color: #d1d5db;
61
+ }
62
+
63
+ .gradio-container textarea:focus,
64
+ .gradio-container input:focus {
65
+ outline: none;
66
+ border-color: #6366f1;
67
+ box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1);
68
+ }
69
+
70
+ .gradio-container .tabitem {
71
+ background: transparent;
72
+ border-radius: 12px;
73
+ }
74
+
75
+ .progress-text {
76
+ font-weight: 500;
77
+ color: #6b7280 !important;
78
+ }
79
+ """
80
+
81
+ from src.pipeline import load_pipeline, cancel_download
82
+ from src.generator import generate_image
83
+ from src.exporter import export_merged_model
84
+ from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
85
+
86
+ with gr.Blocks(title="SDXL Model Merger", css=header_css) as demo:
87
+ # Header section
88
+ with gr.Column(elem_classes=["feature-card"]):
89
+ gr.HTML("""
90
+ <div style="text-align: center; margin-bottom: 24px;">
91
+ <h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;">
92
+ <span class="header-gradient">SDXL Model Merger</span>
93
+ </h1>
94
+ <p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;">
95
+ Merge checkpoints, LoRAs, and VAEs - then bake LoRAs into a single exportable
96
+ checkpoint with optional quantization.
97
+ </p>
98
+ </div>
99
+ """)
100
+
101
+ # Feature highlights
102
+ with gr.Row():
103
+ with gr.Column(scale=1):
104
+ gr.HTML("""
105
+ <div style="text-align: center; padding: 16px;">
106
+ <div style="font-size: 2.5em; margin-bottom: 8px;">🚀</div>
107
+ <strong>Fast Loading</strong>
108
+ <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p>
109
+ </div>
110
+ """)
111
+ with gr.Column(scale=1):
112
+ gr.HTML("""
113
+ <div style="text-align: center; padding: 16px;">
114
+ <div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div>
115
+ <strong>Panorama Gen</strong>
116
+ <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p>
117
+ </div>
118
+ """)
119
+ with gr.Column(scale=1):
120
+ gr.HTML("""
121
+ <div style="text-align: center; padding: 16px;">
122
+ <div style="font-size: 2.5em; margin-bottom: 8px;">📦</div>
123
+ <strong>Export Ready</strong>
124
+ <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p>
125
+ </div>
126
+ """)
127
+
128
+ gr.Markdown("---")
129
+
130
+ with gr.Tab("Load Pipeline"):
131
+ gr.Markdown("### Load SDXL Pipeline with Checkpoint, VAE, and LoRAs")
132
+
133
+ # Progress indicator for pipeline loading
134
+ load_progress = gr.Textbox(
135
+ label="Loading Progress",
136
+ placeholder="Ready to start...",
137
+ show_label=True,
138
+ info="Real-time status of model downloads and pipeline setup"
139
+ )
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=2):
143
+ # Checkpoint URL with cached models dropdown
144
+ checkpoint_url = gr.Textbox(
145
+ label="Base Model (.safetensors) URL",
146
+ value="https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
147
+ placeholder="e.g., https://civitai.com/api/download/models/...",
148
+ info="Download link for the base SDXL checkpoint"
149
+ )
150
+
151
+ # Dropdown of cached checkpoints
152
+ cached_checkpoints = gr.Dropdown(
153
+ choices=["(None found)"] + get_cached_checkpoints(),
154
+ label="Cached Checkpoints",
155
+ value="(None found)" if not get_cached_checkpoints() else None,
156
+ info="Models already downloaded to .cache/"
157
+ )
158
+
159
+ # VAE URL
160
+ vae_url = gr.Textbox(
161
+ label="VAE (.safetensors) URL",
162
+ value="https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
163
+ placeholder="Leave blank to use model's built-in VAE",
164
+ info="Optional custom VAE for improved quality"
165
+ )
166
+
167
+ # Dropdown of cached VAEs
168
+ cached_vaes = gr.Dropdown(
169
+ choices=["(None found)"] + get_cached_vaes(),
170
+ label="Cached VAEs",
171
+ value="(None found)" if not get_cached_vaes() else None,
172
+ info="Select a VAE to load"
173
+ )
174
+
175
+ with gr.Column(scale=1):
176
+ # LoRA URLs input
177
+ lora_urls = gr.Textbox(
178
+ label="LoRA URLs (one per line)",
179
+ lines=5,
180
+ value="https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor",
181
+ placeholder="https://civit.ai/...\nhttps://huggingface.co/...",
182
+ info="Multiple LoRAs can be loaded and fused together"
183
+ )
184
+
185
+ # Dropdown of cached LoRAs
186
+ cached_loras = gr.Dropdown(
187
+ choices=["(None found)"] + get_cached_loras(),
188
+ label="Cached LoRAs",
189
+ value="(None found)" if not get_cached_loras() else None,
190
+ info="Select a LoRA to add to the list below"
191
+ )
192
+
193
+ lora_strengths = gr.Textbox(
194
+ label="LoRA Strengths",
195
+ value="1.0",
196
+ placeholder="e.g., 0.8,1.0,0.5",
197
+ info="Comma-separated strength values for each LoRA"
198
+ )
199
+
200
+ with gr.Row():
201
+ load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg")
202
+
203
+ # Detailed status display
204
+ load_status = gr.HTML(
205
+ label="Status",
206
+ value='<div class="status-success">✅ Ready to load pipeline</div>',
207
+ )
208
+
209
+ with gr.Tab("Generate Image"):
210
+ gr.Markdown("### Generate Panorama Images with Seamless Tiling")
211
+
212
+ # Progress indicator for image generation
213
+ gen_progress = gr.Textbox(
214
+ label="Generation Progress",
215
+ placeholder="Ready to generate...",
216
+ show_label=True,
217
+ info="Real-time status of image generation"
218
+ )
219
+
220
+ with gr.Row():
221
+ with gr.Column(scale=1):
222
+ prompt = gr.Textbox(
223
+ label="Positive Prompt",
224
+ value="Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic",
225
+ lines=4,
226
+ placeholder="Describe the image you want to generate..."
227
+ )
228
+
229
+ cfg = gr.Slider(
230
+ minimum=1.0, maximum=20.0, value=7.5, step=0.5,
231
+ label="CFG Scale",
232
+ info="Higher values make outputs match prompt more strictly"
233
+ )
234
+
235
+ height = gr.Number(
236
+ value=1024, precision=0,
237
+ label="Height (pixels)",
238
+ info="Output image height"
239
+ )
240
+
241
+ with gr.Column(scale=1):
242
+ negative_prompt = gr.Textbox(
243
+ label="Negative Prompt",
244
+ value="boring, text, signature, watermark, low quality, bad quality",
245
+ lines=4,
246
+ placeholder="Elements to avoid in generation..."
247
+ )
248
+
249
+ steps = gr.Slider(
250
+ minimum=1, maximum=100, value=25, step=1,
251
+ label="Inference Steps",
252
+ info="More steps = better quality but slower"
253
+ )
254
+
255
+ width = gr.Number(
256
+ value=2048, precision=0,
257
+ label="Width (pixels)",
258
+ info="Output image width"
259
+ )
260
+
261
+ with gr.Row():
262
+ tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
263
+ tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
264
+
265
+ with gr.Row():
266
+ gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
267
+
268
+ with gr.Row():
269
+ image_output = gr.Image(
270
+ label="Result",
271
+ height=400,
272
+ show_label=True
273
+ )
274
+ with gr.Column():
275
+ gen_status = gr.HTML(
276
+ label="Generation Status",
277
+ value='<div class="status-success">✅ Ready to generate</div>',
278
+ )
279
+
280
+ gr.HTML("""
281
+ <div style="margin-top: 16px; padding: 12px; background-color: #e5e7eb !important; border-radius: 8px;">
282
+ <strong style="color: #1f2937 !important;">💡 Tips:</strong>
283
+ <ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em; color: #1f2937 !important;">
284
+ <li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li>
285
+ <li>Enable seamless tiling for texture-like outputs</li>
286
+ <li>Lower CFG (3-5) for more creative results</li>
287
+ </ul>
288
+ </div>
289
+ """)
290
+
291
+ with gr.Tab("Export Model"):
292
+ gr.Markdown("### Export Merged Checkpoint with Quantization Options")
293
+
294
+ # Progress indicator for export
295
+ export_progress = gr.Textbox(
296
+ label="Export Progress",
297
+ placeholder="Ready to export...",
298
+ show_label=True,
299
+ info="Real-time status of model export and quantization"
300
+ )
301
+
302
+ with gr.Row():
303
+ include_lora = gr.Checkbox(
304
+ True,
305
+ label="Include Fused LoRAs",
306
+ info="Bake the loaded LoRAs into the exported model"
307
+ )
308
+
309
+ quantize_toggle = gr.Checkbox(
310
+ False,
311
+ label="Apply Quantization",
312
+ info="Reduce model size with quantization"
313
+ )
314
+
315
+ qtype_row = gr.Row(visible=True)
316
+ with qtype_row:
317
+ qtype_dropdown = gr.Dropdown(
318
+ choices=["none", "int8", "int4", "float8"],
319
+ value="int8",
320
+ label="Quantization Method",
321
+ info="Trade quality for smaller file size"
322
+ )
323
+
324
+ with gr.Row():
325
+ format_dropdown = gr.Dropdown(
326
+ choices=["safetensors", "bin"],
327
+ value="safetensors",
328
+ label="Export Format",
329
+ info="safetensors is recommended for safety"
330
+ )
331
+
332
+ with gr.Row():
333
+ export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg")
334
+
335
+ with gr.Row():
336
+ download_link = gr.File(
337
+ label="Download Merged File",
338
+ show_label=True,
339
+ )
340
+
341
+ with gr.Column():
342
+ export_status = gr.HTML(
343
+ label="Export Status",
344
+ value='<div class="status-success">✅ Ready to export</div>',
345
+ )
346
+
347
+ gr.HTML("""
348
+ <div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
349
+ <strong>ℹ️ About Quantization:</strong>
350
+ <p style="font-size: 0.9em; margin: 8px 0;">
351
+ Reduces model size by lowering precision. Int8 is typically
352
+ lossless for inference while cutting size in half.
353
+ </p>
354
+ </div>
355
+ """)
356
+
357
+ # Event handlers - all inside Blocks context
358
+
359
+ def on_load_pipeline_start():
360
+ """Called when pipeline loading starts."""
361
+ return (
362
+ '<div class="status-warning">⏳ Loading started...</div>',
363
+ "Starting download...",
364
+ gr.update(interactive=False)
365
+ )
366
+
367
+ def on_load_pipeline_complete(status_msg, progress_text):
368
+ """Called when pipeline loading completes."""
369
+ if "✅" in status_msg:
370
+ return (
371
+ '<div class="status-success">✅ Pipeline loaded successfully!</div>',
372
+ progress_text,
373
+ gr.update(interactive=True)
374
+ )
375
+ elif "⚠️" in status_msg:
376
+ return (
377
+ '<div class="status-warning">⚠️ Download cancelled</div>',
378
+ progress_text,
379
+ gr.update(interactive=True)
380
+ )
381
+ else:
382
+ return (
383
+ f'<div class="status-error">{status_msg}</div>',
384
+ progress_text,
385
+ gr.update(interactive=True)
386
+ )
387
+
388
+ load_btn.click(
389
+ fn=on_load_pipeline_start,
390
+ inputs=[],
391
+ outputs=[load_status, load_progress, load_btn],
392
+ ).then(
393
+ fn=load_pipeline,
394
+ inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
395
+ outputs=[load_status, load_progress],
396
+ show_api=False,
397
+ )
398
+
399
+ def on_cached_checkpoint_change(cached_path):
400
+ """Update URL when a cached checkpoint is selected."""
401
+ if cached_path and cached_path != "(None found)":
402
+ return gr.update(value=f"file://{cached_path}")
403
+ return gr.update()
404
+
405
+ cached_checkpoints.change(
406
+ fn=lambda x: gr.update(value=f"file://{x}" if x and x != "(None found)" else ""),
407
+ inputs=cached_checkpoints,
408
+ outputs=checkpoint_url,
409
+ )
410
+
411
+ def on_cached_vae_change(cached_path):
412
+ """Update VAE URL when a cached VAE is selected."""
413
+ if cached_path and cached_path != "(None found)":
414
+ return gr.update(value=f"file://{cached_path}")
415
+ return gr.update()
416
+
417
+ cached_vaes.change(
418
+ fn=on_cached_vae_change,
419
+ inputs=cached_vaes,
420
+ outputs=vae_url,
421
+ )
422
+
423
+ def on_cached_lora_change(cached_path, current_urls):
424
+ """Add cached LoRA to the list."""
425
+ if cached_path and cached_path != "(None found)":
426
+ # Add new LoRA to existing URLs (avoid duplicate)
427
+ urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
428
+ if cached_path not in urls_list:
429
+ urls_list.append(cached_path)
430
+ return gr.update(value="\n".join(urls_list))
431
+ return gr.update()
432
+
433
+ cached_loras.change(
434
+ fn=on_cached_lora_change,
435
+ inputs=[cached_loras, lora_urls],
436
+ outputs=lora_urls,
437
+ )
438
+
439
+
440
+ def on_generate_start():
441
+ """Called when image generation starts."""
442
+ return (
443
+ '<div class="status-warning">⏳ Generating image...</div>',
444
+ "Starting generation...",
445
+ gr.update(interactive=False)
446
+ )
447
+
448
+ def on_generate_complete(status_msg, progress_text, image):
449
+ """Called when image generation completes."""
450
+ if image is None:
451
+ return (
452
+ f'<div class="status-error">{status_msg}</div>',
453
+ "",
454
+ gr.update(interactive=True),
455
+ gr.update()
456
+ )
457
+ else:
458
+ return (
459
+ '<div class="status-success">✅ Generation complete!</div>',
460
+ "Done",
461
+ gr.update(interactive=True),
462
+ gr.update(value=image)
463
+ )
464
+
465
+ gen_btn.click(
466
+ fn=on_generate_start,
467
+ inputs=[],
468
+ outputs=[gen_status, gen_progress, gen_btn],
469
+ ).then(
470
+ fn=generate_image,
471
+ inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
472
+ outputs=[image_output, gen_progress],
473
+ show_api=False,
474
+ ).then(
475
+ fn=lambda img, msg: on_generate_complete(msg, "Done", img),
476
+ inputs=[image_output, gen_progress],
477
+ outputs=[gen_status, gen_progress, gen_btn, image_output],
478
+ )
479
+
480
+ def on_export_start():
481
+ """Called when export starts."""
482
+ return (
483
+ '<div class="status-warning">⏳ Export started...</div>',
484
+ "Starting export...",
485
+ gr.update(interactive=False)
486
+ )
487
+
488
+ def on_export_complete(status_msg, progress_text, file_path):
489
+ """Called when export completes."""
490
+ if file_path is None:
491
+ return (
492
+ f'<div class="status-error">{status_msg}</div>',
493
+ "",
494
+ gr.update(interactive=True),
495
+ gr.update(value=None)
496
+ )
497
+ else:
498
+ return (
499
+ '<div class="status-success">✅ Export complete!</div>',
500
+ "Exported successfully",
501
+ gr.update(interactive=True),
502
+ gr.update(value=file_path)
503
+ )
504
+
505
+ export_btn.click(
506
+ fn=on_export_start,
507
+ inputs=[],
508
+ outputs=[export_status, export_progress, export_btn],
509
+ ).then(
510
+ fn=lambda inc, q, qt, fmt: export_merged_model(
511
+ include_lora=inc,
512
+ quantize=q and (qt != "none"),
513
+ qtype=qt if qt != "none" else None,
514
+ save_format=fmt,
515
+ ),
516
+ inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
517
+ outputs=[download_link, export_progress],
518
+ show_api=False,
519
+ ).then(
520
+ fn=lambda path, msg: on_export_complete(msg, "Exported", path),
521
+ inputs=[download_link, export_progress],
522
+ outputs=[export_status, export_progress, export_btn, download_link],
523
+ )
524
+
525
+ quantize_toggle.change(
526
+ fn=lambda checked: gr.update(visible=checked),
527
+ inputs=[quantize_toggle],
528
+ outputs=qtype_row,
529
+ )
530
+
531
+ return demo
532
+
533
+
534
+ def main():
535
+ """Create and launch the Gradio app."""
536
+ app = create_app()
537
+ # CSS is embedded in the Blocks, so we pass it to launch for Gradio 6+
538
+ app.launch()
539
+
540
+
541
+ if __name__ == "__main__":
542
+ main()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SDXL Model Merger - Dependencies
2
+
3
+ # Core ML frameworks
4
+ torch>=2.0.0
5
+ diffusers>=0.24.0
6
+ transformers>=4.35.0
7
+ safetensors>=0.4.0
8
+
9
+ # Image processing
10
+ Pillow>=10.0.0
11
+
12
+ # UI framework
13
+ gradio>=4.0.0
14
+
15
+ # Download utilities
16
+ tqdm>=4.65.0
17
+ requests>=2.31.0
18
+
19
+ # Optional: quantization support
20
+ optimum-quanto>=0.2.0
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """SDXL Model Merger - Modular SDXL pipeline management and generation."""
2
+
3
+ __version__ = "1.0.0"
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (233 Bytes). View file
 
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (270 Bytes). View file
 
src/__pycache__/config.cpython-311.pyc ADDED
Binary file (5.32 kB). View file
 
src/__pycache__/config.cpython-313.pyc ADDED
Binary file (1.44 kB). View file
 
src/__pycache__/downloader.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
src/__pycache__/downloader.cpython-313.pyc ADDED
Binary file (5.4 kB). View file
 
src/__pycache__/exporter.cpython-311.pyc ADDED
Binary file (6.9 kB). View file
 
src/__pycache__/exporter.cpython-313.pyc ADDED
Binary file (5.67 kB). View file
 
src/__pycache__/generator.cpython-311.pyc ADDED
Binary file (2.65 kB). View file
 
src/__pycache__/generator.cpython-313.pyc ADDED
Binary file (2.24 kB). View file
 
src/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
src/__pycache__/pipeline.cpython-313.pyc ADDED
Binary file (8.11 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration constants and global settings for SDXL Model Merger."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ # ──────────────────────────────────────────────
7
+ # Paths & Directories
8
+ # ──────────────────────────────────────────────
9
+ SCRIPT_DIR = Path.cwd()
10
+ CACHE_DIR = SCRIPT_DIR / ".cache"
11
+ CACHE_DIR.mkdir(exist_ok=True)
12
+
13
+ # ──────────────────────────────────────────────
14
+ # Default URLs
15
+ # ──────────────────────────────────────────────
16
+ DEFAULT_CHECKPOINT_URL = "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16"
17
+ DEFAULT_VAE_URL = "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true"
18
+ DEFAULT_LORA_URLS = "https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor"
19
+
20
+ # ──────────────────────────────────────────────
21
+ # PyTorch & Device Settings
22
+ # ──────────────────────────────────────────────
23
+ import torch
24
+
25
+
26
+ def get_device_info() -> tuple[str, str]:
27
+ """
28
+ Detect and return the optimal device for ML inference.
29
+
30
+ Returns:
31
+ Tuple of (device_name, device_description)
32
+ """
33
+ if torch.cuda.is_available():
34
+ device_name = "cuda"
35
+ device_desc = f"CUDA (GPU: {torch.cuda.get_device_name(0)})"
36
+ elif torch.backends.mps.is_available():
37
+ device_name = "mps"
38
+ device_desc = "Apple Silicon MPS"
39
+ else:
40
+ device_name = "cpu"
41
+ device_desc = "CPU (no GPU available)"
42
+
43
+ return device_name, device_desc
44
+
45
+
46
+ device, device_description = get_device_info()
47
+ dtype = torch.float16
48
+
49
+ print(f"🚀 Using device: {device_description}")
50
+
51
+ # ──────────────────────────────────────────────
52
+ # Global State
53
+ # ──────────────────────────────────────────────
54
+ pipe = None
55
+ download_cancelled = False
56
+
57
+ # ──────────────────────────────────────────────
58
+ # Generation Defaults
59
+ # ──────────────────────────────────────────────
60
+ DEFAULT_PROMPT = "Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic"
61
+ DEFAULT_NEGATIVE_PROMPT = "boring, text, signature, watermark, low quality, bad quality"
62
+
63
+ # ──────────────────────────────────────────────
64
+ # Model Presets (URLs for common models)
65
+ # ──────────────────────────────────────────────
66
+ MODEL_PRESETS = {
67
+ # Checkpoints
68
+ "DreamShaper XL v2": "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
69
+ "Realism Engine SDXL": "https://civitai.com/api/download/models/328799?type=Model&format=SafeTensor&size=full&fp=fp16",
70
+ "Juggernaut XL v9": "https://civitai.com/api/download/models/350565?type=Model&format=SafeTensor&size=full&fp=fp16",
71
+
72
+ # VAEs
73
+ "VAE-FP16 Fix": "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
74
+
75
+ # LoRAs
76
+ "Rainbow Color LoRA": "https://civitai.com/api/download/models/127983?type=Model&format=SafeTensor",
77
+ "More Details LoRA": "https://civitai.com/api/download/models/280590?type=Model&format=SafeTensor",
78
+ "Epic Realism LoRA": "https://civitai.com/api/download/models/346631?type=Model&format=SafeTensor",
79
+ }
80
+
81
+
82
+ def get_cached_models():
83
+ """Get list of cached model files."""
84
+ if not CACHE_DIR.exists():
85
+ return []
86
+
87
+ models = []
88
+ for file in sorted(CACHE_DIR.glob("*.safetensors")):
89
+ models.append(str(file))
90
+ return models
91
+
92
+
93
+ def get_cached_model_names():
94
+ """Get display names for cached models."""
95
+ models = get_cached_models()
96
+ return [str(m.name) for m in models]
97
+
98
+
99
+ def get_cached_checkpoints():
100
+ """Get list of cached checkpoint files (model_id_model.safetensors)."""
101
+ if not CACHE_DIR.exists():
102
+ return []
103
+
104
+ models = []
105
+ for file in sorted(CACHE_DIR.glob("*_model.safetensors")):
106
+ models.append(str(file))
107
+ return models
108
+
109
+
110
+ def get_cached_vaes():
111
+ """Get list of cached VAE files (model_id_*_vae.safetensors)."""
112
+ if not CACHE_DIR.exists():
113
+ return []
114
+
115
+ models = []
116
+ for file in sorted(CACHE_DIR.glob("*_vae.safetensors")):
117
+ models.append(str(file))
118
+ return models
119
+
120
+
121
+ def get_cached_loras():
122
+ """Get list of cached LoRA files (model_id_*_lora.safetensors)."""
123
+ if not CACHE_DIR.exists():
124
+ return []
125
+
126
+ models = []
127
+ for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
128
+ models.append(str(file))
129
+ return models
src/downloader.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download utilities for SDXL Model Merger with Gradio progress integration."""
2
+
3
+ import re
4
+ import requests
5
+ from pathlib import Path
6
+ from tqdm import tqdm as TqdmBase
7
+
8
+ from .config import download_cancelled
9
+
10
+
11
+ def extract_model_id(url: str) -> str | None:
12
+ """Extract CivitAI model ID from URL."""
13
+ match = re.search(r'/models/(\d+)', url)
14
+ return match.group(1) if match else None
15
+
16
+
17
+ def get_safe_filename_from_url(
18
+ url: str,
19
+ default_name: str = "model.safetensors",
20
+ suffix: str = "",
21
+ type_prefix: str | None = None
22
+ ) -> str:
23
+ """
24
+ Generate a safe filename with model ID from URL.
25
+
26
+ For CivitAI URLs like https://civitai.com/api/download/models/12345?type=...
27
+
28
+ Naming patterns:
29
+ - Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors
30
+ - VAE (suffix='_vae'): 12345_vae.safetensors or 12345_anime_vae.safetensors
31
+ - LoRA (suffix='_lora'): 12345_lora.safetensors or 12345_name_lora.safetensors
32
+
33
+ For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming.
34
+
35
+ Args:
36
+ url: The download URL
37
+ default_name: Fallback filename if extraction fails
38
+ suffix: Optional suffix to append before .safetensors (e.g., '_vae', '_lora')
39
+ type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors)
40
+ """
41
+ model_id = extract_model_id(url)
42
+
43
+ # If no CivitAI model ID, try to generate a name from HuggingFace path
44
+ if not model_id and "huggingface.co" in url:
45
+ # Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> vae)
46
+ try:
47
+ parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else ""
48
+ if parts:
49
+ # Get the repo name (second part after org/)
50
+ path_parts = [p for p in parts.split("/") if p]
51
+ if len(path_parts) >= 2:
52
+ repo_name = path_parts[1]
53
+ # Clean up and create a simple identifier
54
+ clean_repo = re.sub(r'[^a-zA-Z0-9]', '_', repo_name)[:30].strip('_')
55
+ if clean_repo:
56
+ model_id = f"hf_{clean_repo}"
57
+ except Exception:
58
+ pass
59
+
60
+ if not model_id:
61
+ return default_name
62
+
63
+ # Build the name portion: either clean name from URL or fallback
64
+ name_part = ""
65
+
66
+ # For VAE/LoRA types, prefer the suffix-based naming and skip Content-Disposition parsing
67
+ # to avoid double naming (e.g., sdxlvae_vae instead of just vae)
68
+ is_special_type = suffix in ("_vae", "_lora")
69
+
70
+ if not is_special_type:
71
+ try:
72
+ response = requests.head(url, timeout=10, allow_redirects=True)
73
+ cd = response.headers.get('Content-Disposition', '')
74
+ match = re.search(r'filename="([^"]+)"', cd)
75
+ if match:
76
+ filename = match.group(1)
77
+ # Extract base name without extension
78
+ base_name = Path(filename).stem
79
+ # Clean up the name (remove special chars)
80
+ clean_name = re.sub(r'[^\w\s-]', '', base_name)[:50]
81
+ clean_name = re.sub(r'[-\s]+', '_', clean_name.strip('-_'))
82
+ if clean_name:
83
+ name_part = clean_name
84
+ except Exception:
85
+ pass
86
+
87
+ # Build filename with model_id, optional type_prefix, optional name_part, and suffix
88
+ parts = [model_id]
89
+ if type_prefix:
90
+ parts.append(type_prefix)
91
+ if name_part:
92
+ parts.append(name_part)
93
+ if suffix:
94
+ # Avoid double underscores: only add separator if needed
95
+ if not suffix.startswith('_'):
96
+ parts.append('_' + suffix.lstrip('_'))
97
+ else:
98
+ parts.append(suffix)
99
+
100
+ return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors'
101
+
102
+
103
+ class TqdmGradio(TqdmBase):
104
+ """tqdm subclass that sends progress updates to Gradio's gr.Progress()"""
105
+
106
+ def __init__(self, *args, gradio_prog=None, **kwargs):
107
+ super().__init__(*args, **kwargs)
108
+ self.gradio_prog = gradio_prog
109
+ self.last_pct = 0
110
+
111
+ def update(self, n=1):
112
+ global download_cancelled
113
+ if download_cancelled:
114
+ raise KeyboardInterrupt("Download cancelled by user")
115
+ super().update(n)
116
+ if self.gradio_prog and self.total:
117
+ pct = int(100 * self.n / self.total)
118
+ # Only update UI every ~5% to avoid spamming
119
+ if pct != self.last_pct and pct % 5 == 0:
120
+ self.last_pct = pct
121
+ self.gradio_prog(pct / 100)
122
+
123
+
124
+ def set_download_cancelled(value: bool):
125
+ """Set the global download cancellation flag."""
126
+ global download_cancelled
127
+ download_cancelled = value
128
+
129
+
130
+ def get_cached_file_size(url: str) -> tuple[Path | None, int | None]:
131
+ """
132
+ Check if file exists in cache and matches expected size.
133
+ Returns (path, expected_size) or (None, None) if no valid cache.
134
+ """
135
+ # Simple implementation - would need URL-to-filename mapping for production
136
+ return None, None
137
+
138
+
139
+ def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
140
+ """
141
+ Download a file with Gradio-synced progress bar + cancel support.
142
+
143
+ Args:
144
+ url: File URL to download (http/https/file)
145
+ output_path: Destination path for downloaded file
146
+ progress_bar: Optional gr.Progress() object for UI updates
147
+
148
+ Returns:
149
+ Path to the downloaded file
150
+
151
+ Raises:
152
+ KeyboardInterrupt: If download is cancelled
153
+ requests.RequestException: If download fails
154
+ """
155
+ global download_cancelled
156
+ download_cancelled = False
157
+
158
+ # Handle local file:// URLs
159
+ if url.startswith("file://"):
160
+ local_path = Path(url[7:]) # Remove "file://" prefix
161
+ if local_path.exists():
162
+ import shutil
163
+ output_path.parent.mkdir(parents=True, exist_ok=True)
164
+ # Copy the file to cache location
165
+ shutil.copy2(str(local_path), str(output_path))
166
+
167
+ # Update progress bar for cached files
168
+ if progress_bar:
169
+ progress_bar(1.0)
170
+ return output_path
171
+ else:
172
+ raise FileNotFoundError(f"Local file not found: {local_path}")
173
+
174
+ # Cache check: if file exists and size matches URL's content-length, skip re-download
175
+ expected_size = None
176
+ try:
177
+ head = requests.head(url, timeout=10)
178
+ expected_size = int(head.headers.get('content-length', 0))
179
+ if output_path.exists() and output_path.stat().st_size == expected_size:
180
+ # Cache hit - still update progress to show completion
181
+ if progress_bar:
182
+ progress_bar(1.0)
183
+ return output_path # Cache hit!
184
+ except Exception:
185
+ pass # Skip cache validation on errors
186
+
187
+ output_path.parent.mkdir(parents=True, exist_ok=True)
188
+
189
+ session = requests.Session()
190
+ response = session.get(url, stream=True, timeout=30)
191
+ response.raise_for_status()
192
+
193
+ total_size = expected_size or int(response.headers.get('content-length', 0))
194
+ block_size = 8192
195
+
196
+ # Use TqdmGradio to sync progress with Gradio
197
+ tqdm_kwargs = {
198
+ 'unit': 'B',
199
+ 'unit_scale': True,
200
+ 'desc': f"Downloading {output_path.name}",
201
+ 'gradio_prog': progress_bar,
202
+ 'disable': False,
203
+ 'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]',
204
+ }
205
+
206
+ with open(output_path, "wb") as f:
207
+ try:
208
+ for data in TqdmGradio(
209
+ response.iter_content(block_size),
210
+ total=total_size // block_size if total_size else 0,
211
+ **tqdm_kwargs,
212
+ ):
213
+ if download_cancelled:
214
+ raise KeyboardInterrupt("Download cancelled by user")
215
+ f.write(data)
216
+ except KeyboardInterrupt:
217
+ # Clean partial file on cancel
218
+ output_path.unlink(missing_ok=True)
219
+ raise
220
+
221
+ return output_path
222
+
223
+
224
+ def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None):
225
+ """
226
+ Remove old cache files.
227
+
228
+ Args:
229
+ cache_dir: Cache directory path (defaults to config.CACHE_DIR)
230
+ keep_extensions: File extensions to preserve (default: ['.safetensors'])
231
+ """
232
+ if cache_dir is None:
233
+ from .config import CACHE_DIR
234
+ cache_dir = CACHE_DIR
235
+
236
+ if keep_extensions is None:
237
+ keep_extensions = ['.safetensors']
238
+
239
+ # Remove temp files
240
+ for file in cache_dir.glob("*.tmp"):
241
+ file.unlink()
242
+
243
+ # Optional: age-based cleanup (7 days)
244
+ # import time
245
+ # cutoff = time.time() - 86400 * 7
246
+ # for f in cache_dir.iterdir():
247
+ # if f.is_file() and f.stat().st_mtime < cutoff:
248
+ # f.unlink()
src/exporter.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model export functionality for SDXL Model Merger."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from safetensors.torch import save_file
8
+
9
+ from .config import SCRIPT_DIR, pipe as global_pipe
10
+
11
+
12
+ def export_merged_model(
13
+ include_lora: bool,
14
+ quantize: bool,
15
+ qtype: str,
16
+ save_format: str = "safetensors",
17
+ ) -> tuple[str | None, str]:
18
+ """
19
+ Export the merged pipeline model with optional LoRA baking and quantization.
20
+
21
+ Args:
22
+ include_lora: Whether to include fused LoRAs in export
23
+ quantize: Whether to apply quantization
24
+ qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
25
+ save_format: Output format - 'safetensors' or 'bin'
26
+
27
+ Returns:
28
+ Tuple of (output_path or None, status message)
29
+ """
30
+ if not global_pipe:
31
+ return None, "⚠️ Please load a pipeline first."
32
+
33
+ try:
34
+ # Step 1: Unload LoRAs
35
+ yield "💾 Exporting model...", "Unloading LoRAs..."
36
+ if include_lora:
37
+ global_pipe.unload_lora_weights()
38
+
39
+ merged_state_dict = {}
40
+
41
+ # Step 2: Extract UNet weights
42
+ yield "💾 Exporting model...", "Extracting UNet weights..."
43
+ for k, v in global_pipe.unet.state_dict().items():
44
+ merged_state_dict[f"unet.{k}"] = v.contiguous().half()
45
+
46
+ # Step 3: Extract text encoder weights
47
+ yield "💾 Exporting model...", "Extracting text encoders..."
48
+ for k, v in global_pipe.text_encoder.state_dict().items():
49
+ merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
50
+ for k, v in global_pipe.text_encoder_2.state_dict().items():
51
+ merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
52
+
53
+ # Step 4: Extract VAE weights
54
+ yield "💾 Exporting model...", "Extracting VAE weights..."
55
+ for k, v in global_pipe.vae.state_dict().items():
56
+ merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
57
+
58
+ # Step 5: Quantize if requested and optimum.quanto is available
59
+ try:
60
+ from optimum.quanto import quantize, QTensor
61
+ QUANTO_AVAILABLE = True
62
+ except ImportError:
63
+ QUANTO_AVAILABLE = False
64
+
65
+ if quantize and qtype != "none" and QUANTO_AVAILABLE:
66
+ yield "💾 Exporting model...", f"Applying {qtype} quantization..."
67
+
68
+ class FakeModel(torch.nn.Module):
69
+ pass
70
+
71
+ fake_model = FakeModel()
72
+ fake_model.__dict__.update(merged_state_dict)
73
+
74
+ # Select quantization method
75
+ if qtype == "int8":
76
+ from optimum.quanto import int8_weight_only
77
+ quantize(fake_model, int8_weight_only())
78
+ elif qtype == "int4":
79
+ from optimum.quanto import int4_weight_only
80
+ quantize(fake_model, int4_weight_only())
81
+ elif qtype == "float8":
82
+ from optimum.quanto import float8_dynamic_activation_float8_weight
83
+ quantize(fake_model, float8_dynamic_activation_float8_weight())
84
+ else:
85
+ raise ValueError(f"Unsupported qtype: {qtype}")
86
+
87
+ merged_state_dict = {
88
+ k: v.dequantize().half() if isinstance(v, QTensor) else v
89
+ for k, v in fake_model.state_dict().items()
90
+ }
91
+ elif quantize and not QUANTO_AVAILABLE:
92
+ return None, "❌ optimum.quanto not installed. Install with: pip install optimum-quanto"
93
+
94
+ # Step 6: Save model
95
+ yield "💾 Exporting model...", "Saving weights..."
96
+
97
+ ext = ".bin" if save_format == "bin" else ".safetensors"
98
+
99
+ # Build filename based on options
100
+ prefix = ""
101
+ if quantize and qtype != "none":
102
+ prefix = f"{qtype}_"
103
+
104
+ out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
105
+
106
+ # Save appropriately
107
+ if ext == ".bin":
108
+ torch.save(merged_state_dict, str(out_path))
109
+ else:
110
+ save_file(merged_state_dict, str(out_path))
111
+
112
+ size_gb = out_path.stat().st_size / 1024**3
113
+
114
+ if quantize and qtype != "none":
115
+ msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
116
+ else:
117
+ msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
118
+
119
+ yield "💾 Exporting model...", msg
120
+ return str(out_path), msg
121
+
122
+ except ImportError as e:
123
+ return None, f"❌ Missing dependency: {str(e)}"
124
+ except Exception as e:
125
+ return None, f"❌ Export failed: {str(e)}"
126
+
127
+
128
+ def get_export_status() -> str:
129
+ """Get current export capability status."""
130
+ try:
131
+ from optimum.quanto import quantize
132
+ return "✅ optimum.quanto available for quantization"
133
+ except ImportError:
134
+ return "ℹ️ Install optimum-quanto for quantization support"
src/generator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image generation functions for SDXL Model Merger."""
2
+
3
+ import torch
4
+
5
+ from .config import device, dtype, pipe as global_pipe
6
+ from .pipeline import enable_seamless_tiling
7
+
8
+
9
+ def generate_image(
10
+ prompt: str,
11
+ negative_prompt: str,
12
+ cfg: float,
13
+ steps: int,
14
+ height: int,
15
+ width: int,
16
+ tile_x: bool = True,
17
+ tile_y: bool = False,
18
+ ) -> tuple[object | None, str]:
19
+ """
20
+ Generate an image using the loaded SDXL pipeline.
21
+
22
+ Args:
23
+ prompt: Positive prompt for image generation
24
+ negative_prompt: Negative prompt to avoid certain elements
25
+ cfg: Classifier-Free Guidance scale (1.0-20.0)
26
+ steps: Number of inference steps (1-50)
27
+ height: Output image height in pixels
28
+ width: Output image width in pixels
29
+ tile_x: Enable seamless tiling on x-axis
30
+ tile_y: Enable seamless tiling on y-axis
31
+
32
+ Returns:
33
+ Tuple of (PIL Image or None, status message)
34
+ """
35
+ if not global_pipe:
36
+ return None, "⚠️ Please load a pipeline first."
37
+
38
+ # Enable seamless tiling on UNet & VAE decoder
39
+ enable_seamless_tiling(global_pipe.unet, tile_x=tile_x, tile_y=tile_y)
40
+ enable_seamless_tiling(global_pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)
41
+
42
+ yield "🎨 Generating image...", f"Steps: 0/{steps} | CFG: {cfg}"
43
+
44
+ generator = torch.Generator(device=device).manual_seed(42) # Fixed seed for reproducibility
45
+ result = global_pipe(
46
+ prompt=prompt,
47
+ negative_prompt=negative_prompt,
48
+ width=int(width),
49
+ height=int(height),
50
+ num_inference_steps=int(steps),
51
+ guidance_scale=float(cfg),
52
+ generator=generator,
53
+ )
54
+
55
+ image = result.images[0]
56
+ yield "🎨 Generating image...", f"✅ Complete! ({width}x{height})"
57
+
58
+
59
+ def set_pipeline(pipe):
60
+ """Set the global pipeline instance."""
61
+ global global_pipe
62
+ global_pipe = pipe
src/pipeline.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline management for SDXL Model Merger."""
2
+
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from diffusers import (
7
+ StableDiffusionXLPipeline,
8
+ AutoencoderKL,
9
+ DPMSolverSDEScheduler,
10
+ )
11
+
12
+ from .config import device, dtype, pipe as global_pipe, CACHE_DIR, download_cancelled, device_description
13
+ from .downloader import download_file_with_progress, get_safe_filename_from_url
14
+
15
+
16
+ def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
17
+ """Create patched forward for seamless tiling on Conv2d layers."""
18
+ original_forward = module._conv_forward
19
+
20
+ def patched_conv_forward(input, weight, bias):
21
+ if tile_x and tile_y:
22
+ input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
23
+ elif tile_x:
24
+ input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
25
+ input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
26
+ elif tile_y:
27
+ input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
28
+ input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
29
+ else:
30
+ return original_forward(input, weight, bias)
31
+
32
+ return torch.nn.functional.conv2d(
33
+ input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
34
+ )
35
+
36
+ return patched_conv_forward
37
+
38
+
39
+ def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
40
+ """
41
+ Enable seamless tiling on a model's Conv2d layers.
42
+
43
+ Args:
44
+ model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
45
+ tile_x: Enable tiling along x-axis
46
+ tile_y: Enable tiling along y-axis
47
+ """
48
+ for module in model.modules():
49
+ if isinstance(module, torch.nn.Conv2d):
50
+ pad_h = module.padding[0]
51
+ pad_w = module.padding[1]
52
+ if pad_h == 0 and pad_w == 0:
53
+ continue
54
+ module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)
55
+
56
+
57
+ def load_pipeline(
58
+ checkpoint_url: str,
59
+ vae_url: str,
60
+ lora_urls_str: str,
61
+ lora_strengths_str: str,
62
+ progress=None
63
+ ) -> tuple[str, str]:
64
+ """
65
+ Load SDXL pipeline with checkpoint, VAE, and LoRAs.
66
+
67
+ Args:
68
+ checkpoint_url: URL to base model .safetensors file
69
+ vae_url: Optional URL to VAE .safetensors file
70
+ lora_urls_str: Newline-separated URLs for LoRA models
71
+ lora_strengths_str: Comma-separated strength values for each LoRA
72
+ progress: Optional gr.Progress() object for UI updates
73
+
74
+ Returns:
75
+ Tuple of (final_status_message, progress_text)
76
+ """
77
+ global global_pipe, download_cancelled
78
+
79
+ try:
80
+ checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
81
+ checkpoint_path = CACHE_DIR / checkpoint_filename
82
+
83
+ # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
84
+ vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
85
+ vae_path = CACHE_DIR / vae_filename
86
+
87
+ # Download checkpoint
88
+ if progress:
89
+ progress(0.1, desc="Downloading base model...")
90
+ yield f"📥 Downloading {checkpoint_path.name}...", "Starting download..."
91
+ download_file_with_progress(checkpoint_url, checkpoint_path)
92
+
93
+ # Download VAE if provided
94
+ if vae_url.strip():
95
+ if progress:
96
+ progress(0.2, desc="Downloading VAE...")
97
+ yield f"📥 Downloading {vae_path.name}...", f"Downloading VAE: {vae_path.name}"
98
+ download_file_with_progress(vae_url, vae_path)
99
+ vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
100
+ else:
101
+ vae = None
102
+
103
+ # Load base pipeline
104
+ if progress:
105
+ progress(0.4, desc="Loading SDXL pipeline...")
106
+ yield f"⚙️ Loading pipeline...", f"Using device: {device_description}"
107
+ global_pipe = StableDiffusionXLPipeline.from_single_file(
108
+ str(checkpoint_path),
109
+ torch_dtype=dtype,
110
+ use_safetensors=True,
111
+ safety_checker=None,
112
+ )
113
+ if vae:
114
+ global_pipe.vae = vae.to(device=device, dtype=dtype)
115
+
116
+ # Parse LoRA URLs & ensure strengths list matches
117
+ lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
118
+ strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
119
+ strengths = []
120
+ for i, url in enumerate(lora_urls):
121
+ try:
122
+ val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0
123
+ strengths.append(val)
124
+ except ValueError:
125
+ strengths.append(1.0)
126
+
127
+ # Load and fuse each LoRA sequentially (only if URLs exist)
128
+ if lora_urls:
129
+ first_lora_filename = get_safe_filename_from_url(lora_urls[0], "lora_0.safetensors", suffix="_lora")
130
+ first_lora_path = CACHE_DIR / first_lora_filename
131
+ yield f"📥 Downloading LoRA: {first_lora_path.name}...", f"Downloading LoRA 1/... ({first_lora_path.name})..."
132
+ download_file_with_progress(lora_urls[0], first_lora_path)
133
+
134
+ global_pipe.load_lora_weights(str(first_lora_path), adapter_name="main_lora")
135
+ global_pipe.fuse_lora(adapter_names=["main_lora"], lora_scale=strengths[0])
136
+
137
+ for i in range(1, len(lora_urls)):
138
+ lora_filename = get_safe_filename_from_url(lora_urls[i], f"lora_{i}.safetensors", suffix="_lora")
139
+ lora_path = CACHE_DIR / lora_filename
140
+ yield f"📥 Downloading LoRA {i+1}...", f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..."
141
+ download_file_with_progress(lora_urls[i], lora_path)
142
+
143
+ global_pipe.unload_lora_weights()
144
+ global_pipe.load_lora_weights(str(lora_path), adapter_name=f"lora_{i}")
145
+ # Fuse all loaded adapters so far
146
+ global_pipe.fuse_lora(
147
+ adapter_names=["main_lora"] + [f"lora_{j}" for j in range(1, i+1)],
148
+ lora_scale=strengths[i]
149
+ )
150
+
151
+ # Set scheduler and move to device (do this once at the end)
152
+ yield "⚙️ Finalizing...", "Setting up scheduler..."
153
+ # Use existing scheduler, just update algorithm_type for DPM++ SDE
154
+ global_pipe.scheduler.config.algorithm_type = "sde-dpmsolver++"
155
+ global_pipe = global_pipe.to(device=device, dtype=dtype)
156
+
157
+ return ("✅ Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
158
+
159
+ except KeyboardInterrupt:
160
+ download_cancelled = False
161
+ return ("⚠️ Download cancelled by user", "Cancelled")
162
+ except Exception as e:
163
+ return (f"❌ Error loading pipeline: {str(e)}", f"Error: {str(e)}")
164
+
165
+
166
+ def cancel_download():
167
+ """Set the global cancellation flag to stop any ongoing downloads."""
168
+ global download_cancelled
169
+ download_cancelled = True
170
+
171
+
172
+ def get_pipeline() -> StableDiffusionXLPipeline | None:
173
+ """Get the currently loaded pipeline."""
174
+ return global_pipe
src/ui/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UI components for SDXL Model Merger."""
2
+
3
+ from .header import create_header
4
+ from .loader_tab import create_loader_tab
5
+ from .generator_tab import create_generator_tab
6
+ from .exporter_tab import create_exporter_tab
7
+
8
+ __all__ = [
9
+ "create_header",
10
+ "create_loader_tab",
11
+ "create_generator_tab",
12
+ "create_exporter_tab",
13
+ ]
src/ui/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (510 Bytes). View file
 
src/ui/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (477 Bytes). View file
 
src/ui/__pycache__/exporter_tab.cpython-311.pyc ADDED
Binary file (5.23 kB). View file
 
src/ui/__pycache__/exporter_tab.cpython-313.pyc ADDED
Binary file (4.41 kB). View file
 
src/ui/__pycache__/generator_tab.cpython-311.pyc ADDED
Binary file (5.4 kB). View file
 
src/ui/__pycache__/generator_tab.cpython-313.pyc ADDED
Binary file (4.51 kB). View file
 
src/ui/__pycache__/header.cpython-311.pyc ADDED
Binary file (5.51 kB). View file
 
src/ui/__pycache__/header.cpython-313.pyc ADDED
Binary file (5.09 kB). View file
 
src/ui/__pycache__/loader_tab.cpython-311.pyc ADDED
Binary file (3.68 kB). View file
 
src/ui/__pycache__/loader_tab.cpython-313.pyc ADDED
Binary file (3.07 kB). View file
 
src/ui/exporter_tab.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model exporter tab for SDXL Model Merger."""
2
+
3
+ import gradio as gr
4
+
5
+ from ..exporter import export_merged_model
6
+
7
+
8
+ def create_exporter_tab():
9
+ """Create the model export tab with all configuration options."""
10
+
11
+ with gr.Accordion("📦 3. Export Merged Model", open=True, elem_classes=["feature-card"]):
12
+ # Export settings
13
+ with gr.Row():
14
+ include_lora = gr.Checkbox(
15
+ True,
16
+ label="Include Fused LoRAs",
17
+ info="Bake the loaded LoRAs into the exported model"
18
+ )
19
+
20
+ quantize_toggle = gr.Checkbox(
21
+ False,
22
+ label="Apply Quantization",
23
+ info="Reduce model size with quantization"
24
+ )
25
+
26
+ # Quantization options
27
+ with gr.Row(visible=True) as qtype_row:
28
+ qtype_dropdown = gr.Dropdown(
29
+ choices=["none", "int8", "int4", "float8"],
30
+ value="int8",
31
+ label="Quantization Method",
32
+ info="Trade quality for smaller file size"
33
+ )
34
+
35
+ # Format options
36
+ with gr.Row():
37
+ format_dropdown = gr.Dropdown(
38
+ choices=["safetensors", "bin"],
39
+ value="safetensors",
40
+ label="Export Format",
41
+ info="safetensors is recommended for safety"
42
+ )
43
+
44
+ # Export button and output
45
+ with gr.Row():
46
+ export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg")
47
+
48
+ with gr.Row():
49
+ download_link = gr.File(
50
+ label="Download Merged File",
51
+ show_label=True,
52
+ )
53
+
54
+ with gr.Column():
55
+ export_status = gr.Textbox(
56
+ label="Export Status",
57
+ placeholder="Ready to export..."
58
+ )
59
+
60
+ # Info about quantization
61
+ gr.HTML("""
62
+ <div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
63
+ <strong>ℹ️ About Quantization:</strong>
64
+ <p style="font-size: 0.9em; margin: 8px 0;">
65
+ Reduces model size by lowering precision. Int8 is typically
66
+ lossless for inference while cutting size in half.
67
+ </p>
68
+ </div>
69
+ """)
70
+
71
+ return (
72
+ include_lora, quantize_toggle, qtype_dropdown, format_dropdown,
73
+ export_btn, download_link, export_status, qtype_row
74
+ )
75
+
76
+
77
+ def setup_exporter_events(
78
+ include_lora, quantize_toggle, qtype_dropdown, format_dropdown,
79
+ export_btn, download_link, export_status, qtype_row
80
+ ):
81
+ """Setup event handlers for the exporter tab."""
82
+
83
+ # Toggle quantization row visibility
84
+ quantize_toggle.change(
85
+ fn=lambda checked: gr.update(visible=checked),
86
+ inputs=[quantize_toggle],
87
+ outputs=qtype_row,
88
+ )
89
+
90
+ # Clear download link after use
91
+ def clear_download_link():
92
+ return None
93
+
94
+ export_btn.click(
95
+ fn=lambda inc, q, qt, fmt: export_merged_model(
96
+ include_lora=inc,
97
+ quantize=q and (qt != "none"),
98
+ qtype=qt if qt != "none" else None,
99
+ save_format=fmt,
100
+ ),
101
+ inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
102
+ outputs=[download_link, export_status],
103
+ ).then(
104
+ fn=clear_download_link,
105
+ outputs=[download_link],
106
+ )
src/ui/generator_tab.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image generator tab for SDXL Model Merger."""
2
+
3
+ import gradio as gr
4
+
5
+ from ..config import DEFAULT_PROMPT, DEFAULT_NEGATIVE_PROMPT
6
+ from ..generator import generate_image
7
+
8
+
9
+ def create_generator_tab():
10
+ """Create the image generation tab with all input controls."""
11
+
12
+ with gr.Accordion("🎨 2. Generate Image", open=True, elem_classes=["feature-card"]):
13
+ # Prompts section
14
+ with gr.Row():
15
+ with gr.Column(scale=1):
16
+ prompt = gr.Textbox(
17
+ label="Positive Prompt",
18
+ value=DEFAULT_PROMPT,
19
+ lines=3,
20
+ placeholder="Describe the image you want to generate..."
21
+ )
22
+
23
+ cfg = gr.Slider(
24
+ minimum=1.0, maximum=20.0, value=7.5, step=0.5,
25
+ label="CFG Scale",
26
+ info="Higher values make outputs match prompt more strictly"
27
+ )
28
+
29
+ height = gr.Number(
30
+ value=1024, precision=0,
31
+ label="Height (pixels)",
32
+ info="Output image height"
33
+ )
34
+
35
+ with gr.Column(scale=1):
36
+ negative_prompt = gr.Textbox(
37
+ label="Negative Prompt",
38
+ value=DEFAULT_NEGATIVE_PROMPT,
39
+ lines=3,
40
+ placeholder="Elements to avoid in generation..."
41
+ )
42
+
43
+ steps = gr.Slider(
44
+ minimum=1, maximum=100, value=25, step=1,
45
+ label="Inference Steps",
46
+ info="More steps = better quality but slower"
47
+ )
48
+
49
+ width = gr.Number(
50
+ value=2048, precision=0,
51
+ label="Width (pixels)",
52
+ info="Output image width"
53
+ )
54
+
55
+ # Tiling options
56
+ with gr.Row():
57
+ tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
58
+ tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
59
+
60
+ # Generate button and outputs
61
+ with gr.Row():
62
+ gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
63
+
64
+ with gr.Row():
65
+ image_output = gr.Image(
66
+ label="Result",
67
+ height=400,
68
+ show_label=True
69
+ )
70
+ with gr.Column():
71
+ gen_status = gr.Textbox(
72
+ label="Generation Status",
73
+ placeholder="Ready to generate..."
74
+ )
75
+
76
+ # Quick tips
77
+ gr.HTML("""
78
+ <div style="margin-top: 16px; padding: 12px; background: #f3f4f6; border-radius: 8px;">
79
+ <strong>💡 Tips:</strong>
80
+ <ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em;">
81
+ <li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li>
82
+ <li>Enable seamless tiling for texture-like outputs</li>
83
+ <li>Lower CFG (3-5) for more creative results</li>
84
+ </ul>
85
+ </div>
86
+ """)
87
+
88
+ return (
89
+ prompt, negative_prompt, cfg, steps, height, width,
90
+ tile_x, tile_y, gen_btn, image_output, gen_status
91
+ )
92
+
93
+
94
+ def setup_generator_events(
95
+ prompt, negative_prompt, cfg, steps, height, width,
96
+ tile_x, tile_y, gen_btn, image_output, gen_status
97
+ ):
98
+ """Setup event handlers for the generator tab."""
99
+
100
+ gen_btn.click(
101
+ fn=generate_image,
102
+ inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
103
+ outputs=[image_output, gen_status],
104
+ )
src/ui/header.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Header component with title and styling for SDXL Model Merger."""
2
+
3
+ import gradio as gr
4
+
5
+
6
+ def create_header():
7
+ """Create the header section with title, description, and custom styling."""
8
+
9
+ # Custom CSS for modern look
10
+ css = """
11
+ /* Header gradient text */
12
+ .header-gradient {
13
+ background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%);
14
+ -webkit-background-clip: text;
15
+ -webkit-text-fill-color: transparent;
16
+ background-clip: text;
17
+ }
18
+
19
+ /* Feature cards */
20
+ .feature-card {
21
+ border-radius: 12px;
22
+ padding: 20px;
23
+ margin-bottom: 16px;
24
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
25
+ transition: transform 0.2s ease;
26
+ }
27
+
28
+ .feature-card:hover {
29
+ transform: translateY(-2px);
30
+ box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
31
+ }
32
+
33
+ /* Label styling */
34
+ .gradio-container .label {
35
+ font-weight: 600;
36
+ color: #374151;
37
+ margin-bottom: 8px;
38
+ }
39
+
40
+ /* Status message colors */
41
+ .status-success {
42
+ color: #059669 !important;
43
+ font-weight: 600;
44
+ }
45
+ .status-error {
46
+ color: #dc2626 !important;
47
+ font-weight: 600;
48
+ }
49
+ .status-warning {
50
+ color: #d97706 !important;
51
+ font-weight: 600;
52
+ }
53
+
54
+ /* Button improvements */
55
+ .gradio-container .btn {
56
+ border-radius: 8px;
57
+ padding: 12px 24px;
58
+ font-weight: 600;
59
+ }
60
+
61
+ /* Input field styling */
62
+ .gradio-container textarea,
63
+ .gradio-container input[type="number"],
64
+ .gradio-container input[type="text"] {
65
+ border-radius: 8px;
66
+ border-color: #d1d5db;
67
+ }
68
+
69
+ .gradio-container textarea:focus,
70
+ .gradio-container input:focus {
71
+ outline: none;
72
+ border-color: #6366f1;
73
+ box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1);
74
+ }
75
+
76
+ /* Tab styling */
77
+ .gradio-container .tabitem {
78
+ background: transparent;
79
+ border-radius: 12px;
80
+ }
81
+
82
+ /* Progress bar improvements */
83
+ .gradio-container .progress-bar {
84
+ border-radius: 8px;
85
+ overflow: hidden;
86
+ }
87
+ """
88
+
89
+ with gr.Column(elem_classes=["feature-card"]):
90
+ gr.HTML("""
91
+ <div style="text-align: center; margin-bottom: 24px;">
92
+ <h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;">
93
+ <span class="header-gradient">SDXL Model Merger</span>
94
+ </h1>
95
+ <p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;">
96
+ Merge checkpoints, LoRAs, and VAEs — then bake LoRAs into a single exportable
97
+ checkpoint with optional quantization.
98
+ </p>
99
+ </div>
100
+ """)
101
+
102
+ # Feature highlights
103
+ with gr.Row():
104
+ with gr.Column(scale=1):
105
+ gr.HTML("""
106
+ <div style="text-align: center; padding: 16px;">
107
+ <div style="font-size: 2.5em; margin-bottom: 8px;">🚀</div>
108
+ <strong>Fast Loading</strong>
109
+ <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p>
110
+ </div>
111
+ """)
112
+ with gr.Column(scale=1):
113
+ gr.HTML("""
114
+ <div style="text-align: center; padding: 16px;">
115
+ <div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div>
116
+ <strong>Panorama Gen</strong>
117
+ <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p>
118
+ </div>
119
+ """)
120
+ with gr.Column(scale=1):
121
+ gr.HTML("""
122
+ <div style="text-align: center; padding: 16px;">
123
+ <div style="font-size: 2.5em; margin-bottom: 8px;">📦</div>
124
+ <strong>Export Ready</strong>
125
+ <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p>
126
+ </div>
127
+ """)
128
+
129
+ return css
src/ui/loader_tab.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline loader tab for SDXL Model Merger."""
2
+
3
+ import gradio as gr
4
+
5
+ from ..config import (
6
+ DEFAULT_CHECKPOINT_URL,
7
+ DEFAULT_VAE_URL,
8
+ DEFAULT_LORA_URLS,
9
+ )
10
+ from ..pipeline import load_pipeline, cancel_download
11
+
12
+
13
+ def create_loader_tab():
14
+ """Create the pipeline loading tab with all input controls."""
15
+
16
+ with gr.Accordion("⚙️ 1. Load Pipeline", open=True, elem_classes=["feature-card"]):
17
+ with gr.Row():
18
+ with gr.Column(scale=2):
19
+ # Checkpoint URL
20
+ checkpoint_url = gr.Textbox(
21
+ label="Base Model (.safetensors) URL",
22
+ value=DEFAULT_CHECKPOINT_URL,
23
+ placeholder="e.g., https://civitai.com/api/download/models/...",
24
+ info="Download link for the base SDXL checkpoint"
25
+ )
26
+
27
+ # VAE URL (optional)
28
+ vae_url = gr.Textbox(
29
+ label="VAE (.safetensors) URL",
30
+ value=DEFAULT_VAE_URL,
31
+ placeholder="Leave blank to use model's built-in VAE",
32
+ info="Optional custom VAE for improved quality"
33
+ )
34
+
35
+ with gr.Column(scale=1):
36
+ # LoRA URLs
37
+ lora_urls = gr.Textbox(
38
+ label="LoRA URLs (one per line)",
39
+ lines=5,
40
+ value=DEFAULT_LORA_URLS,
41
+ placeholder="https://civit.ai/...\nhttps://huggingface.co/...",
42
+ info="Multiple LoRAs can be loaded and fused together"
43
+ )
44
+
45
+ # LoRA strengths
46
+ lora_strengths = gr.Textbox(
47
+ label="LoRA Strengths",
48
+ value="1.0",
49
+ placeholder="e.g., 0.8,1.0,0.5",
50
+ info="Comma-separated strength values for each LoRA"
51
+ )
52
+
53
+ # Action buttons
54
+ with gr.Row():
55
+ load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg")
56
+ cancel_btn = gr.Button("🛑 Cancel Download", variant="stop", size="lg")
57
+
58
+ # Status output
59
+ load_status = gr.Textbox(
60
+ label="Status",
61
+ placeholder="Ready to load pipeline...",
62
+ show_label=True,
63
+ )
64
+
65
+ return (
66
+ checkpoint_url, vae_url, lora_urls, lora_strengths,
67
+ load_btn, cancel_btn, load_status
68
+ )
69
+
70
+
71
+ def setup_loader_events(
72
+ checkpoint_url, vae_url, lora_urls, lora_strengths,
73
+ load_btn, cancel_btn, load_status
74
+ ):
75
+ """Setup event handlers for the loader tab."""
76
+
77
+ load_btn.click(
78
+ fn=load_pipeline,
79
+ inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
80
+ outputs=load_status,
81
+ )
82
+
83
+ cancel_btn.click(fn=cancel_download)