| import PIL |
| from PIL import Image |
| import numpy as np |
| import torch |
| import cv2 as cv |
| import random |
| import os |
| import spaces |
| import gradio as gr |
|
|
| from diffusers import DiffusionPipeline |
| from peft import PeftModel, LoraConfig |
|
|
| from diffusers import ( |
| StableDiffusionPipeline, |
| StableDiffusionControlNetPipeline, |
| StableDiffusionControlNetImg2ImgPipeline, |
| DPMSolverMultistepScheduler, |
| PNDMScheduler, |
| ControlNetModel |
| ) |
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.utils import load_image, make_image_grid |
|
|
|
|
| MAX_SEED = np.iinfo(np.int32).max |
| MAX_IMAGE_SIZE = 1024 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| model_id_default = "sd-legacy/stable-diffusion-v1-5" |
| model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5'] |
| model_lora_default = "lora" |
|
|
|
|
| def get_lora_sd_pipeline( |
| ckpt_dir='./' + model_lora_default, |
| base_model_name_or_path=None, |
| dtype=torch.float16, |
| device=DEVICE, |
| adapter_name="default", |
| controlnet=None, |
| ip_adapter=None |
| ): |
| unet_sub_dir = os.path.join(ckpt_dir, "unet") |
| text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder") |
| if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None: |
| config = LoraConfig.from_pretrained(text_encoder_sub_dir) |
| base_model_name_or_path = config.base_model_name_or_path |
|
|
| if base_model_name_or_path is None: |
| raise ValueError("Please specify the base model name or path") |
|
|
|
|
| if controlnet and ip_adapter: |
| print('Pipe with ControlNet and IpAdapter') |
|
|
| controlnet = ControlNetModel.from_pretrained( |
| "lllyasviel/sd-controlnet-canny", |
| cache_dir="./models_cache", |
| torch_dtype=torch.float16 |
| ) |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| base_model_name_or_path, |
| torch_dtype=dtype, |
| controlnet=controlnet).to(device) |
| |
| pipe.load_ip_adapter( |
| "h94/IP-Adapter", |
| subfolder="models", |
| weight_name="ip-adapter-plus_sd15.bin", |
| ) |
|
|
|
|
| elif controlnet: |
| print('Pipe with ControlNet') |
| controlnet = ControlNetModel.from_pretrained( |
| "lllyasviel/sd-controlnet-canny", |
| cache_dir="./models_cache", |
| torch_dtype=torch.float16) |
| pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype, controlnet=controlnet) |
|
|
|
|
| elif ip_adapter: |
| print('Pipe with IpAdapter') |
| pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype) |
| pipe.load_ip_adapter( |
| "h94/IP-Adapter", |
| subfolder="models", |
| weight_name="ip-adapter-plus_sd15.bin") |
|
|
|
|
| else: |
| print('Pipe with only SD') |
| pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype) |
|
|
|
|
| pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name) |
| if os.path.exists(text_encoder_sub_dir): |
| pipe.text_encoder = PeftModel.from_pretrained( |
| pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name |
| ) |
|
|
| if dtype in (torch.float16, torch.bfloat16): |
| pipe.unet.half() |
| pipe.text_encoder.half() |
|
|
| pipe.safety_checker = None |
| pipe.to(device) |
| return pipe |
|
|
|
|
| @spaces.GPU |
| def infer( |
| prompt, |
| negative_prompt, |
| randomize_seed, |
| width=512, |
| height=512, |
| model_repo_id=model_id_default, |
| seed=22, |
| guidance_scale=7, |
| num_inference_steps=50, |
| |
| use_advanced_controlnet=False, |
| control_strength=None, |
| image_upload_cn=None, |
| |
| use_advanced_ip=False, |
| ip_adapter_scale=None, |
| image_upload_ip=None, |
| |
| model_lora_id=model_lora_default, |
| progress=gr.Progress(track_tqdm=True), |
| dtype=torch.float16, |
| device=DEVICE, |
| ): |
| |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| generator = torch.Generator().manual_seed(seed) |
|
|
|
|
| print(use_advanced_controlnet, use_advanced_ip) |
|
|
|
|
| if use_advanced_controlnet == False and use_advanced_ip == False: |
| print("1. SD 1.5 + Lora") |
| pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id, |
| dtype=dtype).to(device) |
| |
| image = pipe(prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| negative_prompt=negative_prompt, |
| width=width, |
| heigth=height, |
| generator=generator).images[0] |
|
|
| elif use_advanced_controlnet != False and use_advanced_ip == False: |
| print("SD 1.5 + Lora + Controlnet") |
|
|
| edges = cv.Canny(image_upload_cn, 80, 160) |
| edges = np.repeat(edges[:, :, None], 3, axis=2) |
| edges = Image.fromarray(edges) |
|
|
| pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id, |
| controlnet=True, |
| dtype=dtype).to(device) |
| |
| image = pipe(prompt, |
| edges, |
| num_inference_steps = num_inference_steps, |
| controlnet_conditioning_scale=control_strength, |
| negative_prompt=negative_prompt, |
| generator=generator).images[0] |
| |
|
|
| elif use_advanced_ip != False and use_advanced_controlnet == False: |
| print("SD 1.5 + Lora + IpAdapter") |
| pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id, |
| ip_adapter=True, |
| dtype=dtype).to(device) |
| pipe.set_ip_adapter_scale(ip_adapter_scale) |
|
|
| image = pipe( |
| prompt, |
| ip_adapter_image=image_upload_ip, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| generator=generator).images[0] |
|
|
| elif use_advanced_ip != False and use_advanced_controlnet != False: |
| print("SD 1.5 + Lora + IpAdapter + ControlNet") |
|
|
| edges = cv.Canny(image_upload_cn, 80, 160) |
| edges = np.repeat(edges[:, :, None], 3, axis=2) |
| edges = Image.fromarray(edges) |
|
|
| pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id, |
| ip_adapter=True, |
| controlnet=True, |
| dtype=dtype).to(device) |
| |
| |
| pipe.set_ip_adapter_scale(ip_adapter_scale) |
| image = pipe(prompt, |
| edges, |
| ip_adapter_image=image_upload_ip, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| controlnet_conditioning_scale=control_strength, |
| height=height, |
| width=width, |
| generator=generator, |
| ).images[0] |
|
|
| return image, seed |