#!/usr/bin/env python3 """ SDXL Model Merger - Modernized with modular architecture and improved UI/UX. This application allows you to: - Load SDXL checkpoints with optional VAE and multiple LoRAs - Generate images with seamless tiling support - Export merged models with quantization options Author: Qwen Code Assistant """ try: import spaces # noqa: F401 — must be imported before torch/CUDA packages except ImportError: pass import gradio as gr def create_app(): """Create and configure the Gradio app.""" header_css = """ .header-gradient { background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } .feature-card { border-radius: 12px; padding: 20px; margin-bottom: 16px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); transition: transform 0.2s ease; } .feature-card:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); } .gradio-container .label { font-weight: 600; color: #374151; margin-bottom: 8px; } .status-success { color: #059669 !important; font-weight: 600; } .status-error { color: #dc2626 !important; font-weight: 600; } .status-warning { color: #d97706 !important; font-weight: 600; } .gradio-container .btn { border-radius: 8px; padding: 12px 24px; font-weight: 600; } .gradio-container textarea, .gradio-container input[type="number"], .gradio-container input[type="text"] { border-radius: 8px; border-color: #d1d5db; } .gradio-container textarea:focus, .gradio-container input:focus { outline: none; border-color: #6366f1; box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1); } .gradio-container .tabitem { background: transparent; border-radius: 12px; } .progress-text { font-weight: 500; color: #6b7280 !important; } """ from src.pipeline import load_pipeline from src.generator import generate_image from src.exporter import export_merged_model from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras with gr.Blocks(title="SDXL Model Merger") as demo: # Header section with gr.Column(elem_classes=["feature-card"]): gr.HTML("""

SDXL Model Merger

Merge checkpoints, LoRAs, and VAEs - then bake LoRAs into a single exportable checkpoint with optional quantization.

""") # Feature highlights with gr.Row(): with gr.Column(scale=1): gr.HTML("""
🚀
Fast Loading

With progress tracking & cache

""") with gr.Column(scale=1): gr.HTML("""
🎨
Panorama Gen

Seamless tiling support

""") with gr.Column(scale=1): gr.HTML("""
đŸ“Ļ
Export Ready

Quantization & format options

""") gr.Markdown("---") with gr.Tab("Load Pipeline"): gr.Markdown("### Load SDXL Pipeline with Checkpoint, VAE, and LoRAs") # Progress indicator for pipeline loading load_progress = gr.Textbox( label="Loading Progress", placeholder="Ready to start...", show_label=True, info="Real-time status of model downloads and pipeline setup" ) with gr.Row(): with gr.Column(scale=2): # Checkpoint URL with cached models dropdown checkpoint_url = gr.Textbox( label="Base Model (.safetensors) URL", value="https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16", placeholder="e.g., https://civitai.com/api/download/models/...", info="Download link for the base SDXL checkpoint" ) # Dropdown of cached checkpoints cached_checkpoints = gr.Dropdown( choices=["(None found)"] + get_cached_checkpoints(), label="Cached Checkpoints", value="(None found)" if not get_cached_checkpoints() else None, info="Models already downloaded to .cache/" ) # VAE URL vae_url = gr.Textbox( label="VAE (.safetensors) URL", value="https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true", placeholder="Leave blank to use model's built-in VAE", info="Optional custom VAE for improved quality" ) # Dropdown of cached VAEs cached_vaes = gr.Dropdown( choices=["(None found)"] + get_cached_vaes(), label="Cached VAEs", value="(None found)" if not get_cached_vaes() else None, info="Select a VAE to load" ) with gr.Column(scale=1): # LoRA URLs input lora_urls = gr.Textbox( label="LoRA URLs (one per line)", lines=5, value="https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor", placeholder="https://civit.ai/...\nhttps://huggingface.co/...", info="Multiple LoRAs can be loaded and fused together" ) # Dropdown of cached LoRAs cached_loras = gr.Dropdown( choices=["(None found)"] + get_cached_loras(), label="Cached LoRAs", value="(None found)" if not get_cached_loras() else None, info="Select a LoRA to add to the list below" ) lora_strengths = gr.Textbox( label="LoRA Strengths", value="1.0", placeholder="e.g., 0.8,1.0,0.5", info="Comma-separated strength values for each LoRA" ) with gr.Row(): load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg") # Detailed status display load_status = gr.HTML( label="Status", value='
✅ Ready to load pipeline
', ) with gr.Tab("Generate Image"): gr.Markdown("### Generate Panorama Images with Seamless Tiling") # Progress indicator for image generation gen_progress = gr.Textbox( label="Generation Progress", placeholder="Ready to generate...", show_label=True, info="Real-time status of image generation" ) with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox( label="Positive Prompt", value="Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic", lines=4, placeholder="Describe the image you want to generate..." ) cfg = gr.Slider( minimum=1.0, maximum=20.0, value=3.0, step=0.5, label="CFG Scale", info="Higher values make outputs match prompt more strictly" ) height = gr.Number( value=1024, precision=0, label="Height (pixels)", info="Output image height" ) with gr.Column(scale=1): negative_prompt = gr.Textbox( label="Negative Prompt", value="boring, text, signature, watermark, low quality, bad quality", lines=4, placeholder="Elements to avoid in generation..." ) steps = gr.Slider( minimum=1, maximum=100, value=8, step=1, label="Inference Steps", info="More steps = better quality but slower" ) width = gr.Number( value=2048, precision=0, label="Width (pixels)", info="Output image width" ) with gr.Row(): tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling") tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling") seed = gr.Number( value=80484030936239, precision=0, label="Seed", info="Random seed for reproducible generation" ) with gr.Row(): gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg") with gr.Row(): image_output = gr.Image( label="Result", height=400, show_label=True ) with gr.Column(): gen_status = gr.HTML( label="Generation Status", value='
✅ Ready to generate
', ) gr.HTML("""
💡 Tips:
""") with gr.Tab("Export Model"): gr.Markdown("### Export Merged Checkpoint with Quantization Options") # Progress indicator for export export_progress = gr.Textbox( label="Export Progress", placeholder="Ready to export...", show_label=True, info="Real-time status of model export and quantization" ) with gr.Row(): include_lora = gr.Checkbox( True, label="Include Fused LoRAs", info="Bake the loaded LoRAs into the exported model" ) quantize_toggle = gr.Checkbox( False, label="Apply Quantization", info="Reduce model size with quantization" ) qtype_row = gr.Row(visible=True) with qtype_row: qtype_dropdown = gr.Dropdown( choices=["none", "int8", "int4", "float8"], value="int8", label="Quantization Method", info="Trade quality for smaller file size" ) with gr.Row(): format_dropdown = gr.Dropdown( choices=["safetensors", "bin"], value="safetensors", label="Export Format", info="safetensors is recommended for safety" ) with gr.Row(): export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg") with gr.Row(): download_link = gr.File( label="Download Merged File", show_label=True, ) with gr.Column(): export_status = gr.HTML( label="Export Status", value='
✅ Ready to export
', ) gr.HTML("""
â„šī¸ About Quantization:

Reduces model size by lowering precision. Int8 is typically lossless for inference while cutting size in half.

""") # Event handlers - all inside Blocks context def on_load_pipeline_start(): """Called when pipeline loading starts.""" return ( '
âŗ Loading started...
', "Starting download...", gr.update(interactive=False) ) def on_load_pipeline_complete(status_msg, progress_text): """Called when pipeline loading completes.""" if "✅" in status_msg: return ( '
✅ Pipeline loaded successfully!
', progress_text, gr.update(interactive=True) ) elif "âš ī¸" in status_msg or "cancelled" in status_msg.lower(): return ( '
âš ī¸ Download cancelled
', progress_text, gr.update(interactive=True) ) else: return ( f'
{status_msg}
', progress_text, gr.update(interactive=True) ) load_btn.click( fn=on_load_pipeline_start, inputs=[], outputs=[load_status, load_progress, load_btn], ).then( fn=load_pipeline, inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths], outputs=[load_status, load_progress], show_progress="full", ).then( fn=on_load_pipeline_complete, inputs=[load_status, load_progress], outputs=[load_status, load_progress, load_btn], ).then( fn=lambda: ( gr.update(choices=["(None found)"] + get_cached_checkpoints()), gr.update(choices=["(None found)"] + get_cached_vaes()), gr.update(choices=["(None found)"] + get_cached_loras()), ), inputs=[], outputs=[cached_checkpoints, cached_vaes, cached_loras], ) def on_cached_checkpoint_change(cached_path): """Update URL when a cached checkpoint is selected.""" if cached_path and cached_path != "(None found)": return gr.update(value=f"file://{cached_path}") return gr.update() cached_checkpoints.change( fn=lambda x: gr.update(value=f"file://{x}" if x and x != "(None found)" else ""), inputs=cached_checkpoints, outputs=checkpoint_url, ) def on_cached_vae_change(cached_path): """Update VAE URL when a cached VAE is selected.""" if cached_path and cached_path != "(None found)": return gr.update(value=f"file://{cached_path}") return gr.update() cached_vaes.change( fn=on_cached_vae_change, inputs=cached_vaes, outputs=vae_url, ) def on_cached_lora_change(cached_path, current_urls): """Add cached LoRA to the list.""" if cached_path and cached_path != "(None found)": urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()] file_url = f"file://{cached_path}" if file_url not in urls_list: urls_list.append(file_url) return gr.update(value="\n".join(urls_list)) return gr.update() cached_loras.change( fn=on_cached_lora_change, inputs=[cached_loras, lora_urls], outputs=lora_urls, ) def on_generate_start(): """Called when image generation starts.""" return ( '
âŗ Generating image...
', "Starting generation...", gr.update(interactive=False) ) def on_generate_complete(status_msg, progress_text, image): """Called when image generation completes.""" if image is None: return ( f'
{status_msg}
', "", gr.update(interactive=True), gr.update() ) else: return ( '
✅ Generation complete!
', "Done", gr.update(interactive=True), gr.update(value=image) ) gen_btn.click( fn=on_generate_start, inputs=[], outputs=[gen_status, gen_progress, gen_btn], ).then( fn=generate_image, inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y, seed], outputs=[image_output, gen_progress], ).then( fn=lambda img, msg: on_generate_complete(msg, "Done", img), inputs=[image_output, gen_progress], outputs=[gen_status, gen_progress, gen_btn, image_output], ) def on_export_start(): """Called when export starts.""" return ( '
âŗ Export started...
', "Starting export...", gr.update(interactive=False) ) def on_export_complete(status_msg, progress_text, file_path): """Called when export completes.""" if file_path is None: return ( f'
{status_msg}
', "", gr.update(interactive=True), gr.update(value=None) ) else: return ( '
✅ Export complete!
', "Exported successfully", gr.update(interactive=True), gr.update(value=file_path) ) export_btn.click( fn=on_export_start, inputs=[], outputs=[export_status, export_progress, export_btn], ).then( fn=lambda inc, q, qt, fmt: export_merged_model( include_lora=inc, quantize=q and (qt != "none"), qtype=qt, # always pass the string value; exporter handles "none" correctly save_format=fmt, ), inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown], outputs=[download_link, export_progress], ).then( fn=lambda path, msg: on_export_complete(msg, "Exported", path), inputs=[download_link, export_progress], outputs=[export_status, export_progress, export_btn, download_link], ) quantize_toggle.change( fn=lambda checked: gr.update(visible=checked), inputs=[quantize_toggle], outputs=qtype_row, ) return demo demo = create_app() if __name__ == "__main__": demo.launch()