| import argparse |
| import itertools |
| import math |
| import os |
| from contextlib import nullcontext |
| import random |
| import torch |
|
|
| import PIL |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import set_seed |
| from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| from diffusers.optimization import get_scheduler |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
| from PIL import Image |
| from torchvision import transforms |
| from tqdm.auto import tqdm |
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
|
|
| import bitsandbytes as bnb |
|
|
| def image_grid(imgs, rows, cols): |
| assert len(imgs) == rows*cols |
|
|
| w, h = imgs[0].size |
| grid = Image.new('RGB', size=(cols*w, rows*h)) |
| grid_w, grid_h = grid.size |
|
|
| for i, img in enumerate(imgs): |
| grid.paste(img, box=(i%cols*w, i//cols*h)) |
| return grid |
|
|
| output_dir = './' |
| from diffusers import DPMSolverMultistepScheduler |
| pipe = StableDiffusionPipeline.from_pretrained( |
| output_dir, |
| scheduler = DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"), |
| torch_dtype=torch.float16, |
| ) |
|
|
| import gradio as gr |
|
|
| def inference(prompt, num_samples): |
| all_images = [] |
| images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=25).images |
| all_images.extend(images) |
| return all_images |
|
|
| with gr.Blocks() as demo: |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(label="prompt") |
| samples = gr.Slider(label="Samples",value=1) |
| run = gr.Button(value="Run") |
| with gr.Column(): |
| gallery = gr.Gallery(show_label=False) |
|
|
| run.click(inference, inputs=[prompt,samples], outputs=gallery) |
| gr.Examples([["Foods in tokyo", 1,1]], [prompt,samples], gallery, inference, cache_examples=False) |
|
|
|
|
| demo.launch() |
|
|
|
|