import argparse import itertools import math import os from contextlib import nullcontext import random import torch import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer import bitsandbytes as bnb def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid output_dir = './' from diffusers import DPMSolverMultistepScheduler pipe = StableDiffusionPipeline.from_pretrained( output_dir, scheduler = DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"), torch_dtype=torch.float16, ) import gradio as gr def inference(prompt, num_samples): all_images = [] images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=25).images all_images.extend(images) return all_images with gr.Blocks() as demo: with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="prompt") samples = gr.Slider(label="Samples",value=1) run = gr.Button(value="Run") with gr.Column(): gallery = gr.Gallery(show_label=False) run.click(inference, inputs=[prompt,samples], outputs=gallery) gr.Examples([["Foods in tokyo", 1,1]], [prompt,samples], gallery, inference, cache_examples=False) demo.launch()