VARestorer / tools /infer4eval.py
YixuanEvan's picture
add HF model card and mirror runnable codebase
7f7272e
import os
import os.path as osp
import hashlib
import time
import argparse
import json
import shutil
import glob
import re
import sys
import cv2
from tqdm.auto import tqdm
import torch
import numpy as np
from pytorch_lightning import seed_everything
from run_varestorer import *
from conf import HF_TOKEN, HF_HOME
from transformers import BlipForConditionalGeneration,BlipProcessor
# set environment variables
os.environ['HF_TOKEN'] = HF_TOKEN
os.environ['HF_HOME'] = HF_HOME
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1'
if __name__ == '__main__':
parser = argparse.ArgumentParser()
add_common_arguments(parser)
parser.add_argument('--out_dir', type=str, default='')
parser.add_argument('--n_samples', type=int, default=1)
parser.add_argument('--metadata_file', type=str, default='evaluation/image_reward/benchmark-prompts.json')
parser.add_argument('--rewrite_prompt', type=int, default=0, choices=[0,1])
###
parser.add_argument('--noise_apply_layers',type=int,default=0)
parser.add_argument('--noise_apply_requant',type=int,default=1)
parser.add_argument('--noise_apply_strength',type=float,default=0.3)
parser.add_argument('--debug_bsc',type=int,default=0)
###
args = parser.parse_args()
# parse cfg
args.cfg = list(map(float, args.cfg.split(',')))
if len(args.cfg) == 1:
args.cfg = args.cfg[0]
with open(args.metadata_file) as fp:
metadatas = json.load(fp)
if args.model_type == 'sdxl':
from diffusers import DiffusionPipeline
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
elif args.model_type == 'sd3':
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
elif args.model_type == 'pixart_sigma':
from diffusers import PixArtSigmaPipeline
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
).to("cuda")
elif args.model_type == 'flux_1_dev':
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
elif args.model_type == 'flux_1_dev_schnell':
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
elif 'infinity' in args.model_type:
# load text encoder
text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
# load vae
vae = load_visual_tokenizer(args)
# load infinity
infinity = load_transformer(vae, args)
if args.rewrite_prompt:
from tools.prompt_rewriter import PromptRewriter
prompt_rewriter = PromptRewriter(system='', few_shot_history=[])
os.makedirs(args.out_dir,exist_ok=True)
save_metadatas = []
#####
blip_processor = BlipProcessor.from_pretrained("weights/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("weights/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
#####
#####
swinir_config = {
"target": "infinity.models.swinir.SwinIR",
"params": {
"img_size": 64,
"patch_size": 1,
"in_chans": 3,
"embed_dim": 180,
"depths": [6, 6, 6, 6, 6, 6, 6, 6],
"num_heads": [6, 6, 6, 6, 6, 6, 6, 6],
"window_size": 8,
"mlp_ratio": 2,
"sf": 8,
"img_range": 1.0,
"upsampler": "nearest+conv",
"resi_connection": "1conv",
"unshuffle": True,
"unshuffle_scale": 8
}
}
swinir: SwinIR = instantiate_from_config(swinir_config)
sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu")
if "state_dict" in sd:
sd = sd["state_dict"]
sd = {
(k[len("module.") :] if k.startswith("module.") else k): v
for k, v in sd.items()
}
swinir.load_state_dict(sd, strict=True)
for p in swinir.parameters():
p.requires_grad = False
swinir.eval().to("cuda")
#####
for index, metadata in tqdm(enumerate(metadatas)):
seed_everything(args.seed)
lq_img_path = metadata['lq_img_path']
prompt = metadata.get('prompt', None)
img_name = os.path.relpath(lq_img_path, start=os.path.dirname(lq_img_path))
sample_path = os.path.join(args.out_dir, img_name)
tau = args.tau
cfg = args.cfg
if args.rewrite_prompt:
refined_prompt = prompt_rewriter.rewrite(prompt)
input_key_val = extract_key_val(refined_prompt)
prompt = input_key_val['prompt']
print(f'prompt: {prompt}, refined_prompt: {refined_prompt}')
images = []
bitwise_self_correction= BitwiseSelfCorrection(vae, args)
for _ in range(args.n_samples): #####n_samples==1
t1 = time.time()
if args.model_type == 'sdxl':
image = base(
prompt=prompt,
num_inference_steps=40,
denoising_end=0.8,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=40,
denoising_start=0.8,
image=image,
).images[0]
elif args.model_type == 'sd3':
image = pipe(
prompt,
negative_prompt="",
num_inference_steps=28,
guidance_scale=7.0,
num_images_per_prompt=1,
).images[0]
elif args.model_type == 'flux_1_dev':
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
num_images_per_prompt=1,
).images[0]
elif args.model_type == 'flux_1_dev_schnell':
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
elif args.model_type == 'pixart_sigma':
image = pipe(prompt).images[0]
elif 'infinity' in args.model_type:
h_div_w_template = 1.000
scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales']
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]['pixel']
# image,prompt = gen_one_img_eval(infinity,
# vae,
# text_tokenizer,
# text_encoder,
# prompt,
# tau_list=tau,
# cfg_sc=3,
# cfg_list=cfg,
# scale_schedule=scale_schedule,
# cfg_insertion_layer=[args.cfg_insertion_layer],
# vae_type=args.vae_type,
# lq_img_path=lq_img_path,
# args=args,
# blip_model=blip_model,
# blip_processor=blip_processor,
# )
image,prompt = gen_one_img_eval_long(infinity,
vae,
text_tokenizer,
text_encoder,
prompt,
tau_list=tau,
cfg_sc=3,
cfg_list=cfg,
scale_schedule=scale_schedule,
cfg_insertion_layer=[args.cfg_insertion_layer],
vae_type=args.vae_type,
lq_img_path=lq_img_path,
args=args,
blip_model=blip_model,
blip_processor=blip_processor,
swinir=swinir,
bitwise_self_correction=bitwise_self_correction
)
else:
raise ValueError
t2 = time.time()
images.append(image)
for i, image in enumerate(images):
if 'infinity' in args.model_type:
cv2.imwrite(sample_path, image.cpu().numpy())
else:
image.save(sample_path)
metadata['prompt']=prompt
save_metadatas.append(metadata)
save_metadata_file_path = os.path.join(os.path.dirname(args.metadata_file), "metadata_w_prompt.json")
with open(save_metadata_file_path, "w") as fp:
json.dump(save_metadatas, fp)