File size: 6,393 Bytes
d695f2c
 
 
 
 
8c5c5a4
d695f2c
26fbae5
8c5c5a4
d695f2c
26fbae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d695f2c
 
19ac4ad
d695f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19ac4ad
 
d695f2c
 
19ac4ad
d695f2c
 
8c5c5a4
 
26fbae5
 
 
 
d695f2c
 
 
19ac4ad
d695f2c
26fbae5
19ac4ad
d695f2c
26fbae5
d695f2c
26fbae5
d695f2c
 
 
19ac4ad
 
26fbae5
19ac4ad
 
 
26fbae5
 
d695f2c
 
 
19ac4ad
d695f2c
 
19ac4ad
d695f2c
19ac4ad
d695f2c
26fbae5
d695f2c
 
 
 
19ac4ad
 
d695f2c
 
19ac4ad
 
d695f2c
19ac4ad
d695f2c
 
 
 
19ac4ad
 
d695f2c
19ac4ad
 
 
d695f2c
19ac4ad
d695f2c
 
19ac4ad
 
d695f2c
 
 
19ac4ad
 
 
d695f2c
 
 
 
 
 
 
19ac4ad
 
 
d695f2c
 
 
 
19ac4ad
 
 
d695f2c
 
 
26fbae5
d695f2c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from huggingface_hub import InferenceClient
import os
import random

HF_TOKEN = os.environ.get("HF_TOKEN")

client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN)

MODELS = {
    "SDXL Base": {
        "model_id": "stabilityai/stable-diffusion-xl-base-1.0",
        "max_steps": 50,
        "recommended_steps": 30,
    },
    "FLUX.1-schnell (4 steps)": {
        "model_id": "black-forest-labs/FLUX.1-schnell",
        "max_steps": 4,
        "recommended_steps": 4,
    },
    "FLUX.1-dev (28 steps)": {
        "model_id": "black-forest-labs/FLUX.1-dev",
        "max_steps": 28,
        "recommended_steps": 28,
    },
}

DEFAULT_MODEL = "SDXL Base"

STYLE_PRESETS = {
    "None": "",
    "Anime": ", anime style, vibrant colors, detailed linework",
    "Photorealistic": ", photorealistic, 8k, ultra detailed, professional photography",
    "Digital Art": ", digital art, highly detailed, concept art, trending on artstation",
    "Oil Painting": ", oil painting style, textured brushstrokes, classical art",
    "Watercolor": ", watercolor painting, soft colors, fluid strokes",
    "3D Render": ", 3D render, octane render, highly detailed, volumetric lighting",
    "Pixel Art": ", pixel art style, retro game, 16-bit",
    "Comic Book": ", comic book style, bold outlines, vibrant colors, halftone",
    "Fantasy": ", fantasy art, magical, epic, detailed, cinematic lighting",
}


def generate_image(prompt, model_name=DEFAULT_MODEL, style="None", negative_prompt="",
    guidance_scale=7.5, num_inference_steps=30, seed=-1, width=1024, height=1024):
    if not prompt.strip():
        raise gr.Error("Please enter a prompt!")
    full_prompt = prompt + STYLE_PRESETS[style] if style in STYLE_PRESETS and style != "None" else prompt
    if seed == -1:
        seed = random.randint(0, 2**32 - 1)
    model_info = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
    model_id = model_info["model_id"]
    max_steps = model_info["max_steps"]
    
    # Clamp inference steps to model's max supported
    steps = min(num_inference_steps, max_steps)

    try:
        image = client.text_to_image(
            prompt=full_prompt, model=model_id,
            negative_prompt=negative_prompt if negative_prompt.strip() else None,
            guidance_scale=guidance_scale, num_inference_steps=steps,
            width=width, height=height, seed=seed,
        )
        return image, f"Seed: {seed} | Model: {model_name} | Steps: {steps}"
    except Exception as e:
        raise gr.Error(f"Generation failed: {str(e)}")


EXAMPLES = [
    ["A majestic dragon flying over a crystal lake at sunset, epic fantasy art", "SDXL Base", "Fantasy", "", 7.5, 30, -1, 1024, 1024],
    ["A cute robot exploring a colorful garden", "SDXL Base", "Anime", "", 7.5, 30, -1, 1024, 1024],
    ["An astronaut riding a horse on Mars", "FLUX.1-schnell (4 steps)", "Photorealistic", "", 7.5, 4, -1, 1024, 1024],
    ["A cozy coffee shop interior with rain outside", "SDXL Base", "Digital Art", "", 7.5, 30, -1, 1024, 1024],
    ["A futuristic city skyline at night", "SDXL Base", "3D Render", "", 7.5, 30, -1, 1024, 1024],
    ["Beautiful mountain landscape with a river, golden hour", "SDXL Base", "Oil Painting", "", 7.5, 30, -1, 1024, 1024],
    ["A magical forest with glowing mushrooms", "FLUX.1-schnell (4 steps)", "Fantasy", "", 7.5, 4, -1, 1024, 1024],
    ["Portrait of a wise old wizard", "FLUX.1-dev (28 steps)", "Digital Art", "", 7.5, 28, -1, 1024, 1024],
]



# Build the Gradio UI
with gr.Blocks(
    title="🎨 Text-to-Image Generator",
    theme=gr.themes.Soft(),
    css=".generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; }"
) as demo:
    gr.Markdown("# 🎨 Text-to-Image Generator\n\nGenerate images from text using SDXL and FLUX via HuggingFace Inference API.")

    with gr.Row():
        with gr.Column(scale=3):
            prompt_input = gr.Textbox(
                label="✏️ Prompt", placeholder="Describe the image you want...",
                lines=3, max_lines=5,
            )
            with gr.Row():
                model_dropdown = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="πŸ€– Model")
                style_dropdown = gr.Dropdown(choices=list(STYLE_PRESETS.keys()), value="None", label="🎭 Style")
            negative_prompt_input = gr.Textbox(
                label="🚫 Negative Prompt", placeholder="blur, low quality...", lines=2,
            )

            with gr.Accordion("βš™οΈ Advanced Settings", open=False):
                with gr.Row():
                    guidance_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="Guidance Scale")
                    num_steps = gr.Slider(10, 50, value=30, step=1, label="Inference Steps")
                with gr.Row():
                    width = gr.Slider(512, 1536, value=1024, step=128, label="Width")
                    height = gr.Slider(512, 1536, value=1024, step=128, label="Height")
                seed_input = gr.Number(value=-1, label="Seed (-1 random)", precision=0)

            generate_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg", elem_classes="generate-btn")

        with gr.Column(scale=4):
            output_image = gr.Image(label="Generated Image", type="pil", height=600)
            info_text = gr.Textbox(label="Info", interactive=False)

    gr.Examples(
        examples=EXAMPLES,
        inputs=[prompt_input, model_dropdown, style_dropdown,
                negative_prompt_input, guidance_scale, num_steps,
                seed_input, width, height],
        outputs=[output_image, info_text],
        fn=generate_image,
        cache_examples=False,
    )

    generate_btn.click(
        fn=generate_image,
        inputs=[prompt_input, model_dropdown, style_dropdown,
                negative_prompt_input, guidance_scale, num_steps,
                seed_input, width, height],
        outputs=[output_image, info_text],
    )
    prompt_input.submit(
        fn=generate_image,
        inputs=[prompt_input, model_dropdown, style_dropdown,
                negative_prompt_input, guidance_scale, num_steps,
                seed_input, width, height],
        outputs=[output_image, info_text],
    )

    gr.Markdown("---\n**Models:** SDXL (Stability AI) | FLUX.1-schnell/dev (Black Forest Labs) via fal-ai")

if __name__ == "__main__":
    demo.launch()