|
|
| import spaces |
| import os |
| import datetime |
| import einops |
| import gradio as gr |
| from gradio_imageslider import ImageSlider |
| import numpy as np |
| import torch |
| import random |
| from PIL import Image |
| from pathlib import Path |
| from torchvision import transforms |
| import torch.nn.functional as F |
| from torchvision.models import resnet50, ResNet50_Weights |
|
|
| from pytorch_lightning import seed_everything |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
| from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler |
|
|
| from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline |
| from myutils.misc import load_dreambooth_lora, rand_name |
| from myutils.wavelet_color_fix import wavelet_color_fix |
| from annotator.retinaface import RetinaFaceDetection |
|
|
| use_pasd_light = False |
| face_detector = RetinaFaceDetection() |
|
|
| if use_pasd_light: |
| from models.pasd_light.unet_2d_condition import UNet2DConditionModel |
| from models.pasd_light.controlnet import ControlNetModel |
| else: |
| from models.pasd.unet_2d_condition import UNet2DConditionModel |
| from models.pasd.controlnet import ControlNetModel |
|
|
| pretrained_model_path = "checkpoints/stable-diffusion-v1-5" |
| ckpt_path = "runs/pasd/checkpoint-100000" |
| |
| dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors" |
| |
| weight_dtype = torch.float16 |
| device = "cuda" |
|
|
| scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") |
| feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") |
| unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet") |
| controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet") |
| vae.requires_grad_(False) |
| text_encoder.requires_grad_(False) |
| unet.requires_grad_(False) |
| controlnet.requires_grad_(False) |
|
|
| unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path) |
|
|
| text_encoder.to(device, dtype=weight_dtype) |
| vae.to(device, dtype=weight_dtype) |
| unet.to(device, dtype=weight_dtype) |
| controlnet.to(device, dtype=weight_dtype) |
|
|
| validation_pipeline = StableDiffusionControlNetPipeline( |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, |
| unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, |
| ) |
| |
| validation_pipeline._init_tiled_vae(decoder_tile_size=224) |
|
|
| weights = ResNet50_Weights.DEFAULT |
| preprocess = weights.transforms() |
| resnet = resnet50(weights=weights) |
| resnet.eval() |
|
|
| def resize_image(image_path, target_height): |
| |
| with Image.open(image_path) as img: |
| |
| ratio = target_height / float(img.size[1]) |
| |
| new_width = int(float(img.size[0]) * ratio) |
| |
| resized_img = img.resize((new_width, target_height), Image.LANCZOS) |
| |
| |
| return resized_img |
|
|
| @spaces.GPU(enable_queue=True) |
| def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed): |
| |
| |
| if seed == -1: |
| seed = 0 |
| |
| input_image = resize_image(input_image, 512) |
| process_size = 768 |
| resize_preproc = transforms.Compose([ |
| transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), |
| ]) |
| |
| |
| timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
| with torch.no_grad(): |
| seed_everything(seed) |
| generator = torch.Generator(device=device) |
|
|
| input_image = input_image.convert('RGB') |
| batch = preprocess(input_image).unsqueeze(0) |
| prediction = resnet(batch).squeeze(0).softmax(0) |
| class_id = prediction.argmax().item() |
| score = prediction[class_id].item() |
| category_name = weights.meta["categories"][class_id] |
| if score >= 0.1: |
| prompt += f"{category_name}" if prompt=='' else f", {category_name}" |
|
|
| prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}" |
|
|
| ori_width, ori_height = input_image.size |
| resize_flag = False |
|
|
| rscale = upscale |
| input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale)) |
| |
| |
| |
|
|
| input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8)) |
| width, height = input_image.size |
| resize_flag = True |
|
|
| try: |
| image = validation_pipeline( |
| None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, |
| negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, |
| ).images[0] |
| |
| if True: |
| image = wavelet_color_fix(image, input_image) |
| |
| if resize_flag: |
| image = image.resize((ori_width*rscale, ori_height*rscale)) |
| except Exception as e: |
| print(e) |
| image = Image.new(mode="RGB", size=(512, 512)) |
| |
| |
| image.save(f'result_{timestamp}.jpg', 'JPEG') |
|
|
| |
| input_image.save(f'input_{timestamp}.jpg', 'JPEG') |
| |
| return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg" |
|
|
| title = "Pixel-Aware Stable Diffusion for Real-ISR" |
| description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them." |
| article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>" |
| |
|
|
| css = """ |
| #col-container{ |
| margin: 0 auto; |
| max-width: 720px; |
| } |
| #project-links{ |
| margin: 0 0 12px !important; |
| column-gap: 8px; |
| display: flex; |
| justify-content: center; |
| flex-wrap: nowrap; |
| flex-direction: row; |
| align-items: center; |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| with gr.Column(elem_id="col-container"): |
| gr.HTML(f""" |
| <h2 style="text-align: center;"> |
| PASD Magnify |
| </h2> |
| <p style="text-align: center;"> |
| Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization |
| </p> |
| <p id="project-links" align="center"> |
| <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> |
| </p> |
| <p style="margin:12px auto;display: flex;justify-content: center;"> |
| <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a> |
| </p> |
| |
| """) |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="filepath", sources=["upload"], value="samples/frog.png") |
| prompt_in = gr.Textbox(label="Prompt", value="Frog") |
| with gr.Accordion(label="Advanced settings", open=False): |
| added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece') |
| neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') |
| denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1) |
| upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1) |
| condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1) |
| classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1) |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) |
| submit_btn = gr.Button("Submit") |
| with gr.Column(): |
| b_a_slider = ImageSlider(label="B/A result", position=0.5) |
| file_output = gr.File(label="Downloadable image result") |
| |
| submit_btn.click( |
| fn = inference, |
| inputs = [ |
| input_image, prompt_in, |
| added_prompt, neg_prompt, |
| denoise_steps, |
| upsample_scale, condition_scale, |
| classifier_free_guidance, seed |
| ], |
| outputs = [ |
| b_a_slider, |
| file_output |
| ] |
| ) |
| demo.queue(max_size=20).launch() |