| import spaces |
| import gradio as gr |
| import torch |
| from PIL import Image |
| from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL |
|
|
| device = "cuda" |
|
|
| base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) |
| refiner_id = "stabilityai/stable-diffusion-xl-refiner-1.0" |
|
|
| base_pipeline = DiffusionPipeline.from_pretrained( |
| base_model_id, |
| torch_dtype = torch.float16, |
| variant = "fp16", |
| use_safetensors = True |
| ).to(device) |
|
|
| refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
| refiner_id, |
| text_encoder_2 = base_pipeline.text_encoder_2, |
| vae = vae, |
| torch_dtype = torch.float16, |
| variant = "fp16", |
| use_safetensors = True |
| ).to(device) |
|
|
|
|
|
|
| SAMPLER_MAP = { |
| "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), |
| "Euler": lambda config: EulerDiscreteScheduler.from_config(config), |
| } |
|
|
|
|
|
|
| @spaces.GPU(duration=59) |
| def generate( |
| prompt, |
| negative_prompt, |
| num_inference_steps, |
| denoising_switch, |
| width, height, |
| guidance_scale |
| ): |
| |
| base_processed_image = base_pipeline( |
| prompt = prompt, |
| negative_prompt = negative_prompt, |
| num_inference_steps = num_inference_steps, |
| denoising_end = denoising_switch, |
| width = width, |
| height = height, |
| guidance_scale = guidance_scale, |
| output_type = "latent" |
| ).images |
|
|
| generated_image = refiner( |
| prompt = prompt, |
| negative_prompt = negative_prompt, |
| num_inference_steps = num_inference_steps, |
| denoising_start = denoising_switch, |
| width = width, |
| height = height, |
| guidance_scale = guidance_scale, |
| image = base_processed_image |
| ).images[0] |
|
|
| return generated_image |
|
|
|
|
| def create_ui(): |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| base_model = gr.Radio(label="Base model", choices=[base_model_id], value=base_model_id, interactive=False) |
| refiner_model = gr.Radio(label="Refiner model", choices=[refiner_id], value=refiner_id, interactive=False) |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(label="Prompt", lines=3) |
| negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="low quality, bad quality") |
| num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=30) |
| denoising_switch = gr.Slider(label="Denoising Switch", minimum=0.01, maximum=1, step=0.01, value=0.8) |
| width = gr.Slider(label="Width", minimum=64, maximum=2048, step=16, value=1024) |
| height = gr.Slider(label="Height", minimum=64, maximum=2048, step=16, value=1024) |
| guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.1, maximum=30, step=0.1, value=7.5) |
| with gr.Column(): |
| output_image = gr.Image(interactive=False) |
| generate_button = gr.Button("Run", variant="primary") |
| |
| generate_button.click( |
| generate, |
| inputs=[ |
| prompt, |
| negative_prompt, |
| num_inference_steps, |
| denoising_switch, |
| width, height, |
| guidance_scale |
| ], |
| outputs=[output_image] |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| gradio_app = create_ui() |
| gradio_app.launch( |
| share = True |
| ) |