| import os |
| import os.path as osp |
| import hashlib |
| import time |
| import argparse |
| import json |
| import shutil |
| import glob |
| import re |
| import sys |
|
|
| import cv2 |
| from tqdm.auto import tqdm |
| import torch |
| import numpy as np |
| from pytorch_lightning import seed_everything |
|
|
| from run_varestorer import * |
| from conf import HF_TOKEN, HF_HOME |
| from transformers import BlipForConditionalGeneration,BlipProcessor |
|
|
| |
| os.environ['HF_TOKEN'] = HF_TOKEN |
| os.environ['HF_HOME'] = HF_HOME |
| os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1' |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| add_common_arguments(parser) |
| parser.add_argument('--out_dir', type=str, default='') |
| parser.add_argument('--n_samples', type=int, default=1) |
| parser.add_argument('--metadata_file', type=str, default='evaluation/image_reward/benchmark-prompts.json') |
| parser.add_argument('--rewrite_prompt', type=int, default=0, choices=[0,1]) |
| |
| parser.add_argument('--noise_apply_layers',type=int,default=0) |
| parser.add_argument('--noise_apply_requant',type=int,default=1) |
| parser.add_argument('--noise_apply_strength',type=float,default=0.3) |
| parser.add_argument('--debug_bsc',type=int,default=0) |
| |
| args = parser.parse_args() |
|
|
| |
| args.cfg = list(map(float, args.cfg.split(','))) |
| if len(args.cfg) == 1: |
| args.cfg = args.cfg[0] |
| |
| with open(args.metadata_file) as fp: |
| metadatas = json.load(fp) |
|
|
| if args.model_type == 'sdxl': |
| from diffusers import DiffusionPipeline |
| base = DiffusionPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True |
| ).to("cuda") |
| refiner = DiffusionPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-refiner-1.0", |
| text_encoder_2=base.text_encoder_2, |
| vae=base.vae, |
| torch_dtype=torch.float16, |
| use_safetensors=True, |
| variant="fp16", |
| ).to("cuda") |
| elif args.model_type == 'sd3': |
| from diffusers import StableDiffusion3Pipeline |
| pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) |
| pipe = pipe.to("cuda") |
| elif args.model_type == 'pixart_sigma': |
| from diffusers import PixArtSigmaPipeline |
| pipe = PixArtSigmaPipeline.from_pretrained( |
| "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 |
| ).to("cuda") |
| elif args.model_type == 'flux_1_dev': |
| from diffusers import FluxPipeline |
| pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda") |
| elif args.model_type == 'flux_1_dev_schnell': |
| from diffusers import FluxPipeline |
| pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda") |
| elif 'infinity' in args.model_type: |
| |
| text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt) |
| |
| vae = load_visual_tokenizer(args) |
| |
| infinity = load_transformer(vae, args) |
| if args.rewrite_prompt: |
| from tools.prompt_rewriter import PromptRewriter |
| prompt_rewriter = PromptRewriter(system='', few_shot_history=[]) |
| |
| |
| os.makedirs(args.out_dir,exist_ok=True) |
| save_metadatas = [] |
| |
| |
| |
| blip_processor = BlipProcessor.from_pretrained("weights/blip-image-captioning-large") |
| blip_model = BlipForConditionalGeneration.from_pretrained("weights/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda") |
| |
| |
| |
| swinir_config = { |
| "target": "infinity.models.swinir.SwinIR", |
| "params": { |
| "img_size": 64, |
| "patch_size": 1, |
| "in_chans": 3, |
| "embed_dim": 180, |
| "depths": [6, 6, 6, 6, 6, 6, 6, 6], |
| "num_heads": [6, 6, 6, 6, 6, 6, 6, 6], |
| "window_size": 8, |
| "mlp_ratio": 2, |
| "sf": 8, |
| "img_range": 1.0, |
| "upsampler": "nearest+conv", |
| "resi_connection": "1conv", |
| "unshuffle": True, |
| "unshuffle_scale": 8 |
| } |
| } |
| swinir: SwinIR = instantiate_from_config(swinir_config) |
| sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu") |
| if "state_dict" in sd: |
| sd = sd["state_dict"] |
| sd = { |
| (k[len("module.") :] if k.startswith("module.") else k): v |
| for k, v in sd.items() |
| } |
| swinir.load_state_dict(sd, strict=True) |
| for p in swinir.parameters(): |
| p.requires_grad = False |
| swinir.eval().to("cuda") |
| |
| |
| for index, metadata in tqdm(enumerate(metadatas)): |
| seed_everything(args.seed) |
| |
| lq_img_path = metadata['lq_img_path'] |
| prompt = metadata.get('prompt', None) |
| img_name = os.path.relpath(lq_img_path, start=os.path.dirname(lq_img_path)) |
| sample_path = os.path.join(args.out_dir, img_name) |
|
|
| tau = args.tau |
| cfg = args.cfg |
| if args.rewrite_prompt: |
| refined_prompt = prompt_rewriter.rewrite(prompt) |
| input_key_val = extract_key_val(refined_prompt) |
| prompt = input_key_val['prompt'] |
| print(f'prompt: {prompt}, refined_prompt: {refined_prompt}') |
| |
| images = [] |
| bitwise_self_correction= BitwiseSelfCorrection(vae, args) |
| for _ in range(args.n_samples): |
| t1 = time.time() |
| if args.model_type == 'sdxl': |
| image = base( |
| prompt=prompt, |
| num_inference_steps=40, |
| denoising_end=0.8, |
| output_type="latent", |
| ).images |
| image = refiner( |
| prompt=prompt, |
| num_inference_steps=40, |
| denoising_start=0.8, |
| image=image, |
| ).images[0] |
| elif args.model_type == 'sd3': |
| image = pipe( |
| prompt, |
| negative_prompt="", |
| num_inference_steps=28, |
| guidance_scale=7.0, |
| num_images_per_prompt=1, |
| ).images[0] |
| elif args.model_type == 'flux_1_dev': |
| image = pipe( |
| prompt, |
| height=1024, |
| width=1024, |
| guidance_scale=3.5, |
| num_inference_steps=50, |
| max_sequence_length=512, |
| num_images_per_prompt=1, |
| ).images[0] |
| elif args.model_type == 'flux_1_dev_schnell': |
| image = pipe( |
| prompt, |
| height=1024, |
| width=1024, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256, |
| generator=torch.Generator("cpu").manual_seed(0) |
| ).images[0] |
| elif args.model_type == 'pixart_sigma': |
| image = pipe(prompt).images[0] |
| elif 'infinity' in args.model_type: |
| h_div_w_template = 1.000 |
| scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales'] |
| scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] |
| tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]['pixel'] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| image,prompt = gen_one_img_eval_long(infinity, |
| vae, |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| tau_list=tau, |
| cfg_sc=3, |
| cfg_list=cfg, |
| scale_schedule=scale_schedule, |
| cfg_insertion_layer=[args.cfg_insertion_layer], |
| vae_type=args.vae_type, |
| lq_img_path=lq_img_path, |
| args=args, |
| blip_model=blip_model, |
| blip_processor=blip_processor, |
| swinir=swinir, |
| bitwise_self_correction=bitwise_self_correction |
| ) |
| else: |
| raise ValueError |
| t2 = time.time() |
| images.append(image) |
| |
| |
| for i, image in enumerate(images): |
| if 'infinity' in args.model_type: |
| cv2.imwrite(sample_path, image.cpu().numpy()) |
| else: |
| image.save(sample_path) |
| |
| metadata['prompt']=prompt |
| save_metadatas.append(metadata) |
|
|
| save_metadata_file_path = os.path.join(os.path.dirname(args.metadata_file), "metadata_w_prompt.json") |
| with open(save_metadata_file_path, "w") as fp: |
| json.dump(save_metadatas, fp) |
|
|
|
|
|
|
|
|