| import numpy as np |
| import gradio as gr |
| import requests |
| import time |
| import json |
| import base64 |
| import os |
| from PIL import Image |
| from io import BytesIO |
|
|
| class Prodia: |
| def __init__(self, api_key, base=None): |
| self.base = base or "https://api.prodia.com/v1" |
| self.headers = { |
| "X-Prodia-Key": api_key |
| } |
|
|
| def generate(self, params): |
| response = self._post(f"{self.base}/sdxl/generate", params) |
| return response.json() |
|
|
| def get_job(self, job_id): |
| response = self._get(f"{self.base}/job/{job_id}") |
| return response.json() |
|
|
| def wait(self, job): |
| job_result = job |
|
|
| while job_result['status'] not in ['succeeded', 'failed']: |
| time.sleep(0.25) |
| job_result = self.get_job(job['job']) |
|
|
| return job_result |
|
|
| def list_models(self): |
| response = self._get(f"{self.base}/models/list") |
| return response.json() |
|
|
| def _post(self, url, params): |
| headers = { |
| **self.headers, |
| "Content-Type": "application/json" |
| } |
| response = requests.post(url, headers=headers, data=json.dumps(params)) |
|
|
| if response.status_code != 200: |
| raise Exception(f"Bad Prodia Response: {response.status_code}") |
|
|
| return response |
|
|
| def _get(self, url): |
| response = requests.get(url, headers=self.headers) |
|
|
| if response.status_code != 200: |
| raise Exception(f"Bad Prodia Response: {response.status_code}") |
|
|
| return response |
|
|
|
|
| def image_to_base64(image_path): |
| |
| with Image.open(image_path) as image: |
| |
| buffered = BytesIO() |
| image.save(buffered, format="PNG") |
|
|
| |
| img_str = base64.b64encode(buffered.getvalue()) |
|
|
| return img_str.decode('utf-8') |
|
|
|
|
|
|
| prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY")) |
|
|
| def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed): |
| result = prodia_client.generate({ |
| "prompt": prompt, |
| "negative_prompt": negative_prompt, |
| "model": model, |
| "steps": steps, |
| "sampler": sampler, |
| "cfg_scale": cfg_scale, |
| "width": width, |
| "height": height, |
| "seed": seed |
| }) |
|
|
| job = prodia_client.wait(result) |
|
|
| return job["imageUrl"] |
|
|
| css = """ |
| #generate { |
| height: 100%; |
| } |
| """ |
|
|
| with gr.Blocks(css=css, theme="Base") as demo: |
|
|
|
|
| |
| with gr.Row(): |
| gr.Markdown("<h1><center>Stable Diffusion XL</center></h1>") |
| with gr.Tab("Playground"): |
| with gr.Row(): |
| with gr.Column(scale=6, min_width=600): |
| prompt = gr.Textbox(label="Prompt", placeholder="beautiful cat, 8k", show_label=True, lines=2) |
| negative_prompt = gr.Textbox(label="Negative Prompt", value="text, blurry, fuzziness", placeholder="text, blurry, fuzziness", show_label=True, lines=3) |
| with gr.Column(): |
| text_button = gr.Button("Generate", variant='primary', elem_id="generate") |
| |
| with gr.Row(): |
| |
|
|
| |
| with gr.Column(scale=2): |
| image_output = gr.Image() |
|
|
| with gr.Accordion("Advanced options", open=False): |
| with gr.Row(): |
| with gr.Column(scale=6): |
| model = gr.Dropdown(interactive=True,value="sd_xl_base_1.0.safetensors [be9edd61]", show_label=True, label="Stable Diffusion Checkpoint", choices=[ |
| "sd_xl_base_1.0.safetensors [be9edd61]", |
| "dynavisionXL_0411.safetensors [c39cc051]", |
| "dreamshaperXL10_alpha2.safetensors [c8afe2ef]", |
| ]) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| sampler = gr.Dropdown(value="DPM++ SDE", show_label=True, label="Sampler", choices=[ |
| "Euler", |
| "Euler a", |
| "LMS", |
| "Heun", |
| "DPM2", |
| "DPM2 a", |
| "DPM++ 2S a", |
| "DPM++ 2M", |
| "DPM++ SDE", |
| "DPM fast", |
| "DPM adaptive", |
| "LMS Karras", |
| "DPM2 Karras", |
| "DPM2 a Karras", |
| "DPM++ 2S a Karras", |
| "DPM++ 2M Karras", |
| "DPM++ SDE Karras", |
| "DDIM", |
| "PLMS", |
| ]) |
| |
| with gr.Column(scale=1): |
| steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=30, step=1) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| width = gr.Slider(label="Width", maximum=1024, value=1024, step=8) |
| height = gr.Slider(label="Height", maximum=1024, value=1024, step=8) |
| |
| with gr.Column(scale=1): |
| batch_size = gr.Slider(label="Batch Size", maximum=1, value=1) |
| batch_count = gr.Slider(label="Batch Count", maximum=1, value=1) |
|
|
| cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1) |
| seed = gr.Number(label="Seed", value=-1, info="""'-1' is random seed""") |
|
|
| |
| text_button.click(flip_text, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output) |
| |
| demo.queue(concurrency_count=1) |
| demo.launch(debug=False, share=False, show_error=False, show_api=False) |
|
|