Spaces:
Running on Zero
Running on Zero
| """ | |
| Juggernaut Z Image Generation Demo | |
| ZeroGPU Space for RunDiffusion/Juggernaut-Z-Image | |
| """ | |
| import spaces | |
| import random | |
| import re | |
| import torch | |
| import gradio as gr | |
| from diffusers import ZImagePipeline | |
| # ==================== Configuration ==================== | |
| BASE_MODEL = "Tongyi-MAI/Z-Image" | |
| # ==================== Resolution Choices ==================== | |
| RES_CHOICES = { | |
| "720": [ | |
| "720x720 ( 1:1 )", | |
| "896x512 ( 16:9 )", | |
| "512x896 ( 9:16 )", | |
| "832x544 ( 3:2 )", | |
| "544x832 ( 2:3 )", | |
| "800x576 ( 4:3 )", | |
| "576x800 ( 3:4 )", | |
| ], | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 ( 16:9 )", | |
| "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", | |
| "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", | |
| "1440x1120 ( 9:7 )", | |
| "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", | |
| "1104x1472 ( 3:4 )", | |
| "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", | |
| "1536x864 ( 16:9 )", | |
| "864x1536 ( 9:16 )", | |
| "1680x720 ( 21:9 )", | |
| "720x1680 ( 9:21 )", | |
| ], | |
| } | |
| RESOLUTION_SET = [] | |
| for resolutions in RES_CHOICES.values(): | |
| RESOLUTION_SET.extend(resolutions) | |
| EXAMPLE_PROMPTS = [ | |
| ["Cinematic portrait of a cyberpunk warrior, neon lights reflecting off chrome armor, rain-soaked streets, dramatic lighting, 8k, photorealistic"], | |
| ["Ethereal forest scene with bioluminescent mushrooms, misty atmosphere, magical lighting, fantasy art style"], | |
| ["Majestic mountain landscape at golden hour, snow-capped peaks, alpine lake reflection, cinematic photography"], | |
| ["Futuristic cityscape at night, flying cars, holographic billboards, cyberpunk aesthetic, highly detailed"], | |
| ["Portrait of an elegant woman in Victorian dress, ornate jewelry, soft natural lighting, studio portrait"], | |
| ] | |
| # ==================== Helper Functions ==================== | |
| def get_resolution(resolution: str) -> tuple[int, int]: | |
| """Parse resolution string to width and height.""" | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| # ==================== Model Loading (Global Context) ==================== | |
| print(f"Loading Z-Image pipeline from {BASE_MODEL}...") | |
| pipe = ZImagePipeline.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Load Juggernaut-Z transformer weights | |
| print("Loading Juggernaut-Z fine-tuned weights...") | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| # Download the safetensors checkpoint | |
| checkpoint_path = hf_hub_download( | |
| repo_id="RunDiffusion/Juggernaut-Z-Image", | |
| filename="Juggernaut_Z_V1_by_RunDiffusion.safetensors", | |
| ) | |
| # Load the safetensors file | |
| state_dict = load_file(checkpoint_path) | |
| # Filter and remap keys for transformer if needed | |
| # The safetensors checkpoint contains the full model weights | |
| # We need to load them into the transformer component | |
| transformer_state_dict = {} | |
| for key, value in state_dict.items(): | |
| # Z-Image transformer keys typically start with specific prefixes | |
| # Adjust prefix if the safetensors uses different naming | |
| if not key.startswith("transformer."): | |
| new_key = "transformer." + key | |
| else: | |
| new_key = key | |
| transformer_state_dict[new_key] = value | |
| # Load into transformer with strict=False to handle partial/key mismatches | |
| missing, unexpected = pipe.transformer.load_state_dict(transformer_state_dict, strict=False) | |
| print(f"Loaded Juggernaut-Z weights. Missing keys: {len(missing)}, Unexpected: {len(unexpected)}") | |
| pipe.to("cuda") | |
| print("Pipeline loaded successfully with Juggernaut-Z fine-tune!") | |
| # ==================== Generation Function ==================== | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| resolution: str = "1024x1024 ( 1:1 )", | |
| seed: int = 42, | |
| num_inference_steps: int = 35, | |
| guidance_scale: float = 6.0, | |
| cfg_normalization: bool = False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Prompt is required.") | |
| width, height = get_resolution(resolution) | |
| generator = torch.Generator("cuda").manual_seed(int(seed)) | |
| image = pipe( | |
| prompt=prompt.strip(), | |
| negative_prompt=negative_prompt.strip() if negative_prompt else None, | |
| height=height, | |
| width=width, | |
| num_inference_steps=int(num_inference_steps), | |
| guidance_scale=float(guidance_scale), | |
| cfg_normalization=bool(cfg_normalization), | |
| generator=generator, | |
| ).images[0] | |
| meta = { | |
| "model": "Juggernaut-Z (RunDiffusion)", | |
| "base_model": BASE_MODEL, | |
| "weights": "Juggernaut_Z_V1_by_RunDiffusion.safetensors", | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "resolution": f"{width} x {height}", | |
| "guidance_scale": guidance_scale, | |
| "steps": num_inference_steps, | |
| "seed": seed, | |
| "cfg_normalization": cfg_normalization, | |
| } | |
| return image, meta | |
| # ==================== Custom Theme ==================== | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700;800&family=Fira+Code:wght@400;500&display=swap'); | |
| :root { | |
| --bg: #080a0e; | |
| --surf: #0d1017; | |
| --card: #111520; | |
| --border: #1c2133; | |
| --border2: #252d45; | |
| --amber: #f59e0b; | |
| --gold: #fbbf24; | |
| --cream: #fef3c7; | |
| --text: #e2e8f8; | |
| --muted: #4a5578; | |
| --r: 14px; | |
| --r-sm: 8px; | |
| } | |
| *, *::before, *::after { box-sizing: border-box; } | |
| body, .gradio-container { | |
| background: var(--bg) !important; | |
| font-family: 'Outfit', sans-serif !important; | |
| color: var(--text) !important; | |
| } | |
| .gradio-container::before { | |
| content: ''; | |
| position: fixed; inset: 0; pointer-events: none; z-index: 0; | |
| background: | |
| radial-gradient(ellipse 70% 50% at 50% -10%, rgba(245,158,11,0.07) 0%, transparent 65%), | |
| radial-gradient(ellipse 40% 30% at 90% 90%, rgba(251,191,36,0.04) 0%, transparent 60%); | |
| } | |
| .app-hero { padding: 52px 0 28px; text-align: center; } | |
| .app-hero h1 { | |
| font-size: 3rem; font-weight: 800; letter-spacing: -0.05em; | |
| line-height: 1; margin: 0 0 12px; | |
| background: linear-gradient(135deg, var(--cream) 0%, var(--gold) 40%, var(--amber) 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; | |
| } | |
| .app-hero .tagline { | |
| color: var(--muted); font-size: 0.88rem; font-weight: 300; | |
| letter-spacing: 0.06em; text-transform: uppercase; margin: 0 0 20px; | |
| } | |
| .app-hero .pills { display: flex; justify-content: center; gap: 8px; flex-wrap: wrap; } | |
| .app-hero .pill { | |
| background: var(--card); border: 1px solid var(--border2); border-radius: 100px; | |
| padding: 4px 14px; font-size: 0.74rem; font-weight: 500; color: var(--muted); | |
| font-family: 'Fira Code', monospace; | |
| } | |
| .app-hero .pill.gold { color: var(--amber); border-color: rgba(245,158,11,0.3); } | |
| .sec-label { | |
| font-size: 0.62rem !important; font-weight: 700 !important; | |
| letter-spacing: 0.15em !important; text-transform: uppercase !important; | |
| color: var(--amber) !important; margin: 0 0 8px !important; display: block; | |
| } | |
| label > span { | |
| font-family: 'Outfit', sans-serif !important; font-size: 0.72rem !important; | |
| font-weight: 500 !important; color: var(--muted) !important; | |
| text-transform: uppercase; letter-spacing: 0.08em; | |
| } | |
| textarea, input[type="text"] { | |
| background: var(--surf) !important; border: 1px solid var(--border) !important; | |
| border-radius: var(--r-sm) !important; color: var(--text) !important; | |
| font-family: 'Outfit', sans-serif !important; font-size: 0.95rem !important; | |
| transition: border-color 0.2s, box-shadow 0.2s; | |
| } | |
| textarea:focus, input[type="text"]:focus { | |
| border-color: var(--amber) !important; | |
| box-shadow: 0 0 0 3px rgba(245,158,11,0.12) !important; | |
| outline: none !important; | |
| } | |
| .gen-btn { | |
| background: linear-gradient(135deg, var(--amber), #d97706) !important; | |
| border: none !important; border-radius: var(--r) !important; | |
| color: #000 !important; font-family: 'Outfit', sans-serif !important; | |
| font-weight: 700 !important; font-size: 1rem !important; | |
| height: 54px !important; width: 100% !important; | |
| letter-spacing: 0.02em !important; cursor: pointer !important; | |
| transition: opacity 0.18s, transform 0.15s, box-shadow 0.2s !important; | |
| box-shadow: 0 4px 20px rgba(245,158,11,0.28) !important; | |
| } | |
| .gen-btn:hover { | |
| opacity: 0.88 !important; transform: translateY(-1px) !important; | |
| box-shadow: 0 8px 30px rgba(245,158,11,0.48) !important; | |
| } | |
| .gen-btn:active { transform: translateY(0) !important; } | |
| .result-gallery .grid-wrap { | |
| background: var(--surf) !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: var(--r) !important; | |
| } | |
| .result-gallery img { border-radius: 10px !important; } | |
| .gr-accordion { | |
| background: var(--card) !important; border: 1px solid var(--border) !important; | |
| border-radius: var(--r) !important; margin-top: 10px !important; | |
| } | |
| ::-webkit-scrollbar { width: 5px; } | |
| ::-webkit-scrollbar-track { background: var(--surf); } | |
| ::-webkit-scrollbar-thumb { background: var(--border2); border-radius: 3px; } | |
| ::-webkit-scrollbar-thumb:hover { background: var(--amber); } | |
| """ | |
| # ==================== Gradio Interface ==================== | |
| with gr.Blocks(css=CSS) as demo: | |
| gr.HTML(""" | |
| <div class="app-hero"> | |
| <h1>Juggernaut Z</h1> | |
| <p class="tagline">Cinematic Fine-tune of Z-Image Base</p> | |
| <div class="pills"> | |
| <span class="pill gold">ZeroGPU ⚡</span> | |
| <span class="pill">RunDiffusion</span> | |
| <span class="pill">bfloat16</span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| gr.HTML('<span class="sec-label">① Prompt</span>') | |
| prompt = gr.Textbox( | |
| label="", | |
| lines=5, | |
| placeholder="Cinematic portrait of a warrior queen, golden armor, dramatic lighting, 8k, photorealistic...", | |
| container=False, | |
| ) | |
| gr.HTML('<div style="height:8px"></div>') | |
| negative_prompt = gr.Textbox( | |
| label="Negative prompt", | |
| lines=2, | |
| placeholder="Optional: describe what to avoid...", | |
| value="", | |
| ) | |
| gr.HTML('<div style="height:10px"></div>') | |
| run_btn = gr.Button("▶ Generate", variant="primary", elem_classes=["gen-btn"]) | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=[prompt], | |
| label="Example prompts", | |
| ) | |
| with gr.Column(scale=1, min_width=320): | |
| gr.HTML('<span class="sec-label">② Result</span>') | |
| result = gr.Image( | |
| label="", | |
| type="pil", | |
| height=512, | |
| container=False, | |
| elem_classes=["result-gallery"], | |
| ) | |
| gr.HTML('<div style="height:8px"></div>') | |
| gr.HTML('<span class="sec-label">Generation Metadata</span>') | |
| metadata = gr.JSON(label="", show_label=False) | |
| with gr.Accordion("⚙ Generation Settings", open=False): | |
| gr.HTML('<span class="sec-label" style="margin-top:4px">Resolution</span>') | |
| resolution = gr.Dropdown( | |
| label="", | |
| choices=RESOLUTION_SET, | |
| value="1024x1024 ( 1:1 )", | |
| container=False, | |
| ) | |
| gr.HTML('<div style="height:10px"></div>') | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=3.0, | |
| maximum=12.0, | |
| step=0.5, | |
| value=6.0, | |
| info="Juggernaut Z: 6-9 (higher = more prompt adherence)", | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Steps", | |
| minimum=20, | |
| maximum=60, | |
| step=1, | |
| value=35, | |
| info="Juggernaut Z: 25-45 recommended", | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2_147_483_647, | |
| step=1, | |
| value=42, | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=False, | |
| ) | |
| cfg_normalization = gr.Checkbox( | |
| label="CFG Normalization", | |
| value=False, | |
| info="Enable for more stable CFG behavior at high values", | |
| ) | |
| def generate_wrapper(prompt, negative_prompt, resolution, seed, num_inference_steps, guidance_scale, cfg_normalization, randomize_seed): | |
| if randomize_seed: | |
| seed = random.randint(0, 2_147_483_647) | |
| return generate(prompt, negative_prompt, resolution, seed, num_inference_steps, guidance_scale, cfg_normalization) | |
| inputs = [ | |
| prompt, negative_prompt, resolution, seed, | |
| num_inference_steps, guidance_scale, cfg_normalization, randomize_seed, | |
| ] | |
| run_btn.click( | |
| fn=generate_wrapper, | |
| inputs=inputs, | |
| outputs=[result, metadata], | |
| api_name="generate", | |
| ) | |
| prompt.submit( | |
| fn=generate_wrapper, | |
| inputs=inputs, | |
| outputs=[result, metadata], | |
| api_name=False, | |
| ) | |
| demo.queue(max_size=20) | |
| demo.launch() | |